uoj50题解

分治fft是非常明显的做法,不过这样是\(\mathcal O(n\log^2n)\)的. 把生成函数弄出来之后会发现它是个微分方程,可以有一些神奇的方法来解,具体可以看UR3的题解.

我比较菜所以还是决定写分治+卡常,发现了一个卡常技巧. 对\([l,r)\)区间分治的时候,设\(m=\left\lfloor\frac{l+r}{2}\right\rfloor\),如果\(l\neq0\),就要把\(C(z)\)\(F(z)\)\([l,m)\)的部分,\(F(z)\)\([0,r-l)\)的部分卷起来,这里fft的长度看似要开到\(4(r-l)\),但事实上只需要开到\(2(r-l)\),因为超出\(2(r-l)\)的部分小于\(2(r-l)+(m-l)\),这样就算循环到前面去,也会小于\(m-l\),而对\([m,r)\)的贡献是从\(m-l\)开始的,所以不影响答案. 而\(l=0\)的时候fft长度显然也可以只开到\(2(r-l)\). 这样一来可以显著减小常数. 跑得比网上搜到的倍增还快!

另外以后在我学会倍增解微分方程之前,看到微分方程不会解可以考虑分治.

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ele int
#define ll long long
using namespace std;
const ele maxn=(1<<20)+1;
const ele MOD=998244353;
const ele G=3;
const ele inv2=(MOD+1)/2;
inline ele& add(ele&a,ele b){
a+=b;
return a>=MOD?a-=MOD:a;
}
inline ele pw(ele a,ele x){
ele ans=1,tmp=a%MOD;
for (; x; x>>=1,tmp=(ll)tmp*tmp%MOD)
if (x&1) ans=(ll)ans*tmp%MOD;
return ans;
}
ele n,f[maxn],g[maxn],c[maxn],inv[maxn],fac[maxn],ifac[maxn];
char s[maxn];
inline void ntt(ele K,ele n,ele *y){
static ele f[maxn];
f[0]=0;
for (int i=1; i<n; ++i){
f[i]=f[i>>1]>>1;
if (i&1) f[i]+=n>>1;
if (i<f[i]) swap(y[i],y[f[i]]);
}
for (int p=1; p<n; p<<=1){
ele o=pw(G,(MOD-1)/p/2); o=~K?o:pw(o,MOD-2);
for (int i=0; i<n; i+=(p<<1)){
ele o1=1;
for (int j=i; j<i+p; ++j,o1=(ll)o1*o%MOD){
ele v=(ll)y[j+p]*o1%MOD;
y[j+p]=y[j];
add(y[j],v); add(y[j+p],MOD-v);
}
}
}
if (!~K){
ele invn=pw(n,MOD-2);
for (int i=0; i<n; ++i) y[i]=(ll)y[i]*invn%MOD;
}
}
void solve(ele l,ele r){
if (r-l<=1){
f[l+1]=(ll)g[l]*inv[l+1]%MOD;
return;
}
ele mid=(l+r)>>1;
static ele t1[maxn],t2[maxn],t3[maxn];
solve(l,mid);
if (l){
ele tmp=(r-l)<<1;
memset(t1,0,sizeof(ele)*tmp); memcpy(t1,c,sizeof(ele)*(r-l));
memset(t2,0,sizeof(ele)*tmp); memcpy(t2,f+l,sizeof(ele)*(mid-l));
memset(t3,0,sizeof(ele)*tmp); memcpy(t3,f,sizeof(ele)*(r-l));
ntt(1,tmp,t1); ntt(1,tmp,t2); ntt(1,tmp,t3);
for (int i=0; i<tmp; ++i) t1[i]=(ll)t1[i]*t2[i]%MOD*t3[i]%MOD;
ntt(-1,tmp,t1);
for (int i=mid; i<r; ++i) add(g[i],t1[i-l]);
}
else{
ele tmp=(r-l)<<1;
memset(t1,0,sizeof(ele)*tmp); memcpy(t1,c,sizeof(ele)*(r-l));
memset(t2,0,sizeof(ele)*tmp); memcpy(t2,f+l,sizeof(ele)*(mid-l));
ntt(1,tmp,t1); ntt(1,tmp,t2);
for (int i=0; i<tmp; ++i){
t1[i]=(ll)t1[i]*t2[i]%MOD*t2[i]%MOD;
t1[i]=(t1[i]&1)?(t1[i]+MOD)>>1:t1[i]>>1;
}
ntt(-1,tmp,t1);
for (int i=mid; i<r; ++i) add(g[i],t1[i-l]);
}
solve(mid,r);
}
int main(){
scanf("%d%s",&n,s);
ele tmp=1;
while (tmp<n) tmp<<=1;
fac[0]=1;
for (int i=1; i<=tmp; ++i) fac[i]=(ll)fac[i-1]*i%MOD;
ifac[tmp]=pw(fac[tmp],MOD-2);
for (int i=tmp-1; ~i; --i) ifac[i]=(ll)ifac[i+1]*(i+1)%MOD;
for (int i=1; i<=tmp; ++i) inv[i]=(ll)fac[i-1]*ifac[i]%MOD;
for (int i=0; i<n; ++i){
c[i]=s[i]-'0';
c[i]=c[i]*ifac[i];
}
g[0]=1;
solve(0,tmp);
for (int i=1; i<=n; ++i)
printf("%lld\n",(ll)f[i]*fac[i]%MOD);
return 0;
}