rt.
using namespace std;
const int N = 3005;
const int mod = 1e9 + 7;
int n, L, R, k;
int a[N];
vector<int> g[N];
int sz[N], cnt[N];
int f[N][N];
int (*dp[N])[N];
int ci[N], cj[N];
void DFS1(int u)
{
sz[u] = 1;
for (int c = 1; c <= n; c++)
f[u][c] = 0;
f[u][a[u]] = 1;
for (int v : g[u])
{
DFS1(v);
sz[u] += sz[v];
for (int c = 1; c <= n; c++)
f[u][c] = (f[u][c] + f[v][c]) % mod;
}
cnt[u] = f[u][a[u]];
}
void DFS2(int u)
{
dp[u] = new int [cnt[u] + 1][N];
for (int i = 0; i <= cnt[u]; i++)
for (int j = 0; j <= sz[u]; j++)
dp[u][i][j] = 0;
dp[u][0][0] = 1;
ci[u] = 0;
cj[u] = 0;
for (int v : g[u])
{
DFS2(v);
if (a[v] == a[u])
{
int (*nf)[N] = new int [cnt[u] + 1][N];
for (int i = 0; i <= cnt[u]; i++)
for (int j = 0; j <= sz[u]; j++)
nf[i][j] = 0;
int nci = 0, ncj = 0;
for (int i1 = 0; i1 <= ci[u]; i1++)
for (int j1 = 0; j1 <= cj[u]; j1++)
if (dp[u][i1][j1] != 0)
for (int i2 = 0; i2 <= cnt[v]; i2++)
for (int j2 = 0; j2 <= sz[v]; j2++)
if (dp[v][i2][j2] != 0)
if (i1 + i2 <= cnt[u] && j1 + j2 <= sz[u])
{
nf[i1 + i2][j1 + j2] = (nf[i1 + i2][j1 + j2] + 1LL * dp[u][i1][j1] * dp[v][i2][j2] % mod) % mod;
if (i1 + i2 > nci) nci = i1 + i2;
if (j1 + j2 > ncj) ncj = j1 + j2;
}
delete[] dp[u];
dp[u] = nf;
ci[u] = nci;
cj[u] = ncj;
}
else
{
vector<int> G_v(sz[v] + 1, 0);
for (int i2 = 0; i2 <= cnt[v]; i2++)
for (int j2 = 0; j2 <= sz[v]; j2++)
G_v[j2] = (G_v[j2] + dp[v][i2][j2]) % mod;
int (*nf)[N] = new int [cnt[u] + 1][N];
for (int i = 0; i <= cnt[u]; i++)
for (int j = 0; j <= sz[u]; j++)
nf[i][j] = 0;
int nci = ci[u];
int ncj = 0;
for (int i1 = 0; i1 <= ci[u]; i1++)
for (int j1 = 0; j1 <= cj[u]; j1++)
if (dp[u][i1][j1] != 0)
for (int j2 = 0; j2 <= sz[v]; j2++)
if (G_v[j2] != 0)
if (j1 + j2 <= sz[u])
{
nf[i1][j1 + j2] = (nf[i1][j1 + j2] + 1LL * dp[u][i1][j1] * G_v[j2] % mod) % mod;
if (j1 + j2 > ncj) ncj = j1 + j2;
}
delete[] dp[u];
dp[u] = nf;
cj[u] = ncj;
}
delete[] dp[v];
}
int (*nf)[N] = new int [cnt[u] + 1][N];
for (int i = 0; i <= cnt[u]; i++)
for (int j = 0; j <= sz[u]; j++)
nf[i][j] = 0;
int nci = 0, ncj = 0;
for (int i = 0; i <= ci[u]; i++)
for (int j = 0; j <= cj[u]; j++)
if (dp[u][i][j] != 0)
{
nf[i][j] = (nf[i][j] + dp[u][i][j]) % mod;
if (i > nci) nci = i;
if (j > ncj) ncj = j;
int j1 = j;
if (i + 1 >= k) j1++;
if (i + 1 <= cnt[u] && j1 <= sz[u])
{
nf[i + 1][j1] = (nf[i + 1][j1] + dp[u][i][j]) % mod;
if (i + 1 > nci) nci = i + 1;
if (j1 > ncj) ncj = j1;
}
}
delete[] dp[u];
dp[u] = nf;
ci[u] = nci;
cj[u] = ncj;
}
int main()
{
freopen("occur.in", "r", stdin);
freopen("occur.out", "w", stdout);
cin >> n >> L >> R >> k;
for (int i = 1; i <= n; i++)
cin >> a[i];
for (int i = 2; i <= n; i++)
{
int p;
cin >> p;
g[p].push_back(i);
}
DFS1(1);
DFS2(1);
long long ans = 0;
for (int j = L; j <= R; j++)
for (int i = 0; i <= cnt[1]; i++)
ans = (ans + dp[1][i][j]) % mod;
cout << ans << endl;
return 0;
}
这份代码RE10分,但是把139行的dp[1][i][j]
改成*dp[i][j]
就会RE0,为什么???(样例还过了(((