세그먼트 트리란?

세그먼트 트리는 특정 구간에 대한 연산을 매우 빠르게 처리하기 위한 트리 자료구조입니다. 데이터의 개수가 매우 많고, 특정 구간의 합이나 최댓값/최솟값을 구하는 질의와 데이터 변경이 빈번하게 발생할 때 압도적인 성능을 발휘합니다.

미리 구간별 연산 결과를 계산하여 트리 형태로 저장해두고, 질의가 들어오면 이 값들을 조합하여 **O(log N)**의 시간 복잡도로 결과를 도출합니다.

세그먼트 트리의 특징

  • 빠른 구간 연산: 구간 합, 최대/최소 등을 로그 시간 안에 처리 가능
  • 효율적인 데이터 업데이트: 특정 원소의 값이 변경되어도 로그 시간 안에 트리 전체에 반영 가능
  • 배열 기반 구현: 완전 이진트리 구조를 활용해 배열로 쉽게 구현할 수 있음


세그먼트 트리의 구현 단계

세그먼트 트리는 크게 3단계로 구현됩니다.

  1. 트리 초기화하기 리프 노드가 원본 데이터의 개수(N)를 모두 포함할 수 있도록 트리의 크기를 결정합니다. 일반적으로 N*4의 크기로 배열을 할당하면 충분하며, 재귀적으로 각 노드에 구간의 연산 결과를 저장하며 트리를 구축합니다.

  2. 질의값 구하기 주어진 범위(구간)에 대한 연산 결과를 찾는 과정입니다. 찾고자 하는 구간을 완전히 포함하는 노드들의 값을 효율적으로 조합하여 결과를 반환합니다. 이 과정이 세그먼트 트리의 핵심 효율성을 보여줍니다.

  3. 데이터 업데이트하기 특정 위치의 데이터가 변경되었을 때, 그 변경 사항을 관련된 모든 노드(리프 노드부터 루트 노드까지의 경로)에 전파하여 트리의 정합성을 유지합니다.


예제: 백준 2042번 “구간 합 구하기”

문제 설명

  • N개의 숫자가 주어지고, M번의 데이터 변경과 K번의 구간 합 질의가 발생한다.
  • 데이터 변경: 특정 인덱스의 숫자를 다른 숫자로 변경한다.
  • 구간 합 질의: 특정 구간에 속한 숫자들의 합을 구하여 출력한다.
  • N이 최대 100만, M+K가 최대 2만이므로 O(N) 방식의 순차 합산으로는 시간 초과가 발생한다.


코드 구현

import sys
input=sys.stdin.readline


# start, end: 숫자 배열의 인덱스
# idx: 세그먼트 트리의 인덱스 (1-based index)
def make_tree(start, end, idx):
    if start == end:
        tree[idx] = arr[start]
        return tree[idx]

    mid = (start + end) // 2
    tree[idx] = make_tree(start, mid, idx * 2) + make_tree(mid + 1, end, idx * 2 + 1)

    return tree[idx]


# target: 수정할 값의 숫자 배열 인덱스
# value: 수정할 값 (기존 값에 얼만큼 더해야 하는지의 값)
def update_tree(start, end, idx, target, value):
    if target < start or target > end:
        return
    
    tree[idx] += value
    if start == end:
        return
    
    mid = (start + end) // 2
    update_tree(start, mid, idx * 2, target, value)
    update_tree(mid + 1, end, idx * 2 + 1, target, value)


# left, right: 구하고자 하는 범위
def sum_tree(start, end, idx, left, right):
    if right < start or left > end:
        return 0

    if left <= start and right >= end:
        return tree[idx]

    mid = (start + end) // 2
    return sum_tree(start, mid, idx * 2, left, right) + sum_tree(mid + 1, end, idx * 2 + 1, left, right)


N, M, K = map(int, input().split())
arr = []
tree = [0] * (N * 4)  # 충분한 트리 공간 확보
for _ in range(N):
    arr.append(int(input()))

make_tree(0, N-1, 1)
for _ in range(M + K):
    a, b, c = map(int, input().split())
    if a == 1:
        diff = c - arr[b-1]
        arr[b-1] = c  # 기존 숫자 배열 값 변경
        update_tree(0, N-1, 1, b-1, diff)
    else:
        print(sum_tree(0, N-1, 1, b-1, c-1))


코드 분석

1. 트리 초기화하기 (make_tree)

tree = [0] * (N * 4)
arr = []
for _ in range(N):
    arr.append(int(input()))

make_tree(0, N-1, 1)

→ (개념 1 적용) 원본 데이터를 arr에 저장하고, N*4 크기의 tree 배열을 생성하여 충분한 공간을 확보합니다. make_tree 함수는 루트 노드(인덱스 1)부터 시작하여 재귀적으로 각 노드에 자식 노드들의 합을 채워 넣어 구간 합 트리를 완성합니다.

2. 질의값 구하기 (sum_tree)

def sum_tree(start, end, idx, left, right):
    if right < start or left > end:
        return 0

    if left <= start and right >= end:
        return tree[idx]

    mid = (start + end) // 2
    return sum_tree(start, mid, idx * 2, left, right) + sum_tree(mid + 1, end, idx * 2 + 1, left, right)

→ (개념 2 적용) “선택된 노드를 모두 더한다”는 개념이 여기에 적용됩니다. Case 2에서처럼, 찾으려는 구간을 완전히 대표하는 노드를 만나면 그 노드의 값(미리 계산된 합)을 즉시 반환하여 탐색을 최적화합니다. Case 3에서는 구간을 나누어 필요한 노드들을 찾아 합칩니다. 이 과정을 통해 O(log N)의 빠른 속도를 달성합니다.

3. 데이터 업데이트하기 (update_tree)

def update_tree(start, end, idx, target, value):
    if target < start or target > end:
        return

    tree[idx] += value
    if start == end:
        return

    mid = (start + end) // 2
    update_tree(start, mid, idx * 2, target, value)
    update_tree(mid + 1, end, idx * 2 + 1, target, value)

→ (개념 3 적용) “자신의 부모 노드로 이동하면서 업데이트”하는 원리를 Top-down 재귀 방식으로 구현했습니다. update_tree는 변경이 필요한 인덱스(target)를 포함하는 모든 노드를 루트부터 리프까지 찾아 내려가면서 값을 갱신합니다. 특히, value 파라미터로 새로운 값이 아닌 **기존 값과의 차이(diff)**를 전달하여, 관련된 모든 구간 합에 해당 차이만큼만 더해주면 되므로 매우 효율적입니다.


정리

항목 설명
알고리즘 유형 세그먼트 트리 (Segment Tree)
주요 조건 빈번한 데이터 변경과 구간 합 쿼리 처리
핵심 자료구조 배열 기반 트리, 재귀 함수
시간 복잡도 초기화: O(N), 쿼리/업데이트: O(log N)
응용 예시 구간 최소/최대값, 구간 곱, 펜윅 트리(BIT) 등