模板

洛谷 P4719 【模板】动态 DP

因为要动态维护树上的DP方程,所以很容易想到利用树链剖分进行优化。进而考虑使用重链的性质来化简DP过程的复杂度。这里将重链和轻链的对DP的贡献分开考虑,从而使用广义矩阵乘法,使DP过程得以在线段树上实现。

证明广义矩阵乘法的结合律
因为max+*都满足结合律和交换律,并且*+有分配律,+max也有分配律。我们将max+分别看成*+,和证明普通矩阵乘法的方法一样,把计算的每一项都展开即可。然后可发现对于矩阵A,B,CA,B,C做广义矩阵乘法,满足(AB)C=A(BC)(AB)C=A(BC)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#include <bits/stdc++.h>

using namespace std;

#define REP(i, a, b) for (int i = (a), _end_ = (b); i <= _end_; ++i)
#define mem(a) memset((a), 0, sizeof(a))
#define str(a) strlen(a)
#define lson root << 1
#define rson root << 1 | 1
typedef long long LL;

const int maxn = 500010;
const int INF = 0x3f3f3f3f;

int Begin[maxn], Next[maxn], To[maxn], e, n, m;
int size[maxn], son[maxn], top[maxn], fa[maxn], dis[maxn], p[maxn], id[maxn],
End[maxn];
// p[i]表示i树剖后的编号,id[p[i]] = i
int cnt, tot, a[maxn], f[maxn][2];

struct matrix {
int g[2][2];
matrix() { memset(g, 0, sizeof(g)); }
matrix operator*(const matrix &b) const // 重载矩阵乘
{
matrix c;
REP(i, 0, 1)
REP(j, 0, 1) REP(k, 0, 1) c.g[i][j] = max(c.g[i][j], g[i][k] + b.g[k][j]);
return c;
}
} Tree[maxn], g[maxn]; // Tree[]是建出来的线段树,g[]是维护的每个点的矩阵

inline void PushUp(int root) { Tree[root] = Tree[lson] * Tree[rson]; }

inline void Build(int root, int l, int r) {
if (l == r) {
Tree[root] = g[id[l]];
return;
}
int Mid = l + r >> 1;
Build(lson, l, Mid);
Build(rson, Mid + 1, r);
PushUp(root);
}

inline matrix Query(int root, int l, int r, int L, int R) {
if (L <= l && r <= R) return Tree[root];
int Mid = l + r >> 1;
if (R <= Mid) return Query(lson, l, Mid, L, R);
if (Mid < L) return Query(rson, Mid + 1, r, L, R);
return Query(lson, l, Mid, L, R) * Query(rson, Mid + 1, r, L, R);
// 注意查询操作的书写
}

inline void Modify(int root, int l, int r, int pos) {
if (l == r) {
Tree[root] = g[id[l]];
return;
}
int Mid = l + r >> 1;
if (pos <= Mid)
Modify(lson, l, Mid, pos);
else
Modify(rson, Mid + 1, r, pos);
PushUp(root);
}

inline void Update(int x, int val) {
g[x].g[1][0] += val - a[x];
a[x] = val;
// 首先修改x的g矩阵
while (x) {
matrix last = Query(1, 1, n, p[top[x]], End[top[x]]);
// 查询top[x]的原本g矩阵
Modify(1, 1, n,
p[x]); // 进行修改(x点的g矩阵已经进行修改但线段树上的未进行修改)
matrix now = Query(1, 1, n, p[top[x]], End[top[x]]);
// 查询top[x]的新g矩阵
x = fa[top[x]];
g[x].g[0][0] +=
max(now.g[0][0], now.g[1][0]) - max(last.g[0][0], last.g[1][0]);
g[x].g[0][1] = g[x].g[0][0];
g[x].g[1][0] += now.g[0][0] - last.g[0][0];
// 根据变化量修改fa[top[x]]的g矩阵
}
}

inline void add(int u, int v) {
To[++e] = v;
Next[e] = Begin[u];
Begin[u] = e;
}

