BZOJ 3836 Codeforces 280D k-Maximum Subsequence Sum (模拟费用流、线段树)

题目链接

(BZOJ) https://www.lydsy.com/JudgeOnline/problem.php?id=3836
(Codeforces) http://codeforces.com/contest/280/problem/D

题解

似乎是最广为人知的模拟费用流题目。
线段树维护DP可以做,但是合并的复杂度是\(O(k^2)\), 会TLE.
考虑做\(k\)次费用流,很容易建出一个图,中间的边容量都是1,求的是最大费用。
做费用流的过程,我们每次找一条最长路,然后将其增广,增广之后这条路的边权会取负(因为容量都是\(1\)所以要么正要么负,正反向边不同时出现)。
所以现在要做的就是每次找出和最大的一段区间然后取负,直到和全都小于\(0\)为止。
线段树维护最大、最小子段和及其位置即可。

仔细想想,每次给一段区间取负就相当于给这段区间内的元素选或者不选的状态进行反转(inverse).
也就相当于费用流的退流(反悔)。

时间复杂度\(O(kn\log n)\).

代码

BZOJ权限号到期了,所以没在上面交过。

#include<cstdio>
#include<cstdlib>
#include<iostream>
#include<utility>
#include<vector>
#define pii pair<int,int>
#define mkpr make_pair
using namespace std;

void read(int &x)
{
    int f=1;x=0;char s=getchar();
    while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();}
    while(s>='0'&&s<='9'){x=x*10+s-'0';s=getchar();}
    x*=f;
}

const int N = 1e5;
int a[N+3];
vector<pii > pool;
int n,q;

struct Data
{
    int sum,mx,mxl,mxr,pl,pr,pml,pmr;
    Data() {sum = mx = mxl = mxr = pl = pr = pml = pmr = 0;}
};
Data operator +(Data x,Data y)
{
    Data ret;
    ret.sum = x.sum+y.sum;
    if(x.mxl>x.sum+y.mxl)
    {
        ret.mxl = x.mxl;
        ret.pl = x.pl;
    }
    else
    {
        ret.mxl = x.sum+y.mxl;
        ret.pl = y.pl;
    }
    if(y.mxr>x.mxr+y.sum)
    {
        ret.mxr = y.mxr;
        ret.pr = y.pr;
    }
    else
    {
        ret.mxr = x.mxr+y.sum;
        ret.pr = x.pr;
    }
    if(x.mx>y.mx)
    {
        ret.mx = x.mx;
        ret.pml = x.pml; ret.pmr = x.pmr;
    }
    else
    {
        ret.mx = y.mx;
        ret.pml = y.pml; ret.pmr = y.pmr;
    }
    if(x.mxr+y.mxl>ret.mx)
    {
        ret.mx = x.mxr+y.mxl;
        ret.pml = x.pr; ret.pmr = y.pl;
    }
    return ret;
}

