1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
| #include <bits/stdc++.h>
#define debug(x) cout << #x << ":\t" << (x) << endl; using namespace std; #define ll long long #define ull unsigned ll const int N = 5e5 + 10; const int INF = 0x3f3f3f3f; const ll inf = 0x3f3f3f3f3f3f3f3f; const ll mod = 998244353; const ll inv2 = (mod + 1) / 2; typedef pair<int, int> pii; typedef pair<int, ll> pil;
int T; int n, m; char s[N];
ll h1[N], p1[N], h2[N], p2[N]; ll m1 = 998244353, m2 = 1e9 + 7; ll P1 = 233, P2 = 3993;
void init(ll h[], ll p[], ll P, ll mod) { p[0] = 1; for (int i = 1; i <= n; i++) { h[i] = (h[i - 1] * P % mod + s[i] - 'a' + 1) % mod; p[i] = p[i - 1] * P % mod; } }
ll hsh(int l, int r, ll h[], ll p[], ll mod) { return (h[r] - h[l - 1] * p[r - l + 1]%mod+mod)%mod; }
ll hsh(int l, int r) { return hsh(l, r, h1, p1, m1) * 2000000000ll + hsh(l, r, h2, p2, m2); }
unordered_map<ll, int> mp;
void upd(int l, int r) { if (l == r)mp[hsh(l, r)] = 1; else mp[hsh(l, r)] = mp[hsh(l, (l + r - 1) / 2)] + 1; }
ll cnt[N]; char Ma[N * 2]; int Mp[N * 2]; void Manacher(char s[], int len) { int l = 0; Ma[++l] = '$'; Ma[++l] = '#'; for (int i = 1; i <= len; i++) { Ma[++l] = s[i]; Ma[++l] = '#'; } Ma[l + 1] = 0; int mx = 0, id = 0; for (int i = 1; i <= l; i++) { Mp[i] = mx > i ? min(Mp[2 * id - i], mx - i) : 1; while (Ma[i + Mp[i]] == Ma[i - Mp[i]]) { int le = Mp[i]; if (Ma[i - le] == '#') { if (i & 1)upd(i / 2 - le / 2, i / 2 + le / 2); else upd(i / 2 - le / 2, i / 2 + le / 2 - 1); } Mp[i]++; } if (i + Mp[i] > mx) { mx = i + Mp[i]; id = i; } } }
int main() { scanf("%d", &T); while (T--) { scanf("%d%d", &n, &m); scanf("%s", s + 1); init(h1, p1, P1, m1); init(h2, p2, P2, m2); mp.clear(); for (int i = 1; i <= n; i++)cnt[i] = 0; Manacher(s, n); for (auto u:mp)cnt[u.second]++; for (int i = n - 1; i >= 1; i--)cnt[i] = (cnt[i] + cnt[i + 1]) % mod; ll ans = 0, tmp = 2; for (int i = 1; i <= m; i++) { ans = (ans + tmp * cnt[i] % mod) % mod; tmp = tmp * 2 % mod; } printf("%lld\n", ans); } return 0; }
|