[백준] 트리의 지름(1967) - Python

문제

Access

문제 접근

트리를 양쪽으로 당긴 다음 지름(가장 긴 경로)를 구하는 문제입니다. 이는 곧 리프노드 간의 거리 중 가장 긴거리를 구하는 것으로 해석할 수 있습니다.

Algorithm

개요

하지만 리프노드들을 구한 다음 브루트 포스 형식으로 일일히 거리를 구하게 된다면 분명 시간 제한에 걸릴지 모릅니다. 따라서 트리를 순회하면서 간선 값을 누적해서 저장하는 방식, 즉 백트래킹을 이용해서 진행합니다.
각 정점마다 리프노드 의 지름 값을 저장하고 각 간선에는 리프노드 로부터 이어지는 누적 간선 값을 저장합니다.

순서


아래와 같이 트리가 있다고 가정합니다.

각 정점의 값을 -1으로 설정합니다. 아직 리프노드 사이의 비용을 구하지 않았기 때문입니다.

맨 밑 리프노드 까지 이동합니다.

리프노드는 자식노드가 존재하지 않기 때문에 지름이 0이 됩니다. 따라서 각 리프노드의 정점 값은 0이 됩니다.

리프노드를 정리했으니 다시 위로 올라갑니다. 해당 노드에는 두개의 리프노드로 지름을 만들 수 있습니다. 따라서 해당 노드의 지름 값은 7 + 1 = 8이 됩니다.

아까 두개의 리프노트중 좌측의 간선 값(7)이 더 컸습니다. 그렇기 때문에 상위 노드로 가는 간선 값에 이전의 간선 값을 누적합니다.
다른 하위 노드도 똑같이 진행합니다.


이렇게 해서 루트 노드 이전의 노드들의 계산이 끝났습니다. 마지막으로 루트 노드 시점에서의 지름을 구하면 되는데 이도 역시 기존 노드에 했던 방식처럼 똑같이 진행합니다. 여기서는 간선 세개 (11, 6, 5) 중 11과 6이 차례대로 크기 때문에 지름의 값은 17이 됩니다.

그래서 사용할 알고리즘은?

트리를 맨 밑바닥 까지 이동한 다음 위로 올라가면서 계산을 하므로 DFS백트래킹을 사용합니다. 그리고 누적 간선값들을 더하거나 저장해야 하므로 다이나믹 프로그래밍도 사용합니다.

Code

초기화

tree = collections.defaultdict(list)
vals = collections.defaultdict(tuple)
  • tree: 트리를 나타냅니다. key는 부모 노드가 되고, value는 자식 노드 리스트 입니다.
  • vals: 누적으로 저장할 정점과 간선의 값입니다. key는 정점 이름, value는 (간선 값, 해당 위치로부터 가장 긴 리프노드 사이의 거리 입니다.)

입력

for _ in range(N-1):
    start, end, value = list(map(int, input()[:-1].split()))
    tree[start].append(end)
    vals[end] = (value, -1)
vals[1] = (0, -1)

vals에서는 부모 노드가 아닌 자식 노드를 key로 잡고 value를 저장합니다. 그 이유는 자식노드로 부터의 간선 값을 저장함으로서 부모 노드위치에서의 간선 값들을 비교할 수 있기 때문입니다.
하지만 자식노드를 key로 잡았기 때문에 루트노드는 값이 초기화 되지 않았습니다. 따라서 따로 (0, -1)을 저장합니다. 루트에서 뻗어나가는 간선이 없기 때문에 0으로 설정합니다.

루틴

값들을 초기했다면 노드번호 1부터 시작해 재귀형태로 진행합니다.
파라미터로 정점 k를 받고 리턴값은 어차피 vals 배열에 누적 간선 값과 지름을 저장할거라 없습니다.

def __search(k: int):
    # k: key

    # if leafnode
    if k not in tree:
        # 지름 값 없음
        vals[k] = (vals[k][0], 0)
    else:
        # 아닌 경우
        candidate_edges = []

        for child_k in tree[k]:
            # child edge 순회
            __search(child_k)
            # 후보 Edge값 push
            heapq.heappush(candidate_edges, -vals[child_k][0])

        # child edge가 하나인 경우
        # 해당 key에서 지름 못구함 따라서 지름은 추가 하지 않고 누적 edge 값만
        if len(candidate_edges) == 1:
            vals[k] = (vals[k][0] + -(heapq.heappop(candidate_edges)), 0)
        else:
            # 두개 이상이면 가장 값이 큰 두개 사용
            fst_e = -heapq.heappop(candidate_edges)
            snd_e = -heapq.heappop(candidate_edges)
            vals[k] = (vals[k][0] + fst_e, fst_e + snd_e)
        # 간선 값 리턴
        return vals[k][0]

리프노드일 경우

밑에 비교할 만한 간선값이 없기 때문에 vals에 간선값과 0을 저장합니다.

vals[k] = (vals[k][0], 0)

브랜치 노드일 경우

자식노드를 순회하여 누적 간선 값들을 계산해 나갑니다. 재귀가 끝날 시점이면 리프노드로부터 해당 정점 까지 향하는 가장 큰 간선 값의 계산이 다 끝나 있습니다. 따라서 자식 노드의 누적 간선 값들을 불러와 Heap에 저장합니다. Heap에 저장하는 이유는 나중에 가장 큰 두 개의 간선 값들을 지정해야 하는데 이에 대한 정렬 알고리즘의 비용을 줄이기 위해서 입니다.
자식 노드를 순회한 후 프로세스는 자식 노드가 한개일 경우와 자식 노드가 두개인 경우로 나뉘어서 진행합니다.

candidate_edges = []

for child_k in tree[k]:
    # child edge 순회
    __search(child_k)
    # 후보 Edge값 push
    heapq.heappush(candidate_edges, -vals[child_k][0])

자식 노드가 한개일 경우

비교할 게 없으므로 다른 프로세스 없이 해당 노드의 기존 간선 값에 자식노드의 간선 값을 누적으로 더해서 저장합니다. 그리고 자식 노드가 하나 뿐인 노드는 지름을 이룰 수 없으므로 0으로 저장합니다.

vals[k] = (vals[k][0] + -(heapq.heappop(candidate_edges)), 0)

두개 이상인 경우

두개 이상이라면 값이 가장 큰 두 개의 간선 값을 가져옵니다. 그런 다음 해당 노드의 값에 가장 큰 간선 값 하나를 누적해서 저장하고, 자식 노드가 두개 이상인 노드는 지름을 이룰 수 있으므로 두 개의 간선 값의 합을 지름으로 저장합니다.

fst_e = -heapq.heappop(candidate_edges)
snd_e = -heapq.heappop(candidate_edges)
vals[k] = (vals[k][0] + fst_e, fst_e + snd_e)

예외 처리

하지만 이런 형태의 트리도 존재합니다.

기존 형태의 트리와 달리 자식노드가 하나뿐인 루트노드는 자기 자신이 리프노드가 될 수 있습니다. 따라서 vals의 루트 노드(1번)부분에 다음과 같이 수정합니다. 루트노드의 자식노드가 하나밖에 없다면 간선 값이 곧 지름이 됩니다.

if len(tree[1]) == 1:
    vals[1] = (vals[1][0], vals[1][0])

가장 긴 지름 구하기

vals의 value부분을 list로 따로 모은 다음, 크기가 가장 큰 지름값 순서대로 정렬합니다.

r = list(vals.values())
r.sort(key=lambda x: x[1], reverse=True)
answer = r[0][1]
print(answer)

좋은 웹페이지 즐겨찾기