树链剖分,听起来确实是一种很高级的算法,但其实它并没有想象中的那么难以理解,事实上,个人觉得,树剖其实根本没有什么太大的思维难度(老实说我觉得怕不是背包都比它要难理解),只是码量大亿点,但只要熟练掌握了,打代码其实也并没有那么难

所谓树链剖分,是用来解决一类树上问题的,它将一棵树剖成很多条链,把树上问题转化成序列问题,然后用其它一些数据结构,比如线段树来维护树上路径的信息

举个例子,比如给定一棵树,每个点有自己的权值,要求查询某两点之间的路径上的所有点的点权和,看起来很简单是不是?用倍增在跳的时候顺便统计就行了,那么如果再要求支持改变某个点的点权呢,或是给某两点之间的路径上的所有点的点权全部加上一个数呢?倍增很明显就不行了吧,这时候就要用到树剖了

思路

树链剖分一般指的是重链剖分,在讲解它的思路之前,我们先要明确几个概念

  • 重儿子:一个点的所有儿子中子树节点数量最多的儿子,如果有多个,那就随便选一个
  • 轻儿子:一个点所有儿子中除重儿子以外的其他儿子,也就是说,对于一个点,重儿子是唯一的,但轻儿子不唯一
  • 重边:一个点到它的重儿子之间的边
  • 轻边:一个点到它的轻儿子之间的边
  • 重链:一大堆重边组成的一条链

这么说起来可能有点抽象,还是拿张图最容易理解

在这张图中,蓝色节点表示轻儿子,橙色节点表示重儿子,相应的,蓝色边表示轻边,橙色边表示重边,由绿框框起来的就是重链,特别的,单独一个点也可以叫做重链

这个过程完了之后,整棵树就会被完全剖分成一条一条的重链

接下来是重点

对于这一条一条的重链,很明显我们还是不能直接用线段树去维护,因为每条链中的节点编号并不是连续的,所以,我们还要引入一个东西——DFS 序

这个东西就是树剖把树上问题转化成序列问题的关键,所谓 DFS 序呢,就是在对这棵树进行 DFS 的时候,标记每个点是第几个到达的,其实也就是强连通分量 Tarjan 算法里的 dfn 数组,但树剖不太一样,因为我们需要让一条重链上所有点的 DFS 序连续,这样才好让这条链变成一个区间,所以,我们在对这棵树进行 DFS 的时候,优先遍历重儿子,这样就可以保证一条重链上的点的 DFS 序连续,因为是先把这条重链一拉到底之后再遍历其他的重链

对于上面那棵树,它的 DFS 序如下(拿蓝框框起来的就是)

我们可以用每个节点的 DFS 序建一棵线段树,这样,一条重链就是一个连续的区间了,也便于维护

代码实现

首先我们先来看一些变量

  • dep[]:记录每个点的深度
  • son[]:记录每个点的重儿子
  • size[]:记录以当前点为根的子树的节点个数
  • id[]:记录每个点的 DFS 序
  • rk[]:记录每个 DFS 序对应的点
  • top[]:记录当前点所在重链的顶部(下称链头),也就是深度最浅的点

其实树链剖分一共只需要两次 DFS 就可以解决了,第一次求出每个点的深度和重儿子(dep[],son[],size[]),第二次记录每个点的 DFS 序,相当于是连重边(id[],rk[],top[]),代码很短,也很好理解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
void dfs1(int u,int father)//第一次 DFS
{
dep[u]=dep[father]+1;
fa[u]=father;
size[u]=1;//这个是把 u 点自己给算进去
for(int i=0;i<edge[u].size();i++)//这里是用的 vector 存图
if(edge[u][i]!=father)
{
dfs1(edge[u][i],u);//先遍历子树
size[u]+=size[edge[u][i]];//累加当前点的子树的节点数
son[u]=size[edge[u][i]]>size[son[u]]? edge[u][i]:son[u];//找重儿子,如果当前儿子的子树比重儿子的子树节点数量更多的话就换过去
}
}
void dfs2(int u,int father,int t)//第二次 DFS,t 是当前点所在重链的顶部
{
cnt++;//DFS 序
top[u]=t;
rk[cnt]=u;//当前 DFS 序对应的节点编号
id[u]=cnt;//当前节点对应的 DFS 序
if(son[u])
dfs2(son[u],u,t);//一定要优先遍历重儿子啊!!!
for(int i=0;i<edge[u].size();i++)
if(edge[u][i]!=son[u]&&edge[u][i]!=father)
dfs2(edge[u][i],u,edge[u][i]);
}

在第二个 DFS 中,之所以下面遍历轻儿子的时候把轻儿子所在的重链顶部设成是轻儿子,是因为当前节点和轻儿子并不在一条重链里,自然也就无法充当它所在的重链的顶部

代码是真的很短了,但这其实只是树剖本身的实现,不要忘了,它是用来解决一类问题的,剖的过程确实很简单,但要维护却比较恼火,上面提到过,可以用线段树来维护重链上的信息,这玩意本来码量就很大,再加上有时候还会结合 LCA,就更令人难受了

