BOJ 풀이

[BOJ/백준 2042/C++] 구간 합 구하기

Vfly 2023. 3. 23. 17:12

https://www.acmicpc.net/problem/2042

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

[문제 요약]

- 어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다.

 

- 만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.

 

- 첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.

 

[문제 조건]

- 첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다.

- M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다.

- 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력하면 된다.

입력으로 주어지는 모든 수는 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.

- 시간 제한 : 2초

- 메모리 제한 : 256MB

 

[문제 풀이]

이번에는 "세그먼트 트리"라는 자료구조에 대해 알아보자. 물론 지난번 히스토그램 문제에서 한번 다룬 적이 있지만 이번 문제와 다음에 이어질 포스팅의 문제가 세그먼트 트리의 근본이 되는 문제다.

https://tobrother.tistory.com/89

 

[BOJ/백준 1725/C++] 히스토그램

https://www.acmicpc.net/problem/1725 1725번: 히스토그램 첫 행에는 N (1 ≤ N ≤ 100,000) 이 주어진다. N은 히스토그램의 가로 칸의 수이다. 다음 N 행에 걸쳐 각 칸의 높이가 왼쪽에서부터 차례대로 주어진다.

tobrother.tistory.com

 

필자는 다음 링크에서 학습했다. https://www.acmicpc.net/blog/view/9

 

세그먼트 트리 (Segment Tree)

글이 업데이트 되었습니다. https://book.acmicpc.net/ds/segment-tree 문제 배열 A가 있고, 여기서 다음과 같은 두 연산을 수행해야하는 문제를 생각해봅시다. 구간 l, r (l ≤ r)이 주어졌을 때, A[l] + A[l+1] + ..

www.acmicpc.net

 

먼저 세그먼트 트리란 무엇일까? 

필자는 다음과 같이 이해했다.

 - 배열 간격에 대한 정보를 이진트리의 형태로 저장하는 자료구조

 

예를들면 다음과 같은 경우를 생각해보자.

 

A = {1, 2, 3, 4, 5, ... , N}이라는 배열이 있을 때 아래와 같은 연산들이 총 M번 있다고 생각해보자

  • [i ~ j] 구간의 합
  • i번째 값을 v로 변경

단순하게 생각한다면 첫번째 작업은 시간복잡도가 O(N), 두번째 작업은 O(1)이 된다. 따라서 M번 수행하면

O(NM) + O(M) = O(MN)이 된다.

 

하지만 세그먼트 트리를 이용하면 두 연산을 모두 O(logN) 시간에 해결 할 수 있기때문에 N이 커지면 세그먼트리를 이용하는 방법이 더 효과적이다.

 

 

세그먼트 트리

세그먼트 트리는 이진트리의 형태를 띄고 있다. 따라서 1차원 배열을 이용하여 나타낼 수 있다.

 

세그먼트 트리에서 리프노드는 배열의 그 값 자체를 의미하고 다른 노드들은 왼쪽 서브트리와 오른쪽 서브트리의 정보를 합하여 저장하고 있다.

 

이진트리의 형태를 띄고 있기때문에 현재 노드의 번호가 n일 때 왼쪽 서브트리의 번호는 n*2, 오른쪽 서브트리의 번호는 n*2+1가 된다.

 

만약 N=10일 경우 세그먼트 트리는 다음 그림과 같이 나타낼 수 있다.

- 출처 : https://www.acmicpc.net/blog/view/9

 

각 노드에 적힌 수 들은 노드가 저장하고 있는 합의 범위를 나타낸다.

 

이제 만드는 방법을 알아보자.

 

세그먼트 트리 구현

세그먼트 트리는 이진트리다.

따라서 크기가 1차원 배열을 이용하여 리프노드의 개수가 n개인 세그먼트 트리를 만들 때 필요한 노드 수는 다음과 같다.

따라서 일반적으로 세그먼트 트리의 배열의 크기는 일반적으로 4N개로 잡고 시작한다.

 

세그먼트 트리는 일반적인 for문과 while문으로 구현하는 것 보다 재귀함수의 형태로 구현하는게 더 편하다.

long long init(long long start, long long end, long long node)
{
	if (start == end) return seg_tree[node] = dt[start];

	long long mid = (start + end) / 2;

	return seg_tree[node] = init(start, mid, node * 2) + init(mid + 1, end, node * 2 + 1);
}

위의 코드는 N개의 데이터가 들어왔을 때 세그먼트리를 구현하는 코드다.

 

첫째줄의 start == end의 경우는 리프노드를 의미한다. 리프노드는 배열의 값 자체를 가지게 때문에 세그먼트 트리에 배열의 값을 넣어 준다.

 

node의 왼쪽 서브트리의 번호는 node *2 , 오른쪽 서브트리의 번호는 node * 2 + 1가 된다.

리프노드가 아닌 노드들은 각 왼,오 서브트리에서 구한 값들을 더해서 그 합을 저장한다. 

- 문제에서 구간의 합을 구하라고 했기 때문에 합의 대한 정보만 구할뿐 문제에 따라 노드들이 가지는 정보는 문제마다 다를 수 있다.

 

세그먼트 트리를 이용한 합 찾기

