286 lines
9.8 KiB
Python
286 lines
9.8 KiB
Python
"""缓存管理器:支持过期时间、内存保护、持久化"""
|
||
import os
|
||
import json
|
||
import hashlib
|
||
import logging
|
||
import threading
|
||
import tempfile
|
||
from pathlib import Path
|
||
from datetime import datetime, timedelta
|
||
from typing import Any, Optional, Dict
|
||
from functools import wraps
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 缓存配置 - 使用 tempfile 获取安全的临时目录
|
||
try:
|
||
CACHE_DIR = Path(tempfile.gettempdir()) / "homework_cache"
|
||
except Exception:
|
||
CACHE_DIR = Path("/tmp/homework_cache")
|
||
|
||
CACHE_EXPIRE_DAYS = 30 # 缓存过期天数
|
||
MAX_MEMORY_CACHE_SIZE = 1000 # 内存缓存最大数量
|
||
|
||
|
||
class CacheManager:
|
||
"""
|
||
缓存管理器:内存缓存 + 文件缓存
|
||
- 内存缓存:快速访问,有大小限制
|
||
- 文件缓存:持久化存储,支持过期时间
|
||
- 异常安全:文件操作失败不影响主流程
|
||
"""
|
||
|
||
def __init__(self, cache_name: str, maxsize: int = MAX_MEMORY_CACHE_SIZE, expire_days: int = CACHE_EXPIRE_DAYS):
|
||
"""
|
||
初始化缓存管理器
|
||
|
||
Args:
|
||
cache_name: 缓存名称(用于区分不同类型的缓存)
|
||
maxsize: 内存缓存最大数量
|
||
expire_days: 缓存过期天数
|
||
"""
|
||
self.cache_name = cache_name
|
||
self.maxsize = maxsize
|
||
self.expire_days = expire_days
|
||
|
||
# 内存缓存(使用字典 + 简单的LRU淘汰)
|
||
self._memory_cache: Dict[str, Any] = {}
|
||
self._cache_keys: list = [] # 记录访问顺序,用于LRU淘汰
|
||
self._lock = threading.Lock() # 线程安全锁
|
||
self._file_cache_enabled = True # 文件缓存是否可用
|
||
|
||
# 文件缓存目录 - 带异常处理
|
||
try:
|
||
self.cache_dir = CACHE_DIR / cache_name
|
||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||
# 测试写入权限
|
||
test_file = self.cache_dir / ".test_write"
|
||
test_file.write_text("test")
|
||
test_file.unlink()
|
||
logger.info(f"CacheManager initialized: {cache_name}, dir={self.cache_dir}")
|
||
except Exception as e:
|
||
logger.warning(f"File cache disabled due to permission error: {e}")
|
||
self._file_cache_enabled = False
|
||
logger.info(f"CacheManager initialized (memory only): {cache_name}")
|
||
|
||
def _get_cache_key(self, key: str) -> str:
|
||
"""生成缓存键(使用MD5哈希)"""
|
||
return hashlib.md5(key.encode()).hexdigest()
|
||
|
||
def _get_cache_file(self, cache_key: str) -> Path:
|
||
"""获取缓存文件路径"""
|
||
return self.cache_dir / f"{cache_key}.json"
|
||
|
||
def _is_expired(self, cache_time: str) -> bool:
|
||
"""检查缓存是否过期"""
|
||
try:
|
||
cached_dt = datetime.fromisoformat(cache_time)
|
||
expire_dt = cached_dt + timedelta(days=self.expire_days)
|
||
return datetime.now() > expire_dt
|
||
except Exception:
|
||
return True # 解析失败视为过期
|
||
|
||
def _evict_lru(self):
|
||
"""LRU淘汰:移除最久未使用的缓存项"""
|
||
while len(self._memory_cache) >= self.maxsize and self._cache_keys:
|
||
oldest_key = self._cache_keys.pop(0)
|
||
if oldest_key in self._memory_cache:
|
||
del self._memory_cache[oldest_key]
|
||
logger.debug(f"LRU evicted: {oldest_key[:8]}...")
|
||
|
||
def get(self, key: str) -> Optional[Any]:
|
||
"""
|
||
获取缓存
|
||
|
||
优先级:内存缓存 > 文件缓存 > None
|
||
|
||
Args:
|
||
key: 缓存键
|
||
|
||
Returns:
|
||
缓存值,不存在或过期返回None
|
||
"""
|
||
cache_key = self._get_cache_key(key)
|
||
|
||
# 1. 检查内存缓存(线程安全)
|
||
with self._lock:
|
||
if cache_key in self._memory_cache:
|
||
# 更新访问顺序(移动到末尾)
|
||
if cache_key in self._cache_keys:
|
||
self._cache_keys.remove(cache_key)
|
||
self._cache_keys.append(cache_key)
|
||
logger.debug(f"Memory cache hit: {cache_key[:8]}...")
|
||
return self._memory_cache[cache_key]
|
||
|
||
# 2. 检查文件缓存(仅在文件缓存可用时)
|
||
if self._file_cache_enabled:
|
||
cache_file = self._get_cache_file(cache_key)
|
||
if cache_file.exists():
|
||
try:
|
||
with open(cache_file, 'r', encoding='utf-8') as f:
|
||
cached_data = json.load(f)
|
||
|
||
# 检查是否过期
|
||
if self._is_expired(cached_data.get("cache_time", "")):
|
||
# 过期,删除文件
|
||
try:
|
||
cache_file.unlink()
|
||
except Exception:
|
||
pass
|
||
logger.debug(f"File cache expired: {cache_key[:8]}...")
|
||
return None
|
||
|
||
# 未过期,加载到内存缓存
|
||
with self._lock:
|
||
self._evict_lru() # 淘汰旧的
|
||
self._memory_cache[cache_key] = cached_data["data"]
|
||
self._cache_keys.append(cache_key)
|
||
|
||
logger.debug(f"File cache hit: {cache_key[:8]}...")
|
||
return cached_data["data"]
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Failed to read cache file: {e}")
|
||
# 删除损坏的缓存文件
|
||
try:
|
||
cache_file.unlink()
|
||
except Exception:
|
||
pass
|
||
return None
|
||
|
||
return None
|
||
|
||
def set(self, key: str, value: Any):
|
||
"""
|
||
设置缓存
|
||
|
||
同时存入内存缓存和文件缓存
|
||
|
||
Args:
|
||
key: 缓存键
|
||
value: 缓存值
|
||
"""
|
||
cache_key = self._get_cache_key(key)
|
||
|
||
# 1. 存入内存缓存
|
||
with self._lock:
|
||
self._evict_lru() # 淘汰旧的
|
||
self._memory_cache[cache_key] = value
|
||
if cache_key in self._cache_keys:
|
||
self._cache_keys.remove(cache_key)
|
||
self._cache_keys.append(cache_key)
|
||
|
||
# 2. 存入文件缓存(仅在文件缓存可用时)
|
||
if self._file_cache_enabled:
|
||
cache_file = self._get_cache_file(cache_key)
|
||
try:
|
||
cached_data = {
|
||
"cache_time": datetime.now().isoformat(),
|
||
"data": value
|
||
}
|
||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||
json.dump(cached_data, f, ensure_ascii=False, indent=2)
|
||
|
||
logger.debug(f"Cache saved: {cache_key[:8]}...")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to save cache file: {e}")
|
||
|
||
def clear_expired(self):
|
||
"""清理所有过期的文件缓存"""
|
||
if not self._file_cache_enabled:
|
||
return 0
|
||
|
||
cleaned = 0
|
||
try:
|
||
for cache_file in self.cache_dir.glob("*.json"):
|
||
try:
|
||
with open(cache_file, 'r', encoding='utf-8') as f:
|
||
cached_data = json.load(f)
|
||
|
||
if self._is_expired(cached_data.get("cache_time", "")):
|
||
cache_file.unlink()
|
||
cleaned += 1
|
||
except Exception:
|
||
# 损坏的文件也删除
|
||
try:
|
||
cache_file.unlink()
|
||
except Exception:
|
||
pass
|
||
cleaned += 1
|
||
|
||
if cleaned > 0:
|
||
logger.info(f"Cleaned {cleaned} expired cache files")
|
||
except Exception as e:
|
||
logger.error(f"Failed to clear expired cache: {e}")
|
||
|
||
return cleaned
|
||
|
||
def get_stats(self) -> dict:
|
||
"""获取缓存统计信息"""
|
||
memory_size = len(self._memory_cache)
|
||
|
||
# 统计文件缓存数量
|
||
file_size = 0
|
||
if self._file_cache_enabled:
|
||
try:
|
||
file_size = len(list(self.cache_dir.glob("*.json")))
|
||
except Exception:
|
||
pass
|
||
|
||
return {
|
||
"cache_name": self.cache_name,
|
||
"memory_cache_size": memory_size,
|
||
"memory_cache_maxsize": self.maxsize,
|
||
"file_cache_size": file_size,
|
||
"file_cache_enabled": self._file_cache_enabled,
|
||
"expire_days": self.expire_days
|
||
}
|
||
|
||
|
||
def cached(cache_manager: CacheManager):
|
||
"""
|
||
缓存装饰器
|
||
|
||
用法:
|
||
@cached(answer_doc_cache)
|
||
def parse_answer_doc(url: str):
|
||
# 解析逻辑
|
||
return result
|
||
"""
|
||
def decorator(func):
|
||
@wraps(func)
|
||
def wrapper(*args, **kwargs):
|
||
# 生成缓存键(使用函数名和参数)
|
||
cache_key = f"{func.__name__}:{str(args)}:{str(kwargs)}"
|
||
|
||
# 尝试从缓存获取
|
||
cached_result = cache_manager.get(cache_key)
|
||
if cached_result is not None:
|
||
return cached_result
|
||
|
||
# 缓存未命中,执行函数
|
||
result = func(*args, **kwargs)
|
||
|
||
# 存入缓存
|
||
if result is not None:
|
||
cache_manager.set(cache_key, result)
|
||
|
||
return result
|
||
|
||
return wrapper
|
||
return decorator
|
||
|
||
|
||
# 创建全局缓存实例
|
||
answer_doc_cache = CacheManager(
|
||
cache_name="answer_doc",
|
||
maxsize=MAX_MEMORY_CACHE_SIZE,
|
||
expire_days=CACHE_EXPIRE_DAYS
|
||
)
|
||
|
||
grade_standards_cache = CacheManager(
|
||
cache_name="grade_standards",
|
||
maxsize=100, # 评分标准缓存数量较少
|
||
expire_days=CACHE_EXPIRE_DAYS
|
||
)
|