整理的算法模板合集: ACM模板
实际上是一个全新的精炼模板整合计划
Weblink
https://www.luogu.com.cn/problem/P4463
Problem
k ≤ 1 0 9 , n ≤ 500 k≤10 ^9 ,n≤500 k≤109,n≤500, p ≤ 1 0 9 p \le 10^9 p≤109,并且 p p p 为素数, p > k > n + 1 p>k>n+1 p>k>n+1。
Solution
显然对于一种取值的合法序列,这个序列不管怎么排列,合法序列的值都一样的,我们先考虑暴力计算,设 d p ( i , j ) dp(i,j) dp(i,j) 表示前 i i i 个数取值域范围 [ 1 , j ] [1,j] [1,j] 的所有取值不同的合法序列的值之和。直接转移很不方便,我们可以只考虑递增的序列,即我们仅需讨论第 i i i 个数取还是不取 j j j
即:
d p [ i ] [ j ] = j ∗ d p [ i − 1 ] [ j − 1 ] + d p [ i ] [ j − 1 ] dp[i][j] = j * dp[i - 1][j - 1] + dp[i][j - 1] dp[i][j]=j∗dp[i−1][j−1]+dp[i][j−1]
显然答案就是所有取值不同的合法序列的值之和乘上排列的方案数 n ! n! n!。答案就是 d p [ n ] [ k ] × n ! dp[n][k]\times n! dp[n][k]×n!,但是 k ≤ 1 e 9 k\le 1e9 k≤1e9,考虑优化。
一个DP的递推式可以看作是一个多项式,多项式 f n ( i ) f_n(i) fn(i) 就是 d p [ n ] [ i ] dp[n][i] dp[n][i],那么答案就是 d p [ n ] [ k ] = f n ( k ) dp[n][k] = f_{n}(k) dp[n][k]=fn(k)
代入递推式得:
f i ( j ) − f i ( j − 1 ) = j ∗ f i − 1 ( j − 1 ) f_{i}(j)-f_{i}(j - 1)=j*f_{i - 1}(j - 1) fi(j)−fi(j−1)=j∗fi−1(j−1)
设 f i ( j ) f_i(j) fi(j) 是 g ( n ) g(n) g(n) 次多项式
前面是一个差分的形式,显然有结论:
-
两个 n n n 次多项式的差分是一个 n − 1 n-1 n−1 次多项式
-
两个 n n n 次多项式的前缀和是一个 n + 1 n+1 n+1 次多项式
自己代入展开算一下就知道了
则 f i ( j ) − f i ( j − 1 ) f_i(j)-f_i(j - 1) fi(j)−fi(j−1) 是 g ( n ) − 1 g(n)-1 g(n)−1 次多项式, j ∗ f i − 1 ( j − 1 ) j*f_{i - 1}(j - 1) j∗fi−1(j−1) 是一个 g ( n − 1 ) + 1 g(n-1)+1 g(n−1)+1 次多项式(乘上了一个 j j j 嘛),即:
g ( n ) − 1 = g ( n − 1 ) + 1 , g ( n ) = g ( n − 1 ) + 2 g(n)-1=g(n-1)+1,g(n)=g(n-1)+2 g(n)−1=g(n−1)+1,g(n)=g(n−1)+2
显然 g ( 0 ) = 0 g(0)=0 g(0)=0,则 g ( n ) = 2 × n g(n)=2\times n g(n)=2×n。也就意味着我们只需要计算出 f f f 的前 2 × n + 1 2\times n+1 2×n+1 项的值,就可以唯一确定一个 2 × n 2\times n 2×n 项的多项式,并且因为我们得到的 2 × n + 1 2\times n+1 2×n+1 项的值中的 x x x 还是连续的,也就意味着我们可以用拉格朗日插值法 ,在 O ( 2 ∗ n ) / O ( n 2 ) O(2*n) / O(n^2) O(2∗n)/O(n2) 的复杂度下直接求出 f n ( k ) f_n(k) fn(k),既是所求的答案。由于需要 O ( n 2 ) O(n^2) O(n2) 预处理 d p dp dp 数组,所以复杂度为 O ( n 2 ) O(n^2) O(n2)。
当然本题还有生成函数的做法,利用多项式科技可以做到 O ( n l o g n ) O(nlogn) O(nlogn) :P5850 calc加强版。
Time
O ( n 2 ) O(n^2) O(n2)
Code
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 5007;
int n, m, k, mod;
int dp[N][N];
int qpow(int a, int b)
{
int res = 1;
while(b) {
if(b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
int inv(int x)
{
return qpow(x, mod - 2);
}
signed main()
{
scanf("%lld%lld%lld", &k, &n, &mod);
int m = 2 * n + 1;
for(int i = 0; i <= m; ++ i)
dp[0][i] = 1;
for(int i = 1; i <= n; ++ i) {
for(int j = 1; j <= m; ++ j) {
dp[i][j] = (dp[i][j - 1] + dp[i - 1][j - 1] * j % mod) % mod;
}
}
int ans = 0, fact = 1;
for(int i = 1; i <= n; ++ i)
fact = fact * i % mod;
if(k <= m) {
ans = dp[n][k];
printf("%lld\n", ans * fact % mod);
return 0;
}
for(int i = 1; i <= m; ++ i) {
int up = dp[n][i], down = 1;
for(int j = 1; j <= m; ++ j) {
if(i != j) {
up = up * (k - j + mod) % mod;
down = down * (i - j + mod) % mod;
}
}
ans = (ans + up * inv(down) % mod) % mod;
}
printf("%lld\n", ans * fact % mod);
}