项链 FZU2289

520要到了,Xenon打算送给Cherry一条项链,这条项链有m个镶孔,编号分别为0到m-1。Xenon手上有n种颜色不同的钻石,他想将钻石镶嵌在项链上,而且每个相邻的镶孔,镶嵌上的钻石颜色要不一样。
Xenon想知道他可以镶嵌出多少种不同的项链
PS:第k号镶孔和第(k-1+m)%m以及(k+1+m)%m号镶孔这两个镶孔是相邻关系。假设m=9,那么0号和1号、8号相邻,3号和2号、4号相邻。
两条项链不同当且仅当存在编号k,两条项链的k号镶孔的钻石颜色不一样。
Input
题目包含多组测试数据,每组测试数据包含两个正整数n和m,分别表示钻石的颜色种数和一条项链所需要的钻石个数,以空格隔开。
n≤50,m≤10^18
Output
输出一个整数,表示Xenon可以制作出多少种项链。

由于答案很大,请将答案对1000000007(10^9+7)取余。

思路:这就是一个数学模型,简化以后发现这样一个递推式,在m>3的时候成立,a[m]=(n-2)*a[n-1]+(n-1)*a[n-2];这时候就只能用矩阵加速了,还是熟悉的矩阵,熟悉的味道(只不过没有改快速幂的参数,T了)((%队友))

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
const int maxn=1e5+6;
typedef long long LL;
const LL mod=1e9+7;
struct mat{
    LL a[30][30];
    int r,c;
    mat operator *(const mat &b) const  {
        mat ret;
        for(int i=0;i<r;i++) {
            for(int j=0;j<b.c;j++) {
                ret.a[i][j]=0;
                for(int k=0;k<c;k++) {
                    ret.a[i][j]+=a[i][k]*b.a[k][j],ret.a[i][j]%=mod;
                }
            }
        }
        ret.r=r;
        ret.c=b.c;
        return ret;
    }
    mat init_unit(int x) {
        r=c=x;
        for(int i=0;i<r;i++) {
            for(int j=0;j<c;j++) {
                if(i==j)a[i][j]=1;
                else a[i][j]=0;
            }
        }
    }
}unit;
mat pow(mat p,LL n) {
    unit.init_unit(3);
    mat ans=unit;
    while(n) {
        if(n&1)ans=p*ans;
        p=p*p;
        n>>=1;
    }
    return ans;
}
int main() {
    LL n,m;
    while(cin>>n>>m) {
        mat A;
        if(m<=3) {
            if(m==1) printf("%lld\n",n%mod);
            else if(m==2) printf("%lld\n",n*(n-1)%mod);
            else if(m==3) printf("%lld\n",(n*(n-1)%mod)*(n-2)%mod);
            else printf("%d\n",0);
            continue;
        }
        A.r=A.c=2;
        A.a[0][0]=(n-2)%mod;
        A.a[0][1]=1;
        A.a[1][0]=(n-1)%mod;
        A.a[1][1]=0;
        mat ans;
        ans=pow(A,m-3);
        mat tmp;
        tmp.a[0][0]=n*(n-1)*(n-2)%mod;
        tmp.a[0][1]=n*(n-1)%mod;
        ans=tmp*ans;
        printf("%lld\n",ans.a[0][0]%mod);
    }
}