intro

Splay 是一种二叉查找树,它通过不断将某个节点旋转到根节点,使得整棵树仍然满足二叉查找树的性质,并且保持平衡而不至于退化为链

实现过程

主要过程参照传送门即可,讲得十分详细,这里就不再赘述

说一下删除操作吧,这部分当时看了比较长的时间才看明白

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
void del(int k) {
rk(k);
if (cnt[rt] > 1) {
cnt[rt]--;
maintain(rt);
return;
}
if (!ch[rt][0] && !ch[rt][1]) {
clear(rt);
rt = 0;
return;
}
if (!ch[rt][0]) {
int cur = rt;
rt = ch[rt][1];
fa[rt] = 0;
clear(cur);
return;
}
if (!ch[rt][1]) {
int cur = rt;
rt = ch[rt][0];
fa[rt] = 0;
clear(cur);
return;
}
int cur = rt, x = pre();
fa[ch[cur][1]] = x;
ch[x][1] = ch[cur][1];
clear(cur);
maintain(rt);
}

之所以能这么操作,是因为此时的cur一定没有左儿子,且一定是新rt的右儿子。因为新rt的位置是“左右右右右右…”,所以在把新rt旋转到顶端的时候,原来的顶端一定通过右旋成为了新顶端的右儿子

例题

「Cerc2007」robotic sort 机械排序

这道题需要在logn内实现一个操作,即给定区间[l,r],翻转这个区间。这时就有了splay的用武之地。

具体实现参照代码,我会附上自己的注释

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 100;
int sz[N], cnt[N], fa[N], ch[N][2], rt, n, m, lazy[N], ans[N], tot;
stack<int>S;
struct node {
int a, id;
bool operator < (node A)const {
if (a != A.a) return a < A.a;
return id < A.id;
}
} f[N];
void pushup(int x) {
sz[x] = sz[ch[x][0]] + sz[ch[x][1]] + 1;
}
bool get(int x) {
return ch[fa[x]][1] == x;
}
void pushdown(int x) {
if (lazy[x]) {
lazy[ch[x][0]] ^= 1;
lazy[ch[x][1]] ^= 1;
swap(ch[x][0], ch[x][1]);
lazy[x] = 0;
}
}
void rotate(int x) {
int y = fa[x], z = fa[y], chk = get(x);
pushdown(y);
pushdown(x);
ch[y][chk] = ch[x][chk ^ 1];
fa[ch[x][chk ^ 1]] = y;
ch[x][chk ^ 1] = y;
fa[y] = x;
fa[x] = z;
if (z) ch[z][y == ch[z][1]] = x;
pushup(x);
pushup(y);
}
void splay(int x, int goal) {
while(fa[x] != goal) {
int y = fa[x], z = fa[y];
if (z != goal)
(ch[z][1] == y) ^ (ch[y][1] == x) ? rotate(x) : rotate(y);
rotate(x);
}
if (goal == 0) rt = x;
}
int find(int k) {
int now = rt;
while(1) {
pushdown(now);
if (ch[now][0] && k <= sz[ch[now][0]]) {
now = ch[now][0];
} else {
if (ch[now][0]) k -= sz[ch[now][0]];
k --;
if (k <= 0) return now;
now = ch[now][1];
}
}
}
int build(int l, int r, int root) {
int now = (l + r) >> 1;
fa[now] = root;
if (l < now) ch[now][0] = build(l, now - 1, now);
if (r > now) ch[now][1] = build(now + 1, r, now);
pushup(now);
return now;
}
void solve(int x) {
int y = x;
while(!S.empty()) S.pop();
while(fa[y] != 0) {
S.push(fa[y]);
y = fa[y];
}
while(!S.empty()) {
y = S.top();
S.pop();
pushdown(y);
}
splay(x, 0);
}

int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
scanf("%d", &f[i].a);
f[i].id = i;
}
int x, y;
sort(f + 1, f + n + 1);
f[0].id = 0, f[n + 1].id = n + 1;
build(1, n + 2, 0);//总共有n+2个点
for (int i = 1; i <= n; ++i) {
solve(f[i].id + 1); //每个点的编号是f[i].id+1
ans[i] = sz[ch[rt][0]];
x = find(i); // 区间是 [i+ 1, ans[i] + 1] , 要左边减一,右边加一 .
y = find(ans[i] + 2);
//这两步找出要翻转的区间外的左边和右边两个端点,其中夹的区间就是要翻转的区间
splay(x, 0);
splay(y, x);
lazy[ch[ch[rt][1]][0]] ^= 1;
}
for (int i = 1; i <= n; ++i)
printf("%d ", ans[i]);
printf("\n");
return 0;
}
/*
6
3 4 5 1 6 2

*/


