使用基本库优化Python代码

4

我正在尝试使用基本python在一个有170万行和4个变量的表格上进行非等号自连接。数据如下所示:

product     position_min     position_max      count_pos
A.16        167804              167870              20
A.18        167804              167838              15
A.15        167896              167768              18
A.20        238359              238361              33
A.35        167835              167837              8

这是我使用的代码:

import csv
from collections import defaultdict
import sys
import os

list_csv=[]
l=[]
with open(r'product.csv', 'r') as file1:
    my_reader1 = csv.reader(file1, delimiter=';')
    for row in my_reader1:
        list_csv.append(row)
with open(r'product.csv', 'r') as file2:
    my_reader2 = csv.reader(file2, delimiter=';') 
    with open('product_p.csv', "w") as csvfile_write:
        ecriture = csv.writer(csvfile_write, delimiter=';',
                                quotechar='"', quoting=csv.QUOTE_ALL)
        for row in my_reader2:
            res = defaultdict(list)
            for k in range(len(list_csv)):
                comp= list_csv[k]
                try:
                    if int(row[1]) >= int(comp[1]) and int(row[2]) <= int(comp[2]) and row[0] != comp[0]:
                        res[row[0]].append([comp[0],comp[3]]) 
                except:
                    pass
            


            if bool(res):    
                for key, value in res.items():
                    sublists = defaultdict(list)
                    for sublist in value:
                        l=[]
                        sublists[sublist[0]].append(int(sublist[1]))
                    l.append(str(key) + ";"+ str(min(sublists.keys(), key=(lambda k: sublists[k]))))
                        ecriture.writerow(l)

我应该在“product_p.csv”文件中得到这个:
'A.18'; 'A.16'
'A.15'; 'A.18'
'A.35'; 'A.18' 

代码的作用是两次读取同一个文件,第一次完全读取并将其转换为列表,第二次逐行读取,并根据 position_min 和 position_max 的条件找到每个产品(第一个变量)所属的所有产品,然后通过保留 count_pos 最小的产品来选择其中之一。
我在原始数据的样本上尝试了这个代码,它可以工作,但是在170万行的情况下,运行数小时仍然没有结果。有没有一种方法可以在没有或者少循环的情况下进行操作?是否有人可以帮助使用基本的Python库进行优化?
提前感谢您。

@Kshitiz 谢谢你的回答。这段代码的作用是两次读取同一个文件,第一次完全读取并将其转换为列表,第二次逐行读取,以查找每个产品(第一个变量)所属的所有产品,条件是在position_min和position_max上,并选择仅保留count_pos最小的产品。 - Stella
你实际上得到的是相反的,因为你想要得到的是 A16 A35,但你实际上得到的是 A35 A16,这样可以吗? - imxitiz
@Kshitiz,我纠正了编译代码时应该得到的内容。 - Stella
现在我不完全明白你想做什么,但我曾尝试使用pandas来完成,但是pandas的速度比你的代码慢。我已经测试了2000个数据集,但是你的代码比我的快。如果我确切地理解了你要做什么,我可能会尝试其他方法!我注意到你的代码中,你不必为完全相同的数据读取文件两次,你可以在第二次使用之前的数据。我没有检查这是否使你的代码更快,但我注意到了这一点。 - imxitiz
我不能使用pandas,只能使用基本的Python库。我尝试过一次读取文件,但在循环中逐行读取时却无法工作。我不知道是否有一种方法可以在没有或更少循环的情况下完成代码所做的事情,你有什么想法吗? - Stella
显示剩余6条评论
3个回答

3

我认为在这里需要采用不同的方法,因为将每个产品与其他产品进行比较总会产生O(n^2)的时间复杂度。

我按照升序排列了产品列表中的position_min(并且降序排列了position_max,以防万一),并将上面答案中的检查反转:不再看comp是否“包含”ref,而是相反地进行。这样,就可以只针对那些position_min更高的产品进行检查,并且在找到一个compposition_min高于refposition_max时停止搜索。

为了测试这个解决方案,我生成了一个随机的100个产品列表,并分别运行了从上面答案中复制的一个函数和一个基于我的建议的函数。后者执行了大约1000次比较,而不是10000次,根据timeit的显示,尽管由于初始排序而产生的开销,但它快了4倍左右。

代码如下:

##reference function
def f1(basedata):
    outd={}
    for ref in basedata:
        for comp in basedata:
            if ref == comp:
                continue
            elif ref[1] >= comp[1] and ref[2] <= comp[2]:
                if not outd.get(ref[0], False) or comp[3] < outd[ref[0]][1]:
                    outd[ref[0]] = (comp[0], comp[3])
    return outd

##optimized(?) function
def f2(basedata):
    outd={}
    sorteddata = sorted(basedata, key=lambda x:(x[1],-x[2]))
    runs = 0
    for i,ref in enumerate(sorteddata):
        toohigh=False
        j=i
        while j < len(sorteddata)-1 and not toohigh:
            j+=1
            runs+=1
            comp=sorteddata[j]
            if comp[1] > ref[2]:
                toohigh=True
            elif comp[2] <= ref[2]:
                if not outd.get(comp[0], False) or ref[3] < outd[comp[0]][1]:
                    outd[comp[0]] = (ref[0], ref[3])
    print(runs)
    return outd

