没有看过前一篇的同学请看这里:蓄水池抽样浅说 (1)
跳跳跳!—— Algorithm X
前面介绍的方法都很好,但是要计算 $n-k$ 个随机数实在是有点浪费时间…… 有兴趣的同学可以把那个调用函数换成比如 reservoirSample(range(10**8), 10**5)
,就知道这东西还是要算上一会儿的了。
再来看看这个过程吧——把前 $k$ 个元素放进去之后,随机决定接下来的元素要不要放进去,但是每次决定时都需要产生一个随机数。要是我们能够确定应该跳过多少个元素,不就可以省掉很多生成随机数的工夫了吗?每一步过程就变成了
- 确定该跳过多少个元素 $S(k, n)$
- 跳过 $S(k, n)$ 个元素
- 从前 $k$ 个元素中随机产生一个要替换的元素,用下一个元素替换
那该跳过的元素数 $S(k,n)$ 要怎么定下来呢?这显然也是一个随机数。我们要计算它的累积分布函数 $\mathbb{P}(S(k,n) \leq s)$,即在第 $n$ 个元素出现时,跳过 $s$ 或 $s$ 以内个元素的概率。这似乎有点不那么直观,但是我们可以反过来计算 $1 – \mathbb{P}(S(k,n) > s)$,要跳过超过 $s$ 个元素,不就是从第 $n + 1$ 个到第 $n+s+1$ 个都不被选中嘛!我们知道第 $n$ 个元素被选中的概率是 $k / n$,所以
$$\begin{align}
F_S(s) & =\mathbb{P}(S(k,n) \leq s) \\ &= 1 – \mathbb{P}(S(k,n) > s) \\
&= 1 – \left(1 – \frac{k}{n+1}\right)\left(1 – \frac{k}{n+2}\right)\cdots\left(1 – \frac{k}{n+s+1}\right) \\
&= 1 – \frac{(n+1-k)(n+2-k)\cdots(n+s-k+1)}{(n+1)(n+2)\cdots(n+s+1)}
\end{align}
$$
好啦,概率分布知道了,怎么生成这个概率分布呢?这里又要用到一个生成任意概率分布的技巧,就是所谓的反变换法。令 $U$ 是均匀分布在 $[0,1]$ 上的随机变量,如果 $X = F^{-1}(U)$,则
$$F_X(a) = \mathbb{P}(X \leq a) = \mathbb{P}(F^{-1}(U) \leq a) = \mathbb{P}(U \leq F(a)) = F(a)$$
那么在这里,我们只需要生成一个 $[0,1]$ 的随机数 $u$,然后找出让 $F_S(s) \leq u$ 的最小 $s$ 值就好了。这就是所谓的 Algorithm X:
import random import sys import time def getS(n, k): u = random.random() S = 0 n += 1 quot = (n - k) / n while (quot > u): S += 1 n += 1 quot *= (n - k) / n return S def reservoirSampleX(stream, sample_size): result = [] s = 0 for index, line in enumerate(stream): if index < sample_size: result.append(line) else: if not s: result[int(random.random() * sample_size)] = line s = getS(index + 1, sample_size) else: s -= 1 return result
在 pypy 下测试,从 $10^8$ 个元素生成 $10^5$ 个元素的样本只需要 1.6 秒,而 Algorithm R 则需要 5 秒。但是,原先的 R 算法对样本大小并不怎么敏感,因为总归要算那么多次,而这个 X 算法所花的时间则会随着样本数增加而有所增加。
上面这个算法来自 Vitter 的一篇著名论文[1],这篇文章针对 $n$ 比较大的时候进一步改进了计算 $s$ 的过程,利用两个函数对 $F(s)$ 进行夹逼从而避免了对 $s$ 的搜索,由此得到了性能进一步提高的 Z 算法。PostgreSQL 就实现了这个 Z 算法来进行数据抽样,有兴趣的同学可以看看源代码。
前面说的是性能,那如果我们需要不均匀的抽样,有没有办法呢?那这个就留到下次再说啦。
- Vitter, Jeffrey S. “Random sampling with a reservoir.” ACM Transactions on Mathematical Software (TOMS) 11.1 (1985): 37-57.
更多算法内容请见《算法拾珠》。