피보나치 수는 이전 $2$개 항을 더한 값이 지금 항이 되는 수열입니다. $n$번째 피보나치 수의 정의는 다음과 같습니다. $f_n = f_{n-1} + f_{n-2} \ (n \geq 2), \ f_0 = 0, f_1 = 1$ 이러한 피보나치 수를 구하는 알고리즘은 여러가지가 있습니다. 가장 대표적으로는 피보나치 수의 정의 그대로 재귀함수를 구현하여 구하는 시간복잡도 $\mathcal{O}(2^n)$이 걸리는 알고리즘이 있습니다. 이에 대한 간단한 구현은 아래 코드에 나와있습니다.
def fibo(n):
if n == 0: return 0
if n == 1: return 1
return fibo(n-1) + fibo(n-2)
하지만 이런식으로 구하면 $n$이 $40$정도만 되어도 $2^{40} \approx 10^{12}$정도로 구하는데 시간이 매우 오래 걸립니다. 이러한 문제점을 해결하기 위해서 우리는 $memoization$을 사용해 줄 수 있습니다. 재귀 함수 내에서의 중복되는 연산을 메모리를 사용해 저장하여, 중복되는 연산을 줄이는 기법입니다. 이러면 시간복잡도를 $\mathcal{O}(n)$까지 줄일 수 있습니다. 간단한 구현은 아래 코드에 나와있습니다.
dp = [-1]*10000 #10000 크기의 memoization을 할 배열
def fibo(n):
if n == 0:
dp[n] = 0
return dp[n]
if n == 1:
dp[n] = 1
return dp[n]
if dp[n] != -1: return dp[n]
dp[n] = fibo(n-1) + fibo(n-2)
return dp[n]
그러나 이보다 더욱 빠른 속도를 요구하는 문제들이 있습니다.
이 문제의 $n$의 제한은 $10^{18}$으로, $\mathcal{O}(n)$의 시간복잡도로 풀기에는 터무니없이 부족합니다. 어떻게 해야할까요? 답은 행렬을 통한 최적화입니다. 갑자기 행렬이 뜬금없이 튀어나와서 이상하지만, 차근차근 살펴봅시다.
피보나치 수의 점화식은 $f_n = f_{n-1} + f_{n-2}$입니다. 이를 조금 변형해서 행렬로 표현해 봅시다.
$\begin{pmatrix} f_n \\ f_{n-1}\end{pmatrix} = \begin{pmatrix}1 & 1\\1 & 0 \end{pmatrix} \begin{pmatrix} f_{n-1} \\ f_{n-2} \end{pmatrix}$ 이 식처럼 행렬로 표현할 수 있다는 사실을 알 수 있습니다.
이제 이 식을 조금 관찰해봅시다.
$\begin{pmatrix} f_{2} \\ f_{1} \end{pmatrix} = \begin{pmatrix}1 & 1\\1 & 0 \end{pmatrix} \begin{pmatrix} f_{1} \\ f_{0} \end{pmatrix}, \begin{pmatrix} f_{3} \\ f_{2} \end{pmatrix} = \begin{pmatrix}1 & 1\\1 & 0 \end{pmatrix} \begin{pmatrix} f_{2} \\ f_{1} \end{pmatrix} = {\begin{pmatrix}1 & 1\\1 & 0 \end{pmatrix}}^2 \begin{pmatrix} f_{1} \\ f_{0} \end{pmatrix}$
$\cdots
\begin{pmatrix} f_{n} \\ f_{n-1} \end{pmatrix} =
{\begin{pmatrix}1 & 1\\1 & 0 \end{pmatrix}}^{n-1}
\begin{pmatrix} f_{1} \\ f_{0} \end{pmatrix}$이 성립함을 알 수 있습니다.
우리는 $f_1$과 $f_0$의 값을 알고 있으므로 ${\begin{pmatrix}1 & 1\\1 & 0 \end{pmatrix}}^{n-1}$ 의 값만 빠르게 구해주면 $f_n$의 값을 구할 수 있습니다. 이는 분할 정복을 이용한 거듭제곱을 통해 $\mathcal{O}(logn)$에 구해줄 수 있습니다. 이 방법을 사용하면 우리는 $n$번째 피보나치 수를 $\mathcal{O}(logn)$시간에 구할 수 있습니다. 아래는 위 문제의 정답 코드입니다.
import sys
sys.setrecursionlimit(10000000)
mod = 1000000007
def square_matrix(matrix1,matrix2):
result = [[0]*2 for _ in range(2)]
for i in range(2):
for j in range(2):
for k in range(2):
result[i][j] += (matrix1[i][k]%mod*matrix2[k][j]%mod)%mod
return result
def solve(cnt):
if cnt == 1:
return matrix
m = solve(cnt//2)
if cnt&1:
return square_matrix(square_matrix(m,m),matrix)
else:
return square_matrix(m,m)
matrix = [[1,1],[1,0]]
n = int(input())
if n == 1:
print(1)
exit()
print(solve(n-1)[0][0]%mod)