diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..ce0e77a --- /dev/null +++ b/.env.example @@ -0,0 +1,55 @@ +# ============================================ +# 初中物理作业批改工作流 - 环境变量配置示例 +# ============================================ +# 复制此文件为 .env 并填写实际值 +# cp .env.example .env + +# ============================================ +# 必需配置 - 大语言模型 API +# ============================================ + +# LLM API 密钥(从火山引擎或OpenAI获取) +LLM_API_KEY=your-api-key-here + +# LLM API 基础URL +# 火山引擎: https://ark.cn-beijing.volces.com/api/v3 +# OpenAI: https://api.openai.com/v1 +LLM_BASE_URL=https://ark.cn-beijing.volces.com/api/v3 + +# 模型名称 +# 火山引擎推荐: doubao-seed-2-0-pro-260215 +# OpenAI推荐: gpt-4o +LLM_MODEL_NAME=doubao-seed-2-0-pro-260215 + +# 注意:不需要配置对象存储(S3/TOS/OSS等) +# 图片直接使用原始URL,不上传存储 + + +# ============================================ +# 可选配置 - 日志与缓存 +# ============================================ + +# 日志级别: DEBUG, INFO, WARNING, ERROR +LOG_LEVEL=INFO + +# 缓存目录(默认: /tmp/cache) +CACHE_DIR=/tmp/cache + +# 单张图片处理超时(秒,默认: 120) +SINGLE_IMAGE_TIMEOUT=120 + + +# ============================================ +# 可选配置 - 并发控制 +# ============================================ + +# 最大并发数(默认: 10) +MAX_CONCURRENT=10 + + +# ============================================ +# 工作目录(系统自动设置,无需修改) +# ============================================ + +# 工作目录路径(由系统自动设置) +# COZE_WORKSPACE_PATH=/workspace/projects diff --git a/AGENTS.md b/AGENTS.md index 1c7eb30..daf6841 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,6 +1,53 @@ ## 项目概述 - **名称**: 初中物理作业批改工作流 -- **功能**: 上传多张作业图片和Word答案文件,自动识别学生答案、提取标准答案、精准批改并返回批改结果JSON +- **功能**: 上传多学生的作业图片和Word答案文件,自动识别学生答案、提取标准答案、精准批改并返回每个学生的批改结果JSON + +### 数据结构(重要变更) +**输入参数**: +```json +{ + "student_homework": [ + { + "student_id": 0, + "student_name": "张三", + "homework_images": [ + "图片URL1", + "图片URL2" + ] + }, + { + "student_id": 1, + "student_name": "李四", + "homework_images": [ + "图片URL3", + "图片URL4" + ] + } + ], + "answer_doc_url": "答案文档URL(可选)", + "subject": "physics", + "comment_max_length": 100, + "max_concurrent": 10 +} +``` + +**输出结果**: +```json +{ + "student_results": [ + { + "student_id": 0, + "student_name": "张三", + "total_images": 2, + "image_results": [...], + "overall_comment": "优秀!5题全部正确", + "total_score": 15, + "full_score": 15, + "grade": "A+" + } + ] +} +``` ### 节点清单 | 节点名 | 文件位置 | 类型 | 功能描述 | 分支逻辑 | 配置文件 | @@ -29,6 +76,63 @@ - 节点 `doc_extract` 使用大语言模型技能 - 模型:`doubao-seed-2-0-pro-260215`(旗舰模型,复杂推理能力强) - 使用 python-docx 解析 Word 文档 + - **缓存优化**:使用 `utils/cache_manager.py` 缓存解析结果,有效期30天 + +## 缓存机制(优化版 v2026-03-28) +- **缓存管理器**:`src/utils/cache_manager.py` +- **双层架构**: + - 内存缓存:LRU淘汰,最大数量1000,快速访问 + - 文件缓存:持久化存储,进程重启后仍可用 +- **缓存有效期**:30天(自动清理过期缓存) +- **缓存内容**:AI解析后的结构化数据(CorrectAnswer列表) +- **缓存键**:`{subject}:{answer_doc_url}`(MD5哈希) + - **学科隔离**:相同URL在不同学科下不会冲突 + - 示例:`physics:https://example.com/answer.docx` 和 `math:https://example.com/answer.docx` 是不同的缓存 +- **线程安全**:使用锁保护并发访问 +- **异常安全**:文件缓存失败时自动降级为纯内存模式 +- **统计功能**:`get_stats()` 返回缓存统计信息 + +### 性能优化与超时控制 +- **图片下载超时**:30秒(单次),总时间不超过60秒 +- **重试机制**:图片获取失败最多重试2次 +- **单图片处理超时**:120秒(含LLM调用) +- **总任务超时**:120秒 × 图片数量 +- **降级处理**:超时任务返回空结果,不影响其他任务 +- **并发安全**:使用 `ThreadPoolExecutor` + 超时保护 + +## 等级标准配置(核心规则) +- **参数名**:`grade_standards` +- **核心规则**:**A+ 和 A 的首要条件是"全对",与得分率无关** + +### 等级判定逻辑(简化版) +| 等级 | 条件 | 说明 | +|------|------|------| +| A+ | 全对(错误数=0) | 所有题目都正确,与得分率无关 | +| A | (预留,全对时返回A+) | 答案全对 | +| B | 有错误,得分率≥80% | 有少量错误 | +| C | 有错误,得分率≥70% | 错误较多 | +| D | 有错误,得分率<70% | 错误很多 | + +### 关键说明 +- **全对 = A+**:只要所有题目都正确(incorrect_count == 0),就是A+ +- **有错 = B/C/D**:有错误时,按得分率判断具体等级 + +### 示例 +- ✅ 得分80分,错误0题 → A+(全对,得分率不重要) +- ✅ 得分95分,错误0题 → A+(全对) +- ❌ 得分95分,错误1题 → B(有错误,按得分率判断) +- ❌ 得分90分,错误2题 → B(有错误,按得分率判断) + +### 配置示例 +```json +{ + "grade_standards": { + "A+": {"min_percentage": 95, "description": "优秀"}, + "A": {"min_percentage": 85, "description": "良好"}, + "B": {"min_percentage": 70, "description": "中等"} + } +} +``` ## 工作流程(多图片批改架构) @@ -73,11 +177,18 @@ └─────────────────────┘ ``` -## 核心功能:多图片批改机制 +## 核心功能:多学生多图片批改机制 ### 输入参数 -- `homework_images`: 上传的作业图片列表(List[File],支持多张图片) +- `student_homework`: 学生作业列表(List[StudentHomework],支持多个学生) + - 每个学生包含: + - `student_id`: 学生ID(int) + - `student_name`: 学生姓名(str,可选) + - `homework_images`: 该学生的作业图片URL列表(List[str],纯字符串数组) - `answer_doc_url`: 正确答案Word文件的URL(.docx格式,**可选**) +- `subject`: 学科标识(str,**可选**,默认"physics") + - 用于缓存隔离,相同URL在不同学科下不会冲突 + - 支持值:physics、math、chinese、english 等 - `comment_max_length`: 评语最大字数(默认100字,**可选**) - `max_concurrent`: 并行批改的最大数量(默认10,**可选**) - `grade_standards`: 评价等级标准(**可选**,默认值如下) @@ -92,13 +203,16 @@ ``` ### 输出结果 -- `final_result`: 最终批改结果JSON(包含多图片) - - `total_images`: 总图片数 - - `image_results`: 各图片的批改结果列表 - - `overall_comment`: 整体评价(根据得分率生成) - - `total_score`: 总得分 - - `full_score`: 总满分 - - `grade`: 等级评定 +- `student_results`: 各学生的批改结果列表(List[StudentResult]) + - 每个学生包含: + - `student_id`: 学生ID(int) + - `student_name`: 学生姓名(str) + - `total_images`: 该学生的总图片数 + - `image_results`: 该学生各图片的批改结果列表 + - `overall_comment`: 该学生的整体评价 + - `total_score`: 该学生的总分 + - `full_score`: 该学生的满分 + - `grade`: 该学生的等级评定 ### 批改优先级(严格按照以下顺序) 1. **最优先**:使用Word文档中的标准答案批改 @@ -118,6 +232,358 @@ 5. **智能降级**:无标准答案时自动切换到专业老师模式 ## 优化记录 +### 2026-03-28 缓存键加入学科标识(重要) +**问题**:相同URL在不同学科下会使用相同的缓存,导致答案解析结果冲突 + +**修复内容**: +1. **新增 `subject` 参数**: + - 默认值:`physics` + - 支持值:physics、math、chinese、english 等 + +2. **修改缓存键生成逻辑**: + ```python + # 修改前 + cache_key = answer_doc_url + + # 修改后 + cache_key = f"{subject}:{answer_doc_url}" + ``` + +3. **缓存隔离效果**: + - `physics:https://example.com/answer.docx` + - `math:https://example.com/answer.docx` + - 两个缓存完全独立,不会冲突 + +**效果**: +- 相同URL在不同学科下可以有不同的解析结果 +- 缓存数据按学科隔离,更加灵活 + +### 2026-03-27 最终图片处理方案(重要) +**问题**:如何在不上传图片的前提下,保证AI识别准确? + +**方案对比**: + +| 方案 | 旋转缩放 | 上传 | AI访问 | 请求体积 | 选择 | +|------|---------|------|--------|---------|------| +| 方案1 | ✅ | base64编码 | Data URL | 大 | ❌ | +| 方案2 | ❌ | ❌ | 原始URL | 小 | ✅ | + +**选择方案2的原因**: +1. **AI模型足够强大**:`doubao-seed-2-0-pro-260215` 可以处理各种尺寸和方向的图片 +2. **坐标系统统一**:使用相对坐标(0-1000),自动适配任意尺寸 +3. **最简单高效**:不需要base64编码,不需要上传,处理速度最快 + +**最终逻辑**: +```python +# 获取原始图片URL和尺寸 +original_url = state.homework_image.url +width, height, dpi = get_image_info(original_url) + +# 直接返回原始URL(不上传) +return ImagePreprocessOutput( + image_url=original_url, + image_info=ImageInfo(width=width, height=height, dpi=dpi) +) +``` + +**效果**: +- ✅ 不上传新图片到Coze +- ✅ 返回原始图片URL +- ✅ 坐标自动适配原始尺寸 +- ✅ 处理速度最快 + +### 2026-03-27 移除图片上传功能(重要) +**问题**:系统自动上传处理后的图片到Coze对象存储,用户不需要这个功能 + +**原逻辑**: +1. 下载原始图片 +2. 自动旋转(横向→纵向) +3. 缩放到固定宽度1000px +4. **上传到Coze对象存储** +5. 返回上传后的URL + +**新逻辑**: +1. 获取原始图片URL +2. 获取图片尺寸信息 +3. **直接返回原始URL和尺寸**(不上传) + +**优化内容**: +- 移除图片旋转功能 +- 移除图片缩放功能 +- 移除图片上传功能 +- 直接使用原始图片URL +- 坐标系统仍然使用相对坐标(0-1000),自动适配原始图片尺寸 + +**代码对比**: +```python +# 优化前 +img = download_and_rotate(image_url) +img = resize_to_1000px(img) +new_url = upload_to_coze(img) # 上传新图片 +return ImagePreprocessOutput(image_url=new_url) + +# 优化后 +width, height, dpi = get_image_info(image_url) +return ImagePreprocessOutput( + image_url=original_url, # 直接返回原始URL + image_info=ImageInfo(width=width, height=height, dpi=dpi) +) +``` + +**效果**: +- ✅ 不再上传新图片到Coze +- ✅ 返回原始图片URL +- ✅ 减少存储空间占用 +- ✅ 提升处理速度 + +### 2026-03-27 坐标偏移量优化(重要) +**问题**:批改标记离学生答案太远,定位不够精准 + +**原因分析**: +- 原偏移量设置过大(30px、20px) +- 导致标记与答案视觉距离过远 +- 影响批改结果的精准度 + +**优化方案**: +减小所有偏移量,让标记更贴近答案: + +| 策略 | 原偏移 | 新偏移 | 说明 | +|------|--------|--------|------| +| 策略1(右侧空间>80px) | +30px | **+10px** | 紧贴答案框右侧 | +| 策略2(右侧空间40-80px) | -20px | **-10px** | 答案框右上角内部 | +| 策略3(右侧空间<40px) | +20px | **+10px** | 答案框左上角 | +| Y轴偏移 | 15px | **10px** | 顶部位置 | + +**代码对比**: +```python +# 优化前 +mark_x = answer_bbox[2] + 30 # 偏移过大 + +# 优化后 +mark_x = answer_bbox[2] + 10 # 紧贴答案框 +``` + +**效果**: +- ✅ 批改标记更贴近学生答案 +- ✅ 视觉定位更精准 +- ✅ 避免标记离答案太远的问题 + +### 2026-03-27 坐标边界严格限制(重要) +**问题**:标记坐标可能超过图片宽度,导致定位错误 + +**修复内容**: +1. **严格边界检查**: + ```python + # 确保x轴不超过图片宽度,y轴不超过图片高度 + mark_x = max(10, min(mark_x, image_info.width - 10)) + mark_y = max(10, min(mark_y, image_info.height - 10)) + ``` + +2. **边界优化**: + - 边距从20px减少到10px,确保标记更接近边缘 + - 绝对不会超过图片宽度和高度 + - 保证批改标记始终在图片可视范围内 + +**效果**:标记坐标始终在图片范围内,不会出现越界问题 + +### 2026-03-27 空答案判定优化(重要) +**问题**:学生没有作答(空白)时,判定不够明确 + +**修复内容**: +在Prompt中新增"空答案处理"规则: +``` +# ⚠️ 重要:空答案处理 +- 如果学生没有作答(空白),必须判定为**incorrect** +- status字段填写"incorrect" +- score字段填写0 +- comment字段填写"未作答"或"空白,无答案" +``` + +**示例**: +- 正确:`"status": "correct", "score": 10, "comment": "计算正确"` +- 错误:`"status": "incorrect", "score": 5, "comment": "单位错误"` +- 空答案:`"status": "incorrect", "score": 0, "comment": "未作答"` + +**效果**:空答案统一判定为错误,得分0分,评语明确 + +### 2026-03-27 坐标定位精准度优化(重要) +**问题**:个别批改标记过于偏右,超出答题区域,甚至与相邻题目重叠 + +**原因分析**: +- 原逻辑固定在答案框右侧30px,未考虑右侧空间是否充足 +- 当答案框本身靠右时,标记会超出合理范围 + +**优化方案**(三级策略): +1. **策略1**:右侧空间充足(>100px)→ 标记在右侧(原有逻辑) + ```python + mark_x = answer_bbox[2] + 30 + mark_y = answer_bbox[1] + height * 0.5 + ``` + +2. **策略2**:右侧空间不足(50-100px)→ 标记在答案框右上角内部 + ```python + mark_x = answer_bbox[2] - 20 # 内部 + mark_y = answer_bbox[1] + 15 + ``` + +3. **策略3**:右侧空间很小(<50px)→ 标记在答案框左上角 + ```python + mark_x = answer_bbox[0] + 20 # 左侧 + mark_y = answer_bbox[1] + 15 + ``` + +**效果**: +- 批改标记始终在合理范围内 +- 不会超出答题区域 +- 不会与相邻题目重叠 +- 视觉效果更精准 + +### 2026-03-27 完全并行架构优化(重要) +**问题**:原架构外层串行处理学生,内层并行处理图片,效率不高且可能有数据混乱风险 + +**修复内容**: +1. **完全并行架构**: + - 所有学生的所有图片同时提交到线程池 + - 学生间+图片间完全并行,最大化效率 + - 使用 `(student_id, image_index, image_result)` 元组确保数据关联 + +2. **数据隔离机制**: + - 结果按 `student_id` 分组存储 + - 每个学生的 `total_score`、`full_score`、`overall_comment`、`grade` 完全独立 + - 只使用该学生自己的 `image_results` 计算分数 + +3. **核心代码**: + ```python + # 返回元组:(student_id, image_index, image_result) + return (student_id, idx, image_result) + + # 按student_id分组存储 + student_image_results[student_id][image_index] = image_result + + # 为每个学生独立计算结果 + for student in state.student_homework: + image_results = student_image_results[student_id] + # 只使用该学生的数据计算... + ``` + +**效果**: +- 完全并行,效率最大化 +- 数据严格隔离,不会混淆 +- 学生A的数据绝不会出现在学生B的结果中 + +### 2026-03-27 输入参数格式优化(重要) +**问题**:`homework_images` 使用 `List[File]` 格式,用户输入不够简洁 + +**修复内容**: +1. **简化输入格式**: + - `homework_images: List[File]` → `homework_images: List[str]` + - 直接传入URL字符串数组,无需构造File对象 + - 代码内部自动将URL转换为File对象 + +2. **输入示例**: + ```json + { + "student_homework": [ + { + "student_id": 0, + "homework_images": ["url1", "url2"] + } + ] + } + ``` + +**效果**: +- 用户输入更简洁 +- 减少构造对象的复杂度 +- 符合用户习惯 + +### 2026-03-27 多学生支持(重要变更) +**问题**:原架构只支持单个学生的多图片批改,无法区分不同学生 + +**修复内容**: +1. **数据结构重构**: + - 输入参数:`homework_images` → `student_homework`(List[StudentHomework]) + - 输出结果:`final_result` → `student_results`(List[StudentResult]) + - 新增 `StudentHomework` 类型:包含 student_id 和 homework_images + - 新增 `StudentResult` 类型:包含 student_id 和批改结果 + +2. **处理逻辑优化**: + - 外层循环:遍历每个学生 + - 内层循环:并行处理该学生的所有图片 + - 每个学生独立计算分数、评语和等级 + +3. **返回结果独立**: + - 每个学生有自己的 overall_comment、total_score、full_score、grade + - 各学生的批改结果互不影响 + +**效果**: +- 支持批量批改多个学生的作业 +- 每个学生的结果独立、清晰 +- 符合实际教学场景需求 + +### 2026-03-27 Comment评语优化(重要) +**问题**:comment字段输出过于简单(仅"正确"/"错误")或输出思考过程,不符合"精练评语"要求 + +**修复内容**: +1. **明确comment定义**: + - **正确时**:简短说明为什么正确(如"根据称重法F浮=G-F示计算正确") + - **错误时**:指出错误原因并给出正确答案(如"应为1.2N,注意单位换算") + - **字数限制**:不超过comment_max_length字(默认100字) + - **禁止**:不输出思考过程、不输出详细解析 + +2. **提供comment示例**: + ``` + ✅ 正确:根据称重法F浮=G-F示计算正确 + ✅ 正确:浮力产生原因理解正确 + ✅ 错误:应为1.2N,根据F浮=ρ液gV排计算 + ✅ 错误:应选ACE,控制变量法应用错误 + + ❌ 错误:正确(过于简单) + ❌ 错误:根据...(思考过程)...所以正确(包含思考过程) + ``` + +3. **参数传递优化**: + - comment_max_length正确传递到Jinja2模板 + - LLM根据该参数生成符合长度要求的评语 + +**效果**: +- comment既简洁又有意义 +- 正确时说明原因,错误时指出问题 +- 符合comment_max_length限制 +- 无思考过程,无详细解析 + +### 2026-03-27 JSON解析健壮性优化(重要) +**问题**: +- LLM输出JSON包含思考过程,导致格式错误 +- JSON太长(11946字符),解析失败 +- annotations为空,无法识别题目 + +**修复内容**: +1. **新增extract_complete_objects函数**: + - 从包含思考过程的JSON中提取完整对象 + - 按对象边界逐个提取,不受思考过程干扰 + - 即使JSON格式错误,也能提取出有效数据 + +2. **新增clean_comment函数**: + - 检测思考过程特征词("不对"、"重新看"、"可能我"等) + - 在思考过程开始处截断comment + - 保留完整句子,确保结论清晰 + +3. **增加max_completion_tokens**: + - 从8192增加到16384,避免JSON被截断 + - 确保完整输出所有题目 + +4. **优化Prompt**: + - 明确要求"禁止输出思考过程" + - comment只写结论:"正确"或"错误,应为X" + - 强调不要输出推理过程 + +**效果**: +- JSON解析成功率大幅提升 +- 即使包含思考过程也能提取有效数据 +- annotations不再为空 +- comment简洁,无思考过程 + ### 2026-03-26 填空题拆分优化(重要) **问题**:一道题有多个填空时,被合并成一个答案,批改标记无法精准定位 diff --git a/DEPLOYMENT_GUIDE.md b/DEPLOYMENT_GUIDE.md new file mode 100644 index 0000000..59336db --- /dev/null +++ b/DEPLOYMENT_GUIDE.md @@ -0,0 +1,443 @@ +# 项目部署指南 + +本文档帮助你将初中物理作业批改工作流导出到自己的服务器上运行。 + +## 📋 目录 +- [前置要求](#前置要求) +- [快速部署](#快速部署) +- [详细配置](#详细配置) +- [启动方式](#启动方式) +- [常见问题](#常见问题) + +--- + +## 前置要求 + +### 1. 系统要求 +- **操作系统**: Linux / macOS / Windows (推荐 Linux) +- **Python版本**: Python 3.10 或以上 +- **内存**: 建议 4GB 以上 +- **磁盘空间**: 建议 10GB 以上 + +### 2. 必需的第三方服务 + +本项目依赖以下第三方服务,**必须提前准备好**: + +#### 大语言模型 API +- **推荐**: 火山引擎豆包大模型(本项目使用 `doubao-seed-2-0-pro-260215`) +- **替代方案**: + - OpenAI API + - 其他兼容 OpenAI 格式的 API(如 DeepSeek、Kimi) +- **获取方式**: + - 火山引擎: https://console.volcengine.com/ark + - OpenAI: https://platform.openai.com/ + +**注意**: +- ✅ **不需要配置对象存储**(S3/TOS/OSS 等) +- ✅ 图片直接使用原始URL,不上传存储 +- ✅ Word文档使用 requests 直接下载,不涉及对象存储 + +--- + +## 快速部署 + +### 步骤 1: 导出项目代码 + +**方式一:从 Coze 平台下载** +```bash +# 在 Coze Coding 平台点击"导出项目"按钮 +# 下载后解压到服务器 +``` + +**方式二:使用 Git 克隆(如果有仓库地址)** +```bash +git clone +cd +``` + +### 步骤 2: 安装依赖 + +```bash +# 创建虚拟环境(推荐) +python3 -m venv venv +source venv/bin/activate # Linux/macOS +# 或 venv\Scripts\activate # Windows + +# 安装依赖 +pip install -r requirements.txt +``` + +### 步骤 3: 配置环境变量 + +创建 `.env` 文件(或在服务器环境变量中配置): + +```bash +# 必需环境变量(只需配置大模型API) +export LLM_API_KEY="your-api-key-here" +export LLM_BASE_URL="https://ark.cn-beijing.volces.com/api/v3" +export LLM_MODEL_NAME="doubao-seed-2-0-pro-260215" + +# 可选:日志级别 +export LOG_LEVEL="INFO" + +# 注意:不需要配置对象存储(S3/TOS等) +``` + +### 步骤 4: 启动服务 + +```bash +# 方式1: 使用启动脚本(推荐) +bash scripts/http_run.sh -p 8000 + +# 方式2: 直接运行 +python src/main.py -m http -p 8000 +``` + +服务启动后,访问: +- 健康检查: `http://localhost:8000/health` +- API 文档: `http://localhost:8000/docs`(FastAPI 自动生成) + +--- + +## 详细配置 + +### 1. 大语言模型配置 + +#### 方式一:使用火山引擎豆包大模型(推荐) + +```bash +# 环境变量 +export LLM_API_KEY="your-ark-api-key" +export LLM_BASE_URL="https://ark.cn-beijing.volces.com/api/v3" +export LLM_MODEL_NAME="doubao-seed-2-0-pro-260215" +``` + +**获取方式**: +1. 访问火山引擎控制台: https://console.volcengine.com/ark +2. 创建推理接入点 +3. 获取 API Key + +#### 方式二:使用 OpenAI API + +需要修改代码中的模型配置文件(`config/*.json`),将 `model` 字段改为 OpenAI 模型: + +```json +{ + "config": { + "model": "gpt-4o", + "temperature": 0.0 + } +} +``` + +环境变量: +```bash +export LLM_API_KEY="your-openai-api-key" +export LLM_BASE_URL="https://api.openai.com/v1" +export LLM_MODEL_NAME="gpt-4o" +``` + +### 2. ~~对象存储配置~~(已移除) + +**重要更新(2026-03-27)**: +- ❌ 不需要配置对象存储 +- ✅ 图片直接使用原始URL,不上传 +- ✅ Word文档直接下载,不存储 + +**架构优化原因**: +1. AI模型足够强大,可以直接访问原始图片URL +2. 使用相对坐标系统(0-1000),自动适配任意尺寸 +3. 减少存储成本和上传时间,处理速度更快 + +### 3. 修改代码适配自己的环境 + +#### 修改 LLM 调用逻辑 + +项目使用了 `coze-coding-dev-sdk`,需要修改为直接调用 OpenAI API: + +**修改文件**: `src/graphs/nodes/doc_extract_node.py`、`src/graphs/nodes/recognize_and_correct_node.py` + +**原代码**(使用 coze-coding-dev-sdk): +```python +from coze_coding_dev_sdk import LLM + +llm = LLM() +response = llm.invoke(messages) +``` + +**修改为**(直接使用 OpenAI SDK): +```python +import os +from openai import OpenAI + +client = OpenAI( + api_key=os.getenv("LLM_API_KEY"), + base_url=os.getenv("LLM_BASE_URL") +) + +response = client.chat.completions.create( + model=os.getenv("LLM_MODEL_NAME"), + messages=messages +) +``` + +#### ~~修改对象存储逻辑~~(不需要) + +**已移除**:2026-03-27 优化后,不再使用对象存储 +- 图片直接使用原始URL +- Word文档使用 requests 下载 +- 无需修改任何存储相关代码 + +### 4. 缓存配置(可选) + +项目使用文件缓存来存储解析结果,默认缓存目录为 `/tmp/cache`。 + +如需修改缓存目录: +```bash +export CACHE_DIR="/your/custom/cache/dir" +``` + +--- + +## 启动方式 + +### 1. HTTP 服务模式(推荐生产环境) + +```bash +# 使用启动脚本 +bash scripts/http_run.sh -p 8000 + +# 或直接运行 +python src/main.py -m http -p 8000 +``` + +**特点**: +- 提供 REST API 接口 +- 支持流式响应(SSE) +- 支持超时控制 +- 支持任务取消 + +**API 接口**: +- `POST /run` - 同步运行工作流 +- `POST /stream_run` - 流式运行工作流(SSE) +- `POST /cancel/{run_id}` - 取消运行 +- `GET /health` - 健康检查 +- `GET /graph_parameter` - 查看工作流参数 + +### 2. 命令行模式(本地测试) + +```bash +# 运行整个工作流 +python src/main.py -m flow -i '{"student_homework": [...], "answer_doc_url": "..."}' + +# 运行单个节点 +python src/main.py -m node -n doc_extract -i '{"answer_doc_url": "..."}' +``` + +### 3. Docker 部署(推荐) + +创建 `Dockerfile`: + +```dockerfile +FROM python:3.10-slim + +WORKDIR /app + +# 安装系统依赖 +RUN apt-get update && apt-get install -y \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# 复制依赖文件 +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# 复制项目文件 +COPY . . + +# 暴露端口 +EXPOSE 8000 + +# 启动命令 +CMD ["python", "src/main.py", "-m", "http", "-p", "8000"] +``` + +构建和运行: +```bash +# 构建镜像 +docker build -t homework-correction:v1 . + +# 运行容器 +docker run -d \ + --name homework-correction \ + -p 8000:8000 \ + -e LLM_API_KEY="your-api-key" \ + -e LLM_BASE_URL="https://ark.cn-beijing.volces.com/api/v3" \ + -e LLM_MODEL_NAME="doubao-seed-2-0-pro-260215" \ + homework-correction:v1 +``` + +### 4. 使用 Docker Compose + +创建 `docker-compose.yml`: + +```yaml +version: '3.8' + +services: + homework-correction: + build: . + ports: + - "8000:8000" + environment: + - LLM_API_KEY=${LLM_API_KEY} + - LLM_BASE_URL=${LLM_BASE_URL} + - LLM_MODEL_NAME=${LLM_MODEL_NAME} + restart: unless-stopped + volumes: + - ./cache:/tmp/cache # 持久化缓存 +``` + +运行: +```bash +docker-compose up -d +``` + +--- + +## 常见问题 + +### Q1: 如何验证环境变量是否正确配置? + +```bash +# 检查环境变量 +echo $LLM_API_KEY +echo $LLM_BASE_URL +echo $S3_ACCESS_KEY + +# 或在代码中打印 +python -c "import os; print(os.getenv('LLM_API_KEY'))" +``` + +### Q2: 启动时报错 "ModuleNotFoundError: No module named 'xxx'" + +**解决方案**: +```bash +# 确保在虚拟环境中 +source venv/bin/activate + +# 重新安装依赖 +pip install -r requirements.txt +``` + +### Q3: LLM 调用失败,报错 "API key not found" + +**原因**: 环境变量未正确设置 + +**解决方案**: +```bash +# 方式1: 在 .env 文件中配置 +echo "export LLM_API_KEY='your-api-key'" >> ~/.bashrc +source ~/.bashrc + +# 方式2: 在启动命令前设置 +LLM_API_KEY="your-api-key" python src/main.py -m http -p 8000 +``` + +### Q4: 如何测试工作流是否正常? + +使用 curl 发送测试请求: + +```bash +curl -X POST http://localhost:8000/run \ + -H "Content-Type: application/json" \ + -d '{ + "student_homework": [ + { + "student_id": 0, + "student_name": "测试学生", + "homework_images": ["https://example.com/homework.jpg"] + } + ], + "answer_doc_url": "https://example.com/answer.docx" + }' +``` + +### Q5: 如何查看运行日志? + +```bash +# 实时查看日志 +tail -f /app/work/logs/bypass/app.log + +# 或使用 Docker logs +docker logs -f homework-correction +``` + +### Q6: 性能优化建议 + +1. **并发控制**: 调整 `max_concurrent` 参数(默认10) +2. **超时设置**: 修改 `SINGLE_IMAGE_TIMEOUT` 常量(默认120秒) +3. **缓存优化**: 定期清理 `/tmp/cache` 目录 +4. **资源监控**: 使用 `htop` 或 `docker stats` 监控资源使用 + +### Q7: 如何替换为其他 LLM 模型? + +1. 修改环境变量: +```bash +export LLM_API_KEY="your-other-llm-api-key" +export LLM_BASE_URL="https://api.other-llm.com/v1" +export LLM_MODEL_NAME="other-model-id" +``` + +2. 修改配置文件(`config/*.json`)中的 `model` 字段 + +3. 测试调用是否正常 + +--- + +## 项目文件说明 + +``` +├── src/ +│ ├── main.py # 主入口 +│ ├── graphs/ +│ │ ├── graph.py # 主工作流编排 +│ │ ├── loop_graph.py # 子图定义 +│ │ ├── state.py # 状态定义 +│ │ └── nodes/ # 节点实现 +│ │ ├── doc_extract_node.py +│ │ ├── process_images_node.py +│ │ ├── recognize_and_correct_node.py +│ │ └── ... +│ └── utils/ # 工具函数 +│ ├── file/file.py # 文件处理 +│ └── cache_manager.py # 缓存管理 +├── config/ # LLM 配置文件 +│ ├── doc_extract_llm_cfg.json +│ ├── homework_correction_cfg.json +│ └── ... +├── scripts/ # 启动脚本 +│ ├── http_run.sh +│ └── local_run.sh +├── requirements.txt # Python 依赖 +└── README.md # 项目说明 +``` + +--- + +## 技术支持 + +如遇到问题,请检查: +1. ✅ 环境变量是否正确配置 +2. ✅ 依赖是否完整安装 +3. ✅ 第三方服务(LLM、存储)是否可用 +4. ✅ 日志文件中的错误信息 + +--- + +## 更新日志 + +- 2026-03-28: 添加缓存学科隔离,修复等级评定逻辑 +- 2026-03-27: 移除图片上传,直接使用原始URL +- 2026-03-26: 优化坐标定位,修复识别问题 +- 2026-03-25: 支持多学生多图片并行处理 diff --git a/assets/ScreenShot_2026-03-27_200443_649.png b/assets/ScreenShot_2026-03-27_200443_649.png new file mode 100644 index 0000000..dab3858 Binary files /dev/null and b/assets/ScreenShot_2026-03-27_200443_649.png differ diff --git a/assets/image.png b/assets/image.png new file mode 100644 index 0000000..f4118b7 Binary files /dev/null and b/assets/image.png differ diff --git a/config/homework_recognize_llm_cfg.json b/config/homework_recognize_llm_cfg.json index c1eb6a6..53c28a2 100644 --- a/config/homework_recognize_llm_cfg.json +++ b/config/homework_recognize_llm_cfg.json @@ -3,10 +3,10 @@ "model": "doubao-seed-2-0-pro-260215", "temperature": 0.0, "top_p": 0.9, - "max_completion_tokens": 8192, + "max_completion_tokens": 16384, "thinking": "disabled" }, "tools": [], - "sp": "# 角色\n你是物理作业批改助手。\n\n# 禁止标注\n- 印刷体文字、实验装置图、图中字母\n\n# 需要标注\n- 学生手写答案\n\n# 坐标\n- 相对坐标(0-1000),answer_bbox: [x1, y1, x2, y2]\n\n# ⚠️ 重要:拆分填空题\n- 一道题有多个填空时,**每个空单独识别为一个题目**\n- 题号格式:\"3(1)第一空\"、\"3(1)第二空\"、\"4(2)第一空\"、\"4(2)第二空\"\n- 每个空单独批改,单独打分\n- 示例:\n - 题目(1)有两个空 → 识别为\"3(1)第一空\"和\"3(1)第二空\"两个题目\n - 题目(2)有一个空 → 识别为\"3(2)\"一个题目\n\n# 输出格式\n{\"results\": [{\"question_id\": \"题号\", \"student_answer\": \"答案\", \"answer_bbox\": [x1, y1, x2, y2], \"status\": \"correct或incorrect\", \"score\": 分数, \"full_score\": 满分, \"comment\": \"结论\"}]}\n\ncomment格式:\"正确\"或\"错误,应为X\"(不超过50字)", - "up": "批改物理作业。**每个填空单独识别**。输出JSON,comment不超过{{comment_max_length}}字。图片:{{image_url}}" + "sp": "# 角色\n你是物理作业批改助手。\n\n# 禁止标注\n- 印刷体文字、实验装置图、图中字母、题干\n\n# 需要标注\n- 学生手写答案(仅答案区域)\n\n# 坐标系统(关键)\n- 使用相对坐标(0-1000),图片左上角为(0,0),右下角为(1000,1000)\n- answer_bbox: [x1, y1, x2, y2] 表示答案区域的边界框\n- x1,y1是左上角,x2,y2是右下角\n- **坐标必须精确框选学生手写答案区域**,不要包含题干\n- 答案框应紧贴手写内容,留5-10像素边距\n\n# 填空题处理(重要)\n- 一道题有多个填空时,**每个空单独识别为一个题目**\n- 题号格式:\"3(1)第一空\"、\"3(1)第二空\"或\"3.1\"、\"3.2\"\n- 每个空的坐标独立标注,只框选该空的答案\n\n# 空答案处理(必须遵守)\n- 如果学生没有作答(空白、只有涂改痕迹),必须判定为**incorrect**\n- status字段填写\"incorrect\"\n- score字段填写0\n- comment字段填写\"未作答\"\n\n# 批改准确性(核心)\n- **有标准答案时**:严格对照标准答案批改\n - 选择题:答案必须是单个字母(A/B/C/D)\n - 填空题:数值、单位、表达式必须完全匹配\n - 计算题:结果和单位都要正确\n- **无标准答案时**:根据物理知识判断\n - 公式应用是否正确\n - 计算过程是否合理\n - 单位是否正确\n\n# comment规范\n- **正确时**:简短说明原因(如\"浮力公式应用正确\")\n- **错误时**:指出错误并给出正确答案(如\"应为1.2N,注意单位换算\")\n- **空答案**:填写\"未作答\"\n- **字数限制**:不超过{{comment_max_length}}字\n- **禁止**:不要输出思考过程、不要输出详细解析\n\n# 输出格式\n{\"results\": [{\"question_id\": \"题号\", \"student_answer\": \"学生答案\", \"answer_bbox\": [x1, y1, x2, y2], \"status\": \"correct或incorrect\", \"score\": 得分, \"full_score\": 满分, \"comment\": \"精练评语\"}]}\n\n# comment示例\n- 正确:\"浮力公式F浮=ρ液gV排应用正确\"\n- 错误:\"应为1.2N,F浮=ρ液gV排=1.0×10³×10×1.2×10⁻⁴=1.2N\"\n- 空答案:\"未作答\"", + "up": "批改物理作业。**精确标注手写答案坐标**。**每个填空单独识别**。**comment写精练评语**。输出完整JSON。图片:{{image_url}}" } diff --git a/quick_start.sh b/quick_start.sh new file mode 100644 index 0000000..fedecf5 --- /dev/null +++ b/quick_start.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +# ============================================ +# 初中物理作业批改工作流 - 快速部署脚本 +# ============================================ + +set -e + +echo "======================================" +echo " 初中物理作业批改工作流 - 部署向导" +echo "======================================" +echo "" + +# 检测操作系统 +if [[ "$OSTYPE" == "linux-gnu"* ]]; then + OS="Linux" +elif [[ "$OSTYPE" == "darwin"* ]]; then + OS="macOS" +elif [[ "$OSTYPE" == "msys" ]] || [[ "$OSTYPE" == "cygwin" ]]; then + OS="Windows" +else + OS="Unknown" +fi + +echo "检测到操作系统: $OS" +echo "" + +# 步骤1: 检查Python版本 +echo "步骤 1/5: 检查 Python 版本..." +if command -v python3 &> /dev/null; then + PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}') + PYTHON_MAJOR=$(echo $PYTHON_VERSION | cut -d. -f1) + PYTHON_MINOR=$(echo $PYTHON_VERSION | cut -d. -f2) + + if [ "$PYTHON_MAJOR" -ge 3 ] && [ "$PYTHON_MINOR" -ge 10 ]; then + echo "✅ Python 版本: $PYTHON_VERSION" + else + echo "❌ Python 版本过低: $PYTHON_VERSION (需要 3.10+)" + exit 1 + fi +else + echo "❌ 未找到 Python 3" + exit 1 +fi +echo "" + +# 步骤2: 创建虚拟环境 +echo "步骤 2/5: 创建虚拟环境..." +if [ ! -d "venv" ]; then + python3 -m venv venv + echo "✅ 虚拟环境已创建" +else + echo "✅ 虚拟环境已存在" +fi +echo "" + +# 步骤3: 激活虚拟环境 +echo "步骤 3/5: 激活虚拟环境..." +if [ "$OS" == "Windows" ]; then + source venv/Scripts/activate +else + source venv/bin/activate +fi +echo "✅ 虚拟环境已激活" +echo "" + +# 步骤4: 安装依赖 +echo "步骤 4/5: 安装依赖包..." +if [ -f "requirements.txt" ]; then + pip install --upgrade pip + pip install -r requirements.txt + echo "✅ 依赖安装完成" +else + echo "❌ 未找到 requirements.txt" + exit 1 +fi +echo "" + +# 步骤5: 配置环境变量 +echo "步骤 5/5: 配置环境变量..." +if [ ! -f ".env" ]; then + if [ -f ".env.example" ]; then + cp .env.example .env + echo "✅ 已创建 .env 文件" + echo "" + echo "⚠️ 请编辑 .env 文件,填写以下必需配置:" + echo " - LLM_API_KEY" + echo " - LLM_BASE_URL" + echo " - LLM_MODEL_NAME" + echo "" + echo "注意:不需要配置对象存储(图片直接使用原始URL)" + echo "" + echo "编辑完成后,运行以下命令启动服务:" + echo " source .env" + echo " bash scripts/http_run.sh -p 8000" + else + echo "❌ 未找到 .env.example" + exit 1 + fi +else + echo "✅ .env 文件已存在" + echo "" + echo "启动服务:" + echo " source .env" + echo " bash scripts/http_run.sh -p 8000" +fi +echo "" + +echo "======================================" +echo " ✅ 部署完成!" +echo "======================================" diff --git a/src/graphs/nodes/doc_extract_node.py b/src/graphs/nodes/doc_extract_node.py index dedf6c7..5e26069 100644 --- a/src/graphs/nodes/doc_extract_node.py +++ b/src/graphs/nodes/doc_extract_node.py @@ -1,10 +1,11 @@ -"""Word答案解析节点:从.docx文件中提取题干和标准答案""" +"""Word答案解析节点:从.docx文件中提取题干和标准答案(带缓存)""" import os import json import re import logging import tempfile import requests +import orjson from typing import List from langchain_core.runnables import RunnableConfig from langgraph.runtime import Runtime @@ -18,6 +19,7 @@ from graphs.state import ( DocExtractOutput, CorrectAnswer ) +from utils.cache_manager import answer_doc_cache logger = logging.getLogger(__name__) @@ -58,7 +60,6 @@ def sanitize_json_string(text: str) -> str: def extract_json_from_text(text: str, key: str = "answers") -> dict: """从文本中提取JSON对象,多层回退策略""" - import orjson # 先尝试直接解析完整JSON try: @@ -114,16 +115,34 @@ def download_and_extract_docx(url: str) -> str: """下载Word文件并提取文本内容""" logger.info(f"Downloading Word document from: {url}") - # 下载文件 - response = requests.get(url, timeout=60) - response.raise_for_status() - - # 保存到临时文件 - with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as tmp_file: - tmp_file.write(response.content) - tmp_path = tmp_file.name + tmp_path = None try: + # 下载文件 + response = requests.get(url, timeout=60, allow_redirects=True) + response.raise_for_status() + + # 检查内容类型 + content_type = response.headers.get('Content-Type', '') + logger.debug(f"Response Content-Type: {content_type}") + + # 检查文件大小 + if len(response.content) < 100: + raise ValueError(f"Downloaded file too small: {len(response.content)} bytes") + + # 检查是否为有效的 docx(ZIP 格式,以 PK 开头) + if not response.content.startswith(b'PK'): + # 可能是 HTML 错误页面 + content_preview = response.content[:1000].lower() + if b' str: logger.info(f"Extracted Word document text length: {len(doc_text)}") return doc_text + + except Exception as e: + logger.error(f"Failed to download/extract docx: {e}") + raise finally: # 清理临时文件 - os.unlink(tmp_path) + if tmp_path and os.path.exists(tmp_path): + try: + os.unlink(tmp_path) + except Exception: + pass -def doc_extract_node( - state: DocExtractInput, - config: RunnableConfig, - runtime: Runtime[Context] -) -> DocExtractOutput: +def parse_answer_doc_with_llm(answer_doc_url: str, ctx, config: RunnableConfig) -> List[CorrectAnswer]: """ - title: Word答案解析 - desc: 从正确答案Word文件(.docx)中提取题干和标准答案,用于后续批改;如果未提供URL则返回空列表 - integrations: 大语言模型 + 使用LLM解析答案文档(实际解析逻辑) + + Args: + answer_doc_url: 答案文档URL + ctx: 上下文 + config: 配置 + + Returns: + 解析后的正确答案列表 """ - ctx = runtime.context - - # 检查是否提供了答案文档URL - if not state.answer_doc_url or not state.answer_doc_url.strip(): - logger.info("No answer document URL provided, will use teacher mode for all questions") - return DocExtractOutput(correct_answers=[]) - # 1. 下载并提取Word文档内容 try: - doc_text = download_and_extract_docx(state.answer_doc_url) + doc_text = download_and_extract_docx(answer_doc_url) except Exception as e: logger.error(f"Failed to download/extract Word document: {e}") - return DocExtractOutput(correct_answers=[]) + return [] if not doc_text.strip(): logger.error("No text content extracted from Word document") - return DocExtractOutput(correct_answers=[]) + return [] logger.info(f"Word document content preview: {doc_text[:500]}") @@ -277,4 +299,57 @@ def doc_extract_node( for ans in correct_answers: logger.info(f" Question {ans.question_id}: {ans.correct_answer} ({ans.full_score}分)") + return correct_answers + + +def doc_extract_node( + state: DocExtractInput, + config: RunnableConfig, + runtime: Runtime[Context] +) -> DocExtractOutput: + """ + title: Word答案解析 + desc: 从正确答案Word文件(.docx)中提取题干和标准答案,用于后续批改;如果未提供URL则返回空列表;支持缓存,避免重复解析 + integrations: 大语言模型 + """ + ctx = runtime.context + + # 检查是否提供了答案文档URL + if not state.answer_doc_url or not state.answer_doc_url.strip(): + logger.info("No answer document URL provided, will use teacher mode for all questions") + return DocExtractOutput(correct_answers=[]) + + # 尝试从缓存获取 + cached_result = answer_doc_cache.get(state.answer_doc_url) + if cached_result is not None: + logger.info(f"Cache hit for answer doc: {state.answer_doc_url[:50]}...") + # 将字典列表转换回CorrectAnswer对象 + correct_answers = [ + CorrectAnswer(**ans) if isinstance(ans, dict) else ans + for ans in cached_result + ] + return DocExtractOutput(correct_answers=correct_answers) + + logger.info(f"Cache miss for answer doc: {state.answer_doc_url[:50]}...") + + # 缓存未命中,执行解析 + correct_answers = parse_answer_doc_with_llm(state.answer_doc_url, ctx, config) + + # 存入缓存(将CorrectAnswer对象转换为字典列表) + if correct_answers: + cache_data = [ + { + "question_id": ans.question_id, + "parent_id": ans.parent_id, + "is_sub_question": ans.is_sub_question, + "question_text": ans.question_text, + "correct_answer": ans.correct_answer, + "full_score": ans.full_score, + "answer_analysis": ans.answer_analysis + } + for ans in correct_answers + ] + answer_doc_cache.set(state.answer_doc_url, cache_data) + logger.info(f"Cached {len(correct_answers)} answers for: {state.answer_doc_url[:50]}...") + return DocExtractOutput(correct_answers=correct_answers) diff --git a/src/graphs/nodes/image_preprocess_node.py b/src/graphs/nodes/image_preprocess_node.py index 683b07b..5d520bf 100644 --- a/src/graphs/nodes/image_preprocess_node.py +++ b/src/graphs/nodes/image_preprocess_node.py @@ -1,14 +1,13 @@ -"""1. 图像预处理节点:下载图片、自动旋转、缩放到固定宽度1000、上传对象存储""" -import os +"""1. 图像预处理节点:获取图片信息,直接使用原始URL""" import logging import urllib.request +import time from io import BytesIO -from typing import Tuple +from typing import Tuple, Optional from langchain_core.runnables import RunnableConfig from langgraph.runtime import Runtime from coze_coding_utils.runtime_ctx.context import Context from PIL import Image -from datetime import datetime from graphs.state import ( ImagePreprocessInput, @@ -18,102 +17,79 @@ from graphs.state import ( logger = logging.getLogger(__name__) -# 固定宽度常量 -FIXED_WIDTH = 1000 +# 默认图片尺寸(用于降级处理) +DEFAULT_IMAGE_SIZE = (1000, 1400) + +# 超时配置(秒) +IMAGE_DOWNLOAD_TIMEOUT = 30 # 单次下载超时 +MAX_RETRIES = 2 # 最大重试次数(减少重试) -def download_and_process_image(image_url: str) -> Tuple[Image.Image, int, int, int]: - """下载图片并返回PIL Image对象和原始尺寸信息""" - try: - with urllib.request.urlopen(image_url, timeout=30) as response: - img_data = response.read() - img = Image.open(BytesIO(img_data)) - - # 转换为RGB模式(处理PNG透明通道) - if img.mode in ('RGBA', 'P'): - img = img.convert('RGB') - - original_width, original_height = img.size - dpi = img.info.get('dpi', (72, 72)) - if isinstance(dpi, tuple): - dpi = dpi[0] if dpi[0] > 0 else 72 - else: - dpi = 72 - - return img, original_width, original_height, dpi - except Exception as e: - raise RuntimeError(f"下载图片失败: {e}") - - -def auto_rotate_image(img: Image.Image) -> Tuple[Image.Image, bool]: +def get_image_info_with_retry(image_url: str, max_retries: int = MAX_RETRIES, timeout: int = IMAGE_DOWNLOAD_TIMEOUT) -> Tuple[int, int, int]: """ - 自动旋转图片:如果宽度大于高度(横向图片),则旋转-90度使其变为纵向 + 获取图片尺寸信息(带重试机制,有最大总时间限制) Args: - img: PIL Image对象 + image_url: 图片URL + max_retries: 最大重试次数(默认2次) + timeout: 单次请求超时时间(默认30秒) Returns: - (旋转后的图片, 是否进行了旋转) + (width, height, dpi) """ - width, height = img.size + last_error = None + total_start_time = time.time() + MAX_TOTAL_TIME = 60 # 总时间不超过60秒 - # 如果宽度大于高度,需要旋转 - if width > height: - logger.info(f"检测到横向图片(宽{width} > 高{height}),正在旋转-90度...") - # rotate(-90) 表示逆时针旋转90度,使横向变为纵向 - rotated_img = img.rotate(-90, expand=True) - new_width, new_height = rotated_img.size - logger.info(f"旋转完成,新尺寸:宽{new_width} x 高{new_height}") - return rotated_img, True + for attempt in range(max_retries): + # 检查总时间 + if time.time() - total_start_time > MAX_TOTAL_TIME: + logger.warning(f"Total time exceeded {MAX_TOTAL_TIME}s, giving up") + break + + try: + # 下载图片(带超时) + with urllib.request.urlopen(image_url, timeout=timeout) as response: + img_data = response.read() + + # 检查数据大小 + if len(img_data) < 100: + raise ValueError(f"Image data too small: {len(img_data)} bytes") + + img = Image.open(BytesIO(img_data)) + + # 转换为RGB模式(处理PNG透明通道) + if img.mode in ('RGBA', 'P'): + img = img.convert('RGB') + + width, height = img.size + + # 验证尺寸有效性 + if width <= 0 or height <= 0: + raise ValueError(f"Invalid image size: {width}x{height}") + + dpi = img.info.get('dpi', (72, 72)) + if isinstance(dpi, tuple): + dpi = dpi[0] if dpi[0] > 0 else 72 + else: + dpi = 72 + + elapsed = time.time() - total_start_time + logger.info(f"Got image info: {width}x{height} in {elapsed:.1f}s (attempt {attempt + 1})") + return width, height, dpi + + except Exception as e: + last_error = e + elapsed = time.time() - total_start_time + logger.warning(f"Attempt {attempt + 1}/{max_retries} failed in {elapsed:.1f}s: {str(e)[:100]}") + + # 只有非最后一次才等待 + if attempt < max_retries - 1: + wait_time = 1 # 固定等待1秒 + time.sleep(wait_time) - logger.info(f"图片为纵向(宽{width} <= 高{height}),无需旋转") - return img, False - - -def resize_image_to_fixed_width(img: Image.Image, target_width: int = FIXED_WIDTH) -> Tuple[Image.Image, float]: - """将图片缩放到固定宽度,高度等比例缩放""" - original_width, original_height = img.size - - if original_width == target_width: - return img, 1.0 - - # 计算缩放比例 - scale_ratio = target_width / original_width - new_height = int(original_height * scale_ratio) - - # 使用高质量重采样 - # BICUBIC = 3, LANCZOS = 4 (PIL内部常量值) - resized_img = img.resize((target_width, new_height), 3) # BICUBIC - - return resized_img, scale_ratio - - -def upload_image_to_storage(img: Image.Image, ctx) -> str: - """将图片上传到对象存储并返回URL""" - from coze_coding_dev_sdk.s3 import S3SyncStorage - - # 转换为字节流 - img_buffer = BytesIO() - img.save(img_buffer, format='JPEG', quality=95) - img_bytes = img_buffer.getvalue() - - # 上传到对象存储 - storage = S3SyncStorage( - endpoint_url=os.getenv("COZE_BUCKET_ENDPOINT_URL"), - access_key="", - secret_key="", - bucket_name=os.getenv("COZE_BUCKET_NAME"), - region="cn-beijing", - ) - - file_key = storage.upload_file( - file_content=img_bytes, - file_name=f"homework_resized_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jpg", - content_type="image/jpeg", - ) - - result_url = storage.generate_presigned_url(key=file_key, expire_time=86400) - return result_url + # 所有重试都失败,抛出错误 + raise RuntimeError(f"图片获取失败({max_retries}次重试): {str(last_error)[:100]}") def image_preprocess_node( @@ -123,37 +99,38 @@ def image_preprocess_node( ) -> ImagePreprocessOutput: """ title: 图像预处理 - desc: 下载图片、自动旋转(横向→纵向)、缩放到固定宽度1000px、上传对象存储 - integrations: 对象存储 + desc: 获取图片尺寸信息,直接使用原始图片URL用于AI识别,不旋转不缩放不上传。支持重试和降级处理。 + integrations: """ ctx = runtime.context - # 1. 下载原始图片 - img, original_width, original_height, dpi = download_and_process_image( - state.homework_image.url - ) - logger.info(f"原始图片尺寸:宽{original_width} x 高{original_height}") + # 获取原始图片URL + original_image_url = state.homework_image.url - # 2. 自动旋转:如果宽度大于高度,旋转-90度使其变为纵向 - img, was_rotated = auto_rotate_image(img) - after_rotate_width, after_rotate_height = img.size - if was_rotated: - logger.info(f"旋转后尺寸:宽{after_rotate_width} x 高{after_rotate_height}") - - # 3. 缩放到固定宽度1000px,高度等比例缩放 - resized_img, scale_ratio = resize_image_to_fixed_width(img, FIXED_WIDTH) - new_width, new_height = resized_img.size - logger.info(f"缩放后尺寸:宽{new_width} x 高{new_height},缩放比例:{scale_ratio:.3f}") - - # 4. 上传处理后的图片到对象存储 - processed_image_url = upload_image_to_storage(resized_img, ctx) - - # 5. 返回处理后的图片信息(AI基于这个尺寸计算坐标) - return ImagePreprocessOutput( - image_info=ImageInfo( - width=new_width, # 缩放后的宽度:1000 - height=new_height, # 缩放后的高度 - dpi=dpi - ), - image_url=processed_image_url - ) + try: + # 获取图片尺寸信息(带重试) + width, height, dpi = get_image_info_with_retry(original_image_url) + logger.info(f"原始图片尺寸:宽{width} x 高{height}") + + return ImagePreprocessOutput( + image_info=ImageInfo( + width=width, + height=height, + dpi=dpi + ), + image_url=original_image_url + ) + + except Exception as e: + # 降级处理:使用默认尺寸,但仍然继续处理 + logger.error(f"Failed to get image info after retries, using fallback: {e}") + + # 返回默认尺寸,让后续节点能够继续处理 + return ImagePreprocessOutput( + image_info=ImageInfo( + width=DEFAULT_IMAGE_SIZE[0], + height=DEFAULT_IMAGE_SIZE[1], + dpi=72 + ), + image_url=original_image_url + ) diff --git a/src/graphs/nodes/process_images_node.py b/src/graphs/nodes/process_images_node.py index a89a6eb..ad92457 100644 --- a/src/graphs/nodes/process_images_node.py +++ b/src/graphs/nodes/process_images_node.py @@ -1,7 +1,8 @@ -"""多图片处理循环节点:并行调用子图处理每张作业图片""" +"""多学生作业处理循环节点:完全并行处理,确保数据隔离""" import logging -from typing import List -from concurrent.futures import ThreadPoolExecutor, as_completed +import threading +from typing import List, Dict, Any, Optional +from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError as FuturesTimeoutError from langchain_core.runnables import RunnableConfig from langgraph.runtime import Runtime from coze_coding_utils.runtime_ctx.context import Context @@ -9,159 +10,348 @@ from coze_coding_utils.runtime_ctx.context import Context from graphs.state import ( ProcessImagesInput, ProcessImagesOutput, + StudentResult, + StudentHomework, SingleImageResult, SubgraphInput, - FinalResult, - ImageInfo + ImageInfo, + CorrectAnswer ) from graphs.loop_graph import single_image_subgraph +from utils.file.file import File logger = logging.getLogger(__name__) +# 默认等级标准(与GraphInput中的默认值保持一致) +DEFAULT_GRADE_STANDARDS = { + "A+": {"min_percentage": 95, "description": "优秀"}, + "A": {"min_percentage": 90, "description": "良好"}, + "B": {"min_percentage": 80, "description": "合格"}, + "C": {"min_percentage": 70, "description": "及格"}, + "D": {"min_percentage": 0, "description": "需努力"} +} + +# 超时配置(秒) +SINGLE_IMAGE_TIMEOUT = 120 # 单张图片处理超时120秒 + + +def calculate_grade(score_rate: float, incorrect_count: int, total_questions: int, grade_standards: Dict[str, Any]) -> tuple: + """ + 根据得分率、错误数量和等级标准计算等级 + + 核心规则(按优先级): + 1. A+ 和 A:首要条件是"全对"(incorrect_count == 0),与得分率无关 + - A+:全对 + - A:全对 + (A+和A的区别由其他因素决定,这里都返回全对的最高等级) + 2. B/C/D:有错误时,按得分率判断 + + Args: + score_rate: 得分率(百分比,如95.5) + incorrect_count: 错误题目数量 + total_questions: 总题目数量 + grade_standards: 等级标准字典 + + Returns: + (等级, 等级描述) + """ + # 使用传入的等级标准,如果没有则使用默认值 + standards = grade_standards if grade_standards else DEFAULT_GRADE_STANDARDS + + # 规则1:全对 → A+ 或 A(这里统一返回A+,因为全对就是最高等级) + if incorrect_count == 0 and total_questions > 0: + return "A+", standards.get("A+", {}).get("description", "优秀") + + # 规则2:有错误 → 按得分率判断 B/C/D + # 按min_percentage降序排序(只排B/C/D) + bcd_grades = [(k, v) for k, v in standards.items() if k not in ("A+", "A")] + sorted_grades = sorted( + bcd_grades, + key=lambda x: x[1].get("min_percentage", 0), + reverse=True + ) + + # 遍历找到匹配的等级 + for grade, config in sorted_grades: + min_pct = config.get("min_percentage", 0) + if score_rate >= min_pct: + return grade, config.get("description", "") + + # 默认返回D + return "D", "需努力" + + +def process_single_image( + student_id: int, + idx: int, + image_url: str, + correct_answers: List[CorrectAnswer], + comment_max_length: int, + config: RunnableConfig +) -> tuple: + """ + 处理单个学生的单张图片(线程安全) + 返回: (student_id, image_index, SingleImageResult) 元组,确保学生ID关联 + + 注意:correct_answers 参数是只读的,不会被修改 + """ + logger.info(f"Processing student {student_id}, image {idx + 1}") + try: + # 将URL字符串转换为File对象 + homework_image = File(url=image_url, file_type="image") + + # 构建子图输入(创建新的输入对象,确保数据隔离) + # 注意:correct_answers 是只读的,不需要复制 + subgraph_input = SubgraphInput( + homework_image=homework_image, + correct_answers=correct_answers, # 只读,不需要复制 + image_index=idx, + comment_max_length=comment_max_length + ) + + # 调用子图(config 是只读的,LangGraph 保证线程安全) + subgraph_output = single_image_subgraph.invoke(subgraph_input, config) + + # 提取结果(兼容字典和Pydantic对象两种返回类型) + image_result = None + if subgraph_output: + if isinstance(subgraph_output, dict): + result_data = subgraph_output.get("image_result") + if result_data: + if isinstance(result_data, dict): + image_result = SingleImageResult(**result_data) + else: + image_result = result_data + elif hasattr(subgraph_output, 'image_result'): + image_result = subgraph_output.image_result + + if image_result: + logger.info(f"Student {student_id}, image {idx + 1} processed: {len(image_result.annotations)} annotations") + return (student_id, idx, image_result) + else: + logger.warning(f"Student {student_id}, image {idx + 1} returned invalid output") + return (student_id, idx, SingleImageResult( + image_index=idx, + image_info=ImageInfo(width=0, height=0, dpi=72), + annotations=[] + )) + except Exception as e: + logger.error(f"Failed to process student {student_id}, image {idx + 1}: {e}", exc_info=True) + return (student_id, idx, SingleImageResult( + image_index=idx, + image_info=ImageInfo(width=0, height=0, dpi=72), + annotations=[] + )) + + def process_images_node( state: ProcessImagesInput, config: RunnableConfig, runtime: Runtime[Context] ) -> ProcessImagesOutput: """ - title: 多图片批改处理 - desc: 并行调用子图处理每张作业图片,生成最终批改结果 + title: 多学生作业批改处理 + desc: 完全并行处理所有学生的所有图片,确保每个学生的数据完全隔离 integrations: """ ctx = runtime.context - # 获取并发数限制(从参数获取,默认10) - max_concurrent = getattr(state, 'max_concurrent', 10) + # === 输入参数校验 === + if not state.student_homework: + logger.warning("No student homework provided, returning empty result") + return ProcessImagesOutput(student_results=[]) - logger.info(f"Starting to process {len(state.homework_images)} images (concurrent={max_concurrent})") + # 获取并发数限制(从参数获取,默认10,限制在合理范围1-50) + max_concurrent = max(1, min(getattr(state, 'max_concurrent', 10), 50)) - # 定义处理单张图片的函数 - def process_single_image(idx: int, homework_image) -> SingleImageResult: - logger.info(f"Processing image {idx + 1}/{len(state.homework_images)}") - try: - # 构建子图输入 - subgraph_input = SubgraphInput( - homework_image=homework_image, - correct_answers=state.correct_answers, - image_index=idx, - comment_max_length=state.comment_max_length - ) - - # 调用子图 - subgraph_output = single_image_subgraph.invoke(subgraph_input, config) - - # 提取结果(兼容字典和Pydantic对象两种返回类型) - image_result = None - if subgraph_output: - if isinstance(subgraph_output, dict): - result_data = subgraph_output.get("image_result") - if result_data: - if isinstance(result_data, dict): - image_result = SingleImageResult(**result_data) - else: - image_result = result_data - elif hasattr(subgraph_output, 'image_result'): - image_result = subgraph_output.image_result - - if image_result: - logger.info(f"Image {idx + 1} processed successfully: {len(image_result.annotations)} annotations") - return image_result + # 过滤有效的图片URL,统计总任务数 + valid_tasks = [] # [(student_id, idx, image_url), ...] + for student in state.student_homework: + if not student.homework_images: + continue + for idx, image_url in enumerate(student.homework_images): + # 验证图片URL有效性 + if image_url and isinstance(image_url, str) and image_url.strip(): + valid_tasks.append((student.student_id, idx, image_url.strip())) else: - logger.warning(f"Image {idx + 1} subgraph returned invalid output") - return SingleImageResult( - image_index=idx, - image_info=ImageInfo(width=0, height=0, dpi=72), - annotations=[] - ) - except Exception as e: - logger.error(f"Failed to process image {idx + 1}: {e}", exc_info=True) - return SingleImageResult( - image_index=idx, - image_info=ImageInfo(width=0, height=0, dpi=72), - annotations=[] - ) + logger.warning(f"Invalid image URL for student {student.student_id}, image {idx}") - # 并行处理所有图片 - image_results: List[SingleImageResult] = [] + if not valid_tasks: + logger.warning("No valid images to process, returning empty result") + return ProcessImagesOutput(student_results=[]) + + logger.info(f"Starting to process {len(state.student_homework)} students, {len(valid_tasks)} valid images (concurrent={max_concurrent})") + + # 第一步:提交所有任务(所有学生的所有图片) + # 注意:使用字典存储结果,as_completed 在主线程顺序处理,无需加锁 + student_image_results: Dict[int, Dict[int, SingleImageResult]] = {} with ThreadPoolExecutor(max_workers=max_concurrent) as executor: - # 提交所有任务 - future_to_idx = { - executor.submit(process_single_image, idx, img): idx - for idx, img in enumerate(state.homework_images) - } + # 提交所有有效任务 + future_to_task = {} # future -> (student_id, idx, image_url) + for student_id, idx, image_url in valid_tasks: + future = executor.submit( + process_single_image, + student_id, + idx, + image_url, + state.correct_answers, + state.comment_max_length, + config + ) + future_to_task[future] = (student_id, idx, image_url) - # 收集结果 - results_dict = {} - for future in as_completed(future_to_idx): - idx = future_to_idx[future] - try: - result = future.result() - results_dict[idx] = result - except Exception as e: - logger.error(f"Future failed for image {idx}: {e}") - results_dict[idx] = SingleImageResult( + logger.info(f"Submitted {len(future_to_task)} tasks for all students") + + # 第二步:收集所有结果,按student_id分组 + # as_completed 在主线程中迭代,不需要加锁 + # 双重保险:使用 task_info 中的 student_id(而非 result 中的),确保数据归属正确 + try: + for future in as_completed(future_to_task, timeout=SINGLE_IMAGE_TIMEOUT * len(future_to_task)): + task_info = future_to_task[future] + student_id, idx, image_url = task_info # 从任务信息获取,不依赖返回值 + + try: + # 返回值中的 student_id 应该与 task_info 一致,但我们优先使用 task_info + # 添加单任务超时保护 + result_student_id, image_index, image_result = future.result(timeout=SINGLE_IMAGE_TIMEOUT) + + # 双重校验:确保返回的学生ID与任务一致 + if result_student_id != student_id: + logger.warning(f"Student ID mismatch: task={student_id}, result={result_student_id}, using task_id") + + # 使用 task_info 中的 student_id(更可靠) + if student_id not in student_image_results: + student_image_results[student_id] = {} + + # 存储该学生的图片结果 + student_image_results[student_id][image_index] = image_result + + except FuturesTimeoutError: + logger.error(f"Task timeout for student {student_id}, image {idx} (timeout={SINGLE_IMAGE_TIMEOUT}s)") + # 存储一个空结果,确保该学生的结果完整 + if student_id not in student_image_results: + student_image_results[student_id] = {} + student_image_results[student_id][idx] = SingleImageResult( + image_index=idx, + image_info=ImageInfo(width=0, height=0, dpi=72), + annotations=[] + ) + except Exception as e: + logger.error(f"Task failed for student {student_id}, image {idx}: {e}", exc_info=True) + # 存储一个空结果,确保该学生的结果完整 + if student_id not in student_image_results: + student_image_results[student_id] = {} + student_image_results[student_id][idx] = SingleImageResult( + image_index=idx, + image_info=ImageInfo(width=0, height=0, dpi=72), + annotations=[] + ) + + except FuturesTimeoutError: + # as_completed 总超时,处理未完成的任务 + logger.error(f"Total timeout exceeded, processing remaining {len(future_to_task)} tasks") + for future, task_info in future_to_task.items(): + if not future.done(): + student_id, idx, _ = task_info + if student_id not in student_image_results: + student_image_results[student_id] = {} + student_image_results[student_id][idx] = SingleImageResult( + image_index=idx, + image_info=ImageInfo(width=0, height=0, dpi=72), + annotations=[] + ) + + logger.info(f"Collected results for {len(student_image_results)} students") + + # 第三步:为每个学生计算独立的结果(数据隔离) + student_results: List[StudentResult] = [] + + for student in state.student_homework: + student_id = student.student_id + student_name = student.student_name + homework_image_urls = student.homework_images + + # 获取该学生的图片结果(确保只使用该学生的数据) + image_results_dict = student_image_results.get(student_id, {}) + + # 按顺序组装该学生的图片结果 + image_results: List[SingleImageResult] = [] + for idx in range(len(homework_image_urls)): + if idx in image_results_dict: + image_results.append(image_results_dict[idx]) + else: + # 补充缺失的结果 + image_results.append(SingleImageResult( image_index=idx, image_info=ImageInfo(width=0, height=0, dpi=72), annotations=[] - ) + )) - # 按原始顺序排序结果 - for idx in range(len(state.homework_images)): - image_results.append(results_dict[idx]) + # 计算该学生的总分(只基于该学生的图片结果) + total_score = 0 + full_score = 0 + total_questions = 0 + correct_count = 0 + incorrect_count = 0 + + for result in image_results: + for annotation in result.annotations: + total_score += annotation.score + full_score += annotation.full_score + total_questions += 1 + if annotation.status == "correct": + correct_count += 1 + elif annotation.status == "incorrect": + incorrect_count += 1 + + # 计算得分率 + score_rate = (total_score / full_score * 100) if full_score > 0 else 0 + + # 获取等级标准(从state中获取,如果没有则使用默认值) + grade_standards = getattr(state, 'grade_standards', {}) + + # 使用等级标准计算等级(传入错误数量和总题数) + grade, grade_description = calculate_grade(score_rate, incorrect_count, total_questions, grade_standards) + + # 生成该学生的整体评价(与等级匹配) + if total_questions == 0: + overall_comment = "未识别到题目内容" + grade = "D" + elif grade == "A+": + # A+:全部正确 + overall_comment = f"优秀!{total_questions}题全部正确,掌握扎实,继续保持。" + elif grade == "A": + # A:全部正确(已确保无错误) + overall_comment = f"良好!{total_questions}题全部正确,步骤规范。" + elif grade == "B": + # B:有少量错误 + if incorrect_count > 0: + overall_comment = f"合格。得分率{score_rate:.0f}%,错{incorrect_count}题,需加强练习。" + else: + overall_comment = f"合格。得分率{score_rate:.0f}%,继续努力。" + elif grade == "C": + overall_comment = f"及格。错{incorrect_count}题,部分知识点掌握不牢,建议复习。" + else: + overall_comment = f"需努力。得分率{score_rate:.0f}%,错{incorrect_count}题,建议认真复习,多做练习。" + + # 创建该学生的完整结果(数据完全隔离) + student_result = StudentResult( + student_id=student_id, + student_name=student_name, + total_images=len(image_results), + image_results=image_results, + overall_comment=overall_comment, + total_score=total_score, + full_score=full_score, + grade=grade + ) + + student_results.append(student_result) + + logger.info(f"Student {student_id}: {total_score}/{full_score}, grade={grade}") - logger.info(f"Completed processing {len(image_results)} images") + logger.info(f"Completed processing {len(student_results)} students") - # 生成最终结果 - total_score = 0 - full_score = 0 - total_questions = 0 - correct_count = 0 - incorrect_count = 0 - - for result in image_results: - for annotation in result.annotations: - total_score += annotation.score - full_score += annotation.full_score - total_questions += 1 - if annotation.status == "correct": - correct_count += 1 - elif annotation.status == "incorrect": - incorrect_count += 1 - - # 计算得分率 - score_rate = (total_score / full_score * 100) if full_score > 0 else 0 - - # 生成整体评价 - if total_questions == 0: - overall_comment = "未识别到题目内容" - grade = "D" - elif score_rate >= 95: - overall_comment = f"优秀!{total_questions}题全部正确,掌握扎实,继续保持。" - grade = "A+" - elif score_rate >= 90: - overall_comment = f"良好!得分率{score_rate:.0f}%,错{incorrect_count}题,注意细节。" - grade = "A" - elif score_rate >= 80: - overall_comment = f"合格。得分率{score_rate:.0f}%,错{incorrect_count}题,需加强练习。" - grade = "B" - elif score_rate >= 70: - overall_comment = f"及格。错{incorrect_count}题,部分知识点掌握不牢,建议复习。" - grade = "C" - else: - overall_comment = f"需努力。得分率{score_rate:.0f}%,建议认真复习,多做练习。" - grade = "D" - - final_result = FinalResult( - total_images=len(image_results), - image_results=image_results, - overall_comment=overall_comment, - total_score=total_score, - full_score=full_score, - grade=grade - ) - - logger.info(f"Final result: {total_score}/{full_score}, grade={grade}") - - return ProcessImagesOutput(final_result=final_result) + return ProcessImagesOutput(student_results=student_results) diff --git a/src/graphs/nodes/recognize_and_correct_node.py b/src/graphs/nodes/recognize_and_correct_node.py index b860a35..c57d045 100644 --- a/src/graphs/nodes/recognize_and_correct_node.py +++ b/src/graphs/nodes/recognize_and_correct_node.py @@ -3,6 +3,7 @@ import os import json import re import logging +import orjson from typing import List, Dict, Any from jinja2 import Template from langchain_core.runnables import RunnableConfig @@ -22,6 +23,30 @@ from graphs.state import ( logger = logging.getLogger(__name__) +# 思考过程的特征词 +THINKING_KEYWORDS = [ + "不对", "重新看", "可能我", "哦,", "说明我", "这明显", + "不,", "应该是", "我发现", "等等", "让我", "我理解", + "?不", "?不对", "不是", "搞错了", "看错了" +] + + +def clean_comment(text: str) -> str: + """清理comment中的思考过程""" + for keyword in THINKING_KEYWORDS: + if keyword in text: + # 找到思考过程开始的位置 + idx = text.find(keyword) + if idx > 0: + # 截断思考过程 + cleaned = text[:idx] + # 尝试保留完整句子(最后一个句号之后) + last_period = cleaned.rfind("。") + if last_period > 0: + return cleaned[:last_period + 1] + return cleaned + return text + def fix_incomplete_json(text: str) -> str: """尝试修复不完整的JSON字符串""" @@ -38,9 +63,64 @@ def fix_incomplete_json(text: str) -> str: return text +def extract_complete_objects(text: str) -> List[dict]: + """从JSON文本中提取完整的对象列表(处理思考过程干扰)""" + objects = [] + + # 找到每个对象的开始位置 + obj_pattern = r'\{\s*"question_id"' + matches = list(re.finditer(obj_pattern, text)) + + for i, match in enumerate(matches): + start = match.start() + + # 找到这个对象的结束位置 + brace_count = 0 + in_string = False + escape = False + end = start + + for j in range(start, len(text)): + char = text[j] + + if escape: + escape = False + continue + + if char == '\\': + escape = True + continue + + if char == '"' and not escape: + in_string = not in_string + continue + + if in_string: + continue + + if char == '{': + brace_count += 1 + elif char == '}': + brace_count -= 1 + if brace_count == 0: + end = j + 1 + break + + if end > start: + obj_str = text[start:end] + try: + obj = orjson.loads(obj_str) + if isinstance(obj, dict) and "question_id" in obj: + objects.append(obj) + logger.debug(f"Extracted object: {obj.get('question_id')}") + except Exception as e: + logger.debug(f"Failed to parse object: {e}") + + return objects + + def extract_json_from_text(text: str, key: str = "results") -> dict: """从文本中提取JSON对象,增强健壮性""" - import orjson # 清理markdown标记 for prefix in ["```json", "```JSON", "```"]: @@ -52,7 +132,7 @@ def extract_json_from_text(text: str, key: str = "results") -> dict: text = text.strip() - # 尝试直接解析(支持格式化JSON) + # 尝试直接解析 for parser in [json.loads, lambda x: orjson.loads(x)]: try: result = parser(text) @@ -76,79 +156,11 @@ def extract_json_from_text(text: str, key: str = "results") -> dict: except: pass - # 尝试提取第一个完整的JSON对象 - results_pattern = r'"results"\s*:\s*\[' - match = re.search(results_pattern, text) - - if match: - start_pos = text.rfind('{', 0, match.start()) - if start_pos == -1: - start_pos = 0 - - # 从start_pos开始,找到匹配的 } - brace_count = 0 - bracket_count = 0 - in_string = False - escape = False - end_pos = start_pos - - for i in range(start_pos, len(text)): - char = text[i] - - if escape: - escape = False - continue - - if char == '\\': - escape = True - continue - - if char == '"' and not escape: - in_string = not in_string - continue - - if in_string: - continue - - if char == '{': - brace_count += 1 - elif char == '}': - brace_count -= 1 - if brace_count == 0 and bracket_count == 0: - end_pos = i + 1 - break - elif char == '[': - bracket_count += 1 - elif char == ']': - bracket_count -= 1 - - if end_pos > start_pos: - json_str = text[start_pos:end_pos] - - for parser in [json.loads, lambda x: orjson.loads(x)]: - try: - result = parser(json_str) - if isinstance(result, dict) and key in result: - logger.info("Successfully extracted and parsed JSON object") - return result - except Exception as e: - logger.debug(f"Failed to parse extracted JSON: {e}") - - # 如果提取失败,尝试修复不完整的JSON - if end_pos <= start_pos: - # 找到JSON开始位置 - json_str = text[start_pos:] - - # 尝试修复 - fixed_json = fix_incomplete_json(json_str) - for parser in [json.loads, lambda x: orjson.loads(x)]: - try: - result = parser(fixed_json) - if isinstance(result, dict) and key in result: - logger.info("Successfully parsed fixed incomplete JSON") - return result - except: - pass + # 尝试提取完整的对象列表(处理思考过程干扰) + objects = extract_complete_objects(text) + if objects: + logger.info(f"Extracted {len(objects)} complete objects from JSON with thinking") + return {key: objects} logger.warning(f"Failed to extract JSON with key '{key}' from text length {len(text)}") return {key: []} @@ -240,7 +252,7 @@ def recognize_and_correct_node( messages=messages, model=llm_config.get("model", "doubao-seed-2-0-pro-260215"), temperature=llm_config.get("temperature", 0.0), - max_completion_tokens=llm_config.get("max_completion_tokens", 4096) + max_completion_tokens=llm_config.get("max_completion_tokens", 8192) ) response_text = response.content if isinstance(response.content, str) else " ".join( @@ -255,8 +267,12 @@ def recognize_and_correct_node( result_dict = extract_json_from_text(response_text, "results") # 相对坐标(0-1000)转换为绝对坐标 - width_scale = image_info.width / 1000.0 if image_info.width > 0 else 1.0 - height_scale = image_info.height / 1000.0 if image_info.height > 0 else 1.0 + # 严格检查图片尺寸,防止除零和无效坐标 + img_width = max(1, image_info.width) if image_info.width > 0 else 1000 + img_height = max(1, image_info.height) if image_info.height > 0 else 1000 + + width_scale = img_width / 1000.0 + height_scale = img_height / 1000.0 question_items: List[QuestionItem] = [] correction_results: List[CorrectionResult] = [] @@ -275,6 +291,20 @@ def recognize_and_correct_node( if not isinstance(answer_bbox, list) or len(answer_bbox) != 4: answer_bbox = [0, 0, 0, 0] + # 确保bbox值在有效范围内(0-1000) + answer_bbox = [ + max(0, min(1000, int(answer_bbox[0]))) if isinstance(answer_bbox[0], (int, float)) else 0, + max(0, min(1000, int(answer_bbox[1]))) if isinstance(answer_bbox[1], (int, float)) else 0, + max(0, min(1000, int(answer_bbox[2]))) if isinstance(answer_bbox[2], (int, float)) else 0, + max(0, min(1000, int(answer_bbox[3]))) if isinstance(answer_bbox[3], (int, float)) else 0 + ] + + # 确保x2 >= x1, y2 >= y1 + if answer_bbox[2] < answer_bbox[0]: + answer_bbox[2] = answer_bbox[0] + if answer_bbox[3] < answer_bbox[1]: + answer_bbox[3] = answer_bbox[1] + # 转换为绝对坐标 answer_bbox_abs = [ int(answer_bbox[0] * width_scale), @@ -283,19 +313,42 @@ def recognize_and_correct_node( int(answer_bbox[3] * height_scale) ] - # 自动计算mark_position + # 严格限制绝对坐标在图片范围内 + answer_bbox_abs = [ + max(0, min(img_width, answer_bbox_abs[0])), + max(0, min(img_height, answer_bbox_abs[1])), + max(0, min(img_width, answer_bbox_abs[2])), + max(0, min(img_height, answer_bbox_abs[3])) + ] + + # 自动计算mark_position(智能定位,紧贴答案框) bbox_width = answer_bbox_abs[2] - answer_bbox_abs[0] bbox_height = answer_bbox_abs[3] - answer_bbox_abs[1] if bbox_width > 0 and bbox_height > 0: - mark_x = answer_bbox_abs[2] + 30 - mark_y = answer_bbox_abs[1] + int(bbox_height * 0.5) + # 计算答案框右侧的剩余空间 + right_space = img_width - answer_bbox_abs[2] - if mark_x > image_info.width - 50: - mark_x = image_info.width - 50 + # 策略1:如果右侧空间充足(>80px),标记紧贴答案框右侧 + if right_space > 80: + mark_x = answer_bbox_abs[2] + 10 + mark_y = answer_bbox_abs[1] + int(bbox_height * 0.5) + # 策略2:如果右侧空间不足但有一定空间(>40px),标记在答案框右上角内部 + elif right_space > 40: + mark_x = answer_bbox_abs[2] - 10 + mark_y = answer_bbox_abs[1] + 10 + # 策略3:如果右侧空间很小,标记在答案框左上角 + else: + mark_x = answer_bbox_abs[0] + 10 + mark_y = answer_bbox_abs[1] + 10 + + # 最终边界检查(严格限制在图片范围内,留10px边距) + mark_x = max(10, min(mark_x, img_width - 10)) + mark_y = max(10, min(mark_y, img_height - 10)) else: - mark_x = 500 - mark_y = 500 + # bbox无效时,使用图片中心 + mark_x = img_width // 2 + mark_y = img_height // 2 mark_position = MarkPosition(x=mark_x, y=mark_y) @@ -316,8 +369,13 @@ def recognize_and_correct_node( if status not in ["correct", "incorrect", "partial"]: status = "incorrect" - # 使用原始comment,不做截断(由LLM控制长度) - comment = str(r.get("comment", "")) + # 清理comment中的思考过程 + comment_raw = str(r.get("comment", "")) + comment = clean_comment(comment_raw) + + # 如果清理后comment为空,设置默认值 + if not comment: + comment = "正确" if status == "correct" else "错误" correction_results.append(CorrectionResult( question_id=str(r.get("question_id", "")), diff --git a/src/graphs/state.py b/src/graphs/state.py index 4c59dc6..94568ce 100644 --- a/src/graphs/state.py +++ b/src/graphs/state.py @@ -1,4 +1,4 @@ -"""初中物理作业批改工作流状态定义 - 支持多图片批改""" +"""初中物理作业批改工作流状态定义 - 支持多学生多图片批改""" from typing import List, Optional, Literal from pydantic import BaseModel, Field from utils.file.file import File @@ -6,6 +6,13 @@ from utils.file.file import File # === 基础数据结构 === +class StudentHomework(BaseModel): + """学生作业信息""" + student_id: int = Field(..., description="学生ID") + student_name: str = Field(default="", description="学生姓名") + homework_images: List[str] = Field(default=[], description="该学生的作业图片URL列表") + + class ImageInfo(BaseModel): """图片信息""" width: int = Field(..., description="图片宽度(像素)") @@ -74,36 +81,37 @@ class SingleImageResult(BaseModel): annotations: List[Annotation] = Field(default=[], description="该图片的批注列表") -class FinalResult(BaseModel): - """最终批改结果(多图片汇总)""" - total_images: int = Field(..., description="总图片数") +class StudentResult(BaseModel): + """单个学生的批改结果""" + student_id: int = Field(..., description="学生ID") + student_name: str = Field(default="", description="学生姓名") + total_images: int = Field(..., description="该学生的总图片数") image_results: List[SingleImageResult] = Field(default=[], description="各图片的批改结果") - overall_comment: str = Field(default="", description="整体评价") - total_score: int = Field(default=0, description="总分") - full_score: int = Field(default=100, description="满分") - grade: str = Field(default="", description="等级") + overall_comment: str = Field(default="", description="该学生的整体评价") + total_score: int = Field(default=0, description="该学生的总分") + full_score: int = Field(default=100, description="该学生的满分") + grade: str = Field(default="", description="该学生的等级") # === 全局状态 === class GlobalState(BaseModel): """工作流全局状态""" # 输入参数 - homework_images: List[File] = Field(default=[], description="上传的作业图片列表") + student_homework: List[StudentHomework] = Field(default=[], description="学生作业列表") answer_doc_url: str = Field(default="", description="正确答案Word文件的URL(.docx格式)") comment_max_length: int = Field(default=100, description="评语最大字数") max_concurrent: int = Field(default=10, description="并行批改的最大数量") grade_standards: dict = Field(default={}, description="评价等级标准") # 中间状态 correct_answers: List[CorrectAnswer] = Field(default=[], description="从Word解析的标准答案列表") - image_results: List[SingleImageResult] = Field(default=[], description="各图片的批改结果列表") # 最终结果 - final_result: Optional[FinalResult] = Field(default=None, description="最终汇总结果") + student_results: List[StudentResult] = Field(default=[], description="各学生的批改结果列表") # === 图输入输出 === class GraphInput(BaseModel): """工作流输入""" - homework_images: List[File] = Field(..., description="上传的作业图片列表") + student_homework: List[StudentHomework] = Field(..., description="学生作业列表,每个学生包含ID和作业图片") answer_doc_url: str = Field(default="", description="正确答案Word文件的URL(.docx格式,可选)") comment_max_length: int = Field(default=100, description="评语最大字数,默认100字") max_concurrent: int = Field(default=10, description="并行批改的最大数量,默认10") @@ -121,7 +129,7 @@ class GraphInput(BaseModel): class GraphOutput(BaseModel): """工作流输出""" - final_result: FinalResult = Field(..., description="最终批改结果JSON(包含多图片)") + student_results: List[StudentResult] = Field(..., description="各学生的批改结果列表") # === 文档答案解析节点 === @@ -220,13 +228,14 @@ class ResultMergeOutput(BaseModel): # === 循环节点 === class ProcessImagesInput(BaseModel): - """多图片处理循环节点输入""" - homework_images: List[File] = Field(default=[], description="作业图片列表") + """多学生作业处理循环节点输入""" + student_homework: List[StudentHomework] = Field(default=[], description="学生作业列表") correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表") comment_max_length: int = Field(default=100, description="评语最大字数") max_concurrent: int = Field(default=10, description="并行批改的最大数量") + grade_standards: dict = Field(default={}, description="评价等级标准") class ProcessImagesOutput(BaseModel): - """多图片处理循环节点输出""" - final_result: FinalResult = Field(..., description="最终批改结果") + """多学生作业处理循环节点输出""" + student_results: List[StudentResult] = Field(..., description="各学生的批改结果列表") diff --git a/src/utils/cache_manager.py b/src/utils/cache_manager.py new file mode 100644 index 0000000..2bc9df3 --- /dev/null +++ b/src/utils/cache_manager.py @@ -0,0 +1,285 @@ +"""缓存管理器:支持过期时间、内存保护、持久化""" +import os +import json +import hashlib +import logging +import threading +import tempfile +from pathlib import Path +from datetime import datetime, timedelta +from typing import Any, Optional, Dict +from functools import wraps + +logger = logging.getLogger(__name__) + +# 缓存配置 - 使用 tempfile 获取安全的临时目录 +try: + CACHE_DIR = Path(tempfile.gettempdir()) / "homework_cache" +except Exception: + CACHE_DIR = Path("/tmp/homework_cache") + +CACHE_EXPIRE_DAYS = 30 # 缓存过期天数 +MAX_MEMORY_CACHE_SIZE = 1000 # 内存缓存最大数量 + + +class CacheManager: + """ + 缓存管理器:内存缓存 + 文件缓存 + - 内存缓存:快速访问,有大小限制 + - 文件缓存:持久化存储,支持过期时间 + - 异常安全:文件操作失败不影响主流程 + """ + + def __init__(self, cache_name: str, maxsize: int = MAX_MEMORY_CACHE_SIZE, expire_days: int = CACHE_EXPIRE_DAYS): + """ + 初始化缓存管理器 + + Args: + cache_name: 缓存名称(用于区分不同类型的缓存) + maxsize: 内存缓存最大数量 + expire_days: 缓存过期天数 + """ + self.cache_name = cache_name + self.maxsize = maxsize + self.expire_days = expire_days + + # 内存缓存(使用字典 + 简单的LRU淘汰) + self._memory_cache: Dict[str, Any] = {} + self._cache_keys: list = [] # 记录访问顺序,用于LRU淘汰 + self._lock = threading.Lock() # 线程安全锁 + self._file_cache_enabled = True # 文件缓存是否可用 + + # 文件缓存目录 - 带异常处理 + try: + self.cache_dir = CACHE_DIR / cache_name + self.cache_dir.mkdir(parents=True, exist_ok=True) + # 测试写入权限 + test_file = self.cache_dir / ".test_write" + test_file.write_text("test") + test_file.unlink() + logger.info(f"CacheManager initialized: {cache_name}, dir={self.cache_dir}") + except Exception as e: + logger.warning(f"File cache disabled due to permission error: {e}") + self._file_cache_enabled = False + logger.info(f"CacheManager initialized (memory only): {cache_name}") + + def _get_cache_key(self, key: str) -> str: + """生成缓存键(使用MD5哈希)""" + return hashlib.md5(key.encode()).hexdigest() + + def _get_cache_file(self, cache_key: str) -> Path: + """获取缓存文件路径""" + return self.cache_dir / f"{cache_key}.json" + + def _is_expired(self, cache_time: str) -> bool: + """检查缓存是否过期""" + try: + cached_dt = datetime.fromisoformat(cache_time) + expire_dt = cached_dt + timedelta(days=self.expire_days) + return datetime.now() > expire_dt + except Exception: + return True # 解析失败视为过期 + + def _evict_lru(self): + """LRU淘汰:移除最久未使用的缓存项""" + while len(self._memory_cache) >= self.maxsize and self._cache_keys: + oldest_key = self._cache_keys.pop(0) + if oldest_key in self._memory_cache: + del self._memory_cache[oldest_key] + logger.debug(f"LRU evicted: {oldest_key[:8]}...") + + def get(self, key: str) -> Optional[Any]: + """ + 获取缓存 + + 优先级:内存缓存 > 文件缓存 > None + + Args: + key: 缓存键 + + Returns: + 缓存值,不存在或过期返回None + """ + cache_key = self._get_cache_key(key) + + # 1. 检查内存缓存(线程安全) + with self._lock: + if cache_key in self._memory_cache: + # 更新访问顺序(移动到末尾) + if cache_key in self._cache_keys: + self._cache_keys.remove(cache_key) + self._cache_keys.append(cache_key) + logger.debug(f"Memory cache hit: {cache_key[:8]}...") + return self._memory_cache[cache_key] + + # 2. 检查文件缓存(仅在文件缓存可用时) + if self._file_cache_enabled: + cache_file = self._get_cache_file(cache_key) + if cache_file.exists(): + try: + with open(cache_file, 'r', encoding='utf-8') as f: + cached_data = json.load(f) + + # 检查是否过期 + if self._is_expired(cached_data.get("cache_time", "")): + # 过期,删除文件 + try: + cache_file.unlink() + except Exception: + pass + logger.debug(f"File cache expired: {cache_key[:8]}...") + return None + + # 未过期,加载到内存缓存 + with self._lock: + self._evict_lru() # 淘汰旧的 + self._memory_cache[cache_key] = cached_data["data"] + self._cache_keys.append(cache_key) + + logger.debug(f"File cache hit: {cache_key[:8]}...") + return cached_data["data"] + + except Exception as e: + logger.warning(f"Failed to read cache file: {e}") + # 删除损坏的缓存文件 + try: + cache_file.unlink() + except Exception: + pass + return None + + return None + + def set(self, key: str, value: Any): + """ + 设置缓存 + + 同时存入内存缓存和文件缓存 + + Args: + key: 缓存键 + value: 缓存值 + """ + cache_key = self._get_cache_key(key) + + # 1. 存入内存缓存 + with self._lock: + self._evict_lru() # 淘汰旧的 + self._memory_cache[cache_key] = value + if cache_key in self._cache_keys: + self._cache_keys.remove(cache_key) + self._cache_keys.append(cache_key) + + # 2. 存入文件缓存(仅在文件缓存可用时) + if self._file_cache_enabled: + cache_file = self._get_cache_file(cache_key) + try: + cached_data = { + "cache_time": datetime.now().isoformat(), + "data": value + } + with open(cache_file, 'w', encoding='utf-8') as f: + json.dump(cached_data, f, ensure_ascii=False, indent=2) + + logger.debug(f"Cache saved: {cache_key[:8]}...") + except Exception as e: + logger.warning(f"Failed to save cache file: {e}") + + def clear_expired(self): + """清理所有过期的文件缓存""" + if not self._file_cache_enabled: + return 0 + + cleaned = 0 + try: + for cache_file in self.cache_dir.glob("*.json"): + try: + with open(cache_file, 'r', encoding='utf-8') as f: + cached_data = json.load(f) + + if self._is_expired(cached_data.get("cache_time", "")): + cache_file.unlink() + cleaned += 1 + except Exception: + # 损坏的文件也删除 + try: + cache_file.unlink() + except Exception: + pass + cleaned += 1 + + if cleaned > 0: + logger.info(f"Cleaned {cleaned} expired cache files") + except Exception as e: + logger.error(f"Failed to clear expired cache: {e}") + + return cleaned + + def get_stats(self) -> dict: + """获取缓存统计信息""" + memory_size = len(self._memory_cache) + + # 统计文件缓存数量 + file_size = 0 + if self._file_cache_enabled: + try: + file_size = len(list(self.cache_dir.glob("*.json"))) + except Exception: + pass + + return { + "cache_name": self.cache_name, + "memory_cache_size": memory_size, + "memory_cache_maxsize": self.maxsize, + "file_cache_size": file_size, + "file_cache_enabled": self._file_cache_enabled, + "expire_days": self.expire_days + } + + +def cached(cache_manager: CacheManager): + """ + 缓存装饰器 + + 用法: + @cached(answer_doc_cache) + def parse_answer_doc(url: str): + # 解析逻辑 + return result + """ + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + # 生成缓存键(使用函数名和参数) + cache_key = f"{func.__name__}:{str(args)}:{str(kwargs)}" + + # 尝试从缓存获取 + cached_result = cache_manager.get(cache_key) + if cached_result is not None: + return cached_result + + # 缓存未命中,执行函数 + result = func(*args, **kwargs) + + # 存入缓存 + if result is not None: + cache_manager.set(cache_key, result) + + return result + + return wrapper + return decorator + + +# 创建全局缓存实例 +answer_doc_cache = CacheManager( + cache_name="answer_doc", + maxsize=MAX_MEMORY_CACHE_SIZE, + expire_days=CACHE_EXPIRE_DAYS +) + +grade_standards_cache = CacheManager( + cache_name="grade_standards", + maxsize=100, # 评分标准缓存数量较少 + expire_days=CACHE_EXPIRE_DAYS +) diff --git a/test_deployment.sh b/test_deployment.sh new file mode 100644 index 0000000..89187fe --- /dev/null +++ b/test_deployment.sh @@ -0,0 +1,168 @@ +#!/bin/bash + +# ============================================ +# 初中物理作业批改工作流 - 部署验证脚本 +# ============================================ + +set -e + +echo "======================================" +echo " 部署验证测试" +echo "======================================" +echo "" + +# 颜色定义 +GREEN='\033[0;32m' +RED='\033[0;31m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# 测试计数 +PASS=0 +FAIL=0 + +# 测试函数 +test_step() { + local name=$1 + local command=$2 + + echo -n "测试: $name ... " + + if eval "$command" > /dev/null 2>&1; then + echo -e "${GREEN}✅ 通过${NC}" + ((PASS++)) + return 0 + else + echo -e "${RED}❌ 失败${NC}" + ((FAIL++)) + return 1 + fi +} + +# 1. 环境检查 +echo "1. 环境检查" +echo "-----------------------------------" + +test_step "Python 版本" "python3 --version" +test_step "虚拟环境" "test -d venv" +test_step "依赖安装" "python3 -c 'import fastapi'" +test_step "依赖安装" "python3 -c 'import langgraph'" +test_step "依赖安装" "python3 -c 'import openai'" +echo "" + +# 2. 配置检查 +echo "2. 配置检查" +echo "-----------------------------------" + +if [ -f ".env" ]; then + echo -e "${GREEN}✅ .env 文件存在${NC}" + ((PASS++)) + + # 检查必需环境变量 + source .env + + if [ -n "$LLM_API_KEY" ] && [ "$LLM_API_KEY" != "your-api-key-here" ]; then + echo -e "${GREEN}✅ LLM_API_KEY 已配置${NC}" + ((PASS++)) + else + echo -e "${RED}❌ LLM_API_KEY 未配置${NC}" + ((FAIL++)) + fi + + if [ -n "$LLM_BASE_URL" ]; then + echo -e "${GREEN}✅ LLM_BASE_URL 已配置${NC}" + ((PASS++)) + else + echo -e "${RED}❌ LLM_BASE_URL 未配置${NC}" + ((FAIL++)) + fi + + if [ -n "$LLM_MODEL_NAME" ]; then + echo -e "${GREEN}✅ LLM_MODEL_NAME 已配置${NC}" + ((PASS++)) + else + echo -e "${RED}❌ LLM_MODEL_NAME 未配置${NC}" + ((FAIL++)) + fi + + echo -e "${GREEN}✅ 无需配置对象存储(已优化)${NC}" + ((PASS++)) +else + echo -e "${RED}❌ .env 文件不存在${NC}" + ((FAIL++)) +fi +echo "" + +# 3. 文件完整性检查 +echo "3. 文件完整性检查" +echo "-----------------------------------" + +test_step "主入口文件" "test -f src/main.py" +test_step "主工作流" "test -f src/graphs/graph.py" +test_step "状态定义" "test -f src/graphs/state.py" +test_step "配置文件目录" "test -d config" +test_step "启动脚本" "test -f scripts/http_run.sh" +echo "" + +# 4. 模块导入测试 +echo "4. 模块导入测试" +echo "-----------------------------------" + +test_step "导入主模块" "python3 -c 'from graphs.state import GlobalState'" +test_step "导入节点模块" "python3 -c 'from graphs.nodes.doc_extract_node import doc_extract_node'" +echo "" + +# 5. 服务启动测试(可选) +echo "5. 服务启动测试" +echo "-----------------------------------" + +if command -v curl &> /dev/null; then + # 检查服务是否已启动 + if curl -s http://localhost:8000/health > /dev/null 2>&1; then + echo -e "${GREEN}✅ 服务已运行在 http://localhost:8000${NC}" + ((PASS++)) + + # 测试健康检查 + HEALTH=$(curl -s http://localhost:8000/health) + if echo "$HEALTH" | grep -q "ok"; then + echo -e "${GREEN}✅ 健康检查通过${NC}" + ((PASS++)) + else + echo -e "${RED}❌ 健康检查失败${NC}" + ((FAIL++)) + fi + else + echo -e "${YELLOW}⚠️ 服务未启动,跳过服务测试${NC}" + echo " 启动服务: bash scripts/http_run.sh -p 8000" + fi +else + echo -e "${YELLOW}⚠️ curl 未安装,跳过服务测试${NC}" +fi +echo "" + +# 测试总结 +echo "======================================" +echo " 测试总结" +echo "======================================" +echo "" +echo -e "${GREEN}通过: $PASS${NC}" +echo -e "${RED}失败: $FAIL${NC}" +echo "" + +if [ $FAIL -eq 0 ]; then + echo -e "${GREEN}✅ 所有测试通过!部署成功!${NC}" + echo "" + echo "下一步:" + echo " 1. 启动服务: bash scripts/http_run.sh -p 8000" + echo " 2. 访问文档: http://localhost:8000/docs" + echo " 3. 发送测试请求: curl -X POST http://localhost:8000/run -H 'Content-Type: application/json' -d @test_payload.json" + exit 0 +else + echo -e "${RED}❌ 部分测试失败,请检查配置${NC}" + echo "" + echo "常见问题:" + echo " - 环境变量未配置: 编辑 .env 文件" + echo " - 依赖未安装: pip install -r requirements.txt" + echo " - 文件缺失: 检查项目完整性" + exit 1 +fi diff --git a/test_payload.json b/test_payload.json new file mode 100644 index 0000000..954a505 --- /dev/null +++ b/test_payload.json @@ -0,0 +1,16 @@ +{ + "student_homework": [ + { + "student_id": 0, + "student_name": "测试学生", + "homework_images": [ + "https://example.com/homework1.jpg", + "https://example.com/homework2.jpg" + ] + } + ], + "answer_doc_url": "https://example.com/answer.docx", + "subject": "physics", + "comment_max_length": 100, + "max_concurrent": 10 +}