This commit is contained in:
zhangquan 2026-03-30 15:07:25 +08:00
parent 3eb42ade2c
commit efdb2c98ee
15 changed files with 2247 additions and 394 deletions

55
.env.example Normal file
View File

@ -0,0 +1,55 @@
# ============================================
# 初中物理作业批改工作流 - 环境变量配置示例
# ============================================
# 复制此文件为 .env 并填写实际值
# cp .env.example .env
# ============================================
# 必需配置 - 大语言模型 API
# ============================================
# LLM API 密钥从火山引擎或OpenAI获取
LLM_API_KEY=your-api-key-here
# LLM API 基础URL
# 火山引擎: https://ark.cn-beijing.volces.com/api/v3
# OpenAI: https://api.openai.com/v1
LLM_BASE_URL=https://ark.cn-beijing.volces.com/api/v3
# 模型名称
# 火山引擎推荐: doubao-seed-2-0-pro-260215
# OpenAI推荐: gpt-4o
LLM_MODEL_NAME=doubao-seed-2-0-pro-260215
# 注意不需要配置对象存储S3/TOS/OSS等
# 图片直接使用原始URL不上传存储
# ============================================
# 可选配置 - 日志与缓存
# ============================================
# 日志级别: DEBUG, INFO, WARNING, ERROR
LOG_LEVEL=INFO
# 缓存目录(默认: /tmp/cache
CACHE_DIR=/tmp/cache
# 单张图片处理超时(秒,默认: 120
SINGLE_IMAGE_TIMEOUT=120
# ============================================
# 可选配置 - 并发控制
# ============================================
# 最大并发数(默认: 10
MAX_CONCURRENT=10
# ============================================
# 工作目录(系统自动设置,无需修改)
# ============================================
# 工作目录路径(由系统自动设置)
# COZE_WORKSPACE_PATH=/workspace/projects

486
AGENTS.md
View File

