[알고리즘] 최소 신장 트리

최소 신장 트리 알고리즘

  • 신장 트리 중에서 최소 비용으로 만들 수 있는 신장 트리를 찾는 알고리즘

신장 트리 (Spanning Tree)

  • 하나의 그래프가 있을 때 모든 노드를 포함하면서 사이클이 존재하지 않는 부분 그래프
  • 간선의 개수 = 노드의 개수 - 1

크루스칼

  • 대표적인 최소 신장 트리 알고리즘
  • 가장 적은 비용으로 모든 노드를 연결할 수 있음
  • 그리디 알고리즘으로 분류됨
  • 가장 거리가 짧은 간선부터 차례대로 집합에 추가하면 된다.
    ❗다만, 사이클을 발생시키는 간선은 제외하고 연결한다!

동작 과정

  1. 간선 데이터를 비용에 따라 오름차순으로 정렬
  2. 간선을 하나씩 확인하면 현재의 간선이 사이클을 발생시키는지 확인
    • 사이클이 발생하지 않는 경우 최소 신장 트리에 포함시킨다.
    • 사이클이 발생하는 경우 최소 신장 트리에 포함시키지 않는다.
  3. 모든 간선에 대하여 2번의 과정을 반복한다.

구현

🕰 시간복잡도 : O(ElogE)

  • 간선의 개수 E개
  • Python 코드
# 특정 원소가 속합 집합을 찾기
def find_parent(parent, x):
	# 루트 노드가 아니라면, 루트 노드를 찾을 때까지 재귀적으로 호출
    if parent[x] != x:
    	parent[x] = find_parent(parent, parent[x])
    return parent[x]

# 두 원소가 속한 집합을 합치기
def union_parent(parent, a, b):
	a = find_parent(parent, a)
    b = find_parent(parent, b)
    if a<b:
    	parent[b] = a
    else:
    	parent[a] = b
        
# 노드의 개수와 간선(union 연산)의 개수 입력받기
v, e = map(int, input().split())
parent = [0] * (v+1)

# 모든 간선을 담을 리스트와 최종 비용을 담을 변수
edges = []
result =0

# 부모 테이블상에서, 부모를 자기 자신으로 초기화
for i in range(1, v+1):
	parent[i] = i

# 모든 간선에 대한 정보를 입력받기
for _ in range(e):
	a, b, cost = map(int, input().split())
    # 비용순으로 정렬하기 위해서 튜플의 첫 번째 원소를 비용으로 산정
   	edges.append((cost, a, b))
    
# 간선을 비용순으로 정렬
edges.sort()

# 간선을 하나씩 확인하며
for edge in edges:
	cost, a, b =edge
    # 사이클이 발생하지 않는 경우에만 집합에 포함
    if find_parent(parent, a) != find_parent(parent, b):
    	union_parent(parent, a, b)
        result+=cost
        
print(result)

프림

  • 대표적인 최소 신장 트리 알고리즘
  • 시작 정점을 선택한 후, 정점에서 인접한 간선 중 최소 간선으로 연결된 정점을 선택하고, 해당 정점에서 다시 최소 간선으로 연결된 정점을 선택하는 방식으로 최소 신장 트리를 확장해나가는 방식

동작 과정

  1. 임의의 정점을 선택, '연결된 노드 집합'에 삽입
  2. 선택된 정점에 연결된 간선들을 간선 리스트에 삽입
  3. 간선 리스트에서 최소 가중치를 가지는 간선부터 추출해서
    • 해당 간선에 연결된 인접 정점이 '연결된 노드 집합'에 이미 들어있다면, 스킵
    • 해당 간선에 연결된 인접 정점이 '연결된 노드 집합'에 들어 있지 않으면, 해당 간선을 선택하고, 해당 간선 정보를 '최소 신장 트리'에 삽입
  4. 추출한 간선을 간선 리스트에서 제거
  5. 간선 리스트에 더 이상 간선이 없을 때까지 3~4번 반복

