Postgres ListenNotify如何构建轻量级发布订阅系统?

摘要:概述 原先设计一个内部系统的消息模块和缓存模块时,只有一个Postgres依赖。想着没多大用户量,没必要额外安装Redis,徒增运维工夫。缓存好解决,配个UNLOGGED表即可。吭吭哧哧琢磨怎么用数据表实现消息的时候,发现PostgreSQ
概述 原先设计一个内部系统的消息模块和缓存模块时,只有一个Postgres依赖。想着没多大用户量,没必要额外安装Redis,徒增运维工夫。缓存好解决,配个UNLOGGED表即可。吭吭哧哧琢磨怎么用数据表实现消息的时候,发现PostgreSQL 提供了内置命令 LISTEN 和 NOTIFY,用于在数据库服务器和连接的客户端之间实现异步通信。这个 PostgreSQL 特有的扩展功能使得数据库可以作为一个轻量级的消息队列(MQ)系统使用,允许应用程序从数据库中生成事件,并由其他客户端实时响应。于是一拍即合,上手体验一下。 核心特性 轻量级实现:无需额外的消息中间件,直接利用 PostgreSQL 内置功能 异步通信:支持发布-订阅模式,实现解耦的组件通信 内存高效:通道(Channel)是纯内存对象,不占用磁盘空间 零配置:无需预先创建或管理通道,随用随建 适用场景 实时仪表盘:数据变更时实时推送更新 缓存失效:数据更新时通知缓存层刷新 数据审计:跟踪重要数据变更事件 任务调度:构建简单的分布式任务队列 事件驱动架构:实现微服务间的事件通信 通道(Channel)机制 重要特性: Channel 是纯内存对象,随 LISTEN 命令隐式创建 当所有监听会话断开或执行 UNLISTEN 时自动回收 无需手动创建或删除通道,也不支持此操作 消息传递模型 PostgreSQL 的 NOTIFY 采用典型的 "无监听即丢弃" 机制: 没有监听者时,消息不会入队 不占用磁盘空间 不消耗持久化内存 消息仅在存在活跃监听者时传递 基础使用 psql 命令行示例 -- 监听指定通道 LISTEN task_channel; -- 向通道发送消息 NOTIFY task_channel, '123456'; -- 取消监听所有通道 UNLISTEN *; -- 查看当前监听的通道 SELECT pg_listening_channels(); -- 查看系统通知状态 SELECT * FROM pg_stat_activity WHERE backend_type = 'client backend'; 动态消息生成 标准的 NOTIFY 命令要求消息内容必须明确指定,不支持动态字符串拼接。但可以使用 pg_notify() 函数来生成动态通知: -- 使用 pg_notify 函数支持动态消息 SELECT pg_notify('my_channel', 'Hello, ' || 'World!'); -- 带参数的动态消息 SELECT pg_notify('audit_channel', 'User ' || current_user || ' logged in at ' || now()::text); Python 实现示例 项目结构 ├── main.py # 核心实现:TaskWorker 和 TaskProducer ├── conf/ │ └── config.toml # 配置文件 └── pkg/ └── config/ # 配置管理模块 安装依赖 uv add "psycopg[binary,pool]>=3.3.3" 配置管理 首先,通过配置文件管理数据库连接和通道设置: # conf/config.toml [database.postgres] host = "127.0.0.1" port = 5432 user = "username" password = "password" dbname = "database_name" pool_min_size = 2 pool_max_size = 10 channel = "task_channel" # 默认通道名称 配置模块的代码示例:pkg/config/config.py import tomllib from typing import Any, Dict class BaseConfig: def __init__(self, cfg_file: str): self._cfg_file = cfg_file self._data: Dict[str, Any] = {} self._load_config() def _load_config(self) -> None: if self._data: return try: with open(self._cfg_file, "rb") as f: self._data = tomllib.load(f) except FileNotFoundError: raise RuntimeError(f"配置文件不存在: {self._cfg_file}") except Exception as e: raise RuntimeError(f"加载配置文件失败: {e}") class PostgresConfigMixin(BaseConfig): def postgres_host(self) -> str: return self._data.get("database", {}).get("postgres", {}).get("host", "") def postgres_port(self) -> int: return self._data.get("database", {}).get("postgres", {}).get("port", 5432) def postgres_user(self) -> str: return self._data.get("database", {}).get("postgres", {}).get("user", "") def postgres_password(self) -> str: return self._data.get("database", {}).get("postgres", {}).get("password", "") def postgres_dbname(self) -> str: return self._data.get("database", {}).get("postgres", {}).get("dbname", "") def postgres_pool_min_size(self) -> int: return ( self._data.get("database", {}).get("postgres", {}).get("pool_min_size", 2) ) def postgres_pool_max_size(self) -> int: return ( self._data.get("database", {}).get("postgres", {}).get("pool_max_size", 10) ) def postgres_channel(self) -> str: return ( self._data.get("database", {}) .get("postgres", {}) .get("channel", "default_channel") ) def get_postgres_dsn(self, hide_password: bool = False) -> str: """获取PostgreSQL连接DSN""" password = self.postgres_password() if hide_password and password: password = "***" # psycopg3 使用标准的 PostgreSQL 连接字符串格式 return ( f"postgresql://{self.postgres_user()}:{password}@" f"{self.postgres_host()}:{self.postgres_port()}/{self.postgres_dbname()}" ) class LLMConfigMixin(BaseConfig): """LLM配置Mixin""" def llm_model(self) -> str: return self._data.get("llm", {}).get("model", "") def llm_base_url(self) -> str: return self._data.get("llm", {}).get("base_url", "") def llm_api_key(self) -> str: return self._data.get("llm", {}).get("api_key", "") class Config(PostgresConfigMixin, LLMConfigMixin): def __init__(self, cfg_file: str = "conf/config.toml"): super().__init__(cfg_file) def reload(self) -> None: """重新加载配置""" self._data = {} # 清空数据 self._load_config() 核心组件实现 1. 任务消费者(TaskWorker) TaskWorker 负责监听指定通道并处理接收到的任务: # main.py - TaskWorker 类核心部分 import asyncio import json import signal from typing import Set from psycopg import AsyncConnection, Notify, sql from psycopg_pool import AsyncConnectionPool from pkg.config import cfg MAX_CONCURRENCY = 10 class TaskWorker: def __init__(self, dsn: str, channel: str): self._dsn = dsn self.channel = channel self.pool: AsyncConnectionPool | None = None self.listener_conn: AsyncConnection | None = None self.sem = asyncio.Semaphore(MAX_CONCURRENCY) self.active_tasks: Set[asyncio.Task] = set() async def start(self) -> None: self.pool = AsyncConnectionPool( self._dsn, min_size=cfg.postgres_pool_min_size(), max_size=cfg.postgres_pool_max_size(), open=False, # 延迟打开, 避免阻塞 ) await self.pool.open() # 独立监听连接,防止 LISTEN 状态随连接回收丢失 self.listener_conn = await AsyncConnection.connect(self._dsn, autocommit=True) await self.listener_conn.execute( sql.SQL("LISTEN {}").format(sql.Identifier(self.channel)) ) print(f"Listening on channel: {self.channel}") try: async for notify in self.listener_conn.notifies(): if notify.channel == self.channel and notify.payload: await self._dispatch_task(notify) except asyncio.CancelledError: print("Listener cancelled") except Exception as e: print(f"Listener error: {e}") finally: await self.stop() async def _dispatch_task(self, notify: Notify) -> None: """接收通知并分发任务""" task_info = notify.payload.strip() try: task_data = json.loads(task_info) # 假设 payload 是 JSON 格式的字符串 print(f"Received task notification: {task_data}") except json.JSONDecodeError: task_data = {"task_id": task_info} # 如果不是 JSON 格式,使用原始字符串 print(f"Received non-JSON task notification: {task_data}") except Exception as e: print(f"Invalid task ID received: {task_info}") return if not isinstance(task_data, dict) or "task_id" not in task_data: print( f"Missing task_id in notification: {task_data}, or task_data is not a dict" ) return task = asyncio.create_task(self._process_task(task_data)) self.active_tasks.add(task) task.add_done_callback(self.active_tasks.discard) async def _process_task(self, task_data: dict) -> None: """执行业务逻辑""" if not self.pool: raise RuntimeError("Connection pool is not initialized") async with self.sem: async with self.pool.connection() as conn: try: await self._execute_business(task_data) print(f"<= Task {task_data['task_id']} completed successfully") except Exception as e: print(f"Error processing task {task_data['task_id']}: {e}") # 这里可以添加重试逻辑或错误记录 await self._log_failure(conn, task_data["task_id"], str(e)) async def _execute_business(self, task_data: dict) -> None: """执行业务逻辑""" print(f"<= Processing task {task_data}...") await asyncio.sleep(5) # 模拟耗时操作 print(f"<= Task {task_data} done.") async def _log_failure(self, conn: AsyncConnection, task_id: int, error_msg: str): """记录失败日志""" try: # 记录失败日志到独立表,便于后续重试或告警 print(f"Logging failure for task {task_id}: {error_msg}") except Exception as e: print(f"Failed to log error for task {task_id}: {e}") async def stop(self) -> None: print("Stopping TaskWorker...") if self.active_tasks: await asyncio.gather(*self.active_tasks, return_exceptions=True) if self.listener_conn: await self.listener_conn.close() if self.pool: await self.pool.close() print("TaskWorker stopped gracefully.") 2. 任务发布者(TaskPublisher) TaskPublisher 负责向通道发布任务消息: # main.py - TaskPublisher 类 class TaskPublisher: def __init__(self, dsn: str): self._dsn = dsn self._pool: AsyncConnectionPool | None = None async def start(self): if not self._pool: self._pool = AsyncConnectionPool( self._dsn, min_size=cfg.postgres_pool_min_size(), max_size=cfg.postgres_pool_max_size(), open=False, # 延迟打开, 避免阻塞 ) await self._pool.open() print("TaskPublisher started.") async def publish(self, channel: str, payload: dict): if not self._pool: raise RuntimeError("Connection pool is not initialized") async with self._pool.connection() as conn: try: payload_str = json.dumps( payload, default=str ) # 将 dict 转换为 JSON 字符串 await conn.execute( sql.SQL("NOTIFY {}, {}").format( sql.Identifier(channel), sql.Literal(payload_str) ) ) print(f"=> Published task to channel {channel}: {payload_str}") return True except Exception as e: print(f"Failed to publish task to channel {channel}: {e}") return False async def publish_batch(self, channel: str, payloads: list[dict]): if not self._pool: raise RuntimeError("Connection pool is not initialized") count = 0 async with self._pool.connection() as conn: for payload in payloads: try: payload_str = json.dumps(payload, default=str) await conn.execute( sql.SQL("NOTIFY {}, {}").format( sql.Identifier(channel), sql.Literal(payload_str) ) ) print(f"=> Published task to channel {channel}: {payload_str}") count += 1 except Exception as e: print(f"Failed to publish batch tasks to channel {channel}: {e}") return count async def stop(self): if self._pool: await self._pool.close() print("TaskPublisher stopped.") async def __aenter__(self): await self.start() return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.stop() 调用演示 以下是main.py中的演示部分: # main.py async def run_worker(): worker = TaskWorker(cfg.get_postgres_dsn(), cfg.postgres_channel()) task = asyncio.create_task(worker.start()) return task async def run_publisher(): async with TaskPublisher(cfg.get_postgres_dsn()) as publisher: for i in range(1, 11): payload = {"task_id": i, "data": f"Task data {i}"} await publisher.publish(cfg.postgres_channel(), payload) await asyncio.sleep(0.5) # 模拟发布间隔 async def main(): # worker = TaskWorker(cfg.get_postgres_dsn(), cfg.postgres_channel()) # loop = asyncio.get_running_loop() # stop_evt = asyncio.Event() # for sig in (signal.SIGINT, signal.SIGTERM): # loop.add_signal_handler(sig, stop_evt.set) # listen_task = asyncio.create_task(worker.start()) # await stop_evt.wait() # print("Shutdown signal received, stopping worker...") # listen_task.cancel() # await listen_task worker_task = await run_worker() await asyncio.sleep(2) # 确保 worker 已经启动并监听 await run_publisher() # 发布任务 print("All tasks published successfully.") await asyncio.sleep(30) # 等待 worker 处理完任务 worker_task.cancel() # 停止 worker try: await worker_task except asyncio.CancelledError: pass if __name__ == "__main__": asyncio.run(main()) 使用场景扩展 场景一:实时数据同步 async def sync_data_change(self, table_name: str, record_id: str, operation: str): """数据变更时发送同步通知""" message = f"{table_name}:{record_id}:{operation}" await self.publish("data_sync_channel", message) 场景二:分布式锁通知 async def notify_lock_release(self, lock_name: str): """锁释放时通知等待者""" await self.publish("distributed_lock_channel", f"RELEASE:{lock_name}") 场景三:缓存失效广播 async def invalidate_cache(self, cache_key: str): """缓存失效时广播通知""" await self.publish("cache_invalidation_channel", cache_key) 注意事项 技术限制 消息大小:NOTIFY 消息最大为 8000 字节 无持久化:消息不持久化,重启后丢失 无确认机制:发送方无法知道消息是否被接收 无顺序保证:消息可能不按发送顺序到达 生产环境建议 监控告警:实现通道监听状态监控 错误处理:添加完善的错误处理和日志记录 备份机制:重要消息应有备份存储 性能测试:在高负载下测试系统表现 补充 asyncpg版 asyncpg是python连接postgres的纯异步驱动,性能更好 安装 uv add asyncpg 代码示例: import asyncio import signal from typing import Set import asyncpg from pkg.config import cfg MAX_CONCURRENT_TASKS = 5 # 最大并发任务数 class TaskWorker: def __init__(self, dsn: str, channel: str): self._dsn = dsn self._channel = channel self._pool: asyncpg.Pool | None = None self._listener_conn: asyncpg.Connection | None = None self._sem = asyncio.Semaphore(MAX_CONCURRENT_TASKS) self._active_tasks: Set[asyncio.Task] = set() async def start(self): """启动监听器""" if not self._pool: self._pool = await asyncpg.create_pool( dsn=self._dsn, min_size=cfg.postgres_pool_min_size(), max_size=cfg.postgres_pool_max_size(), ) # 创建专用连接用于监听 if not self._listener_conn: self._listener_conn = await asyncpg.connect(dsn=self._dsn) await self._listener_conn.add_listener(self._channel, self._on_notify) print(f"Listening on channel: {self._channel}") try: # 保持监听状态 await asyncio.Future() except asyncio.CancelledError: pass finally: await self.stop() async def _on_notify( self, conn: asyncpg.Connection, pid: int, channel: str, payload: str ): """收到 NOTIFY 时的回调函数""" if not payload: return task_id = payload.strip() print(f"Received notification: {task_id} on channel: {channel}") # 提交到事件循环,带并发限制 task = asyncio.create_task(self._handle_task(task_id)) self._active_tasks.add(task) task.add_done_callback(self._active_tasks.discard) async def _handle_task(self, task_id: str): """任务处理核心逻辑""" if not self._pool: raise RuntimeError("数据库连接池未初始化") async with self._sem: # 并发控制 async with self._pool.acquire() as conn: try: # 1. 执行业务逻辑 await self._execute_business_logic(task_id) # 2. 更新任务状态(示例) print(f"Task {task_id} completed successfully.") except Exception as e: print(f"Task {task_id} failed: {e}") await self._log_failure(task_id, str(e)) async def _execute_business_logic(self, task_id: str): """模拟业务逻辑处理""" print(f"Processing task {task_id}...") await asyncio.sleep(5) # 模拟耗时操作 print(f"Task {task_id} completed.") async def _log_failure(self, task_id: str, error: str): """记录失败日志""" print(f"Task {task_id} failure logged: {error}") # 实际项目中可记录到数据库 async def stop(self): """优雅关闭""" print("Shutting down gracefully...") # 等待正在运行的任务完成 if self._active_tasks: await asyncio.gather(*self._active_tasks, return_exceptions=True) # 清理资源 if self._listener_conn: await self._listener_conn.remove_listener(self._channel, self._on_notify) await self._listener_conn.close() if self._pool: await self._pool.close() print("Shutdown complete.") async def main(): worker = TaskWorker(cfg.get_postgres_dsn(), cfg.postgres_channel()) loop = asyncio.get_running_loop() stop_evt = asyncio.Event() # 注册信号处理器,优雅关闭 for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, stop_evt.set) # 启动监听器 listen_task = asyncio.create_task(worker.start()) # 等待停止事件 await stop_evt.wait() print("Shutdown signal received, stopping worker...") # 取消监听任务并等待完成 listen_task.cancel() await listen_task if __name__ == "__main__": asyncio.run(main())