「题解」树套树 tree

本文将同步发布于:

题目

题目描述

给你一个 \(n\) 个点的小树(正常的树),给你一个 \(m\) 个点的大树,大树的节点是一棵小树,大树的边是跨越了两棵小树之间的边,\(q\) 次询问,求树上距离。

\(1\leq n,m,q\leq 4\times 10^4\)。

题解

预处理

思路非常简单,我们显然可以通过一系列操作 \(\Theta(n)\) 或 \(\Theta(n\log_2n)\) 预处理,使得可以在 \(\Theta(1)\) 或者 \(\Theta(\log_2n)\) 求出小树任意两点间的距离。

大树倍增

我们在大树的每个节点保存一点信息:

  • \(\texttt{fa}_i\):编号为 \(i\) 的大树节点在大树上的祖先为 \(\texttt{fa}_i\)。
  • \(\texttt{rt}_i\):编号为 \(i\) 的大树节点连接 \(\texttt{fa}_i\) 对应小树节点为 \(\texttt{rt}_i\);
  • \(\texttt{ptr}_i\):\(\texttt{rt}_i\) 在实际的树中对应的祖先,也就是编号为 \(\texttt{fa}_i\) 中与 \(i\) 相连的小树节点编号。

维护了以上信息后,我们再维护 \(\texttt{dis}_i\),表示 \(\texttt{rt}_i\) 到实际的树的根的距离。

然后直接倍增加分类讨论即可解决问题。

优化时间复杂度

不难看出,最简单的做法的时间复杂度为 \(\Theta\left(n\log_2n+m\left(\log_2n+\log_2m\right)+q\left(\log_2n+\log_2m\right)\right)\)。

我们可以通过 \(\Theta(n)\) 构造的 ST 表轻松将复杂度降到 \(\Theta(n+m+q)\),考虑到代码复杂度偏大,就没有具体实现。

参考程序

参考程序的时间复杂度为 \(\Theta\left(n+m\left(\log_2n+\log_2m\right)+q\left(\log_2n+\log_2m\right)\right)\),通过树剖消除了一个 \(\log\)。

#pragma GCC optimize("Ofast")
#include<bits/stdc++.h>
using namespace std;
#define reg register
typedef long long ll;

bool st;

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
static char buf[1<<21],*p1=buf,*p2=buf;
#define flush() (fwrite(wbuf,1,wp1,stdout),wp1=0)
#define putchar(c) (wp1==wp2&&(flush(),0),wbuf[wp1++]=c)
static char wbuf[1<<21];int wp1;const int wp2=1<<21;
inline int read(void){
	reg char ch=getchar();
	reg int res=0;
	while(!isdigit(ch)) ch=getchar();
	while(isdigit(ch)) res=10*res+(ch^'0'),ch=getchar();
	return res;
}

inline void writeln(reg int x){
	static char buf[32];
	reg int p=-1;
	if(!x) putchar('0');
	else while(x) buf[++p]=(x%10)^'0',x/=10;
	while(~p) putchar(buf[p--]);
	putchar('\n');
	return;
}

inline void swap(reg int &x,reg int &y){
	reg int tmp=x;
	x=y,y=tmp;
	return;
}

const int MAXN=4e4+5;
const int MAXLOG2N=16+1;
const int MAXM=4e4+5;
const int MAXLOG2M=16+1;
const int MAXQ=4e4+5;

int n,m,q;

namespace Small{
	int cnt,head[MAXN],to[MAXN<<1],Next[MAXN<<1];
	inline void Add_Edge(reg int u,reg int v){
		Next[++cnt]=head[u];
		to[cnt]=v;
		head[u]=cnt;
		return;
	}
	inline void Add_Tube(reg int u,reg int v){
		Add_Edge(u,v),Add_Edge(v,u);
		return;
	}
	int fa[MAXN],dep[MAXN];
	int siz[MAXN],son[MAXN];
	inline void dfs1(reg int u,reg int father){
		siz[u]=1;
		fa[u]=father;
		dep[u]=dep[father]+1;
		for(reg int i=head[u];i;i=Next[i]){
			reg int v=to[i];
			if(v!=father){
				dfs1(v,u);
				if(siz[son[u]]<siz[v])
					son[u]=v;
			}
		}
		return;
	}
	int top[MAXN];
	inline void dfs2(reg int u,reg int father,reg int topf){
		top[u]=topf;
		if(!son[u])
			return;
		dfs2(son[u],u,topf);
		for(reg int i=head[u];i;i=Next[i]){
			reg int v=to[i];
			if(v!=father&&v!=son[u])
				dfs2(v,u,v);
		}
		return;
	}
	inline int LCA(reg int x,reg int y){
		while(top[x]!=top[y])
			if(dep[top[x]]>dep[top[y]])
				x=fa[top[x]];
			else
				y=fa[top[y]];
		return dep[x]<dep[y]?x:y;
	}
	inline int getDis(reg int x,reg int y){
		return dep[x]+dep[y]-(dep[LCA(x,y)]<<1);
	}
}

