题目链接:洛谷P3384-树链剖分(模板)
题意:
已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
- 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
- 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
- 3 x z 表示将以x为根节点的子树内所有节点值都加上z
- 4 x 表示求以x为根节点的子树内所有节点值之和
题解:
树链剖分模板题。
两次dfs预处理之后,一棵树就被剖成了若干条链,把这些链连接在一起就成了区间问题,方便用数据结构进行维护。
参考代码:
#include <bits/stdc++.h>
#define IL inline
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 5;
int n, m, r;
ll mod;
ll a[MAXN];
struct EDGE
{
int v, nxt;
} edge[MAXN << 1];
int fir[MAXN], ecnt = 0;
IL void addedge(int u, int v)
{
edge[ecnt].v = v;
edge[ecnt].nxt = fir[u];
fir[u] = ecnt++;
}
int dfscnt; //dfs序
int fa[MAXN], dep[MAXN], sz[MAXN], son[MAXN], rk[MAXN], id[MAXN], top[MAXN];
/*
fa[i]:记录i点的父亲
dep[i]:记录i点的深度
sz[i]:记录i点子树最大值
son[i]:记录i点的重儿子
rk[i]:新编号与原编号之间的映射关系
id[i]:记录i点新编号
top[i]:记录i点这条链的顶端结点
*/
ll sum[MAXN << 2], lazy[MAXN << 2];
/* 树剖预处理 */
IL void dfs1(int u, int Fa, int depth)
{
fa[u] = Fa, dep[u] = depth, sz[u] = 1;
int maxson = -1;
for (int i = fir[u]; i != -1; i = edge[i].nxt)
{
int v = edge[i].v;
if (v == Fa)
continue;
dfs1(v, u, depth + 1);
sz[u] += sz[v];
if (sz[v] > maxson)
son[u] = v, maxson = sz[v];
}
}
IL void dfs2(int u, int t)
{
top[u] = t;
id[u] = ++dfscnt;
rk[dfscnt] = u;
if (son[u] == -1)
return;
dfs2(son[u], t);
for (int i = fir[u]; i != -1; i = edge[i].nxt)
{
int v = edge[i].v;
if (v != fa[u] && v != son[u])
dfs2(v, v);
}
}
/* 线段树部分 */
IL void seg_pushup(int rt)
{
sum[rt] = (sum[rt << 1] + sum[rt << 1 | 1]) % mod;
}
IL void seg_pushdown(int l, int r, int rt)
{
if (lazy[rt])
{
lazy[rt << 1] = (lazy[rt << 1] + lazy[rt]) % mod;
lazy[rt << 1 | 1] = (lazy[rt << 1 | 1] + lazy[rt]) % mod;
int mid = (l + r) >> 1;
sum[rt << 1] = (sum[rt << 1] + (mid - l + 1) * lazy[rt] % mod) % mod;
sum[rt << 1 | 1] = (sum[rt << 1 | 1] + (r - mid) * lazy[rt] % mod) % mod;
lazy[rt] = 0;
}
}
IL void seg_build(int l, int r, int rt)
{
sum[rt] = 0, lazy[rt] = 0;
if (l == r)
{
sum[rt] = a[rk[l]];
return;
}
int mid = (l + r) >> 1;
seg_build(l, mid, rt << 1);
seg_build(mid + 1, r, rt << 1 | 1);
seg_pushup(rt);
}
IL void seg_update(int L, int R, ll val, int l, int r, int rt)
{
if (L <= l && r <= R)
{
lazy[rt] = (lazy[rt] + val) % mod;
sum[rt] = (sum[rt] + (r - l + 1) * val % mod) % mod;
return;
}
seg_pushdown(l, r, rt);
int mid = (l + r) >> 1;
if (L <= mid)
seg_update(L, R, val, l, mid, rt << 1);
if (R > mid)
seg_update(L, R, val, mid + 1, r, rt << 1 | 1);
seg_pushup(rt);
}
IL ll seg_query(int L, int R, int l, int r, int rt)
{
if (L <= l && r <= R)
return sum[rt] % mod;
seg_pushdown(l, r, rt);
ll res = 0;
int mid = (l + r) >> 1;
if (L <= mid)
res = (res + seg_query(L, R, l, mid, rt << 1)) % mod;
if (R > mid)
res = (res + seg_query(L, R, mid + 1, r, rt << 1 | 1)) % mod;
return res;
}
/* 树链剖分 */
IL ll query(int x, int y)
{
ll res = 0;
while (top[x] != top[y]) //当两个点不在同一条链上
{
if (dep[top[x]] < dep[top[y]])
swap(x, y); //先爬深度较深的那个点
res = (res + seg_query(id[top[x]], id[x], 1, n, 1)) % mod; //把当前节点到当前链的顶端节点这段区间计入答案
x = fa[top[x]];
}
//当两个点在同一条链上
if (id[x] > id[y])
swap(x, y);
res = (res + seg_query(id[x], id[y], 1, n, 1)) % mod;
return res;
}
IL void update(int x, int y, ll val)
{
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]])
swap(x, y);
seg_update(id[top[x]], id[x], val, 1, n, 1);
x = fa[top[x]];
}
if (id[x] > id[y])
swap(x, y);
seg_update(id[x], id[y], val, 1, n, 1);
return;
}
int main()
{
scanf("%d%d%d%lld", &n, &m, &r, &mod);
for (int i = 1; i <= n; i++)
scanf("%lld", &a[i]), a[i] %= mod;
ecnt = 0;
memset(fir, -1, sizeof(fir));
for (int i = 1; i <= n - 1; i++)
{
int u, v;
scanf("%d%d", &u, &v);
addedge(u, v), addedge(v, u);
}
dfscnt = 0;
memset(son, -1, sizeof(son));
dfs1(r, -1, 1);
dfs2(r, r);
seg_build(1, n, 1);
while (m--)
{
int op, x, y;
ll z;
scanf("%d", &op);
if (op == 1) //x到y节点的路径上的点权值+z
{
scanf("%d%d%lld", &x, &y, &z);
update(x, y, z);
}
else if (op == 2) //查询x到y节点的路径上的点权和
{
scanf("%d%d", &x, &y);
printf("%lld\n", query(x, y) % mod);
}
else if (op == 3) //以x为根的子树内的所有节点权值+z
{
scanf("%d%lld", &x, &z);
seg_update(id[x], id[x] + sz[x] - 1, z, 1, n, 1);
}
else if (op == 4) //查询以x为根的子树内所有节点点权和
{
scanf("%d", &x);
printf("%lld\n", seg_query(id[x], id[x] + sz[x] - 1, 1, n, 1));
}
}
return 0;
}