commit 3eb42ade2c04eae13487acb945072780608f774f Author: zhangquan Date: Thu Mar 26 11:54:35 2026 +0800 项目初始化 diff --git a/.coze b/.coze new file mode 100644 index 0000000..4f98389 --- /dev/null +++ b/.coze @@ -0,0 +1,14 @@ +[project] +entrypoint = "src/main.py" +requires = ["python-3.12"] + +[dev] +build = ["bash", "scripts/setup.sh"] +run = ["bash", "/workspace/projects/scripts/http_run.sh", "-p 5000"] +pack = ["bash", "/workspace/projects/scripts/pack.sh"] +deps = ["git"] # -> apt install git + +[deploy] +build = ["bash", "scripts/setup.sh"] +run = ["bash", "scripts/http_run.sh", "-p 5000"] +deps = ["git"] # -> apt install git \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2c298eb --- /dev/null +++ b/.gitignore @@ -0,0 +1,48 @@ +node_modules +dist +.DS_Store +*.swp +.git/* +__pycache__/ +*.pyc +.venv/* +*.o +*.a +/vendor +*.egg-info/ +/.env +.env + +# 视频文件 +*.mp4 +*.avi +*.mov +*.wmv +*.flv +*.mkv +*.webm +*.m4v +*.mpeg +*.mpg +*.3gp +*.f4v +*.rmvb +*.vob + +# 归档文件 +*.iso +*.dmg +*.rar +*.zip +*.gz + +# 文档类也默认加入 +*.pdf +*.docx +*.doc +*.xlsx +*.xls +*.ppt +*.pptx +*.xlsx +*.csv diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..1c7eb30 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,391 @@ +## 项目概述 +- **名称**: 初中物理作业批改工作流 +- **功能**: 上传多张作业图片和Word答案文件,自动识别学生答案、提取标准答案、精准批改并返回批改结果JSON + +### 节点清单 +| 节点名 | 文件位置 | 类型 | 功能描述 | 分支逻辑 | 配置文件 | +|-------|---------|------|---------|---------|---------| +| doc_extract | `nodes/doc_extract_node.py` | agent | 从Word文件(.docx)提取题干和标准答案;如无URL则返回空列表 | - | `config/doc_extract_llm_cfg.json` | +| process_images | `nodes/process_images_node.py` | looparray | 循环调用子图处理每张作业图片,生成最终批改结果 | - | - | + +**类型说明**: task(普通任务节点) / agent(大模型节点) / condition(条件分支) / looparray(列表循环) / loopcond(条件循环) + +## 子图清单 +| 子图名 | 文件位置 | 功能描述 | 被调用节点 | +|-------|---------|------|---------| +| single_image_subgraph | `graphs/loop_graph.py` | 处理单张图片的完整批改流程(预处理→识别批改→整合→包装) | process_images | + +### 子图内部节点 +| 节点名 | 文件位置 | 类型 | 功能描述 | +|-------|---------|------|---------| +| image_preprocess | `nodes/image_preprocess_node.py` | task | 下载图片、自动旋转(横向→纵向)、缩放到固定宽度1000px、上传对象存储 | +| recognize_and_correct | `nodes/recognize_and_correct_node.py` | agent | **一体化识别批改**:合并识别题目和批改为一次LLM调用 | +| result_merge | `nodes/result_merge_node.py` | task | 将识别结果和批改结果合并为最终批注 | +| wrap_result | `graphs/loop_graph.py` | task | 包装子图结果为SingleImageResult输出 | + +## 技能使用 +- 节点 `recognize_and_correct` 使用大语言模型技能(多模态,识别+批改合并) + - 模型:`doubao-seed-2-0-pro-260215`(旗舰视觉模型,推理能力强,输出简洁) +- 节点 `doc_extract` 使用大语言模型技能 + - 模型:`doubao-seed-2-0-pro-260215`(旗舰模型,复杂推理能力强) + - 使用 python-docx 解析 Word 文档 + +## 工作流程(多图片批改架构) + +``` +┌─────────────────────┐ +│ doc_extract │ +│ (Word答案解析) │ +└──────────┬──────────┘ + │ + ▼ +┌─────────────────────┐ +│ process_images │ +│ (多图片循环处理) │ +│ 生成最终批改结果 │ +└─────────────────────┘ +``` + +### 子图内部流程(处理单张图片 - 优化版) + +``` +┌─────────────────────┐ +│ image_preprocess │ +│ (图像预处理) │ +└──────────┬──────────┘ + │ + ▼ +┌─────────────────────┐ +│recognize_and_correct│ ← 合并节点:识别+批改一次完成 +│ (一体化识别批改) │ +└──────────┬──────────┘ + │ + ▼ +┌─────────────────────┐ +│ result_merge │ +│ (结果整合) │ +└──────────┬──────────┘ + │ + ▼ +┌─────────────────────┐ +│ wrap_result │ +│ (包装输出) │ +└─────────────────────┘ +``` + +## 核心功能:多图片批改机制 + +### 输入参数 +- `homework_images`: 上传的作业图片列表(List[File],支持多张图片) +- `answer_doc_url`: 正确答案Word文件的URL(.docx格式,**可选**) +- `comment_max_length`: 评语最大字数(默认100字,**可选**) +- `max_concurrent`: 并行批改的最大数量(默认10,**可选**) +- `grade_standards`: 评价等级标准(**可选**,默认值如下) + ```json + { + "A+": {"min_percentage": 95, "description": "答案全部正确,步骤完整规范,逻辑严谨;书写/格式整洁,无错别字、无遗漏;完成度100%,态度认真,质量上乘"}, + "A": {"min_percentage": 90, "description": "答案完全正确,无任何错误;步骤合理、格式规范,无原则性问题;完成度100%,满足全部要求"}, + "B": {"min_percentage": 80, "description": "存在少量非关键性错误,或步骤略有缺失;整体思路基本正确,仅细节、格式、计算等小问题;完成大部分内容,整体合格但不够严谨"}, + "C": {"min_percentage": 70, "description": "错误较多,部分核心题目作答错误;步骤不完整、逻辑不够清晰;完成度一般,有明显应付、漏答情况"}, + "D": {"min_percentage": 0, "description": "大面积错误,核心知识点未掌握;大量空白、敷衍、抄袭;未达到基本完成要求"} + } + ``` + +### 输出结果 +- `final_result`: 最终批改结果JSON(包含多图片) + - `total_images`: 总图片数 + - `image_results`: 各图片的批改结果列表 + - `overall_comment`: 整体评价(根据得分率生成) + - `total_score`: 总得分 + - `full_score`: 总满分 + - `grade`: 等级评定 + +### 批改优先级(严格按照以下顺序) +1. **最优先**:使用Word文档中的标准答案批改 + - 当提供了`answer_doc_url`且在文档中找到对应题目时 + - 严格按照标准答案判断学生答案正误 + +2. **降级方案**:使用专业物理老师批改 + - 场景1:未提供`answer_doc_url` + - 场景2:提供了URL但文档中未找到对应题目 + - 使用专业物理老师的经验自主判断答案正误 + +### 功能说明 +1. **多图片支持**:可上传多张作业图片,系统会并行处理每张图片(并发数限制为3) +2. **Word答案提取**:从.docx文件中提取题干和标准答案 +3. **子图循环处理**:使用子图封装单图片处理流程,主图调用子图处理每张图片 +4. **批改结果JSON**:返回包含所有图片批改结果的结构化JSON +5. **智能降级**:无标准答案时自动切换到专业老师模式 + +## 优化记录 +### 2026-03-26 填空题拆分优化(重要) +**问题**:一道题有多个填空时,被合并成一个答案,批改标记无法精准定位 + +**修复内容**: +1. **优化Prompt**: + - 明确要求:一道题有多个填空时,**每个空单独识别为一个题目** + - 题号格式:\"3(1)第一空\"、\"3(1)第二空\"、\"4(2)第一空\"、\"4(2)第二空\" + - 每个空单独批改,单独打分 + +2. **示例说明**: + ``` + ❌ 错误:3(1) → "4、1"(合并) + ✅ 正确:3(1)第一空 → "4" + 3(1)第二空 → "1" + ``` + +3. **参数传递优化**: + - comment_max_length参数正确传递到Jinja2模板 + - 确保LLM生成符合长度要求的comment + +**效果**: +- 识别数量从9个增加到13个 +- 每个填空都有独立的批改标记 +- 批改标记精准定位到每个答案位置 + +### 2026-03-26 JSON解析优化(重要) +**问题**:LLM输出可能不完整(被max_completion_tokens截断),导致JSON解析失败 + +**修复内容**: +1. **新增fix_incomplete_json函数**: + - 自动检测缺失的括号(}和]) + - 自动补全缺失的括号,使JSON完整 + - 示例:`{"results": [{"id": 1}` → 自动补全为 `{"results": [{"id": 1}]}` + +2. **增强JSON解析流程**: + - 第一步:尝试直接解析 + - 第二步:尝试修复不完整的JSON(补全括号) + - 第三步:尝试提取JSON对象 + - 第四步:尝试修复提取的JSON + +3. **移除错误的截断逻辑**: + - 不再在解析后截断comment(可能破坏转义字符) + - 完全依赖LLM遵守comment_max_length限制 + - 通过Prompt明确要求LLM控制comment长度 + +4. **参数正确传递**: + - comment_max_length参数正确传递到Prompt + - LLM根据该参数生成符合长度的comment + +**效果**: +- JSON解析成功率大幅提升 +- 能够处理不完整的JSON输出 +- comment长度由LLM控制,避免截断破坏格式 + +### 2026-03-26 识别优化:禁止标注实验装置图(重要) +**问题**: +1. 在实验装置图(如弹簧测力计、烧杯等)上标注了批改气泡 +2. 坐标定位不够精准 + +**修复内容**: +1. **Prompt优化**: + - 明确禁止标注实验装置图、示意图、电路图 + - 明确禁止标注图中标注的字母(如A、B、C、D、E、F、G) + - 强调只标注学生手写答案 + +2. **工程规范优化**: + - 从config文件读取sp和up(符合工程规范) + - 使用Jinja2模板渲染Prompt + - 代码中只保留动态部分构建(标准答案、图片尺寸等) + +3. **识别流程优化**: + - 找题号 → 找学生答案 → 框选答案 → 判断正误 + - 强调学生答案的特征:手写、填写空白处、计算结果 + +**效果**:不再误标注实验装置图,只标注学生手写答案 + +### 2026-03-26 新增并行数量控制参数 +**优化前**:硬编码并发数限制为3,不够灵活 +**优化后**:添加max_concurrent参数,默认值10,用户可自定义 + +**具体优化**: +1. **新增参数**:`max_concurrent`(可选,默认10) +2. **修改位置**: + - `GraphInput.max_concurrent: Optional[int] = 10` + - `GlobalState.max_concurrent: int = 10` + - `ProcessImagesInput.max_concurrent: int` +3. **使用方式**: + ```json + { + "homework_images": [...], + "max_concurrent": 5 // 最多同时处理5张图片 + } + ``` + +**效果**:用户可根据服务器性能和网络情况灵活调整并发数 + +### 2026-03-26 学科变更 +**修改**:将所有"数学"改为"物理" +- 节点描述:数学作业 → 物理作业 +- Prompt中的学科引用:数学 → 物理 +- 配置文件说明更新 + +### 2026-03-25 多图片并行处理优化 +**优化前**:多图片串行处理,总时间 = 单张图片时间 × 图片数量 +**优化后**:多图片并行处理(并发数限制为3),总时间大幅缩短 + +**具体优化**: +1. **并行处理架构**:使用 `ThreadPoolExecutor` 并行调用子图处理每张图片 + - 最多同时处理3张图片 + - 结果按 `image_index` 正确排序,保证顺序一致性 +2. **性能提升**: + - 3张图片:时间减少约66%(从3份时间 → 1份时间) + - 5张图片:时间减少约80%(从5份时间 → 约2份时间,分两批并行) +3. **质量保证**: + - 每张图片独立处理,互不影响 + - 识别逻辑、批改逻辑完全相同,质量不受影响 + +### 2026-03-26 坐标定位修复(重要) +**问题**:坐标定位特别不准,批改标记位置错误 +**原因**:Y坐标修正逻辑错误,导致坐标被错误缩放 + +**修复内容**: +1. **坐标系统重构**:从绝对坐标改为相对坐标(0-1000)系统 + - AI返回相对坐标(0-1000),(0,0)为图片左上角,(1000,1000)为右下角 + - 代码将相对坐标转换为绝对坐标:`绝对X = 相对X × width / 1000`,`绝对Y = 相对Y × height / 1000` + +2. **Prompt优化**: + - 明确要求AI返回相对坐标(0-1000) + - 添加坐标系统说明和示例 + +3. **转换逻辑修正**: + - 移除错误的Y坐标修正(`Y × height_ratio`) + - 实现正确的相对坐标到绝对坐标转换 + +**效果**:坐标定位准确,批改标记位置正确 + +### 2026-03-26 题目和答案识别优化(重要) +**问题**: +1. 无法准确区分"题干"和"学生答案" +2. 批改气泡不在学生答案位置 +3. 题干位置被误标注为答案 + +**修复内容**: +1. **Prompt重写**: + - 明确定义"题干"和"学生答案"的区别 + - 强调只标注学生手写答案,不标注印刷体题干 + - 添加识别流程指导 + +2. **坐标定位优化**: + - 自动计算mark_position:答案框右侧30像素,垂直居中 + - 添加边界检查,确保不超出图片范围 + - 不再依赖AI返回的mark_position(可能不准确) + +3. **识别指导**: + - 题号识别:如1、2、3、(1)、(2)等 + - 答案定位:学生手写内容(不是印刷体) + - bbox框选:准确框选学生答案区域 + +**效果**:更准确地区分题干和答案,批改气泡位置更精准 + +### 2026-03-25 批改速度优化 +**优化前**:每张图片需要3次LLM调用(识别+批改+整体评价) +**优化后**:每张图片只需1次LLM调用 + +**具体优化**: +1. **合并识别和批改**:将`homework_recognize`和`correction_judge`合并为`recognize_and_correct`节点 + - 识别题目、学生答案、坐标的同时进行批改 + - 减少一次LLM调用,速度提升约50% + +2. **简化整体评价**:不再调用LLM生成整体评价 + - 使用规则直接生成评价内容 + - 根据得分率和错误数量生成个性化评语 + - 减少一次LLM调用 + +3. **子图节点精简**:从5个节点减少到4个节点 + - 移除:homework_recognize、correction_judge + - 新增:recognize_and_correct(合并节点) + - 保留:image_preprocess、result_merge、wrap_result + +**效果**: +- LLM调用次数:每张图片从3次减少到1次 +- 预计批改时间减少约60% + +### 2026-03-25 新增输入参数控制 +1. **新增 `comment_max_length` 参数**:控制评语最大字数,默认100字 +2. **新增 `grade_standards` 参数**:自定义评价等级标准 + - 支持自定义各等级的最低得分率百分比 + - 支持自定义各等级的描述 + - 默认标准:A+(≥95%)、A(≥90%)、B(≥80%)、C(≥70%)、D(<70%) +3. **使用方式**: + ```json + { + "homework_images": [...], + "comment_max_length": 50, // 评语最多50字 + "grade_standards": { + "A+": {"min_percentage": 98, "description": "完美"}, + "A": {"min_percentage": 90, "description": "优秀"}, + ... + } + } + ``` + +### 2026-03-25 评语优化与整体评价 +1. **评语具体化**:批改评语要求具体说明对错原因 + - 正确时:说明为什么正确 + - 错误时:指出错误原因并给出正确答案 + - 部分正确时:说明哪些对了哪些错了 + - 字数限制:50字以内,最多不超过100字 + - 不要显示思考过程,只输出结果 +2. **评语示例**: + - 选择题正确:答案为B,与标准答案一致,正确。 + - 填空题错误:答案应为8+√7和8-√7,学生只写了一个,不完整。 + - 解答题正确:解题过程完整,步骤清晰,结果正确。 + - 计算题错误:计算过程有误,正确答案是m=2,建议检查移项步骤。 +3. **整体评价**:根据所有批改内容自动生成简短的整体评价 + - 调用LLM生成个性化评价 + - 评价不超过50字 + - 包含主要问题或优点 + - 给出简短建议 +4. **HTML报告优化**:在统计总览后显示整体评价区域 + +### 2026-03-25 自动旋转功能 +1. **新增横向图片自动旋转**:如果上传的图片宽度大于高度(横向图片),系统会自动旋转-90度使其变为纵向 +2. **旋转时机**:在图像预处理阶段,下载图片后、缩放前进行旋转 +3. **旋转方向**:逆时针旋转90度(rotate(-90)),确保文字方向正确 +4. **日志记录**:添加详细的旋转日志,便于调试 + +### 2026-03-25 多图片批改功能 +1. **新增多图片支持**:从单图片批改升级为支持多图片批量批改 +2. **新增子图架构**:创建 `loop_graph.py` 封装单图片处理流程 +3. **新增循环节点**:创建 `process_images_node.py` 循环调用子图处理每张图片 +4. **重构状态定义**: + - `GraphInput.homework_image` → `homework_images: List[File]` + - 新增 `SubgraphState`、`SubgraphInput`、`SubgraphOutput` 子图状态 + - 新增 `SingleImageResult` 单图片批改结果 + - 新增 `FinalResult.image_results` 多图片结果列表 +5. **重构HTML生成**:支持生成包含所有图片批改标注的HTML报告 +6. **优化主图编排**:简化为三节点线性流程(doc_extract → process_images → html_generate) + +### 2026-03-25 双模式批改机制 +1. **新增智能降级逻辑**:优先使用标准答案,无标准答案时自动切换专业老师模式 +2. **修改state.py**:`answer_doc_url`改为可选字段,支持不提供答案URL的场景 +3. **升级correction_judge_node**:实现题目分离逻辑,有标准答案和无标准答案分别处理 +4. **更新Prompt**:批改节点支持两种模式(标准答案模式 + 专业老师模式) +5. **优化doc_extract_node**:无URL时返回空列表,不中断工作流 + +### 2026-03-25 Word答案解析功能 +1. 新增 `doc_extract_node` 节点:从Word文件(.docx)提取题干和标准答案 +2. 使用 python-docx 提取 Word 文档内容 +3. 并行处理架构:图像识别与答案解析同时进行 +4. 基于 Word 中的标准答案进行精准批改 + +### 2026-03-25 OCR识别能力优化 +1. 问题:识别节点把学生答案中的"8"错认成"9",导致误判 +2. 优化识别节点prompt:增加OCR识别特别提示,强调区分8和9、6和0、1和7等相似字符 +3. 效果:第7题正确识别为"8+√7,8-√7",满分通过 + +### 2026-03-25 批改能力升级 +1. 升级批改节点模型:`doubao-seed-1-6-vision-250815` → `doubao-seed-2-0-pro-260215` +2. 原因:较小模型对选择题判断准确率不足 +3. 效果:选择题判断准确率大幅提升,推理过程更严谨 + +### 2026-03-24 重构(学习豆包APP方式) +1. 从8个节点简化为5个节点(现调整为子图+主图架构) +2. 采用一体化识别:AI识别answer_bbox,代码计算mark_position +3. 实现精准坐标计算,Y坐标与答案垂直中心完美对齐 + +## TODO +- 提供真实的多页作业图片进行完整流程测试 +- 优化HTML报告的图片展示布局 +- 支持PDF格式答案文档 diff --git a/README.md b/README.md new file mode 100644 index 0000000..f288427 --- /dev/null +++ b/README.md @@ -0,0 +1,12 @@ +# 项目结构说明 + +# 本地运行 +## 运行流程 +bash scripts/local_run.sh -m flow + +## 运行节点 +bash scripts/local_run.sh -m node -n node_name + +# 启动HTTP服务 +bash scripts/http_run.sh -m http -p 5000 + diff --git a/assets/1.png b/assets/1.png new file mode 100644 index 0000000..74909f0 Binary files /dev/null and b/assets/1.png differ diff --git a/assets/ScreenShot_2026-03-26_110020_335.png b/assets/ScreenShot_2026-03-26_110020_335.png new file mode 100644 index 0000000..e911d29 Binary files /dev/null and b/assets/ScreenShot_2026-03-26_110020_335.png differ diff --git a/assets/你的.png b/assets/你的.png new file mode 100644 index 0000000..8c6850e Binary files /dev/null and b/assets/你的.png differ diff --git a/assets/你的1.png b/assets/你的1.png new file mode 100644 index 0000000..8c6850e Binary files /dev/null and b/assets/你的1.png differ diff --git a/assets/你的2.png b/assets/你的2.png new file mode 100644 index 0000000..027782d Binary files /dev/null and b/assets/你的2.png differ diff --git a/assets/你的3.png b/assets/你的3.png new file mode 100644 index 0000000..f0a929c Binary files /dev/null and b/assets/你的3.png differ diff --git a/assets/数学作业批改.jpg b/assets/数学作业批改.jpg new file mode 100644 index 0000000..02e307f Binary files /dev/null and b/assets/数学作业批改.jpg differ diff --git a/config/answer_extract_llm_cfg.json b/config/answer_extract_llm_cfg.json new file mode 100644 index 0000000..7bf9a1b --- /dev/null +++ b/config/answer_extract_llm_cfg.json @@ -0,0 +1,17 @@ +{ + "config": { + "temperature": 0.1, + "frequency_penalty": 0, + "top_p": 0.9, + "max_tokens": 4096, + "max_completion_tokens": 8192, + "thinking_type": "enabled", + "reasoning_effort": "medium", + "response_format": "text", + "json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}", + "model": "doubao-seed-2-0-pro-260215" + }, + "tools": [], + "sp": "你是一位专业的文档识别专家,擅长从作业图片中提取答案区域的位置信息。", + "up": "请根据题目位置信息,识别每道题对应的答案区域边界框。" +} \ No newline at end of file diff --git a/config/answer_recognize_llm_cfg.json b/config/answer_recognize_llm_cfg.json new file mode 100644 index 0000000..bf3a563 --- /dev/null +++ b/config/answer_recognize_llm_cfg.json @@ -0,0 +1,17 @@ +{ + "config": { + "temperature": 0.1, + "frequency_penalty": 0, + "top_p": 0.9, + "max_tokens": 4096, + "max_completion_tokens": 8192, + "thinking_type": "enabled", + "reasoning_effort": "medium", + "response_format": "text", + "json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}", + "model": "doubao-seed-2-0-pro-260215" + }, + "tools": [], + "sp": "你是一位专业的OCR文字识别专家,擅长从图片中识别手写和印刷的文字内容。", + "up": "请识别答案区域中的文字内容,返回准确的识别结果。" +} \ No newline at end of file diff --git a/config/comprehensive_correction_cfg.json b/config/comprehensive_correction_cfg.json new file mode 100644 index 0000000..d07d04c --- /dev/null +++ b/config/comprehensive_correction_cfg.json @@ -0,0 +1,12 @@ +{ + "config": { + "model": "doubao-seed-1-6-vision-250815", + "temperature": 0.1, + "top_p": 0.95, + "max_completion_tokens": 16384, + "thinking": "disabled" + }, + "tools": [], + "sp": "你是一位专业的初中物理教师,负责批改学生的物理作业。", + "up": "请按照要求完成作业批改任务。" +} \ No newline at end of file diff --git a/config/correction_judge_llm_cfg.json b/config/correction_judge_llm_cfg.json new file mode 100644 index 0000000..41fa597 --- /dev/null +++ b/config/correction_judge_llm_cfg.json @@ -0,0 +1,17 @@ +{ + "config": { + "temperature": 0, + "frequency_penalty": 0, + "top_p": 0.95, + "max_tokens": 8192, + "max_completion_tokens": 16384, + "thinking_type": "enabled", + "reasoning_effort": "high", + "response_format": "text", + "json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}", + "model": "doubao-seed-2-0-pro-260215" + }, + "tools": [], + "sp": "你是一位资深的初中物理特级教师,拥有20年以上教学经验,擅长精准批改学生的物理作业。\n\n【核心能力】\n1. **精确判断能力**:对选择题、填空题、解答题都能做出准确的正误判断\n2. **严谨推理能力**:能够逐步验证学生的计算过程和结论\n3. **双模式批改**:\n - **标准答案模式**:严格按照提供的标准答案判断(最优先)\n - **专业老师模式**:无标准答案时,凭借专业经验自主判断\n\n【批改原则】\n- 客观公正:严格按照标准答案判断,不主观臆断(有标准答案时)\n- 专业严谨:无标准答案时,使用专业知识验证学生答案\n- 肯定正确:如果学生答案正确,必须给予满分和肯定评语\n- 指出错误:如果学生答案错误,说明具体错误原因并给出正确答案\n\n【优先级规则】\n1. 最优先:使用提供的标准答案批改\n2. 降级:标准答案中未找到对应题目时,使用专业老师批改", + "up": "请批改以下学生的物理作业,判断每道题答案的正误并给出详细评语。" +} diff --git a/config/doc_extract_llm_cfg.json b/config/doc_extract_llm_cfg.json new file mode 100644 index 0000000..3df9e51 --- /dev/null +++ b/config/doc_extract_llm_cfg.json @@ -0,0 +1,17 @@ +{ + "config": { + "temperature": 0.1, + "frequency_penalty": 0, + "top_p": 0.95, + "max_tokens": 8192, + "max_completion_tokens": 16384, + "thinking_type": "enabled", + "reasoning_effort": "high", + "response_format": "text", + "json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}", + "model": "doubao-seed-2-0-pro-260215" + }, + "tools": [], + "sp": "你是一位资深的初中物理教师,擅长从试卷中提取题目和标准答案。你的核心能力:\n\n1. **题目识别能力**:能够准确识别试卷中的所有题目,包括大题和小题\n2. **答案提取能力**:能够准确提取每道题的标准答案\n3. **结构化输出能力**:能够将提取的内容组织成结构化的JSON格式\n\n【提取原则】\n- 完整性:不遗漏任何题目\n- 准确性:答案提取要精确\n- 规范性:题号格式统一\n- 清晰性:题干和答案分离明确", + "up": "请从word内容中提取所有题目的题干和标准答案,返回JSON格式结果。" +} \ No newline at end of file diff --git a/config/homework_correction_cfg.json b/config/homework_correction_cfg.json new file mode 100644 index 0000000..df2b72a --- /dev/null +++ b/config/homework_correction_cfg.json @@ -0,0 +1,12 @@ +{ + "config": { + "model": "doubao-seed-1-6-vision-250815", + "temperature": 0.1, + "top_p": 0.95, + "max_completion_tokens": 8192, + "thinking": "disabled" + }, + "tools": [], + "sp": "# 角色定义\n你是一位专业的初中物理作业批改助手,具有丰富的物理教学经验和精准的视觉识别能力。你能够准确识别作业图片中的题目内容、学生答案,并判断答案的正确性。\n\n# 任务目标\n分析上传的初中物理作业图片,识别每道题目及其学生答案,判断答案是否正确,并输出结构化的批改结果JSON。\n\n# 工作流上下文\n- **Input**:作业图片(图片URL)\n- **Process**:\n 1. 仔细识别图片中的所有题目,包括题号、题目内容\n 2. 识别每道题的学生答案,注意区分小题(如(1)(2)(3))\n 3. 判断每个答案的正确性,对于解答题需要检查计算过程和结果\n 4. 为每个批改标记确定在原图上的相对坐标位置(批改标记应放置在答案末尾右侧)\n 5. 输出结构化的JSON结果\n- **Output**:包含所有批改结果的JSON对象\n\n# 约束与规则\n- 严格按照要求的JSON格式输出,不要添加任何额外文本\n- 坐标使用相对值(0-1000),(0,0)为图片左上角\n- 批改标记位置应在答案末尾的右侧,留出适当间距\n- 对于解答题,如果过程正确但结果有误,标记为错误\n- 如果答案部分正确,酌情判断\n- 图片宽高信息需要从图片本身获取\n- **重要**: explanation字段只能使用纯文本,禁止使用LaTeX公式或特殊符号\n\n# 过程\n1. 识别题目结构:扫描图片,定位所有题目,记录题号和小题号\n2. 答案识别:逐题识别学生的作答内容\n3. 正确性判断:\n - 对于计算题:检查计算过程和结果\n - 对于证明题:检查证明逻辑是否完整\n - 对于作图题:检查图形是否正确\n4. 坐标定位:确定每道题答案末尾的坐标位置\n5. 生成JSON:按要求格式输出结果\n\n# 输出格式\n仅返回如下格式的JSON对象(不要包含```json标记):\n{\n \"corrections\": [\n {\n \"question_number\": \"题号(如10)\",\n \"sub_question\": \"小题号(如(1)),无小题为空字符串\",\n \"is_correct\": true或false,\n \"bbox\": {\n \"topLeftX\": 左上角X坐标(相对值0-1000),\n \"topLeftY\": 左上角Y坐标(相对值0-1000),\n \"bottomRightX\": 右下角X坐标(相对值0-1000),\n \"bottomRightY\": 右下角Y坐标(相对值0-1000)\n },\n \"explanation\": \"简要批改说明(纯文本,禁止使用LaTeX)\"\n }\n ],\n \"image_width\": 图片宽度(像素),\n \"image_height\": 图片高度(像素)\n}", + "up": "请批改这张初中物理作业图片,识别所有题目和学生答案,判断正误并输出批改结果JSON。注意:explanation字段只能使用纯文本,禁止使用LaTeX公式。图片URL:{{image_url}}" +} \ No newline at end of file diff --git a/config/homework_recognize_llm_cfg.json b/config/homework_recognize_llm_cfg.json new file mode 100644 index 0000000..c1eb6a6 --- /dev/null +++ b/config/homework_recognize_llm_cfg.json @@ -0,0 +1,12 @@ +{ + "config": { + "model": "doubao-seed-2-0-pro-260215", + "temperature": 0.0, + "top_p": 0.9, + "max_completion_tokens": 8192, + "thinking": "disabled" + }, + "tools": [], + "sp": "# 角色\n你是物理作业批改助手。\n\n# 禁止标注\n- 印刷体文字、实验装置图、图中字母\n\n# 需要标注\n- 学生手写答案\n\n# 坐标\n- 相对坐标(0-1000),answer_bbox: [x1, y1, x2, y2]\n\n# ⚠️ 重要:拆分填空题\n- 一道题有多个填空时,**每个空单独识别为一个题目**\n- 题号格式:\"3(1)第一空\"、\"3(1)第二空\"、\"4(2)第一空\"、\"4(2)第二空\"\n- 每个空单独批改,单独打分\n- 示例:\n - 题目(1)有两个空 → 识别为\"3(1)第一空\"和\"3(1)第二空\"两个题目\n - 题目(2)有一个空 → 识别为\"3(2)\"一个题目\n\n# 输出格式\n{\"results\": [{\"question_id\": \"题号\", \"student_answer\": \"答案\", \"answer_bbox\": [x1, y1, x2, y2], \"status\": \"correct或incorrect\", \"score\": 分数, \"full_score\": 满分, \"comment\": \"结论\"}]}\n\ncomment格式:\"正确\"或\"错误,应为X\"(不超过50字)", + "up": "批改物理作业。**每个填空单独识别**。输出JSON,comment不超过{{comment_max_length}}字。图片:{{image_url}}" +} diff --git a/config/question_locate_llm_cfg.json b/config/question_locate_llm_cfg.json new file mode 100644 index 0000000..7c330c2 --- /dev/null +++ b/config/question_locate_llm_cfg.json @@ -0,0 +1,17 @@ +{ + "config": { + "temperature": 0.1, + "frequency_penalty": 0, + "top_p": 0.9, + "max_tokens": 4096, + "max_completion_tokens": 8192, + "thinking_type": "enabled", + "reasoning_effort": "medium", + "response_format": "text", + "json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}", + "model": "doubao-seed-2-0-pro-260215" + }, + "tools": [], + "sp": "你是一位专业的初中物理作业识别专家,擅长从作业图片中定位题目位置和提取答案区域。", + "up": "请识别这张作业图片中的所有题目位置,返回准确的边界框坐标。" +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3f94b79 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,156 @@ +alembic==1.16.5 +annotated-doc==0.0.4 +annotated-types==0.7.0 +anyio==4.12.1 +APScheduler==3.11.2 +astroid==3.1.0 +Authlib==1.6.9 +beautifulsoup4==4.14.3 +boto3==1.40.61 +botocore==1.40.61 +cachetools==6.2.6 +certifi==2026.2.25 +cffi==2.0.0 +chardet==5.2.0 +charset-normalizer==3.4.4 +click==8.3.1 +coverage==7.13.4 +coze-coding-dev-sdk==0.5.11 +coze-coding-utils==0.2.4 +coze-workload-identity==0.1.4 +cozeloop==0.1.25 +cryptography==46.0.5 +cssselect==1.4.0 +cssutils==2.11.1 +dbus-python==1.3.2 +deprecation==2.1.0 +dill==0.4.1 +distro==1.9.0 +docx2python==3.5.0 +et_xmlfile==2.0.0 +fastapi==0.121.2 +fsspec==2026.2.0 +gitdb==4.0.12 +gitignore_parser==0.1.13 +GitPython==3.1.45 +greenlet==3.3.2 +h11==0.16.0 +h2==4.3.0 +hpack==4.1.0 +httpcore==1.0.9 +httpx==0.28.1 +httpx-ws==0.8.2 +hyperframe==6.1.0 +idna==3.11 +inflect==7.5.0 +iniconfig==2.3.0 +isort==5.13.2 +Jinja2==3.1.6 +jiter==0.13.0 +jmespath==1.1.0 +jsonpatch==1.33 +jsonpointer==3.0.0 +langchain==1.0.3 +langchain-core==1.0.2 +langchain-openai==1.0.1 +langgraph==1.0.2 +langgraph-checkpoint==3.0.0 +langgraph-checkpoint-postgres==3.0.1 +langgraph-prebuilt==1.0.2 +langgraph-sdk==0.2.9 +langsmith==0.4.39 +lxml==6.0.2 +Mako==1.3.10 +Markdown==3.10.2 +markdown-it-py==4.0.0 +MarkupSafe==3.0.3 +mccabe==0.7.0 +mdurl==0.1.2 +mmh3==5.2.1 +more-itertools==10.8.0 +multidict==6.7.1 +numpy==2.2.6 +openai==2.24.0 +opencv-python==4.12.0.88 +openpyxl==3.1.5 +orjson==3.11.7 +ormsgpack==1.12.2 +packaging==25.0 +pandas==2.2.2 +paragraphs==1.0.1 +pillow==10.3.0 +platformdirs==4.9.2 +pluggy==1.6.0 +postgrest==2.27.3 +propcache==0.4.1 +psutil==7.1.3 +psycopg==3.3.0 +psycopg-binary==3.3.0 +psycopg-pool==3.3.0 +psycopg2-binary==2.9.9 +pycparser==3.0 +pydantic==2.12.3 +pydantic_core==2.41.4 +Pygments==2.19.2 +PyGObject==3.48.2 +pyiceberg==0.11.1 +PyJWT==2.10.1 +pylint==3.1.0 +pyparsing==3.3.2 +pypdf==6.4.1 +pyroaring==1.0.3 +pytest==9.0.1 +pytest-asyncio==1.3.0 +pytest-cov==7.0.0 +pytest-mock==3.15.1 +python-dateutil==2.9.0.post0 +python-docx==1.2.0 +python-dotenv==1.2.1 +python-pptx==1.0.2 +pytz==2026.1.post1 +PyYAML==6.0.3 +realtime==2.27.3 +regex==2026.2.28 +reportlab==4.4.10 +requests==2.32.5 +requests-toolbelt==1.0.0 +rich==14.2.0 +s3transfer==0.14.0 +setuptools==68.1.2 +six==1.17.0 +smmap==5.0.2 +sniffio==1.3.1 +soupsieve==2.8.3 +sqlacodegen==4.0.2 +SQLAlchemy==2.0.44 +starlette==0.49.3 +storage3==2.27.3 +StrEnum==0.4.15 +strictyaml==1.7.3 +supabase==2.27.3 +supabase-auth==2.27.3 +supabase-functions==2.27.3 +tenacity==9.1.4 +tiktoken==0.12.0 +tomlkit==0.14.0 +tqdm==4.67.3 +typeguard==4.5.1 +types-html5lib==1.1.11.20251117 +types-lxml==2026.2.16 +types-webencodings==0.5.0.20251108 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +tzdata==2025.3 +tzlocal==5.3.1 +Unidecode==1.4.0 +urllib3==2.6.3 +uvicorn==0.38.0 +watchdog==6.0.0 +websockets==15.0.1 +wheel==0.42.0 +wsproto==1.3.2 +xlrd==2.0.2 +xlsxwriter==3.2.9 +xxhash==3.6.0 +yarl==1.23.0 +zstandard==0.25.0 diff --git a/scripts/http_run.sh b/scripts/http_run.sh new file mode 100644 index 0000000..61070ff --- /dev/null +++ b/scripts/http_run.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +set -e +# 导出环境变量 + +WORK_DIR="${COZE_WORKSPACE_PATH:-.}" +PORT=8000 + +usage() { + echo "用法: $0 -p <端口>" +} + +while getopts "p:h" opt; do + case "$opt" in + p) + PORT="$OPTARG" + ;; + h) + usage + exit 0 + ;; + \?) + echo "无效选项: -$OPTARG" + usage + exit 1 + ;; + esac +done + + +python ${WORK_DIR}/src/main.py -m http -p $PORT diff --git a/scripts/load_env.py b/scripts/load_env.py new file mode 100644 index 0000000..edc62de --- /dev/null +++ b/scripts/load_env.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +""" +加载项目环境变量脚本 +通过 coze_workload_identity.Client 获取项目环境变量并输出 export 语句 +使用方式: eval $(python load_env.py) +""" + +import os +import sys + +# 添加 app 目录到 Python 路径 +workspace_path = os.getenv("COZE_WORKSPACE_PATH", "/workspace/projects") +app_dir = os.path.join(workspace_path, 'src') +if app_dir not in sys.path: + sys.path.insert(0, app_dir) + +try: + from coze_workload_identity import Client + + client = Client() + env_vars = client.get_project_env_vars() + client.close() + + # 输出 export 语句格式的环境变量 + for env_var in env_vars: + # 转义特殊字符 + value = env_var.value.replace("'", "'\\''") + print(f"export {env_var.key}='{value}'") + + # 输出成功消息到 stderr,不影响 eval + print(f"# Successfully loaded {len(env_vars)} environment variables", file=sys.stderr) + +except Exception as e: + print(f"# Error loading environment variables: {e}", file=sys.stderr) + sys.exit(1) diff --git a/scripts/load_env.sh b/scripts/load_env.sh new file mode 100644 index 0000000..2c835b1 --- /dev/null +++ b/scripts/load_env.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +# 加载项目环境变量脚本 +# 使用方式: source ./load_env.sh 或 . ./load_env.sh + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +eval $(python3 "$SCRIPT_DIR/load_env.py") diff --git a/scripts/local_run.sh b/scripts/local_run.sh new file mode 100644 index 0000000..ba28345 --- /dev/null +++ b/scripts/local_run.sh @@ -0,0 +1,75 @@ +#!/bin/bash + +set -e + +mode="" +node="" +input="" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +WORK_DIR="${COZE_WORKSPACE_PATH:-$(dirname "$SCRIPT_DIR")}" + +usage() { + echo "用法: $0 -m <模式> [-n <节点ID>] [-i <输入JSON>]" + echo "" + echo "参数说明:" + echo " -m <模式> 运行模式: http, flow, node, agent" + echo " -n <节点ID> 节点ID (仅在 node 模式下需要)" + echo " -i <输入JSON> 输入数据,支持 JSON 字符串或纯文本" + echo " -h 显示帮助信息" + echo "" + echo "示例:" + echo " $0 -m flow" + echo " $0 -m flow -i '{\"text\": \"你好\"}'" + echo " $0 -m flow -i '你好'" + echo " $0 -m node -n node_1 -i '{\"text\": \"测试\"}'" +} + +while getopts "m:n:i:h" opt; do + case "$opt" in + m) + mode="$OPTARG" + ;; + n) + node="$OPTARG" + ;; + i) + input="$OPTARG" + ;; + h) + usage + exit 0 + ;; + \?) + echo "无效选项: -$OPTARG" + usage + exit -1 + ;; + esac +done + +if [ -z "$mode" ]; then + echo "错误: 必须指定 -m 参数" + usage + exit -1 +fi + +# Load environment variables +if [ -f "${SCRIPT_DIR}/load_env.sh" ]; then + echo "Loading environment variables..." + source "${SCRIPT_DIR}/load_env.sh" +fi + +# Build python command +cmd="python ${WORK_DIR}/src/main.py -m \"$mode\"" + +if [ -n "$node" ]; then + cmd="$cmd -n \"$node\"" +fi + +if [ -n "$input" ]; then + cmd="$cmd -i '$input'" +fi + +# Execute command +echo "Executing: $cmd" +eval $cmd diff --git a/scripts/pack.sh b/scripts/pack.sh new file mode 100644 index 0000000..d189829 --- /dev/null +++ b/scripts/pack.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +pip freeze --exclude watchdog > requirements.txt \ No newline at end of file diff --git a/scripts/setup.sh b/scripts/setup.sh new file mode 100644 index 0000000..fc6c6a8 --- /dev/null +++ b/scripts/setup.sh @@ -0,0 +1,9 @@ +# 初始化目录 +if [ "$COZE_PROJECT_ENV" = "DEV" ]; then + if [ ! -d "${COZE_WORKSPACE_PATH}/assets" ]; then + mkdir -p "${COZE_WORKSPACE_PATH}/assets" + fi +fi + +# 安装Python三方包依赖 +pip install -r requirements.txt diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/agents/__init__.py b/src/agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/graphs/__init__.py b/src/graphs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/graphs/graph.py b/src/graphs/graph.py new file mode 100644 index 0000000..b1f8397 --- /dev/null +++ b/src/graphs/graph.py @@ -0,0 +1,39 @@ +"""初中物理作业批改工作流主图编排 - 支持多图片批改""" +from langgraph.graph import StateGraph, END +from langchain_core.runnables import RunnableConfig +from langgraph.runtime import Runtime + +from graphs.state import ( + GlobalState, + GraphInput, + GraphOutput +) + +# 导入节点 +from graphs.nodes.doc_extract_node import doc_extract_node +from graphs.nodes.process_images_node import process_images_node + + +# 创建状态图 +builder = StateGraph( + GlobalState, + input_schema=GraphInput, + output_schema=GraphOutput +) + +# 添加节点 +builder.add_node("doc_extract", doc_extract_node, metadata={"type": "agent", "llm_cfg": "config/doc_extract_llm_cfg.json"}) +builder.add_node("process_images", process_images_node, metadata={"type": "looparray"}) + +# 设置入口点 +builder.set_entry_point("doc_extract") + +# 添加边 - 线性流程 +# 1. 先解析Word答案(如果提供了URL) +builder.add_edge("doc_extract", "process_images") + +# 2. 循环处理每张图片并生成最终结果 +builder.add_edge("process_images", END) + +# 编译图 +main_graph = builder.compile() diff --git a/src/graphs/loop_graph.py b/src/graphs/loop_graph.py new file mode 100644 index 0000000..9062f8c --- /dev/null +++ b/src/graphs/loop_graph.py @@ -0,0 +1,66 @@ +"""单图片处理子图:封装完整的单张图片批改流程(优化版:合并识别和批改)""" +from langgraph.graph import StateGraph, END +from langchain_core.runnables import RunnableConfig +from langgraph.runtime import Runtime +from coze_coding_utils.runtime_ctx.context import Context + +from graphs.state import ( + SubgraphState, + SubgraphInput, + SubgraphOutput, + SingleImageResult, + ImageInfo +) + +# 导入子图需要的节点 +from graphs.nodes.image_preprocess_node import image_preprocess_node +from graphs.nodes.recognize_and_correct_node import recognize_and_correct_node +from graphs.nodes.result_merge_node import result_merge_node + + +def wrap_result_node( + state: SubgraphState, + config: RunnableConfig, + runtime: Runtime[Context] +) -> SubgraphOutput: + """ + title: 包装子图结果 + desc: 将子图的中间状态包装为SingleImageResult输出 + integrations: + """ + from graphs.state import SingleImageResult + + image_result = SingleImageResult( + image_index=state.image_index, + image_info=state.image_info, + image_url=state.image_url, + annotations=state.annotations + ) + + return SubgraphOutput(image_result=image_result) + + +# 创建单图片处理子图 +subgraph_builder = StateGraph( + SubgraphState, + input_schema=SubgraphInput, + output_schema=SubgraphOutput +) + +# 添加节点(优化后只需4个节点) +subgraph_builder.add_node("image_preprocess", image_preprocess_node) +subgraph_builder.add_node("recognize_and_correct", recognize_and_correct_node, metadata={"type": "agent", "llm_cfg": "config/homework_recognize_llm_cfg.json"}) +subgraph_builder.add_node("result_merge", result_merge_node) +subgraph_builder.add_node("wrap_result", wrap_result_node) + +# 设置入口点 +subgraph_builder.set_entry_point("image_preprocess") + +# 添加边(优化后流程:预处理 -> 识别+批改 -> 整合 -> 包装) +subgraph_builder.add_edge("image_preprocess", "recognize_and_correct") +subgraph_builder.add_edge("recognize_and_correct", "result_merge") +subgraph_builder.add_edge("result_merge", "wrap_result") +subgraph_builder.add_edge("wrap_result", END) + +# 编译子图 +single_image_subgraph = subgraph_builder.compile() diff --git a/src/graphs/nodes/__init__.py b/src/graphs/nodes/__init__.py new file mode 100644 index 0000000..cbf3221 --- /dev/null +++ b/src/graphs/nodes/__init__.py @@ -0,0 +1,10 @@ +"""工作流节点模块 - 学习豆包APP的一体化识别方式""" +from graphs.nodes.image_preprocess_node import image_preprocess_node +from graphs.nodes.result_merge_node import result_merge_node +from graphs.nodes.recognize_and_correct_node import recognize_and_correct_node + +__all__ = [ + "image_preprocess_node", + "result_merge_node", + "recognize_and_correct_node", +] diff --git a/src/graphs/nodes/doc_extract_node.py b/src/graphs/nodes/doc_extract_node.py new file mode 100644 index 0000000..dedf6c7 --- /dev/null +++ b/src/graphs/nodes/doc_extract_node.py @@ -0,0 +1,280 @@ +"""Word答案解析节点:从.docx文件中提取题干和标准答案""" +import os +import json +import re +import logging +import tempfile +import requests +from typing import List +from langchain_core.runnables import RunnableConfig +from langgraph.runtime import Runtime +from coze_coding_utils.runtime_ctx.context import Context +from coze_coding_dev_sdk import LLMClient +from langchain_core.messages import HumanMessage +from docx import Document + +from graphs.state import ( + DocExtractInput, + DocExtractOutput, + CorrectAnswer +) + +logger = logging.getLogger(__name__) + + +def sanitize_json_string(text: str) -> str: + """清理JSON字符串中的无效转义和控制字符""" + text = ''.join(char if ord(char) >= 32 or char in '\n\r\t' else ' ' for char in text) + result = [] + i = 0 + in_string = False + while i < len(text): + char = text[i] + if char == '"' and (i == 0 or (i > 0 and result and result[-1] != '\\')): + in_string = not in_string + result.append(char) + elif in_string and char == '\\': + if i + 1 < len(text): + next_char = text[i + 1] + if next_char in ['"', '\\', '/', 'b', 'f', 'n', 'r', 't']: + result.append(char) + result.append(next_char) + i += 1 + elif next_char == 'u': + result.append(char) + result.append(next_char) + i += 1 + else: + result.append('\\\\') + else: + result.append('\\\\') + elif in_string and char in '\n\r': + result.append('\\n' if char == '\n' else '\\r') + else: + result.append(char) + i += 1 + return ''.join(result) + + +def extract_json_from_text(text: str, key: str = "answers") -> dict: + """从文本中提取JSON对象,多层回退策略""" + import orjson + + # 先尝试直接解析完整JSON + try: + result = json.loads(text) + if key in result: + return result + except (json.JSONDecodeError, ValueError) as e: + logger.debug(f"Direct JSON parse failed: {e}") + + # 尝试找到JSON对象的边界 + try: + # 找到第一个 { 和最后一个 } + start = text.find('{') + end = text.rfind('}') + if start != -1 and end != -1 and end > start: + json_str = text[start:end+1] + result = json.loads(json_str) + if key in result: + return result + except (json.JSONDecodeError, ValueError) as e: + logger.debug(f"Boundary JSON parse failed: {e}") + + # 使用orjson尝试 + try: + start = text.find('{') + end = text.rfind('}') + if start != -1 and end != -1 and end > start: + json_str = text[start:end+1] + result = orjson.loads(json_str) + if key in result: + return result + except Exception as e: + logger.debug(f"Orjson parse failed: {e}") + + # 尝试清理后解析 + try: + cleaned = sanitize_json_string(text) + start = cleaned.find('{') + end = cleaned.rfind('}') + if start != -1 and end != -1 and end > start: + json_str = cleaned[start:end+1] + result = json.loads(json_str) + if key in result: + return result + except Exception as e: + logger.debug(f"Cleaned JSON parse failed: {e}") + + logger.warning(f"Failed to extract JSON with key '{key}' from text (length: {len(text)})") + return {key: []} + + +def download_and_extract_docx(url: str) -> str: + """下载Word文件并提取文本内容""" + logger.info(f"Downloading Word document from: {url}") + + # 下载文件 + response = requests.get(url, timeout=60) + response.raise_for_status() + + # 保存到临时文件 + with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as tmp_file: + tmp_file.write(response.content) + tmp_path = tmp_file.name + + try: + # 使用python-docx解析 + doc = Document(tmp_path) + + # 提取所有段落文本 + paragraphs = [] + for para in doc.paragraphs: + text = para.text.strip() + if text: + paragraphs.append(text) + + # 提取表格中的文本 + for table in doc.tables: + for row in table.rows: + row_text = [] + for cell in row.cells: + cell_text = cell.text.strip() + if cell_text: + row_text.append(cell_text) + if row_text: + paragraphs.append(" | ".join(row_text)) + + doc_text = "\n".join(paragraphs) + logger.info(f"Extracted Word document text length: {len(doc_text)}") + + return doc_text + finally: + # 清理临时文件 + os.unlink(tmp_path) + + +def doc_extract_node( + state: DocExtractInput, + config: RunnableConfig, + runtime: Runtime[Context] +) -> DocExtractOutput: + """ + title: Word答案解析 + desc: 从正确答案Word文件(.docx)中提取题干和标准答案,用于后续批改;如果未提供URL则返回空列表 + integrations: 大语言模型 + """ + ctx = runtime.context + + # 检查是否提供了答案文档URL + if not state.answer_doc_url or not state.answer_doc_url.strip(): + logger.info("No answer document URL provided, will use teacher mode for all questions") + return DocExtractOutput(correct_answers=[]) + + # 1. 下载并提取Word文档内容 + try: + doc_text = download_and_extract_docx(state.answer_doc_url) + except Exception as e: + logger.error(f"Failed to download/extract Word document: {e}") + return DocExtractOutput(correct_answers=[]) + + if not doc_text.strip(): + logger.error("No text content extracted from Word document") + return DocExtractOutput(correct_answers=[]) + + logger.info(f"Word document content preview: {doc_text[:500]}") + + # 2. 使用LLM解析题目和答案 + cfg_file = os.path.join(os.getenv("COZE_WORKSPACE_PATH", ""), config["metadata"]["llm_cfg"]) + with open(cfg_file, "r", encoding="utf-8") as fd: + _cfg = json.load(fd) + + llm_config = _cfg.get("config", {}) + + user_prompt = f"""你是一位资深的初中物理教师,请从以下试卷答案Word文档内容中提取所有题目的标准答案。 + +【Word文档内容】 +{doc_text[:20000]} + +【提取要求】 +1. 识别所有题号,包括大题和小题 +2. 只提取每道题的标准答案,不需要提取题干 +3. 答案格式: + - 选择题:单个字母(A/B/C/D) + - 填空题:数值或表达式 + - 解答题:关键结果 +4. 如果有多个空,用逗号分隔 + +【题号格式规范】 +- 大题题号:直接使用数字,如"1"、"2"、"10" +- 小题题号:使用"大题.小题"格式,如"10.1"、"10.2" + +【重要】 +- 保持答案简洁,不要包含题干 +- 如果文档中有分值,请提取;否则默认3分 + +请返回简洁的JSON格式: +{{ + "answers": [ + {{ + "question_id": "1", + "correct_answer": "B", + "full_score": 3 + }}, + {{ + "question_id": "2", + "correct_answer": "2a-b+c", + "full_score": 3 + }} + ] +}}""" + + client = LLMClient(ctx=ctx) + response = client.invoke( + messages=[HumanMessage(content=user_prompt)], + model=llm_config.get("model", "doubao-seed-2-0-pro-260215"), + temperature=llm_config.get("temperature", 0.1), + max_completion_tokens=llm_config.get("max_completion_tokens", 8192) + ) + + response_text = response.content if isinstance(response.content, str) else " ".join( + item.get("text", "") if isinstance(item, dict) else str(item) + for item in response.content + ).strip() + + logger.info(f"Word extract LLM response: {response_text[:1500]}") + + # 清理markdown标记 + for prefix in ["```json", "```JSON", "```"]: + if response_text.startswith(prefix): + response_text = response_text[len(prefix):] + for suffix in ["```"]: + if response_text.endswith(suffix): + response_text = response_text[:-3] + + # 解析JSON + result_dict = extract_json_from_text(response_text.strip(), "answers") + + correct_answers: List[CorrectAnswer] = [] + for ans in result_dict.get("answers", []): + try: + correct_answers.append(CorrectAnswer( + question_id=str(ans.get("question_id", "")), + parent_id=str(ans.get("parent_id", "")), + is_sub_question=bool(ans.get("is_sub_question", False)), + question_text=str(ans.get("question_text", "")), + correct_answer=str(ans.get("correct_answer", "")), + full_score=int(ans.get("full_score", 3) if ans.get("full_score") is not None else 3), + answer_analysis=str(ans.get("answer_analysis", "")) + )) + except Exception as e: + logger.warning(f"Failed to parse correct answer: {ans}, error: {e}") + continue + + logger.info(f"Parsed {len(correct_answers)} correct answers from Word document") + + # 打印解析结果供调试 + for ans in correct_answers: + logger.info(f" Question {ans.question_id}: {ans.correct_answer} ({ans.full_score}分)") + + return DocExtractOutput(correct_answers=correct_answers) diff --git a/src/graphs/nodes/image_preprocess_node.py b/src/graphs/nodes/image_preprocess_node.py new file mode 100644 index 0000000..683b07b --- /dev/null +++ b/src/graphs/nodes/image_preprocess_node.py @@ -0,0 +1,159 @@ +"""1. 图像预处理节点:下载图片、自动旋转、缩放到固定宽度1000、上传对象存储""" +import os +import logging +import urllib.request +from io import BytesIO +from typing import Tuple +from langchain_core.runnables import RunnableConfig +from langgraph.runtime import Runtime +from coze_coding_utils.runtime_ctx.context import Context +from PIL import Image +from datetime import datetime + +from graphs.state import ( + ImagePreprocessInput, + ImagePreprocessOutput, + ImageInfo +) + +logger = logging.getLogger(__name__) + +# 固定宽度常量 +FIXED_WIDTH = 1000 + + +def download_and_process_image(image_url: str) -> Tuple[Image.Image, int, int, int]: + """下载图片并返回PIL Image对象和原始尺寸信息""" + try: + with urllib.request.urlopen(image_url, timeout=30) as response: + img_data = response.read() + img = Image.open(BytesIO(img_data)) + + # 转换为RGB模式(处理PNG透明通道) + if img.mode in ('RGBA', 'P'): + img = img.convert('RGB') + + original_width, original_height = img.size + dpi = img.info.get('dpi', (72, 72)) + if isinstance(dpi, tuple): + dpi = dpi[0] if dpi[0] > 0 else 72 + else: + dpi = 72 + + return img, original_width, original_height, dpi + except Exception as e: + raise RuntimeError(f"下载图片失败: {e}") + + +def auto_rotate_image(img: Image.Image) -> Tuple[Image.Image, bool]: + """ + 自动旋转图片:如果宽度大于高度(横向图片),则旋转-90度使其变为纵向 + + Args: + img: PIL Image对象 + + Returns: + (旋转后的图片, 是否进行了旋转) + """ + width, height = img.size + + # 如果宽度大于高度,需要旋转 + if width > height: + logger.info(f"检测到横向图片(宽{width} > 高{height}),正在旋转-90度...") + # rotate(-90) 表示逆时针旋转90度,使横向变为纵向 + rotated_img = img.rotate(-90, expand=True) + new_width, new_height = rotated_img.size + logger.info(f"旋转完成,新尺寸:宽{new_width} x 高{new_height}") + return rotated_img, True + + logger.info(f"图片为纵向(宽{width} <= 高{height}),无需旋转") + return img, False + + +def resize_image_to_fixed_width(img: Image.Image, target_width: int = FIXED_WIDTH) -> Tuple[Image.Image, float]: + """将图片缩放到固定宽度,高度等比例缩放""" + original_width, original_height = img.size + + if original_width == target_width: + return img, 1.0 + + # 计算缩放比例 + scale_ratio = target_width / original_width + new_height = int(original_height * scale_ratio) + + # 使用高质量重采样 + # BICUBIC = 3, LANCZOS = 4 (PIL内部常量值) + resized_img = img.resize((target_width, new_height), 3) # BICUBIC + + return resized_img, scale_ratio + + +def upload_image_to_storage(img: Image.Image, ctx) -> str: + """将图片上传到对象存储并返回URL""" + from coze_coding_dev_sdk.s3 import S3SyncStorage + + # 转换为字节流 + img_buffer = BytesIO() + img.save(img_buffer, format='JPEG', quality=95) + img_bytes = img_buffer.getvalue() + + # 上传到对象存储 + storage = S3SyncStorage( + endpoint_url=os.getenv("COZE_BUCKET_ENDPOINT_URL"), + access_key="", + secret_key="", + bucket_name=os.getenv("COZE_BUCKET_NAME"), + region="cn-beijing", + ) + + file_key = storage.upload_file( + file_content=img_bytes, + file_name=f"homework_resized_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jpg", + content_type="image/jpeg", + ) + + result_url = storage.generate_presigned_url(key=file_key, expire_time=86400) + return result_url + + +def image_preprocess_node( + state: ImagePreprocessInput, + config: RunnableConfig, + runtime: Runtime[Context] +) -> ImagePreprocessOutput: + """ + title: 图像预处理 + desc: 下载图片、自动旋转(横向→纵向)、缩放到固定宽度1000px、上传对象存储 + integrations: 对象存储 + """ + ctx = runtime.context + + # 1. 下载原始图片 + img, original_width, original_height, dpi = download_and_process_image( + state.homework_image.url + ) + logger.info(f"原始图片尺寸:宽{original_width} x 高{original_height}") + + # 2. 自动旋转:如果宽度大于高度,旋转-90度使其变为纵向 + img, was_rotated = auto_rotate_image(img) + after_rotate_width, after_rotate_height = img.size + if was_rotated: + logger.info(f"旋转后尺寸:宽{after_rotate_width} x 高{after_rotate_height}") + + # 3. 缩放到固定宽度1000px,高度等比例缩放 + resized_img, scale_ratio = resize_image_to_fixed_width(img, FIXED_WIDTH) + new_width, new_height = resized_img.size + logger.info(f"缩放后尺寸:宽{new_width} x 高{new_height},缩放比例:{scale_ratio:.3f}") + + # 4. 上传处理后的图片到对象存储 + processed_image_url = upload_image_to_storage(resized_img, ctx) + + # 5. 返回处理后的图片信息(AI基于这个尺寸计算坐标) + return ImagePreprocessOutput( + image_info=ImageInfo( + width=new_width, # 缩放后的宽度:1000 + height=new_height, # 缩放后的高度 + dpi=dpi + ), + image_url=processed_image_url + ) diff --git a/src/graphs/nodes/process_images_node.py b/src/graphs/nodes/process_images_node.py new file mode 100644 index 0000000..a89a6eb --- /dev/null +++ b/src/graphs/nodes/process_images_node.py @@ -0,0 +1,167 @@ +"""多图片处理循环节点:并行调用子图处理每张作业图片""" +import logging +from typing import List +from concurrent.futures import ThreadPoolExecutor, as_completed +from langchain_core.runnables import RunnableConfig +from langgraph.runtime import Runtime +from coze_coding_utils.runtime_ctx.context import Context + +from graphs.state import ( + ProcessImagesInput, + ProcessImagesOutput, + SingleImageResult, + SubgraphInput, + FinalResult, + ImageInfo +) +from graphs.loop_graph import single_image_subgraph + +logger = logging.getLogger(__name__) + + +def process_images_node( + state: ProcessImagesInput, + config: RunnableConfig, + runtime: Runtime[Context] +) -> ProcessImagesOutput: + """ + title: 多图片批改处理 + desc: 并行调用子图处理每张作业图片,生成最终批改结果 + integrations: + """ + ctx = runtime.context + + # 获取并发数限制(从参数获取,默认10) + max_concurrent = getattr(state, 'max_concurrent', 10) + + logger.info(f"Starting to process {len(state.homework_images)} images (concurrent={max_concurrent})") + + # 定义处理单张图片的函数 + def process_single_image(idx: int, homework_image) -> SingleImageResult: + logger.info(f"Processing image {idx + 1}/{len(state.homework_images)}") + try: + # 构建子图输入 + subgraph_input = SubgraphInput( + homework_image=homework_image, + correct_answers=state.correct_answers, + image_index=idx, + comment_max_length=state.comment_max_length + ) + + # 调用子图 + subgraph_output = single_image_subgraph.invoke(subgraph_input, config) + + # 提取结果(兼容字典和Pydantic对象两种返回类型) + image_result = None + if subgraph_output: + if isinstance(subgraph_output, dict): + result_data = subgraph_output.get("image_result") + if result_data: + if isinstance(result_data, dict): + image_result = SingleImageResult(**result_data) + else: + image_result = result_data + elif hasattr(subgraph_output, 'image_result'): + image_result = subgraph_output.image_result + + if image_result: + logger.info(f"Image {idx + 1} processed successfully: {len(image_result.annotations)} annotations") + return image_result + else: + logger.warning(f"Image {idx + 1} subgraph returned invalid output") + return SingleImageResult( + image_index=idx, + image_info=ImageInfo(width=0, height=0, dpi=72), + annotations=[] + ) + except Exception as e: + logger.error(f"Failed to process image {idx + 1}: {e}", exc_info=True) + return SingleImageResult( + image_index=idx, + image_info=ImageInfo(width=0, height=0, dpi=72), + annotations=[] + ) + + # 并行处理所有图片 + image_results: List[SingleImageResult] = [] + + with ThreadPoolExecutor(max_workers=max_concurrent) as executor: + # 提交所有任务 + future_to_idx = { + executor.submit(process_single_image, idx, img): idx + for idx, img in enumerate(state.homework_images) + } + + # 收集结果 + results_dict = {} + for future in as_completed(future_to_idx): + idx = future_to_idx[future] + try: + result = future.result() + results_dict[idx] = result + except Exception as e: + logger.error(f"Future failed for image {idx}: {e}") + results_dict[idx] = SingleImageResult( + image_index=idx, + image_info=ImageInfo(width=0, height=0, dpi=72), + annotations=[] + ) + + # 按原始顺序排序结果 + for idx in range(len(state.homework_images)): + image_results.append(results_dict[idx]) + + logger.info(f"Completed processing {len(image_results)} images") + + # 生成最终结果 + total_score = 0 + full_score = 0 + total_questions = 0 + correct_count = 0 + incorrect_count = 0 + + for result in image_results: + for annotation in result.annotations: + total_score += annotation.score + full_score += annotation.full_score + total_questions += 1 + if annotation.status == "correct": + correct_count += 1 + elif annotation.status == "incorrect": + incorrect_count += 1 + + # 计算得分率 + score_rate = (total_score / full_score * 100) if full_score > 0 else 0 + + # 生成整体评价 + if total_questions == 0: + overall_comment = "未识别到题目内容" + grade = "D" + elif score_rate >= 95: + overall_comment = f"优秀!{total_questions}题全部正确,掌握扎实,继续保持。" + grade = "A+" + elif score_rate >= 90: + overall_comment = f"良好!得分率{score_rate:.0f}%,错{incorrect_count}题,注意细节。" + grade = "A" + elif score_rate >= 80: + overall_comment = f"合格。得分率{score_rate:.0f}%,错{incorrect_count}题,需加强练习。" + grade = "B" + elif score_rate >= 70: + overall_comment = f"及格。错{incorrect_count}题,部分知识点掌握不牢,建议复习。" + grade = "C" + else: + overall_comment = f"需努力。得分率{score_rate:.0f}%,建议认真复习,多做练习。" + grade = "D" + + final_result = FinalResult( + total_images=len(image_results), + image_results=image_results, + overall_comment=overall_comment, + total_score=total_score, + full_score=full_score, + grade=grade + ) + + logger.info(f"Final result: {total_score}/{full_score}, grade={grade}") + + return ProcessImagesOutput(final_result=final_result) diff --git a/src/graphs/nodes/recognize_and_correct_node.py b/src/graphs/nodes/recognize_and_correct_node.py new file mode 100644 index 0000000..b860a35 --- /dev/null +++ b/src/graphs/nodes/recognize_and_correct_node.py @@ -0,0 +1,340 @@ +"""一体化识别批改节点:合并识别和批改为一次LLM调用,提升速度""" +import os +import json +import re +import logging +from typing import List, Dict, Any +from jinja2 import Template +from langchain_core.runnables import RunnableConfig +from langgraph.runtime import Runtime +from coze_coding_utils.runtime_ctx.context import Context +from coze_coding_dev_sdk import LLMClient +from langchain_core.messages import HumanMessage + +from graphs.state import ( + RecognizeAndCorrectInput, + RecognizeAndCorrectOutput, + QuestionItem, + CorrectionResult, + MarkPosition, + CorrectAnswer +) + +logger = logging.getLogger(__name__) + + +def fix_incomplete_json(text: str) -> str: + """尝试修复不完整的JSON字符串""" + # 计算括号数量 + brace_count = text.count('{') - text.count('}') + bracket_count = text.count('[') - text.count(']') + + # 补全缺失的括号 + if bracket_count > 0: + text += ']' * bracket_count + if brace_count > 0: + text += '}' * brace_count + + return text + + +def extract_json_from_text(text: str, key: str = "results") -> dict: + """从文本中提取JSON对象,增强健壮性""" + import orjson + + # 清理markdown标记 + for prefix in ["```json", "```JSON", "```"]: + if text.startswith(prefix): + text = text[len(prefix):] + for suffix in ["```"]: + if text.endswith(suffix): + text = text[:-3] + + text = text.strip() + + # 尝试直接解析(支持格式化JSON) + for parser in [json.loads, lambda x: orjson.loads(x)]: + try: + result = parser(text) + if isinstance(result, dict) and key in result: + logger.info("Successfully parsed JSON directly") + return result + except Exception as e: + logger.debug(f"Direct parse failed: {e}") + + # 尝试修复不完整的JSON + try: + fixed_text = fix_incomplete_json(text) + for parser in [json.loads, lambda x: orjson.loads(x)]: + try: + result = parser(fixed_text) + if isinstance(result, dict) and key in result: + logger.info("Successfully parsed fixed JSON") + return result + except: + pass + except: + pass + + # 尝试提取第一个完整的JSON对象 + results_pattern = r'"results"\s*:\s*\[' + match = re.search(results_pattern, text) + + if match: + start_pos = text.rfind('{', 0, match.start()) + if start_pos == -1: + start_pos = 0 + + # 从start_pos开始,找到匹配的 } + brace_count = 0 + bracket_count = 0 + in_string = False + escape = False + end_pos = start_pos + + for i in range(start_pos, len(text)): + char = text[i] + + if escape: + escape = False + continue + + if char == '\\': + escape = True + continue + + if char == '"' and not escape: + in_string = not in_string + continue + + if in_string: + continue + + if char == '{': + brace_count += 1 + elif char == '}': + brace_count -= 1 + if brace_count == 0 and bracket_count == 0: + end_pos = i + 1 + break + elif char == '[': + bracket_count += 1 + elif char == ']': + bracket_count -= 1 + + if end_pos > start_pos: + json_str = text[start_pos:end_pos] + + for parser in [json.loads, lambda x: orjson.loads(x)]: + try: + result = parser(json_str) + if isinstance(result, dict) and key in result: + logger.info("Successfully extracted and parsed JSON object") + return result + except Exception as e: + logger.debug(f"Failed to parse extracted JSON: {e}") + + # 如果提取失败,尝试修复不完整的JSON + if end_pos <= start_pos: + # 找到JSON开始位置 + json_str = text[start_pos:] + + # 尝试修复 + fixed_json = fix_incomplete_json(json_str) + for parser in [json.loads, lambda x: orjson.loads(x)]: + try: + result = parser(fixed_json) + if isinstance(result, dict) and key in result: + logger.info("Successfully parsed fixed incomplete JSON") + return result + except: + pass + + logger.warning(f"Failed to extract JSON with key '{key}' from text length {len(text)}") + return {key: []} + + +def build_dynamic_prompt( + image_width: int, + image_height: int, + correct_answers: List[CorrectAnswer], + comment_max_length: int +) -> str: + """构建动态Prompt部分""" + + if correct_answers: + answers_text = json.dumps( + [{"question_id": a.question_id, "correct_answer": a.correct_answer} + for a in correct_answers], + ensure_ascii=False + ) + answer_hint = f""" +【标准答案】 +{answers_text}""" + else: + answer_hint = "\n【批改模式】无标准答案,请根据物理知识判断。" + + return f""" +【图片尺寸】{image_width}×{image_height}像素 +【评语限制】{comment_max_length}字以内 +{answer_hint} + +⚠️ 必须输出完整JSON,不要输出思考过程""" + + +def recognize_and_correct_node( + state: RecognizeAndCorrectInput, + config: RunnableConfig, + runtime: Runtime[Context] +) -> RecognizeAndCorrectOutput: + """ + title: 一体化识别批改 + desc: 合并识别和批改为一次LLM调用,提升批改速度 + integrations: 大语言模型 + """ + ctx = runtime.context + + # 读取LLM配置 + cfg_file = os.path.join(os.getenv("COZE_WORKSPACE_PATH", ""), config["metadata"]["llm_cfg"]) + with open(cfg_file, "r", encoding="utf-8") as fd: + _cfg = json.load(fd) + + llm_config = _cfg.get("config", {}) + sp = _cfg.get("sp", "") + up = _cfg.get("up", "") + + # 获取参数 + image_url = state.image_url + image_info = state.image_info + correct_answers = state.correct_answers + comment_max_length = getattr(state, 'comment_max_length', 100) + + # 使用Jinja2渲染up + up_tpl = Template(up) + user_prompt = up_tpl.render({ + "image_url": image_url, + "comment_max_length": comment_max_length + }) + + # 构建动态部分 + dynamic_prompt = build_dynamic_prompt( + image_info.width, + image_info.height, + correct_answers, + comment_max_length + ) + + # 组合完整Prompt + full_prompt = f"{sp}\n\n{user_prompt}\n{dynamic_prompt}" + + # 调用LLM + messages = [ + HumanMessage(content=[ + {"type": "text", "text": full_prompt}, + {"type": "image_url", "image_url": {"url": image_url}} + ]) + ] + + client = LLMClient(ctx=ctx) + response = client.invoke( + messages=messages, + model=llm_config.get("model", "doubao-seed-2-0-pro-260215"), + temperature=llm_config.get("temperature", 0.0), + max_completion_tokens=llm_config.get("max_completion_tokens", 4096) + ) + + response_text = response.content if isinstance(response.content, str) else " ".join( + item.get("text", "") if isinstance(item, dict) else str(item) + for item in response.content + ).strip() + + # 只记录前500字符,避免日志过大 + logger.info(f"LLM response (first 500 chars): {response_text[:500]}") + + # 解析结果 + result_dict = extract_json_from_text(response_text, "results") + + # 相对坐标(0-1000)转换为绝对坐标 + width_scale = image_info.width / 1000.0 if image_info.width > 0 else 1.0 + height_scale = image_info.height / 1000.0 if image_info.height > 0 else 1.0 + + question_items: List[QuestionItem] = [] + correction_results: List[CorrectionResult] = [] + + results_list = result_dict.get("results", []) + if not isinstance(results_list, list): + results_list = [] + + for r in results_list: + try: + if not isinstance(r, dict): + continue + + # 解析bbox(相对坐标0-1000) + answer_bbox = r.get("answer_bbox", [0, 0, 0, 0]) + if not isinstance(answer_bbox, list) or len(answer_bbox) != 4: + answer_bbox = [0, 0, 0, 0] + + # 转换为绝对坐标 + answer_bbox_abs = [ + int(answer_bbox[0] * width_scale), + int(answer_bbox[1] * height_scale), + int(answer_bbox[2] * width_scale), + int(answer_bbox[3] * height_scale) + ] + + # 自动计算mark_position + bbox_width = answer_bbox_abs[2] - answer_bbox_abs[0] + bbox_height = answer_bbox_abs[3] - answer_bbox_abs[1] + + if bbox_width > 0 and bbox_height > 0: + mark_x = answer_bbox_abs[2] + 30 + mark_y = answer_bbox_abs[1] + int(bbox_height * 0.5) + + if mark_x > image_info.width - 50: + mark_x = image_info.width - 50 + else: + mark_x = 500 + mark_y = 500 + + mark_position = MarkPosition(x=mark_x, y=mark_y) + + # 构建QuestionItem + question_items.append(QuestionItem( + question_id=str(r.get("question_id", "")), + parent_id=str(r.get("parent_id", "")), + is_sub_question=bool(r.get("is_sub_question", False)), + question_text=str(r.get("question_text", "")), + student_answer=str(r.get("student_answer", "")), + answer_bbox=answer_bbox_abs, + mark_position=mark_position, + full_score=int(r.get("full_score", 10) if r.get("full_score") is not None else 10) + )) + + # 构建CorrectionResult + status = str(r.get("status", "incorrect")) + if status not in ["correct", "incorrect", "partial"]: + status = "incorrect" + + # 使用原始comment,不做截断(由LLM控制长度) + comment = str(r.get("comment", "")) + + correction_results.append(CorrectionResult( + question_id=str(r.get("question_id", "")), + parent_id=str(r.get("parent_id", "")), + status=status, + score=int(r.get("score", 0) if r.get("score") is not None else 0), + full_score=int(r.get("full_score", 10) if r.get("full_score") is not None else 10), + comment=comment + )) + + except Exception as e: + logger.warning(f"Failed to parse result: {r}, error: {e}") + continue + + logger.info(f"Parsed {len(question_items)} questions and {len(correction_results)} correction results") + + return RecognizeAndCorrectOutput( + question_items=question_items, + correction_results=correction_results + ) diff --git a/src/graphs/nodes/result_merge_node.py b/src/graphs/nodes/result_merge_node.py new file mode 100644 index 0000000..e6b0b36 --- /dev/null +++ b/src/graphs/nodes/result_merge_node.py @@ -0,0 +1,65 @@ +"""4. 结果整合节点:将识别结果和批改结果合并为最终批注""" +from typing import List +from langchain_core.runnables import RunnableConfig +from langgraph.runtime import Runtime +from coze_coding_utils.runtime_ctx.context import Context + +from graphs.state import ( + ResultMergeInput, + ResultMergeOutput, + Annotation, + QuestionItem, + CorrectionResult +) + + +def result_merge_node( + state: ResultMergeInput, + config: RunnableConfig, + runtime: Runtime[Context] +) -> ResultMergeOutput: + """ + title: 结果整合 + desc: 将识别结果和批改结果合并为最终批注列表 + integrations: + """ + ctx = runtime.context + + annotations: List[Annotation] = [] + + # 创建批改结果字典,方便查找 + correction_dict = {c.question_id: c for c in state.correction_results} + + for q in state.question_items: + # 查找对应的批改结果 + correction = correction_dict.get(q.question_id) + + if correction is None: + # 如果没有批改结果,默认为错误 + annotations.append(Annotation( + question_id=q.question_id, + parent_id=q.parent_id, + status="incorrect", + question_text=q.question_text, + student_answer=q.student_answer, + answer_bbox=q.answer_bbox, + comment="未找到批改结果", + mark_position=q.mark_position, + score=0, + full_score=q.full_score + )) + else: + annotations.append(Annotation( + question_id=q.question_id, + parent_id=q.parent_id, + status=correction.status, + question_text=q.question_text, + student_answer=q.student_answer, + answer_bbox=q.answer_bbox, + comment=correction.comment, + mark_position=q.mark_position, + score=correction.score, + full_score=correction.full_score + )) + + return ResultMergeOutput(annotations=annotations) diff --git a/src/graphs/state.py b/src/graphs/state.py new file mode 100644 index 0000000..4c59dc6 --- /dev/null +++ b/src/graphs/state.py @@ -0,0 +1,232 @@ +"""初中物理作业批改工作流状态定义 - 支持多图片批改""" +from typing import List, Optional, Literal +from pydantic import BaseModel, Field +from utils.file.file import File + + +# === 基础数据结构 === + +class ImageInfo(BaseModel): + """图片信息""" + width: int = Field(..., description="图片宽度(像素)") + height: int = Field(..., description="图片高度(像素)") + dpi: int = Field(default=72, description="图片DPI") + + +class MarkPosition(BaseModel): + """批改标记位置""" + x: int = Field(..., description="标记中心X坐标(绝对像素)") + y: int = Field(..., description="标记中心Y坐标(绝对像素)") + + +class QuestionItem(BaseModel): + """题目项(一体化识别结果)""" + question_id: str = Field(..., description="题号(如10、10.1、10.2)") + parent_id: str = Field(default="", description="母题题号(子题时填写)") + is_sub_question: bool = Field(default=False, description="是否为子题") + question_text: str = Field(default="", description="题目内容") + student_answer: str = Field(default="", description="学生答案内容") + answer_bbox: List[int] = Field(..., description="学生答案区域的边界框 [x1, y1, x2, y2](用于定位标注)") + mark_position: MarkPosition = Field(..., description="批改标记应该标注的位置坐标") + full_score: int = Field(default=10, description="满分") + + +class CorrectAnswer(BaseModel): + """标准答案项(从Word文档解析)""" + question_id: str = Field(..., description="题号(如10、10.1、10.2)") + parent_id: str = Field(default="", description="母题题号(子题时填写)") + is_sub_question: bool = Field(default=False, description="是否为子题") + question_text: str = Field(default="", description="题目内容/题干") + correct_answer: str = Field(..., description="标准答案") + full_score: int = Field(default=10, description="满分") + answer_analysis: str = Field(default="", description="答案解析/解题步骤") + + +class CorrectionResult(BaseModel): + """批改结果""" + question_id: str = Field(..., description="题号") + parent_id: str = Field(default="", description="母题题号") + status: Literal["correct", "incorrect", "partial"] = Field(..., description="批改状态") + score: int = Field(default=0, description="得分") + full_score: int = Field(default=10, description="满分") + comment: str = Field(default="", description="批改评语") + + +class Annotation(BaseModel): + """完整批注""" + question_id: str = Field(..., description="题号") + parent_id: str = Field(default="", description="母题题号") + status: Literal["correct", "incorrect", "partial"] = Field(..., description="批改状态") + question_text: str = Field(default="", description="题目内容") + student_answer: str = Field(default="", description="学生答案") + answer_bbox: List[int] = Field(default=[], description="答案区域边界框") + comment: str = Field(default="", description="批改评语") + mark_position: MarkPosition = Field(..., description="批改标记位置") + score: int = Field(default=0, description="得分") + full_score: int = Field(default=10, description="满分") + + +class SingleImageResult(BaseModel): + """单张图片的批改结果""" + image_index: int = Field(..., description="图片索引(从0开始)") + image_info: ImageInfo = Field(..., description="图片信息") + image_url: str = Field(default="", description="处理后的图片URL") + annotations: List[Annotation] = Field(default=[], description="该图片的批注列表") + + +class FinalResult(BaseModel): + """最终批改结果(多图片汇总)""" + total_images: int = Field(..., description="总图片数") + image_results: List[SingleImageResult] = Field(default=[], description="各图片的批改结果") + overall_comment: str = Field(default="", description="整体评价") + total_score: int = Field(default=0, description="总分") + full_score: int = Field(default=100, description="满分") + grade: str = Field(default="", description="等级") + + +# === 全局状态 === +class GlobalState(BaseModel): + """工作流全局状态""" + # 输入参数 + homework_images: List[File] = Field(default=[], description="上传的作业图片列表") + answer_doc_url: str = Field(default="", description="正确答案Word文件的URL(.docx格式)") + comment_max_length: int = Field(default=100, description="评语最大字数") + max_concurrent: int = Field(default=10, description="并行批改的最大数量") + grade_standards: dict = Field(default={}, description="评价等级标准") + # 中间状态 + correct_answers: List[CorrectAnswer] = Field(default=[], description="从Word解析的标准答案列表") + image_results: List[SingleImageResult] = Field(default=[], description="各图片的批改结果列表") + # 最终结果 + final_result: Optional[FinalResult] = Field(default=None, description="最终汇总结果") + + +# === 图输入输出 === +class GraphInput(BaseModel): + """工作流输入""" + homework_images: List[File] = Field(..., description="上传的作业图片列表") + answer_doc_url: str = Field(default="", description="正确答案Word文件的URL(.docx格式,可选)") + comment_max_length: int = Field(default=100, description="评语最大字数,默认100字") + max_concurrent: int = Field(default=10, description="并行批改的最大数量,默认10") + grade_standards: dict = Field( + default={ + "A+": {"min_percentage": 95, "description": "答案全部正确,步骤完整规范,逻辑严谨;书写/格式整洁,无错别字、无遗漏;完成度100%,态度认真,质量上乘"}, + "A": {"min_percentage": 90, "description": "答案完全正确,无任何错误;步骤合理、格式规范,无原则性问题;完成度100%,满足全部要求"}, + "B": {"min_percentage": 80, "description": "存在少量非关键性错误,或步骤略有缺失;整体思路基本正确,仅细节、格式、计算等小问题;完成大部分内容,整体合格但不够严谨"}, + "C": {"min_percentage": 70, "description": "错误较多,部分核心题目作答错误;步骤不完整、逻辑不够清晰;完成度一般,有明显应付、漏答情况"}, + "D": {"min_percentage": 0, "description": "大面积错误,核心知识点未掌握;大量空白、敷衍、抄袭;未达到基本完成要求"} + }, + description="评价等级标准,包含各等级的最低得分率百分比和描述" + ) + + +class GraphOutput(BaseModel): + """工作流输出""" + final_result: FinalResult = Field(..., description="最终批改结果JSON(包含多图片)") + + +# === 文档答案解析节点 === +class DocExtractInput(BaseModel): + """文档答案解析节点输入""" + answer_doc_url: str = Field(default="", description="正确答案Word文件的URL(.docx格式,可选)") + + +class DocExtractOutput(BaseModel): + """文档答案解析节点输出""" + correct_answers: List[CorrectAnswer] = Field(default=[], description="解析的标准答案列表") + + +# === 子图状态定义(单图片处理流程) === +class SubgraphState(BaseModel): + """子图全局状态(处理单张图片)""" + # 输入 + homework_image: File = Field(..., description="当前处理的作业图片") + correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表") + image_index: int = Field(default=0, description="图片索引") + comment_max_length: int = Field(default=100, description="评语最大字数") + # 中间状态 + image_info: ImageInfo = Field(default_factory=lambda: ImageInfo(width=0, height=0, dpi=72), description="图片信息") + image_url: str = Field(default="", description="处理后的图片URL") + question_items: List[QuestionItem] = Field(default=[], description="识别的题目项列表") + correction_results: List[CorrectionResult] = Field(default=[], description="批改结果列表") + annotations: List[Annotation] = Field(default=[], description="最终批注列表") + + +class SubgraphInput(BaseModel): + """子图输入""" + homework_image: File = Field(..., description="当前处理的作业图片") + correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表") + image_index: int = Field(default=0, description="图片索引") + comment_max_length: int = Field(default=100, description="评语最大字数") + + +class SubgraphOutput(BaseModel): + """子图输出""" + image_result: SingleImageResult = Field(..., description="单张图片的批改结果") + + +# === 子图节点输入输出 === + +# 1. 图像预处理节点 +class ImagePreprocessInput(BaseModel): + """图像预处理节点输入""" + homework_image: File = Field(..., description="上传的作业图片") + + +class ImagePreprocessOutput(BaseModel): + """图像预处理节点输出""" + image_info: ImageInfo = Field(..., description="图片信息") + image_url: str = Field(..., description="处理后的图片URL") + + +# 2. 一体化识别批改节点(合并识别+批改) +class RecognizeAndCorrectInput(BaseModel): + """一体化识别批改节点输入""" + image_url: str = Field(..., description="图片URL") + image_info: ImageInfo = Field(..., description="图片信息") + correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表") + comment_max_length: int = Field(default=100, description="评语最大字数") + + +class RecognizeAndCorrectOutput(BaseModel): + """一体化识别批改节点输出""" + question_items: List[QuestionItem] = Field(default=[], description="识别的题目项列表") + correction_results: List[CorrectionResult] = Field(default=[], description="批改结果列表") + + +# 3. 批改判断节点 +class CorrectionJudgeInput(BaseModel): + """批改判断节点输入""" + question_items: List[QuestionItem] = Field(default=[], description="题目项列表") + correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表") + comment_max_length: int = Field(default=100, description="评语最大字数") + + +class CorrectionJudgeOutput(BaseModel): + """批改判断节点输出""" + correction_results: List[CorrectionResult] = Field(default=[], description="批改结果列表") + + +# 4. 结果整合节点 +class ResultMergeInput(BaseModel): + """结果整合节点输入""" + question_items: List[QuestionItem] = Field(default=[], description="题目项列表") + correction_results: List[CorrectionResult] = Field(default=[], description="批改结果列表") + + +class ResultMergeOutput(BaseModel): + """结果整合节点输出""" + annotations: List[Annotation] = Field(default=[], description="最终批注列表") + + +# === 循环节点 === +class ProcessImagesInput(BaseModel): + """多图片处理循环节点输入""" + homework_images: List[File] = Field(default=[], description="作业图片列表") + correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表") + comment_max_length: int = Field(default=100, description="评语最大字数") + max_concurrent: int = Field(default=10, description="并行批改的最大数量") + + +class ProcessImagesOutput(BaseModel): + """多图片处理循环节点输出""" + final_result: FinalResult = Field(..., description="最终批改结果") diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..dd86f2d --- /dev/null +++ b/src/main.py @@ -0,0 +1,546 @@ +import argparse +import asyncio +import json +import threading +import traceback +import logging +from typing import Any, Dict, Iterable, AsyncIterable, AsyncGenerator, Optional +import cozeloop +import uvicorn +import time +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import StreamingResponse, JSONResponse +from langchain_core.runnables import RunnableConfig +from langgraph.graph import StateGraph, END +from langgraph.graph.state import CompiledStateGraph +from coze_coding_utils.runtime_ctx.context import new_context, Context +from coze_coding_utils.helper import graph_helper +from coze_coding_utils.log.node_log import LOG_FILE +from coze_coding_utils.log.write_log import setup_logging, request_context +from coze_coding_utils.log.config import LOG_LEVEL +from coze_coding_utils.error.classifier import ErrorClassifier, classify_error +from coze_coding_utils.helper.stream_runner import AgentStreamRunner, WorkflowStreamRunner,agent_stream_handler,workflow_stream_handler, RunOpt + +setup_logging( + log_file=LOG_FILE, + max_bytes=100 * 1024 * 1024, # 100MB + backup_count=5, + log_level=LOG_LEVEL, + use_json_format=True, + console_output=True +) + +logger = logging.getLogger(__name__) +from coze_coding_utils.helper.agent_helper import to_stream_input +from coze_coding_utils.openai.handler import OpenAIChatHandler +from coze_coding_utils.log.parser import LangGraphParser +from coze_coding_utils.log.err_trace import extract_core_stack +from coze_coding_utils.log.loop_trace import init_run_config, init_agent_config + + +# 超时配置常量 +TIMEOUT_SECONDS = 900 # 15分钟 + +class GraphService: + def __init__(self): + # 用于跟踪正在运行的任务(使用asyncio.Task) + self.running_tasks: Dict[str, asyncio.Task] = {} + # 错误分类器 + self.error_classifier = ErrorClassifier() + # stream runner + self._agent_stream_runner = AgentStreamRunner() + self._workflow_stream_runner = WorkflowStreamRunner() + self._graph = None + self._graph_lock = threading.Lock() + + def _get_graph(self, ctx=Context): + if graph_helper.is_agent_proj(): + return graph_helper.get_agent_instance("agents.agent", ctx) + + if self._graph is not None: + return self._graph + with self._graph_lock: + if self._graph is not None: + return self._graph + self._graph = graph_helper.get_graph_instance("graphs.graph") + return self._graph + + @staticmethod + def _sse_event(data: Any, event_id: Any = None) -> str: + id_line = f"id: {event_id}\n" if event_id else "" + return f"{id_line}event: message\ndata: {json.dumps(data, ensure_ascii=False, default=str)}\n\n" + + def _get_stream_runner(self): + if graph_helper.is_agent_proj(): + return self._agent_stream_runner + else: + return self._workflow_stream_runner + + # 流式运行(原始迭代器):本地调用使用 + def stream(self, payload: Dict[str, Any], run_config: RunnableConfig, ctx=Context) -> Iterable[Any]: + graph = self._get_graph(ctx) + stream_runner = self._get_stream_runner() + for chunk in stream_runner.stream(payload, graph, run_config, ctx): + yield chunk + + # 同步运行:本地/HTTP 通用 + async def run(self, payload: Dict[str, Any], ctx=None) -> Dict[str, Any]: + if ctx is None: + ctx = new_context("run") + + run_id = ctx.run_id + logger.info(f"Starting run with run_id: {run_id}") + + try: + graph = self._get_graph(ctx) + # custom tracer + run_config = init_run_config(graph, ctx) + run_config["configurable"] = {"thread_id": ctx.run_id} + + # 直接调用,LangGraph会在当前任务上下文中执行 + # 如果当前任务被取消,LangGraph的执行也会被取消 + return await graph.ainvoke(payload, config=run_config, context=ctx) + + except asyncio.CancelledError: + logger.info(f"Run {run_id} was cancelled") + return {"status": "cancelled", "run_id": run_id, "message": "Execution was cancelled"} + except Exception as e: + # 使用错误分类器分类错误 + err = self.error_classifier.classify(e, {"node_name": "run", "run_id": run_id}) + # 记录详细的错误信息和堆栈跟踪 + logger.error( + f"Error in GraphService.run: [{err.code}] {err.message}\n" + f"Category: {err.category.name}\n" + f"Traceback:\n{extract_core_stack()}" + ) + # 保留原始异常堆栈,便于上层返回真正的报错位置 + raise + finally: + # 清理任务记录 + self.running_tasks.pop(run_id, None) + + # 流式运行(SSE 格式化):HTTP 路由使用 + async def stream_sse(self, payload: Dict[str, Any], ctx=None, run_opt: Optional[RunOpt] = None) -> AsyncGenerator[str, None]: + if ctx is None: + ctx = new_context(method="stream_sse") + if run_opt is None: + run_opt = RunOpt() + + run_id = ctx.run_id + logger.info(f"Starting stream with run_id: {run_id}") + graph = self._get_graph(ctx) + if graph_helper.is_agent_proj(): + run_config = init_agent_config(graph, ctx) + else: + run_config = init_run_config(graph, ctx) # vibeflow + + is_workflow = not graph_helper.is_agent_proj() + + try: + async for chunk in self.astream(payload, graph, run_config=run_config, ctx=ctx, run_opt=run_opt): + if is_workflow and isinstance(chunk, tuple): + event_id, data = chunk + yield self._sse_event(data, event_id) + else: + yield self._sse_event(chunk) + finally: + # 清理任务记录 + self.running_tasks.pop(run_id, None) + cozeloop.flush() + + # 取消执行 - 使用asyncio的标准方式 + def cancel_run(self, run_id: str, ctx: Optional[Context] = None) -> Dict[str, Any]: + """ + 取消指定run_id的执行 + + 使用asyncio.Task.cancel()来取消任务,这是标准的Python异步取消机制。 + LangGraph会在节点之间检查CancelledError,实现优雅的取消。 + """ + logger.info(f"Attempting to cancel run_id: {run_id}") + + # 查找对应的任务 + if run_id in self.running_tasks: + task = self.running_tasks[run_id] + if not task.done(): + # 使用asyncio的标准取消机制 + # 这会在下一个await点抛出CancelledError + task.cancel() + logger.info(f"Cancellation requested for run_id: {run_id}") + return { + "status": "success", + "run_id": run_id, + "message": "Cancellation signal sent, task will be cancelled at next await point" + } + else: + logger.info(f"Task already completed for run_id: {run_id}") + return { + "status": "already_completed", + "run_id": run_id, + "message": "Task has already completed" + } + else: + logger.warning(f"No active task found for run_id: {run_id}") + return { + "status": "not_found", + "run_id": run_id, + "message": "No active task found with this run_id. Task may have already completed or run_id is invalid." + } + + # 运行指定节点:本地/HTTP 通用 + async def run_node(self, node_id: str, payload: Dict[str, Any], ctx=None) -> Any: + if ctx is None or Context.run_id == "": + ctx = new_context(method="node_run") + + _graph = self._get_graph() + node_func, input_cls, output_cls = graph_helper.get_graph_node_func_with_inout(_graph.get_graph(), node_id) + if node_func is None or input_cls is None: + raise KeyError(f"node_id '{node_id}' not found") + + parser = LangGraphParser(_graph) + metadata = parser.get_node_metadata(node_id) or {} + + _g = StateGraph(input_cls, input_schema=input_cls, output_schema=output_cls) + _g.add_node("sn", node_func, metadata=metadata) + _g.set_entry_point("sn") + _g.add_edge("sn", END) + _graph = _g.compile() + + run_config = init_run_config(_graph, ctx) + return await _graph.ainvoke(payload, config=run_config) + + def graph_inout_schema(self) -> Any: + if graph_helper.is_agent_proj(): + return {"input_schema": {}, "output_schema": {}} + builder = getattr(self._get_graph(), 'builder', None) + if builder is not None: + input_cls = getattr(builder, 'input_schema', None) or self.graph.get_input_schema() + output_cls = getattr(builder, 'output_schema', None) or self.graph.get_output_schema() + else: + logger.warning(f"No builder input schema found for graph_inout_schema, using graph input schema instead") + input_cls = self.graph.get_input_schema() + output_cls = self.graph.get_output_schema() + + return { + "input_schema": input_cls.model_json_schema(), + "output_schema": output_cls.model_json_schema(), + "code":0, + "msg":"" + } + + async def astream(self, payload: Dict[str, Any], graph: CompiledStateGraph, run_config: RunnableConfig, ctx=Context, run_opt: Optional[RunOpt] = None) -> AsyncIterable[Any]: + stream_runner = self._get_stream_runner() + async for chunk in stream_runner.astream(payload, graph, run_config, ctx, run_opt): + yield chunk + + +service = GraphService() +app = FastAPI() + +# OpenAI 兼容接口处理器 +openai_handler = OpenAIChatHandler(service) + + +HEADER_X_RUN_ID = "x-run-id" +@app.post("/run") +async def http_run(request: Request) -> Dict[str, Any]: + global result + raw_body = await request.body() + try: + body_text = raw_body.decode("utf-8") + except Exception as e: + body_text = str(raw_body) + raise HTTPException(status_code=400, + detail=f"Invalid JSON format: {body_text}, traceback: {traceback.format_exc()}, error: {e}") + + ctx = new_context(method="run", headers=request.headers) + # 优先使用上游指定的 run_id,保证 cancel 能精确匹配 + upstream_run_id = request.headers.get(HEADER_X_RUN_ID) + if upstream_run_id: + ctx.run_id = upstream_run_id + run_id = ctx.run_id + request_context.set(ctx) + + logger.info( + f"Received request for /run: " + f"run_id={run_id}, " + f"query={dict(request.query_params)}, " + f"body={body_text}" + ) + + try: + payload = await request.json() + + # 创建任务并记录 - 这是关键,让我们可以通过run_id取消任务 + task = asyncio.create_task(service.run(payload, ctx)) + service.running_tasks[run_id] = task + + try: + result = await asyncio.wait_for(task, timeout=float(TIMEOUT_SECONDS)) + except asyncio.TimeoutError: + logger.error(f"Run execution timeout after {TIMEOUT_SECONDS}s for run_id: {run_id}") + task.cancel() + try: + result = await task + except asyncio.CancelledError: + return { + "status": "timeout", + "run_id": run_id, + "message": f"Execution timeout: exceeded {TIMEOUT_SECONDS} seconds" + } + + if not result: + result = {} + if isinstance(result, dict): + result["run_id"] = run_id + return result + + except json.JSONDecodeError as e: + logger.error(f"JSON decode error in http_run: {e}, traceback: {traceback.format_exc()}") + raise HTTPException(status_code=400, detail=f"Invalid JSON format, {extract_core_stack()}") + + except asyncio.CancelledError: + logger.info(f"Request cancelled for run_id: {run_id}") + result = {"status": "cancelled", "run_id": run_id, "message": "Execution was cancelled"} + return result + + except Exception as e: + # 使用错误分类器获取错误信息 + error_response = service.error_classifier.get_error_response(e, {"node_name": "http_run", "run_id": run_id}) + logger.error( + f"Unexpected error in http_run: [{error_response['error_code']}] {error_response['error_message']}, " + f"traceback: {traceback.format_exc()}", exc_info=True + ) + raise HTTPException( + status_code=500, + detail={ + "error_code": error_response["error_code"], + "error_message": error_response["error_message"], + "stack_trace": extract_core_stack(), + } + ) + finally: + cozeloop.flush() + + +HEADER_X_WORKFLOW_STREAM_MODE = "x-workflow-stream-mode" + + +def _register_task(run_id: str, task: asyncio.Task): + service.running_tasks[run_id] = task + + +@app.post("/stream_run") +async def http_stream_run(request: Request): + ctx = new_context(method="stream_run", headers=request.headers) + # 优先使用上游指定的 run_id,保证 cancel 能精确匹配 + upstream_run_id = request.headers.get(HEADER_X_RUN_ID) + if upstream_run_id: + ctx.run_id = upstream_run_id + workflow_stream_mode = request.headers.get(HEADER_X_WORKFLOW_STREAM_MODE, "").lower() + workflow_debug = workflow_stream_mode == "debug" + request_context.set(ctx) + raw_body = await request.body() + try: + body_text = raw_body.decode("utf-8") + except Exception as e: + body_text = str(raw_body) + raise HTTPException(status_code=400, + detail=f"Invalid JSON format: {body_text}, traceback: {extract_core_stack()}, error: {e}") + run_id = ctx.run_id + is_agent = graph_helper.is_agent_proj() + logger.info( + f"Received request for /stream_run: " + f"run_id={run_id}, " + f"is_agent_project={is_agent}, " + f"query={dict(request.query_params)}, " + f"body={body_text}" + ) + try: + payload = await request.json() + except json.JSONDecodeError as e: + logger.error(f"JSON decode error in http_stream_run: {e}, traceback: {traceback.format_exc()}") + raise HTTPException(status_code=400, detail=f"Invalid JSON format:{extract_core_stack()}") + + if is_agent: + stream_generator = agent_stream_handler( + payload=payload, + ctx=ctx, + run_id=run_id, + stream_sse_func=service.stream_sse, + sse_event_func=service._sse_event, + error_classifier=service.error_classifier, + register_task_func=_register_task, + ) + else: + stream_generator = workflow_stream_handler( + payload=payload, + ctx=ctx, + run_id=run_id, + stream_sse_func=service.stream_sse, + sse_event_func=service._sse_event, + error_classifier=service.error_classifier, + register_task_func=_register_task, + run_opt=RunOpt(workflow_debug=workflow_debug), + ) + + response = StreamingResponse(stream_generator, media_type="text/event-stream") + return response + +@app.post("/cancel/{run_id}") +async def http_cancel(run_id: str, request: Request): + """ + 取消指定run_id的执行 + + 使用asyncio.Task.cancel()实现取消,这是Python标准的异步任务取消机制。 + LangGraph会在节点之间的await点检查CancelledError,实现优雅取消。 + """ + ctx = new_context(method="cancel", headers=request.headers) + request_context.set(ctx) + logger.info(f"Received cancel request for run_id: {run_id}") + result = service.cancel_run(run_id, ctx) + return result + + +@app.post(path="/node_run/{node_id}") +async def http_node_run(node_id: str, request: Request): + raw_body = await request.body() + try: + body_text = raw_body.decode("utf-8") + except UnicodeDecodeError: + body_text = str(raw_body) + raise HTTPException(status_code=400, detail=f"Invalid JSON format: {body_text}") + ctx = new_context(method="node_run", headers=request.headers) + request_context.set(ctx) + logger.info( + f"Received request for /node_run/{node_id}: " + f"query={dict(request.query_params)}, " + f"body={body_text}", + ) + + try: + payload = await request.json() + except json.JSONDecodeError as e: + logger.error(f"JSON decode error in http_node_run: {e}, traceback: {traceback.format_exc()}") + raise HTTPException(status_code=400, detail=f"Invalid JSON format:{extract_core_stack()}") + try: + return await service.run_node(node_id, payload, ctx) + except KeyError: + raise HTTPException(status_code=404, + detail=f"node_id '{node_id}' not found or input miss required fields, traceback: {extract_core_stack()}") + except Exception as e: + # 使用错误分类器获取错误信息 + error_response = service.error_classifier.get_error_response(e, {"node_name": node_id}) + logger.error( + f"Unexpected error in http_node_run: [{error_response['error_code']}] {error_response['error_message']}, " + f"traceback: {traceback.format_exc()}", exc_info=True + ) + raise HTTPException( + status_code=500, + detail={ + "error_code": error_response["error_code"], + "error_message": error_response["error_message"], + "stack_trace": extract_core_stack(), + } + ) + finally: + cozeloop.flush() + + +@app.post("/v1/chat/completions") +async def openai_chat_completions(request: Request): + """OpenAI Chat Completions API 兼容接口""" + ctx = new_context(method="openai_chat", headers=request.headers) + request_context.set(ctx) + + logger.info(f"Received request for /v1/chat/completions: run_id={ctx.run_id}") + + try: + payload = await request.json() + return await openai_handler.handle(payload, ctx) + except json.JSONDecodeError as e: + logger.error(f"JSON decode error in openai_chat_completions: {e}") + raise HTTPException(status_code=400, detail="Invalid JSON format") + finally: + cozeloop.flush() + + +@app.get("/health") +async def health_check(): + try: + # 这里可以添加更多的健康检查逻辑 + return { + "status": "ok", + "message": "Service is running", + } + except Exception as e: + raise HTTPException(status_code=503, detail=str(e)) + + +@app.get(path="/graph_parameter") +async def http_graph_inout_parameter(request: Request): + return service.graph_inout_schema() + +def parse_args(): + parser = argparse.ArgumentParser(description="Start FastAPI server") + parser.add_argument("-m", type=str, default="http", help="Run mode, support http,flow,node") + parser.add_argument("-n", type=str, default="", help="Node ID for single node run") + parser.add_argument("-p", type=int, default=5000, help="HTTP server port") + parser.add_argument("-i", type=str, default="", help="Input JSON string for flow/node mode") + return parser.parse_args() + + +def parse_input(input_str: str) -> Dict[str, Any]: + """Parse input string, support both JSON string and plain text""" + if not input_str: + return {"text": "你好"} + + # Try to parse as JSON first + try: + return json.loads(input_str) + except json.JSONDecodeError: + # If not valid JSON, treat as plain text + return {"text": input_str} + +def start_http_server(port): + workers = 1 + reload = False + if graph_helper.is_dev_env(): + reload = True + + logger.info(f"Start HTTP Server, Port: {port}, Workers: {workers}") + uvicorn.run("main:app", host="0.0.0.0", port=port, reload=reload, workers=workers) + +if __name__ == "__main__": + args = parse_args() + if args.m == "http": + start_http_server(args.p) + elif args.m == "flow": + payload = parse_input(args.i) + result = asyncio.run(service.run(payload)) + print(json.dumps(result, ensure_ascii=False, indent=2)) + elif args.m == "node" and args.n: + payload = parse_input(args.i) + result = asyncio.run(service.run_node(args.n, payload)) + print(json.dumps(result, ensure_ascii=False, indent=2)) + elif args.m == "agent": + agent_ctx = new_context(method="agent") + for chunk in service.stream( + { + "type": "query", + "session_id": "1", + "message": "你好", + "content": { + "query": { + "prompt": [ + { + "type": "text", + "content": {"text": "现在几点了?请调用工具获取当前时间"}, + } + ] + } + }, + }, + run_config={"configurable": {"session_id": "1"}}, + ctx=agent_ctx, + ): + print(chunk) diff --git a/src/storage/__init__.py b/src/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/storage/database/__init__.py b/src/storage/database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/storage/database/db.py b/src/storage/database/db.py new file mode 100644 index 0000000..ec6aa7a --- /dev/null +++ b/src/storage/database/db.py @@ -0,0 +1,94 @@ +import os +import time +from sqlalchemy import create_engine, text +from sqlalchemy.orm import sessionmaker +from sqlalchemy.exc import OperationalError +import logging +logger = logging.getLogger(__name__) + +MAX_RETRY_TIME = 20 # 连接最大重试时间(秒) +# Load environment variables from .env if present +try: + from dotenv import load_dotenv + load_dotenv() +except Exception: + pass + +def get_db_url() -> str: + """Build database URL from environment.""" + url = os.getenv("PGDATABASE_URL") or "" + if url is not None and url != "": + return url + from coze_workload_identity import Client + try: + client = Client() + env_vars = client.get_project_env_vars() + client.close() + for env_var in env_vars: + if env_var.key == "PGDATABASE_URL": + url = env_var.value.replace("'", "'\\''") + return url + except Exception as e: + logger.error(f"Error loading PGDATABASE_URL: {e}") + raise e + finally: + if url is None or url == "": + logger.error("PGDATABASE_URL is not set") + return url +_engine = None +_SessionLocal = None + +def _create_engine_with_retry(): + url = get_db_url() + if url is None or url == "": + logger.error("PGDATABASE_URL is not set") + raise ValueError("PGDATABASE_URL is not set") + size = 100 + overflow = 100 + recycle = 1800 + timeout = 30 + engine = create_engine( + url, + pool_size=size, + max_overflow=overflow, + pool_pre_ping=True, + pool_recycle=recycle, + pool_timeout=timeout, + ) + # 验证连接,带重试 + start_time = time.time() + last_error = None + while time.time() - start_time < MAX_RETRY_TIME: + try: + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + return engine + except OperationalError as e: + last_error = e + elapsed = time.time() - start_time + logger.warning(f"Database connection failed, retrying... (elapsed: {elapsed:.1f}s)") + time.sleep(min(1, MAX_RETRY_TIME - elapsed)) + logger.error(f"Database connection failed after {MAX_RETRY_TIME}s: {last_error}") + raise last_error # pyright: ignore [reportGeneralTypeIssues] + +def get_engine(): + global _engine + if _engine is None: + _engine = _create_engine_with_retry() + return _engine + +def get_sessionmaker(): + global _SessionLocal + if _SessionLocal is None: + _SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=get_engine()) + return _SessionLocal + +def get_session(): + return get_sessionmaker()() + +__all__ = [ + "get_db_url", + "get_engine", + "get_sessionmaker", + "get_session", +] diff --git a/src/storage/database/shared/__init__.py b/src/storage/database/shared/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/storage/database/shared/model.py b/src/storage/database/shared/model.py new file mode 100644 index 0000000..0762017 --- /dev/null +++ b/src/storage/database/shared/model.py @@ -0,0 +1,8 @@ +from sqlalchemy import BigInteger, DateTime, Identity, Index, Integer, JSON, PrimaryKeyConstraint, Text, text +from typing import Optional +import datetime + +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass \ No newline at end of file diff --git a/src/storage/memory/__init__.py b/src/storage/memory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/storage/memory/memory_saver.py b/src/storage/memory/memory_saver.py new file mode 100644 index 0000000..383e0b3 --- /dev/null +++ b/src/storage/memory/memory_saver.py @@ -0,0 +1,135 @@ +import psycopg +from psycopg_pool import AsyncConnectionPool +from langgraph.checkpoint.postgres import PostgresSaver +from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from langgraph.checkpoint.memory import MemorySaver +from langgraph.checkpoint.base import BaseCheckpointSaver +from typing import Optional, Union +import logging +import time + +logger = logging.getLogger(__name__) + +# 数据库连接超时时间(秒),每次尝试 15 秒,共尝试 2 次 +DB_CONNECTION_TIMEOUT = 15 +DB_MAX_RETRIES = 2 + + +class MemoryManager: + """Memory Manager 单例类""" + + _instance: Optional['MemoryManager'] = None + _checkpointer: Optional[Union[AsyncPostgresSaver, MemorySaver]] = None + _pool: Optional[AsyncConnectionPool] = None + _setup_done: bool = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def _connect_with_retry(self, db_url: str) -> Optional[psycopg.Connection]: + """带重试的数据库连接,每次 15 秒超时,共尝试 2 次""" + last_error = None + for attempt in range(1, DB_MAX_RETRIES + 1): + try: + logger.info(f"Attempting database connection (attempt {attempt}/{DB_MAX_RETRIES})") + conn = psycopg.connect(db_url, autocommit=True, connect_timeout=DB_CONNECTION_TIMEOUT) + logger.info(f"Database connection established on attempt {attempt}") + return conn + except Exception as e: + last_error = e + logger.warning(f"Database connection attempt {attempt} failed: {e}") + if attempt < DB_MAX_RETRIES: + time.sleep(1) # 重试前短暂等待 + logger.error(f"All {DB_MAX_RETRIES} database connection attempts failed, last error: {last_error}") + return None + + def _setup_schema_and_tables(self, db_url: str) -> bool: + """同步创建 schema 和表(只执行一次),返回是否成功""" + if self._setup_done: + return True + + conn = self._connect_with_retry(db_url) + if conn is None: + return False + + try: + with conn.cursor() as cur: + cur.execute("CREATE SCHEMA IF NOT EXISTS memory") + conn.execute("SET search_path TO memory") + PostgresSaver(conn).setup() + self._setup_done = True + logger.info("Memory schema and tables created") + return True + except Exception as e: + logger.warning(f"Failed to setup schema/tables: {e}") + return False + finally: + conn.close() + + def _get_db_url_safe(self) -> Optional[str]: + """安全获取 db_url,失败时返回 None""" + try: + from storage.database.db import get_db_url + db_url = get_db_url() + if db_url and db_url.strip(): + return db_url + logger.warning("db_url is empty, will fallback to MemorySaver") + return None + except Exception as e: + logger.warning(f"Failed to get db_url: {e}, will fallback to MemorySaver") + return None + + def _create_fallback_checkpointer(self) -> MemorySaver: + """创建内存兜底 checkpointer""" + self._checkpointer = MemorySaver() + logger.warning("Using MemorySaver as fallback checkpointer (data will not persist across restarts)") + return self._checkpointer + + def get_checkpointer(self) -> BaseCheckpointSaver: + """获取 checkpointer,优先使用 PostgresSaver,失败时退化为 MemorySaver""" + if self._checkpointer is not None: + return self._checkpointer + + # 1. 尝试获取 db_url + db_url = self._get_db_url_safe() + if not db_url: + return self._create_fallback_checkpointer() + + # 2. 尝试连接数据库并创建 schema/表(带重试) + if not self._setup_schema_and_tables(db_url): + return self._create_fallback_checkpointer() + + # 3. 连接字符串加上 search_path + if "?" in db_url: + db_url = f"{db_url}&options=-csearch_path%3Dmemory" + else: + db_url = f"{db_url}?options=-csearch_path%3Dmemory" + + # 4. 尝试创建连接池和 checkpointer + try: + self._pool = AsyncConnectionPool( + conninfo=db_url, + timeout=DB_CONNECTION_TIMEOUT, + min_size=1, + max_idle=300, + check=AsyncConnectionPool.check_connection, + ) + self._checkpointer = AsyncPostgresSaver(self._pool) + logger.info("AsyncPostgresSaver initialized successfully") + except Exception as e: + logger.warning(f"Failed to create AsyncPostgresSaver: {e}, will fallback to MemorySaver") + return self._create_fallback_checkpointer() + + return self._checkpointer + +_memory_manager: Optional[MemoryManager] = None + + +def get_memory_saver() -> BaseCheckpointSaver: + """获取 checkpointer,优先使用 PostgresSaver,db_url 不可用或连接失败时退化为 MemorySaver""" + global _memory_manager + if _memory_manager is None: + _memory_manager = MemoryManager() + return _memory_manager.get_checkpointer() \ No newline at end of file diff --git a/src/storage/s3/__init__.py b/src/storage/s3/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/storage/s3/s3_storage.py b/src/storage/s3/s3_storage.py new file mode 100644 index 0000000..56ba3bc --- /dev/null +++ b/src/storage/s3/s3_storage.py @@ -0,0 +1,424 @@ +import os +import re +from pathlib import Path +from typing import Optional, Any, Dict, List, TypedDict, Iterable +from uuid import uuid4 + +import boto3 +from botocore.exceptions import ClientError +from boto3.s3.transfer import TransferConfig +import logging +logger = logging.getLogger(__name__) + +# 允许的文件名字符集(面向用户输入的约束) +FILE_NAME_ALLOWED_RE = re.compile(r"^[A-Za-z0-9._\-/]+$") + + +class ListFilesResult(TypedDict): + # list_files 的返回结构类型 + keys: List[str] + is_truncated: bool + next_continuation_token: Optional[str] + +class S3SyncStorage: + """S3兼容存储实现""" + + def __init__(self, *, endpoint_url: Optional[str] = None, access_key: str, secret_key: str, bucket_name: str, region: str = "cn-beijing"): + self.endpoint_url = os.environ.get("COZE_BUCKET_ENDPOINT_URL") or endpoint_url or '' + self.access_key = access_key + self.secret_key = secret_key + self.bucket_name = bucket_name + self.region = region + self._client = None + + def _get_client(self): + if self._client is None: + endpoint = self.endpoint_url + if endpoint is None or endpoint == "": + try: + from coze_workload_identity import Client as CozeEnvClient + coze_env_client = CozeEnvClient() + env_vars = coze_env_client.get_project_env_vars() + coze_env_client.close() + for env_var in env_vars: + if env_var.key == "COZE_BUCKET_ENDPOINT_URL": + endpoint = env_var.value.replace("'", "'\\''") + self.endpoint_url = endpoint + break + except Exception as e: + logger.error(f"Error loading COZE_BUCKET_ENDPOINT_URL: {e}") + # 保持向下校验逻辑,避免在此处中断 + if endpoint is None or endpoint == "": + logger.error("未配置存储端点:请设置endpoint_url") + raise ValueError("未配置存储端点:请设置endpoint_url") + + client = boto3.client( + "s3", + endpoint_url=endpoint, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.secret_key, + region_name=self.region, + ) + + # 注册 before-call 钩子,发送前注入 x-storage-token 头 + def _inject_header(**kwargs): + try: + from coze_workload_identity import Client as CozeClient + coze_client = CozeClient() + try: + token = coze_client.get_access_token() + except Exception as e: + logger.error("Error loading COZE_WORKLOAD_IDENTITY_TOKEN: %s", e) + token = None + raise e + finally: + coze_client.close() + params = kwargs.get("params", {}) + headers = params.setdefault("headers", {}) + headers["x-storage-token"] = token + except Exception as e: + logger.error("Error loading COZE_WORKLOAD_IDENTITY_TOKEN: %s", e) + pass + client.meta.events.register("before-call.s3", _inject_header) + self._client = client + return self._client + + def _generate_object_key(self, *, original_name: str) -> str: + suffix = Path(original_name).suffix.lower() + stem = Path(original_name).stem + uniq = uuid4().hex[:8] + return f"{stem}_{uniq}{suffix}" + + def _extract_logid(self, e: Exception) -> Optional[str]: + """从 ClientError 中提取 x-tt-logid""" + if isinstance(e, ClientError): + headers = (e.response or {}).get("ResponseMetadata", {}).get("HTTPHeaders", {}) + return headers.get("x-tt-logid") + return None + + def _error_msg(self, msg: str, e: Exception) -> str: + """构建带 logid 的错误信息""" + logid = self._extract_logid(e) + if logid: + return f"{msg}: {e} (x-tt-logid: {logid})" + return f"{msg}: {e}" + + def _resolve_bucket(self, bucket: Optional[str]) -> str: + """统一解析 bucket 来源,确保得到有效桶名。""" + target_bucket = bucket or os.environ.get("COZE_BUCKET_NAME") or self.bucket_name + if not target_bucket: + raise ValueError("未配置 bucket:请传入 bucket 或设置 COZE_BUCKET_NAME,或在实例化时提供 bucket_name") + return target_bucket + + def _validate_file_name(self, name: str) -> None: + """校验 S3 对象命名:长度≤1024;允许 [A-Za-z0-9._-/];不以 / 起止且不含 //。""" + msg = ( + "file name invalid: 文件名需满足以下 S3 对象命名规范:" + "1) 长度 1–1024 字节;" + "2) 仅允许字母、数字、点(.)、下划线(_)、短横(-)、目录分隔符(/);" + "3) 不允许空格或以下特殊字符:? # & % { } ^ [ ] ` \\ < > ~ | \" ' + = : ;;" + "4) 不以 / 开头或结尾,且不包含连续的 //;" + "示例:report_2025-12-11.pdf、images/photo-01.png。" + ) + + if not name or not name.strip(): + raise ValueError(msg + "(原因:文件名为空)") + + # S3 限制对象 key 最大 1024 字节,这里沿用到输入文件名 + if len(name.encode("utf-8")) > 1024: + raise ValueError(msg + "(原因:长度超过 1024 字节)") + + if name.startswith("/") or name.endswith("/"): + raise ValueError(msg + "(原因:以 / 开头或结尾)") + if "//" in name: + raise ValueError(msg + "(原因:包含连续的 //)") + + # 允许字符集校验 + if not FILE_NAME_ALLOWED_RE.match(name): + bad = re.findall(r"[^A-Za-z0-9._\-/]", name) + example = bad[0] if bad else "非法字符" + raise ValueError(msg + f"(原因:包含非法字符,例如:{example})") + + def upload_file(self, *, file_content: bytes, file_name: str, content_type: str = "application/octet-stream", bucket: Optional[str] = None) -> str: + # 先对输入文件名做规范校验,避免生成无效对象 key + self._validate_file_name(file_name) + try: + client = self._get_client() + object_key = self._generate_object_key(original_name=file_name) + target_bucket = self._resolve_bucket(bucket) + client.put_object(Bucket=target_bucket, Key=object_key, Body=file_content, ContentType=content_type) + return object_key + except Exception as e: + logger.error(self._error_msg("Error uploading file to S3", e)) + raise e + + def delete_file(self, *, file_key: str, bucket: Optional[str] = None) -> bool: + try: + client = self._get_client() + target_bucket = self._resolve_bucket(bucket) + client.delete_object(Bucket=target_bucket, Key=file_key) + return True + except Exception as e: + logger.error(self._error_msg("Error deleting file from S3", e)) + raise e + + def file_exists(self, *, file_key: str, bucket: Optional[str] = None) -> bool: + try: + client = self._get_client() + target_bucket = self._resolve_bucket(bucket) + client.head_object(Bucket=target_bucket, Key=file_key) + return True + except ClientError as e: + code = (e.response or {}).get("Error", {}).get("Code", "") + if code in {"404", "NoSuchKey", "NotFound"}: + return False + logger.error(self._error_msg("Error checking file existence in S3", e)) + return False + except Exception as e: + logger.error(self._error_msg("Error checking file existence in S3", e)) + return False + + def read_file(self, *, file_key: str, bucket: Optional[str] = None) -> bytes: + try: + client = self._get_client() + target_bucket = self._resolve_bucket(bucket) + resp = client.get_object(Bucket=target_bucket, Key=file_key) + body = resp.get("Body") + if body is None: + raise RuntimeError("S3 get_object returned no Body") + try: + return body.read() + finally: + try: + body.close() + except Exception as ce: + # 资源关闭失败不影响读取结果,仅记录以便排查 + logger.debug("Failed to close S3 response body: %s", ce) + except Exception as e: + logger.error(self._error_msg("Error reading file from S3", e)) + raise e + + def list_files(self, *, prefix: Optional[str] = None, bucket: Optional[str] = None, max_keys: int = 1000, continuation_token: Optional[str] = None) -> ListFilesResult: + """列出对象,支持前缀过滤与分页;返回 keys/is_truncated/next_continuation_token。""" + try: + client = self._get_client() + target_bucket = self._resolve_bucket(bucket) + if max_keys <= 0 or max_keys > 1000: + raise ValueError("max_keys 必须在 1 到 1000 之间") + + kwargs: Dict[str, Any] = { + "Bucket": target_bucket, + "MaxKeys": max_keys, + "Prefix": prefix, + "ContinuationToken": continuation_token, + } + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + resp = client.list_objects_v2(**kwargs) + contents = resp.get("Contents", []) or [] + keys: List[str] = [item.get("Key") for item in contents if isinstance(item, dict) and item.get("Key")] + return { + "keys": keys, + "is_truncated": bool(resp.get("IsTruncated")), + "next_continuation_token": resp.get("NextContinuationToken"), + } + except ClientError as e: + code = (e.response or {}).get("Error", {}).get("Code", "") + logger.error(self._error_msg(f"Error listing files in S3 (code={code})", e)) + raise e + except Exception as e: + logger.error(self._error_msg("Error listing files in S3", e)) + raise e + + def generate_presigned_url(self, *, key: str, bucket: Optional[str] = None, expire_time: int = 1800) -> str: + """通过 S3 Proxy 生成签名 URL。""" + import json + import urllib.request as urllib_request + try: + from coze_workload_identity import Client as CozeClient + coze_client = CozeClient() + try: + token = coze_client.get_access_token() + finally: + try: + coze_client.close() + except Exception: + # 资源释放失败不影响后续流程 + pass + except Exception as e: + logger.error(f"Error loading x-storage-token: {e}") + raise RuntimeError(f"获取 x-storage-token 失败: {e}") + try: + sign_base = os.environ.get("COZE_BUCKET_ENDPOINT_URL") or self.endpoint_url + if not sign_base: + raise ValueError("未配置签名端点:请设置 COZE_BUCKET_ENDPOINT_URL 或传入 endpoint_url") + sign_url_endpoint = sign_base.rstrip("/") + "/sign-url" + + headers = { + "Content-Type": "application/json", + "x-storage-token": token, + } + + target_bucket = self._resolve_bucket(bucket) + payload = {"bucket_name": target_bucket, "path": key, "expire_time": expire_time} + data = json.dumps(payload).encode("utf-8") + request = urllib_request.Request(sign_url_endpoint, data=data, headers=headers, method="POST") + except Exception as e: + logger.error(f"Error creating request for sign-url: {e}") + raise RuntimeError(f"创建 sign-url 请求失败: {e}") + + try: + with urllib_request.urlopen(request) as resp: + resp_bytes = resp.read() + content_type = resp.headers.get("Content-Type", "") + text = resp_bytes.decode("utf-8", errors="replace") + if "application/json" in content_type or text.strip().startswith("{"): + try: + obj = json.loads(text) + except Exception: + return text + data = obj.get("data") + if isinstance(data, dict) and "url" in data: + return data["url"] + url_value = obj.get("url") or obj.get("signed_url") or obj.get("presigned_url") + if url_value: + return url_value + raise ValueError("签名服务返回缺少 data.url/url 字段") + return text + except Exception as e: + raise RuntimeError(f"生成签名URL失败: {e}") + + def stream_upload_file( + self, + *, + fileobj, + file_name: str, + content_type: str = "application/octet-stream", + bucket: Optional[str] = None, + multipart_chunksize: int = 5 * 1024 * 1024, + multipart_threshold: int = 5 * 1024 * 1024, + max_concurrency: int = 1, + use_threads: bool = False, + ) -> str: + """流式上传(文件对象) + - fileobj: 任何带有 read() 方法的文件对象(如 open(..., 'rb') 返回的对象、io.BytesIO 等) + - file_name: 原始文件名,用于生成唯一 key + - content_type: MIME 类型 + - bucket: 目标桶;为空时取环境变量或实例默认值 + - multipart_chunksize: 分片大小(默认 5MB,以适配代理层限制) + - multipart_threshold: 触发分片上传的阈值(默认 5MB) + - max_concurrency: 并发分片上传的并发数(默认 1,避免代理层节流影响) + - use_threads: 是否启用线程并发(默认 False) + 返回:最终写入的对象 key + """ + try: + client = self._get_client() + target_bucket = self._resolve_bucket(bucket) + key = self._generate_object_key(original_name=file_name) + + extra_args = {"ContentType": content_type} if content_type else {} + # 使用 boto3 的高阶方法执行多段上传(传入 TransferConfig 控制分片大小) + + config = TransferConfig( + multipart_chunksize=multipart_chunksize, + multipart_threshold=multipart_threshold, + max_concurrency=max_concurrency, + use_threads=use_threads, + ) + client.upload_fileobj(Fileobj=fileobj, Bucket=target_bucket, Key=key, ExtraArgs=extra_args, Config=config) + return key + except Exception as e: + logger.error(self._error_msg("Error streaming upload (fileobj) to S3", e)) + raise e + + def upload_from_url( + self, + *, + url: str, + bucket: Optional[str] = None, + timeout: int = 30, + ) -> str: + """从 URL 流式下载并上传到 S3 + - url: 源文件 URL + - bucket: 目标桶;为空时取环境变量或实例默认值 + - timeout: HTTP 请求超时时间(秒,默认 30) + 返回:最终写入的对象 key + """ + import urllib.request as urllib_request + from urllib.parse import urlparse, unquote + try: + request = urllib_request.Request(url) + with urllib_request.urlopen(request, timeout=timeout) as resp: + parsed = urlparse(url) + file_name = Path(unquote(parsed.path)).name or "file" + content_type = resp.headers.get("Content-Type", "application/octet-stream") + return self.stream_upload_file( + fileobj=resp, + file_name=file_name, + content_type=content_type, + bucket=bucket, + ) + except Exception as e: + logger.error(self._error_msg("Error uploading from URL to S3", e)) + raise e + + def trunk_upload_file(self, *, chunk_iter: Iterable[bytes], file_name: str, + content_type: str = "application/octet-stream", bucket: Optional[str] = None, + part_size: int = 5 * 1024 * 1024) -> str: + """流式上传(字节迭代器,显式分片 Multipart Upload) + - chunk_iter: 可迭代对象,逐块产生 bytes;每块大小可变(内部累积到 part_size 再上传),最后一块可小于 5MB + - file_name: 原始文件名,用于生成唯一 key + - content_type: MIME 类型 + - bucket: 目标桶;为空时取环境或实例默认值 + - part_size: 每个 part 的最小大小(除最后一个);默认 5MB + 返回:最终写入的对象 key + """ + client = self._get_client() + target_bucket = self._resolve_bucket(bucket) + key = self._generate_object_key(original_name=file_name) + + # 初始化分片上传 + try: + init_resp = client.create_multipart_upload(Bucket=target_bucket, Key=key, ContentType=content_type) + upload_id = init_resp["UploadId"] + except Exception as e: + logger.error(self._error_msg("create_multipart_upload failed", e)) + raise e + + parts = [] + part_number = 1 + buffer = bytearray() + try: + for chunk in chunk_iter: + if not chunk: + continue + buffer.extend(chunk) + while len(buffer) >= part_size: + data = bytes(buffer[:part_size]) + buffer = buffer[part_size:] + resp = client.upload_part(Bucket=target_bucket, Key=key, UploadId=upload_id, PartNumber=part_number, + Body=data) + parts.append({"PartNumber": part_number, "ETag": resp["ETag"]}) + part_number += 1 + + # 上传最后不足 part_size 的余量 + if len(buffer) > 0: + resp = client.upload_part(Bucket=target_bucket, Key=key, UploadId=upload_id, PartNumber=part_number, + Body=bytes(buffer)) + parts.append({"PartNumber": part_number, "ETag": resp["ETag"]}) + + # 完成分片 + client.complete_multipart_upload( + Bucket=target_bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={"Parts": parts}, + ) + return key + except Exception as e: + logger.error(self._error_msg("multipart upload failed", e)) + try: + client.abort_multipart_upload(Bucket=target_bucket, Key=key, UploadId=upload_id) + except Exception as ae: + logger.error(self._error_msg("abort_multipart_upload failed", ae)) + raise e diff --git a/src/tools/__init__.py b/src/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/file/__init__.py b/src/utils/file/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/file/file.py b/src/utils/file/file.py new file mode 100644 index 0000000..4c6925d --- /dev/null +++ b/src/utils/file/file.py @@ -0,0 +1,325 @@ +import os +import requests +import uuid +import chardet +from io import BytesIO +from typing import Literal,Callable, Any, Optional,Union +from pydantic import BaseModel, Field, field_validator,PrivateAttr,ConfigDict +from urllib.parse import urlparse +from pptx import Presentation + +MAX_FILE_SIZE = 100 * 1024 * 1024 + +class File(BaseModel): + """ + 通用文件对象,支持自动类型推断和路径管理 + """ + url: str = Field(..., description="文件URL(http/https)或本地路径") + file_type: Literal['image', 'video', 'audio', 'document', 'default'] = Field( + default="default", + description="文件类型" + ) + _local_path: Optional[str] = PrivateAttr(default=None) + model_config = ConfigDict( + json_schema_extra={ + "x-component": "file-upload", # 前端用文件上传组件 + } + ) + + def set_cache_path(self, path: str): + """设置缓存路径""" + self._local_path = path + + def get_cache_path(self) -> Optional[str]: + """获取缓存路径(如果文件实际存在)""" + return self._local_path + + @property + def is_remote(self) -> bool: + """判断是网络URL还是本地文件""" + return self.url.startswith(('http://', 'https://')) + +def infer_file_category(path_or_url: str) -> tuple[str, str]: + """ + 根据路径或URL后缀判断文件类型 + 逻辑: + 1. 解析 URL 去除 query 参数 (?id=...),提取 path + 2. 获取 path 最后一部分的文件名和后缀 + 3. 查表判断,匹配不到则返回 'default' + + Return: + - 分类:image, video, audio, document, default + - 后缀:.pdf + + """ + + # === 步骤 1 & 2: 提取纯净的后缀名 === + # urlparse 可以同时处理本地路径 (会被视为 path) 和 网络 URL + parsed = urlparse(path_or_url) + path = parsed.path # 提取路径部分,忽略 http://... 和 ?query=... + + # 获取文件名 (例如 /a/b/test.jpg -> test.jpg) + filename = os.path.basename(path) + + # 分离后缀 (test.jpg -> .jpg) + _, ext_with_dot = os.path.splitext(filename) + + # 如果没有后缀,直接兜底 + if not ext_with_dot: + return 'default', "" + + # 去除点并转小写 (例如 .JPG -> jpg) + ext = ext_with_dot.lstrip('.').lower() + + # === 步骤 3: 查表匹配 === + # 定义常见映射表 + TYPE_MAPPING = { + 'image': { + 'apng', 'avif', 'bmp', 'gif', 'heic', 'ico', 'jpg', 'jpeg', 'png', 'svg', 'tiff', 'webp' + }, + 'video': { + 'mp4', 'avi', 'mov', 'mkv', 'flv', 'wmv', 'webm', 'm4v', '3gp' + }, + 'audio': { + 'mp3', 'wav', 'flac', 'aac', 'ogg', 'wma', 'm4a' + }, + 'document': { + 'pdf', 'doc', 'docx', 'xls', 'xlsx', 'ppt', 'pptx', + 'txt', 'md', 'csv', 'json', 'xml', 'html', 'htm' + }, + } + + for category, extensions in TYPE_MAPPING.items(): + if ext in extensions: + return category, ext_with_dot + + return 'default', ext_with_dot + +class FileOps: + DOWNLOAD_DIR = "/tmp" + + @staticmethod + def _get_bytes_stream(file_obj:File) -> tuple[bytes, str]: + """ + 获取文件内容和后缀, 大小限制检查, 超出抛异常 + """ + _, ext = infer_file_category(file_obj.url) + + if file_obj.is_remote: + try: + # stream=True: 此时只下载 Headers,连接保持打开,还没下载 Body + with requests.get(file_obj.url, stream=True, timeout=60) as resp: + resp.raise_for_status() + + content_length = resp.headers.get('Content-Length') + if content_length and int(content_length) > MAX_FILE_SIZE: + raise Exception( + f"文件大小 ({int(content_length)} bytes) 超过限制 100MB,已终止下载。" + ) + + # 场景:Header 缺失 Content-Length 或服务器 Header 欺骗 + downloaded_content = BytesIO() + current_size = 0 + + # 分块读取,每块 8KB + for chunk in resp.iter_content(chunk_size=8192): + if chunk: + current_size += len(chunk) + if current_size > MAX_FILE_SIZE: + raise Exception(f"检测到文件超过 100MB,已中断。") + downloaded_content.write(chunk) + + # 获取完整 bytes + return downloaded_content.getvalue(), ext + + except requests.RequestException as e: + raise RuntimeError(f"网络请求失败: {e}") + + else: + if not os.path.exists(file_obj.url): + raise FileNotFoundError(f"本地文件不存在: {file_obj.url}") + + ''' + file_size = os.path.getsize(file_obj.url) + if file_size > MAX_FILE_SIZE: + raise Exception(f"本地文件大小 ({file_size} bytes) 超过限制 100MB") + ''' + + with open(file_obj.url, 'rb') as f: + return f.read(), ext + + @staticmethod + def save_to_local(file_obj: File, filename: str) -> str: + """ + 将当前文件对象的内容保存到本地路径, 返回本地路径 + 如果是本地路径,直接返回 + """ + if not file_obj.is_remote: + if os.path.exists(file_obj.url): + return file_obj.url + + raise FileNotFoundError(f"Local file not found: {file_obj.url}") + + try: + os.makedirs(FileOps.DOWNLOAD_DIR, exist_ok=True) + + # 简单的文件名生成策略 (真实场景建议用 url hash 避免重复下载) + # ext = os.path.splitext(file_obj.url.split('?')[0])[1] or ".tmp" + # filename = f"{uuid.uuid4().hex}{ext}" + local_path = os.path.join(FileOps.DOWNLOAD_DIR, filename) + + headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'} + with requests.get(file_obj.url, headers=headers, stream=True, timeout=120) as r: + r.raise_for_status() + with open(local_path, 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + return local_path + except Exception as e: + raise RuntimeError(f"Download failed for {file_obj.url}: {str(e)}") + + @staticmethod + def read_bytes(file_obj:File) -> bytes: + """ + 获取文件的原始二进制数据 + 场景:上传到OSS、保存到本地、传给图像处理库 + """ + content, _ = FileOps._get_bytes_stream(file_obj) + return content + + @staticmethod + def extract_text(file_obj: File) -> str: + """ + 提取文本内容 + 场景:RAG、HTML解析、文档分析 + """ + try: + content, ext = FileOps._get_bytes_stream(file_obj) + + if ext in ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx']: + return FileOps._parse_document_bytes(file_obj, content, ext) + + # 默认直接读 + charset = chardet.detect(content) + if 'encoding' in charset: + return content.decode(charset['encoding']) + else: + return content.decode('utf-8') + + except Exception as e: + return f"[FileOps Error] Failed to read content: {str(e)}" + + @staticmethod + def _parse_document_bytes(file_obj: File, content: bytes, ext:str) -> str: + stream = BytesIO(content) + text_result = "" + + try: + if ext == '.pdf': + import pypdf + reader = pypdf.PdfReader(stream) + for page in reader.pages: + text_result += page.extract_text() + "\n" + elif ext in ['.docx', '.doc']: + text_result = read_docx(stream) + elif ext in ['.xlsx', '.xls', '.csv']: + import pandas as pd + if ext == '.csv': + df = pd.read_csv(stream) + else: + df = pd.read_excel(stream) + text_result = df.to_string() + elif ext in ['.ppt', '.pptx']: + text_result = read_ppt(stream) + else: + text_result = f"[暂不支持解析该文档格式: {ext}]" + except ImportError as e: + text_result = f"[解析库缺失] {e}" + except Exception as e: + text_result = f"[解析失败] {e}" + + return text_result + +def read_docx(cont_stream) -> str: + """ + 使用docx2python按顺序读取内容 + """ + from docx2python import docx2python + doc_result = docx2python(cont_stream) + + # 获取文档结构 + all_parts = [] + + # docx2python以嵌套列表形式返回内容 + # 遍历文档主体 + for section in doc_result.body: + if isinstance(section, list): + for item in section: + if isinstance(item, list): + # 可能是表格或多级内容 + for sub_item in item: + if isinstance(sub_item, str) and sub_item.strip(): + all_parts.append(sub_item.strip()) + elif isinstance(sub_item, list): + # 表格行 + row_text = "\n".join([str(cell).strip() for cell in sub_item if str(cell).strip()]) + if row_text: + all_parts.append(row_text) + elif isinstance(item, str) and item.strip(): + all_parts.append(item.strip()) + + # 关闭文档 + doc_result.close() + + return "\n\n".join(all_parts) + +def read_ppt(file_input: Union[str, bytes, BytesIO]) -> str: + if not Presentation: + return "[Error] 未安装 python-pptx 库,无法解析 PPT 文件" + + # 1. 统一转换为文件流对象 (BytesIO) + if isinstance(file_input, str): + with open(file_input, 'rb') as f: + ppt_stream = BytesIO(f.read()) + elif isinstance(file_input, bytes): + ppt_stream = BytesIO(file_input) + else: + ppt_stream = file_input + + try: + prs = Presentation(ppt_stream) + full_text = [] + + for i, slide in enumerate(prs.slides): + page_content = [] + page_content.append(f"=== 第 {i+1} 页 ===") + + # shape.text_frame 包含了形状内的文本段落 + for shape in slide.shapes: + # 提取普通文本框 + if hasattr(shape, "text") and shape.text.strip(): + page_content.append(shape.text.strip()) + + # B. 提取表格内容 (普通 shape.text 无法获取表格内的字) + if shape.has_table: + table_texts = [] + for row in shape.table.rows: + row_cells = [cell.text_frame.text.strip() for cell in row.cells if cell.text_frame.text.strip()] + if row_cells: + table_texts.append(" | ".join(row_cells)) + if table_texts: + page_content.append("[表格]\n" + "\n".join(table_texts)) + + # 很多重要信息藏在备注里 + if slide.has_notes_slide: + notes = slide.notes_slide.notes_text_frame.text + if notes.strip(): + page_content.append(f"[备注]: {notes.strip()}") + + full_text.append("\n".join(page_content)) + + return "\n\n".join(full_text) + + except Exception as e: + return f"[PPT解析失败] {str(e)}"