思路:
每个点维护权值线段树,下标上的值表示该值有多少概率作为权值。
合并的时候先遍历左儿子其次右儿子,因此遍历过的值都是比当前值小的。维护 prel 与 prer 来统计遍历过了的两个合并前的树在前面这些值上的总概率,这样就能迅速计算出当前值的新概率。
一个dfs从下向上合并,最终dfs一遍根节点的线段树。
过样例,全WA。
代码:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define endl '\n'
#define debug(x) cerr << #x << " = " << x << endl
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define per(i, a, b) for (int i = (a); i >= (b); i--)
#define gn(u, v) for (int v : G.G[u])
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define sz(x) (int)(x).size()
#define pii pair<int, int>
#define vi vector<int>
#define vpii vector<pii>
#define vvi vector<vi>
#define no cout << "NO" << endl
#define yes cout << "YES" << endl
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define tomin(x, y) ((x) = min((x), (y)))
#define tomax(x, y) ((x) = max((x), (y)))
#define ck(mask, i) (((mask) >> (i)) & 1)
#define pq priority_queue
#define FLG (cerr << "Alive!" << endl);
constexpr int MAXN = 3e5+5;
constexpr int MOD = 998244353;
int qpow(int x, int y) {
int res = 1;
while (y) {
if (y & 1)
res = res * x % MOD;
x = x * x % MOD;
y >>= 1;
}
return res;
}
struct Mint {
int x;
Mint() { x = 0; }
Mint(const int _x) { x = _x % MOD; }
friend Mint operator+(Mint x, Mint y) {
int t = x.x + y.x;
return (t >= MOD) ? (t - MOD) : t;
}
friend Mint operator-(Mint x, Mint y) {
int t = x.x - y.x;
return (t < 0) ? (t + MOD) : t;
}
friend Mint operator*(Mint x, Mint y) {
return x.x * y.x % MOD;
}
friend Mint operator/(Mint x, Mint y) {
return x.x * qpow(y.x, MOD - 2) % MOD;
}
friend Mint operator^(Mint& x, int y) {
return Mint(qpow(x.x, y));
}
friend Mint& operator+=(Mint& x, Mint y) {
return x = x + y;
}
friend Mint& operator-=(Mint& x, Mint y) {
return x = x - y;
}
friend Mint& operator*=(Mint& x, Mint y) {
return x = x * y;
}
friend Mint& operator/=(Mint& x, Mint y) {
return x = x / y;
}
friend Mint& operator^=(Mint& x, int y) {
return x = x ^ y;
}
friend ostream& operator<<(ostream& o, Mint y) {
o << y.x;
return o;
}
friend istream& operator>>(istream& i, Mint y) {
i >> y.x;
return i;
}
Mint& operator++() {
x++;
if (x >= MOD)
x -= MOD;
return *this;
}
Mint operator++(signed) {
x++;
if (x >= MOD)
x -= MOD;
return *this;
}
Mint& operator--() {
x--;
if (x < 0)
x += MOD;
return *this;
}
Mint operator--(signed) {
x--;
if (x < 0)
x += MOD;
return *this;
}
friend bool operator==(const Mint& x, const Mint &y) {
return x.x == y.x;
}
friend bool operator!=(const Mint& x, const Mint &y) {
return x.x != y.x;
}
};
int n;
vi s[MAXN];
int v[MAXN];
Mint prel, prer;
int rt[MAXN];
Mint ans = 0;
Mint cnt = 0;
struct SegTree {
struct Node {
int lc, rc;
Mint p;
};
vector<Node> t;
void init() {
t.pb({0, 0, 0});
rep(i, 1, n) {
t.pb({0, 0, 0});
}
}
void mkl(int k) {
if (t[k].lc) return;
t[k].lc = sz(t);
t.pb({0, 0, 0});
}
void mkr(int k) {
if (t[k].rc) return;
t[k].rc = sz(t);
t.pb({0, 0, 0});
}
void pull(int k) {
t[k].p = t[t[k].lc].p + t[t[k].rc].p;
}
void insert(int k, int l, int r, int v, Mint w) {
// cerr << "insert! " << k << " " << l << " " << r << " " << v << " " << w << endl
if (l == r) {
t[k].p += w;
return;
}
int md = l + r >> 1;
if (v <= md) {
mkl(k);
insert(t[k].lc, l, md, v, w);
} else {
mkr(k);
insert(t[k].rc, md + 1, r, v, w);
}
pull(k);
}
int merge(int u, int v, int l, int r, Mint p) {
// cerr << format("merge({}, {}, {}, {})", u, v, l, r) << endl;
if (l == r) {
Mint a = t[u].p;
Mint b = t[v].p;
if (u) t[u].p = p * prer * a + (1 - p) * (1 - prer) * a;
if (v) t[v].p = p * prel * b + (1 - p) * (1 - prel) * b;
prel += a;
prer += b;
return u | v;
}
if (!u || !v) {
prel += t[u].p;
prer += t[v].p;
return u | v;
}
int md = l + r >> 1;
t[u].lc = merge(t[u].lc, t[v].lc, l, md, p);
t[u].rc = merge(t[u].rc, t[v].rc, md + 1, r, p);
pull(u);
return u;
}
void answer(int k, int l, int r) {
// cerr << "answer " << k << " " << l << " " << r << endl;
if (!k) return;
if (l == r) {
cnt++;
// cerr << l << " " << t[k].p << endl;
ans += cnt * (Mint)l * t[k].p * t[k].p;
return;
}
int md = l + r >> 1;
answer(t[k].lc, l, md);
answer(t[k].rc, md + 1, r);
}
} seg;
void solve(int u) {
// cerr << format("solve({})", u) << endl;
if (s[u].empty()) {
seg.insert(u, 1, 1000000000, v[u], 1);
rt[u] = u;
} else {
for (int v : s[u]) {
solve(v);
}
if (sz(s[u]) == 1) {
rt[u] = rt[s[u].front()];
} else {
prel = prer = 0;
rt[u] = seg.merge(rt[s[u].front()], rt[s[u].back()], 1, 1000000000, (Mint)v[u] / (Mint) 10000);
}
}
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin >> n;
rep(i, 1, n) {
int x;
cin >> x;
s[x].pb(i);
}
rep(i, 1, n) {
cin >> v[i];
}
seg.init();
solve(1);
seg.answer(rt[1], 1, 1000000000);
cout << ans << endl;
return 0;
}