bzoj 2827 千山鸟飞绝

这个问题的关键在于维护一个 数据结构,能动态删除特定点,且维护一个最大值即可。其实动态开点线段树也能解决。比较简单,直接上代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define mp make_pair
using namespace std;
const int maxn=3e5+5;
int siz[maxn],f[maxn],son[maxn][2],val[maxn],tot_size,root[maxn],a[maxn];
int laz1[maxn],laz2[maxn],ans1[maxn],ans2[maxn],q[maxn],num,w[maxn];
map<pii,int>vis;
inline ll read()
{
ll ret=0;char son=' ',c=getchar();
while(!(c<='9'&&c>='0')) son=c,c=getchar();
while(c<='9'&&c>='0') ret=(ret<<1)+(ret<<3)+c-'0',c=getchar();
return son=='-'?-ret:ret;
}
int get(int x) {return x==son[f[x]][1];}
void clear(int x) {son[x][1]=son[x][0]=siz[x]=f[x]=val[x]=laz1[x]=laz2[x]=0;}
void update(int x)
{
siz[x]=siz[son[x][1]]+siz[son[x][0]]+1;
val[x]=max(max(val[son[x][1]],val[son[x][0]]),w[x]);
}
void rotate(int x)
{
int y=f[x],z=f[y],type=get(x);
son[y][type]=son[x][type^1];
f[son[x][type^1]]=y;
son[x][type^1]=y;
f[y]=x,f[x]=z;
if(z) son[z][y==son[z][1]]=x;
update(y);update(x);
}
void pushdown(int x)
{
int k1=laz1[x],k2=laz2[x];
for(int i=0;i<=1;i++)
{
if(son[x][i])
{
int j=son[x][i];
ans1[j]=max(ans1[j],k1);
ans2[j]=max(ans2[j],k2);
laz1[j]=max(laz1[j],k1);
laz2[j]=max(laz2[j],k2);
}
}
laz1[x]=laz2[x]=0;
}
void splay(int x,int &rt)
{
int top=0;
for(int fa=x;fa;fa=f[fa]) q[++top]=fa;
while(top) pushdown(q[top--]);
for(int fa;fa=f[x];rotate(x)) if(f[fa]) rotate(get(fa)==get(x)?fa:x);
rt=x;
}
int find_pre(int p)
{
int now=son[root[p]][0];
while(son[now][1]) now=son[now][1];
return now;
}
void del(int p,int x)
{
splay(x,root[p]);
int now=root[p];
if(!son[now][0]&&!son[now][1])
{
clear(now);
root[p]=0;
return;
}
for(int i=0;i<=1;i++)
{
if(!son[now][i])
{
root[p]=son[now][i^1];
f[root[p]]=0;
clear(now);
return;
}
}
x=find_pre(p);
splay(x,root[p]);
f[son[now][1]]=x;
son[x][1]=son[now][1];
clear(now);
update(root[p]);
}
void insert(int &p,int x,int fa)
{
if(!p)
{
p=x,f[p]=fa,val[p]=w[p],siz[p]=1;
return;
}
pushdown(p);
if(!son[p][0]) insert(son[p][0],x,p);
else insert(son[p][1],x,p);
}
int main()
{
int n=read();
for(int i=1;i<=n;i++)
{
w[i]=read();
int x=read(),y=read();
if(!vis.count(mp(x,y))) vis[mp(x,y)]=++num;
int k=vis[mp(x,y)];
a[i]=k;
ans1[i]=max(ans1[i],val[root[k]]);
ans2[i]=max(ans2[i],siz[root[k]]);
insert(root[k],i,0);
splay(i,root[k]);
laz1[root[k]]=max(laz1[root[k]],w[i]);
laz2[root[k]]=max(laz2[root[k]],siz[root[k]]-1);
}
int t=read();
while(t--)
{
int i=read(),x=read(),y=read();
if(!vis.count(mp(x,y))) vis[mp(x,y)]=++num;
int k=vis[mp(x,y)],pre=a[i];
a[i]=k;
del(pre,i);
ans1[i]=max(ans1[i],val[root[k]]);
ans2[i]=max(ans2[i],siz[root[k]]);
insert(root[k],i,0);
splay(i,root[k]);
laz1[root[k]]=max(laz1[root[k]],w[i]);
laz2[root[k]]=max(laz2[root[k]],siz[root[k]]-1);
}
for(int i=1;i<=n;i++)
{
splay(i,root[a[i]]);
printf("%lld\n",1ll*ans1[i]*ans2[i]);
}
return 0;
}