字符向量的随机样本,不包含相互前缀的元素

44
考虑一个字符向量pool,其元素是最多具有max_len位的(前导零填充的)二进制数字。
max_len <- 4
pool <- unlist(lapply(seq_len(max_len), function(x) 
  do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))

pool
##  [1] "0"    "1"    "00"   "10"   "01"   "11"   "000"  "100"  "010"  "110" 
## [11] "001"  "101"  "011"  "111"  "0000" "1000" "0100" "1100" "0010" "1010"
## [21] "0110" "1110" "0001" "1001" "0101" "1101" "0011" "1011" "0111" "1111"

我想对这些元素进行抽样,数量为n,但有一个限制条件,不能选择任何已经被选中元素的前缀(例如,如果我们选择了1101,则不能再选择以111110开头的元素,而如果我们选择1,则不能再选择以1开头的元素,如1011100等)。

以下是我的while尝试,但当n很大时(或接近2^max_len)速度会很慢。

set.seed(1)
n <- 10
chosen <- sample(pool, n)
while(any(rowSums(outer(paste0('^', chosen), chosen, Vectorize(grepl))) > 1)) {
  prefixes <- rowSums(outer(paste0('^', chosen), chosen, Vectorize(grepl))) > 1
  pool <- pool[rowSums(Vectorize(grepl, 'pattern')(
    paste0('^', chosen[!prefixes]), pool)) == 0]
  chosen <- c(chosen[!prefixes], sample(pool, sum(prefixes)))
}

chosen
## [1] "0100" "0101" "0001" "0011" "1000" "111"  "0000" "0110" "1100" "0111"

这可以稍微改进一下,通过最初从pool中移除那些包含将意味着在pool中剩余的元素不足以取得大小为n的总样本的元素来实现。例如,当max_len = 4n > 9时,我们可以立即从pool中删除01,因为通过包括其中任何一个,最大样本将是9(要么是0和以1开头的八个4字符元素,要么是1和以0开头的八个4字符元素)。

基于这种逻辑,我们可以在取初始样本之前像这样省略pool中的元素:

pool <- pool[
  nchar(pool) > tail(which(n > (2^max_len - rev(2^(0:max_len))[-1] + 1)), 1)]

有人能想到更好的方法吗?我觉得我可能忽略了更简单的方法。


编辑

为了澄清我的意图,我将把池子描绘成一组分支,其中交叉点和末端是节点(pool的元素)。假设在下面的图中黄色节点(即010)被绘制出来。现在,整个红色“分支”,由节点0、01和010组成,将从池子中移除。这就是我所说的禁止采样已经在我们的样本中“前缀”节点(以及那些已经被我们的样本中的节点“前缀”的节点)的含义。

enter image description here

如果采样的节点位于分支的一半位置,例如下图中的01,则所有红色节点(0、01、010和011)都被禁止,因为0是01的前缀,而01既是010的前缀,也是011的前缀。

enter image description here

I don't mean to sample either 1 or 0 at each junction (i.e. walking along the branches flipping coins at forks) - it's fine to have both in the sample, as long as: (1) parents (or grand-parents, etc.) or children (grandchildren, etc.) of the node aren't already sampled; and (2) upon sampling the node there will be sufficient nodes remaining to achieve the desired sample of size n.
In the second figure above, if 010 was the first pick, all nodes at black nodes are still (currently) valid, assuming n <= 4. For example, if n==4 and we sampled node 1 next (and so our picks now included 01 and 1), we would subsequently disallow node 00 (due to rule 2 above) but could still pick 000 and 001, giving us our 4-element sample. If n==5, on the other hand, node 1 would be disallowed at this stage.

3
我对你具体的算法如何改进没有任何见解,但我们知道R的字符串操作速度非常慢。也许一个基于数字向量比较的解决方案会更快。比如,一个检查函数(a, b),将a分解为所有可能长度的向量(a'),然后在所有组合中进行a' == b的检查。 - Vlo
2
@jbaums,你的更新看起来非常像我在链接中指向的哈夫曼树(例如这里)。 - Henrik
1
@jbaums 我在这个主题上远非专家(否则我会建议一个答案;))。我只是发现你的问题和“无前缀编码/霍夫曼编码”的主题相似。在我的天真眼中,它似乎是一种解决你的问题的有益方法,尤其是启发式地(例如,如你在编辑中添加的树所建议的)。我不知道你所需取样的适当、更具体的术语。我不知道“那种方法”是否是_给定_一个池,而不是有一个算法来最小化结果的熵(比特/符号)。 - Henrik
1
我正在思考这个问题,但实际上找不到将“110”或“0110”作为池的用途,因为在二进制中它们是相同的,所以即使样本具有不同的字符长度,它们也具有相同的值。如果您正在处理真正的二进制池,则4个长度仅有16个条目(所有1、2和3个字符长的条目都只是其子集)。您的树形表示使我想到,如果您正在使用固定长度的条目,则位掩码比较可能会有所帮助。 - Tensibai
4
对于那些想知道我的赏金计划进展情况的人...那些已经发布的答案都花费了相当多的时间和精力来制作高质量的答案。我从中学到了很多,并计划授予一些奖励金。授予奖励金有24小时的延迟,所以这需要一些时间。话虽如此,如果有人想尝试的话,新的答案也是受欢迎的 ;) - jbaums
显示剩余18条评论
8个回答

19

介绍

这是我们在另一个答案中实现的字符串算法的数字变体。它更快且不需要创建或排序池。

算法概述

我们可以使用整数来表示二进制字符串,这极大地简化了池生成和顺序消除值的问题。例如,对于 max_len==3,我们可以将数字 1--(其中 - 表示填充)表示为十进制中的 4。此外,我们可以确定选择此数字后需要消除的数字是那些位于 44 + 2 ^ x - 1 之间的数字。这里的 x 是填充元素的数量(在本例中为 2),因此需要消除的数字位于 44 + 2 ^ 2 - 1 之间(或者在 100110111 中表示的 47 之间)。

为了完全匹配您的问题,我们需要做一些微调,因为您在算法的某些部分将二进制中可能相同的数字视为不同。例如,10010-1--都是相同的数字,但在您的方案中需要以不同的方式处理。在max_len==3的情况下,我们有8个可能的数字,但有14种可能的表示形式:
0 - 000: 0--, 00-
1 - 001:
2 - 010: 01-
3 - 011:
4 - 100: 1--, 10-
5 - 101:
6 - 110: 11-
7 - 111:

0和4有三种可能的编码,2和6有两种,其他数字只有一种。我们需要生成一个整数池,表示具有多个表示的数字的更高选择概率,以及跟踪数字包含的空格数量的机制。我们可以通过在数字末尾附加一些位来指示我们想要的加权值来实现这一点。因此,我们的数字变为(这里我们使用两个位):

jbaum | int | bin | bin.enc | int.enc    
  0-- |   0 | 000 |   00000 |       0
  00- |   0 | 000 |   00001 |       1      
  000 |   0 | 000 |   00010 |       2      
  001 |   1 | 001 |   00100 |       3      
  01- |   2 | 010 |   01000 |       4  
  010 |   2 | 010 |   01001 |       5  
  011 |   3 | 011 |   01101 |       6  
  1-- |   4 | 100 |   10000 |       7  
  10- |   4 | 100 |   10001 |       8  
  100 |   4 | 100 |   10010 |       9  
  101 |   5 | 101 |   10100 |      10  
  11- |   6 | 110 |   11000 |      11   
  110 |   6 | 110 |   11001 |      12   
  111 |   7 | 111 |   11100 |      13

一些有用的属性:
  • enc.bits 表示编码所需的位数(在本例中为两位)
  • int.enc %% enc.bits 告诉我们有多少个数字是明确指定的
  • int.enc %/% enc.bits 返回 int
  • int * 2 ^ enc.bits + explicitly.specified 返回 int.enc

