重链剖分

概述

由于重链剖分时,从根节点到叶子节点经过的重链数量必定小于等于lognlogn这一量级,故可以通过对每一条重链进行维护和求解得到最终答案。

只要经过了重链的某一个节点,就算是经过了该重链

一般来说,可以结合线段树维护dfs序,再通过类似于找LCA的方式向上跳,只不过这里向上跳是跳到该重链最上面的端点。然后在跳的过程中利用线段树对经过的每一条重链进行查询,对于每一次查询在log2nlog^2 n的复杂度求出答案。

例题

「SDOI2014」旅行

由于有多个不同类型的点,对于每个种类的点我们都要开一颗线段树。但是由于点的种类能达到10510^5之多,故不能像以往那样每次都直接建立出一个大小为nlognnlogn的线段树。这里在建树时,我们只能一个一个地加点,并且加点时只需要开辟根节点到该节点这条路径上的点即可(每次要存储lognlogn个点)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#include <bits/stdc++.h>
using namespace std;
const int maxn=100000+10;
int n,q,w[maxn],c[maxn],T[maxn],sz;
int head[maxn],to[maxn<<1],nxt[maxn<<1],tot;
int dep[maxn],top[maxn],son[maxn],siz[maxn],fa[maxn],id[maxn],mp[maxn],tim;
struct node{
int lson,rson,val,sum;
#define mid (l+r>>1)
#define lson(rt) tree[rt].lson
#define rson(rt) tree[rt].rson
#define val(rt) tree[rt].val
#define sum(rt) tree[rt].sum
}tree[maxn*24];

inline int read(){
register int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return (f==1)?x:-x;
}
inline void add(int x,int y){//连边
to[++tot]=y;
nxt[tot]=head[x];
head[x]=tot;
}

void dfs1(int x,int f){//处理深度、重儿子、子树大小和父亲
dep[x]=dep[f]+1;
siz[x]=1;
fa[x]=f;
int maxson=-1;
for(int i=head[x],y;i;i=nxt[i]){
y=to[i];
if(y==f) continue;
dfs1(y,x);
siz[x]+=siz[y];
if(maxson<siz[y]){
maxson=siz[y];
son[x]=y;
}
}
}

void dfs2(int x,int topf){//处理重链轻边
id[x]=++tim;
mp[tim]=x;
top[x]=topf;
if(son[x]) dfs2(son[x],topf);
for(int i=head[x],y;i;i=nxt[i]){
y=to[i];
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);//不能写dfs2(y,x)!
}
}

void pushup(int rt){//上传操作
val(rt)=max(val(lson(rt)),val(rson(rt)));
sum(rt)=sum(lson(rt))+sum(rson(rt));
}
void pushdown(int rt){//删除结点
val(rt)=sum(rt)=0;
}
void update(int &rt,int x,int l,int r,int v){//动态开点
if(!rt) rt=++sz;
if(l == r){
val(rt)=sum(rt)=v;
return ;
}
if(x <= mid) update(lson(rt),x,l,mid,v);
else update(rson(rt),x,mid+1,r,v);
pushup(rt);
}
void del(int &rt,int x,int l,int r){//删除结点
if(!rt) return ;//剪枝
if(l == r){
pushdown(rt);
return ;
}
if(x <= mid) del(lson(rt),x,l,mid);
else del(rson(rt),x,mid+1,r);
pushup(rt);
}
int query_sum(int &rt,int L,int R,int l,int r){//查询区间和
if(!rt) return 0;
if(L <= l && r <= R){
return sum(rt);
}
int ans=0;
if(L <= mid) ans+=query_sum(lson(rt),L,R,l,mid);
if(R > mid) ans+=query_sum(rson(rt),L,R,mid+1,r);//不能写else!
return ans;
}
int query_max(int &rt,int L,int R,int l,int r){//查询区间最大值
if(!rt) return 0;
if(L <= l && r <= R){
return val(rt);
}
int ans=0;
if(L <= mid) ans=max(ans,query_max(lson(rt),L,R,l,mid));
if(R > mid) ans=max(ans,query_max(rson(rt),L,R,mid+1,r));//不能写else!
return ans;
}

