[경쟁 프로] 반복 제곱법【Java】【Python】

오랜만에 경쟁 프로 기사입니다.

이번은 반복 제곱법에 대해입니다. Java로 작성해 보겠습니다. 또, 반복 제곱법을 사용한 문제를 Java, Python으로 풀어 보겠습니다.

반복 제곱법이란?



승제의 계산량을 줄이는 기법입니다. 다음과 같이 지수를 2의 승표기를 하여 누승 계산을 합니다.
N승의 계산이 O(N)에서 O(logN)가 됩니다.

계산 방법



3^10을 구합니다.
10 = 2^3 + 2^1로 표현할 수 있기 때문에,
3^10 = 3^(2^3 + 2^1) = 3^(2^3) * 3^(2^1)로 표현할 수 있습니다.

깔끔하게 쓰면 다음과 같습니다.
3^{10} = 3^{2^3} * 3^{2^1}

"그건 알지만 왜 그렇게 생각하면 계산량이 줄어들어?"라는 목소리가 들려올 것 같습니다.
조금 정중하게 설명하네요.
이진수 비트 연산을 사용하고 있습니다.
이런 느낌입니다 ↓



계산 과정은 다음과 같습니다.
tmp ... 임시 변수
ans・・・3^10의 해가 들어가는 변수


#
지금 참조하는 비트
설명


1
최하위 비트
bit가 0이므로, ans의 갱신은 하지 않는다(ans=1)tmp에 tmp를 걸어 9로 한다(3*3=9)

2
아래에서 두 번째
비트가 1이므로 ans*=tmp로 한다(ans=1*9=9) tmp에 tmp를 곱하여 81로 한다(9*9=81)

3
아래에서 세 번째
bit가 0이므로, ans의 갱신은 하지 않는다(ans=9)tmp에 tmp를 걸어 6561로 한다(81*81=6561)

4
아래에서 네 번째
비트가 1이므로, ans의 갱신을 한다(ans=9*6561=59049)


이런 느낌입니다.
tmp가 중요하네요. 처음은 3이었습니다만, 상위 bit를 참조해 갈 때마다 제곱씩 늘어나갑니다.
2진수의 자리수가 늘어나는 것처럼 tmp도 늘어나는 느낌입니다.

구현



그럼, 구현해 보겠습니다.
    public static long myPow(long a, long n) {
        // a^nを計算
        long ans = 1l;
        long tmp = a;

        // わかりやすくfor文の中ですべて処理
        for (;;) {
            // すべての桁を見終わったら終了
            if (n < 1l) {
                break;
            }

            // 最下位bitが1かどうかの判定
            if (n % 2l == 1l) {
                ans *= tmp;
            }

            // tmpの更新
            tmp *= tmp;

            // nのbitを一つずらす
            n = n >> 1;
        }
        return ans;
    }

네. 이런 느낌입니다.
myPow(3,10) 를 호출하면 확실히 59049가 돌아옵니다.

문제 연습



2021/04/17(토)의 AtCoder, 제2회 일본 최강 프로그래머 학생 선수권의 D문제, Nowhere P 를 풀어 봅니다.

문제는 링크와 같지만 대답은
(P-1)(P-2)^{n-1} mod 1000000007

을 계산하면 됩니다.
다만, n의 제약은 최고로 10^9로, 단순하게 하면 TLE가 될 가능성이 높습니다.
그래서 반복 제곱법의 등장입니다.

자바 답변 예


import java.util.Scanner;

public class Main {

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        long n = sc.nextLong();
        long p = sc.nextLong();
        long MOD = 1000000007l;

        long ans = (p - 1) * modPow(p - 2, n - 1, MOD) % MOD;

        System.out.println(ans);
    }

    public static long modPow(long a, long n, long mod) {
        // a^nを計算
        long ans = 1l;
        long tmp = a;

        // わかりやすくfor文の中ですべて処理
        for (;;) {
            // すべての桁を見終わったら終了
            if (n < 1l) {
                break;
            }

            // 最下位bitが1かどうかの判定
            if (n % 2l == 1l) {
                ans *= tmp;
                ans %= mod;
            }

            // tmpの更新
            tmp *= tmp;
            tmp %= mod;

            // nのbitを一つずらす
            n = n >> 1;
        }
        return ans;
    }

}


이전의 myPow 함수를 조금 바꾸어 modPow로 만들었습니다. 계산의 도중에 mod 취하고 있을 뿐이군요.
이제 문제가 해결되었습니다.

파이썬 답변 예제



덧붙여서, Python의 pow 함수는 반복 제곱법이 구현되어 있으므로 특히 신경쓰지 않고 구현이 가능합니다.
또, pow의 제3 인수에 값을 넣으면 modPow의 구현이 됩니다.

실장 예(공식 해설과 거의 같습니다만)
N, P = map(int, input().split())
MOD = 1000000007
ans = (P-1) * pow(P-2, N-1, MOD) % MOD
print(ans)

이런 느낌입니다.

반대로 경프로 이외에 반복 제곱법이나 modPow가 필요한 장면을 가르쳐 주었을 정도입니다만・・・.

비교적 간단한 알고리즘이었기 때문에 기억해 두고 싶네요.

반복 제곱법을 소개했습니다. 이번 기사는 여기까지입니다.

좋은 웹페이지 즐겨찾기