[백준 10830번 C/C++] 행렬 제곱

 

https://www.acmicpc.net/problem/10830

 

10830번: 행렬 제곱

크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.

www.acmicpc.net


 

해결전략

 

분할 정복

long long mod_pow(long long base, long long exponent, long long mod)

 

base^exponent를 mod로 나눈 나머지를 반환한다.

  • 지수가 0인 경우: base^0 = 1 다. 따라서 1을 반환한다.
  • 지수가 홀수인 경우: base^(2 * k + 1) = (base^k) * (base^k) * base이므로 base를 한 번 곱하고 분할하여 나머지를 곱한다.
  • 지수가 짝수인 경우: base^(2 * k) = (base^k) * (base^k)이므로 분할하여 나머지를 곱한다.

mod_pow 함수를 호출해 계산을 수행하려면 main 함수에서 기본값, 지수 및 나머지를 전달한다.

분할 정복 알고리즘을 사용하여 (A^B) mod C를 계산한다.

 

 

행렬의 곱

 	for (int y = 0; y < arr1.size(); y++) {
        for (int x = 0; x < arr2[0].size(); x++) {
            for (int k = 0; k < arr2.size(); k++)
            {
                answer[y][x] += arr1[y][k] * arr2[k][x];
            }
        }
    }

 

유사 문제

 

2023.07.31 - [⭐ 코딩테스트/백준] - [백준 1629번 C/C++] 곱셈

 

2023.09.13 - [⭐ 코딩테스트/프로그래머스] - [프로그래머스 C++] 행렬의 곱셈


 

처음 시도한 코드 - 메모리 초과, 테스트 케이트 모두 통과 

 

#include <iostream>
#include <vector>
using namespace std;

int n;
long long b;
vector<vector<long long>> ori, v;

void Calculate(int cnt)
{
	if(cnt == b){
		return;
	}

	vector<vector<long long>> temp(n, vector<long long>(n, 0));
	for (int y = 0; y < n; y++) {
		for (int x = 0; x < n; x++) {
			for (int k = 0; k < n; k++) {
				temp[y][x] += (v[y][k] * ori[k][x]);
			}
			temp[y][x] %= 1000;
		}
	}
	v = temp;
	
	Calculate(cnt + 1);
}

int main()
{
	cin >> n >> b;
	ori.resize(n, vector<long long>(n));
	v.resize(n, vector<long long>(n));
	for(int i=0; i<n; i++){
		for(int j=0; j<n; j++){
			cin >> ori[i][j];
		}
	}
	v = ori;

	Calculate(1);

	for (int y = 0; y < n; y++) {
		for (int x = 0; x < n; x++) {
			cout << v[y][x] << " ";
		}
		cout << "\n";
	}

	return 0;
}

 


 

 

정답 코드

 

#include <iostream>
#include <vector>
using namespace std;

int n;
long long b;
vector<vector<long long>> ori;

// 두 행렬의 곱셈 연산을 수행하는 함수
vector<vector<long long>> Calculate(vector<vector<long long>> a, vector<vector<long long>> b)
{
	vector<vector<long long>> temp(n, vector<long long>(n, 0));
	for (int y = 0; y < n; y++) {
		for (int x = 0; x < n; x++) {
			for (int k = 0; k < n; k++) {
				temp[y][x] += (a[y][k] * b[k][x]);
			}
			temp[y][x] %= 1000;
		}
	}

	return temp;
}

vector<vector<long long>> Power(long long exponent)
{
	// 지수 = 1인 경우
	if (exponent == 1){
		// 원소의 값이 1000이상인 경우를 고려하여 %1000 연산
		vector<vector<long long>> tmp(n, vector<long long>(n));
		tmp = ori;
		for (int y = 0; y < n; y++) {
			for (int x = 0; x < n; x++) {
				if (tmp[y][x] >= 1000)
					tmp[y][x] %= 1000;
			}
		}
		return tmp;
	}

	// 지수가 홀수
	else if(exponent % 2 == 1){
		return Calculate(ori, Power(exponent - 1));
	}
	// 지수가 짝수
	else{
		vector<vector<long long>> tmp = Power(exponent / 2);
		return Calculate(tmp, tmp);
	}
}

int main()
{
	ios::sync_with_stdio(false);
	cin.tie(0); cout.tie(0);

	cin >> n >> b;
	ori.resize(n, vector<long long>(n));
	for(int i=0; i<n; i++){
		for(int j=0; j<n; j++){
			cin >> ori[i][j];
		}
	}

	vector<vector<long long>> result = Power(b);

	for (int y = 0; y < n; y++) {
		for (int x = 0; x < n; x++) {
			cout << result[y][x] << " ";
		}
		cout << "\n";
	}

	return 0;
}