백준 1717번 집합의 표현

백준 1717번 집합의 표현

출처 : https://www.acmicpc.net/problem/1717

코드

import sys

n, m = map(int, sys.stdin.readline().split())
calc = []  
for i in range(m):
  calc.append(list(map(int,sys.stdin.readline().split())))

info_dic = {}
set_dic = {}
now_set = 0

for index in range(m):
  kinds, a, b = calc[index]
  if kinds == 0:    
    if a not in info_dic.keys() and b not in info_dic.keys():
      info_dic[a] = now_set
      info_dic[b] = now_set
      set_dic[now_set] = [a, b]
      now_set += 1
    
    elif a not in info_dic.keys() and b in info_dic.keys():
      info_dic[a] = info_dic[b]
      set_dic[info_dic[b]].append(a)
      
    elif a in info_dic.keys() and b not in info_dic.keys():
      info_dic[b] = info_dic[a]
      set_dic[info_dic[a]].append(b)

    elif info_dic[a] != info_dic[b]:
      li = set_dic.pop(info_dic[b])
      set_dic[info_dic[a]].extend(li)
      for number in li:
        info_dic[number] = info_dic[a]
      
  elif kinds == 1:
    if a in info_dic.keys() and b in info_dic.keys():
      if info_dic[a] == info_dic[b]:
        print('YES')
      else:
        print('NO')
    else:
      if a == b:
        print('YES')
      else:
        print('NO')

풀이방법

저는 이 문제를 파이썬의 딕셔너리 자료형을 이용해서 풀었습니다.

  • info_dic을 통해서 각 숫자가 들어가 있는 집합의 index를 저장

  • set_dic을 통해서 각 set_dic[index]의 집합을 리스트로 저장

  • 주어진 연산에 따라서 조건문을 통해서 구현

  • 조건은 순서대로 다음과 같다.

  1. a, b 가 둘다 정해진 집합이 없을 때
  2. a는 정해진 집합이 없고 b는 정해진 집합이 있을 때
  3. b는 정해진 집합이 없고 a가 정해진 집합이 있을 때
  4. 둘다 정해진 집합이 있어서 합쳐야 하는 경우

이 문제 자체가 0을 포함하기도 하고 여러방면에서 실수가 날 수 있는 문제조건이라서 여러번 틀렸지만 그래도 풀어냈다!!

Union Find

문제를 해결한 후 게시판 및 자료를 찾아보니 Union Find 방법으로 풀면 쉽다는 이야기가 있었다!

union

  • 2개 원소로 이루어진 집합을 하나의 집합으로 합침(합집합)

  • 파이썬 딕셔너리를 이용해서 합집합의 계산이 나올때마다 해당 값들의 루트 노드들을 고쳐준다.

find

  • 특정 원소가 속한 집합이 뭔지 알려주는 연산

  • 유니온의 정보가 담긴 딕셔너리를 통해서 해당 원소가 어디 집합에 속해져 있는지 루트노드를 찾아나가다 보면 찾을 수 있음

  • 특정상황에서는 비효율적인 경로로 루트노드를 찾게된다.

경로 압축

  • 비효율적인 find를 없애기 위해서 union 과정에서 노드의 정보가 루트가 아닌 경우에는 찾아서 딕셔너리를 다시 업데이트 해주는 방식을 사용

소스 코드

import sys

def find_parent(x):  
  if info[x] != x:
    return find_parent(info[x])
  return info[x]

def union(a, b):
  a = find_parent(a)
  b = find_parent(b)
  if a == b:
    return
  elif a < b:
    info[b] = a
  elif a > b:
    info[a] = b

def solve():
  for index in range(m):
    kinds, a, b = map(int, sys.stdin.readline().split())
    if kinds == 0:
      union(a, b)
    elif kinds == 1:
      if find_parent(a) == find_parent(b):
        print("YES")
      else:
        print("NO")

sys.setrecursionlimit(10**5)
n, m = map(int, sys.stdin.readline().split())
info = [i for i in range(n+1)]
solve()

좋은 웹페이지 즐겨찾기