loj2541题解

这种求某个东西在最后一个的概率之类的其实可以考虑容斥,设钦定一个集合\(S\)在它后面的概率为\(f(S)\),那么最后的答案为\(\sum (-1)^{|S|}f(S)\).

现在考虑怎么算\(f(S)\),一种感性的方法是,你可以认为其他人就没有关系了,那么\(f(S)\)即为第一个人第一个死的概率,即\(\frac{w _1}{w _1+\text{sum}(S)}\).

要严谨地证明的话,可以改变一下游戏规则:死去的人不把他踢出去,这样不会改变每个人的死亡顺序,记\(W=\sum _{i=1}^n w _i\),那么\(f(S)=\sum _{i=0}^{+\infty}\left(\frac{W-w _1-\text{sum}(S)}{W}\right)^i\frac{w _1}{W}=\frac{w _1}{w _1+\text{sum}(S)}\).

于是答案就是\(\sum \frac{(-1)^{|S|}}{w _1+\text{sum}(S)}\),注意到\(w _i\)加起来很小,我们可以统计每个\(\text{sum}(S)\)的贡献,这个只需要计算\(\prod _{i=2}^n(1-z^{w _i})\)就可以了.

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ele int
#define ll long long
using namespace std;
static const ele maxn=(1<<18)+1;
#define MOD 998244353
#define G 3
inline ele& add(ele &a,ele b){
a+=b;
return a>=MOD?a-=MOD:a;
}
inline ele pw(ele a,ele x){
ele ans=1,tmp=a%MOD;
for (; x; x>>=1,tmp=(ll)tmp*tmp%MOD)
if (x&1) ans=(ll)ans*tmp%MOD;
return ans;
}
ele n,w[maxn],a[maxn];
inline void ntt(ele K,ele n,ele *y){
static ele f[maxn];
f[0]=0;
for (int i=1; i<n; ++i){
f[i]=f[i>>1]>>1;
if (i&1) f[i]+=n>>1;
if (i<f[i]) swap(y[i],y[f[i]]);
}
for (int p=1; p<n; p<<=1){
ele o=pw(G,(MOD-1)/p/2); o=~K?o:pw(o,MOD-2);
for (int i=0; i<n; i+=(p<<1)){
ele o1=1;
for (int j=i; j<i+p; ++j,o1=(ll)o1*o%MOD){
ele u=y[j],v=(ll)y[j+p]*o1%MOD;
y[j]=y[j+p]=u;
add(y[j],v); add(y[j+p],MOD-v);
}
}
}
if (!~K){
ele invn=pw(n,MOD-2);
for (int i=0; i<n; ++i) y[i]=(ll)y[i]*invn%MOD;
}
}
ele solve(ele *a,ele l,ele r){
if (l==r){
a[0]=1;
a[w[l]]=MOD-1;
for (int i=1; i<w[l]; ++i) a[i]=0;
return w[l];
}
ele mid=(l+r)>>1;
ele s1=solve(a,l,mid);
ele s2=solve(a+s1+1,mid+1,r);
ele tmp=1; while (tmp<=s1+s2) tmp<<=1;
static ele t1[maxn],t2[maxn];
memset(t1,0,sizeof(ele)*tmp); memcpy(t1,a,sizeof(ele)*(s1+1));
memset(t2,0,sizeof(ele)*tmp); memcpy(t2,a+s1+1,sizeof(ele)*(s2+1));
ntt(1,tmp,t1); ntt(1,tmp,t2);
for (int i=0; i<tmp; ++i) t1[i]=(ll)t1[i]*t2[i]%MOD;
ntt(-1,tmp,t1);
memcpy(a,t1,sizeof(ele)*(s1+s2+1));
return s1+s2;
}
int main(){
scanf("%d",&n);
for (int i=0; i<n; ++i) scanf("%d",w+i);
if (n>1){
ele s=solve(a,1,n-1);
ele ans=0;
for (int i=0; i<=s; ++i) add(ans,(ll)w[0]*a[i]%MOD*pw(w[0]+i,MOD-2)%MOD);
printf("%d\n",ans);
}
else puts("1");
return 0;
}