如何将SQLAlchemy的UPSERT操作为?

摘要:前言 SQLite 和 PostgreSQL 都支持 UPSERT 操作,即"有则更新,无则新增"。冲突列必须有唯一约束。 语法: PostgreSQL: INSERT ... ON CO
前言 SQLite 和 PostgreSQL 都支持 UPSERT 操作,即"有则更新,无则新增"。冲突列必须有唯一约束。 语法: PostgreSQL: INSERT ... ON CONFLICT (column) DO UPDATE/NOTHING SQLite: INSERT ... ON CONFLICT(column) DO UPDATE/NOTHING。注意括号位置 场景 PostgreSQL SQLite 说明 基本 UPSERT ON CONFLICT (col) DO UPDATE SET ... ON CONFLICT(col) DO UPDATE SET ... 括号位置略有不同 冲突忽略 ON CONFLICT (col) DO NOTHING ON CONFLICT(col) DO NOTHING 相同 引用新值 EXCLUDED.col excluded.col PostgreSQL 大写,SQLite 小写 返回结果 RETURNING * RETURNING * 相同 条件更新 WHERE condition 不支持 WHERE SQLite 限制 注意事项 冲突列必须有唯一约束 PostgreSQL 和 SQLite 的语法相似,但仍有细微差别。使用原生 SQL 时需要注意。 SQLite 在 UPSERT 时不支持 WHERE 子句,需要改用 CASE 表达式或应用层过滤。 SQLite 3.35+ 版本才支持 RETURNING EXCLUDED 和 RETURNING EXCLUDED EXCLUDED 表示冲突时被拦截的新值。 INSERT INTO users (email, name, age) VALUES ('test@example.com', '新名字', 30) ON CONFLICT (email) DO UPDATE SET name = EXCLUDED.name, -- ← 引用新值 "新名字" age = EXCLUDED.age -- ← 引用新值 30 场景 表达式 含义 示例值 原表字段 users.name 冲突行的当前值 "老名字" 新值字段 EXCLUDED.name 试图插入的新值 "新名字" 混合计算 users.age + EXCLUDED.age 原值 + 新值 25 + 30 = 55 示例 1:累加库存 -- 商品库存累加:原库存 100 + 新增 50 = 150 INSERT INTO products (sku, stock) VALUES ('IPHONE15', 50) ON CONFLICT (sku) DO UPDATE SET stock = products.stock + EXCLUDED.stock -- 100 + 50 RETURNING stock; 示例 2:仅更新非空字段 -- 如果新值为 NULL,保留原值 INSERT INTO users (email, name, age) VALUES ('test@example.com', '新名字', NULL) ON CONFLICT (email) DO UPDATE SET name = COALESCE(EXCLUDED.name, users.name), -- 新名字 age = COALESCE(EXCLUDED.age, users.age) -- 保留原 age 示例 3:时间戳更新 -- 更新时刷新 updated_at INSERT INTO users (email, name) VALUES ('test@example.com', '新名字') ON CONFLICT (email) DO UPDATE SET name = EXCLUDED.name, updated_at = NOW() -- PostgreSQL -- updated_at = CURRENT_TIMESTAMP -- SQLite RETURNING RETURNING 用于返回操作结果。在 INSERT/UPDATE/DELETE 后直接返回指定列,避免额外 SELECT 查询: INSERT INTO users (email, name) VALUES ('test@example.com', '张三') RETURNING id, email, name, created_at; 示例 1:插入后立即获取 ID # PostgreSQL / SQLite 3.35+ sql = text(""" INSERT INTO users (email, name) VALUES (:email, :name) RETURNING id, email, created_at """) result = await session.execute(sql, {"email": "test@example.com", "name": "张三"}) user = result.mappings().first() print(user["id"]) # 直接获取 ID 示例 2:UPSERT 后统一返回 -- 无论插入还是更新,都返回最终状态 INSERT INTO users (email, name, login_count) VALUES ('test@example.com', '张三', 1) ON CONFLICT (email) DO UPDATE SET name = EXCLUDED.name, login_count = users.login_count + 1 -- 累加登录次数 RETURNING id, email, name, login_count, CASE WHEN xmax = 0 THEN 'inserted' -- PostgreSQL 特有:xmax=0 表示插入 ELSE 'updated' END AS action 示例 3:批量操作返回所有结果 -- PostgreSQL 支持批量 RETURNING INSERT INTO users (email, name) VALUES ('a@example.com', 'A'), ('b@example.com', 'B') ON CONFLICT (email) DO UPDATE SET name = EXCLUDED.name RETURNING id, email, name; Python 处理批量返回: result = await session.execute(sql) users = [dict(row) for row in result.mappings().all()] # [{'id': 1, 'email': 'a@example.com', 'name': 'A'}, ...] 示例:用户登录计数器 async def record_user_login(session: AsyncSession, email: str, name: str) -> dict: """ 用户登录计数器: - 新用户:插入,login_count = 1 - 老用户:更新,login_count += 1 - 返回最终状态 + 操作类型 """ sql = text(""" INSERT INTO users ( email, name, login_count, last_login, created_at ) VALUES ( :email, :name, 1, :now, :now ) ON CONFLICT (email) DO UPDATE SET name = EXCLUDED.name, -- 更新用户名 login_count = users.login_count + 1, -- 累加登录次数 last_login = EXCLUDED.last_login -- 更新最后登录时间 RETURNING id, email, name, login_count, last_login, created_at, CASE WHEN xmax = 0 THEN 'inserted' ELSE 'updated' END AS action -- PostgreSQL 特有:区分插入/更新 """) now = datetime.utcnow() result = await session.execute( sql, {"email": email, "name": name, "now": now} ) row = result.mappings().first() return dict(row) if row else None # 使用示例 user = await record_user_login(session, "test@example.com", "张三") print(f"{user['action']} user {user['email']} with {user['login_count']} logins") # 输出: inserted user test@example.com with 1 logins # 或: updated user test@example.com with 5 logins 示例数据模型类 from sqlalchemy import Column, Integer, String, UniqueConstraint from sqlalchemy.orm import DeclarativeBase class Base(DeclarativeBase): pass class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True, autoincrement=True) email = Column(String(100), unique=True, nullable=False) # 唯一约束 name = Column(String(50)) age = Column(Integer) balance = Column(Integer, default=0) __table_args__ = ( UniqueConstraint("email", name="uq_users_email"), ) class Product(Base): __tablename__ = "products" id = Column(Integer, primary_key=True) sku = Column(String(50), unique=True, nullable=False) # 唯一 SKU name = Column(String(100)) stock = Column(Integer, default=0) price = Column(Integer) ORM 方式 注意 insert 的导入路径。 基本示例 from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy import insert async def upsert_user_orm(session: AsyncSession, user_data: dict) -> dict: """ UPSERT 用户(ORM 风格) 如果 email 冲突则更新,否则插入 """ # 方式 1:使用通用 insert(推荐⭐) # SQLAlchemy 会根据方言自动选择正确的语法 stmt = ( insert(User) .values(**user_data) .on_conflict_do_update( index_elements=["email"], # 冲突检测列(唯一约束) set_={ "name": user_data["name"], "age": user_data.get("age"), "updated_at": func.now() # 假设有 updated_at 列 } ) .returning(User) # 返回插入/更新后的行 ) result = await session.execute(stmt) user = result.scalar_one() return { "id": user.id, "email": user.email, "name": user.name, "age": user.age } async def upsert_user_ignore(session: AsyncSession, user_data: dict) -> bool: """ UPSERT 但冲突时忽略(DO NOTHING) """ stmt = ( insert(User) .values(**user_data) .on_conflict_do_nothing( index_elements=["email"] ) ) result = await session.execute(stmt) return result.rowcount > 0 # 返回是否插入成功 条件更新:仅更新特定字段 async def upsert_user_conditional(session: AsyncSession, user_data: dict) -> dict: """ UPSERT:冲突时只更新非空字段 """ stmt = ( insert(User) .values(**user_data) .on_conflict_do_update( index_elements=["email"], set_={ "name": user_data["name"], # 条件:只有提供了 age 才更新 "age": user_data.get("age", User.age), # 保持原值 }, # 可选:添加 WHERE 条件 where=User.email == user_data["email"] ) .returning(User) ) result = await session.execute(stmt) return result.mappings().first() 批量 UPSERT async def bulk_upsert_users(session: AsyncSession, users: list[dict]) -> int: """ 批量 UPSERT 用户 """ stmt = ( insert(User) .values(users) .on_conflict_do_update( index_elements=["email"], set_={ "name": insert(User).excluded.name, # 使用 excluded 表示新值 "age": insert(User).excluded.age, } ) ) result = await session.execute(stmt) return result.rowcount 使用 EXCLUDED 引用新值 async def upsert_product_with_stock(session: AsyncSession, product_data: dict) -> dict: """ UPSERT 产品:冲突时累加库存 """ stmt = ( insert(Product) .values(**product_data) .on_conflict_do_update( index_elements=["sku"], set_={ # 累加库存:原库存 + 新库存 "stock": Product.stock + insert(Product).excluded.stock, # 更新其他字段 "name": insert(Product).excluded.name, "price": insert(Product).excluded.price, } ) .returning(Product) ) result = await session.execute(stmt) return result.mappings().first() 用户服务 class UserService: """用户服务(支持 UPSERT)""" def __init__(self, session: AsyncSession): self.session = session async def create_or_update(self, email: str, name: str, age: int | None = None) -> dict: """创建或更新用户""" stmt = ( insert(User) .values( email=email, name=name, age=age, created_at=datetime.utcnow() ) .on_conflict_do_update( index_elements=["email"], set_={ "name": name, "age": age, "updated_at": datetime.utcnow() } ) .returning(User) ) result = await self.session.execute(stmt) user = result.scalar_one() return { "id": user.id, "email": user.email, "name": user.name, "age": user.age } async def bulk_create_or_update(self, users: list[dict]) -> int: """批量创建或更新""" stmt = ( insert(User) .values(users) .on_conflict_do_update( index_elements=["email"], set_={ "name": insert(User).excluded.name, "age": insert(User).excluded.age, "updated_at": datetime.utcnow() } ) ) result = await self.session.execute(stmt) return result.rowcount async def create_if_not_exists(self, email: str, name: str) -> bool: """仅当不存在时创建""" stmt = ( insert(User) .values( email=email, name=name, created_at=datetime.utcnow() ) .on_conflict_do_nothing( index_elements=["email"] ) ) result = await self.session.execute(stmt) return result.rowcount > 0 # True = 插入成功,False = 已存在 原生 SQL 基本示例 PostgreSQL async def upsert_user_pg(session: AsyncSession, user_data: dict) -> dict | None: """ PostgreSQL 原生 UPSERT """ sql = text(""" INSERT INTO users (email, name, age, created_at) VALUES (:email, :name, :age, :created_at) ON CONFLICT (email) DO UPDATE -- 冲突列 SET name = EXCLUDED.name, -- EXCLUDED 表示新插入的值 age = EXCLUDED.age, updated_at = NOW() RETURNING id, email, name, age """) result = await session.execute( sql, { "email": user_data["email"], "name": user_data["name"], "age": user_data.get("age"), "created_at": datetime.utcnow() } ) row = result.mappings().first() return dict(row) if row else None SQLite async def upsert_user_sqlite(session: AsyncSession, user_data: dict) -> dict | None: """ SQLite 原生 UPSERT(语法与 PostgreSQL 几乎相同) """ sql = text(""" INSERT INTO users (email, name, age, created_at) VALUES (:email, :name, :age, :created_at) ON CONFLICT(email) DO UPDATE SET -- SQLite 语法稍有不同 name = excluded.name, age = excluded.age, updated_at = CURRENT_TIMESTAMP RETURNING id, email, name, age """) result = await session.execute( sql, { "email": user_data["email"], "name": user_data["name"], "age": user_data.get("age"), "created_at": datetime.utcnow() } ) row = result.mappings().first() return dict(row) if row else None 冲突时忽略 async def insert_or_ignore_user(session: AsyncSession, user_data: dict) -> bool: """ 插入用户,如果冲突则忽略 """ # PostgreSQL sql = text(""" INSERT INTO users (email, name, age, created_at) VALUES (:email, :name, :age, :created_at) ON CONFLICT (email) DO NOTHING """) # SQLite(语法相同) # sql = text(""" # INSERT INTO users (email, name, age, created_at) # VALUES (:email, :name, :age, :created_at) # ON CONFLICT(email) DO NOTHING # """) result = await session.execute( sql, { "email": user_data["email"], "name": user_data["name"], "age": user_data.get("age"), "created_at": datetime.utcnow() } ) return result.rowcount > 0 # 返回是否插入成功 批量 UPSERT async def bulk_upsert_products(session: AsyncSession, products: list[dict]) -> int: """ 批量 UPSERT 产品(原生 SQL) """ # PostgreSQL sql = text(""" INSERT INTO products (sku, name, stock, price, created_at) VALUES ( :sku, :name, :stock, :price, :created_at ) ON CONFLICT (sku) DO UPDATE SET name = EXCLUDED.name, stock = products.stock + EXCLUDED.stock, -- 累加库存 price = EXCLUDED.price, updated_at = NOW() """) # 批量执行 for product in products: await session.execute( sql, { "sku": product["sku"], "name": product["name"], "stock": product.get("stock", 0), "price": product.get("price", 0), "created_at": datetime.utcnow() } ) return len(products) 部分更新 + 条件判断 async def upsert_user_smart(session: AsyncSession, user_data: dict) -> dict | None: """ 智能 UPSERT: - 如果提供了 age,才更新 age - 如果提供了 name,才更新 name - 更新 updated_at """ sql = text(""" INSERT INTO users (email, name, age, created_at) VALUES (:email, :name, :age, :created_at) ON CONFLICT (email) DO UPDATE SET name = COALESCE(:name, users.name), -- 如果新值为 NULL,保持原值 age = COALESCE(:age, users.age), updated_at = NOW() RETURNING id, email, name, age, updated_at """) result = await session.execute( sql, { "email": user_data["email"], "name": user_data.get("name"), # 可能为 None "age": user_data.get("age"), # 可能为 None "created_at": datetime.utcnow() } ) row = result.mappings().first() return dict(row) if row else None 用户注册/登录:存在则更新最后登录时间 async def register_or_login(session: AsyncSession, email: str, name: str) -> dict: """ 用户注册或登录: - 新用户:插入 - 老用户:更新最后登录时间 """ sql = text(""" INSERT INTO users (email, name, last_login, created_at) VALUES (:email, :name, :now, :now) ON CONFLICT (email) DO UPDATE SET last_login = EXCLUDED.last_login, name = EXCLUDED.name -- 可选:更新用户名 RETURNING id, email, name, last_login, created_at """) now = datetime.utcnow() result = await session.execute( sql, {"email": email, "name": name, "now": now} ) return dict(result.mappings().first()) 库存累加 async def add_product_stock(session: AsyncSession, sku: str, quantity: int) -> bool: """ 增加商品库存: - 商品不存在:插入 - 商品存在:累加库存 """ sql = text(""" INSERT INTO products (sku, stock, created_at) VALUES (:sku, :quantity, :now) ON CONFLICT (sku) DO UPDATE SET stock = products.stock + EXCLUDED.stock, updated_at = NOW() """) result = await session.execute( sql, { "sku": sku, "quantity": quantity, "now": datetime.utcnow() } ) return result.rowcount > 0 用户积分累加 async def add_user_points(session: AsyncSession, user_id: int, points: int) -> dict | None: """ 增加用户积分(累加) """ sql = text(""" INSERT INTO user_points (user_id, points, created_at) VALUES (:user_id, :points, :now) ON CONFLICT (user_id) DO UPDATE SET points = user_points.points + EXCLUDED.points, updated_at = NOW() RETURNING user_id, points """) result = await session.execute( sql, { "user_id": user_id, "points": points, "now": datetime.utcnow() } ) row = result.mappings().first() return dict(row) if row else None 标签计数 存在则 +1,不存在则创建: async def increment_tag_count(session: AsyncSession, tag_name: str) -> int: """ 标签计数: - 标签不存在:插入 count=1 - 标签存在:count += 1 """ sql = text(""" INSERT INTO tags (name, count, created_at) VALUES (:name, 1, :now) ON CONFLICT (name) DO UPDATE SET count = tags.count + 1, updated_at = NOW() RETURNING count """) result = await session.execute( sql, {"name": tag_name, "now": datetime.utcnow()} ) return result.scalar() or 0