2 条题解
-
112322132131231 (褚战) LV 10 @ 2023-07-12 17:46:30
#include<cstdio> #include<map> #include<vector> #include<cassert> #include<bitset> #include<ctime> #include<iostream> #include<algorithm> using namespace std; vector<int> E[5005]; vector<int> kind[1250]; vector<int> T[15]; int rnd[2005]; int dp[5005][1250]; int crt[20] , ntd[1250]; int C[5005][20]; int inv[21]; const int mod = 998244353; const int inv6 = 166374059; int n , kk , ans = 0 , limit , cnt = 0; int p[20] , tot = 0 , e_tot = 0; map<int,int> mp; map<long long,int> hashh; map<pair<int,int> , int> edge; int seed = 91478513 , a = 16554871 , b = 35659598; int power(int a,int b) { int temp = a , ans = 1; while(b){ if(b&1) ans = (1LL * ans * temp) % mod; temp = (1LL * temp * temp) % mod; b >>= 1; } return ans; } inline int rand() { return seed = (1LL * (seed ^ a) * b) % mod; } int get_hash(int fa,int u) { long long h = 1; vector<int> hh; for(int i = 0;i < T[u].size();i++){ if(T[u][i] != fa){ int g = get_hash(u , T[u][i]); h += rnd[g]; hh.push_back(g); } } map<long long , int>::iterator it = hashh.find(h); if(it != hashh.end()) return it->second; hashh.insert(pair<long long,int>{h , ++tot}); kind[tot] = hh;ntd[tot] = 1; sort(kind[tot].begin() , kind[tot].end()); int qq = 1 , start = 0;ntd[tot] = 1; for(;start < kind[tot].size() && kind[tot][start] == 1;start++); for(int i = start + 1;i < kind[tot].size();i++){ if(kind[tot][i] == 1) continue; if(kind[tot][i] == kind[tot][i - 1]) qq++; else{ ntd[tot] = (1LL * ntd[tot] * inv[qq]) % mod; qq = 1; } } ntd[tot] = (1LL * ntd[tot] * inv[qq]) % mod; return tot; } int get_p(int fa,int u) { long long h = 1; for(int i = 0;i < T[u].size();i++){ if(T[u][i] != fa){ int g = get_p(u , T[u][i]); if(g == -1) return -1; h += rnd[g]; } } map<long long , int>::iterator it = hashh.find(h); if(it != hashh.end()) return it->second; return -1; } inline int get_num(int u,int v) { if(u > v) swap(u , v); map<pair<int,int> , int>::iterator it = edge.find(pair<int,int>{u , v}); if(it != edge.end()) return it->second; e_tot++; edge.insert(pair<pair<int,int>,int>{pair<int,int>{u , v} , e_tot}); return e_tot; } int get_node(int cnode , int k , vector<int> G[]) { edge.clear();e_tot = 0; int q = 0; for(int i = 1;i <= cnode;i++) q += G[i].size(); q /= 2; if(k == 2){ int ans = 0; for(int i = 1;i <= cnode;i++) ans = (ans + 1LL * (G[i].size() - 1) * G[i].size()) % mod; return (1LL * ans * inv[2]); } if(k == 3){ int ans = 0; for(int i = 1;i <= cnode;i++){ for(int j = 0;j < G[i].size();j++){ if(i > G[i][j]) continue; if(G[i].size() + G[G[i][j]].size() < 4) continue; int d = G[i].size() + G[G[i][j]].size() - 2; ans = (ans + 1LL * d * (d - 1) % mod * inv[2]) % mod; } } return ans; } if(k == 4){ vector<int> ed[cnode + 1]; int ans = 0; for(int i = 1;i <= cnode;i++){ for(int j = 0;j < G[i].size();j++){ if(i > G[i][j]) continue; int d0 = G[i].size()+G[G[i][j]].size()-2; ed[i].push_back (d0 - 1), ed[G[i][j]].push_back (d0 - 1); if (d0 > 1) ans = (ans + 1LL * d0 * (d0 - 1) % mod * (d0 - 2) % mod) % mod; } } for(int i = 1;i <= cnode;i++){ long long x = 0, y = 0; for (int j = 0; j < ed[i].size(); ++ j) if (ed[i][j]>0) x = (x + ed[i][j]) % mod, y = (y+1LL * (ed[i][j] * ed[i][j]) % mod) % mod; ans = (ans + (x * x % mod - y + mod) % mod) % mod; } return 1LL * ans * inv[2] % mod; } vector<int> G2[q + 1]; for(int i = 1;i <= cnode;i++){ if(G[i].size() < 2) continue; int pst[G[i].size()]; for(int j = 0;j < G[i].size();j++) pst[j] = get_num(i , G[i][j]); for(int j = 0;j < G[i].size();j++){ for(int k = j + 1;k < G[i].size();k++){ G2[pst[j]].push_back(pst[k]);G2[pst[k]].push_back(pst[j]); } } } return get_node(e_tot , k - 1 , G2); } bool connect[(1<<10) + 1]; int neigh[20]; int brout(vector<int> G[]) { for(int i = 0;i < (1<<limit);i++) connect[i] = 0; for(int i = 1;i <= limit;i++){ neigh[i] = 0; for(int j = 0;j < G[i].size();j++){ neigh[i] |= (1<<G[i][j]-1); } connect[1<<i-1] = 1; } int ans = 0; for(int i = 1;i < (1<<limit) - 1;i++){ if(!connect[i]) continue; int mask = 0 , pp = 0; for(int j = 1;j <= limit;j++){ if((i >> j-1) & 1) mask |= neigh[j]; } for(int j = 1;j <= limit;j++){ if((mask >> j - 1) & 1) {connect[i | (1<<j-1)] = 1;} } if(i == (i&-i)) continue; for(int j = 1;j <= limit;j++){ T[j].clear(); if(((i >> j-1) & 1) == 0) continue; if(!pp) pp = j; for(int k = 0;k < G[j].size();k++){ if((i>>G[j][k]-1) & 1) {T[j].push_back(G[j][k]);} } } ans = (ans + mp[get_hash(0 , pp)]) % mod; } return ans; } void count() { for(int i = 1;i <= limit;i++) T[i].clear(); for(int i = 2;i <= limit;i++){ T[i].push_back(p[i]);T[p[i]].push_back(i); } int h = get_hash(0 , 1); map<int,int>::iterator it = mp.find(h); if(it != mp.end()) return; for(int i = 2;i <= limit;i++){ int g = get_p(0 , i); if(g == -1) continue; map<int,int>::iterator it = mp.find(g); if(it != mp.end()) {mp.insert(pair<int,int>{h , it->second});return;} } vector<int> W[limit + 1]; for(int i = 1;i <= limit;i++){ for(int j = 0;j < T[i].size();j++) W[i].push_back(T[i][j]); } int t = get_node(limit , kk , T); int g = brout(W); t = (t + mod - g) % mod; mp.insert(pair<int,int>{h , t}); return; } void dfs(int x) { if(x == limit){ count();return; } for(int i = 1;i <= x;i++){ p[x + 1] = i;dfs(x + 1); } return; } void find(int fa,int u) { dp[u][1] = 1; for(int i = 0;i < E[u].size();i++){ if(E[u][i] != fa) find(u , E[u][i]); } if(E[u].size() == 1 && fa != 0) return; int siz = (fa == 0) ? E[u].size() : E[u].size() - 1; int cop[siz + 1][(1<<kk)+1]; for(int i = 2;i <= cnt;i++){ int len = 0 , g = 1 , pcnt = 0; for(int j = 0;j < kind[i].size();j++){ if(kind[i][j] != 1){ crt[++pcnt] = j; ++len; } } for(int j = 0;j <= siz;j++){ for(int k = 0;k < (1<<len);k++) cop[j][k] = 0; } cop[0][0] = 1; for(int j = 0;j < E[u].size();j++){ if(E[u][j] == fa) continue; cop[g][0] = 1; for(int k = 1;k < (1<<len);k++){ cop[g][k] = cop[g - 1][k]; for(int p = 0;p < len;p++){ if((k>>p) & 1){ cop[g][k] = (cop[g][k] + 1LL * cop[g - 1][k ^ (1<<p)] * dp[E[u][j]][kind[i][crt[p+1]]]) % mod; } } } g++; } if(siz >= kind[i].size()) dp[u][i] = (1LL * cop[siz][(1<<len) - 1] * C[siz - len][kind[i].size() - len]) % mod; dp[u][i] = (1LL * dp[u][i] * ntd[i]) % mod; } return; } int main() { scanf("%d%d",&n,&kk); C[0][0] = 1; for(int i = 1;i <= n;i++){ C[i][0] = 1; for(int j = 1;j <= i && j <= 15;j++){ C[i][j] = (C[i-1][j] + C[i-1][j - 1]) % mod; } } int g = 1;inv[1] = inv[0] = 1; for(int i = 2;i <= 20;i++){ g = (1LL * g * i) % mod; inv[i] = power(g , mod - 2); } for(int i = 0;i <= 2000;i++) rnd[i] = rand(); for(int i = 1;i < n;i++){ int u , v;scanf("%d%d",&u,&v); E[u].push_back(v); E[v].push_back(u); } mp.insert(pair<int,int>{1 , 0}); for(int i = 2;i <= kk + 1;i++){ limit = i; dfs(1); }cnt = tot; find(0 , 1); int ans = 0; for(int i = 1;i <= n;i++){ for(int j = 1;j <= cnt;j++){ ans = (ans + 1LL * dp[i][j] * mp[j]) % mod; } } printf("%d\n",ans); return 0; }
-
-12020-04-30 18:37:54@
想要知道正解的请去看另外两位大佬的题解。
我只是想说一下如何用人类智慧,步步套娃,来骗到一些部分分。
先说一个事情,这个图一定没有重边与自环。我的个人能力只想到了30-50分的办法,不过作为ZJOIZJOI,6个题每个题骗50,就进省队了。
算法一:k = 1,n = 5000嘛,你可以暴力把图建出来,然后就可以获得0分的好成绩了。
算法二:k = 2,不难发现一个图的线图的点数就是这个图的边数,利用算法一里的图,输出边数,期望得分10分。
算法三:k = 3,我们想要L^2(G)L
2
(G)的点数,就是需要知道L(G)L(G)的边数,不难发现L(G)L(G)的边数与G每个点的度数有关,点ii的度数为d[i]d[i],对答案的贡献是C_{d[i]}^2C
d[i]
2
,期望得分20分。聪明的你一定要预处理逆元的,要不然常数不优秀的话就会收到TLETLE好礼。算法四:k = 4,我想要L(G)L(G)的点的度数,不难发现,这与GG的共点边数有关。对于每个边,它变成的点的度数就是与它相邻的边(与它有公共点的边)的数量,期望得分30分。共点边是我自己口胡定义的,不过我相信聪明的你一定可以理解的。
算法五: k = 5,这个是重头戏,值20分呢。我想要L(G)L(G)的共点边数,不难发现这可以枚举G的两条邻边进行统计,不过复杂度是糟糕的O(n ^ 4n
4
),就算是你的复杂度是O(松)的,都过不去。然后这个时候你需要一些信仰,还需要一些卡常能力,你要相信图上的边和点会很少,O(n ^ 2lognn
2
logn)是能过的,因为没有写过,我也不确定能不能过,如果有人写完,请告诉我一声,非常感谢。不过我们不难发现,对于同一个点,共点边数量的相同的边,是等价的。而一对相邻的公共点为x的边的贡献是g[i] + g[j] - 2g[i]+g[j]−2, g[x]g[x]含义是x的共点边数,不难发现这个答案是可以NTTNTT的, g[x] <= ng[x]<=n,所以复杂度O(n ^ 2lognn
2
logn)的。相信信仰的力量,奥利给一下就完事了。最后没有代码,因为懒得敲了,要是在考场上就敲了。
cpp
#include<cstdio>
#include<map>
#include<vector>
#include<cassert>
#include<bitset>
#include<ctime>
#include<iostream>
#include<algorithm>
using namespace std;
vector<int> E[5005];
vector<int> kind[1250];
vector<int> T[15];
int rnd[2005];
int dp[5005][1250];
int crt[20] , ntd[1250];
int C[5005][20];
int inv[21];
const int mod = 998244353;
const int inv6 = 166374059;
int n , kk , ans = 0 , limit , cnt = 0;
int p[20] , tot = 0 , e_tot = 0;
map<int,int> mp;
map<long long,int> hashh;
map<pair<int,int> , int> edge;
int seed = 91478513 , a = 16554871 , b = 35659598;
int power(int a,int b)
{
int temp = a , ans = 1;
while(b){
if(b&1) ans = (1LL * ans * temp) % mod;
temp = (1LL * temp * temp) % mod;
b >>= 1;
}
return ans;
}
inline int rand()
{
return seed = (1LL * (seed ^ a) * b) % mod;
}
int get_hash(int fa,int u)
{
long long h = 1;
vector<int> hh;
for(int i = 0;i < T[u].size();i++){
if(T[u][i] != fa){
int g = get_hash(u , T[u][i]);
h += rnd[g];
hh.push_back(g);
}
}
map<long long , int>::iterator it = hashh.find(h);
if(it != hashh.end()) return it->second;
hashh.insert(pair<long long,int>{h , ++tot});
kind[tot] = hh;ntd[tot] = 1;
sort(kind[tot].begin() , kind[tot].end());
int qq = 1 , start = 0;ntd[tot] = 1;
for(;start < kind[tot].size() && kind[tot][start] == 1;start++);
for(int i = start + 1;i < kind[tot].size();i++){
if(kind[tot][i] == 1) continue;
if(kind[tot][i] == kind[tot][i - 1]) qq++;
else{
ntd[tot] = (1LL * ntd[tot] * inv[qq]) % mod;
qq = 1;
}
}
ntd[tot] = (1LL * ntd[tot] * inv[qq]) % mod;
return tot;
}
int get_p(int fa,int u)
{
long long h = 1;
for(int i = 0;i < T[u].size();i++){
if(T[u][i] != fa){
int g = get_p(u , T[u][i]);
if(g == -1) return -1;
h += rnd[g];
}
}
map<long long , int>::iterator it = hashh.find(h);
if(it != hashh.end()) return it->second;
return -1;
}
inline int get_num(int u,int v)
{
if(u > v) swap(u , v);
map<pair<int,int> , int>::iterator it = edge.find(pair<int,int>{u , v});
if(it != edge.end()) return it->second;
e_tot++;
edge.insert(pair<pair<int,int>,int>{pair<int,int>{u , v} , e_tot});
return e_tot;
}
int get_node(int cnode , int k , vector<int> G[])
{
edge.clear();e_tot = 0;
int q = 0;
for(int i = 1;i <= cnode;i++) q += G[i].size();
q /= 2;
if(k == 2){
int ans = 0;
for(int i = 1;i <= cnode;i++) ans = (ans + 1LL * (G[i].size() - 1) * G[i].size()) % mod;
return (1LL * ans * inv[2]);
}
if(k == 3){
int ans = 0;
for(int i = 1;i <= cnode;i++){
for(int j = 0;j < G[i].size();j++){
if(i > G[i][j]) continue;
if(G[i].size() + G[G[i][j]].size() < 4) continue;
int d = G[i].size() + G[G[i][j]].size() - 2;
ans = (ans + 1LL * d * (d - 1) % mod * inv[2]) % mod;
}
}
return ans;
}
if(k == 4){
vector<int> ed[cnode + 1];
int ans = 0;
for(int i = 1;i <= cnode;i++){
for(int j = 0;j < G[i].size();j++){
if(i > G[i][j]) continue;
int d0 = G[i].size()+G[G[i][j]].size()-2;
ed[i].push_back (d0 - 1), ed[G[i][j]].push_back (d0 - 1);
if (d0 > 1) ans = (ans + 1LL * d0 * (d0 - 1) % mod * (d0 - 2) % mod) % mod;
}
}
for(int i = 1;i <= cnode;i++){
long long x = 0, y = 0;
for (int j = 0; j < ed[i].size(); ++ j) if (ed[i][j]>0)
x = (x + ed[i][j]) % mod, y = (y+1LL * (ed[i][j] * ed[i][j]) % mod) % mod;
ans = (ans + (x * x % mod - y + mod) % mod) % mod;
}
return 1LL * ans * inv[2] % mod;
}
vector<int> G2[q + 1];
for(int i = 1;i <= cnode;i++){
if(G[i].size() < 2) continue;
int pst[G[i].size()];
for(int j = 0;j < G[i].size();j++) pst[j] = get_num(i , G[i][j]);
for(int j = 0;j < G[i].size();j++){
for(int k = j + 1;k < G[i].size();k++){
G2[pst[j]].push_back(pst[k]);G2[pst[k]].push_back(pst[j]);
}
}
}
return get_node(e_tot , k - 1 , G2);
}
bool connect[(1<<10) + 1];
int neigh[20];
int brout(vector<int> G[])
{
for(int i = 0;i < (1<<limit);i++) connect[i] = 0;
for(int i = 1;i <= limit;i++){
neigh[i] = 0;
for(int j = 0;j < G[i].size();j++){
neigh[i] |= (1<<G[i][j]-1);
}
connect[1<<i-1] = 1;
}
int ans = 0;
for(int i = 1;i < (1<<limit) - 1;i++){
if(!connect[i]) continue;
int mask = 0 , pp = 0;
for(int j = 1;j <= limit;j++){
if((i >> j-1) & 1) mask |= neigh[j];
}
for(int j = 1;j <= limit;j++){
if((mask >> j - 1) & 1) {connect[i | (1<<j-1)] = 1;}
}
if(i == (i&-i)) continue;
for(int j = 1;j <= limit;j++){
T[j].clear();
if(((i >> j-1) & 1) == 0) continue;
if(!pp) pp = j;
for(int k = 0;k < G[j].size();k++){
if((i>>G[j][k]-1) & 1) {T[j].push_back(G[j][k]);}
}
}
ans = (ans + mp[get_hash(0 , pp)]) % mod;
}
return ans;
}
void count()
{
for(int i = 1;i <= limit;i++) T[i].clear();
for(int i = 2;i <= limit;i++){
T[i].push_back(p[i]);T[p[i]].push_back(i);
}
int h = get_hash(0 , 1);
map<int,int>::iterator it = mp.find(h);
if(it != mp.end()) return;
for(int i = 2;i <= limit;i++){
int g = get_p(0 , i);
if(g == -1) continue;
map<int,int>::iterator it = mp.find(g);
if(it != mp.end()) {mp.insert(pair<int,int>{h , it->second});return;}
}
vector<int> W[limit + 1];
for(int i = 1;i <= limit;i++){
for(int j = 0;j < T[i].size();j++) W[i].push_back(T[i][j]);
}
int t = get_node(limit , kk , T);
int g = brout(W);
t = (t + mod - g) % mod;
mp.insert(pair<int,int>{h , t});
return;
}
void dfs(int x)
{
if(x == limit){
count();return;
}
for(int i = 1;i <= x;i++){
p[x + 1] = i;dfs(x + 1);
}
return;
}
void find(int fa,int u)
{
dp[u][1] = 1;
for(int i = 0;i < E[u].size();i++){
if(E[u][i] != fa) find(u , E[u][i]);
}
if(E[u].size() == 1 && fa != 0) return;
int siz = (fa == 0) ? E[u].size() : E[u].size() - 1;
int cop[siz + 1][(1<<kk)+1];
for(int i = 2;i <= cnt;i++){
int len = 0 , g = 1 , pcnt = 0;
for(int j = 0;j < kind[i].size();j++){
if(kind[i][j] != 1){
crt[++pcnt] = j;
++len;
}
}
for(int j = 0;j <= siz;j++){
for(int k = 0;k < (1<<len);k++) cop[j][k] = 0;
}
cop[0][0] = 1;
for(int j = 0;j < E[u].size();j++){
if(E[u][j] == fa) continue;
cop[g][0] = 1;
for(int k = 1;k < (1<<len);k++){
cop[g][k] = cop[g - 1][k];
for(int p = 0;p < len;p++){
if((k>>p) & 1){
cop[g][k] = (cop[g][k] + 1LL * cop[g - 1][k ^ (1<<p)] * dp[E[u][j]][kind[i][crt[p+1]]]) % mod;
}
}
}
g++;
}
if(siz >= kind[i].size()) dp[u][i] = (1LL * cop[siz][(1<<len) - 1] * C[siz - len][kind[i].size() - len]) % mod;
dp[u][i] = (1LL * dp[u][i] * ntd[i]) % mod;
}
return;
}
int main()
{
scanf("%d%d",&n,&kk);
C[0][0] = 1;
for(int i = 1;i <= n;i++){
C[i][0] = 1;
for(int j = 1;j <= i && j <= 15;j++){
C[i][j] = (C[i-1][j] + C[i-1][j - 1]) % mod;
}
}
int g = 1;inv[1] = inv[0] = 1;
for(int i = 2;i <= 20;i++){
g = (1LL * g * i) % mod;
inv[i] = power(g , mod - 2);
}
for(int i = 0;i <= 2000;i++) rnd[i] = rand();
for(int i = 1;i < n;i++){
int u , v;scanf("%d%d",&u,&v);
E[u].push_back(v);
E[v].push_back(u);
}
mp.insert(pair<int,int>{1 , 0});
for(int i = 2;i <= kk + 1;i++){
limit = i;
dfs(1);
}cnt = tot;
find(0 , 1);
int ans = 0;
for(int i = 1;i <= n;i++){
for(int j = 1;j <= cnt;j++){
ans = (ans + 1LL * dp[i][j] * mp[j]) % mod;
}
}
printf("%d\n",ans);
return 0;
}
- 1
信息
- ID
- 2041
- 难度
- 4
- 分类
- (无)
- 标签
- 递交数
- 20
- 已通过
- 16
- 通过率
- 80%
- 被复制
- 2
- 上传者