[TIL] 알고리즘&자료구조: 트리와 바이너리 인덱스 트리

3. 트리 자료구조 (Tree)

가계도처럼 계층적인 구조를 표현할 때 사용할 수 있는 자료구조

  • 트리 관련 용어
    • 루트 노드: 부모가 없는 최상위 노드
    • 단말 노드: 자식이 없는 노드
    • 크기: 트리에 포함된 모든 노드의 개수
    • 깊이: 루트 노드부터의 거리
    • 높이: 깊이 중 최댓값
    • 차수: 각 노드의 자식 방향 간선 개수

📌 트리 크기가 N일 때 전체 간선의 개수는 N-1

3-1. 이진 탐색 트리(Binary Search Tree)

  • 이진 탐색이 동작할 수 있도록 고안된 효율적인 탐색이 가능한 자료구조
  • 특징
    • 왼쪽 자식 노드 < 부모 노드 < 오른쪽 자식 노드
    • 부모 노드보다 왼쪽 자식 노드가 작다.
    • 부모 노드보다 오른쪽 자식 노드가 크다.
  • 데이터 조회 과정
    1. 루트 노드부터 방문하여 탐색 진행
    2. 현재 노드와 찾는 원소 값을 비교
    3. 찾는 원소가 더 크면 오른쪽 노드 방문
    4. 현재 노드가 찾는 원소보다 더 크면 왼쪽 노드 방문
    5. 찾는 원소에 다다를 때까지 반복

3-2. 트리의 순회

트리 자료구조에 포함된 노드를 특정 방법으로 한 번씩 방문하는 방법. 트리의 정보를 시각적으로 확인할 수 있으므로 자주 사용한다.

  • 트리 순회 방법
    • 전위 순회: 루트➡왼쪽 자식➡오른쪽 자식
    • 중위 순회: 왼쪽 자식➡루트➡오른쪽 자식
    • 후위 순회: 왼쪽 자식➡오른쪽 자식➡루트

📌 구현 예제

class Node:
    def __init__(self, data, left_node, right_node):
        self.data = data
        self.left_node = left_node
        self.right_node = right_node

# 전위 순회
def pre_order(node):
    print(node.data, end=' ')
    if node.left_node != None:
        in_order(tree[node.left_node])
    print(node.data, end=' ')
    if node.right_node != None:
        in_order(tree[node.right_node])
        
# 중위 순회
def in_dorder(node):
    if node.left_node != None:
        in_order(tree[node.left_node])
    print(node.data, end=' ')
    if node.right_node != Nonde:
        in_order(tree[node.right_node])

# 후위 순회
def post_order(node):
    if node.left_node != None:
        post_order(tree[node.left_node])
    if node.right_node != None:
        post_order(tree[node.right_node])
    print(node.data, end=' ')
    
n = int(input())
tree = []

for i in range(n):
    data, left_node, right_node = input().split()
    if left_node == "None":
        left_node = None
    if right_node == "None":
        right_node = None
    tree[data] = Node(data, left_node, right_node)

pre_order(tree['A'])
print()
in_order(tree['A'])
print()
post_order(tree['A'])

4. 바이너리 인덱스 트리 (Binary Indexed Tree)

a.k.a. BIT, 펜윅 트리(Fenwick Tree). 이진법 인덱스 구조를 활용해서 구간 합 문제를 효과적으로 해결해 줄 수 있는 자료구조

👉 구간 합 알고리즘 문제: https://www.acmicpc.net/problem/2042

  • 정수에 따른 이진수 표기 예시

    정수이진수 표기
    700000000 00000000 00000000 0000111
    -711111111 11111111 11111111 1111001
  • 0이 아닌 마지막 비트를 찾는 방법

    • 특정 숫자 K의 0이 아닌 마지막 비트 계산: K & -K
    n = 8
    for i in range(n+1):
        print(i, "의 마지막 비트:", (i & -i))
  • 바이너리 인덱스 트리 구현 동작 원리

    1. 생성: 0이 아닌 마지막 비트 = 내가 저장하고 있는 값들의 개수

    2. 업데이트

      • 특정 값을 변경하는 경우: 0이 아닌 마지막 비트만큼 더하면서 구간들의 값을 변경

        예) 3번 값 변경: 1~4번 값 변경 ➡ 1~8번 값 변경 ➡ 1~16번값 변경

    3. 누적 합 구하기

      • 1부터 N까지의 누적 합: 0이 아닌 마지막 비트만큼 빼면서 구간들의 값의 합을 계산

        예) 11번까지 누적 합 구하기: 11번 값 + 9~10번 값 + 1~8번 값

📌 구현 예시

import sys
input = sys.stdin.readline

# 데이터의 개수(n), 변경 횟수(m), 구간 합 계산 횟수(k)
n, m, k = map(int, input().split())

# 전체 데이터의 개수는 최대 1,000,000개
arr = [0] * (n + 1)
tree = [0] * (n + 1)

# i번째 수까지의 누적 합을 계산하는 함수
def prefix_sum(i):
    result = 0
    while i > 0:
        result += tree[i]
        # 0이 아닌 마지막 비트만큼 빼가면서 이동
        i -= (i & -i)
    return result

# i번째 수를 dif만큼 더하는 함수
def update(i, dif):
    while i <= n:
        tree[i] += dif
        i += (i & -i)

# start부터 end까지의 구간 합을 계산하는 함수
def interval_sum(start, end):
    return prefix_sum(end) - prefix_sum(start - 1)

for i in range(1, n + 1):
    x = int(input())
    arr[i] = x
    update(i, x)

for i in range(m + k):
    a, b, c = map(int, input().split())
    # 업데이트 연산인 경우
    if a == 1:
        update(b, c - arr[b])	# 바뀐 크기(dif)만큼 적용
    else:
        print(interval_sum(b, c))

좋은 웹페이지 즐겨찾기