我一直在尝试将我的代码实现在GPU上运行,但是收效甚微。希望有人能帮助实现。
关于问题,我有一个包含N个节点的图G和每个节点x上的分布mx。我想计算所有边缘上的每一对节点之间分布的距离。对于给定的一对节点(x, y),我使用python POT软件包中的ot.sinkhorn(mx, my, dNxNy)
代码来计算距离。这里,mx、my是大小为Nx和Ny的向量,在节点x和y上,并且dNxNy是一个Nx x Ny的距离矩阵。
现在,我发现有一个GPU实现的代码ot.gpu.sinkhorn(mx, my, dNxNy)
。但是这并不够好,因为每次迭代都需要上传mx、my和dNxNy到GPU,这是一个巨大的开销。所以,思路就是在GPU上并行化处理所有边缘。
代码的核心如下所示。mx_all是所有分布的集合。
for i,e in enumerate(G.edges):
W[i] = W_comp(mx_all,dist,e)
def W_comp(mx_all, dist, e):
i = e[0]
j = e[1]
Nx = np.array(mx_all[i][1]).flatten()
Ny = np.array(mx_all[j][1]).flatten()
mx = np.array(mx_all[i][0]).flatten()
my = np.array(mx_all[j][0]).flatten()
dNxNy = dist[Nx,:][:,Ny].copy(order='C')
W = ot.sinkhorn2(mx, my, dNxNy, 1)
以下是一个最小化工作示例。请忽略除虚线
===
之间的部分以外的所有内容。import ot
import numpy as np
import scipy as sc
def main():
import networkx as nx
#some example graph
G = nx.planted_partition_graph(4, 20, 0.6, 0.3, seed=2)
L = nx.normalized_laplacian_matrix(G)
#this just computes all distributions (IGNORE)
mx_all = []
for i in G.nodes:
mx_all.append(mx_comp(L,1,1,i))
#some random distance matrix (IGNORE)
dist = np.random.randint(5,size=(nx.number_of_nodes(G),nx.number_of_nodes(G)))
# =============================================================================
#this is what needs to be parallelised on GPU
W = np.zeros(nx.Graph.size(G))
for i,e in enumerate(G.edges):
print(i)
W[i] = W_comp(mx_all,dist,e)
return W
def W_comp(mx_all, dist, e):
i = e[0]
j = e[1]
Nx = np.array(mx_all[i][1]).flatten()
Ny = np.array(mx_all[j][1]).flatten()
mx = np.array(mx_all[i][0]).flatten()
my = np.array(mx_all[j][0]).flatten()
dNxNy = dist[Nx,:][:,Ny].copy(order='C')
return ot.sinkhorn2(mx, my, dNxNy,1)
# =============================================================================
#some other functions (IGNORE)
def delta(i, n):
p0 = np.zeros(n)
p0[i] = 1.
return p0
# all neighbourhood densities
def mx_comp(L, t, cutoff, i):
N = np.shape(L)[0]
mx_all = sc.sparse.linalg.expm_multiply(-t*L, delta(i, N))
Nx_all = np.argwhere(mx_all > (1-cutoff)*np.max(mx_all))
return mx_all, Nx_all
if __name__ == "__main__":
main()
谢谢!