알고리즘

[Java] 백준 2042번 구간 합 구하기

J3SUNG 2023. 3. 6. 18:56
728x90

세그먼트 트리 문제입니다.

수열이 주어지고 중간에 수열안의 값들이 자주 변경됩니다. 이 때 주어진 구간의 합을 출력하여야 합니다.

구간의 합을 구하는데 걸리는 시간은 O(N)입니다.
100만 * 10000번의 명령이 있으므로 100억번의 연산이 필요합니다.
O(N)으로 진행하게 된다면 시간초과가 발생합니다.

구간의합을 구하는데 걸리는 시간을 logN까지 줄여야지 문제를 해결할 수 있습니다.

이때 구간의 합을 효율적으로 구하기 위해서 세그먼트 트리를 사용합니다.
세그먼트 트리는 여러 개의 데이터가 존재할 때 특정 구간의 합(최솟값, 최댓값, 곱 등)을 구하는 데 사용하는 자료구조입니다.

https://www.acmicpc.net/blog/view/9

위와 같이 위에 있는 노드가 아래있는 노드들의 합을 가지고 있습니다.
재귀함수를 통해서 찾는 구간에 속한 경우 해당 값을 결과값에 추가해줍니다.
추가된 값들이 구간의 합입니다.

수정이 될 때는 해당 위치의 리프노드 부터 부모노드로 타고가면서 전부 수정해주었습니다.

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.StringTokenizer;

public class Main {
	static int n;
	static int m;
	static int k;
	static long[] arr;
	static long[] tree;

	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
		StringTokenizer st = new StringTokenizer(br.readLine());

		n = Integer.parseInt(st.nextToken());
		m = Integer.parseInt(st.nextToken());
		k = Integer.parseInt(st.nextToken());

		arr = new long[n + 1];
		tree = new long[n * 4];
		for (int i = 1; i <= n; ++i) {
			arr[i] = Long.parseLong(br.readLine());
		}
		init(1, n, 1);

		for (int i = 0; i < m + k; ++i) {
			st = new StringTokenizer(br.readLine());
			int a = Integer.parseInt(st.nextToken());
			int b = Integer.parseInt(st.nextToken());
			long c = Long.parseLong(st.nextToken());

			if (a == 1) {
				update(1, n, b, c - arr[b], 1);
				arr[b] = c;
			} else if (a == 2) {
				long sum = search(1, n, b, (int)c, 1);
				bw.write(sum + "\n");
			}
		}
		bw.close();
	}
	public static long search(int start, int end, int left, int right, int index) {
		if(end < left || start > right) {
			return 0;
		}
		if(start >= left && end <= right) {
			return tree[index];
		}
		int mid = (start + end) / 2;
		
		return search(start, mid, left, right, index * 2) + search(mid + 1, end, left, right, index * 2 + 1);
	}
	public static void update(int start, int end, int cur, long val, int index) {
		if (cur < start || cur > end) {
			return;
		}
		tree[index] += val;
		if (start == end) {
			return;
		}
		int mid = (start + end) / 2;
		update(start, mid, cur, val, index * 2);
		update(mid + 1, end, cur, val, index * 2 + 1);
	}

	public static long init(int start, int end, int index) {
		if (start == end) {
			tree[index] = arr[start];
			return tree[index];
		}

		int mid = (start + end) / 2;
		tree[index] = init(start, mid, index * 2) + init(mid + 1, end, index * 2 + 1);
		return tree[index];
	}
}