lgorithm | array - find pivot index

> 문제

배열의 element들 중 특정 값의 왼쪽에 있는 숫자들의 합과 오른쪽에 있는 합이 같을 때 그 숫자를 pivot이라 하고 pivot의 인덱스를 리턴하는 문제입니다.

위와 같은 경우 9의 왼쪽의 element들의 합은 11이고 오른쪽 element들의 합은 11이기 때문에 9는 pivot이 되고 index 3이 리턴됩니다.

위 배열의 경우 pivot이 존재할 수 없고 -1을 리턴하면 정답이 됩니다.




> solution1_brute force

완전탐색 알고리즘을 생각해볼 수 있습니다.

예시와 같이 처음에 8을 기준으로 오른쪽의 숫자들을 더해줍니다.(O(n))


이후 기준을 옮기고 좌 우의 숫자들을 각각 더합니다(O(n) + O(n))

이렇게 계속 진행하면 O(n)*n이 되기 때문에 O(n²)의 time complexity가 됩니다.

더 좋은 방법을 생각해야 합니다.




> solution2_sliding window

전체 합계를 구합니다.(O(n))

전체합계 = 31


8을 pivot으로 잡았다면 전체합계에서 8을 빼줍니다.

전체(오른쪽)합계 = 31 - 8

아직 pivot 왼쪽의 합계는 0입니다.


pivot이 오른쪽으로 이동하고 pivot은 2가 되었습니다.

전체(오른쪽)합계 = 31 - 8 - 2

왼쪽 합계에 이전 pivot이었던 8을 더해줍니다.

왼쪽합계 = 0 + 8


이와 같은 방식으로 pivot이 이동하며 좌•우의 값들을 더하고 빼며 pivot index를 찾아나가게 됩니다.

pivot이 한 번 이동할 때 필요한 계산량은 O(1)이고 총 n개의 pivot이 있으므로 o(n)의 time complexity가 필요하고 맨 처음 전체합계를 한 O(n)까지 총 O(n)으로 array 갯수에 linear한 time complexity를 갖습니다.


▶︎ code

array = [8, 2, 1, 9, 3, 6, 2]

def find_pivot_index(values):
    total_sum = sum(values)
    left_sum = 0
    right_sum = total_sum

    past_pivot = 0
    for i in range(len(array)):
        pivot = values[i]
        right_sum = right_sum - pivot
        left_sum = left_sum + past_pivot

        if left_sum == right_sum:
            return i
        
        past_pivot = pivot
    
    return -1

print(f'The index of pivot is {find_pivot_index(array)}')




출처: 코드없는 프로그래밍

좋은 웹페이지 즐겨찾기