请注意,这里的 explicitly.specified 在我们的实现中介于 0max_len - 1 之间,因为至少指定了一个数字。现在,我们使用整数完全表示您的数据结构。我们可以从整数中取样并产生您想要的结果,带有正确的权重等。这种方法的一个限制是,在 R 中,我们使用 32 位整数,并且必须为编码保留一些位,因此我们将自己限制在大约 max_len==25 的池中。如果使用双精度浮点数指定整数,则可以做得更大,但我们在这里没有这样做。

避免重复选择

有两种粗略的方法可以确保我们不会两次选择相同的值:
  1. 跟踪剩余可供选择的值,并从中随机抽样
  2. 从所有可能的值中随机抽样,然后检查该值是否已被选择过,如果是,则重新抽样
虽然第一种选项看起来最简洁,但它实际上在计算上非常昂贵。它要求每次选择时对所有可能的值进行向量扫描以预先淘汰已选择的值,或者创建一个包含未淘汰值的缩小向量。如果使用C代码通过引用使向量缩小,则缩小选项比向量扫描更有效率,但即使如此,它也需要重复翻译可能大部分向量,并且需要C。

这里我们使用第二种方法。这使我们能够一次随机打乱可能值的宇宙,然后按顺序选择每个值,检查它是否被淘汰,如果是,则选择另一个值,以此类推。这种方法有效,因为根据我们的值编码,检查一个值是否被选中是微不足道的;我们可以仅根据值本身推断出其在排序表中的位置。因此,我们记录了排序表中每个值的状态,并且可以通过直接索引访问(无需扫描)更新或查找该状态。

示例

在R基础实现中,这种算法的实现可作为a gist获得。这种特定的实现仅拉取完整的抽样。这是从max_len==4池中抽取8个元素的10个样本的示例:

# each column represents a draw from a `max_len==4` pool

set.seed(6); replicate(10, sample0110b(4, 8))
     [,1]   [,2]   [,3]   [,4]   [,5]   [,6]   [,7]   [,8]   [,9]   [,10] 
[1,] "1000" "1"    "0011" "0010" "100"  "0011" "0"    "011"  "0100" "1011"
[2,] "111"  "0000" "1101" "0000" "0110" "0100" "1000" "00"   "0101" "1001"
[3,] "0011" "0110" "1001" "0100" "0000" "0101" "1101" "1111" "10"   "1100"
[4,] "0100" "0010" "0000" "0101" "1101" "101"  "1011" "1101" "0110" "1101"
[5,] "101"  "0100" "1100" "1100" "0101" "1001" "1001" "1000" "1111" "1111"
[6,] "110"  "0111" "1011" "111"  "1011" "110"  "1111" "0100" "0011" "000" 
[7,] "0101" "0101" "111"  "011"  "1010" "1000" "1100" "101"  "0001" "0101"
[8,] "011"  "0001" "01"   "1010" "0011" "1110" "1110" "1001" "110"  "1000"

我们最初有两种采用方法#1避免重复的实现,一种是基于R的,另一种是基于C的,但即使当n很大时,即使是C版本也不如新的基于R的版本快。这些函数确实实现了绘制不完整的样本的能力,因此我们在此提供它们供参考:

比较基准测试

这里是一组基准测试,比较了出现在这个Q/A中的多个函数。时间以毫秒为单位计算。 brodie.b 版本是本答案中描述的版本。 brodie 是原始实现,brodie.C 是带有一些C的原始实现。所有这些都强制要求完整的样本。brodie.str 是另一个答案中的基于字符串的版本。

   size    n  jbaum josilber  frank tensibai brodie.b brodie brodie.C brodie.str
1     4   10     11        1      3        1        1      1        1          0
2     4   50      -        -      -        1        -      -        -          1
3     4  100      -        -      -        1        -      -        -          0
4     4  256      -        -      -        1        -      -        -          1
5     4 1000      -        -      -        1        -      -        -          1
6     8   10      1      290      6        3        2      2        1          1
7     8   50    388        -      8        8        3      4        3          4
8     8  100  2,506        -     13       18        6      7        5          5
9     8  256      -        -     22       27       13     14       12          6
10    8 1000      -        -      -       27        -      -        -          7
11   16   10      -        -    615      688       31     61       19        424
12   16   50      -        -  2,123    2,497       28    276       19      1,764
13   16  100      -        -  4,202    4,807       30    451       23      3,166
14   16  256      -        - 11,822   11,942       40  1,077       43      8,717
15   16 1000      -        - 38,132   44,591       83  3,345      130     27,768

这对于更大的资源池相对来说扩展性较好。

system.time(sample0110b(18, 100000))
   user  system elapsed 
  8.441   0.079   8.527 

Benchmark笔记:
  • Frank和Brodie(减去Brodie.str)不需要预先生成池,这将影响比较(见下文)
  • Josilber是LP版本
  • jbaum是OP示例
  • tensibai稍作修改以在池为空时退出而不是失败
  • 未设置运行Python,因此无法进行比较/缓冲区账户
  • -表示不可行的选项或时间过长而无法合理计时

这些时间不包括绘制池(对于大小分别为4816jbaumjosilberbrodie.str运行分别为0.82.5401毫秒),以及排序它们(对于大小分别为4816brodie.str,除了绘制之外还需要0.12.73700毫秒)。是否要包含这些时间取决于您针对特定池运行函数的次数。此外,几乎肯定有更好的生成/排序池的方法。

这些是使用microbenchmark进行三次运行的中间时间。代码可在 gist中获取 ,但请注意,您必须先加载sample0110bsample0110sample01101sample01函数。

很酷;由于我不懂二进制,很多东西都超出了我的理解范围。顺便说一下,似乎有时会排除“0”,比如当我尝试z = replicate(1e4,sample011(2,4)); table(z)时,我不确定为什么;对于size=1,它是存在的... - Frank
@jbaums,抱歉,我认为我理解得很好,但是我在那里留下了一个旧段落。关于前缀,我昨晚上床时意识到了这一点(添加了更新以说明它),但我认为有一种方法可以修复它,我在早期的实现中曾经使用过,但由于速度而放弃了。 - BrodieG
太好了,是的,这正是我想要的——共享前缀没问题,但在示例中包括前缀本身不行。再次感谢。 - jbaums
@Frank,可能是池的大小超出限制或者运行速度太慢了。lbaum只是原始问题中包含的代码(while循环)。 - BrodieG
1
@jbaums,感谢您的赏金。这个问题一定是赏金奖励最高的问题之一。我更新了答案并加入了盲抽样版。结果大多数情况下比C版本更快,并且完全使用基础R语言编写。 - BrodieG
显示剩余6条评论

15

我发现这个问题很有趣,所以我尝试用非常基础的R技能来解决它(所以可能还有改进空间):

经过@Franck建议的更新版本:

library(microbenchmark)
library(lineprof)

max_len <- 16
pool <- unlist(lapply(seq_len(max_len), function(x) 
  do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))
n<-100

