트라이(Trie)

트라이 (Trie)

트라이(Trie)란?

트라이는 CS에서 탐색 트리의 일종이다. 문자열이 키인 경우가 많다.

이진 탐색 트리와 달리 트리의 어떤 노드도 그 노드 자체와 연관된 키는 저장하지 않는다.

대신 노드가 트리에서 차지하는 위치가 연관된 키를 정의한다.

즉, 키의 값은 자료 구조 전체에 분산된다. 노드의 모든 자손은 노드에 연관된 문자열의 공통 접두사를 공유한다.

루트(head)는 빈 문자열에 연관된다.

일부 내부 노드가 키에 대응할 수도 있지만, 일반적으로는 키는 단말에 연관되는 경향이 있다. 따라서 모든 노드가 꼭 키와 연결되지는 않는다.

문자열 검색에 효율적이다.

트라이 예시

각 문자열의 알파벳이 노드의 value가 된다.

어떤 노드는 word의 원소에 대한 정보를 담고 있는데 이는 해당 노드까지 탐색하게 되면 탐색 과정에서 등장한 알파벳으로 이뤄진 단어가 현재 트라이에 존재한다는 의미를 나타낸다.

예를 들어 ‘korea’ 에서는 ‘a’ 에서만 ‘korea’ 정보를 갖고 있다.

‘korean’은 ‘n’에서 ‘korean’ 정보를 갖고 있다.

이처럼 트라이는 접두사가 중복되는 단어를 효율적으로 관리할 수 있으며 어떤 단어를 찾는 쿼리에서의 시간 복잡도는 O(N)O(N) (NN = 찾는 단어의 길이) 으로 찾을 수 있다.

시간 복잡도면에서 뛰어나지만 알파벳 마다, 그 순서 마다 만들어야 하는 노드가 매우 많으므로 메모리를 많이 쓰게 된다.

문제

14425번: 문자열 집합

개요

사실 위 문제는 단순하게 Hashing 으로도 해결이 가능하다.

하지만 메모리 제한이 1536MB 인것을 생각해보면 트라이를 연습하기에 좋은 문제라고 생각한다.

트라이 구현

노드 클래스

class Node:
	def __init__(self, key, data=None):
		self.key = key
		self.data = data
		self.children = {}

트라이의 구성 요소인 노드에 관한 클래스이다.

key(알파벳), data(단어 존재 유무), children(현재 노드의 자식 노드)로 구성된다.

data를 통해서 현재 노드의 key를 마지막 철자로 하는 단어가 있는지 없는지를 판별할 수 있다.

children은 Trie 클래스의 메서드에서 재귀적으로 노드를 탐색할 때 사용된다.

Trie 클래스

class Trie:
	def __init__(self):
		self.head = Node(None)

트라이 클래스의 생성자는 key = None인 노드이다.

그리고 이 노드는 Trie의 head, 즉 루트 노드가 된다.

insert

def insert(self, word):
		cur_node = self.head
		for char in word:
			if char not in cur_node.children:
				cur_node.children[char] = Node(char)
			cur_node = cur_node.children[char]
		cur_node.data = word

트라이에 어떤 단어를 삽입하는 메서드이다.

word를 parameter로 받고 word 안의 철자를 하나씩 노드 객체로 생성해서 트라이에 넣어준다.

트라이에 넣어준다는 것은 head 부터 차곡차곡 자식 노드(children)에 할당 해준다는 의미이다.

search

def search(self, word):
		cur_node = self.head
		for char in word:
			if char in cur_node.children:
				cur_node = cur_node.children[char]
			else:
				return False

찾고자 하는 단어가 트라이에 존재하는지 확인하는 메서드이다.

여기서 node.data가 사용된다.

head에서 부터 차례대로 children을 확인하고 철자가 일치하는 노드가 존재하면 재귀적으로 같은 탐색을 반복한다.

만약 해당 철자가 트라이에 없거나 모든 철자가 트라이에 존재하지만 data가 없는 경우에는 False를 리턴한다.

start_with

def start_with(self, prefix):
		cur_node = self.head

		for char in prefix:
			if char in cur_node.children:
				cur_node = cur_node.children[char]
			else:
				return None

		words = []
		next_node = []
		if cur_node.data: #app
			words.append(cur_node.data)
		next_node.extend(list(cur_node.children.values()))
		if len(next_node) == 0: return words

		cur_node = list(next_node)
		next_node = []
		while True:
			for node in cur_node:
				if node.data:
					words.append(node.data)
				next_node.extend(list(node.children.values()))
			if len(next_node) == 0: return words
			cur_node = list(next_node)
			next_node = []

찾고자 하는 문자열을 접두사(prefix)로 가지는 단어를 모두 리턴하는 메서드이다.

head에서 부터 child를 검사하여 prefix의 전체 알파벳이 트라이에 존재하는지 확인한다.

존재하지 않으면 None을 리턴한다.

prefix의 마지막 알파벳 이후로 모든 child를 탐색하고 그 child들의 노드에서 data가 있는지 확인해야한다.

단어들은 words에 append 되고 현재 노드의 자식 노드들을 next_node에 저장한 다음에 cur_node를 next_node로 갱신해준다. 이후 갱신된 cur_node에 대해 같은 작업을 반복한다.

작업 수행 방식이 BFS와 유사하다.

모든 child들을 우선적으로 검사하고 그 child들의 child를 다음 while 반복문에서 작업한다.

만약 더 이상 cur_node가 갱신되지 않는다면 저장된 word를 리턴한다.

문제 정답 코드

import sys

In = lambda: sys.stdin.readline().rstrip()
MIS = lambda: map(int, In().split())

class Node:
	def __init__(self, key, data=None):
		self.key = key
		self.data = data
		self.children = {}

class Trie:
	def __init__(self):
		self.head = Node(None)

	def insert(self, word):
		cur_node = self.head
		for char in word:
			if char not in cur_node.children:
				cur_node.children[char] = Node(char)
			cur_node = cur_node.children[char]
		cur_node.data = word

	def search(self, word):
		cur_node = self.head
		for char in word:
			if char in cur_node.children:
				cur_node = cur_node.children[char]
			else:
				return False

		if cur_node.data:
			return True
		else:
			return False

	def start_with(self, prefix):
		cur_node = self.head

		for char in prefix:
			if char in cur_node.children:
				cur_node = cur_node.children[char]
			else:
				return None

		words = []
		next_node = []
		if cur_node.data: #app
			words.append(cur_node.data)
		next_node.extend(list(cur_node.children.values()))
		if len(next_node) == 0: return words

		cur_node = list(next_node)
		next_node = []
		while True:
			for node in cur_node:
				if node.data:
					words.append(node.data)
				next_node.extend(list(node.children.values()))
			if len(next_node) == 0: return words
			cur_node = list(next_node)
			next_node = []

if __name__ == "__main__":
	N, M = MIS()
	trie = Trie()
	for _ in range(N):
		trie.insert(In())

	ans = 0
	for _ in range(M):
		if trie.search(In()):
			ans += 1
	print(ans)

좋은 웹페이지 즐겨찾기