int main()
{
n=read(),q=read();int x,y,k,ans;
for(int i=1;i<=n;i++)
w[i]=read(),c[i]=read();
for(int i=1;i<n;i++){
x=read(),y=read();
add(x,y);add(y,x);
}
dfs1(1,0);dfs2(1,1);
for(int i=1;i<=n;i++)//预处理
update(T[c[mp[i]]],i,1,n,w[mp[i]]);
char opt[50];
while(q--){
scanf("%s",opt);
if(opt[1]=='C'){//更改宗教
x=read(),k=read();
del(T[c[x]],id[x],1,n);
c[x]=k;
update(T[c[x]],id[x],1,n,w[x]);
}
if(opt[1]=='W'){//更改评分
x=read(),k=read();
del(T[c[x]],id[x],1,n);
w[x]=k;
update(T[c[x]],id[x],1,n,w[x]);
}
if(opt[1]=='S'){//树剖求和
x=read(),y=read();
k=c[x];ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);//不能写dep[x]<dep[y]!
ans+=query_sum(T[k],id[top[x]],id[x],1,n);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans+=query_sum(T[k],id[x],id[y],1,n);
printf("%d\n",ans);
}
if(opt[1]=='M'){//树剖求最大值
x=read(),y=read();
k=c[x];ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);//不能写dep[x]<dep[y]!
ans=max(ans,query_max(T[k],id[top[x]],id[x],1,n));
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans=max(ans,query_max(T[k],id[x],id[y],1,n));
printf("%d\n",ans);
}
}
return 0;
}

长链剖分

概述

在重链剖分中,一个点的重儿子是子树中最大的儿子。而长链剖分时,“重儿子”则指深度最深的儿子。

长链剖分时,叶子节点到根节点经过的重链数量必定小于等于n\sqrt{n}这一量级

证明
要经过尽量多的重链,不妨设有一个叶子,它到根节点经过的所有边都是轻边(连接轻儿子与其父节点的边)。假设该叶子的深度为depthdepth,则至少有depthdepth条链的叶子节点的深度大于等于depthdepth。则有depth+depth1+depth2+...+1ndepth+depth-1+depth-2+...+1\leq n,故depth(1+depth)/2ndepth*(1+depth)/2\leq n。证毕。

长链剖分多用于DP中。若DP的方程与当前节点的深度有关,可以利用长链剖分优化时间复杂度至O(n)O(n)。不过要利用这一点需要用到指针,具体使用方式参见例题。

例题

[POI2014]HOT-Hotels 加强版

这道题的难点在于DP方程的构思。具体可参见洛谷题解

首先,仍然根据重儿子优先的原则进行遍历。这样一来,对于每一条重链,该链上所有的节点的dfs序都是连着的。由于我们在dfs的过程中利用指针依次把每个点对应的值存放在数组中,故同一条重链上的点的存储地址也是连续的。

当我们需要更新某一节点的DP值时,它将自动继承其重儿子的所有参数。

首先借助这道题目讲讲为什么要在长链剖分中使用指针。
这里的f的更新方式为f[cur][j+1]+=f[to][j]f[cur][j + 1] += f[to][j],其中cur为当前节点,to为轻儿子节点。
有人可能会问,为什么循环中只遍历了轻儿子,更新了轻儿子的答案,那么重儿子呢?可以注意到,其实f[cur][j+1]f[cur][j + 1]f[son[cur]][j]f[son[cur]][j]的储存位置是一样的。这相当于在不花任何时间代价的前提下直接继承了重儿子处理得到的答案,即自动完成了f[cur][j+1]=f[son[cur]][j]f[cur][j + 1] = f[son[cur]][j]这一步。不得不说,这个处理方式实在是巧妙。

还有一点需要注意,这里的g的性质和f有些不一样。由于g[cur][j]=g[son[cur]][j+1]g[cur][j]=g[son[cur]][j + 1],所以对于某点的g,应该从该点重儿子后一位开始继承。所以我们要把在一条重链上的点的g倒着存。这里我想了好久才反应过来QAQ

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define re register int
#define FE(x) for (int h = head[x], to; ~h; h = gg[h].nxt)
const int maxn = 200005;
int n, cnt, head[maxn], height[maxn], son[maxn], dep[maxn], *f[maxn << 2], *g[maxn << 2], ans, tmpf[maxn << 2], tmpg[maxn << 2], *posf, *posg;
struct Arc
{
int from, to, nxt;
} gg[maxn << 1];