library(stringr)
tree_sample <- function(samples,pool) {
  results <- vector("integer",samples)
  # Will be used on a regular basis, compute it in advance
  PoolLen <- str_length(pool)
  # Make a mask vector based on the length of each entry of the pool
  masks <- strtoi(str_pad(str_pad("1",PoolLen,"right","1"),max_len,"right","0"),base=2)

  # Make an integer vector from "0" right padded orignal: for max_len=4 and pool entry "1" we get "1000" => 8
  # This will allow to find this entry as parent of 10 and 11 which become "1000" and "1100", as integer 8 and 12 respectively
  # once bitwise "anded" with the repective mask "1000" the first bit is striclty the same, so it's a parent.
  integerPool <- strtoi(str_pad(pool,max_len,"right","0"),base=2)

  # Create a vector to filter the available value to sample
  ok <- rep(TRUE,length(pool))

  #Precompute the result of the bitwise and betwwen our integer pool and the masks   
  MaskedPool <- bitwAnd(integerPool,masks)

  while(samples) {
    samp <- sample(pool[ok],1) # Get a sample
    results[samples] <- samp # Store it as result
    ok[pool == samp] <- FALSE # Remove it from available entries

    vsamp <- strtoi(str_pad(samp,max_len,"right","0"),base=2) # Get the integer value of the "0" right padded sample
    mlen <- str_length(samp) # Get sample len

    #Creation of unitary mask to remove childs of sample
    mask <- strtoi(paste0(rep(1:0,c(mlen,max_len-mlen)),collapse=""),base=2)

    # Get the result of bitwise And between the integerPool and the sample mask 
    FilterVec <- bitwAnd(integerPool,mask)

    # Get the bitwise and result of the sample and it's mask
    Childm <- bitwAnd(vsamp,mask)

    ok[FilterVec == Childm] <- FALSE  # Remove from available entries the childs of the sample
    ok[MaskedPool == bitwAnd(vsamp,masks)] <- FALSE # compare the sample with all the masks to remove parents matching

    samples <- samples -1
  }
  print(results)
}
microbenchmark(tree_sample(n,pool),times=10L)

主要思路是使用位掩码比较来判断一个样本是否是另一个样本的父节点(共同的位部分),如果是,则从池中删除该元素。

在我的机器上,从长度为16的池中绘制100个样本现在需要1.4秒。


13
你可以对池进行排序以帮助决定要淘汰哪些元素。例如,查看一个有三个元素的已排序池:
 [1] "0"   "00"  "000" "001" "01"  "010" "011" "1"   "10"  "100" "101" "11" 
[13] "110" "111"

我可以判断出在选定项之后具有比我的项更多字符的任何内容,直到第一个具有相同数量或更少字符的项目。例如,如果我选择“01”,我可以立即看到需要删除下面两个项目(“010”,“011”),但不是其后面的那个项目,因为“1”的字符更少。然后删除“0”很容易实现。这是一个实现示例:

library(fastmatch)  # could use `match`, but we repeatedly search against same hash

# `pool` must be sorted!

sample01 <- function(pool, n) {
  picked <- logical(length(pool))
  chrs <- nchar(pool)
  pick.list <- character(n)
  pool.seq <- seq_along(pool)

  for(i in seq(n)) {
    # Make sure pool not exhausted

    left <- which(!picked)
    left.len <- length(left)
    if(!length(left)) break

    # Sample from pool

    seq.left <- seq.int(left)
    pool.left <- pool[left]
    chrs.left <- chrs[left]
    pick <- sample(length(pool.left), 1L)

    # Find all the elements with more characters that are disqualified
    # and store their indices in `valid` (bad name...)

    valid.tmp <- chrs.left > chrs.left[[pick]] & seq.left > pick
    first.invalid <- which(!valid.tmp & seq.left > pick)
    valid <- if(length(first.invalid)) {
      pick:(first.invalid[[1L]] - 1L)
    } else pick:left.len

    # Translate back to original pool indices since we're working on a 
    # subset in `pool.left`

    pool.seq.left <- pool.seq[left]
    pool.idx <- pool.seq.left[valid]
    val <- pool[[pool.idx[[1L]]]]

    # Record the picked value, and all the disqualifications

    pick.list[[i]] <- val
    picked[pool.idx] <- TRUE

    # Disqualify shorter matches

    to.rem <- vapply(
      seq.int(nchar(val) - 1), substr, character(1L), x=val, start=1L
    )
    to.rem.idx <- fmatch(to.rem, pool, nomatch=0)
    picked[to.rem.idx] <- TRUE  
  }
  pick.list  
}

还有一个函数可以生成排序后的池(与您的代码完全相同,但返回已排序的):

make_pool <- function(size)
  sort(
    unlist(
      lapply(
        seq_len(size), 
        function(x) do.call(paste0, expand.grid(rep(list(c('0', '1')), x))) 
  ) ) )

然后,使用 max_len 3 池化层(用于在视觉上检查事物是否按预期运行):

pool3 <- make_pool(3)
set.seed(1)
sample01(pool3, 8)
# [1] "001" "1"   "010" "011" "000" ""    ""    ""   
sample01(pool3, 8)
# [1] "110" "111" "011" "10"  "00"  ""    ""    ""   
sample01(pool3, 8)
# [1] "000" "01"  "11"  "10"  "001" ""    ""    ""   
sample01(pool3, 8)
# [1] "011" "101" "111" "001" "110" "100" "000" "010"    

在最后一种情况下,我们获取了所有的3位二进制组合(2 ^ 3),因为恰好我们随机抽取的是3位二进制数。此外,只有一个大小为3的池子,存在许多抽样方式使得无法完全进行8次抽取; 您可以通过您提出的方法消除妨碍从池中进行完整绘制的组合。

这非常快。查看使用备用解决方案的max_len == 9示例,需要2秒钟:

pool9 <- make_pool(9)
microbenchmark(sample01(pool9, 4))
# Unit: microseconds
#                expr     min      lq  median      uq     max neval
#  sample01(pool9, 4) 493.107 565.015 571.624 593.791 983.663   100    

大约半毫秒。你甚至可以尝试使用相当大的池:

pool16 <- make_pool(16)  # 131K entries
system.time(sample01(pool16, 100))
#  user  system elapsed 
# 3.407   0.146   3.552 

这并不是非常快的方法,但我们讨论的是一个有130K个项目的池子。另外还有进一步优化的潜力。

请注意,对于大型池子,排序步骤变得相对较慢,但我没有将其计算在内,因为您只需要执行一次,而且您可能可以想出一个合理的算法来预先对池进行排序。

还有一种基于整数转二进制的更快的方法,我在一个已经被删除的答案中探索了这种方法,但这需要更多的工作才能精确地适配到您所需的内容。


你说“你可以用你的建议来解决这个问题”,但实际上,OP的建议(消除特定字符串)是不够的。像josilber的答案一样,需要顺序抽样。以OP的示例为例,当max_len=4n=10时,我们知道01被预先禁止了,但我们的前四次抽样可能是00011110。排除这样的组合将是一个难以想象的麻烦。由于OP忽略了定义概率,我只会选择tail(pool,n) - Frank
@Frank,如果要对特定节点进行采样,计算剩余可供采样的最大节点数并不太难。当然,这可能效率不高,但是作为顺序采样解决方案的一部分,Rcpp可以使其变得可行。 - jbaums
您IP地址为143.198.54.68,由于运营成本限制,当前对于免费用户的使用频率限制为每个IP每72小时10次对话,如需解除限制,请点击左下角设置图标按钮(手机用户先点击左上角菜单按钮)。 - BrodieG
sample01 包装在 while 中,以确保在 n 相对于 2^size 较小时(例如 s <- sample01(pool <- make_pool(6), 10); while(any(s=='')) s <- sample01(pool, 10)),完整的样本运行得非常好,但是当 n 接近或等于 2^size 时(例如 s <- sample01(pool <- make_pool(6), 64); while(any(s=='')) s <- sample01(pool, 64)),就会出现问题;这并不奇怪,因为在这种特殊情况下只存在一种可能的完整样本。理想情况下,我希望能够在两种情况下都有效地进行采样。这是一个很好的方法,但我可能要求过多了! - jbaums
@jbaums,这说得有道理。此外,我已经在数字版本上有一个可行的解决方案,但仍在整理中。事先选择无效的选项可能非常困难。例如,即使 n 相对较小于 size,将 01 作为您的前两个选项也是无效的,因此您需要预先计算无效的计算结果,这可能不切实际或同样耗费费用。 - BrodieG

13

将id映射到字符串。您可以像 @BrodieG 提到的那样,将数字映射到您的0/1向量中:

# some key objects

n_pool      = sum(2^(1:max_len))      # total number of indices
cuts        = cumsum(2^(1:max_len-1)) # new group starts
inds_by_g   = mapply(seq,cuts,cuts*2) # indices grouped by length

