SQLAlchemy - 在postgresql中执行批量upsert操作(如果存在,则更新;否则插入)

69

我正在尝试使用SQLAlchemy模块(非SQL)编写Python中的批量upsert。

在执行SQLAlchemy添加操作时,我遇到了以下错误:

sqlalchemy.exc.IntegrityError: (IntegrityError) duplicate key value violates unique constraint "posts_pkey"
DETAIL:  Key (id)=(TEST1234) already exists.

我有一个名为posts的表,其中id列上有一个主键。

在这个例子中,我已经在数据库中有了一个id=TEST1234的行。当我尝试使用设置为TEST1234id添加一个新的 posts 对象时,就会出现上述错误。我原本认为,如果主键已经存在,记录将会被更新。

如何基于主键使用 Flask-SQLAlchemy 进行 upsert?有简单的解决方案吗?

如果没有,我可以始终检查并删除任何具有匹配 id 的记录,然后插入新记录,但对于我不希望进行许多更新的情况来说,那似乎代价太高了。


6
如果原问题没有提到SQLAlchemy,那么它怎么会是重复的呢? - techkuz
你能否考虑接受exhuma的答案?它利用了PosgreSQL的INSERT … ON CONFLICT DO UPDATE功能,效果非常好。 - GG.
6个回答

56

SQLAlchemy 中有一个类似 upsert 的操作:

db.session.merge()

我发现这个命令后,就能执行 upsert 操作,但值得一提的是,这个操作对于批量“upsert”来说速度比较慢。

另一个选择是获取要 upsert 的主键列表,并查询数据库中是否存在匹配的 id:

# Imagine that post1, post5, and post1000 are posts objects with ids 1, 5 and 1000 respectively
# The goal is to "upsert" these posts.
# we initialize a dict which maps id to the post object

my_new_posts = {1: post1, 5: post5, 1000: post1000} 

for each in posts.query.filter(posts.id.in_(my_new_posts.keys())).all():
    # Only merge those posts which already exist in the database
    db.session.merge(my_new_posts.pop(each.id))

# Only add those posts which did not exist in the database 
db.session.add_all(my_new_posts.values())

# Now we commit our modifications (merges) and inserts (adds) to the database!
db.session.commit()

13
合并操作不处理完整性错误。 - Manoj Sahu
6
以上过程非常缓慢,无法使用。 - Manoj Sahu
8
若在唯一索引上捕获到“重复键”错误,合并操作无法解决此问题,它仅适用于主键。 - Logovskii Dmitrii
13
合并没有诚信。 - deed02392

41

您可以利用on_conflict_do_update变体。一个简单的例子如下:

from sqlalchemy.dialects.postgresql import insert

class Post(Base):
    """
    A simple class for demonstration
    """

    id = Column(Integer, primary_key=True)
    title = Column(Unicode)

# Prepare all the values that should be "upserted" to the DB
values = [
    {"id": 1, "title": "mytitle 1"},
    {"id": 2, "title": "mytitle 2"},
    {"id": 3, "title": "mytitle 3"},
    {"id": 4, "title": "mytitle 4"},
]

stmt = insert(Post).values(values)
stmt = stmt.on_conflict_do_update(
    # Let's use the constraint name which was visible in the original posts error msg
    constraint="post_pkey",

    # The columns that should be updated on conflict
    set_={
        "title": stmt.excluded.title
    }
)
session.execute(stmt)

有关ON CONFLICT DO UPDATE的更多细节,请参见Postgres文档

有关on_conflict_do_update的更多细节,请参见SQLAlchemy文档

关于重复的列名的说明

上面的代码在values列表和set_参数中都使用列名作为字典键。如果在类定义中更改了列名,则必须在所有地方进行更改,否则会出错。可以通过访问列定义来避免这种情况,尽管代码会变得有些丑陋,但更加健壮:

coldefs = Post.__table__.c

values = [
    {coldefs.id.name: 1, coldefs.title.name: "mytitlte 1"},
    ...
]

stmt = stmt.on_conflict_do_update(
    ...
    set_={
        coldefs.title.name: stmt.excluded.title
        ...
    }
)

我的 constraint="post_pkey" 代码失败了,因为 sqlalchemy 找不到我在原始 sql 中创建的唯一约束 CREATE UNIQUE INDEX post_pkey...,然后使用 metadata.reflect(eng, only="my_table") 加载到 sqlalchemy 中后,收到了一个警告 base.py:3515: SAWarning: Skipped unsupported reflection of expression-based index post_pkey 。有什么建议如何修复吗? - user1071182
@user1071182 我认为将此作为单独的问题发布会更好。这样可以让您添加更多细节。如果没有看到完整的“CREATE INDEX”语句,很难猜测出错了什么。我不能保证任何东西,因为我还没有使用SQLAlchemy进行部分索引的工作。但也许其他人可能有解决方案。 - exhuma
@exhuma @GG。感谢您提供的解决方案,但我在使用时遇到了问题。当我运行它时,会出现错误提示:“当前数据库版本设置下的'default'方言不支持原地多行插入。”。 当我upsert单个值时,它可以正常工作,但在多个值上会出现此错误。有什么办法可以解决吗? - Nikhil Arora
@NikhilArora 这个错误提示表明你的元数据设置与数据库版本不兼容。请确保你正在使用最新版本的SQLAlchemy,并且确保你正在使用PostgreSQL作为后端。如果两者都是正确的,请检查你正在使用的PostgreSQL驱动程序。以上内容已在“psycopg2”上进行了测试。 - exhuma