接下来,我会讲一下如何用树剖求 LCA,当然,一般来说用树剖求 LCA,就一定会有路径查询和路径修改,我会顺带着把这两个也讲一下的

树剖+LCA

如果是做过树剖的题的人,应该知道在这种题中一般会有这种要求,即求这棵树上从uuvv的最短路径上所有点权之和,以及将这棵树上从uuvv的最短路径上所有点的点权加上一个数,而这个最短路径,很明显就是要求 LCA

树剖求 LCA,思路其实和倍增差不多,都是往上跳,直到跳到同一个点结束,但树剖与倍增又有一点不同,倍增是往上跳2k2^k个祖先,而树剖则是直接跳到链头的父亲,因为如果uuvv在同一条重链上,那么可以肯定他们中有一个是对方的祖先,比较一下深度就行了,如果不在同一条重链上,那么我们就先让他们跳到同一条重链上,再按照前面的方法执行

首先,我们比较uuvv链头的深度,避免跳过头,接下来,我们把链头深度较深的那一个(假设是uu)跳到他链头的父亲,因为如果只跳到链头,那么很明显这个点所在的重链并没有变,只有跳到链头的父亲才是到了另外一条重链,如此循环,直到uuvv的链头是同一个点,也就是它们处于同一条重链上,这时深度浅的那一个就是LCA(u,v)\text{LCA}(u,v)

老规矩,用上面那张图手动模拟一下

假设我们求 14 和 16 号点的 LCA,过程如下:

  1. 首先比较各自链头的深度,很明显是 16 号点的链头深度更深,因此把 16 号点跳到链头的父亲 11 号点
  2. 再次比较 14 号点和 11 号点的链头深度,这次是 14 号点,跳到链头的父亲,为 7 号点
  3. 发现 7 号点和 11 号点的链头相同,也就是它们处于同一条重链,退出循环
  4. 因为 7 号点深度比 11 号点的深度更浅,所以LCA(14,16)=7\text{LCA}(14,16)=7

当然,一般来说如果只是单纯求 LCA 还用不到树剖上场,如果必须要用树剖,那就肯定是加了路径修改和路径查询,这两个也是导致树剖码量大的一个很重要的原因。

其实只要掌握了树剖求 LCA 的方法,修改和查询也不是什么难事,上面说过,我们是用节点的 DFS 序建的线段树,因此每条重链都是一个连续的区间,而重链上的每个点到链头很明显也是连续的,结合刚刚求 LCA 的跳法,我们只需要在uu(或vv)跳到链头的父亲之前先修改(或查询)uu到链头的点,并在最后出于同一重链上时把uuvv之间的点进行修改(或查询)就行了,具体的可以看一下代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
int lca_query(int u,int v)//查询
{
int ans=0;
while(top[u]!=top[v])//比较链头是否相同
if(dep[top[u]]>=dep[top[v]])
{
ans=ans+query(1,id[top[u]],id[u]);//query 是线段树的区间查询函数,和普通写法是一样的
u=fa[top[u]];//让这个点跳到链头的父亲
}
else
{
ans=ans+query(1,id[top[v]],id[v]);
v=fa[top[v]];
}
if(dep[u]>=dep[v])
ans=ans+query(1,id[v],id[u]);//u 和 v 处于同一重链上时单独处理
else
ans=ans+query(1,id[u],id[v]);
return ans;
}
void lca_update(int u,int v,int k)//和上面是一样的
{
while(top[u]!=top[v])
if(dep[top[u]]>=dep[top[v]])
{
update(1,id[top[u]],id[u],k);//update 是线段树的区间修改函数,和普通写法也是一样的
u=fa[top[u]];
}
else
{
update(1,id[top[v]],id[v],k);
v=fa[top[v]];
}
if(dep[u]>=dep[v])
update(1,id[v],id[u],k);
else
update(1,id[u],id[v],k);
}

(感觉其实不算很难啊)

一般来说,树剖的题基本上就是路径修改,路径查询,单点修改,单点查询,修改子树,查询子树这几种,后面四种都可以直接用线段树来完成,因为一个点的子树的 DFS 序也是一个完整的区间,比如上图中 7 的子树的 DFS 序就是从 3 到 8,由于之前记录了每棵子树的节点数量,所以这里只需要修改从 id[i] 到 id[i]+size[i]-1 这个区间就行了

另外,树剖求 LCA 的时间复杂度是O(logn)O(\log n),且常数较小,不容易被卡掉,预处理也是O(n)O(n)级别的,所以还算比较优秀了

参考代码

题目:P3384 【模板】轻重链剖分

这道题要求支持路径修改,路径查询,子树修改,子树查询四种操作,上面都已经讲过了,就直接贴代码了

