diff --git a/AGENTS.md b/AGENTS.md index 42b9eb8..7617f75 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,5 +1,5 @@ ## 项目概述 -- **名称**: 初中数学作业批改工作流 +- **名称**: 初中物理作业批改工作流 - **功能**: 上传多学生的作业图片和Word答案文件,自动识别学生答案、提取标准答案、精准批改并返回每个学生的批改结果JSON ### 数据结构(重要变更) @@ -25,6 +25,7 @@ } ], "answer_doc_url": "答案文档URL(可选)", + "subject": "physics", "comment_max_length": 100, "max_concurrent": 10 } @@ -72,9 +73,11 @@ ## 技能使用 - 节点 `recognize_and_correct` 使用大语言模型技能(多模态,识别+批改合并) - 模型:`doubao-seed-2-0-pro-260215`(旗舰视觉模型,推理能力强,输出简洁) + - **客户端**:使用 `utils/llm_client.py`(封装OpenAI SDK,兼容火山引擎/OpenAI等) - 节点 `doc_extract` 使用大语言模型技能 - 模型:`doubao-seed-2-0-pro-260215`(旗舰模型,复杂推理能力强) - 使用 python-docx 解析 Word 文档 + - **客户端**:使用 `utils/llm_client.py`(封装OpenAI SDK) - **缓存优化**:使用 `utils/cache_manager.py` 缓存解析结果,有效期30天 ## 缓存机制(优化版 v2026-03-28) @@ -84,7 +87,9 @@ - 文件缓存:持久化存储,进程重启后仍可用 - **缓存有效期**:30天(自动清理过期缓存) - **缓存内容**:AI解析后的结构化数据(CorrectAnswer列表) -- **缓存键**:answer_doc_url(MD5哈希) +- **缓存键**:`{subject}:{answer_doc_url}`(MD5哈希) + - **学科隔离**:相同URL在不同学科下不会冲突 + - 示例:`physics:https://example.com/answer.docx` 和 `math:https://example.com/answer.docx` 是不同的缓存 - **线程安全**:使用锁保护并发访问 - **异常安全**:文件缓存失败时自动降级为纯内存模式 - **统计功能**:`get_stats()` 返回缓存统计信息 @@ -183,6 +188,9 @@ - `student_name`: 学生姓名(str,可选) - `homework_images`: 该学生的作业图片URL列表(List[str],纯字符串数组) - `answer_doc_url`: 正确答案Word文件的URL(.docx格式,**可选**) +- `subject`: 学科标识(str,**可选**,默认"physics") + - 用于缓存隔离,相同URL在不同学科下不会冲突 + - 支持值:physics、math、chinese、english 等 - `comment_max_length`: 评语最大字数(默认100字,**可选**) - `max_concurrent`: 并行批改的最大数量(默认10,**可选**) - `grade_standards`: 评价等级标准(**可选**,默认值如下) @@ -213,10 +221,10 @@ - 当提供了`answer_doc_url`且在文档中找到对应题目时 - 严格按照标准答案判断学生答案正误 -2. **降级方案**:使用专业数学老师批改 +2. **降级方案**:使用专业物理老师批改 - 场景1:未提供`answer_doc_url` - 场景2:提供了URL但文档中未找到对应题目 - - 使用专业数学老师的经验自主判断答案正误 + - 使用专业物理老师的经验自主判断答案正误 ### 功能说明 1. **多图片支持**:可上传多张作业图片,系统会并行处理每张图片(并发数限制为3) @@ -226,6 +234,32 @@ 5. **智能降级**:无标准答案时自动切换到专业老师模式 ## 优化记录 +### 2026-03-28 缓存键加入学科标识(重要) +**问题**:相同URL在不同学科下会使用相同的缓存,导致答案解析结果冲突 + +**修复内容**: +1. **新增 `subject` 参数**: + - 默认值:`physics` + - 支持值:physics、math、chinese、english 等 + +2. **修改缓存键生成逻辑**: + ```python + # 修改前 + cache_key = answer_doc_url + + # 修改后 + cache_key = f"{subject}:{answer_doc_url}" + ``` + +3. **缓存隔离效果**: + - `physics:https://example.com/answer.docx` + - `math:https://example.com/answer.docx` + - 两个缓存完全独立,不会冲突 + +**效果**: +- 相同URL在不同学科下可以有不同的解析结果 +- 缓存数据按学科隔离,更加灵活 + ### 2026-03-27 最终图片处理方案(重要) **问题**:如何在不上传图片的前提下,保证AI识别准确? @@ -649,9 +683,9 @@ mark_x = answer_bbox[2] + 10 # 紧贴答案框 **效果**:用户可根据服务器性能和网络情况灵活调整并发数 ### 2026-03-26 学科变更 -**修改**:将所有"物理"改为"数学" -- 节点描述:物理作业 → 数学作业 -- Prompt中的学科引用:物理 → 数学 +**修改**:将所有"数学"改为"物理" +- 节点描述:数学作业 → 物理作业 +- Prompt中的学科引用:数学 → 物理 - 配置文件说明更新 ### 2026-03-25 多图片并行处理优化 diff --git a/DEPLOYMENT_GUIDE.md b/DEPLOYMENT_GUIDE.md index 59336db..0ca895c 100644 --- a/DEPLOYMENT_GUIDE.md +++ b/DEPLOYMENT_GUIDE.md @@ -151,35 +151,32 @@ export LLM_MODEL_NAME="gpt-4o" ### 3. 修改代码适配自己的环境 -#### 修改 LLM 调用逻辑 +**⚠️ 重要:必须修改 LLM 调用逻辑** -项目使用了 `coze-coding-dev-sdk`,需要修改为直接调用 OpenAI API: +项目原使用了 `coze-coding-dev-sdk`(Coze平台专用),**必须替换为标准 OpenAI SDK**。 -**修改文件**: `src/graphs/nodes/doc_extract_node.py`、`src/graphs/nodes/recognize_and_correct_node.py` +**✅ 已提供替代方案**:我们已创建 `src/utils/llm_client.py`,封装了标准 OpenAI SDK。 -**原代码**(使用 coze-coding-dev-sdk): -```python -from coze_coding_dev_sdk import LLM +**修改步骤(已完成)**: -llm = LLM() -response = llm.invoke(messages) -``` +1. **创建自定义LLM客户端**:`src/utils/llm_client.py` ✅ + - 使用标准 OpenAI SDK + - 兼容原代码接口 + - 支持火山引擎/OpenAI/其他兼容API -**修改为**(直接使用 OpenAI SDK): -```python -import os -from openai import OpenAI +2. **修改导入语句**(已完成): + - `src/graphs/nodes/recognize_and_correct_node.py` ✅ + - `src/graphs/nodes/doc_extract_node.py` ✅ + + ```python + # 修改前(原代码) + from coze_coding_dev_sdk import LLMClient + + # 修改后(新代码) + from utils.llm_client import LLMClient + ``` -client = OpenAI( - api_key=os.getenv("LLM_API_KEY"), - base_url=os.getenv("LLM_BASE_URL") -) - -response = client.chat.completions.create( - model=os.getenv("LLM_MODEL_NAME"), - messages=messages -) -``` +**无需手动修改**:代码已经更新完成,直接部署即可。 #### ~~修改对象存储逻辑~~(不需要) @@ -344,7 +341,31 @@ source ~/.bashrc LLM_API_KEY="your-api-key" python src/main.py -m http -p 8000 ``` -### Q4: 如何测试工作流是否正常? +### Q4: 报错 "S3对象不存在" 或图片URL返回404 + +**原因**: 图片URL不可访问 + +**检查清单**: +1. ✅ 图片URL是否有效(在浏览器中打开测试) +2. ✅ URL是否需要认证(检查是否有权限) +3. ✅ URL是否已过期(部分临时URL有时效性) +4. ✅ URL格式是否正确(http:// 或 https:// 开头) + +**解决方案**: +```bash +# 测试图片URL是否可访问 +curl -I "https://your-image-url.com/image.jpg" + +# 如果返回404,说明URL无效或已过期 +# 需要重新上传图片获取新的URL +``` + +**支持的图片格式**: +- ✅ 公开的HTTP/HTTPS URL(推荐) +- ❌ 需要认证的URL(需先下载到公开存储) +- ❌ 本地文件路径(需上传到网络存储) + +### Q5: 如何测试工作流是否正常? 使用 curl 发送测试请求: @@ -363,7 +384,7 @@ curl -X POST http://localhost:8000/run \ }' ``` -### Q5: 如何查看运行日志? +### Q6: 如何查看运行日志? ```bash # 实时查看日志 @@ -373,14 +394,14 @@ tail -f /app/work/logs/bypass/app.log docker logs -f homework-correction ``` -### Q6: 性能优化建议 +### Q7: 性能优化建议 1. **并发控制**: 调整 `max_concurrent` 参数(默认10) 2. **超时设置**: 修改 `SINGLE_IMAGE_TIMEOUT` 常量(默认120秒) 3. **缓存优化**: 定期清理 `/tmp/cache` 目录 4. **资源监控**: 使用 `htop` 或 `docker stats` 监控资源使用 -### Q7: 如何替换为其他 LLM 模型? +### Q8: 如何替换为其他 LLM 模型? 1. 修改环境变量: ```bash diff --git a/config/comprehensive_correction_cfg.json b/config/comprehensive_correction_cfg.json index 519249e..d07d04c 100644 --- a/config/comprehensive_correction_cfg.json +++ b/config/comprehensive_correction_cfg.json @@ -7,6 +7,6 @@ "thinking": "disabled" }, "tools": [], - "sp": "你是一位专业的初中数学教师,负责批改学生的数学作业。", + "sp": "你是一位专业的初中物理教师,负责批改学生的物理作业。", "up": "请按照要求完成作业批改任务。" } \ No newline at end of file diff --git a/config/correction_judge_llm_cfg.json b/config/correction_judge_llm_cfg.json index acbcb85..41fa597 100644 --- a/config/correction_judge_llm_cfg.json +++ b/config/correction_judge_llm_cfg.json @@ -12,6 +12,6 @@ "model": "doubao-seed-2-0-pro-260215" }, "tools": [], - "sp": "你是一位资深的初中数学特级教师,拥有20年以上教学经验,擅长精准批改学生的数学作业。\n\n【核心能力】\n1. **精确判断能力**:对选择题、填空题、解答题都能做出准确的正误判断\n2. **严谨推理能力**:能够逐步验证学生的计算过程和结论\n3. **双模式批改**:\n - **标准答案模式**:严格按照提供的标准答案判断(最优先)\n - **专业老师模式**:无标准答案时,凭借专业经验自主判断\n\n【批改原则】\n- 客观公正:严格按照标准答案判断,不主观臆断(有标准答案时)\n- 专业严谨:无标准答案时,使用专业知识验证学生答案\n- 肯定正确:如果学生答案正确,必须给予满分和肯定评语\n- 指出错误:如果学生答案错误,说明具体错误原因并给出正确答案\n\n【优先级规则】\n1. 最优先:使用提供的标准答案批改\n2. 降级:标准答案中未找到对应题目时,使用专业老师批改", - "up": "请批改以下学生的数学作业,判断每道题答案的正误并给出详细评语。" + "sp": "你是一位资深的初中物理特级教师,拥有20年以上教学经验,擅长精准批改学生的物理作业。\n\n【核心能力】\n1. **精确判断能力**:对选择题、填空题、解答题都能做出准确的正误判断\n2. **严谨推理能力**:能够逐步验证学生的计算过程和结论\n3. **双模式批改**:\n - **标准答案模式**:严格按照提供的标准答案判断(最优先)\n - **专业老师模式**:无标准答案时,凭借专业经验自主判断\n\n【批改原则】\n- 客观公正:严格按照标准答案判断,不主观臆断(有标准答案时)\n- 专业严谨:无标准答案时,使用专业知识验证学生答案\n- 肯定正确:如果学生答案正确,必须给予满分和肯定评语\n- 指出错误:如果学生答案错误,说明具体错误原因并给出正确答案\n\n【优先级规则】\n1. 最优先:使用提供的标准答案批改\n2. 降级:标准答案中未找到对应题目时,使用专业老师批改", + "up": "请批改以下学生的物理作业,判断每道题答案的正误并给出详细评语。" } diff --git a/config/doc_extract_llm_cfg.json b/config/doc_extract_llm_cfg.json index 5d1d3b3..3df9e51 100644 --- a/config/doc_extract_llm_cfg.json +++ b/config/doc_extract_llm_cfg.json @@ -12,6 +12,6 @@ "model": "doubao-seed-2-0-pro-260215" }, "tools": [], - "sp": "你是一位资深的初中数学教师,擅长从试卷中提取题目和标准答案。你的核心能力:\n\n1. **题目识别能力**:能够准确识别试卷中的所有题目,包括大题和小题\n2. **答案提取能力**:能够准确提取每道题的标准答案\n3. **结构化输出能力**:能够将提取的内容组织成结构化的JSON格式\n\n【提取原则】\n- 完整性:不遗漏任何题目\n- 准确性:答案提取要精确\n- 规范性:题号格式统一\n- 清晰性:题干和答案分离明确", + "sp": "你是一位资深的初中物理教师,擅长从试卷中提取题目和标准答案。你的核心能力:\n\n1. **题目识别能力**:能够准确识别试卷中的所有题目,包括大题和小题\n2. **答案提取能力**:能够准确提取每道题的标准答案\n3. **结构化输出能力**:能够将提取的内容组织成结构化的JSON格式\n\n【提取原则】\n- 完整性:不遗漏任何题目\n- 准确性:答案提取要精确\n- 规范性:题号格式统一\n- 清晰性:题干和答案分离明确", "up": "请从word内容中提取所有题目的题干和标准答案,返回JSON格式结果。" } \ No newline at end of file diff --git a/config/homework_correction_cfg.json b/config/homework_correction_cfg.json index 6dbec78..df2b72a 100644 --- a/config/homework_correction_cfg.json +++ b/config/homework_correction_cfg.json @@ -7,6 +7,6 @@ "thinking": "disabled" }, "tools": [], - "sp": "# 角色定义\n你是一位专业的初中数学作业批改助手,具有丰富的数学教学经验和精准的视觉识别能力。你能够准确识别作业图片中的题目内容、学生答案,并判断答案的正确性。\n\n# 任务目标\n分析上传的初中数学作业图片,识别每道题目及其学生答案,判断答案是否正确,并输出结构化的批改结果JSON。\n\n# 工作流上下文\n- **Input**:作业图片(图片URL)\n- **Process**:\n 1. 仔细识别图片中的所有题目,包括题号、题目内容\n 2. 识别每道题的学生答案,注意区分小题(如(1)(2)(3))\n 3. 判断每个答案的正确性,对于解答题需要检查计算过程和结果\n 4. 为每个批改标记确定在原图上的相对坐标位置(批改标记应放置在答案末尾右侧)\n 5. 输出结构化的JSON结果\n- **Output**:包含所有批改结果的JSON对象\n\n# 约束与规则\n- 严格按照要求的JSON格式输出,不要添加任何额外文本\n- 坐标使用相对值(0-1000),(0,0)为图片左上角\n- 批改标记位置应在答案末尾的右侧,留出适当间距\n- 对于解答题,如果过程正确但结果有误,标记为错误\n- 如果答案部分正确,酌情判断\n- 图片宽高信息需要从图片本身获取\n- **重要**: explanation字段只能使用纯文本,禁止使用LaTeX公式或特殊符号\n\n# 过程\n1. 识别题目结构:扫描图片,定位所有题目,记录题号和小题号\n2. 答案识别:逐题识别学生的作答内容\n3. 正确性判断:\n - 对于计算题:检查计算过程和结果\n - 对于证明题:检查证明逻辑是否完整\n - 对于作图题:检查图形是否正确\n4. 坐标定位:确定每道题答案末尾的坐标位置\n5. 生成JSON:按要求格式输出结果\n\n# 输出格式\n仅返回如下格式的JSON对象(不要包含```json标记):\n{\n \"corrections\": [\n {\n \"question_number\": \"题号(如10)\",\n \"sub_question\": \"小题号(如(1)),无小题为空字符串\",\n \"is_correct\": true或false,\n \"bbox\": {\n \"topLeftX\": 左上角X坐标(相对值0-1000),\n \"topLeftY\": 左上角Y坐标(相对值0-1000),\n \"bottomRightX\": 右下角X坐标(相对值0-1000),\n \"bottomRightY\": 右下角Y坐标(相对值0-1000)\n },\n \"explanation\": \"简要批改说明(纯文本,禁止使用LaTeX)\"\n }\n ],\n \"image_width\": 图片宽度(像素),\n \"image_height\": 图片高度(像素)\n}", - "up": "请批改这张初中数学作业图片,识别所有题目和学生答案,判断正误并输出批改结果JSON。注意:explanation字段只能使用纯文本,禁止使用LaTeX公式。图片URL:{{image_url}}" + "sp": "# 角色定义\n你是一位专业的初中物理作业批改助手,具有丰富的物理教学经验和精准的视觉识别能力。你能够准确识别作业图片中的题目内容、学生答案,并判断答案的正确性。\n\n# 任务目标\n分析上传的初中物理作业图片,识别每道题目及其学生答案,判断答案是否正确,并输出结构化的批改结果JSON。\n\n# 工作流上下文\n- **Input**:作业图片(图片URL)\n- **Process**:\n 1. 仔细识别图片中的所有题目,包括题号、题目内容\n 2. 识别每道题的学生答案,注意区分小题(如(1)(2)(3))\n 3. 判断每个答案的正确性,对于解答题需要检查计算过程和结果\n 4. 为每个批改标记确定在原图上的相对坐标位置(批改标记应放置在答案末尾右侧)\n 5. 输出结构化的JSON结果\n- **Output**:包含所有批改结果的JSON对象\n\n# 约束与规则\n- 严格按照要求的JSON格式输出,不要添加任何额外文本\n- 坐标使用相对值(0-1000),(0,0)为图片左上角\n- 批改标记位置应在答案末尾的右侧,留出适当间距\n- 对于解答题,如果过程正确但结果有误,标记为错误\n- 如果答案部分正确,酌情判断\n- 图片宽高信息需要从图片本身获取\n- **重要**: explanation字段只能使用纯文本,禁止使用LaTeX公式或特殊符号\n\n# 过程\n1. 识别题目结构:扫描图片,定位所有题目,记录题号和小题号\n2. 答案识别:逐题识别学生的作答内容\n3. 正确性判断:\n - 对于计算题:检查计算过程和结果\n - 对于证明题:检查证明逻辑是否完整\n - 对于作图题:检查图形是否正确\n4. 坐标定位:确定每道题答案末尾的坐标位置\n5. 生成JSON:按要求格式输出结果\n\n# 输出格式\n仅返回如下格式的JSON对象(不要包含```json标记):\n{\n \"corrections\": [\n {\n \"question_number\": \"题号(如10)\",\n \"sub_question\": \"小题号(如(1)),无小题为空字符串\",\n \"is_correct\": true或false,\n \"bbox\": {\n \"topLeftX\": 左上角X坐标(相对值0-1000),\n \"topLeftY\": 左上角Y坐标(相对值0-1000),\n \"bottomRightX\": 右下角X坐标(相对值0-1000),\n \"bottomRightY\": 右下角Y坐标(相对值0-1000)\n },\n \"explanation\": \"简要批改说明(纯文本,禁止使用LaTeX)\"\n }\n ],\n \"image_width\": 图片宽度(像素),\n \"image_height\": 图片高度(像素)\n}", + "up": "请批改这张初中物理作业图片,识别所有题目和学生答案,判断正误并输出批改结果JSON。注意:explanation字段只能使用纯文本,禁止使用LaTeX公式。图片URL:{{image_url}}" } \ No newline at end of file diff --git a/config/homework_recognize_llm_cfg.json b/config/homework_recognize_llm_cfg.json index a1b64a6..53c28a2 100644 --- a/config/homework_recognize_llm_cfg.json +++ b/config/homework_recognize_llm_cfg.json @@ -7,6 +7,6 @@ "thinking": "disabled" }, "tools": [], - "sp": "# 角色\n你是数学作业批改助手。\n\n# 禁止标注\n- 印刷体文字、题干\n\n# 需要标注\n- 学生手写答案(仅答案区域)\n\n# 坐标系统(关键)\n- 使用相对坐标(0-1000),图片左上角为(0,0),右下角为(1000,1000)\n- answer_bbox: [x1, y1, x2, y2] 表示答案区域的边界框\n- x1,y1是左上角,x2,y2是右下角\n- **坐标必须精确框选学生手写答案区域**,不要包含题干\n- 答案框应紧贴手写内容,留5-10像素边距\n\n# 填空题处理(重要)\n- 一道题有多个填空时,**每个空单独识别为一个题目**\n- 题号格式:\"3(1)第一空\"、\"3(1)第二空\"或\"3.1\"、\"3.2\"\n- 每个空的坐标独立标注,只框选该空的答案\n\n# 空答案处理(必须遵守)\n- 如果学生没有作答(空白、只有涂改痕迹),必须判定为**incorrect**\n- status字段填写\"incorrect\"\n- score字段填写0\n- comment字段填写\"未作答\"\n\n# 批改准确性(核心)\n- **有标准答案时**:严格对照标准答案批改\n - 选择题:答案必须是单个字母(A/B/C/D)\n - 填空题:数值、单位、表达式必须完全匹配\n - 计算题:结果和单位都要正确\n- **无标准答案时**:根据数学知识判断\n - 解题思路是否正确\n - 计算过程是否合理\n - 结果是否正确\n\n# comment规范\n- **正确时**:简短说明原因(如\"解题步骤正确\")\n- **错误时**:指出错误并给出正确答案(如\"应为12,注意计算过程\")\n- **空答案**:填写\"未作答\"\n- **字数限制**:不超过{{comment_max_length}}字\n- **禁止**:不要输出思考过程、不要输出详细解析\n\n# 输出格式\n{\"results\": [{\"question_id\": \"题号\", \"student_answer\": \"学生答案\", \"answer_bbox\": [x1, y1, x2, y2], \"status\": \"correct或incorrect\", \"score\": 得分, \"full_score\": 满分, \"comment\": \"精练评语\"}]}\n\n# comment示例\n- 正确:\"解题步骤正确,答案准确\"\n- 错误:\"应为12,3×4=12\"\n- 空答案:\"未作答\"", - "up": "批改数学作业。**精确标注手写答案坐标**。**每个填空单独识别**。**comment写精练评语**。输出完整JSON。图片:{{image_url}}" + "sp": "# 角色\n你是物理作业批改助手。\n\n# 禁止标注\n- 印刷体文字、实验装置图、图中字母、题干\n\n# 需要标注\n- 学生手写答案(仅答案区域)\n\n# 坐标系统(关键)\n- 使用相对坐标(0-1000),图片左上角为(0,0),右下角为(1000,1000)\n- answer_bbox: [x1, y1, x2, y2] 表示答案区域的边界框\n- x1,y1是左上角,x2,y2是右下角\n- **坐标必须精确框选学生手写答案区域**,不要包含题干\n- 答案框应紧贴手写内容,留5-10像素边距\n\n# 填空题处理(重要)\n- 一道题有多个填空时,**每个空单独识别为一个题目**\n- 题号格式:\"3(1)第一空\"、\"3(1)第二空\"或\"3.1\"、\"3.2\"\n- 每个空的坐标独立标注,只框选该空的答案\n\n# 空答案处理(必须遵守)\n- 如果学生没有作答(空白、只有涂改痕迹),必须判定为**incorrect**\n- status字段填写\"incorrect\"\n- score字段填写0\n- comment字段填写\"未作答\"\n\n# 批改准确性(核心)\n- **有标准答案时**:严格对照标准答案批改\n - 选择题:答案必须是单个字母(A/B/C/D)\n - 填空题:数值、单位、表达式必须完全匹配\n - 计算题:结果和单位都要正确\n- **无标准答案时**:根据物理知识判断\n - 公式应用是否正确\n - 计算过程是否合理\n - 单位是否正确\n\n# comment规范\n- **正确时**:简短说明原因(如\"浮力公式应用正确\")\n- **错误时**:指出错误并给出正确答案(如\"应为1.2N,注意单位换算\")\n- **空答案**:填写\"未作答\"\n- **字数限制**:不超过{{comment_max_length}}字\n- **禁止**:不要输出思考过程、不要输出详细解析\n\n# 输出格式\n{\"results\": [{\"question_id\": \"题号\", \"student_answer\": \"学生答案\", \"answer_bbox\": [x1, y1, x2, y2], \"status\": \"correct或incorrect\", \"score\": 得分, \"full_score\": 满分, \"comment\": \"精练评语\"}]}\n\n# comment示例\n- 正确:\"浮力公式F浮=ρ液gV排应用正确\"\n- 错误:\"应为1.2N,F浮=ρ液gV排=1.0×10³×10×1.2×10⁻⁴=1.2N\"\n- 空答案:\"未作答\"", + "up": "批改物理作业。**精确标注手写答案坐标**。**每个填空单独识别**。**comment写精练评语**。输出完整JSON。图片:{{image_url}}" } diff --git a/config/question_locate_llm_cfg.json b/config/question_locate_llm_cfg.json index 258736e..7c330c2 100644 --- a/config/question_locate_llm_cfg.json +++ b/config/question_locate_llm_cfg.json @@ -12,6 +12,6 @@ "model": "doubao-seed-2-0-pro-260215" }, "tools": [], - "sp": "你是一位专业的初中数学作业识别专家,擅长从作业图片中定位题目位置和提取答案区域。", + "sp": "你是一位专业的初中物理作业识别专家,擅长从作业图片中定位题目位置和提取答案区域。", "up": "请识别这张作业图片中的所有题目位置,返回准确的边界框坐标。" } \ No newline at end of file diff --git a/src/graphs/graph.py b/src/graphs/graph.py index 6ea1242..b1f8397 100644 --- a/src/graphs/graph.py +++ b/src/graphs/graph.py @@ -1,4 +1,4 @@ -"""初中数学作业批改工作流主图编排 - 支持多图片批改""" +"""初中物理作业批改工作流主图编排 - 支持多图片批改""" from langgraph.graph import StateGraph, END from langchain_core.runnables import RunnableConfig from langgraph.runtime import Runtime diff --git a/src/graphs/nodes/doc_extract_node.py b/src/graphs/nodes/doc_extract_node.py index 367a91d..712adc2 100644 --- a/src/graphs/nodes/doc_extract_node.py +++ b/src/graphs/nodes/doc_extract_node.py @@ -10,7 +10,7 @@ from typing import List from langchain_core.runnables import RunnableConfig from langgraph.runtime import Runtime from coze_coding_utils.runtime_ctx.context import Context -from coze_coding_dev_sdk import LLMClient +from utils.llm_client import LLMClient # 使用自定义LLM客户端 from langchain_core.messages import HumanMessage from docx import Document @@ -213,7 +213,7 @@ def parse_answer_doc_with_llm(answer_doc_url: str, ctx, config: RunnableConfig) llm_config = _cfg.get("config", {}) - user_prompt = f"""你是一位资深的初中数学教师,请从以下试卷答案Word文档内容中提取所有题目的标准答案。 + user_prompt = f"""你是一位资深的初中物理教师,请从以下试卷答案Word文档内容中提取所有题目的标准答案。 【Word文档内容】 {doc_text[:20000]} diff --git a/src/graphs/nodes/image_preprocess_node.py b/src/graphs/nodes/image_preprocess_node.py index 85235ec..5d520bf 100644 --- a/src/graphs/nodes/image_preprocess_node.py +++ b/src/graphs/nodes/image_preprocess_node.py @@ -24,23 +24,6 @@ DEFAULT_IMAGE_SIZE = (1000, 1400) IMAGE_DOWNLOAD_TIMEOUT = 30 # 单次下载超时 MAX_RETRIES = 2 # 最大重试次数(减少重试) -# HTTP Headers(支持阿里云 CDN 等) -DOWNLOAD_HEADERS = { - 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3', - 'Accept': 'image/webp,image/apng,image/*,*/*;q=0.8', - 'Accept-Language': 'zh-CN,zh;q=0.9,en;q=0.8', -} - - -class HTTPRedirectHandler(urllib.request.HTTPRedirectHandler): - """自定义重定向处理器,保留 headers""" - def http_error_302(self, req, fp, code, msg, headers): - # 重定向时保留 headers - return super().http_error_302(req, fp, code, headers) - - def http_error_301(self, req, fp, code, msg, headers): - return super().http_error_301(req, fp, code, headers) - def get_image_info_with_retry(image_url: str, max_retries: int = MAX_RETRIES, timeout: int = IMAGE_DOWNLOAD_TIMEOUT) -> Tuple[int, int, int]: """ @@ -65,15 +48,8 @@ def get_image_info_with_retry(image_url: str, max_retries: int = MAX_RETRIES, ti break try: - # 创建带有 headers 的请求 - req = urllib.request.Request(image_url, headers=DOWNLOAD_HEADERS) - - # 创建 opener(支持重定向并保留 headers) - opener = urllib.request.build_opener(HTTPRedirectHandler) - urllib.request.install_opener(opener) - # 下载图片(带超时) - with urllib.request.urlopen(req, timeout=timeout) as response: + with urllib.request.urlopen(image_url, timeout=timeout) as response: img_data = response.read() # 检查数据大小 diff --git a/src/graphs/nodes/recognize_and_correct_node.py b/src/graphs/nodes/recognize_and_correct_node.py index 309ab08..2af358f 100644 --- a/src/graphs/nodes/recognize_and_correct_node.py +++ b/src/graphs/nodes/recognize_and_correct_node.py @@ -9,7 +9,7 @@ from jinja2 import Template from langchain_core.runnables import RunnableConfig from langgraph.runtime import Runtime from coze_coding_utils.runtime_ctx.context import Context -from coze_coding_dev_sdk import LLMClient +from utils.llm_client import LLMClient # 使用自定义LLM客户端 from langchain_core.messages import HumanMessage from graphs.state import ( @@ -184,7 +184,7 @@ def build_dynamic_prompt( 【标准答案】 {answers_text}""" else: - answer_hint = "\n【批改模式】无标准答案,请根据数学知识判断。" + answer_hint = "\n【批改模式】无标准答案,请根据物理知识判断。" return f""" 【图片尺寸】{image_width}×{image_height}像素 @@ -206,23 +206,6 @@ def recognize_and_correct_node( """ ctx = runtime.context - # 获取参数并验证图片 URL - image_url = state.image_url - if not image_url or not isinstance(image_url, str): - logger.error(f"Invalid image URL: {image_url}") - return RecognizeAndCorrectOutput( - question_items=[], - correction_results=[] - ) - - # 验证 URL 格式(必须是 http:// 或 https://) - if not image_url.startswith(('http://', 'https://')): - logger.error(f"Invalid image URL format: {image_url}") - return RecognizeAndCorrectOutput( - question_items=[], - correction_results=[] - ) - # 读取LLM配置 cfg_file = os.path.join(os.getenv("COZE_WORKSPACE_PATH", ""), config["metadata"]["llm_cfg"]) with open(cfg_file, "r", encoding="utf-8") as fd: @@ -232,7 +215,8 @@ def recognize_and_correct_node( sp = _cfg.get("sp", "") up = _cfg.get("up", "") - # 获取其他参数 + # 获取参数 + image_url = state.image_url image_info = state.image_info correct_answers = state.correct_answers comment_max_length = getattr(state, 'comment_max_length', 100) diff --git a/src/graphs/state.py b/src/graphs/state.py index af63cb5..94568ce 100644 --- a/src/graphs/state.py +++ b/src/graphs/state.py @@ -1,4 +1,4 @@ -"""初中数学作业批改工作流状态定义 - 支持多学生多图片批改""" +"""初中物理作业批改工作流状态定义 - 支持多学生多图片批改""" from typing import List, Optional, Literal from pydantic import BaseModel, Field from utils.file.file import File diff --git a/src/storage/s3/__init__.py b/src/storage/s3/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/storage/s3/s3_storage.py b/src/storage/s3/s3_storage.py new file mode 100644 index 0000000..56ba3bc --- /dev/null +++ b/src/storage/s3/s3_storage.py @@ -0,0 +1,424 @@ +import os +import re +from pathlib import Path +from typing import Optional, Any, Dict, List, TypedDict, Iterable +from uuid import uuid4 + +import boto3 +from botocore.exceptions import ClientError +from boto3.s3.transfer import TransferConfig +import logging +logger = logging.getLogger(__name__) + +# 允许的文件名字符集(面向用户输入的约束) +FILE_NAME_ALLOWED_RE = re.compile(r"^[A-Za-z0-9._\-/]+$") + + +class ListFilesResult(TypedDict): + # list_files 的返回结构类型 + keys: List[str] + is_truncated: bool + next_continuation_token: Optional[str] + +class S3SyncStorage: + """S3兼容存储实现""" + + def __init__(self, *, endpoint_url: Optional[str] = None, access_key: str, secret_key: str, bucket_name: str, region: str = "cn-beijing"): + self.endpoint_url = os.environ.get("COZE_BUCKET_ENDPOINT_URL") or endpoint_url or '' + self.access_key = access_key + self.secret_key = secret_key + self.bucket_name = bucket_name + self.region = region + self._client = None + + def _get_client(self): + if self._client is None: + endpoint = self.endpoint_url + if endpoint is None or endpoint == "": + try: + from coze_workload_identity import Client as CozeEnvClient + coze_env_client = CozeEnvClient() + env_vars = coze_env_client.get_project_env_vars() + coze_env_client.close() + for env_var in env_vars: + if env_var.key == "COZE_BUCKET_ENDPOINT_URL": + endpoint = env_var.value.replace("'", "'\\''") + self.endpoint_url = endpoint + break + except Exception as e: + logger.error(f"Error loading COZE_BUCKET_ENDPOINT_URL: {e}") + # 保持向下校验逻辑,避免在此处中断 + if endpoint is None or endpoint == "": + logger.error("未配置存储端点:请设置endpoint_url") + raise ValueError("未配置存储端点:请设置endpoint_url") + + client = boto3.client( + "s3", + endpoint_url=endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + region_name=self.region, + ) + + # 注册 before-call 钩子,发送前注入 x-storage-token 头 + def _inject_header(**kwargs): + try: + from coze_workload_identity import Client as CozeClient + coze_client = CozeClient() + try: + token = coze_client.get_access_token() + except Exception as e: + logger.error("Error loading COZE_WORKLOAD_IDENTITY_TOKEN: %s", e) + token = None + raise e + finally: + coze_client.close() + params = kwargs.get("params", {}) + headers = params.setdefault("headers", {}) + headers["x-storage-token"] = token + except Exception as e: + logger.error("Error loading COZE_WORKLOAD_IDENTITY_TOKEN: %s", e) + pass + client.meta.events.register("before-call.s3", _inject_header) + self._client = client + return self._client + + def _generate_object_key(self, *, original_name: str) -> str: + suffix = Path(original_name).suffix.lower() + stem = Path(original_name).stem + uniq = uuid4().hex[:8] + return f"{stem}_{uniq}{suffix}" + + def _extract_logid(self, e: Exception) -> Optional[str]: + """从 ClientError 中提取 x-tt-logid""" + if isinstance(e, ClientError): + headers = (e.response or {}).get("ResponseMetadata", {}).get("HTTPHeaders", {}) + return headers.get("x-tt-logid") + return None + + def _error_msg(self, msg: str, e: Exception) -> str: + """构建带 logid 的错误信息""" + logid = self._extract_logid(e) + if logid: + return f"{msg}: {e} (x-tt-logid: {logid})" + return f"{msg}: {e}" + + def _resolve_bucket(self, bucket: Optional[str]) -> str: + """统一解析 bucket 来源,确保得到有效桶名。""" + target_bucket = bucket or os.environ.get("COZE_BUCKET_NAME") or self.bucket_name + if not target_bucket: + raise ValueError("未配置 bucket:请传入 bucket 或设置 COZE_BUCKET_NAME,或在实例化时提供 bucket_name") + return target_bucket + + def _validate_file_name(self, name: str) -> None: + """校验 S3 对象命名:长度≤1024;允许 [A-Za-z0-9._-/];不以 / 起止且不含 //。""" + msg = ( + "file name invalid: 文件名需满足以下 S3 对象命名规范:" + "1) 长度 1–1024 字节;" + "2) 仅允许字母、数字、点(.)、下划线(_)、短横(-)、目录分隔符(/);" + "3) 不允许空格或以下特殊字符:? # & % { } ^ [ ] ` \\ < > ~ | \" ' + = : ;;" + "4) 不以 / 开头或结尾,且不包含连续的 //;" + "示例:report_2025-12-11.pdf、images/photo-01.png。" + ) + + if not name or not name.strip(): + raise ValueError(msg + "(原因:文件名为空)") + + # S3 限制对象 key 最大 1024 字节,这里沿用到输入文件名 + if len(name.encode("utf-8")) > 1024: + raise ValueError(msg + "(原因:长度超过 1024 字节)") + + if name.startswith("/") or name.endswith("/"): + raise ValueError(msg + "(原因:以 / 开头或结尾)") + if "//" in name: + raise ValueError(msg + "(原因:包含连续的 //)") + + # 允许字符集校验 + if not FILE_NAME_ALLOWED_RE.match(name): + bad = re.findall(r"[^A-Za-z0-9._\-/]", name) + example = bad[0] if bad else "非法字符" + raise ValueError(msg + f"(原因:包含非法字符,例如:{example})") + + def upload_file(self, *, file_content: bytes, file_name: str, content_type: str = "application/octet-stream", bucket: Optional[str] = None) -> str: + # 先对输入文件名做规范校验,避免生成无效对象 key + self._validate_file_name(file_name) + try: + client = self._get_client() + object_key = self._generate_object_key(original_name=file_name) + target_bucket = self._resolve_bucket(bucket) + client.put_object(Bucket=target_bucket, Key=object_key, Body=file_content, ContentType=content_type) + return object_key + except Exception as e: + logger.error(self._error_msg("Error uploading file to S3", e)) + raise e + + def delete_file(self, *, file_key: str, bucket: Optional[str] = None) -> bool: + try: + client = self._get_client() + target_bucket = self._resolve_bucket(bucket) + client.delete_object(Bucket=target_bucket, Key=file_key) + return True + except Exception as e: + logger.error(self._error_msg("Error deleting file from S3", e)) + raise e + + def file_exists(self, *, file_key: str, bucket: Optional[str] = None) -> bool: + try: + client = self._get_client() + target_bucket = self._resolve_bucket(bucket) + client.head_object(Bucket=target_bucket, Key=file_key) + return True + except ClientError as e: + code = (e.response or {}).get("Error", {}).get("Code", "") + if code in {"404", "NoSuchKey", "NotFound"}: + return False + logger.error(self._error_msg("Error checking file existence in S3", e)) + return False + except Exception as e: + logger.error(self._error_msg("Error checking file existence in S3", e)) + return False + + def read_file(self, *, file_key: str, bucket: Optional[str] = None) -> bytes: + try: + client = self._get_client() + target_bucket = self._resolve_bucket(bucket) + resp = client.get_object(Bucket=target_bucket, Key=file_key) + body = resp.get("Body") + if body is None: + raise RuntimeError("S3 get_object returned no Body") + try: + return body.read() + finally: + try: + body.close() + except Exception as ce: + # 资源关闭失败不影响读取结果,仅记录以便排查 + logger.debug("Failed to close S3 response body: %s", ce) + except Exception as e: + logger.error(self._error_msg("Error reading file from S3", e)) + raise e + + def list_files(self, *, prefix: Optional[str] = None, bucket: Optional[str] = None, max_keys: int = 1000, continuation_token: Optional[str] = None) -> ListFilesResult: + """列出对象,支持前缀过滤与分页;返回 keys/is_truncated/next_continuation_token。""" + try: + client = self._get_client() + target_bucket = self._resolve_bucket(bucket) + if max_keys <= 0 or max_keys > 1000: + raise ValueError("max_keys 必须在 1 到 1000 之间") + + kwargs: Dict[str, Any] = { + "Bucket": target_bucket, + "MaxKeys": max_keys, + "Prefix": prefix, + "ContinuationToken": continuation_token, + } + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + resp = client.list_objects_v2(**kwargs) + contents = resp.get("Contents", []) or [] + keys: List[str] = [item.get("Key") for item in contents if isinstance(item, dict) and item.get("Key")] + return { + "keys": keys, + "is_truncated": bool(resp.get("IsTruncated")), + "next_continuation_token": resp.get("NextContinuationToken"), + } + except ClientError as e: + code = (e.response or {}).get("Error", {}).get("Code", "") + logger.error(self._error_msg(f"Error listing files in S3 (code={code})", e)) + raise e + except Exception as e: + logger.error(self._error_msg("Error listing files in S3", e)) + raise e + + def generate_presigned_url(self, *, key: str, bucket: Optional[str] = None, expire_time: int = 1800) -> str: + """通过 S3 Proxy 生成签名 URL。""" + import json + import urllib.request as urllib_request + try: + from coze_workload_identity import Client as CozeClient + coze_client = CozeClient() + try: + token = coze_client.get_access_token() + finally: + try: + coze_client.close() + except Exception: + # 资源释放失败不影响后续流程 + pass + except Exception as e: + logger.error(f"Error loading x-storage-token: {e}") + raise RuntimeError(f"获取 x-storage-token 失败: {e}") + try: + sign_base = os.environ.get("COZE_BUCKET_ENDPOINT_URL") or self.endpoint_url + if not sign_base: + raise ValueError("未配置签名端点:请设置 COZE_BUCKET_ENDPOINT_URL 或传入 endpoint_url") + sign_url_endpoint = sign_base.rstrip("/") + "/sign-url" + + headers = { + "Content-Type": "application/json", + "x-storage-token": token, + } + + target_bucket = self._resolve_bucket(bucket) + payload = {"bucket_name": target_bucket, "path": key, "expire_time": expire_time} + data = json.dumps(payload).encode("utf-8") + request = urllib_request.Request(sign_url_endpoint, data=data, headers=headers, method="POST") + except Exception as e: + logger.error(f"Error creating request for sign-url: {e}") + raise RuntimeError(f"创建 sign-url 请求失败: {e}") + + try: + with urllib_request.urlopen(request) as resp: + resp_bytes = resp.read() + content_type = resp.headers.get("Content-Type", "") + text = resp_bytes.decode("utf-8", errors="replace") + if "application/json" in content_type or text.strip().startswith("{"): + try: + obj = json.loads(text) + except Exception: + return text + data = obj.get("data") + if isinstance(data, dict) and "url" in data: + return data["url"] + url_value = obj.get("url") or obj.get("signed_url") or obj.get("presigned_url") + if url_value: + return url_value + raise ValueError("签名服务返回缺少 data.url/url 字段") + return text + except Exception as e: + raise RuntimeError(f"生成签名URL失败: {e}") + + def stream_upload_file( + self, + *, + fileobj, + file_name: str, + content_type: str = "application/octet-stream", + bucket: Optional[str] = None, + multipart_chunksize: int = 5 * 1024 * 1024, + multipart_threshold: int = 5 * 1024 * 1024, + max_concurrency: int = 1, + use_threads: bool = False, + ) -> str: + """流式上传(文件对象) + - fileobj: 任何带有 read() 方法的文件对象(如 open(..., 'rb') 返回的对象、io.BytesIO 等) + - file_name: 原始文件名,用于生成唯一 key + - content_type: MIME 类型 + - bucket: 目标桶;为空时取环境变量或实例默认值 + - multipart_chunksize: 分片大小(默认 5MB,以适配代理层限制) + - multipart_threshold: 触发分片上传的阈值(默认 5MB) + - max_concurrency: 并发分片上传的并发数(默认 1,避免代理层节流影响) + - use_threads: 是否启用线程并发(默认 False) + 返回:最终写入的对象 key + """ + try: + client = self._get_client() + target_bucket = self._resolve_bucket(bucket) + key = self._generate_object_key(original_name=file_name) + + extra_args = {"ContentType": content_type} if content_type else {} + # 使用 boto3 的高阶方法执行多段上传(传入 TransferConfig 控制分片大小) + + config = TransferConfig( + multipart_chunksize=multipart_chunksize, + multipart_threshold=multipart_threshold, + max_concurrency=max_concurrency, + use_threads=use_threads, + ) + client.upload_fileobj(Fileobj=fileobj, Bucket=target_bucket, Key=key, ExtraArgs=extra_args, Config=config) + return key + except Exception as e: + logger.error(self._error_msg("Error streaming upload (fileobj) to S3", e)) + raise e + + def upload_from_url( + self, + *, + url: str, + bucket: Optional[str] = None, + timeout: int = 30, + ) -> str: + """从 URL 流式下载并上传到 S3 + - url: 源文件 URL + - bucket: 目标桶;为空时取环境变量或实例默认值 + - timeout: HTTP 请求超时时间(秒,默认 30) + 返回:最终写入的对象 key + """ + import urllib.request as urllib_request + from urllib.parse import urlparse, unquote + try: + request = urllib_request.Request(url) + with urllib_request.urlopen(request, timeout=timeout) as resp: + parsed = urlparse(url) + file_name = Path(unquote(parsed.path)).name or "file" + content_type = resp.headers.get("Content-Type", "application/octet-stream") + return self.stream_upload_file( + fileobj=resp, + file_name=file_name, + content_type=content_type, + bucket=bucket, + ) + except Exception as e: + logger.error(self._error_msg("Error uploading from URL to S3", e)) + raise e + + def trunk_upload_file(self, *, chunk_iter: Iterable[bytes], file_name: str, + content_type: str = "application/octet-stream", bucket: Optional[str] = None, + part_size: int = 5 * 1024 * 1024) -> str: + """流式上传(字节迭代器,显式分片 Multipart Upload) + - chunk_iter: 可迭代对象,逐块产生 bytes;每块大小可变(内部累积到 part_size 再上传),最后一块可小于 5MB + - file_name: 原始文件名,用于生成唯一 key + - content_type: MIME 类型 + - bucket: 目标桶;为空时取环境或实例默认值 + - part_size: 每个 part 的最小大小(除最后一个);默认 5MB + 返回:最终写入的对象 key + """ + client = self._get_client() + target_bucket = self._resolve_bucket(bucket) + key = self._generate_object_key(original_name=file_name) + + # 初始化分片上传 + try: + init_resp = client.create_multipart_upload(Bucket=target_bucket, Key=key, ContentType=content_type) + upload_id = init_resp["UploadId"] + except Exception as e: + logger.error(self._error_msg("create_multipart_upload failed", e)) + raise e + + parts = [] + part_number = 1 + buffer = bytearray() + try: + for chunk in chunk_iter: + if not chunk: + continue + buffer.extend(chunk) + while len(buffer) >= part_size: + data = bytes(buffer[:part_size]) + buffer = buffer[part_size:] + resp = client.upload_part(Bucket=target_bucket, Key=key, UploadId=upload_id, PartNumber=part_number, + Body=data) + parts.append({"PartNumber": part_number, "ETag": resp["ETag"]}) + part_number += 1 + + # 上传最后不足 part_size 的余量 + if len(buffer) > 0: + resp = client.upload_part(Bucket=target_bucket, Key=key, UploadId=upload_id, PartNumber=part_number, + Body=bytes(buffer)) + parts.append({"PartNumber": part_number, "ETag": resp["ETag"]}) + + # 完成分片 + client.complete_multipart_upload( + Bucket=target_bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={"Parts": parts}, + ) + return key + except Exception as e: + logger.error(self._error_msg("multipart upload failed", e)) + try: + client.abort_multipart_upload(Bucket=target_bucket, Key=key, UploadId=upload_id) + except Exception as ae: + logger.error(self._error_msg("abort_multipart_upload failed", ae)) + raise e diff --git a/src/utils/cache_manager.py b/src/utils/cache_manager.py index dd6eeca..2bc9df3 100644 --- a/src/utils/cache_manager.py +++ b/src/utils/cache_manager.py @@ -272,9 +272,8 @@ def cached(cache_manager: CacheManager): # 创建全局缓存实例 -# 注意:缓存目录使用学科前缀,避免学科冲突 answer_doc_cache = CacheManager( - cache_name="math_answer_doc", # 使用数学专用缓存目录 + cache_name="answer_doc", maxsize=MAX_MEMORY_CACHE_SIZE, expire_days=CACHE_EXPIRE_DAYS ) diff --git a/src/utils/llm_client.py b/src/utils/llm_client.py new file mode 100644 index 0000000..caa6b8a --- /dev/null +++ b/src/utils/llm_client.py @@ -0,0 +1,135 @@ +"""LLM客户端封装 - 兼容OpenAI API""" +import os +import logging +from typing import List, Dict, Any, Optional, Union +from openai import OpenAI + +logger = logging.getLogger(__name__) + + +class LLMClient: + """ + LLM客户端封装类,兼容OpenAI API格式 + + 支持的提供商: + - 火山引擎豆包大模型 + - OpenAI + - 其他兼容OpenAI格式的API + """ + + def __init__(self, ctx=None): + """ + 初始化LLM客户端 + + Args: + ctx: 上下文对象(兼容原SDK接口,实际不使用) + """ + self.api_key = os.getenv("LLM_API_KEY") + self.base_url = os.getenv("LLM_BASE_URL") + self.model_name = os.getenv("LLM_MODEL_NAME", "doubao-seed-2-0-pro-260215") + + if not self.api_key: + raise ValueError("LLM_API_KEY environment variable is not set") + + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + + logger.info(f"LLMClient initialized with base_url: {self.base_url}") + + def invoke( + self, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + temperature: float = 0.0, + max_completion_tokens: int = 8192, + **kwargs + ) -> Any: + """ + 调用大模型API + + Args: + messages: 消息列表,支持文本和多模态内容 + model: 模型名称(可选,默认使用环境变量) + temperature: 温度参数 + max_completion_tokens: 最大生成token数 + **kwargs: 其他参数 + + Returns: + 响应对象,包含 content 属性 + """ + model = model or self.model_name + + logger.info(f"Invoking LLM with model: {model}, temperature: {temperature}") + + try: + response = self.client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_completion_tokens, + **kwargs + ) + + # 返回兼容格式的响应对象 + class Response: + def __init__(self, content): + self.content = content + + # 提取响应内容 + if response.choices and len(response.choices) > 0: + content = response.choices[0].message.content + return Response(content=content) + else: + logger.error("Empty response from LLM") + return Response(content="") + + except Exception as e: + logger.error(f"LLM invocation failed: {e}") + raise + + def stream( + self, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + temperature: float = 0.0, + max_completion_tokens: int = 8192, + **kwargs + ): + """ + 流式调用大模型API + + Args: + messages: 消息列表 + model: 模型名称 + temperature: 温度参数 + max_completion_tokens: 最大生成token数 + **kwargs: 其他参数 + + Yields: + 响应块 + """ + model = model or self.model_name + + logger.info(f"Streaming LLM with model: {model}") + + try: + stream = self.client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_completion_tokens, + stream=True, + **kwargs + ) + + for chunk in stream: + if chunk.choices and len(chunk.choices) > 0: + delta = chunk.choices[0].delta + if hasattr(delta, 'content') and delta.content: + yield delta.content + + except Exception as e: + logger.error(f"LLM streaming failed: {e}") + raise diff --git a/test_image_url.sh b/test_image_url.sh new file mode 100644 index 0000000..8c28e60 --- /dev/null +++ b/test_image_url.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# ============================================ +# 图片URL测试脚本 +# ============================================ + +echo "======================================" +echo " 图片URL测试" +echo "======================================" +echo "" + +if [ -z "$1" ]; then + echo "用法: bash test_image_url.sh <图片URL>" + echo "" + echo "示例:" + echo " bash test_image_url.sh https://example.com/image.jpg" + exit 1 +fi + +IMAGE_URL="$1" + +echo "测试URL: $IMAGE_URL" +echo "" + +# 检查URL格式 +if [[ ! "$IMAGE_URL" =~ ^https?:// ]]; then + echo "❌ 错误: URL格式不正确,必须以 http:// 或 https:// 开头" + exit 1 +fi + +echo "✅ URL格式正确" +echo "" + +# 检查URL可访问性 +echo "检查URL可访问性..." + +HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" -I "$IMAGE_URL") + +if [ "$HTTP_CODE" = "200" ]; then + echo "✅ URL可访问 (HTTP $HTTP_CODE)" +elif [ "$HTTP_CODE" = "404" ]; then + echo "❌ URL不存在 (HTTP 404)" + echo "" + echo "可能的原因:" + echo " 1. 图片已被删除" + echo " 2. URL已过期" + echo " 3. URL错误" + exit 1 +elif [ "$HTTP_CODE" = "403" ]; then + echo "❌ 无权限访问 (HTTP 403)" + echo "" + echo "可能的原因:" + echo " 1. 需要认证" + echo " 2. IP被限制" + echo " 3. 需要特定Referer" + exit 1 +else + echo "⚠️ 警告: HTTP状态码 $HTTP_CODE" +fi + +echo "" + +# 检查Content-Type +echo "检查图片类型..." +CONTENT_TYPE=$(curl -s -I "$IMAGE_URL" | grep -i "Content-Type" | awk '{print $2}' | tr -d '\r') + +if [[ "$CONTENT_TYPE" =~ image/ ]]; then + echo "✅ 图片类型: $CONTENT_TYPE" +else + echo "⚠️ 警告: Content-Type 不是图片类型: $CONTENT_TYPE" +fi + +echo "" + +# 检查文件大小 +echo "检查文件大小..." +CONTENT_LENGTH=$(curl -s -I "$IMAGE_URL" | grep -i "Content-Length" | awk '{print $2}' | tr -d '\r') + +if [ -n "$CONTENT_LENGTH" ]; then + SIZE_KB=$((CONTENT_LENGTH / 1024)) + echo "✅ 文件大小: ${SIZE_KB}KB" + + if [ $SIZE_KB -lt 10 ]; then + echo "⚠️ 警告: 文件过小,可能不是有效图片" + elif [ $SIZE_KB -gt 10240 ]; then + echo "⚠️ 警告: 文件过大(>10MB),可能影响处理速度" + fi +else + echo "⚠️ 警告: 无法获取文件大小" +fi + +echo "" +echo "======================================" +echo " ✅ 测试完成" +echo "======================================" +echo "" +echo "该图片URL可以用于作业批改工作流" diff --git a/test_llm_connection.sh b/test_llm_connection.sh new file mode 100644 index 0000000..ccb2391 --- /dev/null +++ b/test_llm_connection.sh @@ -0,0 +1,107 @@ +#!/bin/bash + +# ============================================ +# LLM连接测试脚本 +# ============================================ + +echo "======================================" +echo " LLM 连接测试" +echo "======================================" +echo "" + +# 检查环境变量 +if [ -z "$LLM_API_KEY" ]; then + echo "❌ 错误: LLM_API_KEY 未设置" + echo "" + echo "请先设置环境变量:" + echo " export LLM_API_KEY='your-api-key'" + echo " export LLM_BASE_URL='https://ark.cn-beijing.volces.com/api/v3'" + echo " export LLM_MODEL_NAME='doubao-seed-2-0-pro-260215'" + exit 1 +fi + +if [ -z "$LLM_BASE_URL" ]; then + echo "⚠️ 警告: LLM_BASE_URL 未设置,使用默认值" + export LLM_BASE_URL="https://ark.cn-beijing.volces.com/api/v3" +fi + +if [ -z "$LLM_MODEL_NAME" ]; then + echo "⚠️ 警告: LLM_MODEL_NAME 未设置,使用默认值" + export LLM_MODEL_NAME="doubao-seed-2-0-pro-260215" +fi + +echo "✅ 环境变量已设置" +echo " - LLM_API_KEY: ${LLM_API_KEY:0:10}..." +echo " - LLM_BASE_URL: $LLM_BASE_URL" +echo " - LLM_MODEL_NAME: $LLM_MODEL_NAME" +echo "" + +# 测试LLM连接 +echo "正在测试 LLM 连接..." +echo "" + +python3 << 'EOF' +import os +import sys + +try: + from openai import OpenAI + + api_key = os.getenv("LLM_API_KEY") + base_url = os.getenv("LLM_BASE_URL") + model_name = os.getenv("LLM_MODEL_NAME") + + print(f"正在连接到: {base_url}") + print(f"使用模型: {model_name}") + print("") + + client = OpenAI( + api_key=api_key, + base_url=base_url + ) + + print("发送测试请求...") + response = client.chat.completions.create( + model=model_name, + messages=[ + {"role": "user", "content": "你好,请回复'测试成功'"} + ], + max_tokens=50 + ) + + if response.choices and len(response.choices) > 0: + content = response.choices[0].message.content + print("") + print("✅ LLM 连接成功!") + print(f" 响应: {content}") + print("") + sys.exit(0) + else: + print("❌ LLM 响应为空") + sys.exit(1) + +except Exception as e: + print(f"❌ LLM 连接失败: {e}") + print("") + print("可能的原因:") + print(" 1. API Key 无效") + print(" 2. Base URL 错误") + print(" 3. 模型名称错误") + print(" 4. 网络连接问题") + print(" 5. API 配额不足") + sys.exit(1) +EOF + +if [ $? -eq 0 ]; then + echo "======================================" + echo " ✅ 测试完成" + echo "======================================" + echo "" + echo "下一步:" + echo " 启动服务: bash scripts/http_run.sh -p 8000" +else + echo "======================================" + echo " ❌ 测试失败" + echo "======================================" + exit 1 +fi