题目:https://www.luogu.com.cn/problem/P5591
我的代码(A了一个点,其他MLE):
#include <iostream>
#include <vector>
using namespace std;
const int MOD = 998244353;
long long mod_exp(long long base, long long exp, long long mod) {
long long res = 1;
while (exp > 0) {
if (exp % 2 == 1) {
res = (res * base) % mod;
}
base = (base * base) % mod;
exp /= 2;
}
return res;
}
void fac(int n, vector<long long>& fact, vector<long long>& inv_fact) {
fact.resize(n + 1);
inv_fact.resize(n + 1);
fact[0] = 1;
for (int i = 1; i <= n; ++i) {
fact[i] = fact[i - 1] * i % MOD;
}
inv_fact[n] = mod_exp(fact[n], MOD - 2, MOD);
for (int i = n - 1; i >= 0; --i) {
inv_fact[i] = inv_fact[i + 1] * (i + 1) % MOD;
}
}
long long pig(int n, int k, const vector<long long>& fact, const vector<long long>& inv_fact) {
if (k < 0 || k > n) return 0;
return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD;
}
long long solve(int n, int p, int k) {
vector<long long> fact, inv_fact;
fac(n, fact, inv_fact);
long long res = 0;
for (int i = 0; i <= n; ++i) {
long long comb = pig(n, i, fact, inv_fact);
long long power = mod_exp(p, i, MOD);
long long floor_val = i / k;
res = (res + comb * power % MOD * floor_val % MOD) % MOD;
}
return res;
}
int main()
{
int n, p, k;
cin >> n >> p >> k;
long long ans = solve(n, p, k);
cout << ans;
return 0;
}