GYM101899I-Imperial roads(次小生成树+树链剖分+线段树)

题目链接:GYM101899I-Imperial roads

题意:

给定n个点,m条边,然后有q个询问,每次询问一条边,问这条边所在的最小生成树是多少。

题解:

我们可以先求出MST,然后对于每次询问,我们把询问的这条边加入树上,此时树上形成了一个环,为了保证尽可能的小,我们需要移除环上最大的边权(除了新加入的边)(次小生成树定义)。

因为询问时独立的,所以树的结构并不会改变。所以我们可以在求出MST之后,对MST进行树链剖分,建线段树,维护区间最值。

对于每次询问,若询问的边就在MST上,答案就为MST,若不在MST上,答案为MST-两点间最大边权+两点间的边权。

参考代码:

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1e5 + 5;
const int MAXM = 2e5 + 5;
const int inf = 0x3f3f3f3f;
int n, m, sum;
struct NODE
{
    int u, v, w;
    NODE(){};
    NODE(int u, int v, int w) { this->u = u, this->v = v, this->w = w; }
} E[MAXM];
struct EDGE
{
    int v, w, nxt;
    EDGE() {}
    EDGE(int v, int w, int nxt) { this->v = v, this->w = w, this->nxt = nxt; }

} edge[MAXM << 1];
int ecnt, fir[MAXN];
void addedge(int u, int v, int w)
{
    edge[ecnt] = EDGE(v, w, fir[u]);
    fir[u] = ecnt++;
    swap(u, v);
    edge[ecnt] = EDGE(v, w, fir[u]);
    fir[u] = ecnt++;
}
int Fa[MAXN];
int getf(int x)
{
    if (x == Fa[x])
        return x;
    return Fa[x] = getf(Fa[x]);
}
bool merged(int x, int y)
{
    int tx = getf(x), ty = getf(y);
    if (tx == ty)
        return 0;
    Fa[ty] = tx;
    return 1;
}
int dfscnt, A[MAXN];
int fa[MAXN], dep[MAXN], sz[MAXN], son[MAXN], rk[MAXN], id[MAXN], top[MAXN];
void dfs1(int u, int FA, int depth, int dw)
{
    fa[u] = FA, dep[u] = depth, sz[u] = 1;
    A[u] = dw;
    int maxson = -1;
    for (int i = fir[u]; i != -1; i = edge[i].nxt)
    {
        int v = edge[i].v, w = edge[i].w;
        if (v == FA)
            continue;
        dfs1(v, u, depth + 1, w);
        sz[u] += sz[v];
        if (sz[v] > maxson)
            son[u] = v, maxson = sz[v];
    }
}
void dfs2(int u, int t)
{
    top[u] = t;
    id[u] = ++dfscnt;
    rk[dfscnt] = u;
    if (son[u] == -1)
        return;
    dfs2(son[u], t);
    for (int i = fir[u]; i != -1; i = edge[i].nxt)
    {
        int v = edge[i].v;
        if (v != fa[u] && v != son[u])
            dfs2(v, v);
    }
}
int maxn[MAXN << 2], lazy[MAXN << 2];
void segpush(int rt)
{
    maxn[rt] = max(maxn[rt << 1], maxn[rt << 1 | 1]);
}
void segbuild(int l, int r, int rt)
{
    maxn[rt] = 0;
    lazy[rt] = 0;
    if (l == r)
    {
        maxn[rt] = A[rk[l]];
        return;
    }
    int mid = (l + r) >> 1;
    segbuild(l, mid, rt << 1);
    segbuild(mid + 1, r, rt << 1 | 1);
    segpush(rt);
}
int segquery(int L, int R, int l, int r, int rt)
{
    if (L <= l && r <= R)
        return maxn[rt];
    int res = 0;
    int mid = (l + r) >> 1;
    if (L <= mid)
        res = max(res, segquery(L, R, l, mid, rt << 1));
    if (R > mid)
        res = max(res, segquery(L, R, mid + 1, r, rt << 1 | 1));
    return res;
}
int query(int x, int y)
{
    int res = 0;
    while (top[x] != top[y])
    {
        if (dep[top[x]] < dep[top[y]])
            swap(x, y);
        res = max(res, segquery(id[top[x]], id[x], 1, n, 1));
        x = fa[top[x]];
    }
    if (id[x] > id[y])
        swap(x, y);
    res = max(res, segquery(id[son[x]], id[y], 1, n, 1));
    //注意 因为是将边权转化为的点权 LCA(u,v)该点的点权不应该被计算
    return res;
}
map<pair<int, int>, int> mp, mmp;
void init()
{
    mp.clear(), mmp.clear();
    ecnt = sum = dfscnt = 0;
    for (int i = 1; i <= n; i++)
    {
        fir[i] = son[i] = -1;
        Fa[i] = i;
    }
}
int main()
{
    scanf("%d%d", &n, &m);
    init();
    for (int i = 0; i < m; i++)
    {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        if (u > v)
            swap(u, v);
        mp[{u, v}] = w;
        E[i] = NODE(u, v, w);
    }
    sort(E, E + m, [](NODE x, NODE y) { return x.w < y.w; });
    int tot = 0;
    for (int i = 0; i < m; i++)
    {
        int u = E[i].u, v = E[i].v, w = E[i].w;
        if (merged(u, v))
        {
            mmp[{u, v}] = 1;
            sum += w;
            tot++;
            addedge(u, v, w);
        }
        if (tot == n - 1)
            break;
    }
    // printf("%d\n", sum);
    dfs1(1, -1, 1, -inf);
    dfs2(1, 1);
    // for (int i = 1; i <= n; i++)
    //     printf("%d ", A[i]);
    puts("");
    segbuild(1, n, 1);
    int q;
    scanf("%d", &q);
    while (q--)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        if (u > v)
            swap(u, v);
        if (mmp.count({u, v}))
        {
            printf("%d\n", sum);
        }
        else
        {
            printf("%d\n", sum + mp[{u, v}] - query(u, v));
        }
    }
    return 0;
}

发表留言

人生在世,错别字在所难免,无需纠正。