1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
| #include <bits/stdc++.h> using namespace std; #define ll long long const int INF = 0x3f3f3f3f; const int N = 3e4 + 10; const int mod = 10007; int n, q; char s[N]; vector<int>G[N]; int f[N][20], dep[N], cnt, ls[N << 4], rs[N << 4], rt[N]; struct X { int ABCBA, A, AB, ABC, ABCB, B, BC, BCB, BCBA, C, CB, CBA, BA; X() { ABCBA = A = AB = ABC = ABCB = B = BC = BCB = BCBA = C = CB = CBA = BA = 0; } X operator+(const X &t) const { X tmp; tmp.ABCBA = (ABCBA + t.ABCBA + A * t.BCBA + AB * t.CBA + ABC * t.BA + ABCB * t.A) % mod; tmp.A = (A + t.A) % mod; tmp.AB = (AB + t.AB + A * t.B) % mod; tmp.ABC = (ABC + t.ABC + A * t.BC + AB * t.C) % mod; tmp.ABCB = (ABCB + t.ABCB + A * t.BCB + AB * t.CB + ABC * t.B) % mod; tmp.B = (B + t.B) % mod; tmp.BC = (BC + t.BC + B * t.C) % mod; tmp.BCB = (BCB + t.BCB + B * t.CB + BC * t.B) % mod; tmp.BCBA = (BCBA + t.BCBA + B * t.CBA + BC * t.BA + BCB * t.A) % mod; tmp.C = (C + t.C) % mod; tmp.CB = (CB + t.CB + C * t.B) % mod; tmp.CBA = (CBA + t.CBA + C * t.BA + CB * t.A) % mod; tmp.BA = (BA + t.BA + B * t.A) % mod; return tmp; } }tr[N << 4]; void up(int& o, int pre, int l, int r, int p, char c) { o = ++cnt; ls[o] = ls[pre]; rs[o] = rs[pre]; if (l == r) { if (c == 'A')tr[o].A = 1; else if (c == 'B')tr[o].B = 1; else if (c == 'C')tr[o].C = 1; return; } int mid = ((l + r) >> 1); if (p <= mid)up(ls[o], ls[pre], l, mid, p, c); else up(rs[o], rs[pre], mid + 1, r, p, c); tr[o] = tr[ls[o]] + tr[rs[o]]; } void dfs(int u, int _fa) { f[u][0] = _fa; dep[u] = dep[_fa] + 1; up(rt[u], rt[_fa], 1, n, dep[u], s[u]); for (int i = 1; (1 << i) <= dep[u]; i++) f[u][i] = f[f[u][i - 1]][i - 1]; for (int v : G[u]) { if (v != _fa)dfs(v, u); } } int LCA(int u, int v) { if (dep[u] < dep[v])swap(u, v); for (int i = 17; i >= 0; i--) { if ((1 << i) <= dep[u] - dep[v]) u = f[u][i]; } if (u == v)return u; for (int i = 17; i >= 0; i--) { if (f[u][i] != f[v][i]) { u = f[u][i]; v = f[v][i]; } } return f[u][0]; } X query(int o, int l, int r, int ql, int qr) { if (ql <= l && qr >= r)return tr[o]; X ans; int mid = ((l + r) >> 1); if (ql <= mid)ans = ans + query(ls[o], l, mid, ql, qr); if (qr > mid)ans = ans + query(rs[o], mid + 1, r, ql, qr); return ans; } int main() { scanf("%d%d", &n, &q); scanf("%s", s + 1); for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } dfs(1, 0); while (q--) { int u, v; scanf("%d%d", &u, &v); int lca = LCA(u, v), ans = 0; if (u == lca) ans = query(rt[v], 1, n, dep[lca], dep[v]).ABCBA; else if (v == lca) ans = query(rt[u], 1, n, dep[lca], dep[u]).ABCBA; else { X t1 = query(rt[u], 1, n, dep[lca], dep[u]); X t2 = query(rt[v], 1, n, dep[lca] + 1, dep[v]); ans = (t1.ABCBA + t2.ABCBA + t1.A*t2.BCBA + t1.BA*t2.CBA + t1.CBA*t2.BA + t1.BCBA*t2.A) % mod; } printf("%d\n", ans); } return 0; }
|