6

使用编译扩展的另一种方法 (https://docs.sqlalchemy.org/en/13/core/compiler.html):

from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert

@compiles(Insert)
def compile_upsert(insert_stmt, compiler, **kwargs):
    """
    converts every SQL insert to an upsert  i.e;
    INSERT INTO test (foo, bar) VALUES (1, 'a')
    becomes:
    INSERT INTO test (foo, bar) VALUES (1, 'a') ON CONFLICT(foo) DO UPDATE SET (bar = EXCLUDED.bar)
    (assuming foo is a primary key)
    :param insert_stmt: Original insert statement
    :param compiler: SQL Compiler
    :param kwargs: optional arguments
    :return: upsert statement
    """
    pk = insert_stmt.table.primary_key
    insert = compiler.visit_insert(insert_stmt, **kwargs)
    ondup = f'ON CONFLICT ({",".join(c.name for c in pk)}) DO UPDATE SET'
    updates = ', '.join(f"{c.name}=EXCLUDED.{c.name}" for c in insert_stmt.table.columns)
    upsert = ' '.join((insert, ondup, updates))
    return upsert

这将确保所有插入语句都像upserts一样工作。这个实现是在Postgres方言中的,但修改为MySQL方言应该相当容易。


1
在使用该片段时出现以下错误:sqlalchemy.exc.ProgrammingError: (psycopg2.errors.SyntaxError) syntax error at or near ")" LINE 1: ...on) VALUES ('US^WYOMING^ALBANY', '') ON CONFLICT () DO UPDAT... - Mark Coletti
啊,好发现!如果你的表中没有主键,这个操作就不会起作用。让我加一个修复方案。 - danielcahall
实际上,如果您没有主键,我不确定为什么您需要这个-您能详细说明一下问题吗? - danielcahall
2
所有插入转换为upserts是有风险的。有时候,您需要获得完整性错误以确保数据一致性并避免意外覆盖。只有在您120%了解此解决方案的所有影响时才应使用它! - exhuma
1
请注意,如果您正在使用Postgres,则最好使用内置的ON CONFLICT功能 - Ramon Dias

2

我开始研究这个问题,我认为使用bulk_insert_mappingsbulk_update_mappings的组合而不是merge可以在sqlalchemy中高效地执行upsert操作。

import time
import sqlite3

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from contextlib import contextmanager


engine = None
Session = sessionmaker()
Base = declarative_base()


def creat_new_database(db_name="sqlite:///bulk_upsert_sqlalchemy.db"):
    global engine
    engine = create_engine(db_name, echo=False)
    local_session = scoped_session(Session)
    local_session.remove()
    local_session.configure(bind=engine, autoflush=False, expire_on_commit=False)
    Base.metadata.drop_all(engine)
    Base.metadata.create_all(engine)


@contextmanager
def db_session():
    local_session = scoped_session(Session)
    session = local_session()

    session.expire_on_commit = False

    try:
        yield session
    except BaseException:
        session.rollback()
        raise
    finally:
        session.close()


class Customer(Base):
    __tablename__ = "customer"
    id = Column(Integer, primary_key=True)
    name = Column(String(255))


def bulk_upsert_mappings(customers):

    entries_to_update = []
    entries_to_put = []
    with db_session() as sess:
        t0 = time.time()

        # Find all customers that needs to be updated and build mappings
        for each in (
            sess.query(Customer.id).filter(Customer.id.in_(customers.keys())).all()
        ):
            customer = customers.pop(each.id)
            entries_to_update.append({"id": customer["id"], "name": customer["name"]})

        # Bulk mappings for everything that needs to be inserted
        for customer in customers.values():
            entries_to_put.append({"id": customer["id"], "name": customer["name"]})

        sess.bulk_insert_mappings(Customer, entries_to_put)
        sess.bulk_update_mappings(Customer, entries_to_update)
        sess.commit()

    print(
        "Total time for upsert with MAPPING update "
        + str(len(customers))
        + " records "
        + str(time.time() - t0)
        + " sec"
        + " inserted : "
        + str(len(entries_to_put))
        + " - updated : "
        + str(len(entries_to_update))
    )