# the mapping to strings (one among many possibilities)

library(data.table)
get_01str <- function(id,max_len){
    cuts = cumsum(2^(1:max_len-1))
    g    = findInterval(id,cuts)
    gid  = id-cuts[g]+1

    data.table(g,gid)[,s:=
      do.call(paste,c(list(sep=""),lapply(
        seq(g[1]), 
        function(x) (gid-1) %/% 2^(x-1) %% 2
      )))
    ,by=g]$s      
} 

找到要删除的id。 我们将从抽样池中顺序删除id

 # the mapping from one index to indices of nixed strings

get_nixstrs <- function(g,gid,max_len){

    cuts         = cumsum(2^(1:max_len-1))
    gids_child   = {
      x = gid%%2^sequence(g-1)
      ifelse(x,x,2^sequence(g-1))
    }
    ids_child    = gids_child+cuts[sequence(g-1)]-1

    ids_parent   = if (g==max_len) gid+cuts[g]-1 else {

      gids_par       = vector(mode="list",max_len)
      gids_par[[g]]  = gid
      for (gg in seq(g,max_len-1)) 
        gids_par[[gg+1]] = c(gids_par[[gg]],gids_par[[gg]]+2^gg)

      unlist(mapply(`+`,gids_par,cuts-1))
    }

    c(ids_child,ids_parent)
}

指数以字符数量 nchar(get_01str(id)) 所属的组 g 进行分组。由于指数已按 g 排序,使用 g=findInterval(id,cuts) 可以更快地找到所在组。

对于组为 g(其中 1 < g < max_len),其索引有一个大小为 g-1 的子索引和两个大小为 g+1 的父索引。对于每个子节点,我们取其子节点直到遇到 g==1;对于每个父节点,我们取其一对父节点直到遇到 g==max_len

树的结构最简单的标识符是该组中的标识符 gidgid 映射到两个父节点 gidgid+2^g,反转此映射可找到子节点。

抽样

drawem <- function(n,max_len){
    cuts        = cumsum(2^(1:max_len-1))
    inds_by_g   = mapply(seq,cuts,cuts*2)

    oklens = (1:max_len)[ n <= 2^max_len*(1-2^(-(1:max_len)))+1 ]
    okinds = unlist(inds_by_g[oklens])

    mysamp = rep(0,n)
    for (i in 1:n){

        id        = if (length(okinds)==1) okinds else sample(okinds,1)
        g         = findInterval(id,cuts)
        gid       = id-cuts[g]+1
        nixed     = get_nixstrs(g,gid,max_len)

        # print(id); print(okinds); print(nixed)

        mysamp[i] = id
        okinds    = setdiff(okinds,nixed)
        if (!length(okinds)) break
    }

    res <- rep("",n)
    res[seq.int(i)] <- get_01str(mysamp[seq.int(i)],max_len)
    res
}
oklens 部分整合了 OP 的想法,即省略一些肯定会使采样失败的字符串。然而,即便这样做了,我们可能会走上一条死路。以 OP 的例子为例,当 max_len=4n=10 时,我们知道必须放弃考虑 01 ,但是如果我们前四次抽取的结果是 00011110,那么我们就没有更多的选择了。这就是为什么你实际上应该定义采样概率。(OP 还有另一个想法,即确定每一步将导致不可能状态的节点,但这似乎是一项艰巨的任务。) 示例
# how the indices line up

n_pool = sum(2^(1:max_len)) 
pdt <- data.table(id=1:n_pool)
pdt[,g:=findInterval(id,cuts)]
pdt[,gid:=1:.N,by=g]
pdt[,s:=get_01str(id,max_len)]

# example run

set.seed(4); drawem(5,5)
# [1] "01100" "1"     "0001"  "0101"  "00101"

set.seed(4); drawem(8,4)
# [1] "1100" "0"    "111"  "101"  "1101" "100"  ""     ""  

基准测试(比 @BrodieG 回答中的测试结果要旧)

require(rbenchmark)
max_len = 8
n = 8

benchmark(
      jos_lp     = {
        pool <- unlist(lapply(seq_len(max_len),
          function(x) do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))
        sample.lp(pool, n)},
      bro_string = {pool <- make_pool(max_len);sample01(pool,n)},
      fra_num    = drawem(n,max_len),
      replications=5)[1:5]
#         test replications elapsed relative user.self
# 2 bro_string            5    0.05      2.5      0.05
# 3    fra_num            5    0.02      1.0      0.02
# 1     jos_lp            5    1.56     78.0      1.55

n = 12
max_len = 12
benchmark(
  bro_string={pool <- make_pool(max_len);sample01(pool,n)},
  fra_num=drawem(n,max_len),
  replications=5)[1:5]
#         test replications elapsed relative user.self
# 1 bro_string            5    0.54     6.75      0.51
# 2    fra_num            5    0.08     1.00      0.08

其他答案。 还有另外两个答案:

jos_enum = {pool <- unlist(lapply(seq_len(max_len), 
    function(x) do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))
  get.template(pool, n)}
bro_num  = sample011(max_len,n)    

因为 @josilber 的枚举方法太耗时,我没有包括它;而 @BrodieG 的数字/索引方法在当时不起作用,但现在已经可以了。请参见 @BrodieG 的更新答案以获取更多基准测试信息。

速度与正确性。 虽然 @josilber 的答案速度要慢得多(对于枚举方法,显然占用更多的内存),但它们保证第一次就会从样本中抽取大小为 n 的数据。使用 @BrodieG 的字符串方法或这个答案,你将不得不一遍又一遍地重新抽样,希望能够抽取完整的 n。对于大的 max_len,那应该不是什么问题,我想。

这个答案比 bro_string 更具可扩展性,因为它不需要预先构建 pool


感谢您坚持下去,弗兰克 - 我很感激您的努力。不幸的是,有时需要重复采样才能确保完整的样本,但我想您是对的 - 在每次抽取之前评估已抽取节点是否会阻止完整样本,将会大大减慢速度。当n相对于n_pool较小时,这是一个很好的解决方案。干杯! - jbaums
get_nixstrs 函数中有一个小错误:当采样到 001(id = 11)或 000(id = 7)时,不允许出现 000(例如 get_nixstrs(11, 3, 3)),实际上应该禁止出现 00(id = 3)。我会看看能否解决这个问题。 - jbaums
@jbaums 不错的发现。我认为现在已经修复了。我知道如何根据“g”和“gid”制作get_nixstrs,所以我只是将其替换并添加了一个解释段落。它比直接从“id”计算要慢,但后者很难调试。(在之前的版本中,每次遇到错误输出的情况时,我都只是微调公式。)哎呀,刚才说错了;等一下... - Frank
1
太好了,Frank。感谢你的修复。系统允许时,我会为你的辛劳奖励一份赏金! - jbaums
@jbaums 嘿,感谢您的慷慨奖励!希望您能尽快恢复10k特权 :) - Frank
显示剩余5条评论

13

这是用Python而不是R编写的,但jbaums说可以。

以下是我的贡献,请参见源代码中的注释以了解关键部分的说明。
我仍在努力分析解决方案,确定深度为tS个样本的树可能组合的数量c,以便改进combs函数。也许有人有这个解决方案吗? 这真的是现在的瓶颈。

从深度为16的树中抽取100个节点,在我的笔记本电脑上大约需要8毫秒。不是第一次,但随着抽样次数的增加,它会变得更快,因为combBuffer被填满了。

import random


