蓄水池抽样

蓄水池抽样

问题

1、给定一个数据流,数据流长度N很大,且N直到处理完所有数据之前都不可知,请问如何在只遍历一遍数据(O(N))的情况下,能够随机选取出m个不重复的数据

2、在不知道文件行数的情况下,如何在只遍历一遍文件的情况下,随机选取出m行

分析

看到此种问题,我们的第一想法是,把数据流中的数据保存起来,然后通过把数据流中的数据存储起来,然后进行随机获取,我们以leetcode中的某个题目为例,代码如下:

class Solution {
public:
    /** @param head The linked list's head.
        Note that the head is guaranteed to be not null, so it contains at least one node. */
    Solution(ListNode* head) {
        auto h = head;
        while (h) {
            v.push_back(h->val);
            h = h->next;
        }
    }
    
    int getRandom() {
        return v[random() % v.size()];
    }
    private:
      std::vector<int> v;
};

此解法基本上满足了随机获取链表节点值的要求,但是如果再加一个条件,即链表很长,乃至内存中存储不了,此时应该怎么做呢?

解决

这就涉及到蓄水池算法。

蓄水池抽样

蓄水池抽样是一系列随机算法,用于在不替换的情况下,从一个未知大小n的总体中选择一个简单的随机样本(k个项目),只需对这些项目进行一次遍历。总体n的大小对于算法来说是未知的,并且通常对于所有n个项来说都太大而无法放入主内存。随着时间的推移,总体将显示给算法,并且算法不能回顾以前的项目。在任何时候,算法的当前状态必须允许提取一个简单的随机样本,而不替换迄今为止看到的部分总体的大小k。

算法思路大致如下:

  • 如果接收的数据量小于m,则依次放入蓄水池。

  • 当接收到第i个数据时,i >= m,在[0, i]范围内取以随机数d,若d的落在[0, m-1]范围内,则用接收到的第i个数据替换蓄水池中的第d个数据。

  • 重复上一个步骤

证明

为了证明这个解是完全有效的,我们必须证明0<=i<n的任何项流[i]在最终储层[]中的概率是k/n。让我们把证据分为两种情况,因为前k项的处理方式不同。

情况1:对于最后n-k个流项,即,对于流[i],其中k<=i<n

对于每一个这样的流项流[i],我们从0到i选取一个随机索引,如果选取的索引是前k个索引之一,我们将选取索引处的元素替换为流[i]

为了简化证明,让我们先考虑最后一个项目。最后一个项目在最终库中的概率=为最后一个项目选取前k个索引之一的概率=k/n(从大小为n的列表中选取k个项目之一的概率)

现在让我们考虑第二个最后一个项目。最后第二项在最终储层中的概率[]=[在流[n-2]的迭代中选取前k个索引之一的概率]X[在流[n-1]的迭代中选取的索引与在流[n-2]中选取的索引不同的概率]=[k/(n-1)]*[(n-1)/n]=k/n。

类似地,我们可以从流[n-1 ]到流[k]中考虑所有流项的其他项,并推广证明。

情况2:对于前k个流项,即,对于流[i],其中0<=i<k

第一k个项目最初被复制到库[],并且可以在稍后的流[k]到流[n]的迭代中被移除。

来自流[0..k-1]的项目在最终数组中的概率=当项目流[k]、流[k+1]、….时项目未被拾取的概率…。考虑流[n-1]=[k/(k+1)]x[(k+1)/(k+2)]x[(k+2)/(k+3)]x…x[(n-1)/n]=k/n

实现

仍然以leetcode中此题为例,随机获取一个链表中的一个节点值,注意,此处k = 1:

class Solution {
public:
    Solution(ListNode* head) {
        this->head = head;
    }
    
    int getRandom() {
        ListNode* phead = this->head;
        int val = phead->val;
        int count = 1;
        while (phead){
            if (rand() % count++ == 0)
                val = phead->val;
            phead = phead->next;
        }
        return val;
    }
    ListNode* head;
};

那么 如果要获取链表中的k(k != 1)个值,时候,又该怎么实现呢? 此时,需要遍历链表的前k个节点,将前k个节点的值存储在数组中,然后从第k + 1个节点开始遍历链表,从中获取值,代码如下:

class Solution {
public:
    Solution(ListNode* head) {
        this->head = head;
    }
    
    std::vector<int> getRandom(int count) {
        ListNode* phead = this->head;
        int i = 0;
        int k = count;
        std::vector<int> res;
        while (phead && k--) {
          res.push_back(phead->val);
        }
  
        k = count;
        while (phead) {
          int rd = rand() % k++;
          if (rd < k) {
            res[rd] = phead->val;         
          }
  
          phead = phead->next;
        }
        return res;
    }
    ListNode* head;
};

更多文章,请移步

http://studyinfo.top 

上一篇:链表算法操作


下一篇:单链表的翻转