Python笛卡尔积和条件?

6
在Python中,我使用itertools.product()函数来生成模拟的输入参数。
我有一个测试函数需要4个输入参数a1、a2、b1和b2。我使用以下代码生成这些参数。例如:
params = itertools.product(range(10,41,2), range(10,41,2), range(0, 2), range(5, 31, 5))

……这给了我3072种组合。不幸的是,有些组合在逻辑上是毫无意义的。例如,如果a2大于a1,则测试结果是无用的;当b1等于0时,b2的值完全无关紧要——因此测试这样的组合没有任何意义。

除了手动嵌套for循环之外,是否有可能限制或过滤笛卡尔积?因为我的真实用例有超过4个参数,所以我喜欢itertools的笛卡尔积函数的便利性。

有什么想法或替代方案吗? 感谢任何帮助。


1
如果您的用例显著复杂,您可以考虑使用像 numpy 这样的数组包。 (例如,这不完全是您想要的,但可能相关:https://dev59.com/uoPba4cB1Zd3GeqPpTqO) - Lack
itertools.product 没有筛选参数,因此您需要自己处理 - 要么编写类似于 itertools.product 文档中示例的生成器,要么在生成后进行筛选。 - wwii
5个回答

6

Python 3

在Python 3中,您可以使用itertools.filterfalse过滤掉不需要的组合:

# predicate is true when need to skip the combination
predicate = (lambda (a1, a2, b1, b2): a1 <= a2 and (b1 != 0 or b2 == 5), params)
filtered_params = itertools.filterfalse(predicate, params)

Python 2

您可以使用列表推导式或itertools.ifilter

filtered_params = itertools.ifilter
    (lambda (a1, a2, b1, b2): a1 <= a2 and (b1 != 0 or b2 == 5), params)

请注意,这两个版本在底层都会进行循环和筛选操作。如果您想避免这种情况,您需要构建一种改进的算法,可以创建不包含不必要内容的元组。

4
如果您有很多参数,使用像python-constraint这样的模块进行基于约束的方法可能更容易处理 - 让它来解决哪些组合是有效的,并让其完成困难的工作。实现方式大致如下:
from constraint import Problem

prob = Problem()
prob.addVariables(["a1", "a2"], range(10,41,2))
prob.addVariable("b1", [0, 2])
prob.addVariable("b2", range(5, 31, 5))
prob.addConstraint(lambda a1, a2: a2 <= a1, ["a1", "a2"])
prob.addConstraint(lambda b1, b2: b1 != 0 or b2 == 5, ["b1", "b2"])

for params in prob.getSolutionIter():
    run_sim(**params)

太棒了,谢谢。我的理解是这个解决方案不会先创建一个巨大的矩阵,然后再将其缩小...而是只添加那些有意义的部分! - Alen
@Alen:我相信它使用分支定界法——它查看a1的每个值,然后查看不违反任何约束条件的a2的每个值,以此类推。 - Hugh Bothwell

1
一种选择是将params作为另一个生成器,该生成器本身由itertools.product提供输入。

例如:
params = (prod for prod in itertools.product(...) if prod[2] <= prod[1])

你可以根据条件在if后添加任何内容。例如,prod[2] <= prod[1] and prod[3] != 0将检查你在问题中提到的条件,仅允许通过测试的结果,并丢弃未通过测试的产品。

0

您可以将列表推导式与任何参数限制结合使用。我建议在此之前将参数放入一个集合中,以确保没有不必要的代码。虽然在您上面提到的情况下不会发生这种情况,但并不总是使用range来生成参数选项。

例如,在这里创建了一组元组参数列表,仅当参数1大于参数2 + 10时才是有效组合:

acceptableParamCombinations = 
[ (p1,p2) for p1 in set(range(10,41,2)) for p2 in set(range(10,41,2)) if p1 > p2 + 10 ]

0
在这种情况下,使用numpy的向量操作来表达您的规则可能是最方便、直观和易读的。例如:
import numpy as np

arr = np.array(list(params), dtype = [('a1',int),('a2',int),('b1',int),('b2',int)])
arr = arr[ arr['a2'] <= arr['a1'] ]
arr = arr[ arr['b1'] != 0 ]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接