구현

  • 기본적인 구현
    🕰 시간복잡도 : O(ElogE)
    - 간선의 개수 E개
    - python 코드

     from collections import defaultdict
     from heapq import *
    
     def prim(start_node, edges):
         mst = list()
         # 모든 간선의 정보를 adjacent_egdes에 저장
         adjacent_egdes = defaultdict(list)
         for weight, n1, n2 in edges:
             adjacent_egdes[n1].append((weight, n1, n2))
             adjacent_egdes[n2].append((weight, n2, n1))
    
         # 연결된 노드 집합에 시작 노드 포함
         connected_nodes = set(start_node)
    
         # 시작(특정) 노드에 연결된 간선 리스트
         candidate_edge_list = adjacent_egdes[start_node]
         heapify(candidate_edge_list) # 가중치 순으로 간선 리스트 정렬
    
         # 최소 가중치를 가지는 간선부터 추출
         while candidate_edge_list:
             # 최소 가중치 간선이 추출됨
             weight, n1, n2 = heappop(candidate_edge_list) 
             if n2 not in connected_nodes:
                 connected_nodes.add(n2)
                 mst.append((weight, n1, n2))
                 # n2의 간선들 중 
                 # 연결된 노드 집합에 없는 노드의 간선들만 
                 # 후보 간선 리스트에 추가함
                 for edge in adjacent_egdes[n2]:
                     if edge[2] not in connected_nodes:
                         heappush(candidate_edge_list, edge)
         return mst
    
     # 중복된 간선 제외(defaultdict을 이용하기 때문)
     myedges = [
         (7, 'A', 'B'), (5, 'A', 'D'),
         (8, 'B', 'C'), (9, 'B', 'D'), (7, 'B', 'E'),
         (5, 'C', 'E'),
         (7, 'D', 'E'), (6, 'D', 'F'),
         (8, 'E', 'F'), (9, 'E', 'G'),
         (11, 'F', 'G')
     ]
    
     print(prim('A', myedges))
    
     """
         [(5, 'A', 'D'), 
          (6, 'D', 'F'), 
          (7, 'A', 'B'), 
          (7, 'B', 'E'), 
          (5, 'E', 'C'), 
          (9, 'E', 'G')]
     """
     ```
     
  • 개선된 구현
    🕰 시간복잡도 : O(ElogV)
    - 간선의 개수 E개, 노드의 개수 V
    - 간선이 아닌 노드를 중심으로 우선순위 큐를 적용
    - 노드마다 Key 값을 가지고 있고 key 값은 우선순위에 넣는다.
    - python 코드

     
     """ 
       heapdict을 쓰는 이유:
       새롭게 heap에 push하거나 pop하지 않아도
       기존의 heap 내용만 update한다 하더라도
       알아서 최소 힙의 구조로 업데이트 함 
       * heapdict을 쓰기 전에 HeapDict 라이브러리 설치하기:
         -> pip install HeapDict
     """ 
    
     from heapdict import heapdict
    
     # 시간 복잡도: O(E logV) = O(V) + O(V logV) + O(E logV)
     def prim(graph, start):
         # pi: key를 업데이트하게 하는 상대 노드를 저장
         mst, keys, pi, total_weight = list(), heapdict(), dict(), 0
         # 초기화: O(V)
         for node in graph.keys():
             keys[node] = float('inf')
             pi[node] = None
         keys[start], pi[start] = 0, start
         # while문: O(V logV)
         while keys:
             current_node, current_key = keys.popitem()
             mst.append([pi[current_node], current_node, current_key])
             total_weight += current_key
             # for문: O(E logV)
             for adjacent, weight in mygraph[current_node].items():
                 if adjacent in keys and weight < keys[adjacent]:
                     keys[adjacent] = weight
                     pi[adjacent] = current_node
         return mst, total_weight
    
     mygraph = {
         'A': {'B': 7, 'D': 5},
         'B': {'A': 7, 'D': 9, 'C': 8, 'E': 7},
         'C': {'B': 8, 'E': 5},
         'D': {'A': 5, 'B': 9, 'E': 7, 'F': 6},
         'E': {'B': 7, 'C': 5, 'D': 7, 'F': 8, 'G': 9},
         'F': {'D': 6, 'E': 8, 'G': 11},
         'G': {'E': 9, 'F': 11}    
     }
     mst, total_weight = prim(mygraph, 'A')
     print('MST:', mst)
     print('Total Weight:', total_weight)
    
     """ 
         MST: [['A', 'A', 0], ['A', 'D', 5], ['D', 'F', 6], 
         ['A', 'B', 7], ['D', 'E', 7], ['E', 'C', 5], ['E', 'G', 9]]
         Total Weight: 39
     """
    	```