inline void addarc(int from, int to)
{
gg[cnt].from = from, gg[cnt].to = to, gg[cnt].nxt = head[from];
head[from] = cnt++;
}

void dfs(int cur, int fa)
{
height[cur] = dep[cur] = dep[fa] + 1;
FE(cur)
{
to = gg[h].to;
if (to ^ fa)
{
dfs(to, cur);
if (height[cur] < height[to])
{
height[cur] = height[to];
son[cur] = to;
}
}
}
}

void dfs1(int cur, int fa)
{
f[cur][0] = 1;
if (son[cur])
{
f[son[cur]] = f[cur] + 1;
g[son[cur]] = g[cur] - 1;
dfs1(son[cur], cur);
if (son[son[cur]])
ans += g[son[cur]][1];
}
FE(cur)
{
to = gg[h].to;
if (to ^ fa && to ^ son[cur])
{
int len = height[to] - dep[to] + 1;
f[to] = posf;
g[to] = posg + len;
posf += len;
posg += len * 2;
dfs1(to, cur);
for (re j = 0; j < len; j++)
{
if (j)
ans += f[cur][j - 1] * g[to][j];
ans += f[to][j] * g[cur][j + 1];
g[cur][j + 1] += f[to][j] * f[cur][j + 1];
}
for (re j = 0; j < len; j++)
{
if (j)
g[cur][j - 1] += g[to][j];
f[cur][j + 1] += f[to][j];
}
}
}
}

signed main()
{
memset(head, -1, sizeof(head));
scanf("%lld", &n);
for (re i = 1, a, b; i < n; i++)
{
scanf("%lld%lld", &a, &b);
addarc(a, b);
addarc(b, a);
}
dep[1] = 1;
dfs(1, 0);
posf = tmpf, posg = tmpg;
int len = height[1];
f[1] = posf;
g[1] = posg + len;
posf += len, posg += len << 1;
dfs1(1, 0);
printf("%lld", ans);
return 0;
}

#3252. 攻略

看到这道题目,很容易想到类似于找最小生成树的贪心算法。即一直选最长的链,选了这条链之后再选排除已选节点后剩下节点中的最长链,一直重复这样的操作直到选出k条链。

那么问题来了,为什么贪心是正确的呢?

首先,按照我们的贪心方法选出的每条链中深度最深的点一定是叶子节点。

假设根节点有m个儿子,使用贪心策略选出的链的叶子一定属于根节点某个儿子的子树。设对于点i(1im)i(1\leq i\leq m),有f(i)f(i)个这样的叶子位于该点的子树中,产生了一个排列,即f(1),f(2),...,f(m)f(1),f(2),...,f(m)

可以证明,只有这样的排列才是最优解。假设我们在i(1im)i(1\leq i\leq m)中多选了一条链,由于我们只能选k条链,所以我们会在j(1jm)j(1\leq j\leq m)中少选一条链。然而后者的长度大于前者,所以这样选不会更优。同理,我们可以把根节点的每一个儿子节点i,看成新的根节点,把k变成f(i),继续进行这样的分析。最终,就完成了贪心策略正确性的证明。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=2e5+5;
int n,k;
vector<int>e[maxn];
int w[maxn];
int son[maxn];
ll ans;
ll dep[maxn];
void dfs1(int u){
dep[u]=w[u];
for(auto x:e[u]){
dfs1(x);
if(dep[u]<w[u]+dep[x]){
dep[u]=w[u]+dep[x];
son[u]=x;
}
}
}
priority_queue<ll,vector<ll>>q;
void dfs2(int u,int top){
if(son[u])dfs2(son[u],top);
for(auto x:e[u]){
if(x==son[u])continue;
q.push(dep[x]);
dfs2(x,x);
}
}
int main(){
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++)scanf("%d",&w[i]);
for(int i=0;i<n-1;i++){
int u,v;scanf("%d%d",&u,&v);
e[u].push_back(v);
}
dfs1(1);dfs2(1,1);
q.push(dep[1]);
while(k--&&!q.empty()){
ans+=q.top();
q.pop();
}
printf("%lld",ans);
return 0;
}