洛谷 P3746 求调

Rt

#include <iostream>
#include <cassert>
#include <queue>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <ctime>
#include <map>
#include <set>
using namespace std;
#define int long long
#define pii pair<int, int>
#define eb emplace_back
#define F first
#define S second
#define test(x) cout << "Test: " << x << '\n'
#define debug puts("qwq");
#define open(x) freopen(#x".in", "r", stdin);freopen(#x".out", "w", stdout);
#define close fclose(stdin);fclose(stdout);
namespace FastIO {
	template <typename T = int>
	inline T read() {
		T s = 0, w = 1;
		char c = getchar();
		while (!isdigit(c)) {
			if (c == '-') w = -1;
			c = getchar();
		}
		while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
		return s * w;
	}
	template <typename T>
	inline void read(T &s) {
		s = 0;
		int w = 1;
		char c = getchar();
		while (!isdigit(c)) {
			if (c == '-') w = -1;
			c = getchar();
		}
		while (isdigit(c)) s = (s << 1) + (s << 3) + (c ^ 48), c = getchar();
		s = s * w;
	}
	template <typename T, typename... Arp> inline void read(T &x, Arp &...arp) {
		read(x), read(arp...);
	}
	template <typename T>
	inline void write(T x, char ch) {
		if (x < 0) x = -x, putchar('-');
		static char stk[25];
		int top = 0;
		do {
			stk[top++] = x % 10 + '0', x /= 10;
		} while (x);
		while (top) putchar(stk[--top]);
		putchar(ch);
		return;
	}
	template <typename T>
	inline void smax(T &x, T y) {
		if (x < y) x = y;
	}
	template <typename T>
	inline void smin(T &x, T y) {
		if (x > y) x = y;
	}
	void quit() {
		exit(0);
	}
} using namespace FastIO;
const int N = 55;
int n, p, k, r;
struct matrix {
	int a[N][N];
	matrix operator * (const matrix &x) {
		matrix ans;
		for (int i = 1; i <= k; ++i) for (int j = 1; j <= k; ++j) {
			ans.a[i][j] = 0;
			for (int s = 1; s <= k; ++s) (ans.a[i][j] += a[i][s] * x.a[s][j]) %= p;
		} return ans;
	}
} a, b;
matrix ksm(matrix a, int b) {
	matrix ans;
	for (int i = 0; i < N; ++i) for (int j = 0; j < N; ++j) ans.a[i][j] = 1;
	while (b) {
		if (b & 1) ans = ans * a;
		b >>= 1; a = a * a;
	}
	return ans;
}
signed main() {
	read(n, p, k, r);
	b.a[1][1] ++; b.a[1][k] ++; a.a[1][1] ++;
	for (int i = 2; i <= k; ++i) {
		b.a[i][i] ++; b.a[i][i-1] ++;
	}
	write((a * ksm(b, n * k)).a[1][r + 1], '\n');
	return 0;
}
2 个赞

passed

3 个赞