如何将一个列表重塑为一个n维列表
输入:
list = [1, 2, 3, 4, 5, 6, 7, 8]
shape = [2, 2, 2]
输出 = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
如何将一个列表重塑为一个n维列表
输入:
list = [1, 2, 3, 4, 5, 6, 7, 8]
shape = [2, 2, 2]
输出 = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
lst = [1, 2, 3, 4, 5, 6, 7, 8]
shape = [2, 2, 2]
from functools import reduce
from operator import mul
def reshape(lst, shape):
if len(shape) == 1:
return lst
n = reduce(mul, shape[1:])
return [reshape(lst[i*n:(i+1)*n], shape[1:]) for i in range(len(lst)//n)]
reshape(lst, shape)
assert reduce(mul, shape) == len(lst)
虽然这是一个很旧的帖子,但由于我现在正在寻找一种比我自己的方法更加优雅的方式,所以我只是告诉你我的方法。
# first, i create some data
l = [ i for i in range(256) ]
# now I reshape in to slices of 4 items
x = [ l[x:x+4] for x in range(0, len(l), 4) ]
这里介绍了一种方法,使用 grouper 函数在除第一维外的每个维度上都应用一次:
import functools as ft
# example
L = list(range(2*3*4))
S = 2,3,4
# if tuples are acceptable
tuple(ft.reduce(lambda x, y: zip(*y*(x,)), (iter(L), *S[:0:-1])))
# (((0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)), ((12, 13, 14, 15), (16, 17, 18, 19), (20, 21, 22, 23)))
# if it must be lists
list(ft.reduce(lambda x, y: map(list, zip(*y*(x,))), (iter(L), *S[:0:-1])))
# [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]
*S
项在Python 2.7中无法识别:SyntaxError: invalid syntax
。 - FangQB
是一个元组,那么在 py2 中 (a,*B)
可以被替换为 (a,) + B
。 - Paul Panzer下面的代码应该可以解决问题。
下面给出的解决方案非常通用。输入列表可以是任意形状的嵌套列表; 它不必是整数列表。
此外,还有单独的可重复使用工具。例如,all_for_one
函数非常方便。
编辑:
我忘记了一个重要的事情。如果在shape
参数中放入1
,则可能会得到多余的列表嵌套(只有一个列表内部的一个列表,而不是五个或六个列表内部的一个列表)
例如,如果shape
为[1, 1, 2]
那么返回值可能是[[[0.1, 0.2]]]
而不是[0.1, 0.2]
shape
的长度是输出列表中有效下标的数量。
例如,
shape = [1, 2] # length 2
lyst = [[0.1, 0.2]]
print(lyst[0][0]) # valid.... no KeyError raised
len(shape)
必须为 1
。
例如,shape = [49]
将给你一个长度为 49
的行/列向量。shape = [2] # length 2
output = [0.1, 0.2]
print(lyst[0])
以下是代码:
from operator import mul
import itertools as itts
import copy
import functools
one_for_all = lambda one: itts.repeat(one, 1)
def all_for_one(lyst):
"""
EXAMPLE:
INPUT:
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
OUTPUT:
iterator to [1, 2, 3, 4, 5, 6, 7, 8]
IN GENERAL:
Gets iterator to all nested elements
of a list of lists of ... of lists of lists.
"""
# make an iterator which **IMMEDIATELY**
# raises a `StopIteration` exception
its = itts.repeat("", 0)
for sublyst in lyst:
if hasattr(sublyst, "__iter__") and id(sublyst) != id(lyst):
# Be careful ....
#
# "string"[0] == "s"[0] == "s"[0][0][0][0][0][0]...
#
# do not drill down while `sublyst` has an "__iter__" method
# do not drill down while `sublyst` has a `__getitem__` method
#
it = all_for_one(sublyst)
else:
it = one_for_all(sublyst)
# concatenate results to what we had previously
its = itts.chain(its, it)
return its
merged = list(all_for_one([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]))
print("merged == ", merged)
def reshape(xread_lyst, xshape):
"""
similar to `numpy.reshape`
EXAMPLE:
lyst = [1, 2, 3, 4, 5, 6, 7, 8]
shape = [2, 2, 2]
result = reshape(lyst)
print(result)
result ==
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
For this function, input parameter `xshape` can be
any iterable containing at least one element.
`xshape` is not required to be a tuple, but it can be.
The length of xshape should be equal to the number
of desired list nestings
If you want a list of integers: len(xshape) == 1
If you want a list of lists: len(xshape) == 2
If you want a list of lists of lists: len(xshape) == 3
If xshape = [1, 2],
outermost list has 1 element
that one element is a list of 2 elements.
result == [[1, 2]]
If xshape == [2]
outermost list has 2 elements
those 2 elements are non-lists:
result: [1, 2]
If xshape = [2, 2],
outermost list has 2 elements
each element is a list of 2 elements.
result == [[1, 2] [3, 4]]
"""
# BEGIN SANITIZING INPUTS
# unfortunately, iterators are not re-usable
# Also, they don't have `len` methods
iread_lyst = [x for x in ReshapeTools.unnest(xread_lyst)]
ishape = [x for x in self.unnest(xshape)]
number_of_elements = functools.reduce(mul, ishape, 1)
if(number_of_elements != len(iread_lyst)):
msg = [str(x) for x in [
"\nAn array having dimensions ", ishape,
"\nMust contain ", number_of_elements, " element(s).",
"\nHowever, we were only given ", len(iread_lyst), " element(s)."
]]
if len(iread_lyst) < 10:
msg.append('\nList before reshape: ')
msg.append(str([str(x)[:5] for x in iread_lyst]))
raise TypeError(''.join(msg))
ishape = iter(ishape)
iread_lyst = iter(iread_lyst)
# END SANITATIZATION OF INPUTS
write_parent = list()
parent_list_len = next(ishape)
try:
child_list_len = next(ishape)
for _ in range(0, parent_list_len):
write_child = []
write_parent.append(write_child)
i_reshape(write_child, iread_lyst, child_list_len, copy.copy(ishape))
except StopIteration:
for _ in range(0, parent_list_len):
write_child = next(iread_lyst)
write_parent.append(write_child)
return write_parent
def ilyst_reshape(write_parent, iread_lyst, parent_list_len, ishape):
"""
You really shouldn't call this function directly.
Try calling `reshape` instead
The `i` in the name of this function stands for "internal"
"""
try:
child_list_len = next(ishape)
for _ in range(0, parent_list_len):
write_child = []
write_parent.append(write_child)
ilyst_reshape(write_child, iread_lyst, child_list_len, copy.copy(ishape))
except StopIteration:
for _ in range(0, parent_list_len):
write_child = next(iread_lyst)
write_parent.append(write_child)
return None
three_dee_mat = reshape(merged, [2, 2, 2])
print("three_dee_mat == ", three_dee_mat)
from functools import reduce
from itertools import islice
l=[1,2,3,4,5,6,7,8,9,0,1,2,3,4,5,6,7,8,9,0,1,2,3,4]
s=[2,3,4]
if s and reduce(lambda x,y:x*y, s) == len(l):
# if number of elements matches product of dimensions,
# the first dimension is actually redundant
s=[1:]
else:
print("length of input list does not match shape")
return
while s:
size = s.pop() # how many elements for this dimension
#split the list based on the size of the dimension
it=iter(l)
l = list(iter(lambda:list(islice(it,size)),[]))
# [[[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]],
# [[3, 4, 5, 6], [7, 8, 9, 0], [1, 2, 3, 4]]]
def reshape(lst, shape):
if not shape:
return lst
n = shape[0]
lst = [lst[i * n:(i + 1) * n]
for i in range(ceil(len(lst) / n))]
return reshape(lst, shape[1:])
reshape([1, 2, 3], [5]) → [1, 2, 3, 1, 2]
。 - zoomlogo