https://codeforces.com/contest/1163
D. Mysterious Code
题意:给定3个字符串,串c中有小写字母和星号,串s和t只有小写字母,现在要把所有星号变成小写字母,要求最大化 串s在c中的出现次数-串t在c中的出现次数,这里的出现指作为子串出现。
KMP/AC自动机+dp
dp[i][j][k] 表示c串到第 i 位,s串到第 j 位,t串到第 k 位,最大的s,t出现次数差值。
每当s串到结尾时,差值+1;当t串到结尾时差值-1。
遍历i,j,k。当c在第i位,s在第j位,t在第k位,要假设s串前j-1位就是c串的i前j-1位,并且t串前k-1位就是c串前k-1位。当加入了c串第i位后,要考虑s串可以到第几位,t串到第几位。这里可以用kmp预处理出来,因为kmp是不断压缩后缀,且前后缀相同,所以如果s串前i-1位是c串到i为止的后缀,那么s的next[j-1] 一定也是c到i-1为止的后缀,所以不断跳next,直到c的第i位能接到后面。
用AC自动机的话每个节点记录这个节点表示的前缀的后缀是否包含s和t,复杂度更低。
https://www.luogu.com.cn/blog/Coding-life/solution-cf1163d
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
| #include<bits/stdc++.h> using namespace std; #define ll long long const int INF = 0x3f3f3f3f; const ll inf = 0x3f3f3f3f3f3f3f3f; const ll mod = 1e9 + 7; const int N = 2e5 + 10; void getNextVal(int* nxt, char* p, int kmp[60][30]) { int n = strlen(p + 1); nxt[0] = -1; nxt[1] = 0; for (int i = 2; i <= n; i++) { int j = nxt[i - 1]; while (j > 0 && p[i] != p[j + 1]) j = nxt[j]; if (p[j + 1] == p[i])j++; nxt[i] = j; } for (int i = 0; i <= n; i++) { for (int j = 0; j < 26; j++) { int cur = i; while (cur > 0 && p[cur + 1] - 'a' != j) cur = nxt[cur]; if (p[cur + 1] - 'a' == j)cur++; kmp[i][j] = cur; } } } int nxts[60], nxtt[60], kmps[60][30], kmpt[60][30]; int dp[1010][60][60]; char c[1010], s[60], t[60]; int main() { scanf("%s%s%s", c + 1, s + 1, t + 1); getNextVal(nxts, s, kmps); getNextVal(nxtt, t, kmpt); int n = strlen(c + 1); int p = strlen(s + 1), q = strlen(t + 1); for (int i = 0; i <= n; i++) for (int j = 0; j <= p; j++) for (int k = 0; k <= q; k++)dp[i][j][k] = -INF; dp[0][0][0] = 0; for (int i = 0; i < n; i++) { for (int j = 0; j <= p; j++) { for (int k = 0; k <= q; k++) { for (int x = 0; x < 26; x++) { if (c[i + 1] == '*' || c[i + 1] - 'a' == x) { int ns = kmps[j][x], nt = kmpt[k][x]; int tmp = dp[i][j][k] + (ns == p) - (nt == q); dp[i + 1][ns][nt] = max(dp[i + 1][ns][nt], tmp); } } } } } int ans = -INF; for (int i = 0; i <= p; i++) { for (int j = 0; j <= q; j++) ans = max(ans, dp[n][i][j]); } printf("%d\n", ans); return 0; }
|