namespace Big{
	int cnt,head[MAXN],to[MAXN<<1],st[MAXN<<1],ed[MAXN<<1],Next[MAXN<<1];
	inline void Add_Edge(reg int u,reg int v,reg int s,reg int e){
		Next[++cnt]=head[u];
		to[cnt]=v,st[cnt]=s,ed[cnt]=e;
		head[u]=cnt;
		return;
	}
	inline void Add_Tube(reg int u,reg int v,reg int s,reg int e){
		Add_Edge(u,v,s,e),Add_Edge(v,u,e,s);
		return;
	}
	int fa[MAXM][MAXLOG2M],dep[MAXM];
	int dis[MAXM];
	int rt[MAXM],ptr[MAXM];
	inline void dfs(reg int u,reg int father,reg int e,reg int s){
		dep[u]=dep[father]+1;
		fa[u][0]=father;
		for(reg int i=1;(1<<i)<=dep[u];++i)
			fa[u][i]=fa[fa[u][i-1]][i-1];
		if(father)
			rt[u]=e,ptr[u]=s,dis[u]=dis[father]+Small::getDis(s,rt[father])+1;
		else
			rt[u]=1,ptr[u]=0,dis[u]=0;
		for(reg int i=head[u];i;i=Next[i]){
			reg int v=to[i];
			if(v!=father)
				dfs(v,u,ed[i],st[i]);
		}
		return;
	}
	inline int LCA(int x,int y){
		if(dep[x]>dep[y])
			swap(x,y);
		for(reg int i=MAXLOG2N-1;i>=0;--i)
			if(dep[fa[y][i]]>=dep[x])
				y=fa[y][i];
		if(x==y)
			return x;
		for(reg int i=MAXLOG2N-1;i>=0;--i)
			if(fa[x][i]!=fa[y][i])
				x=fa[x][i],y=fa[y][i];
		return fa[x][0];
	}
	inline pair<int,int> LCA_lower(int x,int y){
		if(dep[x]>dep[y])
			swap(x,y);
		for(reg int i=MAXLOG2N-1;i>=0;--i)
			if(dep[fa[y][i]]>dep[x])
				y=fa[y][i];
		if(fa[y][0]==x)
			return make_pair(y,0);
		if(dep[y]>dep[x])
			y=fa[y][0];
		for(reg int i=MAXLOG2N-1;i>=0;--i)
			if(fa[x][i]!=fa[y][i])
				x=fa[x][i],y=fa[y][i];
		return make_pair(x,y);
	}
}

bool ed;

int main(void){
	n=read(),m=read(),q=read();
	for(reg int i=1;i<n;++i){
		static int x,y;
		x=read(),y=read();
		Small::Add_Tube(x,y);
	}
	Small::dfs1(1,0),Small::dfs2(1,0,1);
	for(reg int i=1;i<m;++i){
		static int w,x,y,z;
		w=read(),x=read(),y=read(),z=read();
		Big::Add_Tube(w,y,x,z);
	}
	Big::dfs(1,0,1,0);

	/*
	puts("============");
	puts("Small:");
	for(reg int i=1;i<=n;++i)
		printf("i=%d fa=%d dep=%d\n",i,Small::fa[i][0],Small::dep[i]);
	puts("============");
	puts("Big:");
	for(reg int i=1;i<=m;++i)
		printf("i=%d fa=%d dep=%d dis=%lld rt=%d ptr=%d\n",i,Big::fa[i][0],Big::dep[i],Big::dis[i],Big::rt[i],Big::ptr[i]);
	puts("============");
	*/

	while(q--){
		static int w,x,y,z,part1,part2,part3,bLca;
		static pair<int,int> p;
		w=read(),x=read(),y=read(),z=read();
		//printf("query w=%d x=%d y=%d z=%d\n",w,x,y,z);
		if(w==y){
			//puts("S1");
			writeln(Small::getDis(x,z));
		}
		else{
			bLca=Big::LCA(w,y);
			if(bLca==w||bLca==y){
				//puts("S2");
				if(bLca==y)
					swap(w,y),swap(x,z);
				p=Big::LCA_lower(w,y);
				part1=Small::getDis(z,Big::rt[y]);
				part2=Big::dis[y]-Big::dis[p.first];
				part3=1+Small::getDis(Big::ptr[p.first],x);
				//printf("part1=%d part2=%d part3=%d\n",part1,part2,part3);
				writeln(part1+part2+part3);
			}
			else{
				//puts("S3");
				p=Big::LCA_lower(w,y);
				part1=Small::getDis(x,Big::rt[w])+Small::getDis(z,Big::rt[y]);
				part2=Big::dis[w]-Big::dis[p.first]+Big::dis[y]-Big::dis[p.second];
				part3=2+Small::getDis(Big::ptr[p.first],Big::ptr[p.second]);
				//printf("part1=%d part2=%d part3=%d\n",part1,part2,part3);
				writeln(part1+part2+part3);
			}
		}
	}
	flush();
	fprintf(stderr,"%.3lf s\n",1.0*clock()/CLOCKS_PER_SEC);
	fprintf(stderr,"%.3lf MiB\n",(&ed-&st)/1048576.0);
	return 0;
}
上一篇:雲雀


下一篇:[图论入门]网络最大流 - 增广路算法