[whj什么都不会系列-2]bzoj4734题解

bzoj上的题面真是残缺不全……uoj上有完整的题面.

退役久了脑子都不好用了,这么简单的东西推了半天……

看到这种求和题很自然地就会想到拆成卷积的形式:

\[Q(f,n,x)=n!\sum _{k=0}^n\frac{f(k)x^k}{k!}\cdot\frac{(1-x)^{n-k}}{(n-k)!}\]

\[\begin{aligned} g(z)&=\sum\frac{f(k)x^k}{k!}z^k\\ h(z)&=\sum\frac{(1-x)^k}{k!}z^k \end{aligned}\]

那么答案就是\(n![z^n]g(z)h(z)\).

很容易看出\(h(z)=e^{(1-x)z}\),关键是怎么表示\(g(z)\).

这里要用到一个结论:设\(A(z)=\sum a _kz^k\),记\(P _L(z)=\sum\left\{\begin{matrix}L\\k\end{matrix}\right\}z^k\),那么有

\[A(z)P _L(z)=\sum a _kk^Lz^k\]

证明的话,对\(L\)归纳就可以了. 感觉这个结论在很多时候都挺有用的.

\(f(x)=\sum c _ix^i\),那么就有

\[\begin{aligned} g(z)&=\sum _k\left(\frac{(xz)^k}{k!}\sum _ic _ik^i\right)\\ &=\sum _ic _i\sum _k\frac{(xz)^kk^i}{k!}\\ &=\sum _ic _ie^{xz}P _i(xz) \end{aligned}\]

那么答案就是\(n![z^n]e^z\sum _{i=0}^mc _iP _i(xz)\)

\(t(z)=\sum _{i=0}^mc _iP _i(xz)\),注意到\(\deg t=m\),所以如果我们能把\(t(z)\)求出来的话,剩下就只需要做一个长度为\(m\)的卷积了(而不是题目式子里长为\(n\)的卷积).

为了求出\(t(z)\),考虑到斯特林数没有什么太好的性质,我们需要给它乘回一个\(e^{xz}\). 我们知道\([z^k]t(z)e^{xz}=\frac{f(k)x^k}{k!}\),所以要求\(t(z)\)的话我们可以把\(\sum _k\frac{f(k)x^k}{k!}z^k\)\(e^{-xz}\)的前\(m+1\)项做一个卷积.

这道题,总的来说,这么一大通的变换,主要目的就是分理出一个长度为\(\mathcal O(m)\)的多项式,这样一切就都好处理了.

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ele int
#define ll long long
using namespace std;
const ele maxn=1<<16|1;
const ele MOD=998244353;
const ele g=3;
ele n,m,x,f[maxn],fac[maxn],ifac[maxn],a[maxn],b[maxn];
inline ele& add(ele&a,ele b){
a+=b;
return a>=MOD?a-=MOD:a;
}
inline ele pw(ele a,ele x){
ele ans=1,tmp=a%MOD;
for (; x; x>>=1,tmp=(ll)tmp*tmp%MOD)
if (x&1) ans=(ll)ans*tmp%MOD;
return ans;
}
inline void ntt(ele K,ele n,ele *y){
static ele f[maxn];
f[0]=0;
for (int i=1; i<n; ++i){
f[i]=f[i>>1]>>1;
if (i&1) f[i]+=n>>1;
if (i<f[i]) swap(y[i],y[f[i]]);
}
for (int p=1; p<n; p<<=1){
ele o=pw(g,(MOD-1)/p/2);
o=~K?o:pw(o,MOD-2);
for (int i=0; i<n; i+=(p<<1))
for (int j=i,o1=1; j<i+p; ++j,o1=(ll)o1*o%MOD){
ele u=y[j],v=(ll)y[j+p]*o1%MOD;
y[j]=y[j+p]=u;
add(y[j],v); add(y[j+p],MOD-v);
}
}
if (!~K){
ele invn=pw(n,MOD-2);
for (int i=0; i<n; ++i) y[i]=(ll)y[i]*invn%MOD;
}
}
int main(){
scanf("%d%d%d",&n,&m,&x);
for (int i=0; i<=m; ++i) scanf("%d",f+i);
fac[0]=1;
for (int i=1; i<=m; ++i) fac[i]=(ll)fac[i-1]*i%MOD;
ifac[m]=pw(fac[m],MOD-2);
for (int i=m-1; ~i; --i) ifac[i]=(ll)ifac[i+1]*(i+1)%MOD;
ele tmp=1; while (tmp<=m+m) tmp<<=1;
memset(a,0,sizeof(ele)*tmp);
memset(b,0,sizeof(ele)*tmp);
for (int i=0,p=1; i<=m; ++i,p=(ll)p*x%MOD){
a[i]=(ll)f[i]*p%MOD*ifac[i]%MOD;
b[i]=(ll)p*ifac[i]%MOD;
if (i&1) b[i]=MOD-b[i];
}
ntt(1,tmp,a); ntt(1,tmp,b);
for (int i=0; i<tmp; ++i) a[i]=(ll)a[i]*b[i]%MOD;
ntt(-1,tmp,a);
ele ans=0;
for (int i=0,p=1; i<=m; p=(ll)p*(n-i)%MOD,++i) add(ans,(ll)a[i]*p%MOD);
printf("%d\n",ans);
return 0;
}