BOJ 풀이

[BOJ/백준 10986/C++] 나머지 합

Vfly 2022. 12. 22. 21:28

https://www.acmicpc.net/problem/10986

 

10986번: 나머지 합

수 N개 A1, A2, ..., AN이 주어진다. 이때, 연속된 부분 구간의 합이 M으로 나누어 떨어지는 구간의 개수를 구하는 프로그램을 작성하시오. 즉, Ai + ... + Aj (i ≤ j) 의 합이 M으로 나누어 떨어지는 (i, j)

www.acmicpc.net

 

[문제 요약]

- 수 N개 A1, A2, ..., AN이 주어진다. 이때, 연속된 부분 구간의 합이 M으로 나누어 떨어지는 구간의 개수를 구하는 프로그램을 작성하시오.

즉, Ai + ... + Aj (i ≤ j) 의 합이 M으로 나누어 떨어지는 (i, j) 쌍의 개수를 구해야 한다.

 

[문제 조건]

- 첫째 줄에 N과 M이 주어진다. (1 ≤ N ≤ 10^6, 2 ≤ M ≤ 10^3)

- 둘째 줄에 N개의 수 A1, A2, ..., AN이 주어진다. (0 ≤ Ai ≤ 10^9)

- 시간 제한 : 1초

- 메모리 제한 : 256MB

 

[문제 풀이]

아마 대부분 사람들이 이 문제를 풀때 시간초과를 가장 많이 경험할 것이다. 아마 시간초과가 났다는건 모든 구간합에 대해 M으로 나누고 나머지 0인 경우를 카운트 해주면 정답이 되겠지만, 이 방법의 경우 O(N^2)의 시간을 갖게 된다. 

 

하지만 문제조건을 보면 N이 최대 1백만이라서 N^2이면 단순계산으로만 약 1조번을 연산한다. 당연히 시간초과가 걸린다.

 

그렇다는건 N^2의 시간보다 빠르게 풀어야한다.

 

먼저, 아주 간단하게 생각해보자. 우리가 구하고자 하는 답은 구간합이 M으로 나눠떨어지는 개수만 구하면된다. 즉 굳이 다 해볼 필요가 없을 수 있다는 것이다.

 

일단 주어진 예제를 이용해 모든 구간합을 표로 한번 나타내 보았다.

5 3
1 2 3 1 2

여기서 위 배열은 MP배열이라고 이름짓고, 이 배열의 값은 MP[i][j] = i부터 j까지 구간 합을 의미한다.

 

아 참고로 미리 말하자면 표를 이용한 방식은 메모리초과가 나니 정확한 풀이 방법은 아니니 주의하자. ( 빌드업이다. )

 

위에 표를 자세히보면 MP배열에서 1번행을 제외하고 나머지 행들은 전부 1번행을 이용하여 값을 구할 수 있다.

 

예를들면 MP[2][4] 의 경우는 MP[1][4]에서 MP[1][1]의 값을 빼면 된다.

 

그렇다는건 1번행만 있으면 나머지는 어떻게든 구할 수 있다. 따라서 1번행만 따로 떼어와서 생각해보자.

 

배열이름 : D

아까 우리가 구간합을 구할때 위의 1번행의 두 값을 이용하여 다른 구간 합들을 구할 수 있었다. 

 

그렇다면 1번행을 통해 두 값을 잘 뽑으면 M으로 나눠지는 구간합을 찾을 수도 있을 것이다. 하지만 어떻게 뽑아야 할까?

 

식으로 한번 써보면 (D[i] - D[j]) % M = 0 인 경우를 찾는 것인데.  이 식을 풀어 보면 아래와 같이 된다.

 

D[i] % M - D[j] % M = 0        →    D[i] % M = D[j] % M

 

최종적으로 정리된 식을 보면 M으로 나눴을때 나머지 같은 두 값을 고르면 우리가 원하는 조건에 맞는 구간합을 구할 수 있게 된다.

 

다시 정리해서 말하자면, 1부터 K까지 (1 ≤ K ≤ N) 구간합을 갖고 있고, 이 값들 중 M으로 나눴을때 나머지가 같은 두 수를 뽑으면 M으로 나눠떨어지는 구간합을 구할 수 있다.

 

따라서 1번행의 각 값들을 M으로 나눠 같은 나머지 L 를 갖는 값들의 개수를 세어준 뒤 nC2 = (n * (n - 1)) / 2를 이용해 L의 나머지를 갖는 경우들로 몇개의 구간합을 만들 수 있는지 구할 수 있다.

 

[소스 코드]

 

#include <stdio.h>
#include <stdlib.h>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <cmath>
#include <deque>
#include <list>
#include <math.h>
#include <map>
#include <queue>
#include <stack>

using namespace std;

int N, M;
vector<unsigned long long> input;
vector<unsigned long long> sum;
vector<unsigned long long> rest;

int main()
{
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);
	cout.tie(NULL);

	cin >> N >> M;
	input.resize(N + 1);
	sum.resize(N + 1);
	rest.resize(M);

	for (int i = 1; i <= N; i++) cin >> input[i];
	for (int i = 1; i <= N; i++) sum[i] = sum[i - 1] + input[i];
	for (int i = 0; i <= N; i++) rest[sum[i] % M]++;
	unsigned long long ans = 0;
	for (int i = 0; i < M; i++)
	{
		if (rest[i] == 0) continue;
		else
		{
			unsigned long long tmp = rest[i];
			ans = ans + ((tmp * (tmp - 1)) / 2);
		}
	}
	cout << ans;
}