感受树剖的码量吧!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
#include<bits/stdc++.h>
#define MAXN 100005
using namespace std;
struct node
{
int l;
int r;
int data;
int f;
};
node tree[MAXN<<2];
vector<int> edge[MAXN];
int n,m,r,p,cnt,w[MAXN];
int dep[MAXN],fa[MAXN],son[MAXN],size[MAXN],top[MAXN],rk[MAXN],id[MAXN];
void dfs1(int u,int father)//第一次 DFS
{
dep[u]=dep[father]+1;
fa[u]=father;
size[u]=1;
for(int i=0;i<edge[u].size();i++)
if(edge[u][i]!=father)
{
dfs1(edge[u][i],u);
size[u]+=size[edge[u][i]];
son[u]=size[edge[u][i]]>size[son[u]]? edge[u][i]:son[u];
}
}
void dfs2(int u,int father,int t)//第二次 DFS
{
cnt++;
top[u]=t;
rk[cnt]=u;
id[u]=cnt;
if(son[u])
dfs2(son[u],u,t);
for(int i=0;i<edge[u].size();i++)
if(edge[u][i]!=son[u]&&edge[u][i]!=father)
dfs2(edge[u][i],u,edge[u][i]);
}
void push_down(int i)//懒标记下传
{
int l=i<<1,r=i<<1|1;
if(!tree[i].f)
return;
tree[l].data=(tree[l].data+tree[i].f*(tree[l].r-tree[l].l+1)%p)%p;
tree[r].data=(tree[r].data+tree[i].f*(tree[r].r-tree[r].l+1)%p)%p;
tree[l].f=(tree[l].f+tree[i].f)%p;
tree[r].f=(tree[r].f+tree[i].f)%p;
tree[i].f=0;
}
void build(int i,int l,int r)//建线段树
{
tree[i].l=l;
tree[i].r=r;
if(l==r)
{
tree[i].data=w[rk[l]]%p;
return;
}
build(i<<1,l,(l+r)>>1);
build(i<<1|1,((l+r)>>1)+1,r);
tree[i].data=(tree[i<<1].data+tree[i<<1|1].data)%p;
}
void update(int i,int l,int r,int k)//线段树区间修改
{
if(tree[i].l>=l&&tree[i].r<=r)
{
tree[i].data=(tree[i].data+(tree[i].r-tree[i].l+1)*k%p)%p;
tree[i].f=(tree[i].f+k)%p;
return;
}
push_down(i);
if(tree[i<<1].r>=l)
update(i<<1,l,r,k);
if(tree[i<<1|1].l<=r)
update(i<<1|1,l,r,k);
tree[i].data=(tree[i<<1].data+tree[i<<1|1].data)%p;
}
int query(int i,int l,int r)//线段树区间查询
{
int ans=0;
if(tree[i].l>=l&&tree[i].r<=r)
return tree[i].data%p;
push_down(i);
if(tree[i<<1].r>=l)
ans=(ans+query(i<<1,l,r))%p;
if(tree[i<<1|1].l<=r)
ans=(ans+query(i<<1|1,l,r))%p;
return ans%p;
}
int lca_query(int u,int v)//路径查询
{
int ans=0;
while(top[u]!=top[v])
if(dep[top[u]]>=dep[top[v]])
{
ans=(ans+query(1,id[top[u]],id[u]))%p;
u=fa[top[u]];
}
else
{
ans=(ans+query(1,id[top[v]],id[v]))%p;
v=fa[top[v]];
}
if(dep[u]>=dep[v])
ans=(ans+query(1,id[v],id[u]))%p;
else
ans=(ans+query(1,id[u],id[v]))%p;
return ans;
}
void lca_update(int u,int v,int k)//路径修改
{
while(top[u]!=top[v])
if(dep[top[u]]>=dep[top[v]])
{
update(1,id[top[u]],id[u],k);
u=fa[top[u]];
}
else
{
update(1,id[top[v]],id[v],k);
v=fa[top[v]];
}
if(dep[u]>=dep[v])
update(1,id[v],id[u],k);
else
update(1,id[u],id[v],k);
}
int main()
{
scanf("%d%d%d%d",&n,&m,&r,&p);
for(int i=1;i<=n;i++)
scanf("%d",&w[i]);
for(int i=1;i<=n-1;i++)
{
int u,v;
scanf("%d%d",&u,&v);
edge[u].push_back(v);
edge[v].push_back(u);
}
dfs1(r,0);
dfs2(r,0,r);
build(1,1,n);
for(int i=1;i<=m;i++)
{
int op,x,y,z;
scanf("%d",&op);
switch(op)
{
case 1:
scanf("%d%d%d",&x,&y,&z);
lca_update(x,y,z);
break;
case 2:
scanf("%d%d",&x,&y);
printf("%d\n",lca_query(x,y));
break;
case 3:
scanf("%d%d",&x,&z);
update(1,id[x],id[x]+size[x]-1,z);
break;
case 4:
scanf("%d",&x);
printf("%d\n",query(1,id[x],id[x]+size[x]-1));
break;
}
}
return 0;
}

其实不算很难懂啦,关键就是码量太大了。

相关题目

其实真正开始做树剖的题就会发现,其实很多代码是通用的,完全可以把上一道题的代码复制过来稍微改一下就行了,但我个人不推荐这样做,毕竟要想熟练掌握代码,最好的方法就是多打几遍嘛,而且这样还可以练手速,做题做多了之后就可以越打越快(我有一次曾经一天内做了 8 道树剖题,手都快打废了