很好。预先对数据进行排序是个好主意。对于小规模的情况,可能不会有太大的影响,但由于 OP 谈到了数百万行,这真的很重要。有一件事让我印象深刻,也许只是为了完整起见:在从 csv 读取后,我们仍然需要进行 int 转换,对吗? - MatBBastos
是的,在实际应用中需要进行转换。我在测试中使用 randint 生成了这些值,所以不需要它。 - gimix

1

我删除了一些未使用的库,并尽可能简化了代码的行为。

代码中最重要的对象是列表input_data,它存储来自输入csv文件的数据,以及字典out_dict,它存储比较的输出。

简单来说,代码的作用是:

  1. product.csv(没有标题)读入列表input_data
  2. 遍历input_data,将每一行与其他行进行比较
    • 如果参考产品范围在比较产品范围,我们检查一个新条件:参考产品是否有out_dict
      • 如果是,则用新的比较产品替换它,如果它具有较低的count_pos
      • 如果没有,则无论如何添加比较产品
  3. out_dict中的信息写入输出文件product_p.csv,但仅适用于具有有效比较产品的产品

这就是它:

import csv

input_data = []
with open('product.csv', 'r') as csv_in:
    reader = csv.reader(csv_in, delimiter=';')
    next(reader)
    for row in reader:
        input_data.append(row)


out_dict = {}
for ref in input_data:
    for comp in input_data:
        if ref == comp:
            continue
        elif int(ref[1]) >= int(comp[1]) and int(ref[2]) <= int(comp[2]):
            if not out_dict.get(ref[0], False) or int(comp[3]) < out_dict[ref[0]][1]:
                out_dict[ref[0]] = (comp[0], int(comp[3]))
                # print(f"In '{ref[0]}': placed '{comp[0]}'")


with open('product_p.csv', "w") as csv_out:
    ecriture = csv.writer(csv_out, delimiter=';', quotechar='"', quoting=csv.QUOTE_ALL)
    for key, value in out_dict.items():
        if value[0]:
            ecriture.writerow([key, value[0]])

同时,我将一个print语句注释掉了,这可以使用仅有几行的示例文件向您展示脚本正在执行的操作。


注意: 我认为你期望的输出是错误的。要么是这样,要么是我在解释中漏掉了什么。如果是这种情况,请告诉我。所示代码已考虑到此问题。
从样例输入中:
product;position_min;position_max;count_pos
A.16;167804;167870;20
A.18;167804;167838;15
A.15;167896;167768;18
A.20;238359;238361;33
A.35;167835;167837;8

预期输出应为:
"A.18";"A.16"
"A.15";"A.35"
"A.35";"A.18"

由于在“A.15”中,“A.35”满足与“A.16”和“A.18”相同的条件并且具有较小的count_pos


刚刚更新了答案,因为我认为没有真正的必要使用defaultdict,只需使用简单的get即可满足键的需求。使用[]符号创建键时,不需要已经有默认值。此外,只需使用out_dict.get(ref[0])就足够了 - 因为它会返回None,这是布尔值False - 虽然这不是最佳实践。 - MatBBastos

0
使用sqlite3内存数据库,可以将搜索移动到B树索引中,这比建议的方法更加优化。以下方法比原始方法快30倍。对于生成的200万行文件,每个项目计算结果需要44小时(原始方法需要约1200小时)。
import csv
import sqlite3
import sys
import time

with sqlite3.connect(':memory:') as con:
    cursor = con.cursor()
    cursor.execute('CREATE TABLE products (id integer PRIMARY KEY, product text, position_min int, position_max int, count_pos int)')
    cursor.execute('CREATE INDEX idx_products_main ON products(position_max, position_min, count_pos)')

    with open('product.csv', 'r') as products_file:
        reader = csv.reader(products_file, delimiter=';')
        # Omit parsing first row in file
        next(reader)

        for row in reader:
            row_id = row[0][len('A.'):] if row[0].startswith('A.') else row[0];
            cursor.execute('INSERT INTO products VALUES (?, ?, ?, ?, ?)', [row_id] + row)

    con.commit()

    with open('product_p.csv', 'wb') as write_file:
        with open('product.csv', 'r') as products_file:
            reader = csv.reader(products_file, delimiter=';')
            # Omit parsing first row in file
            next(reader)

            for row in reader:
                row_product_id, row_position_min, row_position_max, count_pos = row
                result_row = cursor.execute(
                    'SELECT product, count_pos FROM products WHERE position_min <= ? AND position_max >= ? ORDER BY count_pos, id LIMIT 1',
                    (row_position_min, row_position_max)
                ).fetchone()

                if (result_row and result_row[0] == row_product_id):
                    result_row = cursor.execute(
                        'SELECT product, count_pos FROM products WHERE product != ? AND position_min <= ? AND position_max >= ? ORDER BY count_pos, id LIMIT 1',
                        (row_product_id, row_position_min, row_position_max)
                    ).fetchone()

                if (result_row):
                    write_file.write(f'{row_product_id};{result_row[0]};{result_row[1]}\n'.encode())

如果需要,可以使用线程进行进一步优化,并且可以通过使用10个线程将结果处理优化为4-5小时。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接