크루스칼 vs 프림

  • 둘다 그리디 알고리즘으로 분류됨
  • 크루스칼 알고리즘은 가장 가중치가 작은 간선부터 선택하면서 MST 구함
  • 프림 알고리즘은 특정 정점에서 시작, 해당 정점에 연결된 가장 가중치가 작은 간선을 선택, 간선으로 연결된 정점들에 연결된 간선 중에서 가장 가중치가 작은 간선을 선택하면서 MST 구함

유니온파인드

서로소 집합 자료구조 (union-find 자료구조)

  • 서로소 부분 집합들로 나누어진 원소들의 데이터를 처리하기 위한 자료구조
  • 연산 : union, find

union(합집합) 연산

  • 2개의 원소가 포함된 집합을 하나의 집합으로 합치는 연산

find(찾기) 연산

  • 특정한 원소가 속한 집합이 어떤 집합인지 알려주는 연산

구현

  • 트리 자료구조를 이용하여 집합을 표현
  1. union(합집합) 연산을 확인하여, 서로 연결된 두 노드 A, B를 확인한다.
    1) A와 B의 루트 노드 A', B'를 각각 찾는다.
    2) A'를 B'의 부모 노드로 설정한다(B'가 A'를 가리키도록 한다)
  2. 모든 union(합집합) 연산을 처리할 때까지 1번 과정을 반복한다.

    부모 테이블을 항상 가지고 있어야 한다.
    서로소 집합 알고리즘으로 루트를 찾기 위해서는 재귀적으로 부모를 거슬러 올라가야 한다.

Python 코드

  • 비효율적인 코드
    🕰 시간복잡도 : O(VM)
    • 노드의 개수 V개, find/union 연산의 개수 M개
# 특정 원소가 속한 집합을 찾기
def find_parent(parent, x):
	# 루트 노드가 아니라면, 루트 노드를 찾을 때까지 재귀적으로 호출
    if parent[x] !=x:
    	return find_parent(parent, parent[x])
    return x

# 두 원소가 속한 집합을 합치기
def union_parent(parent, a, b):
	a = find_parent(parent, a)
    b = find_parent(parent, b)
    if a<b:
    	parent[b]=a
    else:
    	parent[a]=b

# 노드의 개수와 간선(union 연산)의 개수 입력받기
v, e = map(int, input().split())
parent = [0] * (v+1)

# 부모 테이블 상에서 부모를 자기 자신으로 초기화
for i in range(1, v+1):
	parent[i] = i

# union 연산을 각각 수행
for i in range(e):
	a, b = map(int, input().split())
    union_parent(parent, a, b)
  
# 각 원소가 속한 집합 출력
print('각 원소가 속한 집합: ', end = '')
for i in range(1, v+1):
	print(find_parent(parent, i), end=' ')
    
print()

# 부모 테이블 내용 출력
print('부모 테이블: ', end='')
for i in range(1, v+1):
	print(parent[i], end=' ')
  • 최적화된 코드
    🕰 시간복잡도 : O(V+MlogV)
    • 노드의 개수 V개, find/union 연산의 개수 M개
    • 경로 압축(Path Compression) 기법 적용
    • find 함수를 재귀적으로 호출한 뒤에 부모 테이블값을 갱신하는 기법
    • 루트 노드에 더욱 빠르게 접근할 수 있다!
# 특정 원소가 속한 집합을 찾기
def find_parent(parent, x):
	# 루트 노드가 아니라면, 루트 노드를 찾을 때까지 재귀적으로 호출
    if parent[x] !=x:
    	parent[x]= find_parent(parent, parent[x])
    return parent[x]

'서로 다른 개체(객체)가 연결되어 있다'
-> 그래프 알고리즘 떠올리기!

ex) 여러 개의 도시가 연결되어 있다.

좋은 웹페이지 즐겨찾기