inline void DFS1(int u) {
size[u] = 1;
int Max = 0;
f[u][1] = a[u];
for (int i = Begin[u]; i; i = Next[i]) {
int v = To[i];
if (v == fa[u]) continue;
dis[v] = dis[u] + 1;
fa[v] = u;
DFS1(v);
size[u] += size[v];
if (size[v] > Max) {
Max = size[v];
son[u] = v;
}
f[u][1] += f[v][0];
f[u][0] += max(f[v][0], f[v][1]);
// DFS1过程中同时求出f[i][0/1]
}
}

inline void DFS2(int u, int t) {
top[u] = t;
p[u] = ++cnt;
id[cnt] = u;
End[t] = cnt;
g[u].g[1][0] = a[u];
g[u].g[1][1] = -INF;
if (!son[u]) return;
DFS2(son[u], t);
for (int i = Begin[u]; i; i = Next[i]) {
int v = To[i];
if (v == fa[u] || v == son[u]) continue;
DFS2(v, v);
g[u].g[0][0] += max(f[v][0], f[v][1]);
g[u].g[1][0] += f[v][0];
// g矩阵根据f[i][0/1]求出
}
g[u].g[0][1] = g[u].g[0][0];
}

int main() {
#ifndef ONLINE_JUDGE
freopen("input.txt", "r", stdin);
freopen("output.txt", "w", stdout);
#endif
scanf("%d%d", &n, &m);
REP(i, 1, n) scanf("%d", &a[i]);
REP(i, 1, n - 1) {
int u, v;
scanf("%d%d", &u, &v);
add(u, v);
add(v, u);
}
dis[1] = 1;
DFS1(1);
DFS2(1, 1);
Build(1, 1, n);
REP(i, 1, m) {
int x, val;
scanf("%d%d", &x, &val);
Update(x, val);
matrix ans = Query(1, 1, n, 1, End[1]); // 查询1所在重链的矩阵乘
printf("%d\n", max(ans.g[0][0], ans.g[1][0]));
}
return 0;
}

例题

SP1716 GSS3 - Can you answer these queries III

常规做法求解

考虑使用线段树维护序列求解。问题在于,如何维护区间最大子段

对于一个区间[l,r][l,r],同时设mid=(l+r)/2mid=(l+r)/2。则最大子段有三种情况:

  • 子段在区间[l,mid][l,mid]
  • 子段在区间[mid+1,r][mid+1,r]
  • 子段有一部分在[l,mid][l,mid]内,有一部分在[mid+1,r][mid+1,r]

则问题在于如何解决第三种情况。

我们记resres为区间最长子段和,sumsum为区间总和,prelprelprerprer 分别表示从区间左端点和右端点开始的最大子段和。

则有:

prel[rt]=max(prel[rt<<1],sum[rt<<1]+prel[rt<<11])prel[rt]=max(prel[rt<<1],sum[rt<<1]+prel[rt<<1|1])
prer[rt]=max(prer[rt<<11],sum[rt<<11]+prer[rt<<1])prer[rt]=max(prer[rt<<1|1],sum[rt<<1|1]+prer[rt<<1])
res[rt]=max(res[rt<<1],res[rt<<11],prer[rt<<1]+prel[r<<11])res[rt]=max(res[rt<<1],res[rt<<1|1],prer[rt<<1]+prel[r<<1|1])

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <cstdio>
#include <algorithm>
#define lson rt<<1
#define rson rt<<1|1
using std::max;
const int N=5e4+5;
int n,m,a[N];

struct Tree {
int prel,prer,res,sum;
}seg[N<<2];

