Codechef TSUM2 Sum on Tree 点分治、李超线段树

传送门


点分治模板题都不会迟早要完

发现这道题需要统计所有路径的信息,考虑点分治统计路径信息。

点分治之后,因为路径是有向的,所以对于每一条路径都有向上和向下的两种。那么如果一条向上的路径,点数为\(s_1\),单独考虑这条路径的权值和为\(v_1\),和一条向下的路径,点权和为\(s_2\),单独考虑这条路径的权值和为\(v_2\),这两条路径进行拼接(分治中心算在向上路径中,这样\(s_1 > 0\)),那么拼接起来的路径的权值和就是\(s_1s_2 + v_1 + v_2\)。如果我们枚举到了一条向上的路径,对于每一条向下路径能够和这条向上路径拼接产生的贡献是一个一次函数的形式,同时横坐标范围在\(1\)到\(50000\)之间,所以可以使用李超线段树维护最值。

注意到来自同一棵子树的向上和向下路径不能拼接,所以对于每一棵子树,先把子树内所有向上路径统计完成,再把向下的路径放进李超树,从前往后做一遍再从后往前做一遍即可。

注意一下分治中心单独作为一条向上路径/向下路径为空的情况。

#include<bits/stdc++.h>
//this code is written by Itst
using namespace std;

int read(){
    int a = 0; char c = getchar(); bool f = 0;
    while(!isdigit(c)){f = c == '-'; c = getchar();}
    while(isdigit(c)){a = a * 10 + c - 48; c = getchar();}
    return f ? -a : a;
}

const int _ = 100003;
#define int long long
struct line{
    int k , b;
    line(int _k = -1e9 , int _b = -1e18) : k(_k) , b(_b){}
};
long double sect(line a , line b){return 1.0 * (b.b - a.b) / (a.k - b.k);}

namespace Tree{
    line now[_ << 2];

#define mid ((l + r) >> 1)
#define lch (x << 1)
#define rch (x << 1 | 1)

    void init(int x , int l , int r){
        if(now[x].k != -1e9 || now[x].b != -1e18){
            now[x] = line();
            if(l != r){init(lch , l , mid); init(rch , mid + 1 , r);}
        }
    }
    
    void ins(int x , int l , int r , line p){
        line q = now[x];
        if(q.k == p.k){now[x] = q.b > p.b ? q : p; return;}
        if(p.k > q.k) swap(p , q);
        long double t = sect(p , q);
        if(t <= l){now[x] = q; return;}
        if(t >= r){now[x] = p; return;}
        if(t <= mid){
            now[x] = q;
            if(l != r) ins(lch , l , mid , p);
        }
        else{
            now[x] = p;
            if(l != r) ins(rch , mid + 1 , r , q);
        }
    }

    int calc(line A , int x){return A.k * x + A.b;}
    
    int qry(int x , int l , int r , int tar){
        if(l == r) return calc(now[x] , tar);
        return max(calc(now[x] , tar) , mid >= tar ? qry(lch , l , mid , tar) : qry(rch , mid + 1 , r , tar));
    }
}

struct Edge{
    int end , upEd;
}Ed[_ << 1];
int head[_] , val[_] , cntEd , N , mnsz , mnid , nowsz , ans;
bool vis[_];

void addEd(int a , int b){
    Ed[++cntEd] = (Edge){b , head[a]};
    head[a] = cntEd;
}

void getsz(int x){
    vis[x] = 1; ++nowsz;
    for(int i = head[x] ; i ; i = Ed[i].upEd)
        if(!vis[Ed[i].end]) getsz(Ed[i].end);
    vis[x] = 0;
}

int getrt(int x){
    int sz = 1 , mx = 1; vis[x] = 1;
    for(int i = head[x] ; i ; i = Ed[i].upEd)
        if(!vis[Ed[i].end]){
            int t = getrt(Ed[i].end);
            mx = max(mx , t); sz += t;
        }
    mx = max(mx , nowsz - sz);
    if(mx < mnsz){mnsz = mx; mnid = x;}
    vis[x] = 0; return sz;
}

void dfs1(int x , int sum , int sz , int v){
    ans = max(ans , Tree::qry(1 , 1 , N , sz) + sum);
    vis[x] = 1;
    for(int i = head[x] ; i ; i = Ed[i].upEd)
        if(!vis[Ed[i].end]) dfs1(Ed[i].end , sum + val[Ed[i].end] + v , sz + 1 , val[Ed[i].end] + v);
    vis[x] = 0;
}

void dfs2(int x , int sum , int sz , int v){
    Tree::ins(1 , 1 , N , line(v , sum));
    vis[x] = 1;
    for(int i = head[x] ; i ; i = Ed[i].upEd)
        if(!vis[Ed[i].end]) dfs2(Ed[i].end , sum + (sz + 1) * val[Ed[i].end] , sz + 1 , val[Ed[i].end] + v);
    vis[x] = 0;
}

void solve(int x){
    mnsz = 1e9; nowsz = 0;
    getsz(x); getrt(x); vis[x = mnid] = 1;
    Tree::init(1 , 1 , N); Tree::ins(1 , 1 , N , line(0 , 0));
    vector < int > nxt;
    for(int i = head[x] ; i ; i = Ed[i].upEd)
        if(!vis[Ed[i].end]){
            nxt.push_back(Ed[i].end);
            dfs1(Ed[i].end , 2 * val[x] + val[Ed[i].end] , 2 , val[x] + val[Ed[i].end]);
            dfs2(Ed[i].end , val[Ed[i].end] , 1 , val[Ed[i].end]);
        }
    ans = max(ans , Tree::qry(1 , 1 , N , 1) + val[x]); Tree::init(1 , 1 , N);
    reverse(nxt.begin() , nxt.end());
    for(auto t : nxt){
        dfs1(t , 2 * val[x] + val[t] , 2 , val[x] + val[t]);
        dfs2(t , val[t] , 1 , val[t]);
    }
    for(auto t : nxt) solve(t);
}

signed main(){
    for(int T = read() ; T ; --T){
        N = read(); cntEd = 0; ans = -1e18;
        memset(vis , 0 , sizeof(bool) * (N + 1));
        memset(head , 0 , sizeof(int) * (N + 1));
        for(int i = 1 ; i <= N ; ++i) val[i] = read();
        for(int i = 1 ; i < N ; ++i){
            int a = read() , b = read(); addEd(a , b); addEd(b , a);
        }
        solve(1); printf("%lld\n" , ans);
    }
    return 0;
}
上一篇:「BZOJ2654」tree


下一篇:跳石头