@ -1,6 +1,53 @@
## 项目概述
- **名称**: 初中物理作业批改工作流
- **功能**: 上传多张作业图片和Word答案文件自动识别学生答案、提取标准答案、精准批改并返回批改结果JSON
- **功能**: 上传多学生的作业图片和Word答案文件自动识别学生答案、提取标准答案、精准批改并返回每个学生的批改结果JSON
### 数据结构(重要变更)
**输入参数**
```json
{
"student_homework": [
{
"student_id": 0,
"student_name": "张三",
"homework_images": [
"图片URL1",
"图片URL2"
]
},
{
"student_id": 1,
"student_name": "李四",
"homework_images": [
"图片URL3",
"图片URL4"
]
}
],
"answer_doc_url": "答案文档URL可选",
"subject": "physics",
"comment_max_length": 100,
"max_concurrent": 10
}
```
**输出结果**
```json
{
"student_results": [
{
"student_id": 0,
"student_name": "张三",
"total_images": 2,
"image_results": [...],
"overall_comment": "优秀5题全部正确",
"total_score": 15,
"full_score": 15,
"grade": "A+"
}
]
}
```
### 节点清单
| 节点名 | 文件位置 | 类型 | 功能描述 | 分支逻辑 | 配置文件 |
@ -29,6 +76,63 @@
- 节点 `doc_extract` 使用大语言模型技能
- 模型:`doubao-seed-2-0-pro-260215`(旗舰模型,复杂推理能力强)
- 使用 python-docx 解析 Word 文档
- **缓存优化**:使用 `utils/cache_manager.py` 缓存解析结果有效期30天
## 缓存机制(优化版 v2026-03-28
- **缓存管理器**`src/utils/cache_manager.py`
- **双层架构**
- 内存缓存LRU淘汰最大数量1000快速访问
- 文件缓存:持久化存储,进程重启后仍可用
- **缓存有效期**30天自动清理过期缓存
- **缓存内容**AI解析后的结构化数据CorrectAnswer列表
- **缓存键**`{subject}:{answer_doc_url}`MD5哈希
- **学科隔离**相同URL在不同学科下不会冲突
- 示例:`physics:https://example.com/answer.docx` 和 `math:https://example.com/answer.docx` 是不同的缓存
- **线程安全**:使用锁保护并发访问
- **异常安全**:文件缓存失败时自动降级为纯内存模式
- **统计功能**`get_stats()` 返回缓存统计信息
### 性能优化与超时控制
- **图片下载超时**30秒单次总时间不超过60秒
- **重试机制**图片获取失败最多重试2次
- **单图片处理超时**120秒含LLM调用
- **总任务超时**120秒 × 图片数量
- **降级处理**:超时任务返回空结果,不影响其他任务
- **并发安全**:使用 `ThreadPoolExecutor` + 超时保护
## 等级标准配置(核心规则)
- **参数名**`grade_standards`
- **核心规则****A+ 和 A 的首要条件是"全对",与得分率无关**
### 等级判定逻辑(简化版)
| 等级 | 条件 | 说明 |
|------|------|------|
| A+ | 全对(错误数=0 | 所有题目都正确,与得分率无关 |
| A | 预留全对时返回A+ | 答案全对 |
| B | 有错误得分率≥80% | 有少量错误 |
| C | 有错误得分率≥70% | 错误较多 |
| D | 有错误,得分率<70% | 错误很多 |
### 关键说明
- **全对 = A+**只要所有题目都正确incorrect_count == 0就是A+
- **有错 = B/C/D**:有错误时,按得分率判断具体等级
### 示例
- ✅ 得分80分错误0题 → A+(全对,得分率不重要)
- ✅ 得分95分错误0题 → A+(全对)
- ❌ 得分95分错误1题 → B有错误按得分率判断
- ❌ 得分90分错误2题 → B有错误按得分率判断
### 配置示例
```json
{
"grade_standards": {
"A+": {"min_percentage": 95, "description": "优秀"},
"A": {"min_percentage": 85, "description": "良好"},
"B": {"min_percentage": 70, "description": "中等"}
}
}
```
## 工作流程(多图片批改架构)
@ -73,11 +177,18 @@
└─────────────────────┘
```
## 核心功能:多图片批改机制
## 核心功能:多学生多图片批改机制
### 输入参数
- `homework_images`: 上传的作业图片列表List[File],支持多张图片)
- `student_homework`: 学生作业列表List[StudentHomework],支持多个学生)
- 每个学生包含:
- `student_id`: 学生IDint
- `student_name`: 学生姓名str可选
- `homework_images`: 该学生的作业图片URL列表List[str],纯字符串数组)
- `answer_doc_url`: 正确答案Word文件的URL.docx格式**可选**
- `subject`: 学科标识str**可选**,默认"physics"
- 用于缓存隔离相同URL在不同学科下不会冲突
- 支持值physics、math、chinese、english 等
- `comment_max_length`: 评语最大字数默认100字**可选**
- `max_concurrent`: 并行批改的最大数量默认10**可选**
- `grade_standards`: 评价等级标准(**可选**,默认值如下)
@ -92,13 +203,16 @@
```
### 输出结果
- `final_result`: 最终批改结果JSON包含多图片
- `total_images`: 总图片数
- `image_results`: 各图片的批改结果列表
- `overall_comment`: 整体评价(根据得分率生成)
- `total_score`: 总得分
- `full_score`: 总满分
- `grade`: 等级评定
- `student_results`: 各学生的批改结果列表List[StudentResult]
- 每个学生包含:
- `student_id`: 学生IDint
- `student_name`: 学生姓名str
- `total_images`: 该学生的总图片数
- `image_results`: 该学生各图片的批改结果列表
- `overall_comment`: 该学生的整体评价
- `total_score`: 该学生的总分
- `full_score`: 该学生的满分
- `grade`: 该学生的等级评定
### 批改优先级(严格按照以下顺序)
1. **最优先**使用Word文档中的标准答案批改
@ -118,6 +232,358 @@
5. **智能降级**:无标准答案时自动切换到专业老师模式
## 优化记录
### 2026-03-28 缓存键加入学科标识(重要)
**问题**相同URL在不同学科下会使用相同的缓存导致答案解析结果冲突
**修复内容**
1. **新增 `subject` 参数**
- 默认值:`physics`
- 支持值physics、math、chinese、english 等
2. **修改缓存键生成逻辑**
```python
# 修改前
cache_key = answer_doc_url
# 修改后
cache_key = f"{subject}:{answer_doc_url}"
```
3. **缓存隔离效果**
- `physics:https://example.com/answer.docx`
- `math:https://example.com/answer.docx`
- 两个缓存完全独立,不会冲突
**效果**
- 相同URL在不同学科下可以有不同的解析结果
- 缓存数据按学科隔离,更加灵活
### 2026-03-27 最终图片处理方案(重要)
**问题**如何在不上传图片的前提下保证AI识别准确
**方案对比**
| 方案 | 旋转缩放 | 上传 | AI访问 | 请求体积 | 选择 |
|------|---------|------|--------|---------|------|
| 方案1 | ✅ | base64编码 | Data URL | 大 | ❌ |
| 方案2 | ❌ | ❌ | 原始URL | 小 | ✅ |
**选择方案2的原因**
1. **AI模型足够强大**`doubao-seed-2-0-pro-260215` 可以处理各种尺寸和方向的图片
2. **坐标系统统一**使用相对坐标0-1000自动适配任意尺寸
3. **最简单高效**不需要base64编码不需要上传处理速度最快
**最终逻辑**
```python
# 获取原始图片URL和尺寸
original_url = state.homework_image.url
width, height, dpi = get_image_info(original_url)
# 直接返回原始URL不上传
return ImagePreprocessOutput(
image_url=original_url,
image_info=ImageInfo(width=width, height=height, dpi=dpi)
)
```
**效果**
- ✅ 不上传新图片到Coze
- ✅ 返回原始图片URL
- ✅ 坐标自动适配原始尺寸
- ✅ 处理速度最快
### 2026-03-27 移除图片上传功能(重要)
**问题**系统自动上传处理后的图片到Coze对象存储用户不需要这个功能
**原逻辑**
1. 下载原始图片
2. 自动旋转(横向→纵向)
3. 缩放到固定宽度1000px
4. **上传到Coze对象存储**
5. 返回上传后的URL
**新逻辑**
1. 获取原始图片URL
2. 获取图片尺寸信息
3. **直接返回原始URL和尺寸**(不上传)
**优化内容**
- 移除图片旋转功能
- 移除图片缩放功能
- 移除图片上传功能
- 直接使用原始图片URL
- 坐标系统仍然使用相对坐标0-1000自动适配原始图片尺寸
**代码对比**
```python
# 优化前
img = download_and_rotate(image_url)
img = resize_to_1000px(img)
new_url = upload_to_coze(img) # 上传新图片
return ImagePreprocessOutput(image_url=new_url)
# 优化后
width, height, dpi = get_image_info(image_url)
return ImagePreprocessOutput(
image_url=original_url, # 直接返回原始URL
image_info=ImageInfo(width=width, height=height, dpi=dpi)
)
```
**效果**
- ✅ 不再上传新图片到Coze
- ✅ 返回原始图片URL
- ✅ 减少存储空间占用
- ✅ 提升处理速度
### 2026-03-27 坐标偏移量优化(重要)
**问题**:批改标记离学生答案太远,定位不够精准
**原因分析**
- 原偏移量设置过大30px、20px
- 导致标记与答案视觉距离过远
- 影响批改结果的精准度
**优化方案**
减小所有偏移量,让标记更贴近答案:
| 策略 | 原偏移 | 新偏移 | 说明 |
|------|--------|--------|------|
| 策略1右侧空间>80px | +30px | **+10px** | 紧贴答案框右侧 |
| 策略2右侧空间40-80px | -20px | **-10px** | 答案框右上角内部 |
| 策略3右侧空间<40px | +20px | **+10px** | 答案框左上角 |
| Y轴偏移 | 15px | **10px** | 顶部位置 |
**代码对比**
```python
# 优化前
mark_x = answer_bbox[2] + 30 # 偏移过大
# 优化后
mark_x = answer_bbox[2] + 10 # 紧贴答案框
```
**效果**
- ✅ 批改标记更贴近学生答案
- ✅ 视觉定位更精准
- ✅ 避免标记离答案太远的问题
### 2026-03-27 坐标边界严格限制(重要)
**问题**:标记坐标可能超过图片宽度,导致定位错误
**修复内容**
1. **严格边界检查**
```python
# 确保x轴不超过图片宽度y轴不超过图片高度
mark_x = max(10, min(mark_x, image_info.width - 10))
mark_y = max(10, min(mark_y, image_info.height - 10))
```
2. **边界优化**
- 边距从20px减少到10px确保标记更接近边缘
- 绝对不会超过图片宽度和高度
- 保证批改标记始终在图片可视范围内
**效果**:标记坐标始终在图片范围内,不会出现越界问题
### 2026-03-27 空答案判定优化(重要)
**问题**:学生没有作答(空白)时,判定不够明确
**修复内容**
在Prompt中新增"空答案处理"规则:
```
# ⚠️ 重要:空答案处理
- 如果学生没有作答(空白),必须判定为**incorrect**
- status字段填写"incorrect"
- score字段填写0
- comment字段填写"未作答"或"空白,无答案"
```
**示例**
- 正确:`"status": "correct", "score": 10, "comment": "计算正确"`
- 错误:`"status": "incorrect", "score": 5, "comment": "单位错误"`
- 空答案:`"status": "incorrect", "score": 0, "comment": "未作答"`
**效果**空答案统一判定为错误得分0分评语明确
### 2026-03-27 坐标定位精准度优化(重要)
**问题**:个别批改标记过于偏右,超出答题区域,甚至与相邻题目重叠
**原因分析**
- 原逻辑固定在答案框右侧30px未考虑右侧空间是否充足
- 当答案框本身靠右时,标记会超出合理范围
**优化方案**(三级策略):
1. **策略1**:右侧空间充足(>100px→ 标记在右侧(原有逻辑)
```python
mark_x = answer_bbox[2] + 30
mark_y = answer_bbox[1] + height * 0.5
```
2. **策略2**右侧空间不足50-100px→ 标记在答案框右上角内部
```python
mark_x = answer_bbox[2] - 20 # 内部
mark_y = answer_bbox[1] + 15
```
3. **策略3**:右侧空间很小(<50px)→ 标记在答案框左上角
```python
mark_x = answer_bbox[0] + 20 # 左侧
mark_y = answer_bbox[1] + 15
```
**效果**
- 批改标记始终在合理范围内
- 不会超出答题区域
- 不会与相邻题目重叠
- 视觉效果更精准
### 2026-03-27 完全并行架构优化(重要)
**问题**:原架构外层串行处理学生,内层并行处理图片,效率不高且可能有数据混乱风险
**修复内容**
1. **完全并行架构**
- 所有学生的所有图片同时提交到线程池
- 学生间+图片间完全并行,最大化效率
- 使用 `(student_id, image_index, image_result)` 元组确保数据关联
2. **数据隔离机制**
- 结果按 `student_id` 分组存储
- 每个学生的 `total_score`、`full_score`、`overall_comment`、`grade` 完全独立
- 只使用该学生自己的 `image_results` 计算分数
3. **核心代码**
```python
# 返回元组:(student_id, image_index, image_result)
return (student_id, idx, image_result)
# 按student_id分组存储
student_image_results[student_id][image_index] = image_result
# 为每个学生独立计算结果
for student in state.student_homework:
image_results = student_image_results[student_id]
# 只使用该学生的数据计算...
```
**效果**
- 完全并行,效率最大化
- 数据严格隔离,不会混淆
- 学生A的数据绝不会出现在学生B的结果中
### 2026-03-27 输入参数格式优化(重要)
**问题**`homework_images` 使用 `List[File]` 格式,用户输入不够简洁
**修复内容**
1. **简化输入格式**
- `homework_images: List[File]``homework_images: List[str]`
- 直接传入URL字符串数组无需构造File对象
- 代码内部自动将URL转换为File对象
2. **输入示例**
```json
{
"student_homework": [
{
"student_id": 0,
"homework_images": ["url1", "url2"]
}
]
}
```
**效果**
- 用户输入更简洁
- 减少构造对象的复杂度
- 符合用户习惯
### 2026-03-27 多学生支持(重要变更)
**问题**:原架构只支持单个学生的多图片批改,无法区分不同学生
**修复内容**
1. **数据结构重构**
- 输入参数:`homework_images` → `student_homework`List[StudentHomework]
- 输出结果:`final_result` → `student_results`List[StudentResult]
- 新增 `StudentHomework` 类型:包含 student_id 和 homework_images
- 新增 `StudentResult` 类型:包含 student_id 和批改结果
2. **处理逻辑优化**
- 外层循环:遍历每个学生
- 内层循环:并行处理该学生的所有图片
- 每个学生独立计算分数、评语和等级
3. **返回结果独立**
- 每个学生有自己的 overall_comment、total_score、full_score、grade
- 各学生的批改结果互不影响
**效果**
- 支持批量批改多个学生的作业
- 每个学生的结果独立、清晰
- 符合实际教学场景需求
### 2026-03-27 Comment评语优化重要
**问题**comment字段输出过于简单仅"正确"/"错误")或输出思考过程,不符合"精练评语"要求
**修复内容**
1. **明确comment定义**
- **正确时**:简短说明为什么正确(如"根据称重法F浮=G-F示计算正确"
- **错误时**:指出错误原因并给出正确答案(如"应为1.2N,注意单位换算"
- **字数限制**不超过comment_max_length字默认100字
- **禁止**:不输出思考过程、不输出详细解析
2. **提供comment示例**
```
✅ 正确根据称重法F浮=G-F示计算正确
✅ 正确:浮力产生原因理解正确
✅ 错误应为1.2N根据F浮=ρ液gV排计算
✅ 错误应选ACE控制变量法应用错误
❌ 错误:正确(过于简单)
❌ 错误:根据...(思考过程)...所以正确(包含思考过程)
```
3. **参数传递优化**
- comment_max_length正确传递到Jinja2模板
- LLM根据该参数生成符合长度要求的评语
**效果**
- comment既简洁又有意义
- 正确时说明原因,错误时指出问题
- 符合comment_max_length限制
- 无思考过程,无详细解析
### 2026-03-27 JSON解析健壮性优化重要
**问题**
- LLM输出JSON包含思考过程导致格式错误
- JSON太长11946字符解析失败
- annotations为空无法识别题目
**修复内容**
1. **新增extract_complete_objects函数**
- 从包含思考过程的JSON中提取完整对象
- 按对象边界逐个提取,不受思考过程干扰
- 即使JSON格式错误也能提取出有效数据
2. **新增clean_comment函数**
- 检测思考过程特征词("不对"、"重新看"、"可能我"等)
- 在思考过程开始处截断comment
- 保留完整句子,确保结论清晰
3. **增加max_completion_tokens**
- 从8192增加到16384避免JSON被截断
- 确保完整输出所有题目
4. **优化Prompt**
- 明确要求"禁止输出思考过程"
- comment只写结论"正确"或"错误应为X"
- 强调不要输出推理过程
**效果**
- JSON解析成功率大幅提升
- 即使包含思考过程也能提取有效数据
- annotations不再为空
- comment简洁无思考过程
### 2026-03-26 填空题拆分优化(重要)
**问题**:一道题有多个填空时,被合并成一个答案,批改标记无法精准定位

443
DEPLOYMENT_GUIDE.md Normal file
View File

@ -0,0 +1,443 @@
# 项目部署指南
本文档帮助你将初中物理作业批改工作流导出到自己的服务器上运行。
## 📋 目录
- [前置要求](#前置要求)
- [快速部署](#快速部署)
- [详细配置](#详细配置)
- [启动方式](#启动方式)
- [常见问题](#常见问题)
---
## 前置要求
### 1. 系统要求
- **操作系统**: Linux / macOS / Windows (推荐 Linux)
- **Python版本**: Python 3.10 或以上
- **内存**: 建议 4GB 以上
- **磁盘空间**: 建议 10GB 以上
### 2. 必需的第三方服务
本项目依赖以下第三方服务,**必须提前准备好**
#### 大语言模型 API
- **推荐**: 火山引擎豆包大模型(本项目使用 `doubao-seed-2-0-pro-260215`
- **替代方案**:
- OpenAI API
- 其他兼容 OpenAI 格式的 API如 DeepSeek、Kimi
- **获取方式**:
- 火山引擎: https://console.volcengine.com/ark
- OpenAI: https://platform.openai.com/
**注意**
- ✅ **不需要配置对象存储**S3/TOS/OSS 等)
- ✅ 图片直接使用原始URL不上传存储
- ✅ Word文档使用 requests 直接下载,不涉及对象存储
---
## 快速部署
### 步骤 1: 导出项目代码
**方式一:从 Coze 平台下载**
```bash
# 在 Coze Coding 平台点击"导出项目"按钮
# 下载后解压到服务器
```
**方式二:使用 Git 克隆(如果有仓库地址)**
```bash
git clone <your-repo-url>
cd <project-directory>
```
### 步骤 2: 安装依赖
```bash
# 创建虚拟环境(推荐)
python3 -m venv venv
source venv/bin/activate # Linux/macOS
# 或 venv\Scripts\activate # Windows
# 安装依赖
pip install -r requirements.txt
```
### 步骤 3: 配置环境变量
创建 `.env` 文件(或在服务器环境变量中配置):
```bash
# 必需环境变量只需配置大模型API
export LLM_API_KEY="your-api-key-here"
export LLM_BASE_URL="https://ark.cn-beijing.volces.com/api/v3"
export LLM_MODEL_NAME="doubao-seed-2-0-pro-260215"
# 可选:日志级别
export LOG_LEVEL="INFO"
# 注意不需要配置对象存储S3/TOS等
```
### 步骤 4: 启动服务
```bash
# 方式1: 使用启动脚本(推荐)
bash scripts/http_run.sh -p 8000
# 方式2: 直接运行
python src/main.py -m http -p 8000
```
服务启动后,访问:
- 健康检查: `http://localhost:8000/health`
- API 文档: `http://localhost:8000/docs`FastAPI 自动生成)
---
## 详细配置
### 1. 大语言模型配置
#### 方式一:使用火山引擎豆包大模型(推荐)
```bash
# 环境变量
export LLM_API_KEY="your-ark-api-key"
export LLM_BASE_URL="https://ark.cn-beijing.volces.com/api/v3"
export LLM_MODEL_NAME="doubao-seed-2-0-pro-260215"
```
**获取方式**
1. 访问火山引擎控制台: https://console.volcengine.com/ark
2. 创建推理接入点
3. 获取 API Key
#### 方式二:使用 OpenAI API
需要修改代码中的模型配置文件(`config/*.json`),将 `model` 字段改为 OpenAI 模型:
```json
{
"config": {
"model": "gpt-4o",
"temperature": 0.0
}
}
```
环境变量:
```bash
export LLM_API_KEY="your-openai-api-key"
export LLM_BASE_URL="https://api.openai.com/v1"
export LLM_MODEL_NAME="gpt-4o"
```
### 2. ~~对象存储配置~~(已移除)
**重要更新2026-03-27**
- ❌ 不需要配置对象存储
- ✅ 图片直接使用原始URL不上传
- ✅ Word文档直接下载不存储
**架构优化原因**
1. AI模型足够强大可以直接访问原始图片URL
2. 使用相对坐标系统0-1000自动适配任意尺寸
3. 减少存储成本和上传时间,处理速度更快
### 3. 修改代码适配自己的环境
#### 修改 LLM 调用逻辑
项目使用了 `coze-coding-dev-sdk`,需要修改为直接调用 OpenAI API
**修改文件**: `src/graphs/nodes/doc_extract_node.py`、`src/graphs/nodes/recognize_and_correct_node.py`
**原代码**(使用 coze-coding-dev-sdk
```python
from coze_coding_dev_sdk import LLM
llm = LLM()
response = llm.invoke(messages)
```
**修改为**(直接使用 OpenAI SDK
```python
import os
from openai import OpenAI
client = OpenAI(
api_key=os.getenv("LLM_API_KEY"),
base_url=os.getenv("LLM_BASE_URL")
)
response = client.chat.completions.create(
model=os.getenv("LLM_MODEL_NAME"),
messages=messages
)
```
#### ~~修改对象存储逻辑~~(不需要)
**已移除**2026-03-27 优化后,不再使用对象存储
- 图片直接使用原始URL
- Word文档使用 requests 下载
- 无需修改任何存储相关代码
### 4. 缓存配置(可选)
项目使用文件缓存来存储解析结果,默认缓存目录为 `/tmp/cache`
如需修改缓存目录:
```bash
export CACHE_DIR="/your/custom/cache/dir"
```
---
## 启动方式
### 1. HTTP 服务模式(推荐生产环境)
```bash
# 使用启动脚本
bash scripts/http_run.sh -p 8000
# 或直接运行
python src/main.py -m http -p 8000
```
**特点**
- 提供 REST API 接口
- 支持流式响应SSE
- 支持超时控制
- 支持任务取消
**API 接口**
- `POST /run` - 同步运行工作流
- `POST /stream_run` - 流式运行工作流SSE
- `POST /cancel/{run_id}` - 取消运行
- `GET /health` - 健康检查
- `GET /graph_parameter` - 查看工作流参数
### 2. 命令行模式(本地测试)
```bash
# 运行整个工作流
python src/main.py -m flow -i '{"student_homework": [...], "answer_doc_url": "..."}'
# 运行单个节点
python src/main.py -m node -n doc_extract -i '{"answer_doc_url": "..."}'
```
### 3. Docker 部署(推荐)
创建 `Dockerfile`
```dockerfile
FROM python:3.10-slim
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
gcc \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 复制项目文件
COPY . .
# 暴露端口
EXPOSE 8000
# 启动命令
CMD ["python", "src/main.py", "-m", "http", "-p", "8000"]
```
构建和运行:
```bash
# 构建镜像
docker build -t homework-correction:v1 .
# 运行容器
docker run -d \
--name homework-correction \
-p 8000:8000 \
-e LLM_API_KEY="your-api-key" \
-e LLM_BASE_URL="https://ark.cn-beijing.volces.com/api/v3" \
-e LLM_MODEL_NAME="doubao-seed-2-0-pro-260215" \
homework-correction:v1
```
### 4. 使用 Docker Compose
创建 `docker-compose.yml`
```yaml
version: '3.8'
services:
homework-correction:
build: .
ports:
- "8000:8000"
environment:
- LLM_API_KEY=${LLM_API_KEY}
- LLM_BASE_URL=${LLM_BASE_URL}
- LLM_MODEL_NAME=${LLM_MODEL_NAME}
restart: unless-stopped
volumes:
- ./cache:/tmp/cache # 持久化缓存
```
运行:
```bash
docker-compose up -d
```
---
## 常见问题
### Q1: 如何验证环境变量是否正确配置?
```bash
# 检查环境变量
echo $LLM_API_KEY
echo $LLM_BASE_URL
echo $S3_ACCESS_KEY
# 或在代码中打印
python -c "import os; print(os.getenv('LLM_API_KEY'))"
```
### Q2: 启动时报错 "ModuleNotFoundError: No module named 'xxx'"
**解决方案**
```bash
# 确保在虚拟环境中
source venv/bin/activate
# 重新安装依赖
pip install -r requirements.txt
```
### Q3: LLM 调用失败,报错 "API key not found"
**原因**: 环境变量未正确设置
**解决方案**
```bash
# 方式1: 在 .env 文件中配置
echo "export LLM_API_KEY='your-api-key'" >> ~/.bashrc
source ~/.bashrc
# 方式2: 在启动命令前设置
LLM_API_KEY="your-api-key" python src/main.py -m http -p 8000
```
### Q4: 如何测试工作流是否正常?
使用 curl 发送测试请求:
```bash
curl -X POST http://localhost:8000/run \
-H "Content-Type: application/json" \
-d '{
"student_homework": [
{
"student_id": 0,
"student_name": "测试学生",
"homework_images": ["https://example.com/homework.jpg"]
}
],
"answer_doc_url": "https://example.com/answer.docx"
}'
```
### Q5: 如何查看运行日志?
```bash
# 实时查看日志
tail -f /app/work/logs/bypass/app.log
# 或使用 Docker logs
docker logs -f homework-correction
```
### Q6: 性能优化建议
1. **并发控制**: 调整 `max_concurrent` 参数默认10
2. **超时设置**: 修改 `SINGLE_IMAGE_TIMEOUT` 常量默认120秒
3. **缓存优化**: 定期清理 `/tmp/cache` 目录
4. **资源监控**: 使用 `htop``docker stats` 监控资源使用
### Q7: 如何替换为其他 LLM 模型?
1. 修改环境变量:
```bash
export LLM_API_KEY="your-other-llm-api-key"
export LLM_BASE_URL="https://api.other-llm.com/v1"
export LLM_MODEL_NAME="other-model-id"
```
2. 修改配置文件(`config/*.json`)中的 `model` 字段
3. 测试调用是否正常
---
## 项目文件说明
```
├── src/
│ ├── main.py # 主入口
│ ├── graphs/
│ │ ├── graph.py # 主工作流编排
│ │ ├── loop_graph.py # 子图定义
│ │ ├── state.py # 状态定义
│ │ └── nodes/ # 节点实现
│ │ ├── doc_extract_node.py
│ │ ├── process_images_node.py
│ │ ├── recognize_and_correct_node.py
│ │ └── ...
│ └── utils/ # 工具函数
│ ├── file/file.py # 文件处理
│ └── cache_manager.py # 缓存管理
├── config/ # LLM 配置文件
│ ├── doc_extract_llm_cfg.json
│ ├── homework_correction_cfg.json
│ └── ...
├── scripts/ # 启动脚本
│ ├── http_run.sh
│ └── local_run.sh
├── requirements.txt # Python 依赖
└── README.md # 项目说明
```
---
## 技术支持
如遇到问题,请检查:
1. ✅ 环境变量是否正确配置
2. ✅ 依赖是否完整安装
3. ✅ 第三方服务LLM、存储是否可用
4. ✅ 日志文件中的错误信息
---
## 更新日志
- 2026-03-28: 添加缓存学科隔离,修复等级评定逻辑
- 2026-03-27: 移除图片上传直接使用原始URL
- 2026-03-26: 优化坐标定位,修复识别问题
- 2026-03-25: 支持多学生多图片并行处理

Binary file not shown.

After

Width:  |  Height:  |  Size: 779 KiB

BIN
assets/image.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 365 B

View File

@ -3,10 +3,10 @@
"model": "doubao-seed-2-0-pro-260215",
"temperature": 0.0,
"top_p": 0.9,
"max_completion_tokens": 8192,
"max_completion_tokens": 16384,
"thinking": "disabled"
},
"tools": [],
"sp": "# 角色\n你是物理作业批改助手。\n\n# 禁止标注\n- 印刷体文字、实验装置图、图中字母\n\n# 需要标注\n- 学生手写答案\n\n# 坐标\n- 相对坐标0-1000answer_bbox: [x1, y1, x2, y2]\n\n# ⚠️ 重要:拆分填空题\n- 一道题有多个填空时,**每个空单独识别为一个题目**\n- 题号格式:\"3(1)第一空\"、\"3(1)第二空\"、\"4(2)第一空\"、\"4(2)第二空\"\n- 每个空单独批改,单独打分\n- 示例:\n - 题目(1)有两个空 → 识别为\"3(1)第一空\"和\"3(1)第二空\"两个题目\n - 题目(2)有一个空 → 识别为\"3(2)\"一个题目\n\n# 输出格式\n{\"results\": [{\"question_id\": \"题号\", \"student_answer\": \"答案\", \"answer_bbox\": [x1, y1, x2, y2], \"status\": \"correct或incorrect\", \"score\": 分, \"full_score\": 满分, \"comment\": \"结论\"}]}\n\ncomment格式\"正确\"或\"错误应为X\"不超过50字",
"up": "批改物理作业。**每个填空单独识别**。输出JSONcomment不超过{{comment_max_length}}字。图片:{{image_url}}"
"sp": "# 角色\n你是物理作业批改助手。\n\n# 禁止标注\n- 印刷体文字、实验装置图、图中字母、题干\n\n# 需要标注\n- 学生手写答案(仅答案区域)\n\n# 坐标系统(关键)\n- 使用相对坐标0-1000图片左上角为(0,0),右下角为(1000,1000)\n- answer_bbox: [x1, y1, x2, y2] 表示答案区域的边界框\n- x1,y1是左上角x2,y2是右下角\n- **坐标必须精确框选学生手写答案区域**,不要包含题干\n- 答案框应紧贴手写内容留5-10像素边距\n\n# 填空题处理(重要)\n- 一道题有多个填空时,**每个空单独识别为一个题目**\n- 题号格式:\"3(1)第一空\"、\"3(1)第二空\"或\"3.1\"、\"3.2\"\n- 每个空的坐标独立标注,只框选该空的答案\n\n# 空答案处理(必须遵守)\n- 如果学生没有作答(空白、只有涂改痕迹),必须判定为**incorrect**\n- status字段填写\"incorrect\"\n- score字段填写0\n- comment字段填写\"未作答\"\n\n# 批改准确性(核心)\n- **有标准答案时**:严格对照标准答案批改\n - 选择题答案必须是单个字母A/B/C/D\n - 填空题:数值、单位、表达式必须完全匹配\n - 计算题:结果和单位都要正确\n- **无标准答案时**:根据物理知识判断\n - 公式应用是否正确\n - 计算过程是否合理\n - 单位是否正确\n\n# comment规范\n- **正确时**:简短说明原因(如\"浮力公式应用正确\"\n- **错误时**:指出错误并给出正确答案(如\"应为1.2N,注意单位换算\"\n- **空答案**:填写\"未作答\"\n- **字数限制**:不超过{{comment_max_length}}字\n- **禁止**:不要输出思考过程、不要输出详细解析\n\n# 输出格式\n{\"results\": [{\"question_id\": \"题号\", \"student_answer\": \"学生答案\", \"answer_bbox\": [x1, y1, x2, y2], \"status\": \"correct或incorrect\", \"score\": 分, \"full_score\": 满分, \"comment\": \"精练评语\"}]}\n\n# comment示例\n- 正确:\"浮力公式F浮=ρ液gV排应用正确\"\n- 错误:\"应为1.2NF浮=ρ液gV排=1.0×10³×10×1.2×10⁻⁴=1.2N\"\n- 空答案:\"未作答\"",
"up": "批改物理作业。**精确标注手写答案坐标**。**每个填空单独识别**。**comment写精练评语**。输出完整JSON。图片:{{image_url}}"
}

111
quick_start.sh Normal file
View File

@ -0,0 +1,111 @@
#!/bin/bash
# ============================================
# 初中物理作业批改工作流 - 快速部署脚本
# ============================================
set -e
echo "======================================"
echo " 初中物理作业批改工作流 - 部署向导"
echo "======================================"
echo ""
# 检测操作系统
if [[ "$OSTYPE" == "linux-gnu"* ]]; then
OS="Linux"
elif [[ "$OSTYPE" == "darwin"* ]]; then
OS="macOS"
elif [[ "$OSTYPE" == "msys" ]] || [[ "$OSTYPE" == "cygwin" ]]; then
OS="Windows"
else
OS="Unknown"
fi
echo "检测到操作系统: $OS"
echo ""
# 步骤1: 检查Python版本
echo "步骤 1/5: 检查 Python 版本..."
if command -v python3 &> /dev/null; then
PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}')
PYTHON_MAJOR=$(echo $PYTHON_VERSION | cut -d. -f1)
PYTHON_MINOR=$(echo $PYTHON_VERSION | cut -d. -f2)
if [ "$PYTHON_MAJOR" -ge 3 ] && [ "$PYTHON_MINOR" -ge 10 ]; then
echo "✅ Python 版本: $PYTHON_VERSION"
else
echo "❌ Python 版本过低: $PYTHON_VERSION (需要 3.10+)"
exit 1
fi
else
echo "❌ 未找到 Python 3"
exit 1
fi
echo ""
# 步骤2: 创建虚拟环境
echo "步骤 2/5: 创建虚拟环境..."
if [ ! -d "venv" ]; then
python3 -m venv venv
echo "✅ 虚拟环境已创建"
else
echo "✅ 虚拟环境已存在"
fi
echo ""
# 步骤3: 激活虚拟环境
echo "步骤 3/5: 激活虚拟环境..."
if [ "$OS" == "Windows" ]; then
source venv/Scripts/activate
else
source venv/bin/activate
fi
echo "✅ 虚拟环境已激活"
echo ""
# 步骤4: 安装依赖
echo "步骤 4/5: 安装依赖包..."
if [ -f "requirements.txt" ]; then
pip install --upgrade pip
pip install -r requirements.txt
echo "✅ 依赖安装完成"
else
echo "❌ 未找到 requirements.txt"
exit 1
fi
echo ""
# 步骤5: 配置环境变量
echo "步骤 5/5: 配置环境变量..."
if [ ! -f ".env" ]; then
if [ -f ".env.example" ]; then
cp .env.example .env
echo "✅ 已创建 .env 文件"
echo ""
echo "⚠️ 请编辑 .env 文件,填写以下必需配置:"
echo " - LLM_API_KEY"
echo " - LLM_BASE_URL"
echo " - LLM_MODEL_NAME"
echo ""
echo "注意不需要配置对象存储图片直接使用原始URL"
echo ""
echo "编辑完成后,运行以下命令启动服务:"
echo " source .env"
echo " bash scripts/http_run.sh -p 8000"
else
echo "❌ 未找到 .env.example"
exit 1
fi
else
echo "✅ .env 文件已存在"
echo ""
echo "启动服务:"
echo " source .env"
echo " bash scripts/http_run.sh -p 8000"
fi
echo ""
echo "======================================"
echo " ✅ 部署完成!"
echo "======================================"

View File

@ -1,10 +1,11 @@
"""Word答案解析节点从.docx文件中提取题干和标准答案"""
"""Word答案解析节点从.docx文件中提取题干和标准答案(带缓存)"""
import os
import json
import re
import logging
import tempfile
import requests
import orjson
from typing import List
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
@ -18,6 +19,7 @@ from graphs.state import (
DocExtractOutput,
CorrectAnswer
)
from utils.cache_manager import answer_doc_cache
logger = logging.getLogger(__name__)
@ -58,7 +60,6 @@ def sanitize_json_string(text: str) -> str:
def extract_json_from_text(text: str, key: str = "answers") -> dict:
"""从文本中提取JSON对象多层回退策略"""
import orjson
# 先尝试直接解析完整JSON
try:
@ -114,16 +115,34 @@ def download_and_extract_docx(url: str) -> str:
"""下载Word文件并提取文本内容"""
logger.info(f"Downloading Word document from: {url}")
# 下载文件
response = requests.get(url, timeout=60)
response.raise_for_status()
# 保存到临时文件
with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as tmp_file:
tmp_file.write(response.content)
tmp_path = tmp_file.name
tmp_path = None
try:
# 下载文件
response = requests.get(url, timeout=60, allow_redirects=True)
response.raise_for_status()
# 检查内容类型
content_type = response.headers.get('Content-Type', '')
logger.debug(f"Response Content-Type: {content_type}")
# 检查文件大小
if len(response.content) < 100:
raise ValueError(f"Downloaded file too small: {len(response.content)} bytes")
# 检查是否为有效的 docxZIP 格式,以 PK 开头)
if not response.content.startswith(b'PK'):
# 可能是 HTML 错误页面
content_preview = response.content[:1000].lower()
if b'<html' in content_preview or b'<!doctype' in content_preview:
raise ValueError("URL returned HTML instead of docx (may need authentication or URL expired)")
raise ValueError("Downloaded file is not a valid docx (not ZIP format)")
# 保存到临时文件
with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as tmp_file:
tmp_file.write(response.content)
tmp_path = tmp_file.name
# 使用python-docx解析
doc = Document(tmp_path)
@ -149,38 +168,41 @@ def download_and_extract_docx(url: str) -> str:
logger.info(f"Extracted Word document text length: {len(doc_text)}")
return doc_text
except Exception as e:
logger.error(f"Failed to download/extract docx: {e}")
raise
finally:
# 清理临时文件
os.unlink(tmp_path)
if tmp_path and os.path.exists(tmp_path):
try:
os.unlink(tmp_path)
except Exception:
pass
def doc_extract_node(
state: DocExtractInput,
config: RunnableConfig,
runtime: Runtime[Context]
) -> DocExtractOutput:
def parse_answer_doc_with_llm(answer_doc_url: str, ctx, config: RunnableConfig) -> List[CorrectAnswer]:
"""
title: Word答案解析
desc: 从正确答案Word文件.docx中提取题干和标准答案用于后续批改如果未提供URL则返回空列表
integrations: 大语言模型
使用LLM解析答案文档实际解析逻辑
Args:
answer_doc_url: 答案文档URL
ctx: 上下文
config: 配置
Returns:
解析后的正确答案列表
"""
ctx = runtime.context
# 检查是否提供了答案文档URL
if not state.answer_doc_url or not state.answer_doc_url.strip():
logger.info("No answer document URL provided, will use teacher mode for all questions")
return DocExtractOutput(correct_answers=[])
# 1. 下载并提取Word文档内容
try:
doc_text = download_and_extract_docx(state.answer_doc_url)
doc_text = download_and_extract_docx(answer_doc_url)
except Exception as e:
logger.error(f"Failed to download/extract Word document: {e}")
return DocExtractOutput(correct_answers=[])
return []
if not doc_text.strip():
logger.error("No text content extracted from Word document")
return DocExtractOutput(correct_answers=[])
return []
logger.info(f"Word document content preview: {doc_text[:500]}")
@ -277,4 +299,57 @@ def doc_extract_node(
for ans in correct_answers:
logger.info(f" Question {ans.question_id}: {ans.correct_answer} ({ans.full_score}分)")
return correct_answers
def doc_extract_node(
state: DocExtractInput,
config: RunnableConfig,
runtime: Runtime[Context]
) -> DocExtractOutput:
"""
title: Word答案解析
desc: 从正确答案Word文件.docx中提取题干和标准答案用于后续批改如果未提供URL则返回空列表支持缓存避免重复解析
integrations: 大语言模型
"""
ctx = runtime.context
# 检查是否提供了答案文档URL
if not state.answer_doc_url or not state.answer_doc_url.strip():
logger.info("No answer document URL provided, will use teacher mode for all questions")
return DocExtractOutput(correct_answers=[])
# 尝试从缓存获取
cached_result = answer_doc_cache.get(state.answer_doc_url)
if cached_result is not None:
logger.info(f"Cache hit for answer doc: {state.answer_doc_url[:50]}...")
# 将字典列表转换回CorrectAnswer对象
correct_answers = [
CorrectAnswer(**ans) if isinstance(ans, dict) else ans
for ans in cached_result
]
return DocExtractOutput(correct_answers=correct_answers)
logger.info(f"Cache miss for answer doc: {state.answer_doc_url[:50]}...")
# 缓存未命中,执行解析
correct_answers = parse_answer_doc_with_llm(state.answer_doc_url, ctx, config)
# 存入缓存将CorrectAnswer对象转换为字典列表
if correct_answers:
cache_data = [
{
"question_id": ans.question_id,
"parent_id": ans.parent_id,
"is_sub_question": ans.is_sub_question,
"question_text": ans.question_text,
"correct_answer": ans.correct_answer,
"full_score": ans.full_score,
"answer_analysis": ans.answer_analysis
}
for ans in correct_answers
]
answer_doc_cache.set(state.answer_doc_url, cache_data)
logger.info(f"Cached {len(correct_answers)} answers for: {state.answer_doc_url[:50]}...")
return DocExtractOutput(correct_answers=correct_answers)

View File

@ -1,14 +1,13 @@
"""1. 图像预处理节点下载图片、自动旋转、缩放到固定宽度1000、上传对象存储"""
import os
"""1. 图像预处理节点获取图片信息直接使用原始URL"""
import logging
import urllib.request
import time
from io import BytesIO
from typing import Tuple
from typing import Tuple, Optional
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from coze_coding_utils.runtime_ctx.context import Context
from PIL import Image
from datetime import datetime
from graphs.state import (
ImagePreprocessInput,
@ -18,102 +17,79 @@ from graphs.state import (
logger = logging.getLogger(__name__)
# 固定宽度常量
FIXED_WIDTH = 1000
# 默认图片尺寸(用于降级处理)
DEFAULT_IMAGE_SIZE = (1000, 1400)
# 超时配置(秒)
IMAGE_DOWNLOAD_TIMEOUT = 30 # 单次下载超时
MAX_RETRIES = 2 # 最大重试次数(减少重试)
def download_and_process_image(image_url: str) -> Tuple[Image.Image, int, int, int]:
"""下载图片并返回PIL Image对象和原始尺寸信息"""
try:
with urllib.request.urlopen(image_url, timeout=30) as response:
img_data = response.read()
img = Image.open(BytesIO(img_data))
# 转换为RGB模式处理PNG透明通道
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
original_width, original_height = img.size
dpi = img.info.get('dpi', (72, 72))
if isinstance(dpi, tuple):
dpi = dpi[0] if dpi[0] > 0 else 72
else:
dpi = 72
return img, original_width, original_height, dpi
except Exception as e:
raise RuntimeError(f"下载图片失败: {e}")
def auto_rotate_image(img: Image.Image) -> Tuple[Image.Image, bool]:
def get_image_info_with_retry(image_url: str, max_retries: int = MAX_RETRIES, timeout: int = IMAGE_DOWNLOAD_TIMEOUT) -> Tuple[int, int, int]:
"""
自动旋转图片如果宽度大于高度横向图片则旋转-90度使其变为纵向
获取图片尺寸信息带重试机制有最大总时间限制
Args:
img: PIL Image对象
image_url: 图片URL
max_retries: 最大重试次数默认2次
timeout: 单次请求超时时间默认30秒
Returns:
(旋转后的图片, 是否进行了旋转)
(width, height, dpi)
"""
width, height = img.size
last_error = None
total_start_time = time.time()
MAX_TOTAL_TIME = 60 # 总时间不超过60秒
# 如果宽度大于高度,需要旋转
if width > height:
logger.info(f"检测到横向图片(宽{width} > 高{height}),正在旋转-90度...")
# rotate(-90) 表示逆时针旋转90度使横向变为纵向
rotated_img = img.rotate(-90, expand=True)
new_width, new_height = rotated_img.size
logger.info(f"旋转完成,新尺寸:宽{new_width} x 高{new_height}")
return rotated_img, True
for attempt in range(max_retries):
# 检查总时间
if time.time() - total_start_time > MAX_TOTAL_TIME:
logger.warning(f"Total time exceeded {MAX_TOTAL_TIME}s, giving up")
break
logger.info(f"图片为纵向(宽{width} <= 高{height}),无需旋转")
return img, False
try:
# 下载图片(带超时)
with urllib.request.urlopen(image_url, timeout=timeout) as response:
img_data = response.read()
# 检查数据大小
if len(img_data) < 100:
raise ValueError(f"Image data too small: {len(img_data)} bytes")
def resize_image_to_fixed_width(img: Image.Image, target_width: int = FIXED_WIDTH) -> Tuple[Image.Image, float]:
"""将图片缩放到固定宽度,高度等比例缩放"""
original_width, original_height = img.size
img = Image.open(BytesIO(img_data))
if original_width == target_width:
return img, 1.0
# 转换为RGB模式处理PNG透明通道
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
# 计算缩放比例
scale_ratio = target_width / original_width
new_height = int(original_height * scale_ratio)
width, height = img.size
# 使用高质量重采样
# BICUBIC = 3, LANCZOS = 4 (PIL内部常量值)
resized_img = img.resize((target_width, new_height), 3) # BICUBIC
# 验证尺寸有效性
if width <= 0 or height <= 0:
raise ValueError(f"Invalid image size: {width}x{height}")
return resized_img, scale_ratio
dpi = img.info.get('dpi', (72, 72))
if isinstance(dpi, tuple):
dpi = dpi[0] if dpi[0] > 0 else 72
else:
dpi = 72
elapsed = time.time() - total_start_time
logger.info(f"Got image info: {width}x{height} in {elapsed:.1f}s (attempt {attempt + 1})")
return width, height, dpi
def upload_image_to_storage(img: Image.Image, ctx) -> str:
"""将图片上传到对象存储并返回URL"""
from coze_coding_dev_sdk.s3 import S3SyncStorage
except Exception as e:
last_error = e
elapsed = time.time() - total_start_time
logger.warning(f"Attempt {attempt + 1}/{max_retries} failed in {elapsed:.1f}s: {str(e)[:100]}")
# 转换为字节流
img_buffer = BytesIO()
img.save(img_buffer, format='JPEG', quality=95)
img_bytes = img_buffer.getvalue()
# 只有非最后一次才等待
if attempt < max_retries - 1:
wait_time = 1 # 固定等待1秒
time.sleep(wait_time)
# 上传到对象存储
storage = S3SyncStorage(
endpoint_url=os.getenv("COZE_BUCKET_ENDPOINT_URL"),
access_key="",
secret_key="",
bucket_name=os.getenv("COZE_BUCKET_NAME"),
region="cn-beijing",
)
file_key = storage.upload_file(
file_content=img_bytes,
file_name=f"homework_resized_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jpg",
content_type="image/jpeg",
)
result_url = storage.generate_presigned_url(key=file_key, expire_time=86400)
return result_url
# 所有重试都失败,抛出错误
raise RuntimeError(f"图片获取失败({max_retries}次重试): {str(last_error)[:100]}")
def image_preprocess_node(
@ -123,37 +99,38 @@ def image_preprocess_node(
) -> ImagePreprocessOutput:
"""
title: 图像预处理
desc: 下载图片自动旋转横向纵向缩放到固定宽度1000px上传对象存储
integrations: 对象存储
desc: 获取图片尺寸信息直接使用原始图片URL用于AI识别不旋转不缩放不上传支持重试和降级处理
integrations:
"""
ctx = runtime.context
# 1. 下载原始图片
img, original_width, original_height, dpi = download_and_process_image(
state.homework_image.url
)
logger.info(f"原始图片尺寸:宽{original_width} x 高{original_height}")
# 获取原始图片URL
original_image_url = state.homework_image.url
# 2. 自动旋转:如果宽度大于高度,旋转-90度使其变为纵向
img, was_rotated = auto_rotate_image(img)
after_rotate_width, after_rotate_height = img.size
if was_rotated:
logger.info(f"旋转后尺寸:宽{after_rotate_width} x 高{after_rotate_height}")
try:
# 获取图片尺寸信息(带重试)
width, height, dpi = get_image_info_with_retry(original_image_url)
logger.info(f"原始图片尺寸:宽{width} x 高{height}")
# 3. 缩放到固定宽度1000px高度等比例缩放
resized_img, scale_ratio = resize_image_to_fixed_width(img, FIXED_WIDTH)
new_width, new_height = resized_img.size
logger.info(f"缩放后尺寸:宽{new_width} x 高{new_height},缩放比例:{scale_ratio:.3f}")
return ImagePreprocessOutput(
image_info=ImageInfo(
width=width,
height=height,
dpi=dpi
),
image_url=original_image_url
)
# 4. 上传处理后的图片到对象存储
processed_image_url = upload_image_to_storage(resized_img, ctx)
except Exception as e:
# 降级处理:使用默认尺寸,但仍然继续处理
logger.error(f"Failed to get image info after retries, using fallback: {e}")
# 5. 返回处理后的图片信息AI基于这个尺寸计算坐标
return ImagePreprocessOutput(
image_info=ImageInfo(
width=new_width, # 缩放后的宽度1000
height=new_height, # 缩放后的高度
dpi=dpi
),
image_url=processed_image_url
)
# 返回默认尺寸,让后续节点能够继续处理
return ImagePreprocessOutput(
image_info=ImageInfo(
width=DEFAULT_IMAGE_SIZE[0],
height=DEFAULT_IMAGE_SIZE[1],
dpi=72
),
image_url=original_image_url
)

View File

@ -1,7 +1,8 @@
"""图片处理循环节点:并行调用子图处理每张作业图片"""
"""学生作业处理循环节点:完全并行处理,确保数据隔离"""
import logging
from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
from typing import List, Dict, Any, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError as FuturesTimeoutError
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from coze_coding_utils.runtime_ctx.context import Context
@ -9,159 +10,348 @@ from coze_coding_utils.runtime_ctx.context import Context
from graphs.state import (
ProcessImagesInput,
ProcessImagesOutput,
StudentResult,
StudentHomework,
SingleImageResult,
SubgraphInput,
FinalResult,
ImageInfo
ImageInfo,
CorrectAnswer
)
from graphs.loop_graph import single_image_subgraph
from utils.file.file import File
logger = logging.getLogger(__name__)
# 默认等级标准与GraphInput中的默认值保持一致
DEFAULT_GRADE_STANDARDS = {
"A+": {"min_percentage": 95, "description": "优秀"},
"A": {"min_percentage": 90, "description": "良好"},
"B": {"min_percentage": 80, "description": "合格"},
"C": {"min_percentage": 70, "description": "及格"},
"D": {"min_percentage": 0, "description": "需努力"}
}
# 超时配置(秒)
SINGLE_IMAGE_TIMEOUT = 120 # 单张图片处理超时120秒
def calculate_grade(score_rate: float, incorrect_count: int, total_questions: int, grade_standards: Dict[str, Any]) -> tuple:
"""
根据得分率错误数量和等级标准计算等级
核心规则按优先级
1. A+ A首要条件是"全对"incorrect_count == 0与得分率无关
- A+全对
- A全对
A+和A的区别由其他因素决定这里都返回全对的最高等级
2. B/C/D有错误时按得分率判断
Args:
score_rate: 得分率百分比如95.5
incorrect_count: 错误题目数量
total_questions: 总题目数量
grade_standards: 等级标准字典
Returns:
(等级, 等级描述)
"""
# 使用传入的等级标准,如果没有则使用默认值
standards = grade_standards if grade_standards else DEFAULT_GRADE_STANDARDS
# 规则1全对 → A+ 或 A这里统一返回A+,因为全对就是最高等级)
if incorrect_count == 0 and total_questions > 0:
return "A+", standards.get("A+", {}).get("description", "优秀")
# 规则2有错误 → 按得分率判断 B/C/D
# 按min_percentage降序排序只排B/C/D
bcd_grades = [(k, v) for k, v in standards.items() if k not in ("A+", "A")]
sorted_grades = sorted(
bcd_grades,
key=lambda x: x[1].get("min_percentage", 0),
reverse=True
)
# 遍历找到匹配的等级
for grade, config in sorted_grades:
min_pct = config.get("min_percentage", 0)
if score_rate >= min_pct:
return grade, config.get("description", "")
# 默认返回D
return "D", "需努力"
def process_single_image(
student_id: int,
idx: int,
image_url: str,
correct_answers: List[CorrectAnswer],
comment_max_length: int,
config: RunnableConfig
) -> tuple:
"""
处理单个学生的单张图片线程安全
返回: (student_id, image_index, SingleImageResult) 元组确保学生ID关联
注意correct_answers 参数是只读的不会被修改
"""
logger.info(f"Processing student {student_id}, image {idx + 1}")
try:
# 将URL字符串转换为File对象
homework_image = File(url=image_url, file_type="image")
# 构建子图输入(创建新的输入对象,确保数据隔离)
# 注意correct_answers 是只读的,不需要复制
subgraph_input = SubgraphInput(
homework_image=homework_image,
correct_answers=correct_answers, # 只读,不需要复制
image_index=idx,
comment_max_length=comment_max_length
)
# 调用子图config 是只读的LangGraph 保证线程安全)
subgraph_output = single_image_subgraph.invoke(subgraph_input, config)
# 提取结果兼容字典和Pydantic对象两种返回类型
image_result = None
if subgraph_output:
if isinstance(subgraph_output, dict):
result_data = subgraph_output.get("image_result")
if result_data:
if isinstance(result_data, dict):
image_result = SingleImageResult(**result_data)
else:
image_result = result_data
elif hasattr(subgraph_output, 'image_result'):
image_result = subgraph_output.image_result
if image_result:
logger.info(f"Student {student_id}, image {idx + 1} processed: {len(image_result.annotations)} annotations")
return (student_id, idx, image_result)
else:
logger.warning(f"Student {student_id}, image {idx + 1} returned invalid output")
return (student_id, idx, SingleImageResult(
image_index=idx,
image_info=ImageInfo(width=0, height=0, dpi=72),
annotations=[]
))
except Exception as e:
logger.error(f"Failed to process student {student_id}, image {idx + 1}: {e}", exc_info=True)
return (student_id, idx, SingleImageResult(
image_index=idx,
image_info=ImageInfo(width=0, height=0, dpi=72),
annotations=[]
))
def process_images_node(
state: ProcessImagesInput,
config: RunnableConfig,
runtime: Runtime[Context]
) -> ProcessImagesOutput:
"""
title: 多图片批改处理
desc: 并行调用子图处理每张作业图片生成最终批改结果
title: 学生作业批改处理
desc: 完全并行处理所有学生的所有图片确保每个学生的数据完全隔离
integrations:
"""
ctx = runtime.context
# 获取并发数限制从参数获取默认10
max_concurrent = getattr(state, 'max_concurrent', 10)
# === 输入参数校验 ===
if not state.student_homework:
logger.warning("No student homework provided, returning empty result")
return ProcessImagesOutput(student_results=[])
logger.info(f"Starting to process {len(state.homework_images)} images (concurrent={max_concurrent})")
# 获取并发数限制从参数获取默认10限制在合理范围1-50
max_concurrent = max(1, min(getattr(state, 'max_concurrent', 10), 50))
# 定义处理单张图片的函数
def process_single_image(idx: int, homework_image) -> SingleImageResult:
logger.info(f"Processing image {idx + 1}/{len(state.homework_images)}")
try:
# 构建子图输入
subgraph_input = SubgraphInput(
homework_image=homework_image,
correct_answers=state.correct_answers,
image_index=idx,
comment_max_length=state.comment_max_length
)
# 调用子图
subgraph_output = single_image_subgraph.invoke(subgraph_input, config)
# 提取结果兼容字典和Pydantic对象两种返回类型
image_result = None
if subgraph_output:
if isinstance(subgraph_output, dict):
result_data = subgraph_output.get("image_result")
if result_data:
if isinstance(result_data, dict):
image_result = SingleImageResult(**result_data)
else:
image_result = result_data
elif hasattr(subgraph_output, 'image_result'):
image_result = subgraph_output.image_result
if image_result:
logger.info(f"Image {idx + 1} processed successfully: {len(image_result.annotations)} annotations")
return image_result
# 过滤有效的图片URL统计总任务数
valid_tasks = [] # [(student_id, idx, image_url), ...]
for student in state.student_homework:
if not student.homework_images:
continue
for idx, image_url in enumerate(student.homework_images):
# 验证图片URL有效性
if image_url and isinstance(image_url, str) and image_url.strip():
valid_tasks.append((student.student_id, idx, image_url.strip()))
else:
logger.warning(f"Image {idx + 1} subgraph returned invalid output")
return SingleImageResult(
image_index=idx,
image_info=ImageInfo(width=0, height=0, dpi=72),
annotations=[]
)
except Exception as e:
logger.error(f"Failed to process image {idx + 1}: {e}", exc_info=True)
return SingleImageResult(
image_index=idx,
image_info=ImageInfo(width=0, height=0, dpi=72),
annotations=[]
)
logger.warning(f"Invalid image URL for student {student.student_id}, image {idx}")
# 并行处理所有图片
image_results: List[SingleImageResult] = []
if not valid_tasks:
logger.warning("No valid images to process, returning empty result")
return ProcessImagesOutput(student_results=[])
logger.info(f"Starting to process {len(state.student_homework)} students, {len(valid_tasks)} valid images (concurrent={max_concurrent})")
# 第一步:提交所有任务(所有学生的所有图片)
# 注意使用字典存储结果as_completed 在主线程顺序处理,无需加锁
student_image_results: Dict[int, Dict[int, SingleImageResult]] = {}
with ThreadPoolExecutor(max_workers=max_concurrent) as executor:
# 提交所有任务
future_to_idx = {
executor.submit(process_single_image, idx, img): idx
for idx, img in enumerate(state.homework_images)
}
# 提交所有有效任务
future_to_task = {} # future -> (student_id, idx, image_url)
for student_id, idx, image_url in valid_tasks:
future = executor.submit(
process_single_image,
student_id,
idx,
image_url,
state.correct_answers,
state.comment_max_length,
config
)
future_to_task[future] = (student_id, idx, image_url)
# 收集结果
results_dict = {}
for future in as_completed(future_to_idx):
idx = future_to_idx[future]
try:
result = future.result()
results_dict[idx] = result
except Exception as e:
logger.error(f"Future failed for image {idx}: {e}")
results_dict[idx] = SingleImageResult(
logger.info(f"Submitted {len(future_to_task)} tasks for all students")
# 第二步收集所有结果按student_id分组
# as_completed 在主线程中迭代,不需要加锁
# 双重保险:使用 task_info 中的 student_id而非 result 中的),确保数据归属正确
try:
for future in as_completed(future_to_task, timeout=SINGLE_IMAGE_TIMEOUT * len(future_to_task)):
task_info = future_to_task[future]
student_id, idx, image_url = task_info # 从任务信息获取,不依赖返回值
try:
# 返回值中的 student_id 应该与 task_info 一致,但我们优先使用 task_info
# 添加单任务超时保护
result_student_id, image_index, image_result = future.result(timeout=SINGLE_IMAGE_TIMEOUT)
# 双重校验确保返回的学生ID与任务一致
if result_student_id != student_id:
logger.warning(f"Student ID mismatch: task={student_id}, result={result_student_id}, using task_id")
# 使用 task_info 中的 student_id更可靠
if student_id not in student_image_results:
student_image_results[student_id] = {}
# 存储该学生的图片结果
student_image_results[student_id][image_index] = image_result
except FuturesTimeoutError:
logger.error(f"Task timeout for student {student_id}, image {idx} (timeout={SINGLE_IMAGE_TIMEOUT}s)")
# 存储一个空结果,确保该学生的结果完整
if student_id not in student_image_results:
student_image_results[student_id] = {}
student_image_results[student_id][idx] = SingleImageResult(
image_index=idx,
image_info=ImageInfo(width=0, height=0, dpi=72),
annotations=[]
)
except Exception as e:
logger.error(f"Task failed for student {student_id}, image {idx}: {e}", exc_info=True)
# 存储一个空结果,确保该学生的结果完整
if student_id not in student_image_results:
student_image_results[student_id] = {}
student_image_results[student_id][idx] = SingleImageResult(
image_index=idx,
image_info=ImageInfo(width=0, height=0, dpi=72),
annotations=[]
)
except FuturesTimeoutError:
# as_completed 总超时,处理未完成的任务
logger.error(f"Total timeout exceeded, processing remaining {len(future_to_task)} tasks")
for future, task_info in future_to_task.items():
if not future.done():
student_id, idx, _ = task_info
if student_id not in student_image_results:
student_image_results[student_id] = {}
student_image_results[student_id][idx] = SingleImageResult(
image_index=idx,
image_info=ImageInfo(width=0, height=0, dpi=72),
annotations=[]
)
logger.info(f"Collected results for {len(student_image_results)} students")
# 第三步:为每个学生计算独立的结果(数据隔离)
student_results: List[StudentResult] = []
for student in state.student_homework:
student_id = student.student_id
student_name = student.student_name
homework_image_urls = student.homework_images
# 获取该学生的图片结果(确保只使用该学生的数据)
image_results_dict = student_image_results.get(student_id, {})
# 按顺序组装该学生的图片结果
image_results: List[SingleImageResult] = []
for idx in range(len(homework_image_urls)):
if idx in image_results_dict:
image_results.append(image_results_dict[idx])
else:
# 补充缺失的结果
image_results.append(SingleImageResult(
image_index=idx,
image_info=ImageInfo(width=0, height=0, dpi=72),
annotations=[]
)
))
# 按原始顺序排序结果
for idx in range(len(state.homework_images)):
image_results.append(results_dict[idx])
# 计算该学生的总分(只基于该学生的图片结果)
total_score = 0
full_score = 0
total_questions = 0
correct_count = 0
incorrect_count = 0
logger.info(f"Completed processing {len(image_results)} images")
for result in image_results:
for annotation in result.annotations:
total_score += annotation.score
full_score += annotation.full_score
total_questions += 1
if annotation.status == "correct":
correct_count += 1
elif annotation.status == "incorrect":
incorrect_count += 1
# 生成最终结果
total_score = 0
full_score = 0
total_questions = 0
correct_count = 0
incorrect_count = 0
# 计算得分率
score_rate = (total_score / full_score * 100) if full_score > 0 else 0
for result in image_results:
for annotation in result.annotations:
total_score += annotation.score
full_score += annotation.full_score
total_questions += 1
if annotation.status == "correct":
correct_count += 1
elif annotation.status == "incorrect":
incorrect_count += 1
# 获取等级标准从state中获取如果没有则使用默认值
grade_standards = getattr(state, 'grade_standards', {})
# 计算得分率
score_rate = (total_score / full_score * 100) if full_score > 0 else 0
# 使用等级标准计算等级(传入错误数量和总题数)
grade, grade_description = calculate_grade(score_rate, incorrect_count, total_questions, grade_standards)
# 生成整体评价
if total_questions == 0:
overall_comment = "未识别到题目内容"
grade = "D"
elif score_rate >= 95:
overall_comment = f"优秀!{total_questions}题全部正确,掌握扎实,继续保持。"
grade = "A+"
elif score_rate >= 90:
overall_comment = f"良好!得分率{score_rate:.0f}%,错{incorrect_count}题,注意细节。"
grade = "A"
elif score_rate >= 80:
overall_comment = f"合格。得分率{score_rate:.0f}%,错{incorrect_count}题,需加强练习。"
grade = "B"
elif score_rate >= 70:
overall_comment = f"及格。错{incorrect_count}题,部分知识点掌握不牢,建议复习。"
grade = "C"
else:
overall_comment = f"需努力。得分率{score_rate:.0f}%,建议认真复习,多做练习。"
grade = "D"
# 生成该学生的整体评价(与等级匹配)
if total_questions == 0:
overall_comment = "未识别到题目内容"
grade = "D"
elif grade == "A+":
# A+:全部正确
overall_comment = f"优秀!{total_questions}题全部正确,掌握扎实,继续保持。"
elif grade == "A":
# A全部正确已确保无错误
overall_comment = f"良好!{total_questions}题全部正确,步骤规范。"
elif grade == "B":
# B有少量错误
if incorrect_count > 0:
overall_comment = f"合格。得分率{score_rate:.0f}%,错{incorrect_count}题,需加强练习。"
else:
overall_comment = f"合格。得分率{score_rate:.0f}%,继续努力。"
elif grade == "C":
overall_comment = f"及格。错{incorrect_count}题,部分知识点掌握不牢,建议复习。"
else:
overall_comment = f"需努力。得分率{score_rate:.0f}%,错{incorrect_count}题,建议认真复习,多做练习。"
final_result = FinalResult(
total_images=len(image_results),
image_results=image_results,
overall_comment=overall_comment,
total_score=total_score,
full_score=full_score,
grade=grade
)
# 创建该学生的完整结果(数据完全隔离)
student_result = StudentResult(
student_id=student_id,
student_name=student_name,
total_images=len(image_results),
image_results=image_results,
overall_comment=overall_comment,
total_score=total_score,
full_score=full_score,
grade=grade
)
logger.info(f"Final result: {total_score}/{full_score}, grade={grade}")
student_results.append(student_result)
return ProcessImagesOutput(final_result=final_result)
logger.info(f"Student {student_id}: {total_score}/{full_score}, grade={grade}")
logger.info(f"Completed processing {len(student_results)} students")
return ProcessImagesOutput(student_results=student_results)

View File

@ -3,6 +3,7 @@ import os
import json
import re
import logging
import orjson
from typing import List, Dict, Any
from jinja2 import Template
from langchain_core.runnables import RunnableConfig
@ -22,6 +23,30 @@ from graphs.state import (
logger = logging.getLogger(__name__)
# 思考过程的特征词
THINKING_KEYWORDS = [
"不对", "重新看", "可能我", "哦,", "说明我", "这明显",
"不,", "应该是", "我发现", "等等", "让我", "我理解",
"?不", "?不对", "不是", "搞错了", "看错了"
]
def clean_comment(text: str) -> str:
"""清理comment中的思考过程"""
for keyword in THINKING_KEYWORDS:
if keyword in text:
# 找到思考过程开始的位置
idx = text.find(keyword)
if idx > 0:
# 截断思考过程
cleaned = text[:idx]
# 尝试保留完整句子(最后一个句号之后)
last_period = cleaned.rfind("")
if last_period > 0:
return cleaned[:last_period + 1]
return cleaned
return text
def fix_incomplete_json(text: str) -> str:
"""尝试修复不完整的JSON字符串"""
@ -38,9 +63,64 @@ def fix_incomplete_json(text: str) -> str:
return text
def extract_complete_objects(text: str) -> List[dict]:
"""从JSON文本中提取完整的对象列表处理思考过程干扰"""
objects = []
# 找到每个对象的开始位置
obj_pattern = r'\{\s*"question_id"'
matches = list(re.finditer(obj_pattern, text))
for i, match in enumerate(matches):
start = match.start()
# 找到这个对象的结束位置
brace_count = 0
in_string = False
escape = False
end = start
for j in range(start, len(text)):
char = text[j]
if escape:
escape = False
continue
if char == '\\':
escape = True
continue
if char == '"' and not escape:
in_string = not in_string
continue
if in_string:
continue
if char == '{':
brace_count += 1
elif char == '}':
brace_count -= 1
if brace_count == 0:
end = j + 1
break
if end > start:
obj_str = text[start:end]
try:
obj = orjson.loads(obj_str)
if isinstance(obj, dict) and "question_id" in obj:
objects.append(obj)
logger.debug(f"Extracted object: {obj.get('question_id')}")
except Exception as e:
logger.debug(f"Failed to parse object: {e}")
return objects
def extract_json_from_text(text: str, key: str = "results") -> dict:
"""从文本中提取JSON对象增强健壮性"""
import orjson
# 清理markdown标记
for prefix in ["```json", "```JSON", "```"]:
@ -52,7 +132,7 @@ def extract_json_from_text(text: str, key: str = "results") -> dict:
text = text.strip()
# 尝试直接解析支持格式化JSON
# 尝试直接解析
for parser in [json.loads, lambda x: orjson.loads(x)]:
try:
result = parser(text)
@ -76,79 +156,11 @@ def extract_json_from_text(text: str, key: str = "results") -> dict:
except:
pass
# 尝试提取第一个完整的JSON对象
results_pattern = r'"results"\s*:\s*\['
match = re.search(results_pattern, text)
if match:
start_pos = text.rfind('{', 0, match.start())
if start_pos == -1:
start_pos = 0
# 从start_pos开始找到匹配的 }
brace_count = 0
bracket_count = 0
in_string = False
escape = False
end_pos = start_pos
for i in range(start_pos, len(text)):
char = text[i]
if escape:
escape = False
continue
if char == '\\':
escape = True
continue
if char == '"' and not escape:
in_string = not in_string
continue
if in_string:
continue
if char == '{':
brace_count += 1
elif char == '}':
brace_count -= 1
if brace_count == 0 and bracket_count == 0:
end_pos = i + 1
break
elif char == '[':
bracket_count += 1
elif char == ']':
bracket_count -= 1
if end_pos > start_pos:
json_str = text[start_pos:end_pos]
for parser in [json.loads, lambda x: orjson.loads(x)]:
try:
result = parser(json_str)
if isinstance(result, dict) and key in result:
logger.info("Successfully extracted and parsed JSON object")
return result
except Exception as e:
logger.debug(f"Failed to parse extracted JSON: {e}")
# 如果提取失败尝试修复不完整的JSON
if end_pos <= start_pos:
# 找到JSON开始位置
json_str = text[start_pos:]
# 尝试修复
fixed_json = fix_incomplete_json(json_str)
for parser in [json.loads, lambda x: orjson.loads(x)]:
try:
result = parser(fixed_json)
if isinstance(result, dict) and key in result:
logger.info("Successfully parsed fixed incomplete JSON")
return result
except:
pass
# 尝试提取完整的对象列表(处理思考过程干扰)
objects = extract_complete_objects(text)
if objects:
logger.info(f"Extracted {len(objects)} complete objects from JSON with thinking")
return {key: objects}
logger.warning(f"Failed to extract JSON with key '{key}' from text length {len(text)}")
return {key: []}
@ -240,7 +252,7 @@ def recognize_and_correct_node(
messages=messages,
model=llm_config.get("model", "doubao-seed-2-0-pro-260215"),
temperature=llm_config.get("temperature", 0.0),
max_completion_tokens=llm_config.get("max_completion_tokens", 4096)
max_completion_tokens=llm_config.get("max_completion_tokens", 8192)
)
response_text = response.content if isinstance(response.content, str) else " ".join(
@ -255,8 +267,12 @@ def recognize_and_correct_node(
result_dict = extract_json_from_text(response_text, "results")
# 相对坐标0-1000转换为绝对坐标
width_scale = image_info.width / 1000.0 if image_info.width > 0 else 1.0
height_scale = image_info.height / 1000.0 if image_info.height > 0 else 1.0
# 严格检查图片尺寸,防止除零和无效坐标
img_width = max(1, image_info.width) if image_info.width > 0 else 1000
img_height = max(1, image_info.height) if image_info.height > 0 else 1000
width_scale = img_width / 1000.0
height_scale = img_height / 1000.0
question_items: List[QuestionItem] = []
correction_results: List[CorrectionResult] = []
@ -275,6 +291,20 @@ def recognize_and_correct_node(
if not isinstance(answer_bbox, list) or len(answer_bbox) != 4:
answer_bbox = [0, 0, 0, 0]
# 确保bbox值在有效范围内0-1000
answer_bbox = [
max(0, min(1000, int(answer_bbox[0]))) if isinstance(answer_bbox[0], (int, float)) else 0,
max(0, min(1000, int(answer_bbox[1]))) if isinstance(answer_bbox[1], (int, float)) else 0,
max(0, min(1000, int(answer_bbox[2]))) if isinstance(answer_bbox[2], (int, float)) else 0,
max(0, min(1000, int(answer_bbox[3]))) if isinstance(answer_bbox[3], (int, float)) else 0
]
# 确保x2 >= x1, y2 >= y1
if answer_bbox[2] < answer_bbox[0]:
answer_bbox[2] = answer_bbox[0]
if answer_bbox[3] < answer_bbox[1]:
answer_bbox[3] = answer_bbox[1]
# 转换为绝对坐标
answer_bbox_abs = [
int(answer_bbox[0] * width_scale),
@ -283,19 +313,42 @@ def recognize_and_correct_node(
int(answer_bbox[3] * height_scale)
]
# 自动计算mark_position
# 严格限制绝对坐标在图片范围内
answer_bbox_abs = [
max(0, min(img_width, answer_bbox_abs[0])),
max(0, min(img_height, answer_bbox_abs[1])),
max(0, min(img_width, answer_bbox_abs[2])),
max(0, min(img_height, answer_bbox_abs[3]))
]
# 自动计算mark_position智能定位紧贴答案框
bbox_width = answer_bbox_abs[2] - answer_bbox_abs[0]
bbox_height = answer_bbox_abs[3] - answer_bbox_abs[1]
if bbox_width > 0 and bbox_height > 0:
mark_x = answer_bbox_abs[2] + 30
mark_y = answer_bbox_abs[1] + int(bbox_height * 0.5)
# 计算答案框右侧的剩余空间
right_space = img_width - answer_bbox_abs[2]
if mark_x > image_info.width - 50:
mark_x = image_info.width - 50
# 策略1如果右侧空间充足>80px标记紧贴答案框右侧
if right_space > 80:
mark_x = answer_bbox_abs[2] + 10
mark_y = answer_bbox_abs[1] + int(bbox_height * 0.5)
# 策略2如果右侧空间不足但有一定空间>40px标记在答案框右上角内部
elif right_space > 40:
mark_x = answer_bbox_abs[2] - 10
mark_y = answer_bbox_abs[1] + 10
# 策略3如果右侧空间很小标记在答案框左上角
else:
mark_x = answer_bbox_abs[0] + 10
mark_y = answer_bbox_abs[1] + 10
# 最终边界检查严格限制在图片范围内留10px边距
mark_x = max(10, min(mark_x, img_width - 10))
mark_y = max(10, min(mark_y, img_height - 10))
else:
mark_x = 500
mark_y = 500
# bbox无效时使用图片中心
mark_x = img_width // 2
mark_y = img_height // 2
mark_position = MarkPosition(x=mark_x, y=mark_y)
@ -316,8 +369,13 @@ def recognize_and_correct_node(
if status not in ["correct", "incorrect", "partial"]:
status = "incorrect"
# 使用原始comment不做截断由LLM控制长度
comment = str(r.get("comment", ""))
# 清理comment中的思考过程
comment_raw = str(r.get("comment", ""))
comment = clean_comment(comment_raw)
# 如果清理后comment为空设置默认值
if not comment:
comment = "正确" if status == "correct" else "错误"
correction_results.append(CorrectionResult(
question_id=str(r.get("question_id", "")),

View File

@ -1,4 +1,4 @@
"""初中物理作业批改工作流状态定义 - 支持多图片批改"""
"""初中物理作业批改工作流状态定义 - 支持多学生多图片批改"""
from typing import List, Optional, Literal
from pydantic import BaseModel, Field
from utils.file.file import File
@ -6,6 +6,13 @@ from utils.file.file import File
# === 基础数据结构 ===
class StudentHomework(BaseModel):
"""学生作业信息"""
student_id: int = Field(..., description="学生ID")
student_name: str = Field(default="", description="学生姓名")
homework_images: List[str] = Field(default=[], description="该学生的作业图片URL列表")
class ImageInfo(BaseModel):
"""图片信息"""
width: int = Field(..., description="图片宽度(像素)")
@ -74,36 +81,37 @@ class SingleImageResult(BaseModel):
annotations: List[Annotation] = Field(default=[], description="该图片的批注列表")
class FinalResult(BaseModel):
"""最终批改结果(多图片汇总)"""
total_images: int = Field(..., description="总图片数")
class StudentResult(BaseModel):
"""单个学生的批改结果"""
student_id: int = Field(..., description="学生ID")
student_name: str = Field(default="", description="学生姓名")
total_images: int = Field(..., description="该学生的总图片数")
image_results: List[SingleImageResult] = Field(default=[], description="各图片的批改结果")
overall_comment: str = Field(default="", description="整体评价")
total_score: int = Field(default=0, description="总分")
full_score: int = Field(default=100, description="满分")
grade: str = Field(default="", description="等级")
overall_comment: str = Field(default="", description="该学生的整体评价")
total_score: int = Field(default=0, description="该学生的总分")
full_score: int = Field(default=100, description="该学生的满分")
grade: str = Field(default="", description="该学生的等级")
# === 全局状态 ===
class GlobalState(BaseModel):
"""工作流全局状态"""
# 输入参数
homework_images: List[File] = Field(default=[], description="上传的作业图片列表")
student_homework: List[StudentHomework] = Field(default=[], description="学生作业列表")
answer_doc_url: str = Field(default="", description="正确答案Word文件的URL.docx格式")
comment_max_length: int = Field(default=100, description="评语最大字数")
max_concurrent: int = Field(default=10, description="并行批改的最大数量")
grade_standards: dict = Field(default={}, description="评价等级标准")
# 中间状态
correct_answers: List[CorrectAnswer] = Field(default=[], description="从Word解析的标准答案列表")
image_results: List[SingleImageResult] = Field(default=[], description="各图片的批改结果列表")
# 最终结果
final_result: Optional[FinalResult] = Field(default=None, description="最终汇总结果")
student_results: List[StudentResult] = Field(default=[], description="各学生的批改结果列表")
# === 图输入输出 ===
class GraphInput(BaseModel):
"""工作流输入"""
homework_images: List[File] = Field(..., description="上传的作业图片列表")
student_homework: List[StudentHomework] = Field(..., description="学生作业列表每个学生包含ID和作业图片")
answer_doc_url: str = Field(default="", description="正确答案Word文件的URL.docx格式可选")
comment_max_length: int = Field(default=100, description="评语最大字数默认100字")
max_concurrent: int = Field(default=10, description="并行批改的最大数量默认10")
@ -121,7 +129,7 @@ class GraphInput(BaseModel):
class GraphOutput(BaseModel):
"""工作流输出"""
final_result: FinalResult = Field(..., description="最终批改结果JSON包含多图片")
student_results: List[StudentResult] = Field(..., description="各学生的批改结果列表")
# === 文档答案解析节点 ===
@ -220,13 +228,14 @@ class ResultMergeOutput(BaseModel):
# === 循环节点 ===
class ProcessImagesInput(BaseModel):
"""图片处理循环节点输入"""
homework_images: List[File] = Field(default=[], description="作业图片列表")
"""学生作业处理循环节点输入"""
student_homework: List[StudentHomework] = Field(default=[], description="学生作业列表")
correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表")
comment_max_length: int = Field(default=100, description="评语最大字数")
max_concurrent: int = Field(default=10, description="并行批改的最大数量")
grade_standards: dict = Field(default={}, description="评价等级标准")
class ProcessImagesOutput(BaseModel):
"""图片处理循环节点输出"""
final_result: FinalResult = Field(..., description="最终批改结果")
"""学生作业处理循环节点输出"""
student_results: List[StudentResult] = Field(..., description="各学生的批改结果列表")

285
src/utils/cache_manager.py Normal file
View File

@ -0,0 +1,285 @@
"""缓存管理器:支持过期时间、内存保护、持久化"""
import os
import json
import hashlib
import logging
import threading
import tempfile
from pathlib import Path
from datetime import datetime, timedelta
from typing import Any, Optional, Dict
from functools import wraps
logger = logging.getLogger(__name__)
# 缓存配置 - 使用 tempfile 获取安全的临时目录
try:
CACHE_DIR = Path(tempfile.gettempdir()) / "homework_cache"
except Exception:
CACHE_DIR = Path("/tmp/homework_cache")
CACHE_EXPIRE_DAYS = 30 # 缓存过期天数
MAX_MEMORY_CACHE_SIZE = 1000 # 内存缓存最大数量
class CacheManager:
"""
缓存管理器内存缓存 + 文件缓存
- 内存缓存快速访问有大小限制
- 文件缓存持久化存储支持过期时间
- 异常安全文件操作失败不影响主流程
"""
def __init__(self, cache_name: str, maxsize: int = MAX_MEMORY_CACHE_SIZE, expire_days: int = CACHE_EXPIRE_DAYS):
"""
初始化缓存管理器
Args:
cache_name: 缓存名称用于区分不同类型的缓存
maxsize: 内存缓存最大数量
expire_days: 缓存过期天数
"""
self.cache_name = cache_name
self.maxsize = maxsize
self.expire_days = expire_days
# 内存缓存(使用字典 + 简单的LRU淘汰
self._memory_cache: Dict[str, Any] = {}
self._cache_keys: list = [] # 记录访问顺序用于LRU淘汰
self._lock = threading.Lock() # 线程安全锁
self._file_cache_enabled = True # 文件缓存是否可用
# 文件缓存目录 - 带异常处理
try:
self.cache_dir = CACHE_DIR / cache_name
self.cache_dir.mkdir(parents=True, exist_ok=True)
# 测试写入权限
test_file = self.cache_dir / ".test_write"
test_file.write_text("test")
test_file.unlink()
logger.info(f"CacheManager initialized: {cache_name}, dir={self.cache_dir}")
except Exception as e:
logger.warning(f"File cache disabled due to permission error: {e}")
self._file_cache_enabled = False
logger.info(f"CacheManager initialized (memory only): {cache_name}")
def _get_cache_key(self, key: str) -> str:
"""生成缓存键使用MD5哈希"""
return hashlib.md5(key.encode()).hexdigest()
def _get_cache_file(self, cache_key: str) -> Path:
"""获取缓存文件路径"""
return self.cache_dir / f"{cache_key}.json"
def _is_expired(self, cache_time: str) -> bool:
"""检查缓存是否过期"""
try:
cached_dt = datetime.fromisoformat(cache_time)
expire_dt = cached_dt + timedelta(days=self.expire_days)
return datetime.now() > expire_dt
except Exception:
return True # 解析失败视为过期
def _evict_lru(self):
"""LRU淘汰移除最久未使用的缓存项"""
while len(self._memory_cache) >= self.maxsize and self._cache_keys:
oldest_key = self._cache_keys.pop(0)
if oldest_key in self._memory_cache:
del self._memory_cache[oldest_key]
logger.debug(f"LRU evicted: {oldest_key[:8]}...")
def get(self, key: str) -> Optional[Any]:
"""
获取缓存
优先级内存缓存 > 文件缓存 > None
Args:
key: 缓存键
Returns:
缓存值不存在或过期返回None
"""
cache_key = self._get_cache_key(key)
# 1. 检查内存缓存(线程安全)
with self._lock:
if cache_key in self._memory_cache:
# 更新访问顺序(移动到末尾)
if cache_key in self._cache_keys:
self._cache_keys.remove(cache_key)
self._cache_keys.append(cache_key)
logger.debug(f"Memory cache hit: {cache_key[:8]}...")
return self._memory_cache[cache_key]
# 2. 检查文件缓存(仅在文件缓存可用时)
if self._file_cache_enabled:
cache_file = self._get_cache_file(cache_key)
if cache_file.exists():
try:
with open(cache_file, 'r', encoding='utf-8') as f:
cached_data = json.load(f)
# 检查是否过期
if self._is_expired(cached_data.get("cache_time", "")):
# 过期,删除文件
try:
cache_file.unlink()
except Exception:
pass
logger.debug(f"File cache expired: {cache_key[:8]}...")
return None
# 未过期,加载到内存缓存
with self._lock:
self._evict_lru() # 淘汰旧的
self._memory_cache[cache_key] = cached_data["data"]
self._cache_keys.append(cache_key)
logger.debug(f"File cache hit: {cache_key[:8]}...")
return cached_data["data"]
except Exception as e:
logger.warning(f"Failed to read cache file: {e}")
# 删除损坏的缓存文件
try:
cache_file.unlink()
except Exception:
pass
return None
return None
def set(self, key: str, value: Any):
"""
设置缓存
同时存入内存缓存和文件缓存
Args:
key: 缓存键
value: 缓存值
"""
cache_key = self._get_cache_key(key)
# 1. 存入内存缓存
with self._lock:
self._evict_lru() # 淘汰旧的
self._memory_cache[cache_key] = value
if cache_key in self._cache_keys:
self._cache_keys.remove(cache_key)
self._cache_keys.append(cache_key)
# 2. 存入文件缓存(仅在文件缓存可用时)
if self._file_cache_enabled:
cache_file = self._get_cache_file(cache_key)
try:
cached_data = {
"cache_time": datetime.now().isoformat(),
"data": value
}
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(cached_data, f, ensure_ascii=False, indent=2)
logger.debug(f"Cache saved: {cache_key[:8]}...")
except Exception as e:
logger.warning(f"Failed to save cache file: {e}")
def clear_expired(self):
"""清理所有过期的文件缓存"""
if not self._file_cache_enabled:
return 0
cleaned = 0
try:
for cache_file in self.cache_dir.glob("*.json"):
try:
with open(cache_file, 'r', encoding='utf-8') as f:
cached_data = json.load(f)
if self._is_expired(cached_data.get("cache_time", "")):
cache_file.unlink()
cleaned += 1
except Exception:
# 损坏的文件也删除
try:
cache_file.unlink()
except Exception:
pass
cleaned += 1
if cleaned > 0:
logger.info(f"Cleaned {cleaned} expired cache files")
except Exception as e:
logger.error(f"Failed to clear expired cache: {e}")
return cleaned
def get_stats(self) -> dict:
"""获取缓存统计信息"""
memory_size = len(self._memory_cache)
# 统计文件缓存数量
file_size = 0
if self._file_cache_enabled:
try:
file_size = len(list(self.cache_dir.glob("*.json")))
except Exception:
pass
return {
"cache_name": self.cache_name,
"memory_cache_size": memory_size,
"memory_cache_maxsize": self.maxsize,
"file_cache_size": file_size,
"file_cache_enabled": self._file_cache_enabled,
"expire_days": self.expire_days
}
def cached(cache_manager: CacheManager):
"""
缓存装饰器
用法
@cached(answer_doc_cache)
def parse_answer_doc(url: str):
# 解析逻辑
return result
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# 生成缓存键(使用函数名和参数)
cache_key = f"{func.__name__}:{str(args)}:{str(kwargs)}"
# 尝试从缓存获取
cached_result = cache_manager.get(cache_key)
if cached_result is not None:
return cached_result
# 缓存未命中,执行函数
result = func(*args, **kwargs)
# 存入缓存
if result is not None:
cache_manager.set(cache_key, result)
return result
return wrapper
return decorator
# 创建全局缓存实例
answer_doc_cache = CacheManager(
cache_name="answer_doc",
maxsize=MAX_MEMORY_CACHE_SIZE,
expire_days=CACHE_EXPIRE_DAYS
)
grade_standards_cache = CacheManager(
cache_name="grade_standards",
maxsize=100, # 评分标准缓存数量较少
expire_days=CACHE_EXPIRE_DAYS
)

168
test_deployment.sh Normal file
View File

@ -0,0 +1,168 @@
#!/bin/bash
# ============================================
# 初中物理作业批改工作流 - 部署验证脚本
# ============================================
set -e
echo "======================================"
echo " 部署验证测试"
echo "======================================"
echo ""
# 颜色定义
GREEN='\033[0;32m'
RED='\033[0;31m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# 测试计数
PASS=0
FAIL=0
# 测试函数
test_step() {
local name=$1
local command=$2
echo -n "测试: $name ... "
if eval "$command" > /dev/null 2>&1; then
echo -e "${GREEN}✅ 通过${NC}"
((PASS++))
return 0
else
echo -e "${RED}❌ 失败${NC}"
((FAIL++))
return 1
fi
}
# 1. 环境检查
echo "1. 环境检查"
echo "-----------------------------------"
test_step "Python 版本" "python3 --version"
test_step "虚拟环境" "test -d venv"
test_step "依赖安装" "python3 -c 'import fastapi'"
test_step "依赖安装" "python3 -c 'import langgraph'"
test_step "依赖安装" "python3 -c 'import openai'"
echo ""
# 2. 配置检查
echo "2. 配置检查"
echo "-----------------------------------"
if [ -f ".env" ]; then
echo -e "${GREEN}✅ .env 文件存在${NC}"
((PASS++))
# 检查必需环境变量
source .env
if [ -n "$LLM_API_KEY" ] && [ "$LLM_API_KEY" != "your-api-key-here" ]; then
echo -e "${GREEN}✅ LLM_API_KEY 已配置${NC}"
((PASS++))
else
echo -e "${RED}❌ LLM_API_KEY 未配置${NC}"
((FAIL++))
fi
if [ -n "$LLM_BASE_URL" ]; then
echo -e "${GREEN}✅ LLM_BASE_URL 已配置${NC}"
((PASS++))
else
echo -e "${RED}❌ LLM_BASE_URL 未配置${NC}"
((FAIL++))
fi
if [ -n "$LLM_MODEL_NAME" ]; then
echo -e "${GREEN}✅ LLM_MODEL_NAME 已配置${NC}"
((PASS++))
else
echo -e "${RED}❌ LLM_MODEL_NAME 未配置${NC}"
((FAIL++))
fi
echo -e "${GREEN}✅ 无需配置对象存储(已优化)${NC}"
((PASS++))
else
echo -e "${RED}❌ .env 文件不存在${NC}"
((FAIL++))
fi
echo ""
# 3. 文件完整性检查
echo "3. 文件完整性检查"
echo "-----------------------------------"
test_step "主入口文件" "test -f src/main.py"
test_step "主工作流" "test -f src/graphs/graph.py"
test_step "状态定义" "test -f src/graphs/state.py"
test_step "配置文件目录" "test -d config"
test_step "启动脚本" "test -f scripts/http_run.sh"
echo ""
# 4. 模块导入测试
echo "4. 模块导入测试"
echo "-----------------------------------"
test_step "导入主模块" "python3 -c 'from graphs.state import GlobalState'"
test_step "导入节点模块" "python3 -c 'from graphs.nodes.doc_extract_node import doc_extract_node'"
echo ""
# 5. 服务启动测试(可选)
echo "5. 服务启动测试"
echo "-----------------------------------"
if command -v curl &> /dev/null; then
# 检查服务是否已启动
if curl -s http://localhost:8000/health > /dev/null 2>&1; then
echo -e "${GREEN}✅ 服务已运行在 http://localhost:8000${NC}"
((PASS++))
# 测试健康检查
HEALTH=$(curl -s http://localhost:8000/health)
if echo "$HEALTH" | grep -q "ok"; then
echo -e "${GREEN}✅ 健康检查通过${NC}"
((PASS++))
else
echo -e "${RED}❌ 健康检查失败${NC}"
((FAIL++))
fi
else
echo -e "${YELLOW}⚠️ 服务未启动,跳过服务测试${NC}"
echo " 启动服务: bash scripts/http_run.sh -p 8000"
fi
else
echo -e "${YELLOW}⚠️ curl 未安装,跳过服务测试${NC}"
fi
echo ""
# 测试总结
echo "======================================"
echo " 测试总结"
echo "======================================"
echo ""
echo -e "${GREEN}通过: $PASS${NC}"
echo -e "${RED}失败: $FAIL${NC}"
echo ""
if [ $FAIL -eq 0 ]; then
echo -e "${GREEN}✅ 所有测试通过!部署成功!${NC}"
echo ""
echo "下一步:"
echo " 1. 启动服务: bash scripts/http_run.sh -p 8000"
echo " 2. 访问文档: http://localhost:8000/docs"
echo " 3. 发送测试请求: curl -X POST http://localhost:8000/run -H 'Content-Type: application/json' -d @test_payload.json"
exit 0
else
echo -e "${RED}❌ 部分测试失败,请检查配置${NC}"
echo ""
echo "常见问题:"
echo " - 环境变量未配置: 编辑 .env 文件"
echo " - 依赖未安装: pip install -r requirements.txt"
echo " - 文件缺失: 检查项目完整性"
exit 1
fi

16
test_payload.json Normal file
View File

@ -0,0 +1,16 @@
{
"student_homework": [
{
"student_id": 0,
"student_name": "测试学生",
"homework_images": [
"https://example.com/homework1.jpg",
"https://example.com/homework2.jpg"
]
}
],
"answer_doc_url": "https://example.com/answer.docx",
"subject": "physics",
"comment_max_length": 100,
"max_concurrent": 10
}