void pushup(int rt) {
Tree L=seg[lson],R=seg[rson];
seg[rt].sum=L.sum+R.sum;
seg[rt].prel=max(L.prel,L.sum+R.prel);
seg[rt].prer=max(R.prer,R.sum+L.prer);
seg[rt].res=max(L.prer+R.prel,max(L.res,R.res));
}
void build(int rt,int l,int r) {
if(l==r) {
seg[rt].prel=seg[rt].prer=seg[rt].res=seg[rt].sum=a[l];
return;
}
int mid=(l+r)>>1;
build(lson,l,mid);
build(rson,mid+1,r);
pushup(rt);
}
void modify(int x,int rt,int l,int r,int val) {
if(l==r) {
seg[rt].prel=seg[rt].prer=seg[rt].res=seg[rt].sum=val;
return;
}
int mid=(l+r)>>1;
if(x<=mid) modify(x,lson,l,mid,val);
else modify(x,rson,mid+1,r,val);
pushup(rt);
}
Tree query(int x,int y,int rt,int l,int r) {
if(x<=l&&r<=y) return seg[rt];
int mid=(l+r)>>1;
if(y<=mid) return query(x,y,lson,l,mid);
if(mid<x) return query(x,y,rson,mid+1,r);
Tree L=query(x,mid,lson,l,mid),R=query(mid+1,y,rson,mid+1,r),res;
res.sum=L.sum+R.sum;
res.prel=max(L.prel,L.sum+R.prel);
res.prer=max(R.prer,R.sum+L.prer);
res.res=max(L.prer+R.prel,max(L.res,R.res));
return res;
}
int main() {
scanf("%d",&n);
for(int i=1;i<=n;++i) scanf("%d",&a[i]);
build(1,1,n);
for(scanf("%d",&m);m--;) {
int opt,x,y;
scanf("%d%d%d",&opt,&x,&y);
if(opt) printf("%d\n",query(x,y,1,1,n).res);
else modify(x,1,1,n,y);
}
return 0;
}

动态DP求解

定义v[i]v[i]A[i]A[i]的值,f[i]f[i]为以第i个结尾的最大字段和的大小,g[i]g[i]表示前i个数的最大子段和。则dp方程为:

f[i]=max(v[i],f[i1]+v[i])f[i]=\max (v[i], f[i-1]+v[i])

g[i]=max(g[i1],f[i])g[i]=\max (g[i-1], f[i])

好巧不巧,这个式子里面只有+max,所以可以使用广义矩阵乘法对dp进行优化。

Ci,j=max{Ai,k+Bk,j}C_{i, j}=\max \left\{A_{i, k}+B_{k, j}\right\}

要求 CC 满足[f[i1]g[i1]0]C=[f[i]g[i]0]\left[\begin{array}{lll}f[i-1] & g[i-1] & 0\end{array}\right] * C=\left[\begin{array}{lll}f[i] & g[i] & 0\end{array}\right]

注意在max运算中,00 是幺元,-\infty是零元

求解得到:

C=[v[i]v[i]0v[i]v[i]0]C=\left[\begin{array}{ccc}v[i] & v[i] & -\infty \\ -\infty & 0 & -\infty \\ v[i] & v[i] & 0\end{array}\right]

然后定义初始的矩阵为:

[f[l1]g[l1]0]=[0]\left[\begin{array}{lll}f[l-1] & g[l-1] & 0\end{array}\right]=\left[\begin{array}{lll}-\infty & -\infty & 0\end{array}\right]

