提高2作业D|LOJ2537,线段树合并求条

思路:
每个点维护权值线段树,下标上的值表示该值有多少概率作为权值。
合并的时候先遍历左儿子其次右儿子,因此遍历过的值都是比当前值小的。维护 prelprer 来统计遍历过了的两个合并前的树在前面这些值上的总概率,这样就能迅速计算出当前值的新概率。
一个dfs从下向上合并,最终dfs一遍根节点的线段树。
过样例,全WA。

代码:

#include <bits/stdc++.h>
using namespace std;

#define int long long
#define endl '\n'
#define debug(x) cerr << #x << " = " << x << endl
#define rep(i, a, b) for (int i = (a); i <= (b); i++)
#define per(i, a, b) for (int i = (a); i >= (b); i--)
#define gn(u, v) for (int v : G.G[u])
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define sz(x) (int)(x).size()
#define pii pair<int, int>
#define vi vector<int>
#define vpii vector<pii>
#define vvi vector<vi>
#define no cout << "NO" << endl
#define yes cout << "YES" << endl
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define tomin(x, y) ((x) = min((x), (y)))
#define tomax(x, y) ((x) = max((x), (y)))
#define ck(mask, i) (((mask) >> (i)) & 1)
#define pq priority_queue
#define FLG (cerr << "Alive!" << endl);

constexpr int MAXN = 3e5+5;
constexpr int MOD = 998244353;

int qpow(int x, int y) {
    int res = 1;
    while (y) {
        if (y & 1)
            res = res * x % MOD;
        x = x * x % MOD;
        y >>= 1;
    }
    return res;
}

struct Mint {
    int x;
    Mint() { x = 0; }
    Mint(const int _x) { x = _x % MOD; }
    friend Mint operator+(Mint x, Mint y) {
        int t = x.x + y.x;
        return (t >= MOD) ? (t - MOD) : t;
    }
    friend Mint operator-(Mint x, Mint y) {
        int t = x.x - y.x;
        return (t < 0) ? (t + MOD) : t;
    }
    friend Mint operator*(Mint x, Mint y) {
        return x.x * y.x % MOD;
    }
    friend Mint operator/(Mint x, Mint y) {
        return x.x * qpow(y.x, MOD - 2) % MOD;
    }
    friend Mint operator^(Mint& x, int y) {
        return Mint(qpow(x.x, y));
    }
    friend Mint& operator+=(Mint& x, Mint y) {
        return x = x + y;
    }
    friend Mint& operator-=(Mint& x, Mint y) {
        return x = x - y;
    }
    friend Mint& operator*=(Mint& x, Mint y) {
        return x = x * y;
    }
    friend Mint& operator/=(Mint& x, Mint y) {
        return x = x / y;
    }
    friend Mint& operator^=(Mint& x, int y) {
        return x = x ^ y;
    }
    friend ostream& operator<<(ostream& o, Mint y) {
        o << y.x;
        return o;
    }
    friend istream& operator>>(istream& i, Mint y) {
        i >> y.x;
        return i;
    }
    Mint& operator++() {
        x++;
        if (x >= MOD)
            x -= MOD;
        return *this;
    }
    Mint operator++(signed) {
        x++;
        if (x >= MOD)
            x -= MOD;
        return *this;
    }
    Mint& operator--() {
        x--;
        if (x < 0)
            x += MOD;
        return *this;
    }
    Mint operator--(signed) {
        x--;
        if (x < 0)
            x += MOD;
        return *this;
    }
    friend bool operator==(const Mint& x, const Mint &y) {
        return x.x == y.x;
    }
    friend bool operator!=(const Mint& x, const Mint &y) {
        return x.x != y.x;
    }
};

int n;
vi s[MAXN];
int v[MAXN];
Mint prel, prer;
int rt[MAXN];
Mint ans = 0;
Mint cnt = 0;

struct SegTree {
    struct Node {
        int lc, rc;
        Mint p;
    };
    vector<Node> t;
    void init() {
        t.pb({0, 0, 0});
        rep(i, 1, n) {
            t.pb({0, 0, 0});
        }
    }
    void mkl(int k) {
        if (t[k].lc) return;
        t[k].lc = sz(t);
        t.pb({0, 0, 0});
    }
    void mkr(int k) {
        if (t[k].rc) return;
        t[k].rc = sz(t);
        t.pb({0, 0, 0});
    }
    void pull(int k) {
        t[k].p = t[t[k].lc].p + t[t[k].rc].p;
    }
    void insert(int k, int l, int r, int v, Mint w) {
        // cerr << "insert! " << k << " " << l << " " << r << " " << v << " " << w << endl
        if (l == r) {
            t[k].p += w;
            return;
        }
        int md = l + r >> 1;
        if (v <= md) {
            mkl(k);
            insert(t[k].lc, l, md, v, w);
        } else {
            mkr(k);
            insert(t[k].rc, md + 1, r, v, w);
        }
        pull(k);
    }
    int merge(int u, int v, int l, int r, Mint p) {
        // cerr << format("merge({}, {}, {}, {})", u, v, l, r) << endl;
        if (l == r) {
            Mint a = t[u].p;
            Mint b = t[v].p;
            if (u) t[u].p = p * prer * a + (1 - p) * (1 - prer) * a;
            if (v) t[v].p = p * prel * b + (1 - p) * (1 - prel) * b;
            prel += a;
            prer += b;
            return u | v;
        }
        if (!u || !v) {
            prel += t[u].p;
            prer += t[v].p;
            return u | v;
        }
        int md = l + r >> 1;
        t[u].lc = merge(t[u].lc, t[v].lc, l, md, p);
        t[u].rc = merge(t[u].rc, t[v].rc, md + 1, r, p);
        pull(u);
        return u;
    }

    void answer(int k, int l, int r) {
        // cerr << "answer " << k << " " << l << " " << r << endl;
        if (!k) return;
        if (l == r) {
            cnt++;
            // cerr << l << " " << t[k].p << endl;
            ans += cnt * (Mint)l * t[k].p * t[k].p;
            return;
        }
        int md = l + r >> 1;
        answer(t[k].lc, l, md);
        answer(t[k].rc, md + 1, r);
    }
} seg;

void solve(int u) {
    // cerr << format("solve({})", u) << endl;
    if (s[u].empty()) {
        seg.insert(u, 1, 1000000000, v[u], 1);
        rt[u] = u;
    } else {
        for (int v : s[u]) {
            solve(v);
        }
        if (sz(s[u]) == 1) {
            rt[u] = rt[s[u].front()];
        } else {
            prel = prer = 0;
            rt[u] = seg.merge(rt[s[u].front()], rt[s[u].back()], 1, 1000000000, (Mint)v[u] / (Mint) 10000);
        }
    }
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    cin >> n;
    rep(i, 1, n) {
        int x;
        cin >> x;
        s[x].pb(i);
    }
    rep(i, 1, n) {
        cin >> v[i];
    }

    seg.init();
    solve(1);
    seg.answer(rt[1], 1, 1000000000);
    cout << ans << endl;

    return 0;
}
1 个赞