struct SegmentTree
{
    struct SgTNode
    {
        Data ans0,ans1; int tag;
    } sgt[(N<<2)+3];
    void pushup(int u)
    {
        sgt[u].ans0 = sgt[u<<1].ans0+sgt[u<<1|1].ans0;
        sgt[u].ans1 = sgt[u<<1].ans1+sgt[u<<1|1].ans1;
    }
    void pushdown(int u)
    {
        if(sgt[u].tag)
        {
            swap(sgt[u<<1].ans0,sgt[u<<1].ans1);
            sgt[u<<1].tag ^= 1;
            swap(sgt[u<<1|1].ans0,sgt[u<<1|1].ans1);
            sgt[u<<1|1].tag ^= 1;
            sgt[u].tag = 0;
        }
    }
    void build(int u,int le,int ri)
    {
        if(le==ri)
        {
            sgt[u].ans0.sum = sgt[u].ans0.mxl = sgt[u].ans0.mxr = sgt[u].ans0.mx = a[le];
            sgt[u].ans0.pl = sgt[u].ans0.pr = sgt[u].ans0.pml = sgt[u].ans0.pmr = le;
            sgt[u].ans1.sum = sgt[u].ans1.mxl = sgt[u].ans1.mxr = sgt[u].ans1.mx = -a[le];
            sgt[u].ans1.pl = sgt[u].ans1.pr = sgt[u].ans1.pml = sgt[u].ans1.pmr = le;
            return;
        }
        int mid = (le+ri)>>1;
        build(u<<1,le,mid);
        build(u<<1|1,mid+1,ri);
        pushup(u);
    }
    void modify(int u,int le,int ri,int pos,int x)
    {
        if(le==pos && ri==pos)
        {
            sgt[u].ans0.sum = sgt[u].ans0.mxl = sgt[u].ans0.mxr = sgt[u].ans0.mx = x;
            sgt[u].ans0.pl = sgt[u].ans0.pr = sgt[u].ans0.pml = sgt[u].ans0.pmr = le;
            sgt[u].ans1.sum = sgt[u].ans1.mxl = sgt[u].ans1.mxr = sgt[u].ans1.mx = -x;
            sgt[u].ans1.pl = sgt[u].ans1.pr = sgt[u].ans1.pml = sgt[u].ans1.pmr = le;
            return;
        }
        pushdown(u);
        int mid = (le+ri)>>1;
        if(pos<=mid) modify(u<<1,le,mid,pos,x);
        if(pos>mid) modify(u<<1|1,mid+1,ri,pos,x);
        pushup(u);
    }
    void setoppo(int u,int le,int ri,int lb,int rb)
    {
        if(le>=lb && ri<=rb)
        {
            sgt[u].tag ^= 1;
            swap(sgt[u].ans0,sgt[u].ans1);
            return;
        }
        pushdown(u);
        int mid = (le+ri)>>1;
        if(lb<=mid) setoppo(u<<1,le,mid,lb,rb);
        if(rb>mid) setoppo(u<<1|1,mid+1,ri,lb,rb);
        pushup(u);
    }
    Data query(int u,int le,int ri,int lb,int rb)
    {
        if(le>=lb && ri<=rb) {return sgt[u].ans0;}
        pushdown(u);
        int mid = (le+ri)>>1; Data ret;
        if(rb<=mid) {ret = query(u<<1,le,mid,lb,rb);}
        else if(lb>mid) {ret = query(u<<1|1,mid+1,ri,lb,rb);}
        else {ret = query(u<<1,le,mid,lb,rb)+query(u<<1|1,mid+1,ri,lb,rb);}
        pushup(u);
        return ret;
    }
} sgt;

int main()
{
    scanf("%d",&n);
    for(int i=1; i<=n; i++) scanf("%d",&a[i]);
    sgt.build(1,1,n);
    scanf("%d",&q);
    while(q--)
    {
        int opt; scanf("%d",&opt);
        if(opt==0)
        {
            int x,y; scanf("%d%d",&x,&y);
            a[x] = y; sgt.modify(1,1,n,x,y);
        }
        else if(opt==1)
        {
            int x,y,z; scanf("%d%d%d",&x,&y,&z);
            int ans = 0;
            for(int i=1; i<=z; i++)
            {
                Data cur = sgt.query(1,1,n,x,y);
//              printf("delta=%d [%d,%d]\n",cur.mx,cur.pml,cur.pmr);
                if(cur.mx<=0) break;
                ans += cur.mx;
                sgt.setoppo(1,1,n,cur.pml,cur.pmr);
                pool.push_back(mkpr(cur.pml,cur.pmr));
            }
            printf("%d\n",ans);
            for(int i=0; i<pool.size(); i++) {sgt.setoppo(1,1,n,pool[i].first,pool[i].second);}
            pool.clear();
        }
    }
    return 0;
}
上一篇:SQL in、not in、exists和not exists的区别:


下一篇:Sql中当插入的字符多于8000个字符只能插入一部分,数据丢失的处理