即在下标为l之前的元素都不能选

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+5;
const ll inf=1e18;
int n,q;
ll a[maxn];
ll l,r;
struct Mat{
ll mat[3][3];
Mat(){
for(int i=0;i<3;i++)
for(int j=0;j<3;j++)
mat[i][j]=-inf;
}
const Mat operator*(const Mat &b) const {
Mat c;
for(int i=0;i<3;i++)
for(int j=0;j<3;j++){
for(int k=0;k<3;k++){
c.mat[i][j]=max(c.mat[i][j],mat[i][k]+b.mat[k][j]);
}
}
return c;
}
}t[maxn<<2];
void push_up(int rt){
t[rt]=t[rt<<1]*t[rt<<1|1];
}
void build(int rt,int l,int r){
if(l==r){
t[rt].mat[0][0]=a[l];
t[rt].mat[0][1]=a[l];
t[rt].mat[1][1]=0;
t[rt].mat[2][0]=a[l];
t[rt].mat[2][1]=a[l];
t[rt].mat[2][2]=0;
return ;
}
int mid=l+r>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
push_up(rt);
}
void modify(int rt,int l,int r,int p){
if(l==r){
t[rt].mat[0][0]=a[l];
t[rt].mat[0][1]=a[l];
t[rt].mat[2][0]=a[l];
t[rt].mat[2][1]=a[l];
return ;
}
int mid=l+r>>1;
if(p<=mid)modify(rt<<1,l,mid,p);
else modify(rt<<1|1,mid+1,r,p);
push_up(rt);
}
Mat query(int rt,int l,int r,int xl,int xr){
if(l==xl&&r==xr){
return t[rt];
}
int mid=l+r>>1;
Mat x;
for(int i=0;i<3;i++)
x.mat[i][i]=0;
if(xl<=mid)x=x*query(rt<<1,l,mid,xl,min(xr,mid));
if(xr>mid)x=x*query(rt<<1|1,mid+1,r,max(mid+1,xl),xr);
return x;
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
build(1,1,n);
scanf("%d",&q);
while(q--){
int opt;
scanf("%d%lld%lld",&opt,&l,&r);
if(!opt){
a[l]=r;
modify(1,1,n,l);
}
else{
Mat ans;
ans.mat[0][2]=0;
ans=ans*query(1,1,n,l,r);
printf("%lld\n",ans.mat[0][1]);
}
}
return 0;
}

「NOIP2018」保卫王国

一眼看去,“这不是动态DP的模板题吗”

结果做的时候还是踩了好多坑…

当然,还有更加常规但是写起来更麻烦的倍增做法,详见题解 P5024 【保卫王国】。这里仅介绍动态DP做法

先写最原始的递推式,即
f[u][0]=f[v][1]f[u][0]=\sum f[v][1]
f[u][1]=min(f[v][1],f[v][0])+a[u]f[u][1]=\sum \min (f[v][1], f[v][0])+a[u],

这里的广义矩阵乘法变成了Ci,j=min{Ai,k+Bk,j}C_{i, j}=\min \left\{A_{i, k}+B_{k, j}\right\}

由于对于min运算的幺元为 00,零元为 ++\infty ,故写出转移矩阵为

G=[+g[u][1]g[u][0]g[u][1]]G=\left[\begin{array}{ll}+\infty & g[u][1] \\ g[u][0] & g[u][1]\end{array}\right]

定义初始矩阵为S=[00]S=\left[\begin{array}{l}0 & 0\end{array}\right]

Ans=SG1G2GmuAns=S * \underbrace{G_{1} * G_{2} * \cdots \cdot G_{m}}_{重链上u节点到叶子所有矩阵} ,答案就是min(Ans[0][0],Ans[0][1])min(Ans[0][0],Ans[0][1])