class Tree(object):
    """
    :param level: The distance of this node from the root.
    :type level: int
    :param parent: This trees parent node
    :type parent: Tree
    :param isleft: Determines if this is a left or a right child node. Can be
                   omitted if this is the root node.
    :type isleft: bool

    A binary tree representing possible strings which match r'[01]{1,n}'. Its
    purpose is to be able to sample n of its nodes where none of the sampled
    nodes' ids is a prefix for another one.
    It is possible to change Tree.maxdepth and then reuse the root. All
    children are created ON DEMAND, which means everything is lazily evaluated.
    If the Tree gets too big anyway, you can call 'prune' on any node to delete
    its children.

        >>> t = Tree()
        >>> t.sample(8, toString=True, depth=3)
        ['111', '110', '101', '100', '011', '010', '001', '000']
        >>> Tree.maxdepth = 2
        >>> t.sample(4, toString=True)
        ['11', '10', '01', '00']
    """

    maxdepth = 10
    _combBuffer = {}

    def __init__(self, level=0, parent=None, isleft=None):
        self.parent = parent
        self.level = level
        self.isleft = isleft
        self._left = None
        self._right = None

    @classmethod
    def setMaxdepth(cls, depth):
        """
        :param depth: The new depth
        :type depth: int

        Sets the maxdepth of the Trees. This basically is the depth of the root
        node.
        """
        if cls.maxdepth == depth:
            return

        cls.maxdepth = depth

    @property
    def left(self):
        """This tree's left child, 'None' if this is a leave node"""
        if self.depth == 0:
            return None

        if self._left is None:
            self._left = Tree(self.level+1, self, True)
        return self._left

    @property
    def right(self):
        """This tree's right child, 'None' if this is a leave node"""
        if self.depth == 0:
            return None

        if self._right is None:
            self._right = Tree(self.level+1, self, False)
        return self._right

    @property
    def depth(self):
        """
        This tree's depth. (maxdepth-level)
        """
        return self.maxdepth-self.level

    @property
    def id(self):
        """
        This tree's id, string of '0's and '1's equal to the path from the root
        to this subtree. Where '1' means going left and '0' means going right.
        """
        # level 0 is the root node, it has no id
        if self.level == 0:
            return ''
        # This takes at most Tree.maxdepth recursions. Therefore
        # it is save to do it this way. We could also save each nodes
        # id once it is created to avoid recreating it every time, however
        # this won't save much time but use quite some space.
        return self.parent.id + ('1' if self.isleft else '0')

    @property
    def leaves(self):
        """
        The amount of leave nodes, this tree has. (2**depth)
        """
        return 2**self.depth

    def __str__(self):
        return self.id

    def __len__(self):
        return 2*self.leaves-1

    def prune(self):
        """
        Recursively prune this tree's children.
        """
        if self._left is not None:
            self._left.prune()
            self._left.parent = None
            self._left = None

        if self._right is not None:
            self._right.prune()
            self._right.parent = None
            self._right = None

    def combs(self, n):
        """
        :param n: The amount of samples to be taken from this tree
        :type n: int
        :returns: The amount of possible combinations to choose n samples from
                  this tree

        Determines recursively the amount of combinations of n nodes to be
        sampled from this tree.
        Subsequent calls with same n on trees with same depth will return the
        result from the previous computation rather than computing it again.

            >>> t = Tree()
            >>> Tree.maxdepth = 4
            >>> t.combs(16)
            1
            >>> Tree.maxdepth = 3
            >>> t.combs(6)
            58
        """

        # important for the amount of combinations is only n and the depth of
        # this tree
        key = (self.depth, n)

        # We use the dict to save computation time. Calling the function with
        # equal values on equal nodes just returns the alrady computed value if
        # possible.
        if key not in Tree._combBuffer:
            leaves = self.leaves

            if n < 0:
                N = 0
            elif n == 0 or self.depth == 0 or n == leaves:
                N = 1
            elif n == 1:
                return (2*leaves-1)
            else:
                if n > leaves/2:
                    # if n > leaves/2, at least n-leaves/2 have to stay on
                    # either side, otherweise the other one would have to
                    # sample more nodes than possible.
                    nMin = n-leaves/2
                else:
                    nMin = 0

                # The rest n-2*nMin is the amount of samples that are free to
                # fall on either side
                free = n-2*nMin

                N = 0
                # sum up the combinations of all possible splits
                for addLeft in range(0, free+1):
                    nLeft = nMin + addLeft
                    nRight = n - nLeft
                    N += self.left.combs(nLeft)*self.right.combs(nRight)

            Tree._combBuffer[key] = N
            return N
        return Tree._combBuffer[key]

    def sample(self, n, toString=False, depth=None):
        """
        :param n: How may samples to take from this tree
        :type n: int
        :param toString: If 'True' result will direclty be turned into a list
                         of strings
        :type toString: bool
        :param depth: If not None, will overwrite Tree.maxdepth
        :type depth: int or None
        :returns: List of n nodes sampled from this tree
        :throws ValueError: when n is invalid

        Takes n random samples from this tree where none of the sample's ids is
        a prefix for another one's.

        For an example see Tree's docstring.
        """
        if depth is not None:
            Tree.setMaxdepth(depth)

        if toString:
            return [str(e) for e in self.sample(n)]

        if n < 0:
            raise ValueError('Negative sample size is not possible!')

        if n == 0:
            return []

        leaves = self.leaves
        if n > leaves:
            raise ValueError(('Cannot sample {} nodes, with only {} ' +
                              'leaves!').format(n, leaves))

        # Only one sample to choose, that is nice! We are free to take any node
        # from this tree, including this very node itself.
        if n == 1 and self.level > 0:
            # This tree has 2*leaves-1 nodes, therefore
            # the probability that we keep the root node has to be
            # 1/(2*leaves-1) = P_root. Lets create a random number from the
            # interval [0, 2*leaves-1).
            # It will be 0 with probability 1/(2*leaves-1)
            P_root = random.randint(0, len(self)-1)
            if P_root == 0:
                return [self]
            else:
                # The probability to land here is 1-P_root

                # A child tree's size is (leaves-1) and since it obeys the same
                # rule as above, the probability for each of its nodes to
                # 'survive' is 1/(leaves-1) = P_child.
                # However all nodes must have equal probability, therefore to
                # make sure that their probability is also P_root we multiply
                # them by 1/2*(1-P_root). The latter is already done, the
                # former will be achieved by the next condition.
                # If we do everything right, this should hold:
                # 1/2 * (1-P_root) * P_child = P_root

                # Lets see...
                # 1/2 * (1-1/(2*leaves-1)) * (1/leaves-1)
                # (1-1/(2*leaves-1)) * (1/(2*(leaves-1)))
                # (1-1/(2*leaves-1)) * (1/(2*leaves-2))
                # (1/(2*leaves-2)) - 1/((2*leaves-2) * (2*leaves-1))
                # (2*leaves-1)/((2*leaves-2) * (2*leaves-1)) - 1/((2*leaves-2) * (2*leaves-1))
                # (2*leaves-2)/((2*leaves-2) * (2*leaves-1))
                # 1/(2*leaves-1)
                # There we go!
                if random.random() < 0.5:
                    return self.right.sample(1)
                else:
                    return self.left.sample(1)

        # Now comes the tricky part... n > 1 therefore we are NOT going to
        # sample this node. Its probability to be chosen is 0!
        # It HAS to be 0 since we are definitely sampling from one of its
        # children which means that this node will be blocked by those samples.
        # The difficult part now is to prove that the sampling the way we do it
        # is really random.

        if n > leaves/2:
            # if n > leaves/2, at least n-leaves/2 have to stay on either
            # side, otherweise the other one would have to sample more
            # nodes than possible.
            nMin = n-leaves/2
        else:
            nMin = 0
        # The rest n-2*nMin is the amount of samples that are free to fall
        # on either side
        free = n-2*nMin

        # Let's have a look at an example, suppose we were to distribute 5
        # samples among two children which have 4 leaves each.
        # Each child has to get at least 1 sample, so the free samples are 3.
        # There are 4 different ways to split the samples among the
        # children (left, right):
        # (1, 4), (2, 3), (3, 2), (4, 1)
        # The amount of unique sample combinations per child are
        # (7, 1), (11, 6), (6, 11), (1, 7)
        # The amount of total unique samples per possible split are
        #   7   ,   66  ,   66  ,    7
        # In case of the first and last split, all samples have a probability
        # of 1/7, this was already proven above.
        # Lets suppose we are good to go and the per sample probabilities for
        # the other two cases are (1/11, 1/6) and (1/6, 1/11), this way the
        # overall per sample probabilities for the splits would be:
        #  1/7  ,  1/66 , 1/66 , 1/7
        # If we used uniform random to determine the split, all splits would be
        # equally probable and therefore be multiplied with the same value (1/4)
        # But this would mean that NOT every sample is equally probable!
        # We need to know in advance how many sample combinations there will be
        # for a given split in order to find out the probability to choose it.
        # In fact, due to the restrictions, this becomes very nasty to
        # determine. So instead of solving it analytically, I do it numerically
        # with the method 'combs'. It gives me the amount of possible sample
        # combinations for a certain amount of samples and a given tree depth.
        # It will return 146 for this node and 7 for the outer and 66 for the
        # inner splits.
        # What we now do is, we take a number from [0, 146).
        # if it is smaller than 7, we sample from the first split,
        # if it is smaller than 7+66, we sample from the second split,
        # ...
        # This way we get the probabilities we need.

        r = random.randint(0, self.combs(n)-1)
        p = 0
        for addLeft in xrange(0, free+1):
            nLeft = nMin + addLeft
            nRight = n - nLeft

            p += (self.left.combs(nLeft) * self.right.combs(nRight))
            if r < p:
                return self.left.sample(nLeft) + self.right.sample(nRight)
        assert False, ('Something really strange happend, p did not sum up ' +
                       'to combs or r was too big')


