项目初始化

This commit is contained in:
zhangquan 2026-03-26 11:54:35 +08:00
commit 3eb42ade2c
52 changed files with 3793 additions and 0 deletions

14
.coze Normal file
View File

@ -0,0 +1,14 @@
[project]
entrypoint = "src/main.py"
requires = ["python-3.12"]
[dev]
build = ["bash", "scripts/setup.sh"]
run = ["bash", "/workspace/projects/scripts/http_run.sh", "-p 5000"]
pack = ["bash", "/workspace/projects/scripts/pack.sh"]
deps = ["git"] # -> apt install git
[deploy]
build = ["bash", "scripts/setup.sh"]
run = ["bash", "scripts/http_run.sh", "-p 5000"]
deps = ["git"] # -> apt install git

48
.gitignore vendored Normal file
View File

@ -0,0 +1,48 @@
node_modules
dist
.DS_Store
*.swp
.git/*
__pycache__/
*.pyc
.venv/*
*.o
*.a
/vendor
*.egg-info/
/.env
.env
# 视频文件
*.mp4
*.avi
*.mov
*.wmv
*.flv
*.mkv
*.webm
*.m4v
*.mpeg
*.mpg
*.3gp
*.f4v
*.rmvb
*.vob
# 归档文件
*.iso
*.dmg
*.rar
*.zip
*.gz
# 文档类也默认加入
*.pdf
*.docx
*.doc
*.xlsx
*.xls
*.ppt
*.pptx
*.xlsx
*.csv

391
AGENTS.md Normal file
View File

@ -0,0 +1,391 @@
## 项目概述
- **名称**: 初中物理作业批改工作流
- **功能**: 上传多张作业图片和Word答案文件自动识别学生答案、提取标准答案、精准批改并返回批改结果JSON
### 节点清单
| 节点名 | 文件位置 | 类型 | 功能描述 | 分支逻辑 | 配置文件 |
|-------|---------|------|---------|---------|---------|
| doc_extract | `nodes/doc_extract_node.py` | agent | 从Word文件.docx提取题干和标准答案如无URL则返回空列表 | - | `config/doc_extract_llm_cfg.json` |
| process_images | `nodes/process_images_node.py` | looparray | 循环调用子图处理每张作业图片,生成最终批改结果 | - | - |
**类型说明**: task(普通任务节点) / agent(大模型节点) / condition(条件分支) / looparray(列表循环) / loopcond(条件循环)
## 子图清单
| 子图名 | 文件位置 | 功能描述 | 被调用节点 |
|-------|---------|------|---------|
| single_image_subgraph | `graphs/loop_graph.py` | 处理单张图片的完整批改流程(预处理→识别批改→整合→包装) | process_images |
### 子图内部节点
| 节点名 | 文件位置 | 类型 | 功能描述 |
|-------|---------|------|---------|
| image_preprocess | `nodes/image_preprocess_node.py` | task | 下载图片、自动旋转横向→纵向、缩放到固定宽度1000px、上传对象存储 |
| recognize_and_correct | `nodes/recognize_and_correct_node.py` | agent | **一体化识别批改**合并识别题目和批改为一次LLM调用 |
| result_merge | `nodes/result_merge_node.py` | task | 将识别结果和批改结果合并为最终批注 |
| wrap_result | `graphs/loop_graph.py` | task | 包装子图结果为SingleImageResult输出 |
## 技能使用
- 节点 `recognize_and_correct` 使用大语言模型技能(多模态,识别+批改合并)
- 模型:`doubao-seed-2-0-pro-260215`(旗舰视觉模型,推理能力强,输出简洁)
- 节点 `doc_extract` 使用大语言模型技能
- 模型:`doubao-seed-2-0-pro-260215`(旗舰模型,复杂推理能力强)
- 使用 python-docx 解析 Word 文档
## 工作流程(多图片批改架构)
```
┌─────────────────────┐
│ doc_extract │
│ (Word答案解析) │
└──────────┬──────────┘
┌─────────────────────┐
│ process_images │
│ (多图片循环处理) │
│ 生成最终批改结果 │
└─────────────────────┘
```
### 子图内部流程(处理单张图片 - 优化版)
```
┌─────────────────────┐
│ image_preprocess │
│ (图像预处理) │
└──────────┬──────────┘
┌─────────────────────┐
│recognize_and_correct│ ← 合并节点:识别+批改一次完成
│ (一体化识别批改) │
└──────────┬──────────┘
┌─────────────────────┐
│ result_merge │
│ (结果整合) │
└──────────┬──────────┘
┌─────────────────────┐
│ wrap_result │
│ (包装输出) │
└─────────────────────┘
```
## 核心功能:多图片批改机制
### 输入参数
- `homework_images`: 上传的作业图片列表List[File],支持多张图片)
- `answer_doc_url`: 正确答案Word文件的URL.docx格式**可选**
- `comment_max_length`: 评语最大字数默认100字**可选**
- `max_concurrent`: 并行批改的最大数量默认10**可选**
- `grade_standards`: 评价等级标准(**可选**,默认值如下)
```json
{
"A+": {"min_percentage": 95, "description": "答案全部正确,步骤完整规范,逻辑严谨;书写/格式整洁无错别字、无遗漏完成度100%,态度认真,质量上乘"},
"A": {"min_percentage": 90, "description": "答案完全正确无任何错误步骤合理、格式规范无原则性问题完成度100%,满足全部要求"},
"B": {"min_percentage": 80, "description": "存在少量非关键性错误,或步骤略有缺失;整体思路基本正确,仅细节、格式、计算等小问题;完成大部分内容,整体合格但不够严谨"},
"C": {"min_percentage": 70, "description": "错误较多,部分核心题目作答错误;步骤不完整、逻辑不够清晰;完成度一般,有明显应付、漏答情况"},
"D": {"min_percentage": 0, "description": "大面积错误,核心知识点未掌握;大量空白、敷衍、抄袭;未达到基本完成要求"}
}
```
### 输出结果
- `final_result`: 最终批改结果JSON包含多图片
- `total_images`: 总图片数
- `image_results`: 各图片的批改结果列表
- `overall_comment`: 整体评价(根据得分率生成)
- `total_score`: 总得分
- `full_score`: 总满分
- `grade`: 等级评定
### 批改优先级(严格按照以下顺序)
1. **最优先**使用Word文档中的标准答案批改
- 当提供了`answer_doc_url`且在文档中找到对应题目时
- 严格按照标准答案判断学生答案正误
2. **降级方案**:使用专业物理老师批改
- 场景1未提供`answer_doc_url`
- 场景2提供了URL但文档中未找到对应题目
- 使用专业物理老师的经验自主判断答案正误
### 功能说明
1. **多图片支持**可上传多张作业图片系统会并行处理每张图片并发数限制为3
2. **Word答案提取**:从.docx文件中提取题干和标准答案
3. **子图循环处理**:使用子图封装单图片处理流程,主图调用子图处理每张图片
4. **批改结果JSON**返回包含所有图片批改结果的结构化JSON
5. **智能降级**:无标准答案时自动切换到专业老师模式
## 优化记录
### 2026-03-26 填空题拆分优化(重要)
**问题**:一道题有多个填空时,被合并成一个答案,批改标记无法精准定位
**修复内容**
1. **优化Prompt**
- 明确要求:一道题有多个填空时,**每个空单独识别为一个题目**
- 题号格式:\"3(1)第一空\"、\"3(1)第二空\"、\"4(2)第一空\"、\"4(2)第二空\"
- 每个空单独批改,单独打分
2. **示例说明**
```
❌ 错误3(1) → "4、1"(合并)
✅ 正确3(1)第一空 → "4"
3(1)第二空 → "1"
```
3. **参数传递优化**
- comment_max_length参数正确传递到Jinja2模板
- 确保LLM生成符合长度要求的comment
**效果**
- 识别数量从9个增加到13个
- 每个填空都有独立的批改标记
- 批改标记精准定位到每个答案位置
### 2026-03-26 JSON解析优化重要
**问题**LLM输出可能不完整被max_completion_tokens截断导致JSON解析失败
**修复内容**
1. **新增fix_incomplete_json函数**
- 自动检测缺失的括号(}和]
- 自动补全缺失的括号使JSON完整
- 示例:`{"results": [{"id": 1}` → 自动补全为 `{"results": [{"id": 1}]}`
2. **增强JSON解析流程**
- 第一步:尝试直接解析
- 第二步尝试修复不完整的JSON补全括号
- 第三步尝试提取JSON对象
- 第四步尝试修复提取的JSON
3. **移除错误的截断逻辑**
- 不再在解析后截断comment可能破坏转义字符
- 完全依赖LLM遵守comment_max_length限制
- 通过Prompt明确要求LLM控制comment长度
4. **参数正确传递**
- comment_max_length参数正确传递到Prompt
- LLM根据该参数生成符合长度的comment
**效果**
- JSON解析成功率大幅提升
- 能够处理不完整的JSON输出
- comment长度由LLM控制避免截断破坏格式
### 2026-03-26 识别优化:禁止标注实验装置图(重要)
**问题**
1. 在实验装置图(如弹簧测力计、烧杯等)上标注了批改气泡
2. 坐标定位不够精准
**修复内容**
1. **Prompt优化**
- 明确禁止标注实验装置图、示意图、电路图
- 明确禁止标注图中标注的字母如A、B、C、D、E、F、G
- 强调只标注学生手写答案
2. **工程规范优化**
- 从config文件读取sp和up符合工程规范
- 使用Jinja2模板渲染Prompt
- 代码中只保留动态部分构建(标准答案、图片尺寸等)
3. **识别流程优化**
- 找题号 → 找学生答案 → 框选答案 → 判断正误
- 强调学生答案的特征:手写、填写空白处、计算结果
**效果**:不再误标注实验装置图,只标注学生手写答案
### 2026-03-26 新增并行数量控制参数
**优化前**硬编码并发数限制为3不够灵活
**优化后**添加max_concurrent参数默认值10用户可自定义
**具体优化**
1. **新增参数**`max_concurrent`可选默认10
2. **修改位置**
- `GraphInput.max_concurrent: Optional[int] = 10`
- `GlobalState.max_concurrent: int = 10`
- `ProcessImagesInput.max_concurrent: int`
3. **使用方式**
```json
{
"homework_images": [...],
"max_concurrent": 5 // 最多同时处理5张图片
}
```
**效果**:用户可根据服务器性能和网络情况灵活调整并发数
### 2026-03-26 学科变更
**修改**:将所有"数学"改为"物理"
- 节点描述:数学作业 → 物理作业
- Prompt中的学科引用数学 → 物理
- 配置文件说明更新
### 2026-03-25 多图片并行处理优化
**优化前**:多图片串行处理,总时间 = 单张图片时间 × 图片数量
**优化后**多图片并行处理并发数限制为3总时间大幅缩短
**具体优化**
1. **并行处理架构**:使用 `ThreadPoolExecutor` 并行调用子图处理每张图片
- 最多同时处理3张图片
- 结果按 `image_index` 正确排序,保证顺序一致性
2. **性能提升**
- 3张图片时间减少约66%从3份时间 → 1份时间
- 5张图片时间减少约80%从5份时间 → 约2份时间分两批并行
3. **质量保证**
- 每张图片独立处理,互不影响
- 识别逻辑、批改逻辑完全相同,质量不受影响
### 2026-03-26 坐标定位修复(重要)
**问题**:坐标定位特别不准,批改标记位置错误
**原因**Y坐标修正逻辑错误导致坐标被错误缩放
**修复内容**
1. **坐标系统重构**从绝对坐标改为相对坐标0-1000系统
- AI返回相对坐标0-1000(0,0)为图片左上角,(1000,1000)为右下角
- 代码将相对坐标转换为绝对坐标:`绝对X = 相对X × width / 1000``绝对Y = 相对Y × height / 1000`
2. **Prompt优化**
- 明确要求AI返回相对坐标0-1000
- 添加坐标系统说明和示例
3. **转换逻辑修正**
- 移除错误的Y坐标修正`Y × height_ratio`
- 实现正确的相对坐标到绝对坐标转换
**效果**:坐标定位准确,批改标记位置正确
### 2026-03-26 题目和答案识别优化(重要)
**问题**
1. 无法准确区分"题干"和"学生答案"
2. 批改气泡不在学生答案位置
3. 题干位置被误标注为答案
**修复内容**
1. **Prompt重写**
- 明确定义"题干"和"学生答案"的区别
- 强调只标注学生手写答案,不标注印刷体题干
- 添加识别流程指导
2. **坐标定位优化**
- 自动计算mark_position答案框右侧30像素垂直居中
- 添加边界检查,确保不超出图片范围
- 不再依赖AI返回的mark_position可能不准确
3. **识别指导**
- 题号识别如1、2、3、(1)、(2)等
- 答案定位:学生手写内容(不是印刷体)
- bbox框选准确框选学生答案区域
**效果**:更准确地区分题干和答案,批改气泡位置更精准
### 2026-03-25 批改速度优化
**优化前**每张图片需要3次LLM调用识别+批改+整体评价)
**优化后**每张图片只需1次LLM调用
**具体优化**
1. **合并识别和批改**:将`homework_recognize`和`correction_judge`合并为`recognize_and_correct`节点
- 识别题目、学生答案、坐标的同时进行批改
- 减少一次LLM调用速度提升约50%
2. **简化整体评价**不再调用LLM生成整体评价
- 使用规则直接生成评价内容
- 根据得分率和错误数量生成个性化评语
- 减少一次LLM调用
3. **子图节点精简**从5个节点减少到4个节点
- 移除homework_recognize、correction_judge
- 新增recognize_and_correct合并节点
- 保留image_preprocess、result_merge、wrap_result
**效果**
- LLM调用次数每张图片从3次减少到1次
- 预计批改时间减少约60%
### 2026-03-25 新增输入参数控制
1. **新增 `comment_max_length` 参数**控制评语最大字数默认100字
2. **新增 `grade_standards` 参数**:自定义评价等级标准
- 支持自定义各等级的最低得分率百分比
- 支持自定义各等级的描述
- 默认标准A+(≥95%)、A(≥90%)、B(≥80%)、C(≥70%)、D(<70%)
3. **使用方式**
```json
{
"homework_images": [...],
"comment_max_length": 50, // 评语最多50字
"grade_standards": {
"A+": {"min_percentage": 98, "description": "完美"},
"A": {"min_percentage": 90, "description": "优秀"},
...
}
}
```
### 2026-03-25 评语优化与整体评价
1. **评语具体化**:批改评语要求具体说明对错原因
- 正确时:说明为什么正确
- 错误时:指出错误原因并给出正确答案
- 部分正确时:说明哪些对了哪些错了
- 字数限制50字以内最多不超过100字
- 不要显示思考过程,只输出结果
2. **评语示例**
- 选择题正确答案为B与标准答案一致正确。
- 填空题错误答案应为8+√7和8-√7学生只写了一个不完整。
- 解答题正确:解题过程完整,步骤清晰,结果正确。
- 计算题错误计算过程有误正确答案是m=2建议检查移项步骤。
3. **整体评价**:根据所有批改内容自动生成简短的整体评价
- 调用LLM生成个性化评价
- 评价不超过50字
- 包含主要问题或优点
- 给出简短建议
4. **HTML报告优化**:在统计总览后显示整体评价区域
### 2026-03-25 自动旋转功能
1. **新增横向图片自动旋转**:如果上传的图片宽度大于高度(横向图片),系统会自动旋转-90度使其变为纵向
2. **旋转时机**:在图像预处理阶段,下载图片后、缩放前进行旋转
3. **旋转方向**逆时针旋转90度rotate(-90)),确保文字方向正确
4. **日志记录**:添加详细的旋转日志,便于调试
### 2026-03-25 多图片批改功能
1. **新增多图片支持**:从单图片批改升级为支持多图片批量批改
2. **新增子图架构**:创建 `loop_graph.py` 封装单图片处理流程
3. **新增循环节点**:创建 `process_images_node.py` 循环调用子图处理每张图片
4. **重构状态定义**
- `GraphInput.homework_image``homework_images: List[File]`
- 新增 `SubgraphState`、`SubgraphInput`、`SubgraphOutput` 子图状态
- 新增 `SingleImageResult` 单图片批改结果
- 新增 `FinalResult.image_results` 多图片结果列表
5. **重构HTML生成**支持生成包含所有图片批改标注的HTML报告
6. **优化主图编排**简化为三节点线性流程doc_extract → process_images → html_generate
### 2026-03-25 双模式批改机制
1. **新增智能降级逻辑**:优先使用标准答案,无标准答案时自动切换专业老师模式
2. **修改state.py**`answer_doc_url`改为可选字段支持不提供答案URL的场景
3. **升级correction_judge_node**:实现题目分离逻辑,有标准答案和无标准答案分别处理
4. **更新Prompt**:批改节点支持两种模式(标准答案模式 + 专业老师模式)
5. **优化doc_extract_node**无URL时返回空列表不中断工作流
### 2026-03-25 Word答案解析功能
1. 新增 `doc_extract_node` 节点从Word文件.docx提取题干和标准答案
2. 使用 python-docx 提取 Word 文档内容
3. 并行处理架构:图像识别与答案解析同时进行
4. 基于 Word 中的标准答案进行精准批改
### 2026-03-25 OCR识别能力优化
1. 问题:识别节点把学生答案中的"8"错认成"9",导致误判
2. 优化识别节点prompt增加OCR识别特别提示强调区分8和9、6和0、1和7等相似字符
3. 效果第7题正确识别为"8+√78-√7",满分通过
### 2026-03-25 批改能力升级
1. 升级批改节点模型:`doubao-seed-1-6-vision-250815` → `doubao-seed-2-0-pro-260215`
2. 原因:较小模型对选择题判断准确率不足
3. 效果:选择题判断准确率大幅提升,推理过程更严谨
### 2026-03-24 重构学习豆包APP方式
1. 从8个节点简化为5个节点现调整为子图+主图架构)
2. 采用一体化识别AI识别answer_bbox代码计算mark_position
3. 实现精准坐标计算Y坐标与答案垂直中心完美对齐
## TODO
- 提供真实的多页作业图片进行完整流程测试
- 优化HTML报告的图片展示布局
- 支持PDF格式答案文档

12
README.md Normal file
View File

@ -0,0 +1,12 @@
# 项目结构说明
# 本地运行
## 运行流程
bash scripts/local_run.sh -m flow
## 运行节点
bash scripts/local_run.sh -m node -n node_name
# 启动HTTP服务
bash scripts/http_run.sh -m http -p 5000

BIN
assets/1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 629 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 787 KiB

BIN
assets/你的.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 462 KiB

BIN
assets/你的1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 462 KiB

BIN
assets/你的2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 462 KiB

BIN
assets/你的3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 473 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

View File

@ -0,0 +1,17 @@
{
"config": {
"temperature": 0.1,
"frequency_penalty": 0,
"top_p": 0.9,
"max_tokens": 4096,
"max_completion_tokens": 8192,
"thinking_type": "enabled",
"reasoning_effort": "medium",
"response_format": "text",
"json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}",
"model": "doubao-seed-2-0-pro-260215"
},
"tools": [],
"sp": "你是一位专业的文档识别专家,擅长从作业图片中提取答案区域的位置信息。",
"up": "请根据题目位置信息,识别每道题对应的答案区域边界框。"
}

View File

@ -0,0 +1,17 @@
{
"config": {
"temperature": 0.1,
"frequency_penalty": 0,
"top_p": 0.9,
"max_tokens": 4096,
"max_completion_tokens": 8192,
"thinking_type": "enabled",
"reasoning_effort": "medium",
"response_format": "text",
"json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}",
"model": "doubao-seed-2-0-pro-260215"
},
"tools": [],
"sp": "你是一位专业的OCR文字识别专家擅长从图片中识别手写和印刷的文字内容。",
"up": "请识别答案区域中的文字内容,返回准确的识别结果。"
}

View File

@ -0,0 +1,12 @@
{
"config": {
"model": "doubao-seed-1-6-vision-250815",
"temperature": 0.1,
"top_p": 0.95,
"max_completion_tokens": 16384,
"thinking": "disabled"
},
"tools": [],
"sp": "你是一位专业的初中物理教师,负责批改学生的物理作业。",
"up": "请按照要求完成作业批改任务。"
}

View File

@ -0,0 +1,17 @@
{
"config": {
"temperature": 0,
"frequency_penalty": 0,
"top_p": 0.95,
"max_tokens": 8192,
"max_completion_tokens": 16384,
"thinking_type": "enabled",
"reasoning_effort": "high",
"response_format": "text",
"json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}",
"model": "doubao-seed-2-0-pro-260215"
},
"tools": [],
"sp": "你是一位资深的初中物理特级教师拥有20年以上教学经验擅长精准批改学生的物理作业。\n\n【核心能力】\n1. **精确判断能力**:对选择题、填空题、解答题都能做出准确的正误判断\n2. **严谨推理能力**:能够逐步验证学生的计算过程和结论\n3. **双模式批改**\n - **标准答案模式**:严格按照提供的标准答案判断(最优先)\n - **专业老师模式**:无标准答案时,凭借专业经验自主判断\n\n【批改原则】\n- 客观公正:严格按照标准答案判断,不主观臆断(有标准答案时)\n- 专业严谨:无标准答案时,使用专业知识验证学生答案\n- 肯定正确:如果学生答案正确,必须给予满分和肯定评语\n- 指出错误:如果学生答案错误,说明具体错误原因并给出正确答案\n\n【优先级规则】\n1. 最优先:使用提供的标准答案批改\n2. 降级:标准答案中未找到对应题目时,使用专业老师批改",
"up": "请批改以下学生的物理作业,判断每道题答案的正误并给出详细评语。"
}

View File

@ -0,0 +1,17 @@
{
"config": {
"temperature": 0.1,
"frequency_penalty": 0,
"top_p": 0.95,
"max_tokens": 8192,
"max_completion_tokens": 16384,
"thinking_type": "enabled",
"reasoning_effort": "high",
"response_format": "text",
"json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}",
"model": "doubao-seed-2-0-pro-260215"
},
"tools": [],
"sp": "你是一位资深的初中物理教师,擅长从试卷中提取题目和标准答案。你的核心能力:\n\n1. **题目识别能力**:能够准确识别试卷中的所有题目,包括大题和小题\n2. **答案提取能力**:能够准确提取每道题的标准答案\n3. **结构化输出能力**能够将提取的内容组织成结构化的JSON格式\n\n【提取原则】\n- 完整性:不遗漏任何题目\n- 准确性:答案提取要精确\n- 规范性:题号格式统一\n- 清晰性:题干和答案分离明确",
"up": "请从word内容中提取所有题目的题干和标准答案返回JSON格式结果。"
}

View File

@ -0,0 +1,12 @@
{
"config": {
"model": "doubao-seed-1-6-vision-250815",
"temperature": 0.1,
"top_p": 0.95,
"max_completion_tokens": 8192,
"thinking": "disabled"
},
"tools": [],
"sp": "# 角色定义\n你是一位专业的初中物理作业批改助手具有丰富的物理教学经验和精准的视觉识别能力。你能够准确识别作业图片中的题目内容、学生答案并判断答案的正确性。\n\n# 任务目标\n分析上传的初中物理作业图片识别每道题目及其学生答案判断答案是否正确并输出结构化的批改结果JSON。\n\n# 工作流上下文\n- **Input**作业图片图片URL\n- **Process**\n 1. 仔细识别图片中的所有题目,包括题号、题目内容\n 2. 识别每道题的学生答案,注意区分小题(如(1)(2)(3)\n 3. 判断每个答案的正确性,对于解答题需要检查计算过程和结果\n 4. 为每个批改标记确定在原图上的相对坐标位置(批改标记应放置在答案末尾右侧)\n 5. 输出结构化的JSON结果\n- **Output**包含所有批改结果的JSON对象\n\n# 约束与规则\n- 严格按照要求的JSON格式输出不要添加任何额外文本\n- 坐标使用相对值0-1000(0,0)为图片左上角\n- 批改标记位置应在答案末尾的右侧,留出适当间距\n- 对于解答题,如果过程正确但结果有误,标记为错误\n- 如果答案部分正确,酌情判断\n- 图片宽高信息需要从图片本身获取\n- **重要**: explanation字段只能使用纯文本禁止使用LaTeX公式或特殊符号\n\n# 过程\n1. 识别题目结构:扫描图片,定位所有题目,记录题号和小题号\n2. 答案识别:逐题识别学生的作答内容\n3. 正确性判断:\n - 对于计算题:检查计算过程和结果\n - 对于证明题:检查证明逻辑是否完整\n - 对于作图题:检查图形是否正确\n4. 坐标定位:确定每道题答案末尾的坐标位置\n5. 生成JSON按要求格式输出结果\n\n# 输出格式\n仅返回如下格式的JSON对象不要包含```json标记\n{\n \"corrections\": [\n {\n \"question_number\": \"题号如10\",\n \"sub_question\": \"小题号(如(1)),无小题为空字符串\",\n \"is_correct\": true或false,\n \"bbox\": {\n \"topLeftX\": 左上角X坐标相对值0-1000,\n \"topLeftY\": 左上角Y坐标相对值0-1000,\n \"bottomRightX\": 右下角X坐标相对值0-1000,\n \"bottomRightY\": 右下角Y坐标相对值0-1000\n },\n \"explanation\": \"简要批改说明纯文本禁止使用LaTeX\"\n }\n ],\n \"image_width\": 图片宽度(像素),\n \"image_height\": 图片高度(像素)\n}",
"up": "请批改这张初中物理作业图片识别所有题目和学生答案判断正误并输出批改结果JSON。注意explanation字段只能使用纯文本禁止使用LaTeX公式。图片URL{{image_url}}"
}

View File

@ -0,0 +1,12 @@
{
"config": {
"model": "doubao-seed-2-0-pro-260215",
"temperature": 0.0,
"top_p": 0.9,
"max_completion_tokens": 8192,
"thinking": "disabled"
},
"tools": [],
"sp": "# 角色\n你是物理作业批改助手。\n\n# 禁止标注\n- 印刷体文字、实验装置图、图中字母\n\n# 需要标注\n- 学生手写答案\n\n# 坐标\n- 相对坐标0-1000answer_bbox: [x1, y1, x2, y2]\n\n# ⚠️ 重要:拆分填空题\n- 一道题有多个填空时,**每个空单独识别为一个题目**\n- 题号格式:\"3(1)第一空\"、\"3(1)第二空\"、\"4(2)第一空\"、\"4(2)第二空\"\n- 每个空单独批改,单独打分\n- 示例:\n - 题目(1)有两个空 → 识别为\"3(1)第一空\"和\"3(1)第二空\"两个题目\n - 题目(2)有一个空 → 识别为\"3(2)\"一个题目\n\n# 输出格式\n{\"results\": [{\"question_id\": \"题号\", \"student_answer\": \"答案\", \"answer_bbox\": [x1, y1, x2, y2], \"status\": \"correct或incorrect\", \"score\": 分数, \"full_score\": 满分, \"comment\": \"结论\"}]}\n\ncomment格式\"正确\"或\"错误应为X\"不超过50字",
"up": "批改物理作业。**每个填空单独识别**。输出JSONcomment不超过{{comment_max_length}}字。图片:{{image_url}}"
}

View File

@ -0,0 +1,17 @@
{
"config": {
"temperature": 0.1,
"frequency_penalty": 0,
"top_p": 0.9,
"max_tokens": 4096,
"max_completion_tokens": 8192,
"thinking_type": "enabled",
"reasoning_effort": "medium",
"response_format": "text",
"json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}",
"model": "doubao-seed-2-0-pro-260215"
},
"tools": [],
"sp": "你是一位专业的初中物理作业识别专家,擅长从作业图片中定位题目位置和提取答案区域。",
"up": "请识别这张作业图片中的所有题目位置,返回准确的边界框坐标。"
}

156
requirements.txt Normal file
View File

@ -0,0 +1,156 @@
alembic==1.16.5
annotated-doc==0.0.4
annotated-types==0.7.0
anyio==4.12.1
APScheduler==3.11.2
astroid==3.1.0
Authlib==1.6.9
beautifulsoup4==4.14.3
boto3==1.40.61
botocore==1.40.61
cachetools==6.2.6
certifi==2026.2.25
cffi==2.0.0
chardet==5.2.0
charset-normalizer==3.4.4
click==8.3.1
coverage==7.13.4
coze-coding-dev-sdk==0.5.11
coze-coding-utils==0.2.4
coze-workload-identity==0.1.4
cozeloop==0.1.25
cryptography==46.0.5
cssselect==1.4.0
cssutils==2.11.1
dbus-python==1.3.2
deprecation==2.1.0
dill==0.4.1
distro==1.9.0
docx2python==3.5.0
et_xmlfile==2.0.0
fastapi==0.121.2
fsspec==2026.2.0
gitdb==4.0.12
gitignore_parser==0.1.13
GitPython==3.1.45
greenlet==3.3.2
h11==0.16.0
h2==4.3.0
hpack==4.1.0
httpcore==1.0.9
httpx==0.28.1
httpx-ws==0.8.2
hyperframe==6.1.0
idna==3.11
inflect==7.5.0
iniconfig==2.3.0
isort==5.13.2
Jinja2==3.1.6
jiter==0.13.0
jmespath==1.1.0
jsonpatch==1.33
jsonpointer==3.0.0
langchain==1.0.3
langchain-core==1.0.2
langchain-openai==1.0.1
langgraph==1.0.2
langgraph-checkpoint==3.0.0
langgraph-checkpoint-postgres==3.0.1
langgraph-prebuilt==1.0.2
langgraph-sdk==0.2.9
langsmith==0.4.39
lxml==6.0.2
Mako==1.3.10
Markdown==3.10.2
markdown-it-py==4.0.0
MarkupSafe==3.0.3
mccabe==0.7.0
mdurl==0.1.2
mmh3==5.2.1
more-itertools==10.8.0
multidict==6.7.1
numpy==2.2.6
openai==2.24.0
opencv-python==4.12.0.88
openpyxl==3.1.5
orjson==3.11.7
ormsgpack==1.12.2
packaging==25.0
pandas==2.2.2
paragraphs==1.0.1
pillow==10.3.0
platformdirs==4.9.2
pluggy==1.6.0
postgrest==2.27.3
propcache==0.4.1
psutil==7.1.3
psycopg==3.3.0
psycopg-binary==3.3.0
psycopg-pool==3.3.0
psycopg2-binary==2.9.9
pycparser==3.0
pydantic==2.12.3
pydantic_core==2.41.4
Pygments==2.19.2
PyGObject==3.48.2
pyiceberg==0.11.1
PyJWT==2.10.1
pylint==3.1.0
pyparsing==3.3.2
pypdf==6.4.1
pyroaring==1.0.3
pytest==9.0.1
pytest-asyncio==1.3.0
pytest-cov==7.0.0
pytest-mock==3.15.1
python-dateutil==2.9.0.post0
python-docx==1.2.0
python-dotenv==1.2.1
python-pptx==1.0.2
pytz==2026.1.post1
PyYAML==6.0.3
realtime==2.27.3
regex==2026.2.28
reportlab==4.4.10
requests==2.32.5
requests-toolbelt==1.0.0
rich==14.2.0
s3transfer==0.14.0
setuptools==68.1.2
six==1.17.0
smmap==5.0.2
sniffio==1.3.1
soupsieve==2.8.3
sqlacodegen==4.0.2
SQLAlchemy==2.0.44
starlette==0.49.3
storage3==2.27.3
StrEnum==0.4.15
strictyaml==1.7.3
supabase==2.27.3
supabase-auth==2.27.3
supabase-functions==2.27.3
tenacity==9.1.4
tiktoken==0.12.0
tomlkit==0.14.0
tqdm==4.67.3
typeguard==4.5.1
types-html5lib==1.1.11.20251117
types-lxml==2026.2.16
types-webencodings==0.5.0.20251108
typing-inspection==0.4.2
typing_extensions==4.15.0
tzdata==2025.3
tzlocal==5.3.1
Unidecode==1.4.0
urllib3==2.6.3
uvicorn==0.38.0
watchdog==6.0.0
websockets==15.0.1
wheel==0.42.0
wsproto==1.3.2
xlrd==2.0.2
xlsxwriter==3.2.9
xxhash==3.6.0
yarl==1.23.0
zstandard==0.25.0

31
scripts/http_run.sh Normal file
View File

@ -0,0 +1,31 @@
#!/bin/bash
set -e
# 导出环境变量
WORK_DIR="${COZE_WORKSPACE_PATH:-.}"
PORT=8000
usage() {
echo "用法: $0 -p <端口>"
}
while getopts "p:h" opt; do
case "$opt" in
p)
PORT="$OPTARG"
;;
h)
usage
exit 0
;;
\?)
echo "无效选项: -$OPTARG"
usage
exit 1
;;
esac
done
python ${WORK_DIR}/src/main.py -m http -p $PORT

35
scripts/load_env.py Normal file
View File

@ -0,0 +1,35 @@
#!/usr/bin/env python3
"""
加载项目环境变量脚本
通过 coze_workload_identity.Client 获取项目环境变量并输出 export 语句
使用方式: eval $(python load_env.py)
"""
import os
import sys
# 添加 app 目录到 Python 路径
workspace_path = os.getenv("COZE_WORKSPACE_PATH", "/workspace/projects")
app_dir = os.path.join(workspace_path, 'src')
if app_dir not in sys.path:
sys.path.insert(0, app_dir)
try:
from coze_workload_identity import Client
client = Client()
env_vars = client.get_project_env_vars()
client.close()
# 输出 export 语句格式的环境变量
for env_var in env_vars:
# 转义特殊字符
value = env_var.value.replace("'", "'\\''")
print(f"export {env_var.key}='{value}'")
# 输出成功消息到 stderr不影响 eval
print(f"# Successfully loaded {len(env_vars)} environment variables", file=sys.stderr)
except Exception as e:
print(f"# Error loading environment variables: {e}", file=sys.stderr)
sys.exit(1)

8
scripts/load_env.sh Normal file
View File

@ -0,0 +1,8 @@
#!/bin/bash
# 加载项目环境变量脚本
# 使用方式: source ./load_env.sh 或 . ./load_env.sh
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
eval $(python3 "$SCRIPT_DIR/load_env.py")

75
scripts/local_run.sh Normal file
View File

@ -0,0 +1,75 @@
#!/bin/bash
set -e
mode=""
node=""
input=""
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
WORK_DIR="${COZE_WORKSPACE_PATH:-$(dirname "$SCRIPT_DIR")}"
usage() {
echo "用法: $0 -m <模式> [-n <节点ID>] [-i <输入JSON>]"
echo ""
echo "参数说明:"
echo " -m <模式> 运行模式: http, flow, node, agent"
echo " -n <节点ID> 节点ID (仅在 node 模式下需要)"
echo " -i <输入JSON> 输入数据,支持 JSON 字符串或纯文本"
echo " -h 显示帮助信息"
echo ""
echo "示例:"
echo " $0 -m flow"
echo " $0 -m flow -i '{\"text\": \"你好\"}'"
echo " $0 -m flow -i '你好'"
echo " $0 -m node -n node_1 -i '{\"text\": \"测试\"}'"
}
while getopts "m:n:i:h" opt; do
case "$opt" in
m)
mode="$OPTARG"
;;
n)
node="$OPTARG"
;;
i)
input="$OPTARG"
;;
h)
usage
exit 0
;;
\?)
echo "无效选项: -$OPTARG"
usage
exit -1
;;
esac
done
if [ -z "$mode" ]; then
echo "错误: 必须指定 -m 参数"
usage
exit -1
fi
# Load environment variables
if [ -f "${SCRIPT_DIR}/load_env.sh" ]; then
echo "Loading environment variables..."
source "${SCRIPT_DIR}/load_env.sh"
fi
# Build python command
cmd="python ${WORK_DIR}/src/main.py -m \"$mode\""
if [ -n "$node" ]; then
cmd="$cmd -n \"$node\""
fi
if [ -n "$input" ]; then
cmd="$cmd -i '$input'"
fi
# Execute command
echo "Executing: $cmd"
eval $cmd

3
scripts/pack.sh Normal file
View File

@ -0,0 +1,3 @@
#!/bin/bash
pip freeze --exclude watchdog > requirements.txt

9
scripts/setup.sh Normal file
View File

@ -0,0 +1,9 @@
# 初始化目录
if [ "$COZE_PROJECT_ENV" = "DEV" ]; then
if [ ! -d "${COZE_WORKSPACE_PATH}/assets" ]; then
mkdir -p "${COZE_WORKSPACE_PATH}/assets"
fi
fi
# 安装Python三方包依赖
pip install -r requirements.txt

0
src/__init__.py Normal file
View File

0
src/agents/__init__.py Normal file
View File

0
src/graphs/__init__.py Normal file
View File

39
src/graphs/graph.py Normal file
View File

@ -0,0 +1,39 @@
"""初中物理作业批改工作流主图编排 - 支持多图片批改"""
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from graphs.state import (
GlobalState,
GraphInput,
GraphOutput
)
# 导入节点
from graphs.nodes.doc_extract_node import doc_extract_node
from graphs.nodes.process_images_node import process_images_node
# 创建状态图
builder = StateGraph(
GlobalState,
input_schema=GraphInput,
output_schema=GraphOutput
)
# 添加节点
builder.add_node("doc_extract", doc_extract_node, metadata={"type": "agent", "llm_cfg": "config/doc_extract_llm_cfg.json"})
builder.add_node("process_images", process_images_node, metadata={"type": "looparray"})
# 设置入口点
builder.set_entry_point("doc_extract")
# 添加边 - 线性流程
# 1. 先解析Word答案如果提供了URL
builder.add_edge("doc_extract", "process_images")
# 2. 循环处理每张图片并生成最终结果
builder.add_edge("process_images", END)
# 编译图
main_graph = builder.compile()

66
src/graphs/loop_graph.py Normal file
View File

@ -0,0 +1,66 @@
"""单图片处理子图:封装完整的单张图片批改流程(优化版:合并识别和批改)"""
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from coze_coding_utils.runtime_ctx.context import Context
from graphs.state import (
SubgraphState,
SubgraphInput,
SubgraphOutput,
SingleImageResult,
ImageInfo
)
# 导入子图需要的节点
from graphs.nodes.image_preprocess_node import image_preprocess_node
from graphs.nodes.recognize_and_correct_node import recognize_and_correct_node
from graphs.nodes.result_merge_node import result_merge_node
def wrap_result_node(
state: SubgraphState,
config: RunnableConfig,
runtime: Runtime[Context]
) -> SubgraphOutput:
"""
title: 包装子图结果
desc: 将子图的中间状态包装为SingleImageResult输出
integrations:
"""
from graphs.state import SingleImageResult
image_result = SingleImageResult(
image_index=state.image_index,
image_info=state.image_info,
image_url=state.image_url,
annotations=state.annotations
)
return SubgraphOutput(image_result=image_result)
# 创建单图片处理子图
subgraph_builder = StateGraph(
SubgraphState,
input_schema=SubgraphInput,
output_schema=SubgraphOutput
)
# 添加节点优化后只需4个节点
subgraph_builder.add_node("image_preprocess", image_preprocess_node)
subgraph_builder.add_node("recognize_and_correct", recognize_and_correct_node, metadata={"type": "agent", "llm_cfg": "config/homework_recognize_llm_cfg.json"})
subgraph_builder.add_node("result_merge", result_merge_node)
subgraph_builder.add_node("wrap_result", wrap_result_node)
# 设置入口点
subgraph_builder.set_entry_point("image_preprocess")
# 添加边(优化后流程:预处理 -> 识别+批改 -> 整合 -> 包装)
subgraph_builder.add_edge("image_preprocess", "recognize_and_correct")
subgraph_builder.add_edge("recognize_and_correct", "result_merge")
subgraph_builder.add_edge("result_merge", "wrap_result")
subgraph_builder.add_edge("wrap_result", END)
# 编译子图
single_image_subgraph = subgraph_builder.compile()

View File

@ -0,0 +1,10 @@
"""工作流节点模块 - 学习豆包APP的一体化识别方式"""
from graphs.nodes.image_preprocess_node import image_preprocess_node
from graphs.nodes.result_merge_node import result_merge_node
from graphs.nodes.recognize_and_correct_node import recognize_and_correct_node
__all__ = [
"image_preprocess_node",
"result_merge_node",
"recognize_and_correct_node",
]

View File

@ -0,0 +1,280 @@
"""Word答案解析节点从.docx文件中提取题干和标准答案"""
import os
import json
import re
import logging
import tempfile
import requests
from typing import List
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from coze_coding_utils.runtime_ctx.context import Context
from coze_coding_dev_sdk import LLMClient
from langchain_core.messages import HumanMessage
from docx import Document
from graphs.state import (
DocExtractInput,
DocExtractOutput,
CorrectAnswer
)
logger = logging.getLogger(__name__)
def sanitize_json_string(text: str) -> str:
"""清理JSON字符串中的无效转义和控制字符"""
text = ''.join(char if ord(char) >= 32 or char in '\n\r\t' else ' ' for char in text)
result = []
i = 0
in_string = False
while i < len(text):
char = text[i]
if char == '"' and (i == 0 or (i > 0 and result and result[-1] != '\\')):
in_string = not in_string
result.append(char)
elif in_string and char == '\\':
if i + 1 < len(text):
next_char = text[i + 1]
if next_char in ['"', '\\', '/', 'b', 'f', 'n', 'r', 't']:
result.append(char)
result.append(next_char)
i += 1
elif next_char == 'u':
result.append(char)
result.append(next_char)
i += 1
else:
result.append('\\\\')
else:
result.append('\\\\')
elif in_string and char in '\n\r':
result.append('\\n' if char == '\n' else '\\r')
else:
result.append(char)
i += 1
return ''.join(result)
def extract_json_from_text(text: str, key: str = "answers") -> dict:
"""从文本中提取JSON对象多层回退策略"""
import orjson
# 先尝试直接解析完整JSON
try:
result = json.loads(text)
if key in result:
return result
except (json.JSONDecodeError, ValueError) as e:
logger.debug(f"Direct JSON parse failed: {e}")
# 尝试找到JSON对象的边界
try:
# 找到第一个 { 和最后一个 }
start = text.find('{')
end = text.rfind('}')
if start != -1 and end != -1 and end > start:
json_str = text[start:end+1]
result = json.loads(json_str)
if key in result:
return result
except (json.JSONDecodeError, ValueError) as e:
logger.debug(f"Boundary JSON parse failed: {e}")
# 使用orjson尝试
try:
start = text.find('{')
end = text.rfind('}')
if start != -1 and end != -1 and end > start:
json_str = text[start:end+1]
result = orjson.loads(json_str)
if key in result:
return result
except Exception as e:
logger.debug(f"Orjson parse failed: {e}")
# 尝试清理后解析
try:
cleaned = sanitize_json_string(text)
start = cleaned.find('{')
end = cleaned.rfind('}')
if start != -1 and end != -1 and end > start:
json_str = cleaned[start:end+1]
result = json.loads(json_str)
if key in result:
return result
except Exception as e:
logger.debug(f"Cleaned JSON parse failed: {e}")
logger.warning(f"Failed to extract JSON with key '{key}' from text (length: {len(text)})")
return {key: []}
def download_and_extract_docx(url: str) -> str:
"""下载Word文件并提取文本内容"""
logger.info(f"Downloading Word document from: {url}")
# 下载文件
response = requests.get(url, timeout=60)
response.raise_for_status()
# 保存到临时文件
with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as tmp_file:
tmp_file.write(response.content)
tmp_path = tmp_file.name
try:
# 使用python-docx解析
doc = Document(tmp_path)
# 提取所有段落文本
paragraphs = []
for para in doc.paragraphs:
text = para.text.strip()
if text:
paragraphs.append(text)
# 提取表格中的文本
for table in doc.tables:
for row in table.rows:
row_text = []
for cell in row.cells:
cell_text = cell.text.strip()
if cell_text:
row_text.append(cell_text)
if row_text:
paragraphs.append(" | ".join(row_text))
doc_text = "\n".join(paragraphs)
logger.info(f"Extracted Word document text length: {len(doc_text)}")
return doc_text
finally:
# 清理临时文件
os.unlink(tmp_path)
def doc_extract_node(
state: DocExtractInput,
config: RunnableConfig,
runtime: Runtime[Context]
) -> DocExtractOutput:
"""
title: Word答案解析
desc: 从正确答案Word文件.docx中提取题干和标准答案用于后续批改如果未提供URL则返回空列表
integrations: 大语言模型
"""
ctx = runtime.context
# 检查是否提供了答案文档URL
if not state.answer_doc_url or not state.answer_doc_url.strip():
logger.info("No answer document URL provided, will use teacher mode for all questions")
return DocExtractOutput(correct_answers=[])
# 1. 下载并提取Word文档内容
try:
doc_text = download_and_extract_docx(state.answer_doc_url)
except Exception as e:
logger.error(f"Failed to download/extract Word document: {e}")
return DocExtractOutput(correct_answers=[])
if not doc_text.strip():
logger.error("No text content extracted from Word document")
return DocExtractOutput(correct_answers=[])
logger.info(f"Word document content preview: {doc_text[:500]}")
# 2. 使用LLM解析题目和答案
cfg_file = os.path.join(os.getenv("COZE_WORKSPACE_PATH", ""), config["metadata"]["llm_cfg"])
with open(cfg_file, "r", encoding="utf-8") as fd:
_cfg = json.load(fd)
llm_config = _cfg.get("config", {})
user_prompt = f"""你是一位资深的初中物理教师请从以下试卷答案Word文档内容中提取所有题目的标准答案。
Word文档内容
{doc_text[:20000]}
提取要求
1. 识别所有题号包括大题和小题
2. 只提取每道题的标准答案不需要提取题干
3. 答案格式
- 选择题单个字母A/B/C/D
- 填空题数值或表达式
- 解答题关键结果
4. 如果有多个空用逗号分隔
题号格式规范
- 大题题号直接使用数字"1""2""10"
- 小题题号使用"大题.小题"格式"10.1""10.2"
重要
- 保持答案简洁不要包含题干
- 如果文档中有分值请提取否则默认3分
请返回简洁的JSON格式
{{
"answers": [
{{
"question_id": "1",
"correct_answer": "B",
"full_score": 3
}},
{{
"question_id": "2",
"correct_answer": "2a-b+c",
"full_score": 3
}}
]
}}"""
client = LLMClient(ctx=ctx)
response = client.invoke(
messages=[HumanMessage(content=user_prompt)],
model=llm_config.get("model", "doubao-seed-2-0-pro-260215"),
temperature=llm_config.get("temperature", 0.1),
max_completion_tokens=llm_config.get("max_completion_tokens", 8192)
)
response_text = response.content if isinstance(response.content, str) else " ".join(
item.get("text", "") if isinstance(item, dict) else str(item)
for item in response.content
).strip()
logger.info(f"Word extract LLM response: {response_text[:1500]}")
# 清理markdown标记
for prefix in ["```json", "```JSON", "```"]:
if response_text.startswith(prefix):
response_text = response_text[len(prefix):]
for suffix in ["```"]:
if response_text.endswith(suffix):
response_text = response_text[:-3]
# 解析JSON
result_dict = extract_json_from_text(response_text.strip(), "answers")
correct_answers: List[CorrectAnswer] = []
for ans in result_dict.get("answers", []):
try:
correct_answers.append(CorrectAnswer(
question_id=str(ans.get("question_id", "")),
parent_id=str(ans.get("parent_id", "")),
is_sub_question=bool(ans.get("is_sub_question", False)),
question_text=str(ans.get("question_text", "")),
correct_answer=str(ans.get("correct_answer", "")),
full_score=int(ans.get("full_score", 3) if ans.get("full_score") is not None else 3),
answer_analysis=str(ans.get("answer_analysis", ""))
))
except Exception as e:
logger.warning(f"Failed to parse correct answer: {ans}, error: {e}")
continue
logger.info(f"Parsed {len(correct_answers)} correct answers from Word document")
# 打印解析结果供调试
for ans in correct_answers:
logger.info(f" Question {ans.question_id}: {ans.correct_answer} ({ans.full_score}分)")
return DocExtractOutput(correct_answers=correct_answers)

View File

@ -0,0 +1,159 @@
"""1. 图像预处理节点下载图片、自动旋转、缩放到固定宽度1000、上传对象存储"""
import os
import logging
import urllib.request
from io import BytesIO
from typing import Tuple
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from coze_coding_utils.runtime_ctx.context import Context
from PIL import Image
from datetime import datetime
from graphs.state import (
ImagePreprocessInput,
ImagePreprocessOutput,
ImageInfo
)
logger = logging.getLogger(__name__)
# 固定宽度常量
FIXED_WIDTH = 1000
def download_and_process_image(image_url: str) -> Tuple[Image.Image, int, int, int]:
"""下载图片并返回PIL Image对象和原始尺寸信息"""
try:
with urllib.request.urlopen(image_url, timeout=30) as response:
img_data = response.read()
img = Image.open(BytesIO(img_data))
# 转换为RGB模式处理PNG透明通道
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
original_width, original_height = img.size
dpi = img.info.get('dpi', (72, 72))
if isinstance(dpi, tuple):
dpi = dpi[0] if dpi[0] > 0 else 72
else:
dpi = 72
return img, original_width, original_height, dpi
except Exception as e:
raise RuntimeError(f"下载图片失败: {e}")
def auto_rotate_image(img: Image.Image) -> Tuple[Image.Image, bool]:
"""
自动旋转图片如果宽度大于高度横向图片则旋转-90度使其变为纵向
Args:
img: PIL Image对象
Returns:
(旋转后的图片, 是否进行了旋转)
"""
width, height = img.size
# 如果宽度大于高度,需要旋转
if width > height:
logger.info(f"检测到横向图片(宽{width} > 高{height}),正在旋转-90度...")
# rotate(-90) 表示逆时针旋转90度使横向变为纵向
rotated_img = img.rotate(-90, expand=True)
new_width, new_height = rotated_img.size
logger.info(f"旋转完成,新尺寸:宽{new_width} x 高{new_height}")
return rotated_img, True
logger.info(f"图片为纵向(宽{width} <= 高{height}),无需旋转")
return img, False
def resize_image_to_fixed_width(img: Image.Image, target_width: int = FIXED_WIDTH) -> Tuple[Image.Image, float]:
"""将图片缩放到固定宽度,高度等比例缩放"""
original_width, original_height = img.size
if original_width == target_width:
return img, 1.0
# 计算缩放比例
scale_ratio = target_width / original_width
new_height = int(original_height * scale_ratio)
# 使用高质量重采样
# BICUBIC = 3, LANCZOS = 4 (PIL内部常量值)
resized_img = img.resize((target_width, new_height), 3) # BICUBIC
return resized_img, scale_ratio
def upload_image_to_storage(img: Image.Image, ctx) -> str:
"""将图片上传到对象存储并返回URL"""
from coze_coding_dev_sdk.s3 import S3SyncStorage
# 转换为字节流
img_buffer = BytesIO()
img.save(img_buffer, format='JPEG', quality=95)
img_bytes = img_buffer.getvalue()
# 上传到对象存储
storage = S3SyncStorage(
endpoint_url=os.getenv("COZE_BUCKET_ENDPOINT_URL"),
access_key="",
secret_key="",
bucket_name=os.getenv("COZE_BUCKET_NAME"),
region="cn-beijing",
)
file_key = storage.upload_file(
file_content=img_bytes,
file_name=f"homework_resized_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jpg",
content_type="image/jpeg",
)
result_url = storage.generate_presigned_url(key=file_key, expire_time=86400)
return result_url
def image_preprocess_node(
state: ImagePreprocessInput,
config: RunnableConfig,
runtime: Runtime[Context]
) -> ImagePreprocessOutput:
"""
title: 图像预处理
desc: 下载图片自动旋转横向纵向缩放到固定宽度1000px上传对象存储
integrations: 对象存储
"""
ctx = runtime.context
# 1. 下载原始图片
img, original_width, original_height, dpi = download_and_process_image(
state.homework_image.url
)
logger.info(f"原始图片尺寸:宽{original_width} x 高{original_height}")
# 2. 自动旋转:如果宽度大于高度,旋转-90度使其变为纵向
img, was_rotated = auto_rotate_image(img)
after_rotate_width, after_rotate_height = img.size
if was_rotated:
logger.info(f"旋转后尺寸:宽{after_rotate_width} x 高{after_rotate_height}")
# 3. 缩放到固定宽度1000px高度等比例缩放
resized_img, scale_ratio = resize_image_to_fixed_width(img, FIXED_WIDTH)
new_width, new_height = resized_img.size
logger.info(f"缩放后尺寸:宽{new_width} x 高{new_height},缩放比例:{scale_ratio:.3f}")
# 4. 上传处理后的图片到对象存储
processed_image_url = upload_image_to_storage(resized_img, ctx)
# 5. 返回处理后的图片信息AI基于这个尺寸计算坐标
return ImagePreprocessOutput(
image_info=ImageInfo(
width=new_width, # 缩放后的宽度1000
height=new_height, # 缩放后的高度
dpi=dpi
),
image_url=processed_image_url
)

View File

@ -0,0 +1,167 @@
"""多图片处理循环节点:并行调用子图处理每张作业图片"""
import logging
from typing import List
from concurrent.futures import ThreadPoolExecutor, as_completed
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from coze_coding_utils.runtime_ctx.context import Context
from graphs.state import (
ProcessImagesInput,
ProcessImagesOutput,
SingleImageResult,
SubgraphInput,
FinalResult,
ImageInfo
)
from graphs.loop_graph import single_image_subgraph
logger = logging.getLogger(__name__)
def process_images_node(
state: ProcessImagesInput,
config: RunnableConfig,
runtime: Runtime[Context]
) -> ProcessImagesOutput:
"""
title: 多图片批改处理
desc: 并行调用子图处理每张作业图片生成最终批改结果
integrations:
"""
ctx = runtime.context
# 获取并发数限制从参数获取默认10
max_concurrent = getattr(state, 'max_concurrent', 10)
logger.info(f"Starting to process {len(state.homework_images)} images (concurrent={max_concurrent})")
# 定义处理单张图片的函数
def process_single_image(idx: int, homework_image) -> SingleImageResult:
logger.info(f"Processing image {idx + 1}/{len(state.homework_images)}")
try:
# 构建子图输入
subgraph_input = SubgraphInput(
homework_image=homework_image,
correct_answers=state.correct_answers,
image_index=idx,
comment_max_length=state.comment_max_length
)
# 调用子图
subgraph_output = single_image_subgraph.invoke(subgraph_input, config)
# 提取结果兼容字典和Pydantic对象两种返回类型
image_result = None
if subgraph_output:
if isinstance(subgraph_output, dict):
result_data = subgraph_output.get("image_result")
if result_data:
if isinstance(result_data, dict):
image_result = SingleImageResult(**result_data)
else:
image_result = result_data
elif hasattr(subgraph_output, 'image_result'):
image_result = subgraph_output.image_result
if image_result:
logger.info(f"Image {idx + 1} processed successfully: {len(image_result.annotations)} annotations")
return image_result
else:
logger.warning(f"Image {idx + 1} subgraph returned invalid output")
return SingleImageResult(
image_index=idx,
image_info=ImageInfo(width=0, height=0, dpi=72),
annotations=[]
)
except Exception as e:
logger.error(f"Failed to process image {idx + 1}: {e}", exc_info=True)
return SingleImageResult(
image_index=idx,
image_info=ImageInfo(width=0, height=0, dpi=72),
annotations=[]
)
# 并行处理所有图片
image_results: List[SingleImageResult] = []
with ThreadPoolExecutor(max_workers=max_concurrent) as executor:
# 提交所有任务
future_to_idx = {
executor.submit(process_single_image, idx, img): idx
for idx, img in enumerate(state.homework_images)
}
# 收集结果
results_dict = {}
for future in as_completed(future_to_idx):
idx = future_to_idx[future]
try:
result = future.result()
results_dict[idx] = result
except Exception as e:
logger.error(f"Future failed for image {idx}: {e}")
results_dict[idx] = SingleImageResult(
image_index=idx,
image_info=ImageInfo(width=0, height=0, dpi=72),
annotations=[]
)
# 按原始顺序排序结果
for idx in range(len(state.homework_images)):
image_results.append(results_dict[idx])
logger.info(f"Completed processing {len(image_results)} images")
# 生成最终结果
total_score = 0
full_score = 0
total_questions = 0
correct_count = 0
incorrect_count = 0
for result in image_results:
for annotation in result.annotations:
total_score += annotation.score
full_score += annotation.full_score
total_questions += 1
if annotation.status == "correct":
correct_count += 1
elif annotation.status == "incorrect":
incorrect_count += 1
# 计算得分率
score_rate = (total_score / full_score * 100) if full_score > 0 else 0
# 生成整体评价
if total_questions == 0:
overall_comment = "未识别到题目内容"
grade = "D"
elif score_rate >= 95:
overall_comment = f"优秀!{total_questions}题全部正确,掌握扎实,继续保持。"
grade = "A+"
elif score_rate >= 90:
overall_comment = f"良好!得分率{score_rate:.0f}%,错{incorrect_count}题,注意细节。"
grade = "A"
elif score_rate >= 80:
overall_comment = f"合格。得分率{score_rate:.0f}%,错{incorrect_count}题,需加强练习。"
grade = "B"
elif score_rate >= 70:
overall_comment = f"及格。错{incorrect_count}题,部分知识点掌握不牢,建议复习。"
grade = "C"
else:
overall_comment = f"需努力。得分率{score_rate:.0f}%,建议认真复习,多做练习。"
grade = "D"
final_result = FinalResult(
total_images=len(image_results),
image_results=image_results,
overall_comment=overall_comment,
total_score=total_score,
full_score=full_score,
grade=grade
)
logger.info(f"Final result: {total_score}/{full_score}, grade={grade}")
return ProcessImagesOutput(final_result=final_result)

View File

@ -0,0 +1,340 @@
"""一体化识别批改节点合并识别和批改为一次LLM调用提升速度"""
import os
import json
import re
import logging
from typing import List, Dict, Any
from jinja2 import Template
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from coze_coding_utils.runtime_ctx.context import Context
from coze_coding_dev_sdk import LLMClient
from langchain_core.messages import HumanMessage
from graphs.state import (
RecognizeAndCorrectInput,
RecognizeAndCorrectOutput,
QuestionItem,
CorrectionResult,
MarkPosition,
CorrectAnswer
)
logger = logging.getLogger(__name__)
def fix_incomplete_json(text: str) -> str:
"""尝试修复不完整的JSON字符串"""
# 计算括号数量
brace_count = text.count('{') - text.count('}')
bracket_count = text.count('[') - text.count(']')
# 补全缺失的括号
if bracket_count > 0:
text += ']' * bracket_count
if brace_count > 0:
text += '}' * brace_count
return text
def extract_json_from_text(text: str, key: str = "results") -> dict:
"""从文本中提取JSON对象增强健壮性"""
import orjson
# 清理markdown标记
for prefix in ["```json", "```JSON", "```"]:
if text.startswith(prefix):
text = text[len(prefix):]
for suffix in ["```"]:
if text.endswith(suffix):
text = text[:-3]
text = text.strip()
# 尝试直接解析支持格式化JSON
for parser in [json.loads, lambda x: orjson.loads(x)]:
try:
result = parser(text)
if isinstance(result, dict) and key in result:
logger.info("Successfully parsed JSON directly")
return result
except Exception as e:
logger.debug(f"Direct parse failed: {e}")
# 尝试修复不完整的JSON
try:
fixed_text = fix_incomplete_json(text)
for parser in [json.loads, lambda x: orjson.loads(x)]:
try:
result = parser(fixed_text)
if isinstance(result, dict) and key in result:
logger.info("Successfully parsed fixed JSON")
return result
except:
pass
except:
pass
# 尝试提取第一个完整的JSON对象
results_pattern = r'"results"\s*:\s*\['
match = re.search(results_pattern, text)
if match:
start_pos = text.rfind('{', 0, match.start())
if start_pos == -1:
start_pos = 0
# 从start_pos开始找到匹配的 }
brace_count = 0
bracket_count = 0
in_string = False
escape = False
end_pos = start_pos
for i in range(start_pos, len(text)):
char = text[i]
if escape:
escape = False
continue
if char == '\\':
escape = True
continue
if char == '"' and not escape:
in_string = not in_string
continue
if in_string:
continue
if char == '{':
brace_count += 1
elif char == '}':
brace_count -= 1
if brace_count == 0 and bracket_count == 0:
end_pos = i + 1
break
elif char == '[':
bracket_count += 1
elif char == ']':
bracket_count -= 1
if end_pos > start_pos:
json_str = text[start_pos:end_pos]
for parser in [json.loads, lambda x: orjson.loads(x)]:
try:
result = parser(json_str)
if isinstance(result, dict) and key in result:
logger.info("Successfully extracted and parsed JSON object")
return result
except Exception as e:
logger.debug(f"Failed to parse extracted JSON: {e}")
# 如果提取失败尝试修复不完整的JSON
if end_pos <= start_pos:
# 找到JSON开始位置
json_str = text[start_pos:]
# 尝试修复
fixed_json = fix_incomplete_json(json_str)
for parser in [json.loads, lambda x: orjson.loads(x)]:
try:
result = parser(fixed_json)
if isinstance(result, dict) and key in result:
logger.info("Successfully parsed fixed incomplete JSON")
return result
except:
pass
logger.warning(f"Failed to extract JSON with key '{key}' from text length {len(text)}")
return {key: []}
def build_dynamic_prompt(
image_width: int,
image_height: int,
correct_answers: List[CorrectAnswer],
comment_max_length: int
) -> str:
"""构建动态Prompt部分"""
if correct_answers:
answers_text = json.dumps(
[{"question_id": a.question_id, "correct_answer": a.correct_answer}
for a in correct_answers],
ensure_ascii=False
)
answer_hint = f"""
标准答案
{answers_text}"""
else:
answer_hint = "\n【批改模式】无标准答案,请根据物理知识判断。"
return f"""
图片尺寸{image_width}×{image_height}像素
评语限制{comment_max_length}字以内
{answer_hint}
必须输出完整JSON不要输出思考过程"""
def recognize_and_correct_node(
state: RecognizeAndCorrectInput,
config: RunnableConfig,
runtime: Runtime[Context]
) -> RecognizeAndCorrectOutput:
"""
title: 一体化识别批改
desc: 合并识别和批改为一次LLM调用提升批改速度
integrations: 大语言模型
"""
ctx = runtime.context
# 读取LLM配置
cfg_file = os.path.join(os.getenv("COZE_WORKSPACE_PATH", ""), config["metadata"]["llm_cfg"])
with open(cfg_file, "r", encoding="utf-8") as fd:
_cfg = json.load(fd)
llm_config = _cfg.get("config", {})
sp = _cfg.get("sp", "")
up = _cfg.get("up", "")
# 获取参数
image_url = state.image_url
image_info = state.image_info
correct_answers = state.correct_answers
comment_max_length = getattr(state, 'comment_max_length', 100)
# 使用Jinja2渲染up
up_tpl = Template(up)
user_prompt = up_tpl.render({
"image_url": image_url,
"comment_max_length": comment_max_length
})
# 构建动态部分
dynamic_prompt = build_dynamic_prompt(
image_info.width,
image_info.height,
correct_answers,
comment_max_length
)
# 组合完整Prompt
full_prompt = f"{sp}\n\n{user_prompt}\n{dynamic_prompt}"
# 调用LLM
messages = [
HumanMessage(content=[
{"type": "text", "text": full_prompt},
{"type": "image_url", "image_url": {"url": image_url}}
])
]
client = LLMClient(ctx=ctx)
response = client.invoke(
messages=messages,
model=llm_config.get("model", "doubao-seed-2-0-pro-260215"),
temperature=llm_config.get("temperature", 0.0),
max_completion_tokens=llm_config.get("max_completion_tokens", 4096)
)
response_text = response.content if isinstance(response.content, str) else " ".join(
item.get("text", "") if isinstance(item, dict) else str(item)
for item in response.content
).strip()
# 只记录前500字符避免日志过大
logger.info(f"LLM response (first 500 chars): {response_text[:500]}")
# 解析结果
result_dict = extract_json_from_text(response_text, "results")
# 相对坐标0-1000转换为绝对坐标
width_scale = image_info.width / 1000.0 if image_info.width > 0 else 1.0
height_scale = image_info.height / 1000.0 if image_info.height > 0 else 1.0
question_items: List[QuestionItem] = []
correction_results: List[CorrectionResult] = []
results_list = result_dict.get("results", [])
if not isinstance(results_list, list):
results_list = []
for r in results_list:
try:
if not isinstance(r, dict):
continue
# 解析bbox相对坐标0-1000
answer_bbox = r.get("answer_bbox", [0, 0, 0, 0])
if not isinstance(answer_bbox, list) or len(answer_bbox) != 4:
answer_bbox = [0, 0, 0, 0]
# 转换为绝对坐标
answer_bbox_abs = [
int(answer_bbox[0] * width_scale),
int(answer_bbox[1] * height_scale),
int(answer_bbox[2] * width_scale),
int(answer_bbox[3] * height_scale)
]
# 自动计算mark_position
bbox_width = answer_bbox_abs[2] - answer_bbox_abs[0]
bbox_height = answer_bbox_abs[3] - answer_bbox_abs[1]
if bbox_width > 0 and bbox_height > 0:
mark_x = answer_bbox_abs[2] + 30
mark_y = answer_bbox_abs[1] + int(bbox_height * 0.5)
if mark_x > image_info.width - 50:
mark_x = image_info.width - 50
else:
mark_x = 500
mark_y = 500
mark_position = MarkPosition(x=mark_x, y=mark_y)
# 构建QuestionItem
question_items.append(QuestionItem(
question_id=str(r.get("question_id", "")),
parent_id=str(r.get("parent_id", "")),
is_sub_question=bool(r.get("is_sub_question", False)),
question_text=str(r.get("question_text", "")),
student_answer=str(r.get("student_answer", "")),
answer_bbox=answer_bbox_abs,
mark_position=mark_position,
full_score=int(r.get("full_score", 10) if r.get("full_score") is not None else 10)
))
# 构建CorrectionResult
status = str(r.get("status", "incorrect"))
if status not in ["correct", "incorrect", "partial"]:
status = "incorrect"
# 使用原始comment不做截断由LLM控制长度
comment = str(r.get("comment", ""))
correction_results.append(CorrectionResult(
question_id=str(r.get("question_id", "")),
parent_id=str(r.get("parent_id", "")),
status=status,
score=int(r.get("score", 0) if r.get("score") is not None else 0),
full_score=int(r.get("full_score", 10) if r.get("full_score") is not None else 10),
comment=comment
))
except Exception as e:
logger.warning(f"Failed to parse result: {r}, error: {e}")
continue
logger.info(f"Parsed {len(question_items)} questions and {len(correction_results)} correction results")
return RecognizeAndCorrectOutput(
question_items=question_items,
correction_results=correction_results
)

View File

@ -0,0 +1,65 @@
"""4. 结果整合节点:将识别结果和批改结果合并为最终批注"""
from typing import List
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from coze_coding_utils.runtime_ctx.context import Context
from graphs.state import (
ResultMergeInput,
ResultMergeOutput,
Annotation,
QuestionItem,
CorrectionResult
)
def result_merge_node(
state: ResultMergeInput,
config: RunnableConfig,
runtime: Runtime[Context]
) -> ResultMergeOutput:
"""
title: 结果整合
desc: 将识别结果和批改结果合并为最终批注列表
integrations:
"""
ctx = runtime.context
annotations: List[Annotation] = []
# 创建批改结果字典,方便查找
correction_dict = {c.question_id: c for c in state.correction_results}
for q in state.question_items:
# 查找对应的批改结果
correction = correction_dict.get(q.question_id)
if correction is None:
# 如果没有批改结果,默认为错误
annotations.append(Annotation(
question_id=q.question_id,
parent_id=q.parent_id,
status="incorrect",
question_text=q.question_text,
student_answer=q.student_answer,
answer_bbox=q.answer_bbox,
comment="未找到批改结果",
mark_position=q.mark_position,
score=0,
full_score=q.full_score
))
else:
annotations.append(Annotation(
question_id=q.question_id,
parent_id=q.parent_id,
status=correction.status,
question_text=q.question_text,
student_answer=q.student_answer,
answer_bbox=q.answer_bbox,
comment=correction.comment,
mark_position=q.mark_position,
score=correction.score,
full_score=correction.full_score
))
return ResultMergeOutput(annotations=annotations)

232
src/graphs/state.py Normal file
View File

@ -0,0 +1,232 @@
"""初中物理作业批改工作流状态定义 - 支持多图片批改"""
from typing import List, Optional, Literal
from pydantic import BaseModel, Field
from utils.file.file import File
# === 基础数据结构 ===
class ImageInfo(BaseModel):
"""图片信息"""
width: int = Field(..., description="图片宽度(像素)")
height: int = Field(..., description="图片高度(像素)")
dpi: int = Field(default=72, description="图片DPI")
class MarkPosition(BaseModel):
"""批改标记位置"""
x: int = Field(..., description="标记中心X坐标绝对像素")
y: int = Field(..., description="标记中心Y坐标绝对像素")
class QuestionItem(BaseModel):
"""题目项(一体化识别结果)"""
question_id: str = Field(..., description="题号如10、10.1、10.2")
parent_id: str = Field(default="", description="母题题号(子题时填写)")
is_sub_question: bool = Field(default=False, description="是否为子题")
question_text: str = Field(default="", description="题目内容")
student_answer: str = Field(default="", description="学生答案内容")
answer_bbox: List[int] = Field(..., description="学生答案区域的边界框 [x1, y1, x2, y2](用于定位标注)")
mark_position: MarkPosition = Field(..., description="批改标记应该标注的位置坐标")
full_score: int = Field(default=10, description="满分")
class CorrectAnswer(BaseModel):
"""标准答案项从Word文档解析"""
question_id: str = Field(..., description="题号如10、10.1、10.2")
parent_id: str = Field(default="", description="母题题号(子题时填写)")
is_sub_question: bool = Field(default=False, description="是否为子题")
question_text: str = Field(default="", description="题目内容/题干")
correct_answer: str = Field(..., description="标准答案")
full_score: int = Field(default=10, description="满分")
answer_analysis: str = Field(default="", description="答案解析/解题步骤")
class CorrectionResult(BaseModel):
"""批改结果"""
question_id: str = Field(..., description="题号")
parent_id: str = Field(default="", description="母题题号")
status: Literal["correct", "incorrect", "partial"] = Field(..., description="批改状态")
score: int = Field(default=0, description="得分")
full_score: int = Field(default=10, description="满分")
comment: str = Field(default="", description="批改评语")
class Annotation(BaseModel):
"""完整批注"""
question_id: str = Field(..., description="题号")
parent_id: str = Field(default="", description="母题题号")
status: Literal["correct", "incorrect", "partial"] = Field(..., description="批改状态")
question_text: str = Field(default="", description="题目内容")
student_answer: str = Field(default="", description="学生答案")
answer_bbox: List[int] = Field(default=[], description="答案区域边界框")
comment: str = Field(default="", description="批改评语")
mark_position: MarkPosition = Field(..., description="批改标记位置")
score: int = Field(default=0, description="得分")
full_score: int = Field(default=10, description="满分")
class SingleImageResult(BaseModel):
"""单张图片的批改结果"""
image_index: int = Field(..., description="图片索引从0开始")
image_info: ImageInfo = Field(..., description="图片信息")
image_url: str = Field(default="", description="处理后的图片URL")
annotations: List[Annotation] = Field(default=[], description="该图片的批注列表")
class FinalResult(BaseModel):
"""最终批改结果(多图片汇总)"""
total_images: int = Field(..., description="总图片数")
image_results: List[SingleImageResult] = Field(default=[], description="各图片的批改结果")
overall_comment: str = Field(default="", description="整体评价")
total_score: int = Field(default=0, description="总分")
full_score: int = Field(default=100, description="满分")
grade: str = Field(default="", description="等级")
# === 全局状态 ===
class GlobalState(BaseModel):
"""工作流全局状态"""
# 输入参数
homework_images: List[File] = Field(default=[], description="上传的作业图片列表")
answer_doc_url: str = Field(default="", description="正确答案Word文件的URL.docx格式")
comment_max_length: int = Field(default=100, description="评语最大字数")
max_concurrent: int = Field(default=10, description="并行批改的最大数量")
grade_standards: dict = Field(default={}, description="评价等级标准")
# 中间状态
correct_answers: List[CorrectAnswer] = Field(default=[], description="从Word解析的标准答案列表")
image_results: List[SingleImageResult] = Field(default=[], description="各图片的批改结果列表")
# 最终结果
final_result: Optional[FinalResult] = Field(default=None, description="最终汇总结果")
# === 图输入输出 ===
class GraphInput(BaseModel):
"""工作流输入"""
homework_images: List[File] = Field(..., description="上传的作业图片列表")
answer_doc_url: str = Field(default="", description="正确答案Word文件的URL.docx格式可选")
comment_max_length: int = Field(default=100, description="评语最大字数默认100字")
max_concurrent: int = Field(default=10, description="并行批改的最大数量默认10")
grade_standards: dict = Field(
default={
"A+": {"min_percentage": 95, "description": "答案全部正确,步骤完整规范,逻辑严谨;书写/格式整洁无错别字、无遗漏完成度100%,态度认真,质量上乘"},
"A": {"min_percentage": 90, "description": "答案完全正确无任何错误步骤合理、格式规范无原则性问题完成度100%,满足全部要求"},
"B": {"min_percentage": 80, "description": "存在少量非关键性错误,或步骤略有缺失;整体思路基本正确,仅细节、格式、计算等小问题;完成大部分内容,整体合格但不够严谨"},
"C": {"min_percentage": 70, "description": "错误较多,部分核心题目作答错误;步骤不完整、逻辑不够清晰;完成度一般,有明显应付、漏答情况"},
"D": {"min_percentage": 0, "description": "大面积错误,核心知识点未掌握;大量空白、敷衍、抄袭;未达到基本完成要求"}
},
description="评价等级标准,包含各等级的最低得分率百分比和描述"
)
class GraphOutput(BaseModel):
"""工作流输出"""
final_result: FinalResult = Field(..., description="最终批改结果JSON包含多图片")
# === 文档答案解析节点 ===
class DocExtractInput(BaseModel):
"""文档答案解析节点输入"""
answer_doc_url: str = Field(default="", description="正确答案Word文件的URL.docx格式可选")
class DocExtractOutput(BaseModel):
"""文档答案解析节点输出"""
correct_answers: List[CorrectAnswer] = Field(default=[], description="解析的标准答案列表")
# === 子图状态定义(单图片处理流程) ===
class SubgraphState(BaseModel):
"""子图全局状态(处理单张图片)"""
# 输入
homework_image: File = Field(..., description="当前处理的作业图片")
correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表")
image_index: int = Field(default=0, description="图片索引")
comment_max_length: int = Field(default=100, description="评语最大字数")
# 中间状态
image_info: ImageInfo = Field(default_factory=lambda: ImageInfo(width=0, height=0, dpi=72), description="图片信息")
image_url: str = Field(default="", description="处理后的图片URL")
question_items: List[QuestionItem] = Field(default=[], description="识别的题目项列表")
correction_results: List[CorrectionResult] = Field(default=[], description="批改结果列表")
annotations: List[Annotation] = Field(default=[], description="最终批注列表")
class SubgraphInput(BaseModel):
"""子图输入"""
homework_image: File = Field(..., description="当前处理的作业图片")
correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表")
image_index: int = Field(default=0, description="图片索引")
comment_max_length: int = Field(default=100, description="评语最大字数")
class SubgraphOutput(BaseModel):
"""子图输出"""
image_result: SingleImageResult = Field(..., description="单张图片的批改结果")
# === 子图节点输入输出 ===
# 1. 图像预处理节点
class ImagePreprocessInput(BaseModel):
"""图像预处理节点输入"""
homework_image: File = Field(..., description="上传的作业图片")
class ImagePreprocessOutput(BaseModel):
"""图像预处理节点输出"""
image_info: ImageInfo = Field(..., description="图片信息")
image_url: str = Field(..., description="处理后的图片URL")
# 2. 一体化识别批改节点(合并识别+批改)
class RecognizeAndCorrectInput(BaseModel):
"""一体化识别批改节点输入"""
image_url: str = Field(..., description="图片URL")
image_info: ImageInfo = Field(..., description="图片信息")
correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表")
comment_max_length: int = Field(default=100, description="评语最大字数")
class RecognizeAndCorrectOutput(BaseModel):
"""一体化识别批改节点输出"""
question_items: List[QuestionItem] = Field(default=[], description="识别的题目项列表")
correction_results: List[CorrectionResult] = Field(default=[], description="批改结果列表")
# 3. 批改判断节点
class CorrectionJudgeInput(BaseModel):
"""批改判断节点输入"""
question_items: List[QuestionItem] = Field(default=[], description="题目项列表")
correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表")
comment_max_length: int = Field(default=100, description="评语最大字数")
class CorrectionJudgeOutput(BaseModel):
"""批改判断节点输出"""
correction_results: List[CorrectionResult] = Field(default=[], description="批改结果列表")
# 4. 结果整合节点
class ResultMergeInput(BaseModel):
"""结果整合节点输入"""
question_items: List[QuestionItem] = Field(default=[], description="题目项列表")
correction_results: List[CorrectionResult] = Field(default=[], description="批改结果列表")
class ResultMergeOutput(BaseModel):
"""结果整合节点输出"""
annotations: List[Annotation] = Field(default=[], description="最终批注列表")
# === 循环节点 ===
class ProcessImagesInput(BaseModel):
"""多图片处理循环节点输入"""
homework_images: List[File] = Field(default=[], description="作业图片列表")
correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表")
comment_max_length: int = Field(default=100, description="评语最大字数")
max_concurrent: int = Field(default=10, description="并行批改的最大数量")
class ProcessImagesOutput(BaseModel):
"""多图片处理循环节点输出"""
final_result: FinalResult = Field(..., description="最终批改结果")

546
src/main.py Normal file
View File

@ -0,0 +1,546 @@
import argparse
import asyncio
import json
import threading
import traceback
import logging
from typing import Any, Dict, Iterable, AsyncIterable, AsyncGenerator, Optional
import cozeloop
import uvicorn
import time
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, JSONResponse
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph, END
from langgraph.graph.state import CompiledStateGraph
from coze_coding_utils.runtime_ctx.context import new_context, Context
from coze_coding_utils.helper import graph_helper
from coze_coding_utils.log.node_log import LOG_FILE
from coze_coding_utils.log.write_log import setup_logging, request_context
from coze_coding_utils.log.config import LOG_LEVEL
from coze_coding_utils.error.classifier import ErrorClassifier, classify_error
from coze_coding_utils.helper.stream_runner import AgentStreamRunner, WorkflowStreamRunner,agent_stream_handler,workflow_stream_handler, RunOpt
setup_logging(
log_file=LOG_FILE,
max_bytes=100 * 1024 * 1024, # 100MB
backup_count=5,
log_level=LOG_LEVEL,
use_json_format=True,
console_output=True
)
logger = logging.getLogger(__name__)
from coze_coding_utils.helper.agent_helper import to_stream_input
from coze_coding_utils.openai.handler import OpenAIChatHandler
from coze_coding_utils.log.parser import LangGraphParser
from coze_coding_utils.log.err_trace import extract_core_stack
from coze_coding_utils.log.loop_trace import init_run_config, init_agent_config
# 超时配置常量
TIMEOUT_SECONDS = 900 # 15分钟
class GraphService:
def __init__(self):
# 用于跟踪正在运行的任务使用asyncio.Task
self.running_tasks: Dict[str, asyncio.Task] = {}
# 错误分类器
self.error_classifier = ErrorClassifier()
# stream runner
self._agent_stream_runner = AgentStreamRunner()
self._workflow_stream_runner = WorkflowStreamRunner()
self._graph = None
self._graph_lock = threading.Lock()
def _get_graph(self, ctx=Context):
if graph_helper.is_agent_proj():
return graph_helper.get_agent_instance("agents.agent", ctx)
if self._graph is not None:
return self._graph
with self._graph_lock:
if self._graph is not None:
return self._graph
self._graph = graph_helper.get_graph_instance("graphs.graph")
return self._graph
@staticmethod
def _sse_event(data: Any, event_id: Any = None) -> str:
id_line = f"id: {event_id}\n" if event_id else ""
return f"{id_line}event: message\ndata: {json.dumps(data, ensure_ascii=False, default=str)}\n\n"
def _get_stream_runner(self):
if graph_helper.is_agent_proj():
return self._agent_stream_runner
else:
return self._workflow_stream_runner
# 流式运行(原始迭代器):本地调用使用
def stream(self, payload: Dict[str, Any], run_config: RunnableConfig, ctx=Context) -> Iterable[Any]:
graph = self._get_graph(ctx)
stream_runner = self._get_stream_runner()
for chunk in stream_runner.stream(payload, graph, run_config, ctx):
yield chunk
# 同步运行:本地/HTTP 通用
async def run(self, payload: Dict[str, Any], ctx=None) -> Dict[str, Any]:
if ctx is None:
ctx = new_context("run")
run_id = ctx.run_id
logger.info(f"Starting run with run_id: {run_id}")
try:
graph = self._get_graph(ctx)
# custom tracer
run_config = init_run_config(graph, ctx)
run_config["configurable"] = {"thread_id": ctx.run_id}
# 直接调用LangGraph会在当前任务上下文中执行
# 如果当前任务被取消LangGraph的执行也会被取消
return await graph.ainvoke(payload, config=run_config, context=ctx)
except asyncio.CancelledError:
logger.info(f"Run {run_id} was cancelled")
return {"status": "cancelled", "run_id": run_id, "message": "Execution was cancelled"}
except Exception as e:
# 使用错误分类器分类错误
err = self.error_classifier.classify(e, {"node_name": "run", "run_id": run_id})
# 记录详细的错误信息和堆栈跟踪
logger.error(
f"Error in GraphService.run: [{err.code}] {err.message}\n"
f"Category: {err.category.name}\n"
f"Traceback:\n{extract_core_stack()}"
)
# 保留原始异常堆栈,便于上层返回真正的报错位置
raise
finally:
# 清理任务记录
self.running_tasks.pop(run_id, None)
# 流式运行SSE 格式化HTTP 路由使用
async def stream_sse(self, payload: Dict[str, Any], ctx=None, run_opt: Optional[RunOpt] = None) -> AsyncGenerator[str, None]:
if ctx is None:
ctx = new_context(method="stream_sse")
if run_opt is None:
run_opt = RunOpt()
run_id = ctx.run_id
logger.info(f"Starting stream with run_id: {run_id}")
graph = self._get_graph(ctx)
if graph_helper.is_agent_proj():
run_config = init_agent_config(graph, ctx)
else:
run_config = init_run_config(graph, ctx) # vibeflow
is_workflow = not graph_helper.is_agent_proj()
try:
async for chunk in self.astream(payload, graph, run_config=run_config, ctx=ctx, run_opt=run_opt):
if is_workflow and isinstance(chunk, tuple):
event_id, data = chunk
yield self._sse_event(data, event_id)
else:
yield self._sse_event(chunk)
finally:
# 清理任务记录
self.running_tasks.pop(run_id, None)
cozeloop.flush()
# 取消执行 - 使用asyncio的标准方式
def cancel_run(self, run_id: str, ctx: Optional[Context] = None) -> Dict[str, Any]:
"""
取消指定run_id的执行
使用asyncio.Task.cancel()来取消任务,这是标准的Python异步取消机制
LangGraph会在节点之间检查CancelledError,实现优雅的取消
"""
logger.info(f"Attempting to cancel run_id: {run_id}")
# 查找对应的任务
if run_id in self.running_tasks:
task = self.running_tasks[run_id]
if not task.done():
# 使用asyncio的标准取消机制
# 这会在下一个await点抛出CancelledError
task.cancel()
logger.info(f"Cancellation requested for run_id: {run_id}")
return {
"status": "success",
"run_id": run_id,
"message": "Cancellation signal sent, task will be cancelled at next await point"
}
else:
logger.info(f"Task already completed for run_id: {run_id}")
return {
"status": "already_completed",
"run_id": run_id,
"message": "Task has already completed"
}
else:
logger.warning(f"No active task found for run_id: {run_id}")
return {
"status": "not_found",
"run_id": run_id,
"message": "No active task found with this run_id. Task may have already completed or run_id is invalid."
}
# 运行指定节点:本地/HTTP 通用
async def run_node(self, node_id: str, payload: Dict[str, Any], ctx=None) -> Any:
if ctx is None or Context.run_id == "":
ctx = new_context(method="node_run")
_graph = self._get_graph()
node_func, input_cls, output_cls = graph_helper.get_graph_node_func_with_inout(_graph.get_graph(), node_id)
if node_func is None or input_cls is None:
raise KeyError(f"node_id '{node_id}' not found")
parser = LangGraphParser(_graph)
metadata = parser.get_node_metadata(node_id) or {}
_g = StateGraph(input_cls, input_schema=input_cls, output_schema=output_cls)
_g.add_node("sn", node_func, metadata=metadata)
_g.set_entry_point("sn")
_g.add_edge("sn", END)
_graph = _g.compile()
run_config = init_run_config(_graph, ctx)
return await _graph.ainvoke(payload, config=run_config)
def graph_inout_schema(self) -> Any:
if graph_helper.is_agent_proj():
return {"input_schema": {}, "output_schema": {}}
builder = getattr(self._get_graph(), 'builder', None)
if builder is not None:
input_cls = getattr(builder, 'input_schema', None) or self.graph.get_input_schema()
output_cls = getattr(builder, 'output_schema', None) or self.graph.get_output_schema()
else:
logger.warning(f"No builder input schema found for graph_inout_schema, using graph input schema instead")
input_cls = self.graph.get_input_schema()
output_cls = self.graph.get_output_schema()
return {
"input_schema": input_cls.model_json_schema(),
"output_schema": output_cls.model_json_schema(),
"code":0,
"msg":""
}
async def astream(self, payload: Dict[str, Any], graph: CompiledStateGraph, run_config: RunnableConfig, ctx=Context, run_opt: Optional[RunOpt] = None) -> AsyncIterable[Any]:
stream_runner = self._get_stream_runner()
async for chunk in stream_runner.astream(payload, graph, run_config, ctx, run_opt):
yield chunk
service = GraphService()
app = FastAPI()
# OpenAI 兼容接口处理器
openai_handler = OpenAIChatHandler(service)
HEADER_X_RUN_ID = "x-run-id"
@app.post("/run")
async def http_run(request: Request) -> Dict[str, Any]:
global result
raw_body = await request.body()
try:
body_text = raw_body.decode("utf-8")
except Exception as e:
body_text = str(raw_body)
raise HTTPException(status_code=400,
detail=f"Invalid JSON format: {body_text}, traceback: {traceback.format_exc()}, error: {e}")
ctx = new_context(method="run", headers=request.headers)
# 优先使用上游指定的 run_id保证 cancel 能精确匹配
upstream_run_id = request.headers.get(HEADER_X_RUN_ID)
if upstream_run_id:
ctx.run_id = upstream_run_id
run_id = ctx.run_id
request_context.set(ctx)
logger.info(
f"Received request for /run: "
f"run_id={run_id}, "
f"query={dict(request.query_params)}, "
f"body={body_text}"
)
try:
payload = await request.json()
# 创建任务并记录 - 这是关键让我们可以通过run_id取消任务
task = asyncio.create_task(service.run(payload, ctx))
service.running_tasks[run_id] = task
try:
result = await asyncio.wait_for(task, timeout=float(TIMEOUT_SECONDS))
except asyncio.TimeoutError:
logger.error(f"Run execution timeout after {TIMEOUT_SECONDS}s for run_id: {run_id}")
task.cancel()
try:
result = await task
except asyncio.CancelledError:
return {
"status": "timeout",
"run_id": run_id,
"message": f"Execution timeout: exceeded {TIMEOUT_SECONDS} seconds"
}
if not result:
result = {}
if isinstance(result, dict):
result["run_id"] = run_id
return result
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in http_run: {e}, traceback: {traceback.format_exc()}")
raise HTTPException(status_code=400, detail=f"Invalid JSON format, {extract_core_stack()}")
except asyncio.CancelledError:
logger.info(f"Request cancelled for run_id: {run_id}")
result = {"status": "cancelled", "run_id": run_id, "message": "Execution was cancelled"}
return result
except Exception as e:
# 使用错误分类器获取错误信息
error_response = service.error_classifier.get_error_response(e, {"node_name": "http_run", "run_id": run_id})
logger.error(
f"Unexpected error in http_run: [{error_response['error_code']}] {error_response['error_message']}, "
f"traceback: {traceback.format_exc()}", exc_info=True
)
raise HTTPException(
status_code=500,
detail={
"error_code": error_response["error_code"],
"error_message": error_response["error_message"],
"stack_trace": extract_core_stack(),
}
)
finally:
cozeloop.flush()
HEADER_X_WORKFLOW_STREAM_MODE = "x-workflow-stream-mode"
def _register_task(run_id: str, task: asyncio.Task):
service.running_tasks[run_id] = task
@app.post("/stream_run")
async def http_stream_run(request: Request):
ctx = new_context(method="stream_run", headers=request.headers)
# 优先使用上游指定的 run_id保证 cancel 能精确匹配
upstream_run_id = request.headers.get(HEADER_X_RUN_ID)
if upstream_run_id:
ctx.run_id = upstream_run_id
workflow_stream_mode = request.headers.get(HEADER_X_WORKFLOW_STREAM_MODE, "").lower()
workflow_debug = workflow_stream_mode == "debug"
request_context.set(ctx)
raw_body = await request.body()
try:
body_text = raw_body.decode("utf-8")
except Exception as e:
body_text = str(raw_body)
raise HTTPException(status_code=400,
detail=f"Invalid JSON format: {body_text}, traceback: {extract_core_stack()}, error: {e}")
run_id = ctx.run_id
is_agent = graph_helper.is_agent_proj()
logger.info(
f"Received request for /stream_run: "
f"run_id={run_id}, "
f"is_agent_project={is_agent}, "
f"query={dict(request.query_params)}, "
f"body={body_text}"
)
try:
payload = await request.json()
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in http_stream_run: {e}, traceback: {traceback.format_exc()}")
raise HTTPException(status_code=400, detail=f"Invalid JSON format:{extract_core_stack()}")
if is_agent:
stream_generator = agent_stream_handler(
payload=payload,
ctx=ctx,
run_id=run_id,
stream_sse_func=service.stream_sse,
sse_event_func=service._sse_event,
error_classifier=service.error_classifier,
register_task_func=_register_task,
)
else:
stream_generator = workflow_stream_handler(
payload=payload,
ctx=ctx,
run_id=run_id,
stream_sse_func=service.stream_sse,
sse_event_func=service._sse_event,
error_classifier=service.error_classifier,
register_task_func=_register_task,
run_opt=RunOpt(workflow_debug=workflow_debug),
)
response = StreamingResponse(stream_generator, media_type="text/event-stream")
return response
@app.post("/cancel/{run_id}")
async def http_cancel(run_id: str, request: Request):
"""
取消指定run_id的执行
使用asyncio.Task.cancel()实现取消,这是Python标准的异步任务取消机制
LangGraph会在节点之间的await点检查CancelledError,实现优雅取消
"""
ctx = new_context(method="cancel", headers=request.headers)
request_context.set(ctx)
logger.info(f"Received cancel request for run_id: {run_id}")
result = service.cancel_run(run_id, ctx)
return result
@app.post(path="/node_run/{node_id}")
async def http_node_run(node_id: str, request: Request):
raw_body = await request.body()
try:
body_text = raw_body.decode("utf-8")
except UnicodeDecodeError:
body_text = str(raw_body)
raise HTTPException(status_code=400, detail=f"Invalid JSON format: {body_text}")
ctx = new_context(method="node_run", headers=request.headers)
request_context.set(ctx)
logger.info(
f"Received request for /node_run/{node_id}: "
f"query={dict(request.query_params)}, "
f"body={body_text}",
)
try:
payload = await request.json()
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in http_node_run: {e}, traceback: {traceback.format_exc()}")
raise HTTPException(status_code=400, detail=f"Invalid JSON format:{extract_core_stack()}")
try:
return await service.run_node(node_id, payload, ctx)
except KeyError:
raise HTTPException(status_code=404,
detail=f"node_id '{node_id}' not found or input miss required fields, traceback: {extract_core_stack()}")
except Exception as e:
# 使用错误分类器获取错误信息
error_response = service.error_classifier.get_error_response(e, {"node_name": node_id})
logger.error(
f"Unexpected error in http_node_run: [{error_response['error_code']}] {error_response['error_message']}, "
f"traceback: {traceback.format_exc()}", exc_info=True
)
raise HTTPException(
status_code=500,
detail={
"error_code": error_response["error_code"],
"error_message": error_response["error_message"],
"stack_trace": extract_core_stack(),
}
)
finally:
cozeloop.flush()
@app.post("/v1/chat/completions")
async def openai_chat_completions(request: Request):
"""OpenAI Chat Completions API 兼容接口"""
ctx = new_context(method="openai_chat", headers=request.headers)
request_context.set(ctx)
logger.info(f"Received request for /v1/chat/completions: run_id={ctx.run_id}")
try:
payload = await request.json()
return await openai_handler.handle(payload, ctx)
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in openai_chat_completions: {e}")
raise HTTPException(status_code=400, detail="Invalid JSON format")
finally:
cozeloop.flush()
@app.get("/health")
async def health_check():
try:
# 这里可以添加更多的健康检查逻辑
return {
"status": "ok",
"message": "Service is running",
}
except Exception as e:
raise HTTPException(status_code=503, detail=str(e))
@app.get(path="/graph_parameter")
async def http_graph_inout_parameter(request: Request):
return service.graph_inout_schema()
def parse_args():
parser = argparse.ArgumentParser(description="Start FastAPI server")
parser.add_argument("-m", type=str, default="http", help="Run mode, support http,flow,node")
parser.add_argument("-n", type=str, default="", help="Node ID for single node run")
parser.add_argument("-p", type=int, default=5000, help="HTTP server port")
parser.add_argument("-i", type=str, default="", help="Input JSON string for flow/node mode")
return parser.parse_args()
def parse_input(input_str: str) -> Dict[str, Any]:
"""Parse input string, support both JSON string and plain text"""
if not input_str:
return {"text": "你好"}
# Try to parse as JSON first
try:
return json.loads(input_str)
except json.JSONDecodeError:
# If not valid JSON, treat as plain text
return {"text": input_str}
def start_http_server(port):
workers = 1
reload = False
if graph_helper.is_dev_env():
reload = True
logger.info(f"Start HTTP Server, Port: {port}, Workers: {workers}")
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=reload, workers=workers)
if __name__ == "__main__":
args = parse_args()
if args.m == "http":
start_http_server(args.p)
elif args.m == "flow":
payload = parse_input(args.i)
result = asyncio.run(service.run(payload))
print(json.dumps(result, ensure_ascii=False, indent=2))
elif args.m == "node" and args.n:
payload = parse_input(args.i)
result = asyncio.run(service.run_node(args.n, payload))
print(json.dumps(result, ensure_ascii=False, indent=2))
elif args.m == "agent":
agent_ctx = new_context(method="agent")
for chunk in service.stream(
{
"type": "query",
"session_id": "1",
"message": "你好",
"content": {
"query": {
"prompt": [
{
"type": "text",
"content": {"text": "现在几点了?请调用工具获取当前时间"},
}
]
}
},
},
run_config={"configurable": {"session_id": "1"}},
ctx=agent_ctx,
):
print(chunk)

0
src/storage/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,94 @@
import os
import time
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.exc import OperationalError
import logging
logger = logging.getLogger(__name__)
MAX_RETRY_TIME = 20 # 连接最大重试时间(秒)
# Load environment variables from .env if present
try:
from dotenv import load_dotenv
load_dotenv()
except Exception:
pass
def get_db_url() -> str:
"""Build database URL from environment."""
url = os.getenv("PGDATABASE_URL") or ""
if url is not None and url != "":
return url
from coze_workload_identity import Client
try:
client = Client()
env_vars = client.get_project_env_vars()
client.close()
for env_var in env_vars:
if env_var.key == "PGDATABASE_URL":
url = env_var.value.replace("'", "'\\''")
return url
except Exception as e:
logger.error(f"Error loading PGDATABASE_URL: {e}")
raise e
finally:
if url is None or url == "":
logger.error("PGDATABASE_URL is not set")
return url
_engine = None
_SessionLocal = None
def _create_engine_with_retry():
url = get_db_url()
if url is None or url == "":
logger.error("PGDATABASE_URL is not set")
raise ValueError("PGDATABASE_URL is not set")
size = 100
overflow = 100
recycle = 1800
timeout = 30
engine = create_engine(
url,
pool_size=size,
max_overflow=overflow,
pool_pre_ping=True,
pool_recycle=recycle,
pool_timeout=timeout,
)
# 验证连接,带重试
start_time = time.time()
last_error = None
while time.time() - start_time < MAX_RETRY_TIME:
try:
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
return engine
except OperationalError as e:
last_error = e
elapsed = time.time() - start_time
logger.warning(f"Database connection failed, retrying... (elapsed: {elapsed:.1f}s)")
time.sleep(min(1, MAX_RETRY_TIME - elapsed))
logger.error(f"Database connection failed after {MAX_RETRY_TIME}s: {last_error}")
raise last_error # pyright: ignore [reportGeneralTypeIssues]
def get_engine():
global _engine
if _engine is None:
_engine = _create_engine_with_retry()
return _engine
def get_sessionmaker():
global _SessionLocal
if _SessionLocal is None:
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=get_engine())
return _SessionLocal
def get_session():
return get_sessionmaker()()
__all__ = [
"get_db_url",
"get_engine",
"get_sessionmaker",
"get_session",
]

View File

View File

@ -0,0 +1,8 @@
from sqlalchemy import BigInteger, DateTime, Identity, Index, Integer, JSON, PrimaryKeyConstraint, Text, text
from typing import Optional
import datetime
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass

View File

View File

@ -0,0 +1,135 @@
import psycopg
from psycopg_pool import AsyncConnectionPool
from langgraph.checkpoint.postgres import PostgresSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint.base import BaseCheckpointSaver
from typing import Optional, Union
import logging
import time
logger = logging.getLogger(__name__)
# 数据库连接超时时间(秒),每次尝试 15 秒,共尝试 2 次
DB_CONNECTION_TIMEOUT = 15
DB_MAX_RETRIES = 2
class MemoryManager:
"""Memory Manager 单例类"""
_instance: Optional['MemoryManager'] = None
_checkpointer: Optional[Union[AsyncPostgresSaver, MemorySaver]] = None
_pool: Optional[AsyncConnectionPool] = None
_setup_done: bool = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def _connect_with_retry(self, db_url: str) -> Optional[psycopg.Connection]:
"""带重试的数据库连接,每次 15 秒超时,共尝试 2 次"""
last_error = None
for attempt in range(1, DB_MAX_RETRIES + 1):
try:
logger.info(f"Attempting database connection (attempt {attempt}/{DB_MAX_RETRIES})")
conn = psycopg.connect(db_url, autocommit=True, connect_timeout=DB_CONNECTION_TIMEOUT)
logger.info(f"Database connection established on attempt {attempt}")
return conn
except Exception as e:
last_error = e
logger.warning(f"Database connection attempt {attempt} failed: {e}")
if attempt < DB_MAX_RETRIES:
time.sleep(1) # 重试前短暂等待
logger.error(f"All {DB_MAX_RETRIES} database connection attempts failed, last error: {last_error}")
return None
def _setup_schema_and_tables(self, db_url: str) -> bool:
"""同步创建 schema 和表(只执行一次),返回是否成功"""
if self._setup_done:
return True
conn = self._connect_with_retry(db_url)
if conn is None:
return False
try:
with conn.cursor() as cur:
cur.execute("CREATE SCHEMA IF NOT EXISTS memory")
conn.execute("SET search_path TO memory")
PostgresSaver(conn).setup()
self._setup_done = True
logger.info("Memory schema and tables created")
return True
except Exception as e:
logger.warning(f"Failed to setup schema/tables: {e}")
return False
finally:
conn.close()
def _get_db_url_safe(self) -> Optional[str]:
"""安全获取 db_url失败时返回 None"""
try:
from storage.database.db import get_db_url
db_url = get_db_url()
if db_url and db_url.strip():
return db_url
logger.warning("db_url is empty, will fallback to MemorySaver")
return None
except Exception as e:
logger.warning(f"Failed to get db_url: {e}, will fallback to MemorySaver")
return None
def _create_fallback_checkpointer(self) -> MemorySaver:
"""创建内存兜底 checkpointer"""
self._checkpointer = MemorySaver()
logger.warning("Using MemorySaver as fallback checkpointer (data will not persist across restarts)")
return self._checkpointer
def get_checkpointer(self) -> BaseCheckpointSaver:
"""获取 checkpointer优先使用 PostgresSaver失败时退化为 MemorySaver"""
if self._checkpointer is not None:
return self._checkpointer
# 1. 尝试获取 db_url
db_url = self._get_db_url_safe()
if not db_url:
return self._create_fallback_checkpointer()
# 2. 尝试连接数据库并创建 schema/表(带重试)
if not self._setup_schema_and_tables(db_url):
return self._create_fallback_checkpointer()
# 3. 连接字符串加上 search_path
if "?" in db_url:
db_url = f"{db_url}&options=-csearch_path%3Dmemory"
else:
db_url = f"{db_url}?options=-csearch_path%3Dmemory"
# 4. 尝试创建连接池和 checkpointer
try:
self._pool = AsyncConnectionPool(
conninfo=db_url,
timeout=DB_CONNECTION_TIMEOUT,
min_size=1,
max_idle=300,
check=AsyncConnectionPool.check_connection,
)
self._checkpointer = AsyncPostgresSaver(self._pool)
logger.info("AsyncPostgresSaver initialized successfully")
except Exception as e:
logger.warning(f"Failed to create AsyncPostgresSaver: {e}, will fallback to MemorySaver")
return self._create_fallback_checkpointer()
return self._checkpointer
_memory_manager: Optional[MemoryManager] = None
def get_memory_saver() -> BaseCheckpointSaver:
"""获取 checkpointer优先使用 PostgresSaverdb_url 不可用或连接失败时退化为 MemorySaver"""
global _memory_manager
if _memory_manager is None:
_memory_manager = MemoryManager()
return _memory_manager.get_checkpointer()

View File

View File

@ -0,0 +1,424 @@
import os
import re
from pathlib import Path
from typing import Optional, Any, Dict, List, TypedDict, Iterable
from uuid import uuid4
import boto3
from botocore.exceptions import ClientError
from boto3.s3.transfer import TransferConfig
import logging
logger = logging.getLogger(__name__)
# 允许的文件名字符集(面向用户输入的约束)
FILE_NAME_ALLOWED_RE = re.compile(r"^[A-Za-z0-9._\-/]+$")
class ListFilesResult(TypedDict):
# list_files 的返回结构类型
keys: List[str]
is_truncated: bool
next_continuation_token: Optional[str]
class S3SyncStorage:
"""S3兼容存储实现"""
def __init__(self, *, endpoint_url: Optional[str] = None, access_key: str, secret_key: str, bucket_name: str, region: str = "cn-beijing"):
self.endpoint_url = os.environ.get("COZE_BUCKET_ENDPOINT_URL") or endpoint_url or ''
self.access_key = access_key
self.secret_key = secret_key
self.bucket_name = bucket_name
self.region = region
self._client = None
def _get_client(self):
if self._client is None:
endpoint = self.endpoint_url
if endpoint is None or endpoint == "":
try:
from coze_workload_identity import Client as CozeEnvClient
coze_env_client = CozeEnvClient()
env_vars = coze_env_client.get_project_env_vars()
coze_env_client.close()
for env_var in env_vars:
if env_var.key == "COZE_BUCKET_ENDPOINT_URL":
endpoint = env_var.value.replace("'", "'\\''")
self.endpoint_url = endpoint
break
except Exception as e:
logger.error(f"Error loading COZE_BUCKET_ENDPOINT_URL: {e}")
# 保持向下校验逻辑,避免在此处中断
if endpoint is None or endpoint == "":
logger.error("未配置存储端点请设置endpoint_url")
raise ValueError("未配置存储端点请设置endpoint_url")
client = boto3.client(
"s3",
endpoint_url=endpoint,
aws_access_key_id=self.access_key,
aws_secret_access_key=self.secret_key,
region_name=self.region,
)
# 注册 before-call 钩子,发送前注入 x-storage-token 头
def _inject_header(**kwargs):
try:
from coze_workload_identity import Client as CozeClient
coze_client = CozeClient()
try:
token = coze_client.get_access_token()
except Exception as e:
logger.error("Error loading COZE_WORKLOAD_IDENTITY_TOKEN: %s", e)
token = None
raise e
finally:
coze_client.close()
params = kwargs.get("params", {})
headers = params.setdefault("headers", {})
headers["x-storage-token"] = token
except Exception as e:
logger.error("Error loading COZE_WORKLOAD_IDENTITY_TOKEN: %s", e)
pass
client.meta.events.register("before-call.s3", _inject_header)
self._client = client
return self._client
def _generate_object_key(self, *, original_name: str) -> str:
suffix = Path(original_name).suffix.lower()
stem = Path(original_name).stem
uniq = uuid4().hex[:8]
return f"{stem}_{uniq}{suffix}"
def _extract_logid(self, e: Exception) -> Optional[str]:
"""从 ClientError 中提取 x-tt-logid"""
if isinstance(e, ClientError):
headers = (e.response or {}).get("ResponseMetadata", {}).get("HTTPHeaders", {})
return headers.get("x-tt-logid")
return None
def _error_msg(self, msg: str, e: Exception) -> str:
"""构建带 logid 的错误信息"""
logid = self._extract_logid(e)
if logid:
return f"{msg}: {e} (x-tt-logid: {logid})"
return f"{msg}: {e}"
def _resolve_bucket(self, bucket: Optional[str]) -> str:
"""统一解析 bucket 来源,确保得到有效桶名。"""
target_bucket = bucket or os.environ.get("COZE_BUCKET_NAME") or self.bucket_name
if not target_bucket:
raise ValueError("未配置 bucket请传入 bucket 或设置 COZE_BUCKET_NAME或在实例化时提供 bucket_name")
return target_bucket
def _validate_file_name(self, name: str) -> None:
"""校验 S3 对象命名长度≤1024允许 [A-Za-z0-9._-/];不以 / 起止且不含 //。"""
msg = (
"file name invalid: 文件名需满足以下 S3 对象命名规范:"
"1) 长度 11024 字节;"
"2) 仅允许字母、数字、点(.)、下划线(_)、短横(-)、目录分隔符(/)"
"3) 不允许空格或以下特殊字符:? # & % { } ^ [ ] ` \\ < > ~ | \" ' + = : ;"
"4) 不以 / 开头或结尾,且不包含连续的 //"
"示例report_2025-12-11.pdf、images/photo-01.png。"
)
if not name or not name.strip():
raise ValueError(msg + "(原因:文件名为空)")
# S3 限制对象 key 最大 1024 字节,这里沿用到输入文件名
if len(name.encode("utf-8")) > 1024:
raise ValueError(msg + "(原因:长度超过 1024 字节)")
if name.startswith("/") or name.endswith("/"):
raise ValueError(msg + "(原因:以 / 开头或结尾)")
if "//" in name:
raise ValueError(msg + "(原因:包含连续的 //")
# 允许字符集校验
if not FILE_NAME_ALLOWED_RE.match(name):
bad = re.findall(r"[^A-Za-z0-9._\-/]", name)
example = bad[0] if bad else "非法字符"
raise ValueError(msg + f"(原因:包含非法字符,例如:{example}")
def upload_file(self, *, file_content: bytes, file_name: str, content_type: str = "application/octet-stream", bucket: Optional[str] = None) -> str:
# 先对输入文件名做规范校验,避免生成无效对象 key
self._validate_file_name(file_name)
try:
client = self._get_client()
object_key = self._generate_object_key(original_name=file_name)
target_bucket = self._resolve_bucket(bucket)
client.put_object(Bucket=target_bucket, Key=object_key, Body=file_content, ContentType=content_type)
return object_key
except Exception as e:
logger.error(self._error_msg("Error uploading file to S3", e))
raise e
def delete_file(self, *, file_key: str, bucket: Optional[str] = None) -> bool:
try:
client = self._get_client()
target_bucket = self._resolve_bucket(bucket)
client.delete_object(Bucket=target_bucket, Key=file_key)
return True
except Exception as e:
logger.error(self._error_msg("Error deleting file from S3", e))
raise e
def file_exists(self, *, file_key: str, bucket: Optional[str] = None) -> bool:
try:
client = self._get_client()
target_bucket = self._resolve_bucket(bucket)
client.head_object(Bucket=target_bucket, Key=file_key)
return True
except ClientError as e:
code = (e.response or {}).get("Error", {}).get("Code", "")
if code in {"404", "NoSuchKey", "NotFound"}:
return False
logger.error(self._error_msg("Error checking file existence in S3", e))
return False
except Exception as e:
logger.error(self._error_msg("Error checking file existence in S3", e))
return False
def read_file(self, *, file_key: str, bucket: Optional[str] = None) -> bytes:
try:
client = self._get_client()
target_bucket = self._resolve_bucket(bucket)
resp = client.get_object(Bucket=target_bucket, Key=file_key)
body = resp.get("Body")
if body is None:
raise RuntimeError("S3 get_object returned no Body")
try:
return body.read()
finally:
try:
body.close()
except Exception as ce:
# 资源关闭失败不影响读取结果,仅记录以便排查
logger.debug("Failed to close S3 response body: %s", ce)
except Exception as e:
logger.error(self._error_msg("Error reading file from S3", e))
raise e
def list_files(self, *, prefix: Optional[str] = None, bucket: Optional[str] = None, max_keys: int = 1000, continuation_token: Optional[str] = None) -> ListFilesResult:
"""列出对象,支持前缀过滤与分页;返回 keys/is_truncated/next_continuation_token。"""
try:
client = self._get_client()
target_bucket = self._resolve_bucket(bucket)
if max_keys <= 0 or max_keys > 1000:
raise ValueError("max_keys 必须在 1 到 1000 之间")
kwargs: Dict[str, Any] = {
"Bucket": target_bucket,
"MaxKeys": max_keys,
"Prefix": prefix,
"ContinuationToken": continuation_token,
}
kwargs = {k: v for k, v in kwargs.items() if v is not None}
resp = client.list_objects_v2(**kwargs)
contents = resp.get("Contents", []) or []
keys: List[str] = [item.get("Key") for item in contents if isinstance(item, dict) and item.get("Key")]
return {
"keys": keys,
"is_truncated": bool(resp.get("IsTruncated")),
"next_continuation_token": resp.get("NextContinuationToken"),
}
except ClientError as e:
code = (e.response or {}).get("Error", {}).get("Code", "")
logger.error(self._error_msg(f"Error listing files in S3 (code={code})", e))
raise e
except Exception as e:
logger.error(self._error_msg("Error listing files in S3", e))
raise e
def generate_presigned_url(self, *, key: str, bucket: Optional[str] = None, expire_time: int = 1800) -> str:
"""通过 S3 Proxy 生成签名 URL。"""
import json
import urllib.request as urllib_request
try:
from coze_workload_identity import Client as CozeClient
coze_client = CozeClient()
try:
token = coze_client.get_access_token()
finally:
try:
coze_client.close()
except Exception:
# 资源释放失败不影响后续流程
pass
except Exception as e:
logger.error(f"Error loading x-storage-token: {e}")
raise RuntimeError(f"获取 x-storage-token 失败: {e}")
try:
sign_base = os.environ.get("COZE_BUCKET_ENDPOINT_URL") or self.endpoint_url
if not sign_base:
raise ValueError("未配置签名端点:请设置 COZE_BUCKET_ENDPOINT_URL 或传入 endpoint_url")
sign_url_endpoint = sign_base.rstrip("/") + "/sign-url"
headers = {
"Content-Type": "application/json",
"x-storage-token": token,
}
target_bucket = self._resolve_bucket(bucket)
payload = {"bucket_name": target_bucket, "path": key, "expire_time": expire_time}
data = json.dumps(payload).encode("utf-8")
request = urllib_request.Request(sign_url_endpoint, data=data, headers=headers, method="POST")
except Exception as e:
logger.error(f"Error creating request for sign-url: {e}")
raise RuntimeError(f"创建 sign-url 请求失败: {e}")
try:
with urllib_request.urlopen(request) as resp:
resp_bytes = resp.read()
content_type = resp.headers.get("Content-Type", "")
text = resp_bytes.decode("utf-8", errors="replace")
if "application/json" in content_type or text.strip().startswith("{"):
try:
obj = json.loads(text)
except Exception:
return text
data = obj.get("data")
if isinstance(data, dict) and "url" in data:
return data["url"]
url_value = obj.get("url") or obj.get("signed_url") or obj.get("presigned_url")
if url_value:
return url_value
raise ValueError("签名服务返回缺少 data.url/url 字段")
return text
except Exception as e:
raise RuntimeError(f"生成签名URL失败: {e}")
def stream_upload_file(
self,
*,
fileobj,
file_name: str,
content_type: str = "application/octet-stream",
bucket: Optional[str] = None,
multipart_chunksize: int = 5 * 1024 * 1024,
multipart_threshold: int = 5 * 1024 * 1024,
max_concurrency: int = 1,
use_threads: bool = False,
) -> str:
"""流式上传(文件对象)
- fileobj: 任何带有 read() 方法的文件对象 open(..., 'rb') 返回的对象io.BytesIO
- file_name: 原始文件名用于生成唯一 key
- content_type: MIME 类型
- bucket: 目标桶为空时取环境变量或实例默认值
- multipart_chunksize: 分片大小默认 5MB以适配代理层限制
- multipart_threshold: 触发分片上传的阈值默认 5MB
- max_concurrency: 并发分片上传的并发数默认 1避免代理层节流影响
- use_threads: 是否启用线程并发默认 False
返回最终写入的对象 key
"""
try:
client = self._get_client()
target_bucket = self._resolve_bucket(bucket)
key = self._generate_object_key(original_name=file_name)
extra_args = {"ContentType": content_type} if content_type else {}
# 使用 boto3 的高阶方法执行多段上传(传入 TransferConfig 控制分片大小)
config = TransferConfig(
multipart_chunksize=multipart_chunksize,
multipart_threshold=multipart_threshold,
max_concurrency=max_concurrency,
use_threads=use_threads,
)
client.upload_fileobj(Fileobj=fileobj, Bucket=target_bucket, Key=key, ExtraArgs=extra_args, Config=config)
return key
except Exception as e:
logger.error(self._error_msg("Error streaming upload (fileobj) to S3", e))
raise e
def upload_from_url(
self,
*,
url: str,
bucket: Optional[str] = None,
timeout: int = 30,
) -> str:
"""从 URL 流式下载并上传到 S3
- url: 源文件 URL
- bucket: 目标桶为空时取环境变量或实例默认值
- timeout: HTTP 请求超时时间默认 30
返回最终写入的对象 key
"""
import urllib.request as urllib_request
from urllib.parse import urlparse, unquote
try:
request = urllib_request.Request(url)
with urllib_request.urlopen(request, timeout=timeout) as resp:
parsed = urlparse(url)
file_name = Path(unquote(parsed.path)).name or "file"
content_type = resp.headers.get("Content-Type", "application/octet-stream")
return self.stream_upload_file(
fileobj=resp,
file_name=file_name,
content_type=content_type,
bucket=bucket,
)
except Exception as e:
logger.error(self._error_msg("Error uploading from URL to S3", e))
raise e
def trunk_upload_file(self, *, chunk_iter: Iterable[bytes], file_name: str,
content_type: str = "application/octet-stream", bucket: Optional[str] = None,
part_size: int = 5 * 1024 * 1024) -> str:
"""流式上传(字节迭代器,显式分片 Multipart Upload
- chunk_iter: 可迭代对象逐块产生 bytes每块大小可变内部累积到 part_size 再上传最后一块可小于 5MB
- file_name: 原始文件名用于生成唯一 key
- content_type: MIME 类型
- bucket: 目标桶为空时取环境或实例默认值
- part_size: 每个 part 的最小大小除最后一个默认 5MB
返回最终写入的对象 key
"""
client = self._get_client()
target_bucket = self._resolve_bucket(bucket)
key = self._generate_object_key(original_name=file_name)
# 初始化分片上传
try:
init_resp = client.create_multipart_upload(Bucket=target_bucket, Key=key, ContentType=content_type)
upload_id = init_resp["UploadId"]
except Exception as e:
logger.error(self._error_msg("create_multipart_upload failed", e))
raise e
parts = []
part_number = 1
buffer = bytearray()
try:
for chunk in chunk_iter:
if not chunk:
continue
buffer.extend(chunk)
while len(buffer) >= part_size:
data = bytes(buffer[:part_size])
buffer = buffer[part_size:]
resp = client.upload_part(Bucket=target_bucket, Key=key, UploadId=upload_id, PartNumber=part_number,
Body=data)
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
part_number += 1
# 上传最后不足 part_size 的余量
if len(buffer) > 0:
resp = client.upload_part(Bucket=target_bucket, Key=key, UploadId=upload_id, PartNumber=part_number,
Body=bytes(buffer))
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
# 完成分片
client.complete_multipart_upload(
Bucket=target_bucket,
Key=key,
UploadId=upload_id,
MultipartUpload={"Parts": parts},
)
return key
except Exception as e:
logger.error(self._error_msg("multipart upload failed", e))
try:
client.abort_multipart_upload(Bucket=target_bucket, Key=key, UploadId=upload_id)
except Exception as ae:
logger.error(self._error_msg("abort_multipart_upload failed", ae))
raise e

0
src/tools/__init__.py Normal file
View File

0
src/utils/__init__.py Normal file
View File

View File

325
src/utils/file/file.py Normal file
View File

@ -0,0 +1,325 @@
import os
import requests
import uuid
import chardet
from io import BytesIO
from typing import Literal,Callable, Any, Optional,Union
from pydantic import BaseModel, Field, field_validator,PrivateAttr,ConfigDict
from urllib.parse import urlparse
from pptx import Presentation
MAX_FILE_SIZE = 100 * 1024 * 1024
class File(BaseModel):
"""
通用文件对象支持自动类型推断和路径管理
"""
url: str = Field(..., description="文件URL(http/https)或本地路径")
file_type: Literal['image', 'video', 'audio', 'document', 'default'] = Field(
default="default",
description="文件类型"
)
_local_path: Optional[str] = PrivateAttr(default=None)
model_config = ConfigDict(
json_schema_extra={
"x-component": "file-upload", # 前端用文件上传组件
}
)
def set_cache_path(self, path: str):
"""设置缓存路径"""
self._local_path = path
def get_cache_path(self) -> Optional[str]:
"""获取缓存路径(如果文件实际存在)"""
return self._local_path
@property
def is_remote(self) -> bool:
"""判断是网络URL还是本地文件"""
return self.url.startswith(('http://', 'https://'))
def infer_file_category(path_or_url: str) -> tuple[str, str]:
"""
根据路径或URL后缀判断文件类型
逻辑
1. 解析 URL 去除 query 参数 (?id=...)提取 path
2. 获取 path 最后一部分的文件名和后缀
3. 查表判断匹配不到则返回 'default'
Return:
- 分类:image, video, audio, document, default
- 后缀:.pdf
"""
# === 步骤 1 & 2: 提取纯净的后缀名 ===
# urlparse 可以同时处理本地路径 (会被视为 path) 和 网络 URL
parsed = urlparse(path_or_url)
path = parsed.path # 提取路径部分,忽略 http://... 和 ?query=...
# 获取文件名 (例如 /a/b/test.jpg -> test.jpg)
filename = os.path.basename(path)
# 分离后缀 (test.jpg -> .jpg)
_, ext_with_dot = os.path.splitext(filename)
# 如果没有后缀,直接兜底
if not ext_with_dot:
return 'default', ""
# 去除点并转小写 (例如 .JPG -> jpg)
ext = ext_with_dot.lstrip('.').lower()
# === 步骤 3: 查表匹配 ===
# 定义常见映射表
TYPE_MAPPING = {
'image': {
'apng', 'avif', 'bmp', 'gif', 'heic', 'ico', 'jpg', 'jpeg', 'png', 'svg', 'tiff', 'webp'
},
'video': {
'mp4', 'avi', 'mov', 'mkv', 'flv', 'wmv', 'webm', 'm4v', '3gp'
},
'audio': {
'mp3', 'wav', 'flac', 'aac', 'ogg', 'wma', 'm4a'
},
'document': {
'pdf', 'doc', 'docx', 'xls', 'xlsx', 'ppt', 'pptx',
'txt', 'md', 'csv', 'json', 'xml', 'html', 'htm'
},
}
for category, extensions in TYPE_MAPPING.items():
if ext in extensions:
return category, ext_with_dot
return 'default', ext_with_dot
class FileOps:
DOWNLOAD_DIR = "/tmp"
@staticmethod
def _get_bytes_stream(file_obj:File) -> tuple[bytes, str]:
"""
获取文件内容和后缀, 大小限制检查, 超出抛异常
"""
_, ext = infer_file_category(file_obj.url)
if file_obj.is_remote:
try:
# stream=True: 此时只下载 Headers连接保持打开还没下载 Body
with requests.get(file_obj.url, stream=True, timeout=60) as resp:
resp.raise_for_status()
content_length = resp.headers.get('Content-Length')
if content_length and int(content_length) > MAX_FILE_SIZE:
raise Exception(
f"文件大小 ({int(content_length)} bytes) 超过限制 100MB已终止下载。"
)
# 场景Header 缺失 Content-Length 或服务器 Header 欺骗
downloaded_content = BytesIO()
current_size = 0
# 分块读取,每块 8KB
for chunk in resp.iter_content(chunk_size=8192):
if chunk:
current_size += len(chunk)
if current_size > MAX_FILE_SIZE:
raise Exception(f"检测到文件超过 100MB已中断。")
downloaded_content.write(chunk)
# 获取完整 bytes
return downloaded_content.getvalue(), ext
except requests.RequestException as e:
raise RuntimeError(f"网络请求失败: {e}")
else:
if not os.path.exists(file_obj.url):
raise FileNotFoundError(f"本地文件不存在: {file_obj.url}")
'''
file_size = os.path.getsize(file_obj.url)
if file_size > MAX_FILE_SIZE:
raise Exception(f"本地文件大小 ({file_size} bytes) 超过限制 100MB")
'''
with open(file_obj.url, 'rb') as f:
return f.read(), ext
@staticmethod
def save_to_local(file_obj: File, filename: str) -> str:
"""
将当前文件对象的内容保存到本地路径, 返回本地路径
如果是本地路径直接返回
"""
if not file_obj.is_remote:
if os.path.exists(file_obj.url):
return file_obj.url
raise FileNotFoundError(f"Local file not found: {file_obj.url}")
try:
os.makedirs(FileOps.DOWNLOAD_DIR, exist_ok=True)
# 简单的文件名生成策略 (真实场景建议用 url hash 避免重复下载)
# ext = os.path.splitext(file_obj.url.split('?')[0])[1] or ".tmp"
# filename = f"{uuid.uuid4().hex}{ext}"
local_path = os.path.join(FileOps.DOWNLOAD_DIR, filename)
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}
with requests.get(file_obj.url, headers=headers, stream=True, timeout=120) as r:
r.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
return local_path
except Exception as e:
raise RuntimeError(f"Download failed for {file_obj.url}: {str(e)}")
@staticmethod
def read_bytes(file_obj:File) -> bytes:
"""
获取文件的原始二进制数据
场景上传到OSS保存到本地传给图像处理库
"""
content, _ = FileOps._get_bytes_stream(file_obj)
return content
@staticmethod
def extract_text(file_obj: File) -> str:
"""
提取文本内容
场景RAGHTML解析文档分析
"""
try:
content, ext = FileOps._get_bytes_stream(file_obj)
if ext in ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx']:
return FileOps._parse_document_bytes(file_obj, content, ext)
# 默认直接读
charset = chardet.detect(content)
if 'encoding' in charset:
return content.decode(charset['encoding'])
else:
return content.decode('utf-8')
except Exception as e:
return f"[FileOps Error] Failed to read content: {str(e)}"
@staticmethod
def _parse_document_bytes(file_obj: File, content: bytes, ext:str) -> str:
stream = BytesIO(content)
text_result = ""
try:
if ext == '.pdf':
import pypdf
reader = pypdf.PdfReader(stream)
for page in reader.pages:
text_result += page.extract_text() + "\n"
elif ext in ['.docx', '.doc']:
text_result = read_docx(stream)
elif ext in ['.xlsx', '.xls', '.csv']:
import pandas as pd
if ext == '.csv':
df = pd.read_csv(stream)
else:
df = pd.read_excel(stream)
text_result = df.to_string()
elif ext in ['.ppt', '.pptx']:
text_result = read_ppt(stream)
else:
text_result = f"[暂不支持解析该文档格式: {ext}]"
except ImportError as e:
text_result = f"[解析库缺失] {e}"
except Exception as e:
text_result = f"[解析失败] {e}"
return text_result
def read_docx(cont_stream) -> str:
"""
使用docx2python按顺序读取内容
"""
from docx2python import docx2python
doc_result = docx2python(cont_stream)
# 获取文档结构
all_parts = []
# docx2python以嵌套列表形式返回内容
# 遍历文档主体
for section in doc_result.body:
if isinstance(section, list):
for item in section:
if isinstance(item, list):
# 可能是表格或多级内容
for sub_item in item:
if isinstance(sub_item, str) and sub_item.strip():
all_parts.append(sub_item.strip())
elif isinstance(sub_item, list):
# 表格行
row_text = "\n".join([str(cell).strip() for cell in sub_item if str(cell).strip()])
if row_text:
all_parts.append(row_text)
elif isinstance(item, str) and item.strip():
all_parts.append(item.strip())
# 关闭文档
doc_result.close()
return "\n\n".join(all_parts)
def read_ppt(file_input: Union[str, bytes, BytesIO]) -> str:
if not Presentation:
return "[Error] 未安装 python-pptx 库,无法解析 PPT 文件"
# 1. 统一转换为文件流对象 (BytesIO)
if isinstance(file_input, str):
with open(file_input, 'rb') as f:
ppt_stream = BytesIO(f.read())
elif isinstance(file_input, bytes):
ppt_stream = BytesIO(file_input)
else:
ppt_stream = file_input
try:
prs = Presentation(ppt_stream)
full_text = []
for i, slide in enumerate(prs.slides):
page_content = []
page_content.append(f"=== 第 {i+1} 页 ===")
# shape.text_frame 包含了形状内的文本段落
for shape in slide.shapes:
# 提取普通文本框
if hasattr(shape, "text") and shape.text.strip():
page_content.append(shape.text.strip())
# B. 提取表格内容 (普通 shape.text 无法获取表格内的字)
if shape.has_table:
table_texts = []
for row in shape.table.rows:
row_cells = [cell.text_frame.text.strip() for cell in row.cells if cell.text_frame.text.strip()]
if row_cells:
table_texts.append(" | ".join(row_cells))
if table_texts:
page_content.append("[表格]\n" + "\n".join(table_texts))
# 很多重要信息藏在备注里
if slide.has_notes_slide:
notes = slide.notes_slide.notes_text_frame.text
if notes.strip():
page_content.append(f"[备注]: {notes.strip()}")
full_text.append("\n".join(page_content))
return "\n\n".join(full_text)
except Exception as e:
return f"[PPT解析失败] {str(e)}"