注意,这里初始矩阵是放在左边的,所以在线段树的push_up操作中需要右半段的矩阵乘左半段的矩阵。然后就是每次询问之后,要复原DP值。这时候要依次复原,即优先复原第二个被更改的点,不然会出错

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+5;
const ll inf=1e10+5;
int n,m,cnt;
char s[maxn];
vector<int>e[maxn];
ll a[maxn],g[maxn][2],f[maxn][2];
int sz[maxn],son[maxn],top[maxn],fa[maxn],rk[maxn],dfn[maxn],ed[maxn];
struct mat{
ll g[2][2];
mat(){
for(int i=0;i<2;i++)
for(int j=0;j<2;j++)
g[i][j]=1e17;
}
const mat operator*(const mat &b) const{
mat c;
for(int i=0;i<2;i++)
for(int j=0;j<2;j++){
for(int k=0;k<2;k++){
c.g[i][j]=min(c.g[i][j],g[i][k]+b.g[k][j]);
}
}
return c;
}
}t[maxn<<2];
void push_up(int rt){
t[rt]=t[rt<<1|1]*t[rt<<1];
}
void init(int rt,int p){
t[rt].g[0][1]=t[rt].g[1][1]=g[p][1];
t[rt].g[1][0]=g[p][0];
}
void build(int rt,int l,int r){
if(l==r){
init(rt,rk[l]);
return;
}
int mid=l+r>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
push_up(rt);
}
void dfs1(int u,int p){
sz[u]=1;
for(auto v:e[u]){
if(v==p)continue;
dfs1(v,u);
fa[v]=u;
sz[u]+=sz[v];
if(sz[v]>sz[son[u]])son[u]=v;
}
}
void dfs2(int u,int pre,int p){
dfn[u]=++cnt;rk[cnt]=u;
top[u]=pre;g[u][1]=a[u];
if(son[u])dfs2(son[u],pre,u);
else ed[pre]=u;
for(auto v:e[u]){
if(v==p||v==son[u])continue;
dfs2(v,v,u);
g[u][1]+=min(f[v][1],f[v][0]);
g[u][0]+=f[v][1];
}
f[u][1]=min(f[son[u]][0],f[son[u]][1])+g[u][1];
f[u][0]=g[u][0]+f[son[u]][1];
}
void modify(int rt,int l,int r,int p){
if(l==r){
init(rt,p);
return ;
}
int mid=l+r>>1;
if(dfn[p]<=mid)modify(rt<<1,l,mid,p);
else modify(rt<<1|1,mid+1,r,p);
push_up(rt);
}
mat query(int rt,int l,int r,int xl,int xr){
if(xl==l&&xr==r){
return t[rt];
}
int mid=l+r>>1;
mat x;x.g[0][0]=x.g[1][1]=0;
if(xr>mid)x=x*query(rt<<1|1,mid+1,r,max(mid+1,xl),xr);
if(xl<=mid)x=x*query(rt<<1,l,mid,xl,min(xr,mid));
return x;
}
void change(int p){
while(p){
modify(1,1,cnt,p);
p=top[p];
mat x;x.g[0][0]=x.g[0][1]=0;
x=x*query(1,1,cnt,dfn[p],dfn[ed[p]]);
if(fa[p]){
g[fa[p]][0]+=x.g[0][1]-f[p][1];
g[fa[p]][1]+=min(x.g[0][0],x.g[0][1])-min(f[p][0],f[p][1]);
}
f[p][0]=x.g[0][0],f[p][1]=x.g[0][1];
p=fa[p];
}
}
int main(){
freopen("defense.in","r",stdin);
freopen("defense.out","w",stdout);
scanf("%d%d%s",&n,&m,s);
for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
e[u].push_back(v);
e[v].push_back(u);
}
dfs1(1,0);dfs2(1,1,0);
build(1,1,cnt);
while(m--){
int p1,p2,w1,w2;
scanf("%d%d%d%d",&p1,&w1,&p2,&w2);
ll temp1[2],temp2[2];
for(int i=0;i<2;i++)temp1[i]=g[p1][i];
if(!w1)g[p1][1]+=inf;else g[p1][0]+=inf;
change(p1);
for(int i=0;i<2;i++)temp2[i]=g[p2][i];
if(!w2)g[p2][1]+=inf;else g[p2][0]+=inf;
change(p2);
mat x;x.g[0][0]=x.g[0][1]=0;
x=x*query(1,1,n,dfn[1],dfn[ed[1]]);
ll ans=min(x.g[0][0],x.g[0][1]);
if(ans<inf)printf("%lld\n",ans);
else printf("-1\n");
for(int i=0;i<2;i++)g[p2][i]=temp2[i];
change(p2);
for(int i=0;i<2;i++)g[p1][i]=temp1[i];
change(p1);
}
return 0;
}

「SDOI2017」切树游戏

要用FWT(快速沃尔什变换)优化才行…然而本菜鸡不会

先埋个坑,过两天学了在来补叭