BOJ 풀이

[BOJ/백준 11401/C++] 이항 계수 3

Vfly 2023. 3. 6. 17:02

https://www.acmicpc.net/problem/11401

 

11401번: 이항 계수 3

자연수 \(N\)과 정수 \(K\)가 주어졌을 때 이항 계수 \(\binom{N}{K}\)를 1,000,000,007로 나눈 나머지를 구하는 프로그램을 작성하시오.

www.acmicpc.net

 

[문제 요약]

- 자연수 과 정수 가 주어졌을 때 이항 계수 nCk 를 1,000,000,007로 나눈 나머지를 구하는 프로그램을 작성하시오.

 

[문제 조건]

- 첫째 줄에 가 주어진다. (1 ≤  ≤ 4,000,000, 0 ≤  ≤ )

- 시간 제한 : 1초

- 메모리 제한 : 256MB

 

[문제 풀이]

문제는 아~~~~~~주 간단한다.

 

nCk의 값을 구하면되는데, n과 k가 최대 400만 까지인 괴랄한 값을 구해야 한다.

 

조합의 값을 구하는 방법은 여러가지가 있다.

코드를 작성하는데 있어서 가장 널리 쓰이는 방법은 2번째와 4번째 방법이 가장 많이 쓰일 것이다.

 

2번째 방법은 그대로 팩토리얼을 계산해서 구하면 되고

4번째 방법은 다이나믹 프로그래밍(DP)를 이용하여 풀 수 있다.

 

하지만 두 방법 모두 n이 커지게 되면 엄청난 메모리를 잡아먹기 때문에 n이 커지면 다른 방법을 이용하여 조합의 값을 구해야한다.

 

새로운 방법은 페르마 소정리를 이용하여 구하는 방법이다.

 

페르마 소정리가 뭐냐?

위의 그림이란다.

 

그러면 위 정리를 어떻게 이용해야되는가?

 

먼저 가장 근본적인 식에서부터 시작해보자. 

우리가 구하고자 하는 값은 결국에는 nCr % mod 다.

하지만 위 식의 분모 분자에 % mod를 적용하면 원래 구하고자 하는 값과 다른값이 나오게 된다. 왜냐하면 모듈러연산은 곱셈이랑 덧셈에서 적용되고 분수에서는 적용되지 않는다. 따라서 위의 형태를 곱셈이나 덧셈의 형태로 바꾸어야 하는데 그때 페르마 소정리를 이용하게 된다.

 

먼저 A = n! , B = r!(n-r)!, p = 1,000,000,007라고 하자

그렇다면 우리가 구하는 식은 다음과 같을 것이다.

먼저 B의 역수를 없애기 위해서 다음과 같이 페르마 소정리를 잘 조작해주자.

이것을 원래 구하려던 식에 대입하면 다음과 같다.

이제 위 식에서 A = n! , B = k!(n-k)! , p = 1,000,000,007 를 대입하여 계산해주면 된다.

 

근데 상식적으로 B의 1,000,000,007승은 너무 괴랄하다.

따라서 분할정복을 이용해서 거듭제곱을 해줘야 한다.

필자는 다음과 같은 방법으로 구하였다.

long long calc(int a, int b)
{
	if (b == 0) return 1;

	if (b % 2 == 0)
	{
		long long tmp = calc(a, b / 2) % mod;
		return (tmp * tmp) % mod;
	}
	else
	{
		return (calc(a, b - 1) * a) % mod;
	}
}

 

 

[전체 소스 코드]

#include <iostream>
#include <algorithm>
#include <queue>
#include <vector>
#include <map>

#define INF 987654321
#define mod 1000000007
#define pii pair<int,int>

using namespace std;

int N, K;

long long calc(int a, int b)
{
	if (b == 0) return 1;

	if (b % 2 == 0)
	{
		long long tmp = calc(a, b / 2) % mod;
		return (tmp * tmp) % mod;
	}
	else
	{
		return (calc(a, b - 1) * a) % mod;
	}
}

int main()
{
	ios::sync_with_stdio(false);
	cin.tie(NULL);
	cout.tie(NULL);

	cin >> N >> K;

	long long A = 1, B = 1;

	//n!
	for (long long i = 1; i <= N; i++) { A *= i; A %= mod; }
    	//k!(n-k)!
	for (long long i = 1; i <= K; i++) { B *= i; B %= mod; }
	for (long long i = 1; i <= N-K; i++) { B *= i; B %= mod; }

	long long tmp = calc(B, mod - 2)%mod;
	cout << (A * tmp) % mod;
}

그래도 언젠가 이 방법을 한번쯤은 써먹지 않을까.....?