BOJ 11438 LCA2

https://www.acmicpc.net/problem/11438
시간 1.5초, 메모리 256MB

input :

  • N(2 ≤ N ≤ 100,000)
  • N-1개 줄 : 트리 상에서 연결된 두 정점
  • M(1 ≤ M ≤ 100,000)
  • M개 줄에는 정점 쌍

output :

  • 첫 줄에 동호가 받을 수 있는 최대 컵라면 수를 출력

조건 :

  • 트리의 각 정점은 1번부터 N번까지 번호가 매겨져 있으며, 루트는 1번이다.

  • 두 노드의 가장 가까운 공통 조상이 몇 번인지 출력


앞의 LCA를 개선해야만 해결이 가능한 문제이다.

개선

우선 개선된 방향을 본다면 lca함수 내에서 이동을 log의 방식으로 움직인다.
그렇기에 시간 복잡도는 매우 강력해지지만 2^0 ~ 2^x까지의 노드를 기록하고 있어야 한다.
고로 메모리와 시간 복잡도를 교환했다고 볼 수 있지만 일단 시간은 빨라진다....

DP

이 LCA에서는 모든 노드를 기록하고 있어야 한다고 했다.
우리가 LCA를 구현할 때 parent에 정의한 것은 2^0 위 칸에 존재하는 노드의 정보이다.
고로 이를 2차원 배열을 통해 모두 저장하고 있게 할 수 있다.

dfs를 수행한 다음에 2^1 ~ 2^x까지 돌면서 점진적으로 저장핟도록 해야 한다.
왜냐? 맨 처음에 가지고 있는 정보는 2^0에 위치한 놈 밖에 없기 때문에 그러하다.

# 2^i 번째의 부모 노드의 값을 저장.
    for i in range(1, 21):
        # j번째 노드 기준으로.
        for j in range(1, n + 1):
            # 2번쨰 칸 위의 값을 저장하려 할 때
            # 현재 노드의 부모 노드로 이동 부모 노드의 "부모"노드의 값을 가져옴.
            # 이러한 단계로 모든 DP를 초기화 함.
            parent[j][i] = parent[parent[j][i - 1]][i - 1]

parent[j][i - 1] 얘를 통해서 2^(i - 1) 위 칸에 위치하는 놈의 2^(i - 1) 위에 존재하는 노드를 가져올 수 있다. 암튼 타고 타고 올라가서 가져오는 것이다.

lca

lca 파트에서는 이동 가능한 모든 위치를 확인 하는 것이다.
현재 존재하는 깊이의 차이가 9인 경우에는
8칸 위, 1칸 위 이렇게 움직이는 것이다.

저렇게 이동할 수 있는 위치일 때 그 위치에 존재하는 부모를 찾는 것이 이 문제를 관통하는 원칙이다.

import sys
sys.setrecursionlimit(100000)

def dfs(node, depth):
    visit[node] = True
    d[node] = depth

    for next_node in graph[node]:
        if visit[next_node] == 1:
            continue
        parent[next_node][0] = node
        dfs(next_node, depth + 1)

def set_parent():
    """
        dp로 모든 부모 관계를 저장한다.
    """
    dfs(1, 0)
    # 2^i 번째의 부모 노드의 값을 저장.
    for i in range(1, 21):
        # j번째 노드 기준으로.
        for j in range(1, n + 1):
            # 2번쨰 칸 위의 값을 저장하려 할 때
            # 현재 노드의 부모 노드로 이동 부모 노드의 "부모"노드의 값을 가져옴.
            # 이러한 단계로 모든 DP를 초기화 함.
            parent[j][i] = parent[parent[j][i - 1]][i - 1]

def lca(high, low):
    # 가지고 있는 값이 클수록 더 깊이 있는 거임.
    if d[high] > d[low]:
        high, low = low, high

    # 2^20 ~ 2^0 까지 반복을 수행
    for log in range(20, -1, -1):
        if d[low] - d[high] >= (1 << log):
            low = parent[low][log]

    # 서로 동일한 노드 인 경우.
    if high == low:
        return high

    for log in range(20, -1, -1):
        if parent[high][log] != parent[low][log]:
            low = parent[low][log]
            high = parent[high][log]

    return parent[high][0]

n = int(sys.stdin.readline())
graph = [[] for _ in range(n + 1)]
parent = [[0] * 21 for _ in range(n + 1)]
d = [0] * (n + 1)
visit = [0] * (n + 1)

for i in range(n - 1):
    a, b = map(int, sys.stdin.readline().split())
    graph[a].append(b)
    graph[b].append(a)

set_parent()

m = int(sys.stdin.readline())
for i in range(m):
    a, b = map(int, sys.stdin.readline().split())
    print(lca(a, b))

좋은 웹페이지 즐겨찾기