啊啊啊,真没招了,有AC代码(题解)对照着也瞪不出问题
MY CODE
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
ll dp[105][55], ms[55][55], a[105];
int main(){
ll t;
cin >> t;
while(t--){
ll n, m;
cin >> n >> m;
for(ll i=0;i<m;i++){
for(ll j=0;j<m;j++){
cin >> ms[i][j];
}
}
for(ll i=0;i<n;i++){
cin >> a[i];
}
memset(dp, 0, sizeof(dp));
for(ll i=1;i<n;i++){
if(a[i]>0&&a[i-1]>0){
dp[i][a[i]]=dp[i-1][a[i-1]]+ms[a[i-1]][a[i]];
}
else if(a[i]>0&&a[i-1]<0){
for(ll k=0;k<m;k++)
dp[i][a[i]]=max(dp[i][a[i]],dp[i-1][k]+ms[k][a[i]]);
}
else if(a[i]<0&&a[i-1]>0){
for(ll j=0;j<m;j++)
dp[i][j]=max(dp[i][j],dp[i-1][a[i-1]]+ms[a[i-1]][j]);
}
else{
for(ll j=0;j<m;j++)
for(ll k=0;k<m;k++)
dp[i][j]=max(dp[i][j],dp[i-1][k]+ms[k][j]);
}
}
ll ans=0;
for(ll i=0;i<m;i++){
ans=max(ans, dp[n-1][i]);
}
cout << ans << endl;
}
}
AC CODE