import pickle
import sqlite3
from pathlib import Path
class AgentMemoryPersistence:
"""智能体记忆的持久化管理"""
def __init__(self, db_path: str = "./agent_memory.db"):
self.db_path = Path(db_path)
self._init_db()
def _init_db(self):
"""初始化数据库"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# 创建记忆表
cursor.execute("""
CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY,
agent_id TEXT NOT NULL,
content TEXT NOT NULL,
memory_type TEXT NOT NULL,
importance REAL,
created_at TEXT NOT NULL,
last_updated TEXT NOT NULL,
access_count INTEGER DEFAULT 0,
data BLOB
)
""")
# 创建会话表
cursor.execute("""
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
agent_id TEXT NOT NULL,
created_at TEXT NOT NULL,
ended_at TEXT,
context_size_tokens INTEGER,
final_status TEXT
)
""")
conn.commit()
def save_memory(
self,
agent_id: str,
memory_id: str,
content: str,
memory_type: str,
importance: float = 0.5,
data: Optional[Any] = None
) -> None:
"""保存单条记忆"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
data_bytes = pickle.dumps(data) if data else None
cursor.execute("""
INSERT OR REPLACE INTO memories
(id, agent_id, content, memory_type, importance, created_at, last_updated, data)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (
memory_id,
agent_id,
content,
memory_type,
importance,
datetime.now().isoformat(),
datetime.now().isoformat(),
data_bytes
))
conn.commit()
def load_memories(
self,
agent_id: str,
memory_type: Optional[str] = None,
min_importance: float = 0.0
) -> List[Dict]:
"""加载记忆"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
query = "SELECT * FROM memories WHERE agent_id = ? AND importance >= ?"
params = [agent_id, min_importance]
if memory_type:
query += " AND memory_type = ?"
params.append(memory_type)
cursor.execute(query, params)
rows = cursor.fetchall()
memories = []
for row in rows:
memory = {
'id': row[0],
'agent_id': row[1],
'content': row[2],
'memory_type': row[3],
'importance': row[4],
'created_at': row[5],
'access_count': row[8],
}
if row[9]: # data blob
try:
memory['data'] = pickle.loads(row[9])
except:
pass
memories.append(memory)
return memories
def save_session_context(
self,
agent_id: str,
session_id: str,
context_snapshot: Dict
) -> None:
"""
保存整个会话的上下文快照。
用于长期任务的中断和恢复。
"""
# 序列化快照
snapshot_json = json.dumps(context_snapshot, default=str)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO sessions
(session_id, agent_id, created_at, context_size_tokens, final_status)
VALUES (?, ?, ?, ?, ?)
""", (
session_id,
agent_id,
datetime.now().isoformat(),
len(snapshot_json) // 4, # Token估计
'in_progress'
))
conn.commit()
# 保存详细的上下文为文件(避免BLOB过大)
context_file = self.db_path.parent / f"context_{session_id}.json"
with open(context_file, 'w') as f:
f.write(snapshot_json)
def load_session_context(self, session_id: str) -> Optional[Dict]:
"""加载保存的会话上下文"""
context_file = self.db_path.parent / f"context_{session_id}.json"
if context_file.exists():
with open(context_file, 'r') as f:
return json.load(f)
return None
def cleanup_old_sessions(self, days_to_keep: int = 30) -> int:
"""清理过期的会话记录"""
cutoff_date = (datetime.now() - timedelta(days=days_to_keep)).isoformat()
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"DELETE FROM sessions WHERE created_at < ? AND final_status = 'completed'",
(cutoff_date,)
)
deleted_count = cursor.rowcount
conn.commit()
return deleted_count