[poj1741] 트리 포인트 치료

3012 단어 DFS점분치
처음으로 분치라도 써봤는데, 역시 황 학장의 코드 프레임워크가 좋았어...
이것은 틀림없이 분치의 가장 고전적인 제목이다.대충 자신의 이해를 적어 보세요.
우선, 한 그루의 나무에 대해 그 중심을 구하여 뿌리 노드로 삼는다.그리고 체인은 두 종류로 나눌 수 있는데 그것이 바로 뿌리 노드를 통과하는 것과 뿌리 노드를 통과하지 않는 것이다.루트 노드를 거치지 않는 경우 하위 트리에서 귀속 호출하면 됩니다.매번 루트 노드는 나무의 중심을 취하기 때문에 한 번 돌아오는 점의 개수는 적어도 2로 나누고, 돌아오는 층수는logN층을 초과하지 않는다.다른 한편, 각 층은 대체적으로 O(N)급 개점으로 볼 수 있으며, 이 점에 대한 작업 시간은 O(NlogN)급이다.따라서 전체 시간 복잡도 O(Nlog^2N).
루트 노드를 통과하는 체인에 대해 우리는 다음과 같은 조작을 진행한다.먼저 루트 노드의 깊이를 0으로 정한 다음에 하위 트리의 각 노드의 깊이를 구하십시오 dep[i].그리고 dep[]를 정렬합니다. 만약에 dep[x]+dep[y]<=k가 조건을 만족시키는 체인입니다.정렬 후 O(N)를 한 번 훑어보면 됩니다.
그러나 상술한 조작은 여전히 연구할 만한 점이 많다.우선, 단지 dep[x]+dep[y]<=k만으로는 뿌리 노드를 통과한 것을 충분히 설명할 수 없고, x와 y가 하나의 하위 트리를 사용하고 있을 수도 있다.그러나 뿌리 노드 즉 x와 y를 거쳐 서로 다른 서브트리에서 dep[x]+dep[y]<=k를 동시에 만족시키도록 강제하면 만발한다.그러므로 먼저 dep[x]+dep[y]<=k를 만족시키는 체인의 개수sum를 구한 다음에 매 순간 트리에 대해 dep[x]+dep[y]<=k의 개수tmp를 만족시키고sum에서 모든 tmp를 빼면 된다.이렇게 하면 알고리즘이 정확하게 실현된다.
근데 O(Nlog^2N)의 시간 복잡도는 별로 좋지 않은 것 같아요.실제로 많은 logN은 주로 sort에 소모된다.평면의 가장 가까운 점을 구할 때 우리는 하위 노드 (차이가 많지 않다는 뜻) 에서 정렬한 후에 다시 병합하는 방법을 사용할 수 있다.여기에서 점분치의 dep수조도 마찬가지로 서브트리에서 순서를 다 배열한 후에 병합할 수 있다.이렇게 하면 dep를 구하는 수조도 필요 없어요.그러나 이 합병은 비교적 번거롭다. 왜냐하면 여러 개의 자목이 있기 때문에 무더기로 유지해야 할 수도 있고, 상수의 차이를 고려하면 오히려 더 느릴 수도 있고, 실현의 복잡도를 고려해야 한다.내가 이렇게 약해서 당분간 생각 안 해.
다음은 AC 코드입니다.
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define inf 1000000000
#define N 300005
using namespace std;

int n,m,cnt,tot,rt,sum,ans,fst[N],pnt[N],len[N],nxt[N],c[N],d[N],sz[N],f[N];
bool vis[N];
int read(){
	int x=0; char ch=getchar();
	while (ch<'0' || ch>'9') ch=getchar();
	while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); }
	return x;
}
void add(int aa,int bb,int cc){
	pnt[++tot]=bb; nxt[tot]=fst[aa]; len[tot]=cc; fst[aa]=tot;
}
void dfs(int x,int last){
	sz[x]=f[x]=1; int p;
	for (p=fst[x]; p; p=nxt[p]){
		int q=pnt[p]; if (q==last || vis[q]) continue;
		dfs(q,x); sz[x]+=sz[q];
		f[x]=max(f[x],sz[q]);
	}
	f[x]=max(f[x],sum-sz[x]); if (f[x]<f[rt]) rt=x;
}
void getdep(int x,int last){
	c[++cnt]=d[x]; int p;
	for (p=fst[x]; p; p=nxt[p]){
		int q=pnt[p]; if (q==last || vis[q]) continue;
		d[q]=d[x]+len[p]; getdep(q,x);
	}
}
int work(int x,int dep){
	d[x]=dep; cnt=0;
	getdep(x,0); sort(c+1,c+cnt+1);
	int tmp=0,l,r=cnt;
	for (l=1; l<r; l++){
		while (l<r && c[l]+c[r]>m) r--;
		tmp+=r-l;
	}
	return tmp;
}
void solve(int x){
	ans+=work(x,0);vis[x]=1; int p;
	for (p=fst[x]; p; p=nxt[p]){
		int q=pnt[p]; if (vis[q]) continue;
		ans-=work(q,len[p]);
		rt=0; sum=sz[q]; dfs(q,x); solve(rt);
	}
}
int main(){
	while (n=read()){
		m=read(); int i;
		tot=ans=0; memset(fst,0,sizeof(fst));
		memset(vis,0,sizeof(vis));
		for (i=1; i<n; i++){
			int x=read(),y=read(),z=read();
			add(x,y,z); add(y,x,z);
		}
		sum=n; f[rt=0]=inf; dfs(1,0);
		solve(rt); printf("%d
",ans); } return 0; }

2015.11.15
by lych

좋은 웹페이지 즐겨찾기