def main():
    """
    Do a microbenchmark.
    """
    import timeit
    i = 1
    main.t = Tree()
    template = ' {:>2}  {:>5} {:>4}  {:<5}'
    print(template.format('i', 'depth', 'n', 'time (ms)'))
    N = 100
    for depth in [4, 8, 15, 16, 17, 18]:
        for n in [10, 50, 100, 150]:
            if n > 2**depth:
                time = '--'
            else:
                time = timeit.timeit(
                    'main.t.sample({}, depth={})'.format(n, depth), setup=
                    'from __main__ import main', number=N)*1000./N
            print(template.format(i, depth, n, time))
            i += 1


if __name__ == "__main__":
    main()

基准测试输出:
  i  depth    n  time (ms)
  1      4   10  0.182511806488
  2      4   50  --   
  3      4  100  --   
  4      4  150  --   
  5      8   10  0.397620201111
  6      8   50  1.66054964066
  7      8  100  2.90236949921
  8      8  150  3.48146915436
  9     15   10  0.804011821747
 10     15   50  3.7428188324
 11     15  100  7.34910964966
 12     15  150  10.8230614662
 13     16   10  0.804491043091
 14     16   50  3.66818904877
 15     16  100  7.09567070007
 16     16  150  10.404779911
 17     17   10  0.865840911865
 18     17   50  3.9999294281
 19     17  100  7.70257949829
 20     17  150  11.3758206367
 21     18   10  0.915451049805
 22     18   50  4.22935962677
 23     18  100  8.22361946106
 24     18  150  12.2081303596

从深度为10的树中抽取10个大小为10的样本:

['1111010111', '1110111010', '1010111010', '011110010', '0111100001', '011101110', '01110010', '01001111', '0001000100', '000001010']
['110', '0110101110', '0110001100', '0011110', '0001111011', '0001100010', '0001100001', '0001100000', '0000011010', '0000001111']
['11010000', '1011111101', '1010001101', '1001110001', '1001100110', '10001110', '011111110', '011001100', '0101110000', '001110101']
['11111101', '110111', '110110111', '1101010101', '1101001011', '1001001100', '100100010', '0100001010', '0100000111', '0010010110']
['111101000', '1110111101', '1101101', '1101000000', '1011110001', '0111111101', '01101011', '011010011', '01100010', '0101100110']
['1111110001', '11000110', '1100010100', '101010000', '1010010001', '100011001', '100000110', '0100001111', '001101100', '0001101101']
['111110010', '1110100', '1101000011', '101101', '101000101', '1000001010', '0111100', '0101010011', '0101000110', '000100111']
['111100111', '1110001110', '1100111111', '1100110010', '11000110', '1011111111', '0111111', '0110000100', '0100011', '0010110111']
['1101011010', '1011111', '1011100100', '1010000010', '10010', '1000010100', '0111011111', '01010101', '001101', '000101100']
['111111110', '111101001', '1110111011', '111011011', '1001011101', '1000010100', '0111010101', '010100110', '0100001101', '0010000000']

有趣的是看到了 Python 的替代方案。看起来性能与我基于整数的非完全方法相当。你知道这个在大规模数据下表现如何吗?比如绘制 1K 而不是 100,或使用 size 8 而不是 16。实际上,我认为这应该更快,因为我的方法中主要的减速因素之一是由于 R 的矢量化特性难以有效地创建递减池。 - BrodieG
@BrodieG 最计算密集的部分是数字生成。我发现 random.randint(0,n) 对于非常大的n仍然有效,因此切换到了该方法。性能提升是不可思议的,现在只需要8毫秒!采样时间与深度不太相关,而与n更相关。这是因为 'combs' 需要将 n 分解成代码中我称之为所有可能的“拆分”。然后下一个递归实例也必须这样做。分支系数是疯狂的。即使使用缓冲区,在具有深度17的树中取100k个样本也需要很长时间。 - swenzel
很酷,你介意加上一些示例输出吗? - BrodieG
另外,我可能漏掉了什么,但这是否意味着存在非线性关系?您的时间表明100k个样本应该很接近。也许问题在于2 ^ 17约等于100k? - BrodieG
好的,我马上添加一些示例。时间确实表明随着 n 的增加呈线性增长,但我可以向您保证,采样 100k 个项目所需的时间远远超过了 8 秒。我在 5 分钟后停止了它。我猜只有当 combBuffer 被充分填充时,它才是线性的。 - swenzel

11

如果您不想生成所有可能元组的集合然后随机抽样(正如您所指出的,对于大型输入大小可能是不可行的),另一个选项是使用整数规划来绘制单个样本。基本上,您可以为 pool 中的每个元素分配一个随机值,然后选择具有最大值总和的可行元组。这应该给每个元组被选择的相等概率,因为它们都是相同的大小,并且它们的值是随机选择的。模型的约束条件将确保不会选择任何不允许的元组对,并且将选择正确数量的元素。

以下是使用 lpSolve 包的解决方案:

library(lpSolve)
sample.lp <- function(pool, max_len) {
  pool <- sort(pool)
  pml <- max(nchar(pool))
  runs <- c(rev(cumsum(2^(seq(pml-1)))), 0)
  banned.from <- rep(seq(pool), runs[nchar(pool)])
  banned.to <- banned.from + unlist(lapply(runs[nchar(pool)], seq_len))
  banned.constr <- matrix(0, nrow=length(banned.from), ncol=length(pool))
  banned.constr[cbind(seq(banned.from), banned.from)] <- 1
  banned.constr[cbind(seq(banned.to), banned.to)] <- 1
  mod <- lp(direction="max",
            objective.in=runif(length(pool)),
            const.mat=rbind(banned.constr, rep(1, length(pool))),
            const.dir=c(rep("<=", length(banned.from)), "=="),
            const.rhs=c(rep(1, length(banned.from)), max_len),
            all.bin=TRUE)
  pool[which(mod$solution == 1)]
}
set.seed(144)
pool <- unlist(lapply(seq_len(4), function(x) do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))
sample.lp(pool, 4)
# [1] "0011" "010"  "1000" "1100"
sample.lp(pool, 8)
# [1] "0000" "0100" "0110" "1001" "1010" "1100" "1101" "1110"

这似乎可扩展到相当大的池子。例如,从大小为510的池中获取长度为20的样本仅需2秒左右:

