[whj什么都不会系列-2]bzoj4734题解

bzoj上的题面真是残缺不全……uoj上有完整的题面.

退役久了脑子都不好用了,这么简单的东西推了半天……

看到这种求和题很自然地就会想到拆成卷积的形式:

那么答案就是$n![z^n]g(z)h(z)$.

很容易看出$h(z)=e^{(1-x)z}$,关键是怎么表示$g(z)$.

这里要用到一个结论:设$A(z)=\sum a _kz^k$,记$P _L(z)=\sum\left\{\begin{matrix}L\\k\end{matrix}\right\}z^k$,那么有

证明的话,对$L$归纳就可以了. 感觉这个结论在很多时候都挺有用的.

设$f(x)=\sum c _ix^i$,那么就有

那么答案就是$n![z^n]e^z\sum _{i=0}^mc _iP _i(xz)$

记$t(z)=\sum _{i=0}^mc _iP _i(xz)$,注意到$\deg t=m$,所以如果我们能把$t(z)$求出来的话,剩下就只需要做一个长度为$m$的卷积了(而不是题目式子里长为$n$的卷积).

为了求出$t(z)$,考虑到斯特林数没有什么太好的性质,我们需要给它乘回一个$e^{xz}$. 我们知道$[z^k]t(z)e^{xz}=\frac{f(k)x^k}{k!}$,所以要求$t(z)$的话我们可以把$\sum _k\frac{f(k)x^k}{k!}z^k$和$e^{-xz}$的前$m+1$项做一个卷积.

这道题,总的来说,这么一大通的变换,主要目的就是分离出一个长度为$\mathcal O(m)$的多项式,这样一切就都好处理了.

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ele int
#define ll long long
using namespace std;
const ele maxn=1<<16|1;
const ele MOD=998244353;
const ele g=3;
ele n,m,x,f[maxn],fac[maxn],ifac[maxn],a[maxn],b[maxn];
inline ele& add(ele&a,ele b){
a+=b;
return a>=MOD?a-=MOD:a;
}
inline ele pw(ele a,ele x){
ele ans=1,tmp=a%MOD;
for (; x; x>>=1,tmp=(ll)tmp*tmp%MOD)
if (x&1) ans=(ll)ans*tmp%MOD;
return ans;
}
inline void ntt(ele K,ele n,ele *y){
static ele f[maxn];
f[0]=0;
for (int i=1; i<n; ++i){
f[i]=f[i>>1]>>1;
if (i&1) f[i]+=n>>1;
if (i<f[i]) swap(y[i],y[f[i]]);
}
for (int p=1; p<n; p<<=1){
ele o=pw(g,(MOD-1)/p/2);
o=~K?o:pw(o,MOD-2);
for (int i=0; i<n; i+=(p<<1))
for (int j=i,o1=1; j<i+p; ++j,o1=(ll)o1*o%MOD){
ele u=y[j],v=(ll)y[j+p]*o1%MOD;
y[j]=y[j+p]=u;
add(y[j],v); add(y[j+p],MOD-v);
}
}
if (!~K){
ele invn=pw(n,MOD-2);
for (int i=0; i<n; ++i) y[i]=(ll)y[i]*invn%MOD;
}
}
int main(){
scanf("%d%d%d",&n,&m,&x);
for (int i=0; i<=m; ++i) scanf("%d",f+i);
fac[0]=1;
for (int i=1; i<=m; ++i) fac[i]=(ll)fac[i-1]*i%MOD;
ifac[m]=pw(fac[m],MOD-2);
for (int i=m-1; ~i; --i) ifac[i]=(ll)ifac[i+1]*(i+1)%MOD;
ele tmp=1; while (tmp<=m+m) tmp<<=1;
memset(a,0,sizeof(ele)*tmp);
memset(b,0,sizeof(ele)*tmp);
for (int i=0,p=1; i<=m; ++i,p=(ll)p*x%MOD){
a[i]=(ll)f[i]*p%MOD*ifac[i]%MOD;
b[i]=(ll)p*ifac[i]%MOD;
if (i&1) b[i]=MOD-b[i];
}
ntt(1,tmp,a); ntt(1,tmp,b);
for (int i=0; i<tmp; ++i) a[i]=(ll)a[i]*b[i]%MOD;
ntt(-1,tmp,a);
ele ans=0;
for (int i=0,p=1; i<=m; p=(ll)p*(n-i)%MOD,++i) add(ans,(ll)a[i]*p%MOD);
printf("%d\n",ans);
return 0;
}