구하고자 하는 구간에 대한 입력 Left 와 Right가 들어올 때 합을 찾으려면 트리를 순회하면서 값을 구할 수 있다.

 

예를 들어 N=10일때 0~9까지 합을 구하는 경우는 루트 노드 하나만 확인해도 값을 구할 수 있다.

 

 

3~9까지 합을 구하는 경우는 다음과 같다.

 

 

 

현재 노드가 저장하고 있는 구간에 대한 정보를 [start, end]라고 하고, 구하고자 하는 구간을 [left,right]라고 할때 생길 수 있는 경우는 다음과 같을 것이다.

  1. 1번과 4번 : [start,end]와 [left,right]가 어떤 부분도 겹치지 않는 경우
  2. 5번과 6번 : 한쪽이 완전히 포함하는 경우
  3. 2번과 3번 : 걸쳐있는 경우

1번의 경우는 if ( left  > end || right < start ) 로 나타낼 수 있다.

2번의 경우는 6번동그라미의 경우는 if( left <= start && end <= right )로 나타낼 수 있다.

3번경우와 5번동그라미의 경우는 왼쪽 서브트리와 오른쪽 서브트리를 루트로 하는 트리에서 다시 탐색을 해야 값을 구할 수 있다.

long long sum(long long start, long long end, long long left, long long right, long long node)
{
	//1번 경우
	if (left > end || right < start) return 0;

	//2번 경우에서 6번 동그라미의 경우
	if (left <= start && end <= right) return seg_tree[node];

	//나머지 경우들
	long long mid = (start + end) / 2;
	return sum(start, mid, left, right, node * 2) + sum(mid + 1, end, left, right, node * 2 + 1);
}

 

 

세그먼트 트리 변경

만약 데이터 중 일부가 변경이 된다면 세그먼트 트리에서 그 변경된 수를 포함하는 구간을 담당하는 모든 노드들에 대해 변경을 해줘야 한다.

 

예를 들어 3번째 수를 변경한다면 다음과 같은 노드들이 변경된다.

5번째 수를 변경한다면 다음 그림과 같다.

 

index번째 수를 val로 변경한다면 ,그 수가 얼마큼 변하는지를 알아햔 한다. 이 값을 필자는 diff로 두었다.

그렇다면 diff = c- dt[index]로 구할 수 있다.

 

트리를 순회하면서 변경에는 2가지 경우가 있다.

  • 현재 노드가 담당하는 구간[start,end]에 index가 포함되는 경우
  • 현재 노드가 담당하는 구간[start,end]에 index가 포함되지 않는 경우

노드의 구간에 포함되는 경우에는 값을 diff만큼 증가시켜서 값을 변경할 수 있다.

void update(long long start, long long end, long long node, long long index, long long diff)
{
	if (index < start || index > end) return;

	seg_tree[node] += diff;

	if (start == end) return;
	long long mid = (start + end) / 2;

	update(start, mid, node * 2, index, diff);
	update(mid + 1, end, node * 2 + 1, index, diff);
}

 

 

[2042 전체 소스 코드]

#include <iostream>
#include <algorithm>
#include <queue>
#include <vector>
#include <map>
#include <string>
#include <set>

#define INF 987654321
#define mod 1000000
#define pii pair<int,int>

using namespace std;

long long N, M, K;
vector<long long> dt;
vector<long long> seg_tree;

long long init(long long start, long long end, long long node)
{
	if (start == end) return seg_tree[node] = dt[start];

	long long mid = (start + end) / 2;

	return seg_tree[node] = init(start, mid, node * 2) + init(mid + 1, end, node * 2 + 1);
}

long long sum(long long start, long long end, long long left, long long right, long long node)
{
	if (left > end || right < start) return 0;

	if (left <= start && end <= right) return seg_tree[node];

	long long mid = (start + end) / 2;
	return sum(start, mid, left, right, node * 2) + sum(mid + 1, end, left, right, node * 2 + 1);
}

void update(long long start, long long end, long long node, long long index, long long diff)
{
	if (index < start || index > end) return;

	seg_tree[node] += diff;

	if (start == end) return;
	long long mid = (start + end) / 2;

	update(start, mid, node * 2, index, diff);
	update(mid + 1, end, node * 2 + 1, index, diff);
}

int main()
{
	ios::sync_with_stdio(false);
	cin.tie(NULL);
	cout.tie(NULL);

	cin >> N >> M >> K;
	dt.resize(N+1); seg_tree.resize(N*4);
	for (int i = 0; i < N; i++) cin >> dt[i];

	init(0, N - 1, 1);

	for (int i = 0; i < M + K; i++)
	{
		long long a, b, c;
		cin >> a >> b >> c;
		//변경
		if (a == 1)
		{
			long long tmp = c - dt[b - 1];
			dt[b - 1] = c;
			update(0, N - 1, 1, b - 1, tmp);
			
		}
		//구하기
		else if(a == 2)
		{
			cout << sum(0, N - 1, b-1, c-1, 1) << '\n';
		}
	}
	return 0;
}

값을 변경할때 dt배열도 c값으로 변경 시켜줘야 한다는 점을 유의하자.