pool <- unlist(lapply(seq_len(8), function(x) do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))
length(pool)
# [1] 510
system.time(sample.lp(pool, 20))
#    user  system elapsed 
#   0.232   0.008   0.239 

如果您需要解决非常非常大的问题规模,那么您可以从随 lpSolve 发行的非开源求解器转向商业求解器,如gurobi或cplex(通常不免费,但对于学术用途是免费的)。


不错的方法。不幸的是,对于较大的“pool”(例如,在max_len = 9的池中绘制4需要11.5秒),这似乎变得相当缓慢,而对于大型池和小的“n”,我的问题中的“while”方法几乎是即时的。我想我可以根据“n”和“max_len”之间的关系在这些方法之间切换。感谢您提供有关可用求解器的信息。 - jbaums
@jbaums 你用什么大小的池子进行了11.5秒的计时?是的,我期望这对于max_len的大值最为适用,这里拒绝采样找不到任何东西,而枚举则难以处理。 - josliber
大小为1022(元素的最大长度为9个字符),我正在绘制大小为4的样本。 - jbaums
1
@jbaums 大部分时间都用于计算不能一起使用的元素对。我改成了一个使用 pool 结构的版本,你引用的例子现在只需要 2.2 秒就能运行(其中大部分时间用于解决优化问题)。 - josliber

9

简介

我觉得这个问题很有趣,所以我不得不思考它,并最终提供自己的答案。由于我得出的算法不是立即从问题描述中得出的,因此我将首先解释我是如何得出这个解决方案的,然后提供一个C++的示例实现(我从未编写过R)。

解决方案的开发

初步看法

阅读问题描述最初令人困惑,但是当我看到树的图片编辑时,我立即理解了问题,我的直觉表明二叉树也是一种解决方案:构建一棵树(大小为1的树集合),并在进行选择后消除分支和祖先,将树分解成较小的树集合。

虽然这一开始看起来不错,但是选择过程和集合的维护会很麻烦。不过,树似乎应该在任何解决方案中都扮演重要角色。

修订1

不要分解树。相反,在每个节点上具有一个布尔数据负载,指示它是否已被消除。这只留下一个保持形式的树。

请注意,这不仅仅是任何二叉树,它实际上是深度为max_len-1的完全二叉树。

修订2

完全二叉树可以非常好地表示为数组。树的典型数组表示使用广度优先搜索,具有以下属性:

Let x be the array index.
x = 0 is the root of the entire tree
left_child(x) = 2x + 1
right_child(x) = 2x + 2
parent(x) = floor((n-1)/2)

在下面的图示中,每个节点都标有其数组索引: breadth-first

作为一个数组,它占用的内存较少(没有更多指针),所使用的内存是连续的(对于缓存很好),并且可以完全放在堆栈上而不是堆(假设您的语言给您选择)。当然,这里有一些条件,特别是数组的大小。我稍后会回到这个问题。
就像在修订1中一样,存储在数组中的数据将是布尔值:true表示可用,false表示不可用。由于根节点实际上不是一个有效的选择,因此索引0应该初始化为false。仍然存在如何进行选择的问题:
随着指数的消失,跟踪已经被消除的数量以及剩余的数量是微不足道的。在该范围内选择一个随机数,然后遍历数组,直到看到那么多个设置为true的指数(包括当前指数)。然后到达的索引就是要进行的选择。 选择n个指数,或者没有东西可选。
这是一个完整的算法,它将起作用,但是在选择过程中还有改进的空间,而且还有尚未解决的实际大小问题:数组大小将为O(2^n)。随着n的增大,首先缓存效益消失,然后数据开始被分页到磁盘上,并且在某些时候变得无法存储。
修订3
我决定先解决更容易的问题:改进选择过程。
从左到右扫描数组是浪费的。跟踪已经被消除的范围可能比检查并找到几个错误更有效。然而,我们的树形表示对于这一点并不理想,因为每轮将被消除的节点中很少有连续的节点在数组中。
通过重新排列数组如何映射到树,可以更好地利用这一点。特别是,让我们使用前序深度优先搜索,而不是广度优先搜索。为了这样做,树必须固定大小,这是这个问题的情况。子节点和父节点的索引如何在数学上连接也不那么明显。
通过使用这种排列方式,保证每个非叶子选择都会消除一个连续的范围:它的子树。
修订4
通过跟踪已消除的范围,不再需要真/假数据,因此根本不需要数组或树。 在每次随机抽取时,已消除的范围可以用于快速找到要选择的节点。所有祖先和整个子树都被消除,并且可以表示为范围,可以轻松地与其他范围合并。
最终任务是将所选节点转换为OP想要的字符串表示形式。由于这个二叉树仍然保持着严格的顺序:从根开始遍历,所有元素>=右子节点的都在右边,而其他元素都在左边。因此,在搜索树时,通过在向左遍历时添加“0”或在向右遍历时添加“1”,可以提供祖先列表和二进制字符串。

实现样例

#include <stdint.h>
#include <algorithm>
#include <cmath>
#include <list>
#include <deque>
#include <ctime>
#include <cstdlib>
#include <iostream>

/*
 * A range of values of the form (a, b), where a <= b, and is inclusive.
 * Ex (1,1) is the range from 1 to 1 (ie: just 1)
 */
class Range
{
private:
    friend bool operator< (const Range& lhs, const Range& rhs);
    friend std::ostream& operator<<(std::ostream& os, const Range& obj);

    int64_t m_start;
    int64_t m_end;

public:
    Range(int64_t start, int64_t end) : m_start(start), m_end(end) {}
    int64_t getStart() const { return m_start; }
    int64_t getEnd() const { return m_end; }
    int64_t size() const { return m_end - m_start + 1; }
    bool canMerge(const Range& other) const {
        return !((other.m_start > m_end + 1) || (m_start > other.m_end + 1));
    }
    int64_t merge(const Range& other) {
        int64_t change = 0;
        if (m_start > other.m_start) {
            change += m_start - other.m_start;
            m_start = other.m_start;
        }
        if (other.m_end > m_end) {
            change += other.m_end - m_end;
            m_end = other.m_end;
        }
        return change;
    }
};

inline bool operator< (const Range& lhs, const Range& rhs){return lhs.m_start < rhs.m_start;}
std::ostream& operator<<(std::ostream& os, const Range& obj) {
    os << '(' << obj.m_start << ',' << obj.m_end << ')';
    return os;
}

/*
 * Stuct to allow returning of multiple values
 */
struct NodeInfo {
    int64_t subTreeSize;
    int64_t depth;
    std::list<int64_t> ancestors;
    std::string representation;
};

/*
 * Collection of functions representing a complete binary tree
 * as an array created using pre-order depth-first search,
 * with 0 as the root.
 * Depth of the root is defined as 0.
 */