def bulk_upsert_merge(customers):

    entries_to_update = 0
    entries_to_put = []
    with db_session() as sess:
        t0 = time.time()

        # Find all customers that needs to be updated and merge
        for each in (
            sess.query(Customer.id).filter(Customer.id.in_(customers.keys())).all()
        ):
            values = customers.pop(each.id)
            sess.merge(Customer(id=values["id"], name=values["name"]))
            entries_to_update += 1

        # Bulk mappings for everything that needs to be inserted
        for customer in customers.values():
            entries_to_put.append({"id": customer["id"], "name": customer["name"]})

        sess.bulk_insert_mappings(Customer, entries_to_put)
        sess.commit()

    print(
        "Total time for upsert with MERGE update "
        + str(len(customers))
        + " records "
        + str(time.time() - t0)
        + " sec"
        + " inserted : "
        + str(len(entries_to_put))
        + " - updated : "
        + str(entries_to_update)
    )


if __name__ == "__main__":

    batch_size = 10000

    # Only inserts
    customers_insert = {
        i: {"id": i, "name": "customer_" + str(i)} for i in range(batch_size)
    }

    # 50/50 inserts update
    customers_upsert = {
        i: {"id": i, "name": "customer_2_" + str(i)}
        for i in range(int(batch_size / 2), batch_size + int(batch_size / 2))
    }

    creat_new_database()
    bulk_upsert_mappings(customers_insert.copy())
    bulk_upsert_mappings(customers_upsert.copy())
    bulk_upsert_mappings(customers_insert.copy())

    creat_new_database()
    bulk_upsert_merge(customers_insert.copy())
    bulk_upsert_merge(customers_upsert.copy())
    bulk_upsert_merge(customers_insert.copy())

基准测试结果:

Total time for upsert with MAPPING: 0.17138004302978516 sec inserted : 10000 - updated : 0
Total time for upsert with MAPPING: 0.22074174880981445 sec inserted : 5000 - updated : 5000
Total time for upsert with MAPPING: 0.22307634353637695 sec inserted : 0 - updated : 10000
Total time for upsert with MERGE: 0.1724097728729248 sec inserted : 10000 - updated : 0
Total time for upsert with MERGE: 7.852903842926025 sec inserted : 5000 - updated : 5000
Total time for upsert with MERGE: 15.11970829963684 sec inserted : 0 - updated : 10000

1
你的回答确实很有趣,但需要注意的是它存在一些缺点。正如文档所述,出于性能和安全原因,这些批量方法正在逐渐被移入遗留状态。请先查看警告部分。此外,如果您需要在具有关系的表上进行批量更新,则不值得使用该方法。请查看同一链接的参数部分中的return_defaults参数。 - Ramon Dias

1

我知道这有点晚了,但我在@Emil Wåreus给出的答案基础上构建了一个函数,可以用于任何模型(表)。

def upsert_data(self, entries, model, key):
    entries_to_update = []
    entries_to_insert = []
    
    # get all entries to be updated
    for each in session.query(model).filter(getattr(model, key).in_(entries.keys())).all():
        entry = entries.pop(str(getattr(each, key)))
        entries_to_update.append(entry)
        
    # get all entries to be inserted
    for entry in entries.values():
        entries_to_insert.append(entry)

    session.bulk_insert_mappings(model, entries_to_insert)
    session.bulk_update_mappings(model, entries_to_update)

    session.commit()

entries 应该是一个字典,以主键值作为键,值应该是映射(将值与数据库列进行映射)。

model 是您想要 upsert 的 ORM 模型。

key 是表的主键。

您甚至可以使用此函数从字符串中获取要插入的表的模型。

def get_table(self, table_name):
    for c in self.base._decl_class_registry.values():
        if hasattr(c, '__tablename__') and c.__tablename__ == table_name:
            return c

使用这个方法,你只需要将表名作为字符串传递给upsert_data函数即可。
def upsert_data(self, entries, table, key):
    model = get_table(table)
    entries_to_update = []
    entries_to_insert = []
    
    # get all entries to be updated
    for each in session.query(model).filter(getattr(model, key).in_(entries.keys())).all():
        entry = entries.pop(str(getattr(each, key)))
        entries_to_update.append(entry)
        
    # get all entries to be inserted
    for entry in entries.values():
        entries_to_insert.append(entry)

    session.bulk_insert_mappings(model, entries_to_insert)
    session.bulk_update_mappings(model, entries_to_update)

    session.commit()

0

这不是最安全的方法,但它非常简单和快速。我只是试图有选择性地覆盖表的一部分。我删除了已知会冲突的行,然后从pandas数据帧中附加了新行。您的pandas数据帧列名需要与您的SQL表列名匹配。

eng = create_engine('postgresql://...')
conn = eng.connect()

conn.execute("DELETE FROM my_table WHERE col = %s", val)
df.to_sql('my_table', con=eng, if_exists='append')

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接