class Tree
{
private:
    int64_t m_depth;
public:
    Tree(int64_t depth) : m_depth(depth) {}
    int64_t size() const {
        return (int64_t(1) << (m_depth+1))-1;
    }
    int64_t getDepthOf(int64_t node) const{
        if (node == 0) { return 0; }
        int64_t searchDepth = m_depth;
        int64_t currentDepth = 1;
        while (true) {
            int64_t rightChild = int64_t(1) << searchDepth;
            if (node == 1 || node == rightChild) {
                break;
            } else if (node > rightChild) {
                node -= rightChild;
            } else {
                node -= 1;
            }
            currentDepth += 1;
            searchDepth -= 1;
        }
        return currentDepth;
    }
    int64_t getSubtreeSizeOf(int64_t node, int64_t nodeDepth = -1) const {
        if (node == 0) {
            return size();
        }
        if (nodeDepth == -1) {
            nodeDepth = getDepthOf(node);
        }
        return (int64_t(1) << (m_depth + 1 - nodeDepth)) - 1;
    }
    int64_t getLeftChildOf(int64_t node, int64_t nodeDepth = -1) const {
        if (nodeDepth == -1) {
            nodeDepth = getDepthOf(node);
        }
        if (nodeDepth == m_depth) { return -1; }
        return node + 1;
    }
    int64_t getRightChildOf(int64_t node, int64_t nodeDepth = -1) const {
        if (nodeDepth == -1) {
            nodeDepth = getDepthOf(node);
        }
        if (nodeDepth == m_depth) { return -1; }
        return node + 1 + ((getSubtreeSizeOf(node, nodeDepth) - 1) / 2);
    }
    NodeInfo getNodeInfo(int64_t node) const {
        NodeInfo info;
        int64_t depth = 0;
        int64_t currentNode = 0;
        while (currentNode != node) {
            if (currentNode != 0) {
                info.ancestors.push_back(currentNode);
            }
            int64_t rightChild = getRightChildOf(currentNode, depth);
            if (rightChild == -1) {
                break;
            } else if (node >= rightChild) {
                info.representation += '1';
                currentNode = rightChild;
            } else {
                info.representation += '0';
                currentNode = getLeftChildOf(currentNode, depth);
            }
            depth++;
        }
        info.depth = depth;
        info.subTreeSize = getSubtreeSizeOf(node, depth);
        return info;
    }
};

// random selection amongst remaining allowed nodes
int64_t selectNode(const std::deque<Range>& eliminationList, int64_t poolSize, std::mt19937_64& randomGenerator)
{
    std::uniform_int_distribution<> randomDistribution(1, poolSize);
    int64_t selection = randomDistribution(randomGenerator);
    for (auto const& range : eliminationList) {
        if (selection >= range.getStart()) { selection += range.size(); }
        else { break; }
    }
    return selection;
}

// determin how many nodes have been elimintated
int64_t countEliminated(const std::deque<Range>& eliminationList)
{
    int64_t count = 0;
    for (auto const& range : eliminationList) {
        count += range.size();
    }
    return count;
}

// merge all the elimination ranges to listA, and return the number of new elimintations
int64_t mergeEliminations(std::deque<Range>& listA, std::deque<Range>& listB) {
    if(listB.empty()) { return 0; }
    if(listA.empty()) {
        listA.swap(listB);
        return countEliminated(listA);
    }

    int64_t newEliminations = 0;
    int64_t x = 0;
    auto listA_iter = listA.begin();
    auto listB_iter = listB.begin();
    while (listB_iter != listB.end()) {
        if (listA_iter == listA.end()) {
            listA_iter = listA.insert(listA_iter, *listB_iter);
            x = listB_iter->size();
            assert(x >= 0);
            newEliminations += x;
            ++listB_iter;
        } else if (listA_iter->canMerge(*listB_iter)) {
            x = listA_iter->merge(*listB_iter);
            assert(x >= 0);
            newEliminations += x;
            ++listB_iter;
        } else if (*listB_iter < *listA_iter) {
            listA_iter = listA.insert(listA_iter, *listB_iter) + 1;
            x = listB_iter->size();
            assert(x >= 0);
            newEliminations += x;
            ++listB_iter;
        } else if ((listA_iter+1) != listA.end() && listA_iter->canMerge(*(listA_iter+1))) {
            listA_iter->merge(*(listA_iter+1));
            listA_iter = listA.erase(listA_iter+1);
        } else {
            ++listA_iter;
        }
    }
    while (listA_iter != listA.end()) {
        if ((listA_iter+1) != listA.end() && listA_iter->canMerge(*(listA_iter+1))) {
            listA_iter->merge(*(listA_iter+1));
            listA_iter = listA.erase(listA_iter+1);
        } else {
            ++listA_iter;
        }
    }
    return newEliminations;
}

int main (int argc, char** argv)
{
    std::random_device rd;
    std::mt19937_64 randomGenerator(rd());

    int64_t max_len = std::stoll(argv[1]);
    int64_t num_samples = std::stoll(argv[2]);

    int64_t samplesRemaining = num_samples;
    Tree tree(max_len);
    int64_t poolSize = tree.size() - 1;
    std::deque<Range> eliminationList;
    std::deque<Range> eliminated;
    std::list<std::string> foundList;

    while (samplesRemaining > 0 && poolSize > 0) {
        // find a valid node
        int64_t selectedNode = selectNode(eliminationList, poolSize, randomGenerator);
        NodeInfo info = tree.getNodeInfo(selectedNode);
        foundList.push_back(info.representation);
        samplesRemaining--;

        // determine which nodes this choice eliminates
        eliminated.clear();
        for( auto const& ancestor : info.ancestors) {
            Range r(ancestor, ancestor);
            if(eliminated.empty() || !eliminated.back().canMerge(r)) {
                eliminated.push_back(r);
            } else {
                eliminated.back().merge(r);
            }
        }
        Range r(selectedNode, selectedNode + info.subTreeSize - 1);
        if(eliminated.empty() || !eliminated.back().canMerge(r)) {
            eliminated.push_back(r);
        } else {
            eliminated.back().merge(r);
        }

        // add the eliminated nodes to the existing list
        poolSize -= mergeEliminations(eliminationList, eliminated);
    }

    // Print some stats
    // std::cout << "tree: " << tree.size() << " samplesRemaining: "
    //                       << samplesRemaining << " poolSize: "
    //                       << poolSize << " samples: " << foundList.size()
    //                       << " eliminated: "
    //                       << countEliminated(eliminationList) << std::endl;

    // Print list of binary strings
    // std::cout << "list:";
    // for (auto const& s : foundList) {
    //  std::cout << " " << s;
    // }
    // std::cout << std::endl;
}

补充思考

这个算法在 max_len 很大的情况下很容易扩展。虽然根据我的分析,它在 n 方面的扩展效果不是很好,但是似乎比其他解决方案更好。

对于包含除了“0”和“1”之外的其他字符的字符串,可以很容易地修改此算法。 在单词中有更多的字母将增加树的扇区,并且每次选择将消除更广泛的范围 - 仍然保持每个子树中所有节点连续。


9
一种方法是采用迭代的方式生成所有可能的符合大小要求的元组:
  1. 构建所有大小为1的元组(即 pool 中的所有元素)
  2. pool 中的元素进行笛卡尔积运算
  3. 删除使用 pool 中同一元素多次的任何元组
  4. 删除任何其他元组的完全重复项
  5. 删除带有无法一起使用的一对的任何元组
  6. 反复洗涤,直到获得适当的元组大小
对于给定的大小(pool 的长度为30,max_len 为4),可以运行此代码。
get.template <- function(pool, max_len) {
  banned <- which(outer(paste0('^', pool), pool, Vectorize(grepl)), arr.ind=T)
  banned <- banned[banned[,1] != banned[,2],]
  banned <- paste(banned[,1], banned[,2])
  vals <- matrix(seq(length(pool)))
  for (k in 2:max_len) {
    vals <- cbind(vals[rep(1:nrow(vals), each=length(pool)),],
                  rep(1:length(pool), nrow(vals)))
    # Can't sample same value more than once
    vals <- vals[apply(vals, 1, function(x) length(unique(x)) == length(x)),]
    # Sort rows to ensure unique only
    vals <- t(apply(vals, 1, sort))
    vals <- unique(vals)
    # Can't have banned pair
    combos <- combn(ncol(vals), 2)
    for (k in seq(ncol(combos))) {
        c1 <- combos[1,k]
        c2 <- combos[2,k]
        vals <- vals[!paste(vals[,c1], vals[,c2]) %in% banned,]
    }
  }
  return(matrix(pool[vals], nrow=nrow(vals)))
}

max_len <- 4
pool <- unlist(lapply(seq_len(max_len), function(x) do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))
system.time(template <- get.template(pool, 4))
#   user  system elapsed 
#  4.549   0.050   4.614 

现在您可以从template的行中随意进行多次采样(这将非常快),这与从定义空间中随机采样相同。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接