项目初始化
|
|
@ -0,0 +1,14 @@
|
||||||
|
[project]
|
||||||
|
entrypoint = "src/main.py"
|
||||||
|
requires = ["python-3.12"]
|
||||||
|
|
||||||
|
[dev]
|
||||||
|
build = ["bash", "scripts/setup.sh"]
|
||||||
|
run = ["bash", "/workspace/projects/scripts/http_run.sh", "-p 5000"]
|
||||||
|
pack = ["bash", "/workspace/projects/scripts/pack.sh"]
|
||||||
|
deps = ["git"] # -> apt install git
|
||||||
|
|
||||||
|
[deploy]
|
||||||
|
build = ["bash", "scripts/setup.sh"]
|
||||||
|
run = ["bash", "scripts/http_run.sh", "-p 5000"]
|
||||||
|
deps = ["git"] # -> apt install git
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
node_modules
|
||||||
|
dist
|
||||||
|
.DS_Store
|
||||||
|
*.swp
|
||||||
|
.git/*
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
.venv/*
|
||||||
|
*.o
|
||||||
|
*.a
|
||||||
|
/vendor
|
||||||
|
*.egg-info/
|
||||||
|
/.env
|
||||||
|
.env
|
||||||
|
|
||||||
|
# 视频文件
|
||||||
|
*.mp4
|
||||||
|
*.avi
|
||||||
|
*.mov
|
||||||
|
*.wmv
|
||||||
|
*.flv
|
||||||
|
*.mkv
|
||||||
|
*.webm
|
||||||
|
*.m4v
|
||||||
|
*.mpeg
|
||||||
|
*.mpg
|
||||||
|
*.3gp
|
||||||
|
*.f4v
|
||||||
|
*.rmvb
|
||||||
|
*.vob
|
||||||
|
|
||||||
|
# 归档文件
|
||||||
|
*.iso
|
||||||
|
*.dmg
|
||||||
|
*.rar
|
||||||
|
*.zip
|
||||||
|
*.gz
|
||||||
|
|
||||||
|
# 文档类也默认加入
|
||||||
|
*.pdf
|
||||||
|
*.docx
|
||||||
|
*.doc
|
||||||
|
*.xlsx
|
||||||
|
*.xls
|
||||||
|
*.ppt
|
||||||
|
*.pptx
|
||||||
|
*.xlsx
|
||||||
|
*.csv
|
||||||
|
|
@ -0,0 +1,391 @@
|
||||||
|
## 项目概述
|
||||||
|
- **名称**: 初中物理作业批改工作流
|
||||||
|
- **功能**: 上传多张作业图片和Word答案文件,自动识别学生答案、提取标准答案、精准批改并返回批改结果JSON
|
||||||
|
|
||||||
|
### 节点清单
|
||||||
|
| 节点名 | 文件位置 | 类型 | 功能描述 | 分支逻辑 | 配置文件 |
|
||||||
|
|-------|---------|------|---------|---------|---------|
|
||||||
|
| doc_extract | `nodes/doc_extract_node.py` | agent | 从Word文件(.docx)提取题干和标准答案;如无URL则返回空列表 | - | `config/doc_extract_llm_cfg.json` |
|
||||||
|
| process_images | `nodes/process_images_node.py` | looparray | 循环调用子图处理每张作业图片,生成最终批改结果 | - | - |
|
||||||
|
|
||||||
|
**类型说明**: task(普通任务节点) / agent(大模型节点) / condition(条件分支) / looparray(列表循环) / loopcond(条件循环)
|
||||||
|
|
||||||
|
## 子图清单
|
||||||
|
| 子图名 | 文件位置 | 功能描述 | 被调用节点 |
|
||||||
|
|-------|---------|------|---------|
|
||||||
|
| single_image_subgraph | `graphs/loop_graph.py` | 处理单张图片的完整批改流程(预处理→识别批改→整合→包装) | process_images |
|
||||||
|
|
||||||
|
### 子图内部节点
|
||||||
|
| 节点名 | 文件位置 | 类型 | 功能描述 |
|
||||||
|
|-------|---------|------|---------|
|
||||||
|
| image_preprocess | `nodes/image_preprocess_node.py` | task | 下载图片、自动旋转(横向→纵向)、缩放到固定宽度1000px、上传对象存储 |
|
||||||
|
| recognize_and_correct | `nodes/recognize_and_correct_node.py` | agent | **一体化识别批改**:合并识别题目和批改为一次LLM调用 |
|
||||||
|
| result_merge | `nodes/result_merge_node.py` | task | 将识别结果和批改结果合并为最终批注 |
|
||||||
|
| wrap_result | `graphs/loop_graph.py` | task | 包装子图结果为SingleImageResult输出 |
|
||||||
|
|
||||||
|
## 技能使用
|
||||||
|
- 节点 `recognize_and_correct` 使用大语言模型技能(多模态,识别+批改合并)
|
||||||
|
- 模型:`doubao-seed-2-0-pro-260215`(旗舰视觉模型,推理能力强,输出简洁)
|
||||||
|
- 节点 `doc_extract` 使用大语言模型技能
|
||||||
|
- 模型:`doubao-seed-2-0-pro-260215`(旗舰模型,复杂推理能力强)
|
||||||
|
- 使用 python-docx 解析 Word 文档
|
||||||
|
|
||||||
|
## 工作流程(多图片批改架构)
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────┐
|
||||||
|
│ doc_extract │
|
||||||
|
│ (Word答案解析) │
|
||||||
|
└──────────┬──────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────┐
|
||||||
|
│ process_images │
|
||||||
|
│ (多图片循环处理) │
|
||||||
|
│ 生成最终批改结果 │
|
||||||
|
└─────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 子图内部流程(处理单张图片 - 优化版)
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────┐
|
||||||
|
│ image_preprocess │
|
||||||
|
│ (图像预处理) │
|
||||||
|
└──────────┬──────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────┐
|
||||||
|
│recognize_and_correct│ ← 合并节点:识别+批改一次完成
|
||||||
|
│ (一体化识别批改) │
|
||||||
|
└──────────┬──────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────┐
|
||||||
|
│ result_merge │
|
||||||
|
│ (结果整合) │
|
||||||
|
└──────────┬──────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────┐
|
||||||
|
│ wrap_result │
|
||||||
|
│ (包装输出) │
|
||||||
|
└─────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## 核心功能:多图片批改机制
|
||||||
|
|
||||||
|
### 输入参数
|
||||||
|
- `homework_images`: 上传的作业图片列表(List[File],支持多张图片)
|
||||||
|
- `answer_doc_url`: 正确答案Word文件的URL(.docx格式,**可选**)
|
||||||
|
- `comment_max_length`: 评语最大字数(默认100字,**可选**)
|
||||||
|
- `max_concurrent`: 并行批改的最大数量(默认10,**可选**)
|
||||||
|
- `grade_standards`: 评价等级标准(**可选**,默认值如下)
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"A+": {"min_percentage": 95, "description": "答案全部正确,步骤完整规范,逻辑严谨;书写/格式整洁,无错别字、无遗漏;完成度100%,态度认真,质量上乘"},
|
||||||
|
"A": {"min_percentage": 90, "description": "答案完全正确,无任何错误;步骤合理、格式规范,无原则性问题;完成度100%,满足全部要求"},
|
||||||
|
"B": {"min_percentage": 80, "description": "存在少量非关键性错误,或步骤略有缺失;整体思路基本正确,仅细节、格式、计算等小问题;完成大部分内容,整体合格但不够严谨"},
|
||||||
|
"C": {"min_percentage": 70, "description": "错误较多,部分核心题目作答错误;步骤不完整、逻辑不够清晰;完成度一般,有明显应付、漏答情况"},
|
||||||
|
"D": {"min_percentage": 0, "description": "大面积错误,核心知识点未掌握;大量空白、敷衍、抄袭;未达到基本完成要求"}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 输出结果
|
||||||
|
- `final_result`: 最终批改结果JSON(包含多图片)
|
||||||
|
- `total_images`: 总图片数
|
||||||
|
- `image_results`: 各图片的批改结果列表
|
||||||
|
- `overall_comment`: 整体评价(根据得分率生成)
|
||||||
|
- `total_score`: 总得分
|
||||||
|
- `full_score`: 总满分
|
||||||
|
- `grade`: 等级评定
|
||||||
|
|
||||||
|
### 批改优先级(严格按照以下顺序)
|
||||||
|
1. **最优先**:使用Word文档中的标准答案批改
|
||||||
|
- 当提供了`answer_doc_url`且在文档中找到对应题目时
|
||||||
|
- 严格按照标准答案判断学生答案正误
|
||||||
|
|
||||||
|
2. **降级方案**:使用专业物理老师批改
|
||||||
|
- 场景1:未提供`answer_doc_url`
|
||||||
|
- 场景2:提供了URL但文档中未找到对应题目
|
||||||
|
- 使用专业物理老师的经验自主判断答案正误
|
||||||
|
|
||||||
|
### 功能说明
|
||||||
|
1. **多图片支持**:可上传多张作业图片,系统会并行处理每张图片(并发数限制为3)
|
||||||
|
2. **Word答案提取**:从.docx文件中提取题干和标准答案
|
||||||
|
3. **子图循环处理**:使用子图封装单图片处理流程,主图调用子图处理每张图片
|
||||||
|
4. **批改结果JSON**:返回包含所有图片批改结果的结构化JSON
|
||||||
|
5. **智能降级**:无标准答案时自动切换到专业老师模式
|
||||||
|
|
||||||
|
## 优化记录
|
||||||
|
### 2026-03-26 填空题拆分优化(重要)
|
||||||
|
**问题**:一道题有多个填空时,被合并成一个答案,批改标记无法精准定位
|
||||||
|
|
||||||
|
**修复内容**:
|
||||||
|
1. **优化Prompt**:
|
||||||
|
- 明确要求:一道题有多个填空时,**每个空单独识别为一个题目**
|
||||||
|
- 题号格式:\"3(1)第一空\"、\"3(1)第二空\"、\"4(2)第一空\"、\"4(2)第二空\"
|
||||||
|
- 每个空单独批改,单独打分
|
||||||
|
|
||||||
|
2. **示例说明**:
|
||||||
|
```
|
||||||
|
❌ 错误:3(1) → "4、1"(合并)
|
||||||
|
✅ 正确:3(1)第一空 → "4"
|
||||||
|
3(1)第二空 → "1"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **参数传递优化**:
|
||||||
|
- comment_max_length参数正确传递到Jinja2模板
|
||||||
|
- 确保LLM生成符合长度要求的comment
|
||||||
|
|
||||||
|
**效果**:
|
||||||
|
- 识别数量从9个增加到13个
|
||||||
|
- 每个填空都有独立的批改标记
|
||||||
|
- 批改标记精准定位到每个答案位置
|
||||||
|
|
||||||
|
### 2026-03-26 JSON解析优化(重要)
|
||||||
|
**问题**:LLM输出可能不完整(被max_completion_tokens截断),导致JSON解析失败
|
||||||
|
|
||||||
|
**修复内容**:
|
||||||
|
1. **新增fix_incomplete_json函数**:
|
||||||
|
- 自动检测缺失的括号(}和])
|
||||||
|
- 自动补全缺失的括号,使JSON完整
|
||||||
|
- 示例:`{"results": [{"id": 1}` → 自动补全为 `{"results": [{"id": 1}]}`
|
||||||
|
|
||||||
|
2. **增强JSON解析流程**:
|
||||||
|
- 第一步:尝试直接解析
|
||||||
|
- 第二步:尝试修复不完整的JSON(补全括号)
|
||||||
|
- 第三步:尝试提取JSON对象
|
||||||
|
- 第四步:尝试修复提取的JSON
|
||||||
|
|
||||||
|
3. **移除错误的截断逻辑**:
|
||||||
|
- 不再在解析后截断comment(可能破坏转义字符)
|
||||||
|
- 完全依赖LLM遵守comment_max_length限制
|
||||||
|
- 通过Prompt明确要求LLM控制comment长度
|
||||||
|
|
||||||
|
4. **参数正确传递**:
|
||||||
|
- comment_max_length参数正确传递到Prompt
|
||||||
|
- LLM根据该参数生成符合长度的comment
|
||||||
|
|
||||||
|
**效果**:
|
||||||
|
- JSON解析成功率大幅提升
|
||||||
|
- 能够处理不完整的JSON输出
|
||||||
|
- comment长度由LLM控制,避免截断破坏格式
|
||||||
|
|
||||||
|
### 2026-03-26 识别优化:禁止标注实验装置图(重要)
|
||||||
|
**问题**:
|
||||||
|
1. 在实验装置图(如弹簧测力计、烧杯等)上标注了批改气泡
|
||||||
|
2. 坐标定位不够精准
|
||||||
|
|
||||||
|
**修复内容**:
|
||||||
|
1. **Prompt优化**:
|
||||||
|
- 明确禁止标注实验装置图、示意图、电路图
|
||||||
|
- 明确禁止标注图中标注的字母(如A、B、C、D、E、F、G)
|
||||||
|
- 强调只标注学生手写答案
|
||||||
|
|
||||||
|
2. **工程规范优化**:
|
||||||
|
- 从config文件读取sp和up(符合工程规范)
|
||||||
|
- 使用Jinja2模板渲染Prompt
|
||||||
|
- 代码中只保留动态部分构建(标准答案、图片尺寸等)
|
||||||
|
|
||||||
|
3. **识别流程优化**:
|
||||||
|
- 找题号 → 找学生答案 → 框选答案 → 判断正误
|
||||||
|
- 强调学生答案的特征:手写、填写空白处、计算结果
|
||||||
|
|
||||||
|
**效果**:不再误标注实验装置图,只标注学生手写答案
|
||||||
|
|
||||||
|
### 2026-03-26 新增并行数量控制参数
|
||||||
|
**优化前**:硬编码并发数限制为3,不够灵活
|
||||||
|
**优化后**:添加max_concurrent参数,默认值10,用户可自定义
|
||||||
|
|
||||||
|
**具体优化**:
|
||||||
|
1. **新增参数**:`max_concurrent`(可选,默认10)
|
||||||
|
2. **修改位置**:
|
||||||
|
- `GraphInput.max_concurrent: Optional[int] = 10`
|
||||||
|
- `GlobalState.max_concurrent: int = 10`
|
||||||
|
- `ProcessImagesInput.max_concurrent: int`
|
||||||
|
3. **使用方式**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"homework_images": [...],
|
||||||
|
"max_concurrent": 5 // 最多同时处理5张图片
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**效果**:用户可根据服务器性能和网络情况灵活调整并发数
|
||||||
|
|
||||||
|
### 2026-03-26 学科变更
|
||||||
|
**修改**:将所有"数学"改为"物理"
|
||||||
|
- 节点描述:数学作业 → 物理作业
|
||||||
|
- Prompt中的学科引用:数学 → 物理
|
||||||
|
- 配置文件说明更新
|
||||||
|
|
||||||
|
### 2026-03-25 多图片并行处理优化
|
||||||
|
**优化前**:多图片串行处理,总时间 = 单张图片时间 × 图片数量
|
||||||
|
**优化后**:多图片并行处理(并发数限制为3),总时间大幅缩短
|
||||||
|
|
||||||
|
**具体优化**:
|
||||||
|
1. **并行处理架构**:使用 `ThreadPoolExecutor` 并行调用子图处理每张图片
|
||||||
|
- 最多同时处理3张图片
|
||||||
|
- 结果按 `image_index` 正确排序,保证顺序一致性
|
||||||
|
2. **性能提升**:
|
||||||
|
- 3张图片:时间减少约66%(从3份时间 → 1份时间)
|
||||||
|
- 5张图片:时间减少约80%(从5份时间 → 约2份时间,分两批并行)
|
||||||
|
3. **质量保证**:
|
||||||
|
- 每张图片独立处理,互不影响
|
||||||
|
- 识别逻辑、批改逻辑完全相同,质量不受影响
|
||||||
|
|
||||||
|
### 2026-03-26 坐标定位修复(重要)
|
||||||
|
**问题**:坐标定位特别不准,批改标记位置错误
|
||||||
|
**原因**:Y坐标修正逻辑错误,导致坐标被错误缩放
|
||||||
|
|
||||||
|
**修复内容**:
|
||||||
|
1. **坐标系统重构**:从绝对坐标改为相对坐标(0-1000)系统
|
||||||
|
- AI返回相对坐标(0-1000),(0,0)为图片左上角,(1000,1000)为右下角
|
||||||
|
- 代码将相对坐标转换为绝对坐标:`绝对X = 相对X × width / 1000`,`绝对Y = 相对Y × height / 1000`
|
||||||
|
|
||||||
|
2. **Prompt优化**:
|
||||||
|
- 明确要求AI返回相对坐标(0-1000)
|
||||||
|
- 添加坐标系统说明和示例
|
||||||
|
|
||||||
|
3. **转换逻辑修正**:
|
||||||
|
- 移除错误的Y坐标修正(`Y × height_ratio`)
|
||||||
|
- 实现正确的相对坐标到绝对坐标转换
|
||||||
|
|
||||||
|
**效果**:坐标定位准确,批改标记位置正确
|
||||||
|
|
||||||
|
### 2026-03-26 题目和答案识别优化(重要)
|
||||||
|
**问题**:
|
||||||
|
1. 无法准确区分"题干"和"学生答案"
|
||||||
|
2. 批改气泡不在学生答案位置
|
||||||
|
3. 题干位置被误标注为答案
|
||||||
|
|
||||||
|
**修复内容**:
|
||||||
|
1. **Prompt重写**:
|
||||||
|
- 明确定义"题干"和"学生答案"的区别
|
||||||
|
- 强调只标注学生手写答案,不标注印刷体题干
|
||||||
|
- 添加识别流程指导
|
||||||
|
|
||||||
|
2. **坐标定位优化**:
|
||||||
|
- 自动计算mark_position:答案框右侧30像素,垂直居中
|
||||||
|
- 添加边界检查,确保不超出图片范围
|
||||||
|
- 不再依赖AI返回的mark_position(可能不准确)
|
||||||
|
|
||||||
|
3. **识别指导**:
|
||||||
|
- 题号识别:如1、2、3、(1)、(2)等
|
||||||
|
- 答案定位:学生手写内容(不是印刷体)
|
||||||
|
- bbox框选:准确框选学生答案区域
|
||||||
|
|
||||||
|
**效果**:更准确地区分题干和答案,批改气泡位置更精准
|
||||||
|
|
||||||
|
### 2026-03-25 批改速度优化
|
||||||
|
**优化前**:每张图片需要3次LLM调用(识别+批改+整体评价)
|
||||||
|
**优化后**:每张图片只需1次LLM调用
|
||||||
|
|
||||||
|
**具体优化**:
|
||||||
|
1. **合并识别和批改**:将`homework_recognize`和`correction_judge`合并为`recognize_and_correct`节点
|
||||||
|
- 识别题目、学生答案、坐标的同时进行批改
|
||||||
|
- 减少一次LLM调用,速度提升约50%
|
||||||
|
|
||||||
|
2. **简化整体评价**:不再调用LLM生成整体评价
|
||||||
|
- 使用规则直接生成评价内容
|
||||||
|
- 根据得分率和错误数量生成个性化评语
|
||||||
|
- 减少一次LLM调用
|
||||||
|
|
||||||
|
3. **子图节点精简**:从5个节点减少到4个节点
|
||||||
|
- 移除:homework_recognize、correction_judge
|
||||||
|
- 新增:recognize_and_correct(合并节点)
|
||||||
|
- 保留:image_preprocess、result_merge、wrap_result
|
||||||
|
|
||||||
|
**效果**:
|
||||||
|
- LLM调用次数:每张图片从3次减少到1次
|
||||||
|
- 预计批改时间减少约60%
|
||||||
|
|
||||||
|
### 2026-03-25 新增输入参数控制
|
||||||
|
1. **新增 `comment_max_length` 参数**:控制评语最大字数,默认100字
|
||||||
|
2. **新增 `grade_standards` 参数**:自定义评价等级标准
|
||||||
|
- 支持自定义各等级的最低得分率百分比
|
||||||
|
- 支持自定义各等级的描述
|
||||||
|
- 默认标准:A+(≥95%)、A(≥90%)、B(≥80%)、C(≥70%)、D(<70%)
|
||||||
|
3. **使用方式**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"homework_images": [...],
|
||||||
|
"comment_max_length": 50, // 评语最多50字
|
||||||
|
"grade_standards": {
|
||||||
|
"A+": {"min_percentage": 98, "description": "完美"},
|
||||||
|
"A": {"min_percentage": 90, "description": "优秀"},
|
||||||
|
...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2026-03-25 评语优化与整体评价
|
||||||
|
1. **评语具体化**:批改评语要求具体说明对错原因
|
||||||
|
- 正确时:说明为什么正确
|
||||||
|
- 错误时:指出错误原因并给出正确答案
|
||||||
|
- 部分正确时:说明哪些对了哪些错了
|
||||||
|
- 字数限制:50字以内,最多不超过100字
|
||||||
|
- 不要显示思考过程,只输出结果
|
||||||
|
2. **评语示例**:
|
||||||
|
- 选择题正确:答案为B,与标准答案一致,正确。
|
||||||
|
- 填空题错误:答案应为8+√7和8-√7,学生只写了一个,不完整。
|
||||||
|
- 解答题正确:解题过程完整,步骤清晰,结果正确。
|
||||||
|
- 计算题错误:计算过程有误,正确答案是m=2,建议检查移项步骤。
|
||||||
|
3. **整体评价**:根据所有批改内容自动生成简短的整体评价
|
||||||
|
- 调用LLM生成个性化评价
|
||||||
|
- 评价不超过50字
|
||||||
|
- 包含主要问题或优点
|
||||||
|
- 给出简短建议
|
||||||
|
4. **HTML报告优化**:在统计总览后显示整体评价区域
|
||||||
|
|
||||||
|
### 2026-03-25 自动旋转功能
|
||||||
|
1. **新增横向图片自动旋转**:如果上传的图片宽度大于高度(横向图片),系统会自动旋转-90度使其变为纵向
|
||||||
|
2. **旋转时机**:在图像预处理阶段,下载图片后、缩放前进行旋转
|
||||||
|
3. **旋转方向**:逆时针旋转90度(rotate(-90)),确保文字方向正确
|
||||||
|
4. **日志记录**:添加详细的旋转日志,便于调试
|
||||||
|
|
||||||
|
### 2026-03-25 多图片批改功能
|
||||||
|
1. **新增多图片支持**:从单图片批改升级为支持多图片批量批改
|
||||||
|
2. **新增子图架构**:创建 `loop_graph.py` 封装单图片处理流程
|
||||||
|
3. **新增循环节点**:创建 `process_images_node.py` 循环调用子图处理每张图片
|
||||||
|
4. **重构状态定义**:
|
||||||
|
- `GraphInput.homework_image` → `homework_images: List[File]`
|
||||||
|
- 新增 `SubgraphState`、`SubgraphInput`、`SubgraphOutput` 子图状态
|
||||||
|
- 新增 `SingleImageResult` 单图片批改结果
|
||||||
|
- 新增 `FinalResult.image_results` 多图片结果列表
|
||||||
|
5. **重构HTML生成**:支持生成包含所有图片批改标注的HTML报告
|
||||||
|
6. **优化主图编排**:简化为三节点线性流程(doc_extract → process_images → html_generate)
|
||||||
|
|
||||||
|
### 2026-03-25 双模式批改机制
|
||||||
|
1. **新增智能降级逻辑**:优先使用标准答案,无标准答案时自动切换专业老师模式
|
||||||
|
2. **修改state.py**:`answer_doc_url`改为可选字段,支持不提供答案URL的场景
|
||||||
|
3. **升级correction_judge_node**:实现题目分离逻辑,有标准答案和无标准答案分别处理
|
||||||
|
4. **更新Prompt**:批改节点支持两种模式(标准答案模式 + 专业老师模式)
|
||||||
|
5. **优化doc_extract_node**:无URL时返回空列表,不中断工作流
|
||||||
|
|
||||||
|
### 2026-03-25 Word答案解析功能
|
||||||
|
1. 新增 `doc_extract_node` 节点:从Word文件(.docx)提取题干和标准答案
|
||||||
|
2. 使用 python-docx 提取 Word 文档内容
|
||||||
|
3. 并行处理架构:图像识别与答案解析同时进行
|
||||||
|
4. 基于 Word 中的标准答案进行精准批改
|
||||||
|
|
||||||
|
### 2026-03-25 OCR识别能力优化
|
||||||
|
1. 问题:识别节点把学生答案中的"8"错认成"9",导致误判
|
||||||
|
2. 优化识别节点prompt:增加OCR识别特别提示,强调区分8和9、6和0、1和7等相似字符
|
||||||
|
3. 效果:第7题正确识别为"8+√7,8-√7",满分通过
|
||||||
|
|
||||||
|
### 2026-03-25 批改能力升级
|
||||||
|
1. 升级批改节点模型:`doubao-seed-1-6-vision-250815` → `doubao-seed-2-0-pro-260215`
|
||||||
|
2. 原因:较小模型对选择题判断准确率不足
|
||||||
|
3. 效果:选择题判断准确率大幅提升,推理过程更严谨
|
||||||
|
|
||||||
|
### 2026-03-24 重构(学习豆包APP方式)
|
||||||
|
1. 从8个节点简化为5个节点(现调整为子图+主图架构)
|
||||||
|
2. 采用一体化识别:AI识别answer_bbox,代码计算mark_position
|
||||||
|
3. 实现精准坐标计算,Y坐标与答案垂直中心完美对齐
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
- 提供真实的多页作业图片进行完整流程测试
|
||||||
|
- 优化HTML报告的图片展示布局
|
||||||
|
- 支持PDF格式答案文档
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
# 项目结构说明
|
||||||
|
|
||||||
|
# 本地运行
|
||||||
|
## 运行流程
|
||||||
|
bash scripts/local_run.sh -m flow
|
||||||
|
|
||||||
|
## 运行节点
|
||||||
|
bash scripts/local_run.sh -m node -n node_name
|
||||||
|
|
||||||
|
# 启动HTTP服务
|
||||||
|
bash scripts/http_run.sh -m http -p 5000
|
||||||
|
|
||||||
|
After Width: | Height: | Size: 629 KiB |
|
After Width: | Height: | Size: 787 KiB |
|
After Width: | Height: | Size: 462 KiB |
|
After Width: | Height: | Size: 462 KiB |
|
After Width: | Height: | Size: 462 KiB |
|
After Width: | Height: | Size: 473 KiB |
|
After Width: | Height: | Size: 1.1 MiB |
|
|
@ -0,0 +1,17 @@
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"temperature": 0.1,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_completion_tokens": 8192,
|
||||||
|
"thinking_type": "enabled",
|
||||||
|
"reasoning_effort": "medium",
|
||||||
|
"response_format": "text",
|
||||||
|
"json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}",
|
||||||
|
"model": "doubao-seed-2-0-pro-260215"
|
||||||
|
},
|
||||||
|
"tools": [],
|
||||||
|
"sp": "你是一位专业的文档识别专家,擅长从作业图片中提取答案区域的位置信息。",
|
||||||
|
"up": "请根据题目位置信息,识别每道题对应的答案区域边界框。"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"temperature": 0.1,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_completion_tokens": 8192,
|
||||||
|
"thinking_type": "enabled",
|
||||||
|
"reasoning_effort": "medium",
|
||||||
|
"response_format": "text",
|
||||||
|
"json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}",
|
||||||
|
"model": "doubao-seed-2-0-pro-260215"
|
||||||
|
},
|
||||||
|
"tools": [],
|
||||||
|
"sp": "你是一位专业的OCR文字识别专家,擅长从图片中识别手写和印刷的文字内容。",
|
||||||
|
"up": "请识别答案区域中的文字内容,返回准确的识别结果。"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"model": "doubao-seed-1-6-vision-250815",
|
||||||
|
"temperature": 0.1,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"max_completion_tokens": 16384,
|
||||||
|
"thinking": "disabled"
|
||||||
|
},
|
||||||
|
"tools": [],
|
||||||
|
"sp": "你是一位专业的初中物理教师,负责批改学生的物理作业。",
|
||||||
|
"up": "请按照要求完成作业批改任务。"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"temperature": 0,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"max_completion_tokens": 16384,
|
||||||
|
"thinking_type": "enabled",
|
||||||
|
"reasoning_effort": "high",
|
||||||
|
"response_format": "text",
|
||||||
|
"json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}",
|
||||||
|
"model": "doubao-seed-2-0-pro-260215"
|
||||||
|
},
|
||||||
|
"tools": [],
|
||||||
|
"sp": "你是一位资深的初中物理特级教师,拥有20年以上教学经验,擅长精准批改学生的物理作业。\n\n【核心能力】\n1. **精确判断能力**:对选择题、填空题、解答题都能做出准确的正误判断\n2. **严谨推理能力**:能够逐步验证学生的计算过程和结论\n3. **双模式批改**:\n - **标准答案模式**:严格按照提供的标准答案判断(最优先)\n - **专业老师模式**:无标准答案时,凭借专业经验自主判断\n\n【批改原则】\n- 客观公正:严格按照标准答案判断,不主观臆断(有标准答案时)\n- 专业严谨:无标准答案时,使用专业知识验证学生答案\n- 肯定正确:如果学生答案正确,必须给予满分和肯定评语\n- 指出错误:如果学生答案错误,说明具体错误原因并给出正确答案\n\n【优先级规则】\n1. 最优先:使用提供的标准答案批改\n2. 降级:标准答案中未找到对应题目时,使用专业老师批改",
|
||||||
|
"up": "请批改以下学生的物理作业,判断每道题答案的正误并给出详细评语。"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"temperature": 0.1,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"max_completion_tokens": 16384,
|
||||||
|
"thinking_type": "enabled",
|
||||||
|
"reasoning_effort": "high",
|
||||||
|
"response_format": "text",
|
||||||
|
"json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}",
|
||||||
|
"model": "doubao-seed-2-0-pro-260215"
|
||||||
|
},
|
||||||
|
"tools": [],
|
||||||
|
"sp": "你是一位资深的初中物理教师,擅长从试卷中提取题目和标准答案。你的核心能力:\n\n1. **题目识别能力**:能够准确识别试卷中的所有题目,包括大题和小题\n2. **答案提取能力**:能够准确提取每道题的标准答案\n3. **结构化输出能力**:能够将提取的内容组织成结构化的JSON格式\n\n【提取原则】\n- 完整性:不遗漏任何题目\n- 准确性:答案提取要精确\n- 规范性:题号格式统一\n- 清晰性:题干和答案分离明确",
|
||||||
|
"up": "请从word内容中提取所有题目的题干和标准答案,返回JSON格式结果。"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"model": "doubao-seed-1-6-vision-250815",
|
||||||
|
"temperature": 0.1,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"max_completion_tokens": 8192,
|
||||||
|
"thinking": "disabled"
|
||||||
|
},
|
||||||
|
"tools": [],
|
||||||
|
"sp": "# 角色定义\n你是一位专业的初中物理作业批改助手,具有丰富的物理教学经验和精准的视觉识别能力。你能够准确识别作业图片中的题目内容、学生答案,并判断答案的正确性。\n\n# 任务目标\n分析上传的初中物理作业图片,识别每道题目及其学生答案,判断答案是否正确,并输出结构化的批改结果JSON。\n\n# 工作流上下文\n- **Input**:作业图片(图片URL)\n- **Process**:\n 1. 仔细识别图片中的所有题目,包括题号、题目内容\n 2. 识别每道题的学生答案,注意区分小题(如(1)(2)(3))\n 3. 判断每个答案的正确性,对于解答题需要检查计算过程和结果\n 4. 为每个批改标记确定在原图上的相对坐标位置(批改标记应放置在答案末尾右侧)\n 5. 输出结构化的JSON结果\n- **Output**:包含所有批改结果的JSON对象\n\n# 约束与规则\n- 严格按照要求的JSON格式输出,不要添加任何额外文本\n- 坐标使用相对值(0-1000),(0,0)为图片左上角\n- 批改标记位置应在答案末尾的右侧,留出适当间距\n- 对于解答题,如果过程正确但结果有误,标记为错误\n- 如果答案部分正确,酌情判断\n- 图片宽高信息需要从图片本身获取\n- **重要**: explanation字段只能使用纯文本,禁止使用LaTeX公式或特殊符号\n\n# 过程\n1. 识别题目结构:扫描图片,定位所有题目,记录题号和小题号\n2. 答案识别:逐题识别学生的作答内容\n3. 正确性判断:\n - 对于计算题:检查计算过程和结果\n - 对于证明题:检查证明逻辑是否完整\n - 对于作图题:检查图形是否正确\n4. 坐标定位:确定每道题答案末尾的坐标位置\n5. 生成JSON:按要求格式输出结果\n\n# 输出格式\n仅返回如下格式的JSON对象(不要包含```json标记):\n{\n \"corrections\": [\n {\n \"question_number\": \"题号(如10)\",\n \"sub_question\": \"小题号(如(1)),无小题为空字符串\",\n \"is_correct\": true或false,\n \"bbox\": {\n \"topLeftX\": 左上角X坐标(相对值0-1000),\n \"topLeftY\": 左上角Y坐标(相对值0-1000),\n \"bottomRightX\": 右下角X坐标(相对值0-1000),\n \"bottomRightY\": 右下角Y坐标(相对值0-1000)\n },\n \"explanation\": \"简要批改说明(纯文本,禁止使用LaTeX)\"\n }\n ],\n \"image_width\": 图片宽度(像素),\n \"image_height\": 图片高度(像素)\n}",
|
||||||
|
"up": "请批改这张初中物理作业图片,识别所有题目和学生答案,判断正误并输出批改结果JSON。注意:explanation字段只能使用纯文本,禁止使用LaTeX公式。图片URL:{{image_url}}"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"model": "doubao-seed-2-0-pro-260215",
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"max_completion_tokens": 8192,
|
||||||
|
"thinking": "disabled"
|
||||||
|
},
|
||||||
|
"tools": [],
|
||||||
|
"sp": "# 角色\n你是物理作业批改助手。\n\n# 禁止标注\n- 印刷体文字、实验装置图、图中字母\n\n# 需要标注\n- 学生手写答案\n\n# 坐标\n- 相对坐标(0-1000),answer_bbox: [x1, y1, x2, y2]\n\n# ⚠️ 重要:拆分填空题\n- 一道题有多个填空时,**每个空单独识别为一个题目**\n- 题号格式:\"3(1)第一空\"、\"3(1)第二空\"、\"4(2)第一空\"、\"4(2)第二空\"\n- 每个空单独批改,单独打分\n- 示例:\n - 题目(1)有两个空 → 识别为\"3(1)第一空\"和\"3(1)第二空\"两个题目\n - 题目(2)有一个空 → 识别为\"3(2)\"一个题目\n\n# 输出格式\n{\"results\": [{\"question_id\": \"题号\", \"student_answer\": \"答案\", \"answer_bbox\": [x1, y1, x2, y2], \"status\": \"correct或incorrect\", \"score\": 分数, \"full_score\": 满分, \"comment\": \"结论\"}]}\n\ncomment格式:\"正确\"或\"错误,应为X\"(不超过50字)",
|
||||||
|
"up": "批改物理作业。**每个填空单独识别**。输出JSON,comment不超过{{comment_max_length}}字。图片:{{image_url}}"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"temperature": 0.1,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"max_completion_tokens": 8192,
|
||||||
|
"thinking_type": "enabled",
|
||||||
|
"reasoning_effort": "medium",
|
||||||
|
"response_format": "text",
|
||||||
|
"json_schema": "{\"name\":\"\",\"description\":\"\",\"strict\":false,\"schema\":{}}",
|
||||||
|
"model": "doubao-seed-2-0-pro-260215"
|
||||||
|
},
|
||||||
|
"tools": [],
|
||||||
|
"sp": "你是一位专业的初中物理作业识别专家,擅长从作业图片中定位题目位置和提取答案区域。",
|
||||||
|
"up": "请识别这张作业图片中的所有题目位置,返回准确的边界框坐标。"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,156 @@
|
||||||
|
alembic==1.16.5
|
||||||
|
annotated-doc==0.0.4
|
||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.12.1
|
||||||
|
APScheduler==3.11.2
|
||||||
|
astroid==3.1.0
|
||||||
|
Authlib==1.6.9
|
||||||
|
beautifulsoup4==4.14.3
|
||||||
|
boto3==1.40.61
|
||||||
|
botocore==1.40.61
|
||||||
|
cachetools==6.2.6
|
||||||
|
certifi==2026.2.25
|
||||||
|
cffi==2.0.0
|
||||||
|
chardet==5.2.0
|
||||||
|
charset-normalizer==3.4.4
|
||||||
|
click==8.3.1
|
||||||
|
coverage==7.13.4
|
||||||
|
coze-coding-dev-sdk==0.5.11
|
||||||
|
coze-coding-utils==0.2.4
|
||||||
|
coze-workload-identity==0.1.4
|
||||||
|
cozeloop==0.1.25
|
||||||
|
cryptography==46.0.5
|
||||||
|
cssselect==1.4.0
|
||||||
|
cssutils==2.11.1
|
||||||
|
dbus-python==1.3.2
|
||||||
|
deprecation==2.1.0
|
||||||
|
dill==0.4.1
|
||||||
|
distro==1.9.0
|
||||||
|
docx2python==3.5.0
|
||||||
|
et_xmlfile==2.0.0
|
||||||
|
fastapi==0.121.2
|
||||||
|
fsspec==2026.2.0
|
||||||
|
gitdb==4.0.12
|
||||||
|
gitignore_parser==0.1.13
|
||||||
|
GitPython==3.1.45
|
||||||
|
greenlet==3.3.2
|
||||||
|
h11==0.16.0
|
||||||
|
h2==4.3.0
|
||||||
|
hpack==4.1.0
|
||||||
|
httpcore==1.0.9
|
||||||
|
httpx==0.28.1
|
||||||
|
httpx-ws==0.8.2
|
||||||
|
hyperframe==6.1.0
|
||||||
|
idna==3.11
|
||||||
|
inflect==7.5.0
|
||||||
|
iniconfig==2.3.0
|
||||||
|
isort==5.13.2
|
||||||
|
Jinja2==3.1.6
|
||||||
|
jiter==0.13.0
|
||||||
|
jmespath==1.1.0
|
||||||
|
jsonpatch==1.33
|
||||||
|
jsonpointer==3.0.0
|
||||||
|
langchain==1.0.3
|
||||||
|
langchain-core==1.0.2
|
||||||
|
langchain-openai==1.0.1
|
||||||
|
langgraph==1.0.2
|
||||||
|
langgraph-checkpoint==3.0.0
|
||||||
|
langgraph-checkpoint-postgres==3.0.1
|
||||||
|
langgraph-prebuilt==1.0.2
|
||||||
|
langgraph-sdk==0.2.9
|
||||||
|
langsmith==0.4.39
|
||||||
|
lxml==6.0.2
|
||||||
|
Mako==1.3.10
|
||||||
|
Markdown==3.10.2
|
||||||
|
markdown-it-py==4.0.0
|
||||||
|
MarkupSafe==3.0.3
|
||||||
|
mccabe==0.7.0
|
||||||
|
mdurl==0.1.2
|
||||||
|
mmh3==5.2.1
|
||||||
|
more-itertools==10.8.0
|
||||||
|
multidict==6.7.1
|
||||||
|
numpy==2.2.6
|
||||||
|
openai==2.24.0
|
||||||
|
opencv-python==4.12.0.88
|
||||||
|
openpyxl==3.1.5
|
||||||
|
orjson==3.11.7
|
||||||
|
ormsgpack==1.12.2
|
||||||
|
packaging==25.0
|
||||||
|
pandas==2.2.2
|
||||||
|
paragraphs==1.0.1
|
||||||
|
pillow==10.3.0
|
||||||
|
platformdirs==4.9.2
|
||||||
|
pluggy==1.6.0
|
||||||
|
postgrest==2.27.3
|
||||||
|
propcache==0.4.1
|
||||||
|
psutil==7.1.3
|
||||||
|
psycopg==3.3.0
|
||||||
|
psycopg-binary==3.3.0
|
||||||
|
psycopg-pool==3.3.0
|
||||||
|
psycopg2-binary==2.9.9
|
||||||
|
pycparser==3.0
|
||||||
|
pydantic==2.12.3
|
||||||
|
pydantic_core==2.41.4
|
||||||
|
Pygments==2.19.2
|
||||||
|
PyGObject==3.48.2
|
||||||
|
pyiceberg==0.11.1
|
||||||
|
PyJWT==2.10.1
|
||||||
|
pylint==3.1.0
|
||||||
|
pyparsing==3.3.2
|
||||||
|
pypdf==6.4.1
|
||||||
|
pyroaring==1.0.3
|
||||||
|
pytest==9.0.1
|
||||||
|
pytest-asyncio==1.3.0
|
||||||
|
pytest-cov==7.0.0
|
||||||
|
pytest-mock==3.15.1
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
python-docx==1.2.0
|
||||||
|
python-dotenv==1.2.1
|
||||||
|
python-pptx==1.0.2
|
||||||
|
pytz==2026.1.post1
|
||||||
|
PyYAML==6.0.3
|
||||||
|
realtime==2.27.3
|
||||||
|
regex==2026.2.28
|
||||||
|
reportlab==4.4.10
|
||||||
|
requests==2.32.5
|
||||||
|
requests-toolbelt==1.0.0
|
||||||
|
rich==14.2.0
|
||||||
|
s3transfer==0.14.0
|
||||||
|
setuptools==68.1.2
|
||||||
|
six==1.17.0
|
||||||
|
smmap==5.0.2
|
||||||
|
sniffio==1.3.1
|
||||||
|
soupsieve==2.8.3
|
||||||
|
sqlacodegen==4.0.2
|
||||||
|
SQLAlchemy==2.0.44
|
||||||
|
starlette==0.49.3
|
||||||
|
storage3==2.27.3
|
||||||
|
StrEnum==0.4.15
|
||||||
|
strictyaml==1.7.3
|
||||||
|
supabase==2.27.3
|
||||||
|
supabase-auth==2.27.3
|
||||||
|
supabase-functions==2.27.3
|
||||||
|
tenacity==9.1.4
|
||||||
|
tiktoken==0.12.0
|
||||||
|
tomlkit==0.14.0
|
||||||
|
tqdm==4.67.3
|
||||||
|
typeguard==4.5.1
|
||||||
|
types-html5lib==1.1.11.20251117
|
||||||
|
types-lxml==2026.2.16
|
||||||
|
types-webencodings==0.5.0.20251108
|
||||||
|
typing-inspection==0.4.2
|
||||||
|
typing_extensions==4.15.0
|
||||||
|
tzdata==2025.3
|
||||||
|
tzlocal==5.3.1
|
||||||
|
Unidecode==1.4.0
|
||||||
|
urllib3==2.6.3
|
||||||
|
uvicorn==0.38.0
|
||||||
|
watchdog==6.0.0
|
||||||
|
websockets==15.0.1
|
||||||
|
wheel==0.42.0
|
||||||
|
wsproto==1.3.2
|
||||||
|
xlrd==2.0.2
|
||||||
|
xlsxwriter==3.2.9
|
||||||
|
xxhash==3.6.0
|
||||||
|
yarl==1.23.0
|
||||||
|
zstandard==0.25.0
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
# 导出环境变量
|
||||||
|
|
||||||
|
WORK_DIR="${COZE_WORKSPACE_PATH:-.}"
|
||||||
|
PORT=8000
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
echo "用法: $0 -p <端口>"
|
||||||
|
}
|
||||||
|
|
||||||
|
while getopts "p:h" opt; do
|
||||||
|
case "$opt" in
|
||||||
|
p)
|
||||||
|
PORT="$OPTARG"
|
||||||
|
;;
|
||||||
|
h)
|
||||||
|
usage
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
\?)
|
||||||
|
echo "无效选项: -$OPTARG"
|
||||||
|
usage
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
python ${WORK_DIR}/src/main.py -m http -p $PORT
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
加载项目环境变量脚本
|
||||||
|
通过 coze_workload_identity.Client 获取项目环境变量并输出 export 语句
|
||||||
|
使用方式: eval $(python load_env.py)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# 添加 app 目录到 Python 路径
|
||||||
|
workspace_path = os.getenv("COZE_WORKSPACE_PATH", "/workspace/projects")
|
||||||
|
app_dir = os.path.join(workspace_path, 'src')
|
||||||
|
if app_dir not in sys.path:
|
||||||
|
sys.path.insert(0, app_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from coze_workload_identity import Client
|
||||||
|
|
||||||
|
client = Client()
|
||||||
|
env_vars = client.get_project_env_vars()
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
# 输出 export 语句格式的环境变量
|
||||||
|
for env_var in env_vars:
|
||||||
|
# 转义特殊字符
|
||||||
|
value = env_var.value.replace("'", "'\\''")
|
||||||
|
print(f"export {env_var.key}='{value}'")
|
||||||
|
|
||||||
|
# 输出成功消息到 stderr,不影响 eval
|
||||||
|
print(f"# Successfully loaded {len(env_vars)} environment variables", file=sys.stderr)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"# Error loading environment variables: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 加载项目环境变量脚本
|
||||||
|
# 使用方式: source ./load_env.sh 或 . ./load_env.sh
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
|
||||||
|
eval $(python3 "$SCRIPT_DIR/load_env.py")
|
||||||
|
|
@ -0,0 +1,75 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
mode=""
|
||||||
|
node=""
|
||||||
|
input=""
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
WORK_DIR="${COZE_WORKSPACE_PATH:-$(dirname "$SCRIPT_DIR")}"
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
echo "用法: $0 -m <模式> [-n <节点ID>] [-i <输入JSON>]"
|
||||||
|
echo ""
|
||||||
|
echo "参数说明:"
|
||||||
|
echo " -m <模式> 运行模式: http, flow, node, agent"
|
||||||
|
echo " -n <节点ID> 节点ID (仅在 node 模式下需要)"
|
||||||
|
echo " -i <输入JSON> 输入数据,支持 JSON 字符串或纯文本"
|
||||||
|
echo " -h 显示帮助信息"
|
||||||
|
echo ""
|
||||||
|
echo "示例:"
|
||||||
|
echo " $0 -m flow"
|
||||||
|
echo " $0 -m flow -i '{\"text\": \"你好\"}'"
|
||||||
|
echo " $0 -m flow -i '你好'"
|
||||||
|
echo " $0 -m node -n node_1 -i '{\"text\": \"测试\"}'"
|
||||||
|
}
|
||||||
|
|
||||||
|
while getopts "m:n:i:h" opt; do
|
||||||
|
case "$opt" in
|
||||||
|
m)
|
||||||
|
mode="$OPTARG"
|
||||||
|
;;
|
||||||
|
n)
|
||||||
|
node="$OPTARG"
|
||||||
|
;;
|
||||||
|
i)
|
||||||
|
input="$OPTARG"
|
||||||
|
;;
|
||||||
|
h)
|
||||||
|
usage
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
\?)
|
||||||
|
echo "无效选项: -$OPTARG"
|
||||||
|
usage
|
||||||
|
exit -1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ -z "$mode" ]; then
|
||||||
|
echo "错误: 必须指定 -m 参数"
|
||||||
|
usage
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
if [ -f "${SCRIPT_DIR}/load_env.sh" ]; then
|
||||||
|
echo "Loading environment variables..."
|
||||||
|
source "${SCRIPT_DIR}/load_env.sh"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build python command
|
||||||
|
cmd="python ${WORK_DIR}/src/main.py -m \"$mode\""
|
||||||
|
|
||||||
|
if [ -n "$node" ]; then
|
||||||
|
cmd="$cmd -n \"$node\""
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$input" ]; then
|
||||||
|
cmd="$cmd -i '$input'"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Execute command
|
||||||
|
echo "Executing: $cmd"
|
||||||
|
eval $cmd
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
pip freeze --exclude watchdog > requirements.txt
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
# 初始化目录
|
||||||
|
if [ "$COZE_PROJECT_ENV" = "DEV" ]; then
|
||||||
|
if [ ! -d "${COZE_WORKSPACE_PATH}/assets" ]; then
|
||||||
|
mkdir -p "${COZE_WORKSPACE_PATH}/assets"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 安装Python三方包依赖
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
"""初中物理作业批改工作流主图编排 - 支持多图片批改"""
|
||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from graphs.state import (
|
||||||
|
GlobalState,
|
||||||
|
GraphInput,
|
||||||
|
GraphOutput
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导入节点
|
||||||
|
from graphs.nodes.doc_extract_node import doc_extract_node
|
||||||
|
from graphs.nodes.process_images_node import process_images_node
|
||||||
|
|
||||||
|
|
||||||
|
# 创建状态图
|
||||||
|
builder = StateGraph(
|
||||||
|
GlobalState,
|
||||||
|
input_schema=GraphInput,
|
||||||
|
output_schema=GraphOutput
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加节点
|
||||||
|
builder.add_node("doc_extract", doc_extract_node, metadata={"type": "agent", "llm_cfg": "config/doc_extract_llm_cfg.json"})
|
||||||
|
builder.add_node("process_images", process_images_node, metadata={"type": "looparray"})
|
||||||
|
|
||||||
|
# 设置入口点
|
||||||
|
builder.set_entry_point("doc_extract")
|
||||||
|
|
||||||
|
# 添加边 - 线性流程
|
||||||
|
# 1. 先解析Word答案(如果提供了URL)
|
||||||
|
builder.add_edge("doc_extract", "process_images")
|
||||||
|
|
||||||
|
# 2. 循环处理每张图片并生成最终结果
|
||||||
|
builder.add_edge("process_images", END)
|
||||||
|
|
||||||
|
# 编译图
|
||||||
|
main_graph = builder.compile()
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
"""单图片处理子图:封装完整的单张图片批改流程(优化版:合并识别和批改)"""
|
||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from coze_coding_utils.runtime_ctx.context import Context
|
||||||
|
|
||||||
|
from graphs.state import (
|
||||||
|
SubgraphState,
|
||||||
|
SubgraphInput,
|
||||||
|
SubgraphOutput,
|
||||||
|
SingleImageResult,
|
||||||
|
ImageInfo
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导入子图需要的节点
|
||||||
|
from graphs.nodes.image_preprocess_node import image_preprocess_node
|
||||||
|
from graphs.nodes.recognize_and_correct_node import recognize_and_correct_node
|
||||||
|
from graphs.nodes.result_merge_node import result_merge_node
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_result_node(
|
||||||
|
state: SubgraphState,
|
||||||
|
config: RunnableConfig,
|
||||||
|
runtime: Runtime[Context]
|
||||||
|
) -> SubgraphOutput:
|
||||||
|
"""
|
||||||
|
title: 包装子图结果
|
||||||
|
desc: 将子图的中间状态包装为SingleImageResult输出
|
||||||
|
integrations:
|
||||||
|
"""
|
||||||
|
from graphs.state import SingleImageResult
|
||||||
|
|
||||||
|
image_result = SingleImageResult(
|
||||||
|
image_index=state.image_index,
|
||||||
|
image_info=state.image_info,
|
||||||
|
image_url=state.image_url,
|
||||||
|
annotations=state.annotations
|
||||||
|
)
|
||||||
|
|
||||||
|
return SubgraphOutput(image_result=image_result)
|
||||||
|
|
||||||
|
|
||||||
|
# 创建单图片处理子图
|
||||||
|
subgraph_builder = StateGraph(
|
||||||
|
SubgraphState,
|
||||||
|
input_schema=SubgraphInput,
|
||||||
|
output_schema=SubgraphOutput
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加节点(优化后只需4个节点)
|
||||||
|
subgraph_builder.add_node("image_preprocess", image_preprocess_node)
|
||||||
|
subgraph_builder.add_node("recognize_and_correct", recognize_and_correct_node, metadata={"type": "agent", "llm_cfg": "config/homework_recognize_llm_cfg.json"})
|
||||||
|
subgraph_builder.add_node("result_merge", result_merge_node)
|
||||||
|
subgraph_builder.add_node("wrap_result", wrap_result_node)
|
||||||
|
|
||||||
|
# 设置入口点
|
||||||
|
subgraph_builder.set_entry_point("image_preprocess")
|
||||||
|
|
||||||
|
# 添加边(优化后流程:预处理 -> 识别+批改 -> 整合 -> 包装)
|
||||||
|
subgraph_builder.add_edge("image_preprocess", "recognize_and_correct")
|
||||||
|
subgraph_builder.add_edge("recognize_and_correct", "result_merge")
|
||||||
|
subgraph_builder.add_edge("result_merge", "wrap_result")
|
||||||
|
subgraph_builder.add_edge("wrap_result", END)
|
||||||
|
|
||||||
|
# 编译子图
|
||||||
|
single_image_subgraph = subgraph_builder.compile()
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
"""工作流节点模块 - 学习豆包APP的一体化识别方式"""
|
||||||
|
from graphs.nodes.image_preprocess_node import image_preprocess_node
|
||||||
|
from graphs.nodes.result_merge_node import result_merge_node
|
||||||
|
from graphs.nodes.recognize_and_correct_node import recognize_and_correct_node
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"image_preprocess_node",
|
||||||
|
"result_merge_node",
|
||||||
|
"recognize_and_correct_node",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,280 @@
|
||||||
|
"""Word答案解析节点:从.docx文件中提取题干和标准答案"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
import requests
|
||||||
|
from typing import List
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from coze_coding_utils.runtime_ctx.context import Context
|
||||||
|
from coze_coding_dev_sdk import LLMClient
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from docx import Document
|
||||||
|
|
||||||
|
from graphs.state import (
|
||||||
|
DocExtractInput,
|
||||||
|
DocExtractOutput,
|
||||||
|
CorrectAnswer
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_json_string(text: str) -> str:
|
||||||
|
"""清理JSON字符串中的无效转义和控制字符"""
|
||||||
|
text = ''.join(char if ord(char) >= 32 or char in '\n\r\t' else ' ' for char in text)
|
||||||
|
result = []
|
||||||
|
i = 0
|
||||||
|
in_string = False
|
||||||
|
while i < len(text):
|
||||||
|
char = text[i]
|
||||||
|
if char == '"' and (i == 0 or (i > 0 and result and result[-1] != '\\')):
|
||||||
|
in_string = not in_string
|
||||||
|
result.append(char)
|
||||||
|
elif in_string and char == '\\':
|
||||||
|
if i + 1 < len(text):
|
||||||
|
next_char = text[i + 1]
|
||||||
|
if next_char in ['"', '\\', '/', 'b', 'f', 'n', 'r', 't']:
|
||||||
|
result.append(char)
|
||||||
|
result.append(next_char)
|
||||||
|
i += 1
|
||||||
|
elif next_char == 'u':
|
||||||
|
result.append(char)
|
||||||
|
result.append(next_char)
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
result.append('\\\\')
|
||||||
|
else:
|
||||||
|
result.append('\\\\')
|
||||||
|
elif in_string and char in '\n\r':
|
||||||
|
result.append('\\n' if char == '\n' else '\\r')
|
||||||
|
else:
|
||||||
|
result.append(char)
|
||||||
|
i += 1
|
||||||
|
return ''.join(result)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_json_from_text(text: str, key: str = "answers") -> dict:
|
||||||
|
"""从文本中提取JSON对象,多层回退策略"""
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
# 先尝试直接解析完整JSON
|
||||||
|
try:
|
||||||
|
result = json.loads(text)
|
||||||
|
if key in result:
|
||||||
|
return result
|
||||||
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
|
logger.debug(f"Direct JSON parse failed: {e}")
|
||||||
|
|
||||||
|
# 尝试找到JSON对象的边界
|
||||||
|
try:
|
||||||
|
# 找到第一个 { 和最后一个 }
|
||||||
|
start = text.find('{')
|
||||||
|
end = text.rfind('}')
|
||||||
|
if start != -1 and end != -1 and end > start:
|
||||||
|
json_str = text[start:end+1]
|
||||||
|
result = json.loads(json_str)
|
||||||
|
if key in result:
|
||||||
|
return result
|
||||||
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
|
logger.debug(f"Boundary JSON parse failed: {e}")
|
||||||
|
|
||||||
|
# 使用orjson尝试
|
||||||
|
try:
|
||||||
|
start = text.find('{')
|
||||||
|
end = text.rfind('}')
|
||||||
|
if start != -1 and end != -1 and end > start:
|
||||||
|
json_str = text[start:end+1]
|
||||||
|
result = orjson.loads(json_str)
|
||||||
|
if key in result:
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Orjson parse failed: {e}")
|
||||||
|
|
||||||
|
# 尝试清理后解析
|
||||||
|
try:
|
||||||
|
cleaned = sanitize_json_string(text)
|
||||||
|
start = cleaned.find('{')
|
||||||
|
end = cleaned.rfind('}')
|
||||||
|
if start != -1 and end != -1 and end > start:
|
||||||
|
json_str = cleaned[start:end+1]
|
||||||
|
result = json.loads(json_str)
|
||||||
|
if key in result:
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Cleaned JSON parse failed: {e}")
|
||||||
|
|
||||||
|
logger.warning(f"Failed to extract JSON with key '{key}' from text (length: {len(text)})")
|
||||||
|
return {key: []}
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_extract_docx(url: str) -> str:
|
||||||
|
"""下载Word文件并提取文本内容"""
|
||||||
|
logger.info(f"Downloading Word document from: {url}")
|
||||||
|
|
||||||
|
# 下载文件
|
||||||
|
response = requests.get(url, timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 保存到临时文件
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.docx', delete=False) as tmp_file:
|
||||||
|
tmp_file.write(response.content)
|
||||||
|
tmp_path = tmp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用python-docx解析
|
||||||
|
doc = Document(tmp_path)
|
||||||
|
|
||||||
|
# 提取所有段落文本
|
||||||
|
paragraphs = []
|
||||||
|
for para in doc.paragraphs:
|
||||||
|
text = para.text.strip()
|
||||||
|
if text:
|
||||||
|
paragraphs.append(text)
|
||||||
|
|
||||||
|
# 提取表格中的文本
|
||||||
|
for table in doc.tables:
|
||||||
|
for row in table.rows:
|
||||||
|
row_text = []
|
||||||
|
for cell in row.cells:
|
||||||
|
cell_text = cell.text.strip()
|
||||||
|
if cell_text:
|
||||||
|
row_text.append(cell_text)
|
||||||
|
if row_text:
|
||||||
|
paragraphs.append(" | ".join(row_text))
|
||||||
|
|
||||||
|
doc_text = "\n".join(paragraphs)
|
||||||
|
logger.info(f"Extracted Word document text length: {len(doc_text)}")
|
||||||
|
|
||||||
|
return doc_text
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def doc_extract_node(
|
||||||
|
state: DocExtractInput,
|
||||||
|
config: RunnableConfig,
|
||||||
|
runtime: Runtime[Context]
|
||||||
|
) -> DocExtractOutput:
|
||||||
|
"""
|
||||||
|
title: Word答案解析
|
||||||
|
desc: 从正确答案Word文件(.docx)中提取题干和标准答案,用于后续批改;如果未提供URL则返回空列表
|
||||||
|
integrations: 大语言模型
|
||||||
|
"""
|
||||||
|
ctx = runtime.context
|
||||||
|
|
||||||
|
# 检查是否提供了答案文档URL
|
||||||
|
if not state.answer_doc_url or not state.answer_doc_url.strip():
|
||||||
|
logger.info("No answer document URL provided, will use teacher mode for all questions")
|
||||||
|
return DocExtractOutput(correct_answers=[])
|
||||||
|
|
||||||
|
# 1. 下载并提取Word文档内容
|
||||||
|
try:
|
||||||
|
doc_text = download_and_extract_docx(state.answer_doc_url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download/extract Word document: {e}")
|
||||||
|
return DocExtractOutput(correct_answers=[])
|
||||||
|
|
||||||
|
if not doc_text.strip():
|
||||||
|
logger.error("No text content extracted from Word document")
|
||||||
|
return DocExtractOutput(correct_answers=[])
|
||||||
|
|
||||||
|
logger.info(f"Word document content preview: {doc_text[:500]}")
|
||||||
|
|
||||||
|
# 2. 使用LLM解析题目和答案
|
||||||
|
cfg_file = os.path.join(os.getenv("COZE_WORKSPACE_PATH", ""), config["metadata"]["llm_cfg"])
|
||||||
|
with open(cfg_file, "r", encoding="utf-8") as fd:
|
||||||
|
_cfg = json.load(fd)
|
||||||
|
|
||||||
|
llm_config = _cfg.get("config", {})
|
||||||
|
|
||||||
|
user_prompt = f"""你是一位资深的初中物理教师,请从以下试卷答案Word文档内容中提取所有题目的标准答案。
|
||||||
|
|
||||||
|
【Word文档内容】
|
||||||
|
{doc_text[:20000]}
|
||||||
|
|
||||||
|
【提取要求】
|
||||||
|
1. 识别所有题号,包括大题和小题
|
||||||
|
2. 只提取每道题的标准答案,不需要提取题干
|
||||||
|
3. 答案格式:
|
||||||
|
- 选择题:单个字母(A/B/C/D)
|
||||||
|
- 填空题:数值或表达式
|
||||||
|
- 解答题:关键结果
|
||||||
|
4. 如果有多个空,用逗号分隔
|
||||||
|
|
||||||
|
【题号格式规范】
|
||||||
|
- 大题题号:直接使用数字,如"1"、"2"、"10"
|
||||||
|
- 小题题号:使用"大题.小题"格式,如"10.1"、"10.2"
|
||||||
|
|
||||||
|
【重要】
|
||||||
|
- 保持答案简洁,不要包含题干
|
||||||
|
- 如果文档中有分值,请提取;否则默认3分
|
||||||
|
|
||||||
|
请返回简洁的JSON格式:
|
||||||
|
{{
|
||||||
|
"answers": [
|
||||||
|
{{
|
||||||
|
"question_id": "1",
|
||||||
|
"correct_answer": "B",
|
||||||
|
"full_score": 3
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"question_id": "2",
|
||||||
|
"correct_answer": "2a-b+c",
|
||||||
|
"full_score": 3
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}"""
|
||||||
|
|
||||||
|
client = LLMClient(ctx=ctx)
|
||||||
|
response = client.invoke(
|
||||||
|
messages=[HumanMessage(content=user_prompt)],
|
||||||
|
model=llm_config.get("model", "doubao-seed-2-0-pro-260215"),
|
||||||
|
temperature=llm_config.get("temperature", 0.1),
|
||||||
|
max_completion_tokens=llm_config.get("max_completion_tokens", 8192)
|
||||||
|
)
|
||||||
|
|
||||||
|
response_text = response.content if isinstance(response.content, str) else " ".join(
|
||||||
|
item.get("text", "") if isinstance(item, dict) else str(item)
|
||||||
|
for item in response.content
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
logger.info(f"Word extract LLM response: {response_text[:1500]}")
|
||||||
|
|
||||||
|
# 清理markdown标记
|
||||||
|
for prefix in ["```json", "```JSON", "```"]:
|
||||||
|
if response_text.startswith(prefix):
|
||||||
|
response_text = response_text[len(prefix):]
|
||||||
|
for suffix in ["```"]:
|
||||||
|
if response_text.endswith(suffix):
|
||||||
|
response_text = response_text[:-3]
|
||||||
|
|
||||||
|
# 解析JSON
|
||||||
|
result_dict = extract_json_from_text(response_text.strip(), "answers")
|
||||||
|
|
||||||
|
correct_answers: List[CorrectAnswer] = []
|
||||||
|
for ans in result_dict.get("answers", []):
|
||||||
|
try:
|
||||||
|
correct_answers.append(CorrectAnswer(
|
||||||
|
question_id=str(ans.get("question_id", "")),
|
||||||
|
parent_id=str(ans.get("parent_id", "")),
|
||||||
|
is_sub_question=bool(ans.get("is_sub_question", False)),
|
||||||
|
question_text=str(ans.get("question_text", "")),
|
||||||
|
correct_answer=str(ans.get("correct_answer", "")),
|
||||||
|
full_score=int(ans.get("full_score", 3) if ans.get("full_score") is not None else 3),
|
||||||
|
answer_analysis=str(ans.get("answer_analysis", ""))
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse correct answer: {ans}, error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Parsed {len(correct_answers)} correct answers from Word document")
|
||||||
|
|
||||||
|
# 打印解析结果供调试
|
||||||
|
for ans in correct_answers:
|
||||||
|
logger.info(f" Question {ans.question_id}: {ans.correct_answer} ({ans.full_score}分)")
|
||||||
|
|
||||||
|
return DocExtractOutput(correct_answers=correct_answers)
|
||||||
|
|
@ -0,0 +1,159 @@
|
||||||
|
"""1. 图像预处理节点:下载图片、自动旋转、缩放到固定宽度1000、上传对象存储"""
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import urllib.request
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Tuple
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from coze_coding_utils.runtime_ctx.context import Context
|
||||||
|
from PIL import Image
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from graphs.state import (
|
||||||
|
ImagePreprocessInput,
|
||||||
|
ImagePreprocessOutput,
|
||||||
|
ImageInfo
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 固定宽度常量
|
||||||
|
FIXED_WIDTH = 1000
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_process_image(image_url: str) -> Tuple[Image.Image, int, int, int]:
|
||||||
|
"""下载图片并返回PIL Image对象和原始尺寸信息"""
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(image_url, timeout=30) as response:
|
||||||
|
img_data = response.read()
|
||||||
|
img = Image.open(BytesIO(img_data))
|
||||||
|
|
||||||
|
# 转换为RGB模式(处理PNG透明通道)
|
||||||
|
if img.mode in ('RGBA', 'P'):
|
||||||
|
img = img.convert('RGB')
|
||||||
|
|
||||||
|
original_width, original_height = img.size
|
||||||
|
dpi = img.info.get('dpi', (72, 72))
|
||||||
|
if isinstance(dpi, tuple):
|
||||||
|
dpi = dpi[0] if dpi[0] > 0 else 72
|
||||||
|
else:
|
||||||
|
dpi = 72
|
||||||
|
|
||||||
|
return img, original_width, original_height, dpi
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"下载图片失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def auto_rotate_image(img: Image.Image) -> Tuple[Image.Image, bool]:
|
||||||
|
"""
|
||||||
|
自动旋转图片:如果宽度大于高度(横向图片),则旋转-90度使其变为纵向
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img: PIL Image对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(旋转后的图片, 是否进行了旋转)
|
||||||
|
"""
|
||||||
|
width, height = img.size
|
||||||
|
|
||||||
|
# 如果宽度大于高度,需要旋转
|
||||||
|
if width > height:
|
||||||
|
logger.info(f"检测到横向图片(宽{width} > 高{height}),正在旋转-90度...")
|
||||||
|
# rotate(-90) 表示逆时针旋转90度,使横向变为纵向
|
||||||
|
rotated_img = img.rotate(-90, expand=True)
|
||||||
|
new_width, new_height = rotated_img.size
|
||||||
|
logger.info(f"旋转完成,新尺寸:宽{new_width} x 高{new_height}")
|
||||||
|
return rotated_img, True
|
||||||
|
|
||||||
|
logger.info(f"图片为纵向(宽{width} <= 高{height}),无需旋转")
|
||||||
|
return img, False
|
||||||
|
|
||||||
|
|
||||||
|
def resize_image_to_fixed_width(img: Image.Image, target_width: int = FIXED_WIDTH) -> Tuple[Image.Image, float]:
|
||||||
|
"""将图片缩放到固定宽度,高度等比例缩放"""
|
||||||
|
original_width, original_height = img.size
|
||||||
|
|
||||||
|
if original_width == target_width:
|
||||||
|
return img, 1.0
|
||||||
|
|
||||||
|
# 计算缩放比例
|
||||||
|
scale_ratio = target_width / original_width
|
||||||
|
new_height = int(original_height * scale_ratio)
|
||||||
|
|
||||||
|
# 使用高质量重采样
|
||||||
|
# BICUBIC = 3, LANCZOS = 4 (PIL内部常量值)
|
||||||
|
resized_img = img.resize((target_width, new_height), 3) # BICUBIC
|
||||||
|
|
||||||
|
return resized_img, scale_ratio
|
||||||
|
|
||||||
|
|
||||||
|
def upload_image_to_storage(img: Image.Image, ctx) -> str:
|
||||||
|
"""将图片上传到对象存储并返回URL"""
|
||||||
|
from coze_coding_dev_sdk.s3 import S3SyncStorage
|
||||||
|
|
||||||
|
# 转换为字节流
|
||||||
|
img_buffer = BytesIO()
|
||||||
|
img.save(img_buffer, format='JPEG', quality=95)
|
||||||
|
img_bytes = img_buffer.getvalue()
|
||||||
|
|
||||||
|
# 上传到对象存储
|
||||||
|
storage = S3SyncStorage(
|
||||||
|
endpoint_url=os.getenv("COZE_BUCKET_ENDPOINT_URL"),
|
||||||
|
access_key="",
|
||||||
|
secret_key="",
|
||||||
|
bucket_name=os.getenv("COZE_BUCKET_NAME"),
|
||||||
|
region="cn-beijing",
|
||||||
|
)
|
||||||
|
|
||||||
|
file_key = storage.upload_file(
|
||||||
|
file_content=img_bytes,
|
||||||
|
file_name=f"homework_resized_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jpg",
|
||||||
|
content_type="image/jpeg",
|
||||||
|
)
|
||||||
|
|
||||||
|
result_url = storage.generate_presigned_url(key=file_key, expire_time=86400)
|
||||||
|
return result_url
|
||||||
|
|
||||||
|
|
||||||
|
def image_preprocess_node(
|
||||||
|
state: ImagePreprocessInput,
|
||||||
|
config: RunnableConfig,
|
||||||
|
runtime: Runtime[Context]
|
||||||
|
) -> ImagePreprocessOutput:
|
||||||
|
"""
|
||||||
|
title: 图像预处理
|
||||||
|
desc: 下载图片、自动旋转(横向→纵向)、缩放到固定宽度1000px、上传对象存储
|
||||||
|
integrations: 对象存储
|
||||||
|
"""
|
||||||
|
ctx = runtime.context
|
||||||
|
|
||||||
|
# 1. 下载原始图片
|
||||||
|
img, original_width, original_height, dpi = download_and_process_image(
|
||||||
|
state.homework_image.url
|
||||||
|
)
|
||||||
|
logger.info(f"原始图片尺寸:宽{original_width} x 高{original_height}")
|
||||||
|
|
||||||
|
# 2. 自动旋转:如果宽度大于高度,旋转-90度使其变为纵向
|
||||||
|
img, was_rotated = auto_rotate_image(img)
|
||||||
|
after_rotate_width, after_rotate_height = img.size
|
||||||
|
if was_rotated:
|
||||||
|
logger.info(f"旋转后尺寸:宽{after_rotate_width} x 高{after_rotate_height}")
|
||||||
|
|
||||||
|
# 3. 缩放到固定宽度1000px,高度等比例缩放
|
||||||
|
resized_img, scale_ratio = resize_image_to_fixed_width(img, FIXED_WIDTH)
|
||||||
|
new_width, new_height = resized_img.size
|
||||||
|
logger.info(f"缩放后尺寸:宽{new_width} x 高{new_height},缩放比例:{scale_ratio:.3f}")
|
||||||
|
|
||||||
|
# 4. 上传处理后的图片到对象存储
|
||||||
|
processed_image_url = upload_image_to_storage(resized_img, ctx)
|
||||||
|
|
||||||
|
# 5. 返回处理后的图片信息(AI基于这个尺寸计算坐标)
|
||||||
|
return ImagePreprocessOutput(
|
||||||
|
image_info=ImageInfo(
|
||||||
|
width=new_width, # 缩放后的宽度:1000
|
||||||
|
height=new_height, # 缩放后的高度
|
||||||
|
dpi=dpi
|
||||||
|
),
|
||||||
|
image_url=processed_image_url
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,167 @@
|
||||||
|
"""多图片处理循环节点:并行调用子图处理每张作业图片"""
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from coze_coding_utils.runtime_ctx.context import Context
|
||||||
|
|
||||||
|
from graphs.state import (
|
||||||
|
ProcessImagesInput,
|
||||||
|
ProcessImagesOutput,
|
||||||
|
SingleImageResult,
|
||||||
|
SubgraphInput,
|
||||||
|
FinalResult,
|
||||||
|
ImageInfo
|
||||||
|
)
|
||||||
|
from graphs.loop_graph import single_image_subgraph
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def process_images_node(
|
||||||
|
state: ProcessImagesInput,
|
||||||
|
config: RunnableConfig,
|
||||||
|
runtime: Runtime[Context]
|
||||||
|
) -> ProcessImagesOutput:
|
||||||
|
"""
|
||||||
|
title: 多图片批改处理
|
||||||
|
desc: 并行调用子图处理每张作业图片,生成最终批改结果
|
||||||
|
integrations:
|
||||||
|
"""
|
||||||
|
ctx = runtime.context
|
||||||
|
|
||||||
|
# 获取并发数限制(从参数获取,默认10)
|
||||||
|
max_concurrent = getattr(state, 'max_concurrent', 10)
|
||||||
|
|
||||||
|
logger.info(f"Starting to process {len(state.homework_images)} images (concurrent={max_concurrent})")
|
||||||
|
|
||||||
|
# 定义处理单张图片的函数
|
||||||
|
def process_single_image(idx: int, homework_image) -> SingleImageResult:
|
||||||
|
logger.info(f"Processing image {idx + 1}/{len(state.homework_images)}")
|
||||||
|
try:
|
||||||
|
# 构建子图输入
|
||||||
|
subgraph_input = SubgraphInput(
|
||||||
|
homework_image=homework_image,
|
||||||
|
correct_answers=state.correct_answers,
|
||||||
|
image_index=idx,
|
||||||
|
comment_max_length=state.comment_max_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用子图
|
||||||
|
subgraph_output = single_image_subgraph.invoke(subgraph_input, config)
|
||||||
|
|
||||||
|
# 提取结果(兼容字典和Pydantic对象两种返回类型)
|
||||||
|
image_result = None
|
||||||
|
if subgraph_output:
|
||||||
|
if isinstance(subgraph_output, dict):
|
||||||
|
result_data = subgraph_output.get("image_result")
|
||||||
|
if result_data:
|
||||||
|
if isinstance(result_data, dict):
|
||||||
|
image_result = SingleImageResult(**result_data)
|
||||||
|
else:
|
||||||
|
image_result = result_data
|
||||||
|
elif hasattr(subgraph_output, 'image_result'):
|
||||||
|
image_result = subgraph_output.image_result
|
||||||
|
|
||||||
|
if image_result:
|
||||||
|
logger.info(f"Image {idx + 1} processed successfully: {len(image_result.annotations)} annotations")
|
||||||
|
return image_result
|
||||||
|
else:
|
||||||
|
logger.warning(f"Image {idx + 1} subgraph returned invalid output")
|
||||||
|
return SingleImageResult(
|
||||||
|
image_index=idx,
|
||||||
|
image_info=ImageInfo(width=0, height=0, dpi=72),
|
||||||
|
annotations=[]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to process image {idx + 1}: {e}", exc_info=True)
|
||||||
|
return SingleImageResult(
|
||||||
|
image_index=idx,
|
||||||
|
image_info=ImageInfo(width=0, height=0, dpi=72),
|
||||||
|
annotations=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 并行处理所有图片
|
||||||
|
image_results: List[SingleImageResult] = []
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_concurrent) as executor:
|
||||||
|
# 提交所有任务
|
||||||
|
future_to_idx = {
|
||||||
|
executor.submit(process_single_image, idx, img): idx
|
||||||
|
for idx, img in enumerate(state.homework_images)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 收集结果
|
||||||
|
results_dict = {}
|
||||||
|
for future in as_completed(future_to_idx):
|
||||||
|
idx = future_to_idx[future]
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
results_dict[idx] = result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Future failed for image {idx}: {e}")
|
||||||
|
results_dict[idx] = SingleImageResult(
|
||||||
|
image_index=idx,
|
||||||
|
image_info=ImageInfo(width=0, height=0, dpi=72),
|
||||||
|
annotations=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 按原始顺序排序结果
|
||||||
|
for idx in range(len(state.homework_images)):
|
||||||
|
image_results.append(results_dict[idx])
|
||||||
|
|
||||||
|
logger.info(f"Completed processing {len(image_results)} images")
|
||||||
|
|
||||||
|
# 生成最终结果
|
||||||
|
total_score = 0
|
||||||
|
full_score = 0
|
||||||
|
total_questions = 0
|
||||||
|
correct_count = 0
|
||||||
|
incorrect_count = 0
|
||||||
|
|
||||||
|
for result in image_results:
|
||||||
|
for annotation in result.annotations:
|
||||||
|
total_score += annotation.score
|
||||||
|
full_score += annotation.full_score
|
||||||
|
total_questions += 1
|
||||||
|
if annotation.status == "correct":
|
||||||
|
correct_count += 1
|
||||||
|
elif annotation.status == "incorrect":
|
||||||
|
incorrect_count += 1
|
||||||
|
|
||||||
|
# 计算得分率
|
||||||
|
score_rate = (total_score / full_score * 100) if full_score > 0 else 0
|
||||||
|
|
||||||
|
# 生成整体评价
|
||||||
|
if total_questions == 0:
|
||||||
|
overall_comment = "未识别到题目内容"
|
||||||
|
grade = "D"
|
||||||
|
elif score_rate >= 95:
|
||||||
|
overall_comment = f"优秀!{total_questions}题全部正确,掌握扎实,继续保持。"
|
||||||
|
grade = "A+"
|
||||||
|
elif score_rate >= 90:
|
||||||
|
overall_comment = f"良好!得分率{score_rate:.0f}%,错{incorrect_count}题,注意细节。"
|
||||||
|
grade = "A"
|
||||||
|
elif score_rate >= 80:
|
||||||
|
overall_comment = f"合格。得分率{score_rate:.0f}%,错{incorrect_count}题,需加强练习。"
|
||||||
|
grade = "B"
|
||||||
|
elif score_rate >= 70:
|
||||||
|
overall_comment = f"及格。错{incorrect_count}题,部分知识点掌握不牢,建议复习。"
|
||||||
|
grade = "C"
|
||||||
|
else:
|
||||||
|
overall_comment = f"需努力。得分率{score_rate:.0f}%,建议认真复习,多做练习。"
|
||||||
|
grade = "D"
|
||||||
|
|
||||||
|
final_result = FinalResult(
|
||||||
|
total_images=len(image_results),
|
||||||
|
image_results=image_results,
|
||||||
|
overall_comment=overall_comment,
|
||||||
|
total_score=total_score,
|
||||||
|
full_score=full_score,
|
||||||
|
grade=grade
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Final result: {total_score}/{full_score}, grade={grade}")
|
||||||
|
|
||||||
|
return ProcessImagesOutput(final_result=final_result)
|
||||||
|
|
@ -0,0 +1,340 @@
|
||||||
|
"""一体化识别批改节点:合并识别和批改为一次LLM调用,提升速度"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from jinja2 import Template
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from coze_coding_utils.runtime_ctx.context import Context
|
||||||
|
from coze_coding_dev_sdk import LLMClient
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from graphs.state import (
|
||||||
|
RecognizeAndCorrectInput,
|
||||||
|
RecognizeAndCorrectOutput,
|
||||||
|
QuestionItem,
|
||||||
|
CorrectionResult,
|
||||||
|
MarkPosition,
|
||||||
|
CorrectAnswer
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_incomplete_json(text: str) -> str:
|
||||||
|
"""尝试修复不完整的JSON字符串"""
|
||||||
|
# 计算括号数量
|
||||||
|
brace_count = text.count('{') - text.count('}')
|
||||||
|
bracket_count = text.count('[') - text.count(']')
|
||||||
|
|
||||||
|
# 补全缺失的括号
|
||||||
|
if bracket_count > 0:
|
||||||
|
text += ']' * bracket_count
|
||||||
|
if brace_count > 0:
|
||||||
|
text += '}' * brace_count
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def extract_json_from_text(text: str, key: str = "results") -> dict:
|
||||||
|
"""从文本中提取JSON对象,增强健壮性"""
|
||||||
|
import orjson
|
||||||
|
|
||||||
|
# 清理markdown标记
|
||||||
|
for prefix in ["```json", "```JSON", "```"]:
|
||||||
|
if text.startswith(prefix):
|
||||||
|
text = text[len(prefix):]
|
||||||
|
for suffix in ["```"]:
|
||||||
|
if text.endswith(suffix):
|
||||||
|
text = text[:-3]
|
||||||
|
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
# 尝试直接解析(支持格式化JSON)
|
||||||
|
for parser in [json.loads, lambda x: orjson.loads(x)]:
|
||||||
|
try:
|
||||||
|
result = parser(text)
|
||||||
|
if isinstance(result, dict) and key in result:
|
||||||
|
logger.info("Successfully parsed JSON directly")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Direct parse failed: {e}")
|
||||||
|
|
||||||
|
# 尝试修复不完整的JSON
|
||||||
|
try:
|
||||||
|
fixed_text = fix_incomplete_json(text)
|
||||||
|
for parser in [json.loads, lambda x: orjson.loads(x)]:
|
||||||
|
try:
|
||||||
|
result = parser(fixed_text)
|
||||||
|
if isinstance(result, dict) and key in result:
|
||||||
|
logger.info("Successfully parsed fixed JSON")
|
||||||
|
return result
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 尝试提取第一个完整的JSON对象
|
||||||
|
results_pattern = r'"results"\s*:\s*\['
|
||||||
|
match = re.search(results_pattern, text)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
start_pos = text.rfind('{', 0, match.start())
|
||||||
|
if start_pos == -1:
|
||||||
|
start_pos = 0
|
||||||
|
|
||||||
|
# 从start_pos开始,找到匹配的 }
|
||||||
|
brace_count = 0
|
||||||
|
bracket_count = 0
|
||||||
|
in_string = False
|
||||||
|
escape = False
|
||||||
|
end_pos = start_pos
|
||||||
|
|
||||||
|
for i in range(start_pos, len(text)):
|
||||||
|
char = text[i]
|
||||||
|
|
||||||
|
if escape:
|
||||||
|
escape = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if char == '\\':
|
||||||
|
escape = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if char == '"' and not escape:
|
||||||
|
in_string = not in_string
|
||||||
|
continue
|
||||||
|
|
||||||
|
if in_string:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if char == '{':
|
||||||
|
brace_count += 1
|
||||||
|
elif char == '}':
|
||||||
|
brace_count -= 1
|
||||||
|
if brace_count == 0 and bracket_count == 0:
|
||||||
|
end_pos = i + 1
|
||||||
|
break
|
||||||
|
elif char == '[':
|
||||||
|
bracket_count += 1
|
||||||
|
elif char == ']':
|
||||||
|
bracket_count -= 1
|
||||||
|
|
||||||
|
if end_pos > start_pos:
|
||||||
|
json_str = text[start_pos:end_pos]
|
||||||
|
|
||||||
|
for parser in [json.loads, lambda x: orjson.loads(x)]:
|
||||||
|
try:
|
||||||
|
result = parser(json_str)
|
||||||
|
if isinstance(result, dict) and key in result:
|
||||||
|
logger.info("Successfully extracted and parsed JSON object")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to parse extracted JSON: {e}")
|
||||||
|
|
||||||
|
# 如果提取失败,尝试修复不完整的JSON
|
||||||
|
if end_pos <= start_pos:
|
||||||
|
# 找到JSON开始位置
|
||||||
|
json_str = text[start_pos:]
|
||||||
|
|
||||||
|
# 尝试修复
|
||||||
|
fixed_json = fix_incomplete_json(json_str)
|
||||||
|
for parser in [json.loads, lambda x: orjson.loads(x)]:
|
||||||
|
try:
|
||||||
|
result = parser(fixed_json)
|
||||||
|
if isinstance(result, dict) and key in result:
|
||||||
|
logger.info("Successfully parsed fixed incomplete JSON")
|
||||||
|
return result
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.warning(f"Failed to extract JSON with key '{key}' from text length {len(text)}")
|
||||||
|
return {key: []}
|
||||||
|
|
||||||
|
|
||||||
|
def build_dynamic_prompt(
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
correct_answers: List[CorrectAnswer],
|
||||||
|
comment_max_length: int
|
||||||
|
) -> str:
|
||||||
|
"""构建动态Prompt部分"""
|
||||||
|
|
||||||
|
if correct_answers:
|
||||||
|
answers_text = json.dumps(
|
||||||
|
[{"question_id": a.question_id, "correct_answer": a.correct_answer}
|
||||||
|
for a in correct_answers],
|
||||||
|
ensure_ascii=False
|
||||||
|
)
|
||||||
|
answer_hint = f"""
|
||||||
|
【标准答案】
|
||||||
|
{answers_text}"""
|
||||||
|
else:
|
||||||
|
answer_hint = "\n【批改模式】无标准答案,请根据物理知识判断。"
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
【图片尺寸】{image_width}×{image_height}像素
|
||||||
|
【评语限制】{comment_max_length}字以内
|
||||||
|
{answer_hint}
|
||||||
|
|
||||||
|
⚠️ 必须输出完整JSON,不要输出思考过程"""
|
||||||
|
|
||||||
|
|
||||||
|
def recognize_and_correct_node(
|
||||||
|
state: RecognizeAndCorrectInput,
|
||||||
|
config: RunnableConfig,
|
||||||
|
runtime: Runtime[Context]
|
||||||
|
) -> RecognizeAndCorrectOutput:
|
||||||
|
"""
|
||||||
|
title: 一体化识别批改
|
||||||
|
desc: 合并识别和批改为一次LLM调用,提升批改速度
|
||||||
|
integrations: 大语言模型
|
||||||
|
"""
|
||||||
|
ctx = runtime.context
|
||||||
|
|
||||||
|
# 读取LLM配置
|
||||||
|
cfg_file = os.path.join(os.getenv("COZE_WORKSPACE_PATH", ""), config["metadata"]["llm_cfg"])
|
||||||
|
with open(cfg_file, "r", encoding="utf-8") as fd:
|
||||||
|
_cfg = json.load(fd)
|
||||||
|
|
||||||
|
llm_config = _cfg.get("config", {})
|
||||||
|
sp = _cfg.get("sp", "")
|
||||||
|
up = _cfg.get("up", "")
|
||||||
|
|
||||||
|
# 获取参数
|
||||||
|
image_url = state.image_url
|
||||||
|
image_info = state.image_info
|
||||||
|
correct_answers = state.correct_answers
|
||||||
|
comment_max_length = getattr(state, 'comment_max_length', 100)
|
||||||
|
|
||||||
|
# 使用Jinja2渲染up
|
||||||
|
up_tpl = Template(up)
|
||||||
|
user_prompt = up_tpl.render({
|
||||||
|
"image_url": image_url,
|
||||||
|
"comment_max_length": comment_max_length
|
||||||
|
})
|
||||||
|
|
||||||
|
# 构建动态部分
|
||||||
|
dynamic_prompt = build_dynamic_prompt(
|
||||||
|
image_info.width,
|
||||||
|
image_info.height,
|
||||||
|
correct_answers,
|
||||||
|
comment_max_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# 组合完整Prompt
|
||||||
|
full_prompt = f"{sp}\n\n{user_prompt}\n{dynamic_prompt}"
|
||||||
|
|
||||||
|
# 调用LLM
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content=[
|
||||||
|
{"type": "text", "text": full_prompt},
|
||||||
|
{"type": "image_url", "image_url": {"url": image_url}}
|
||||||
|
])
|
||||||
|
]
|
||||||
|
|
||||||
|
client = LLMClient(ctx=ctx)
|
||||||
|
response = client.invoke(
|
||||||
|
messages=messages,
|
||||||
|
model=llm_config.get("model", "doubao-seed-2-0-pro-260215"),
|
||||||
|
temperature=llm_config.get("temperature", 0.0),
|
||||||
|
max_completion_tokens=llm_config.get("max_completion_tokens", 4096)
|
||||||
|
)
|
||||||
|
|
||||||
|
response_text = response.content if isinstance(response.content, str) else " ".join(
|
||||||
|
item.get("text", "") if isinstance(item, dict) else str(item)
|
||||||
|
for item in response.content
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
# 只记录前500字符,避免日志过大
|
||||||
|
logger.info(f"LLM response (first 500 chars): {response_text[:500]}")
|
||||||
|
|
||||||
|
# 解析结果
|
||||||
|
result_dict = extract_json_from_text(response_text, "results")
|
||||||
|
|
||||||
|
# 相对坐标(0-1000)转换为绝对坐标
|
||||||
|
width_scale = image_info.width / 1000.0 if image_info.width > 0 else 1.0
|
||||||
|
height_scale = image_info.height / 1000.0 if image_info.height > 0 else 1.0
|
||||||
|
|
||||||
|
question_items: List[QuestionItem] = []
|
||||||
|
correction_results: List[CorrectionResult] = []
|
||||||
|
|
||||||
|
results_list = result_dict.get("results", [])
|
||||||
|
if not isinstance(results_list, list):
|
||||||
|
results_list = []
|
||||||
|
|
||||||
|
for r in results_list:
|
||||||
|
try:
|
||||||
|
if not isinstance(r, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解析bbox(相对坐标0-1000)
|
||||||
|
answer_bbox = r.get("answer_bbox", [0, 0, 0, 0])
|
||||||
|
if not isinstance(answer_bbox, list) or len(answer_bbox) != 4:
|
||||||
|
answer_bbox = [0, 0, 0, 0]
|
||||||
|
|
||||||
|
# 转换为绝对坐标
|
||||||
|
answer_bbox_abs = [
|
||||||
|
int(answer_bbox[0] * width_scale),
|
||||||
|
int(answer_bbox[1] * height_scale),
|
||||||
|
int(answer_bbox[2] * width_scale),
|
||||||
|
int(answer_bbox[3] * height_scale)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 自动计算mark_position
|
||||||
|
bbox_width = answer_bbox_abs[2] - answer_bbox_abs[0]
|
||||||
|
bbox_height = answer_bbox_abs[3] - answer_bbox_abs[1]
|
||||||
|
|
||||||
|
if bbox_width > 0 and bbox_height > 0:
|
||||||
|
mark_x = answer_bbox_abs[2] + 30
|
||||||
|
mark_y = answer_bbox_abs[1] + int(bbox_height * 0.5)
|
||||||
|
|
||||||
|
if mark_x > image_info.width - 50:
|
||||||
|
mark_x = image_info.width - 50
|
||||||
|
else:
|
||||||
|
mark_x = 500
|
||||||
|
mark_y = 500
|
||||||
|
|
||||||
|
mark_position = MarkPosition(x=mark_x, y=mark_y)
|
||||||
|
|
||||||
|
# 构建QuestionItem
|
||||||
|
question_items.append(QuestionItem(
|
||||||
|
question_id=str(r.get("question_id", "")),
|
||||||
|
parent_id=str(r.get("parent_id", "")),
|
||||||
|
is_sub_question=bool(r.get("is_sub_question", False)),
|
||||||
|
question_text=str(r.get("question_text", "")),
|
||||||
|
student_answer=str(r.get("student_answer", "")),
|
||||||
|
answer_bbox=answer_bbox_abs,
|
||||||
|
mark_position=mark_position,
|
||||||
|
full_score=int(r.get("full_score", 10) if r.get("full_score") is not None else 10)
|
||||||
|
))
|
||||||
|
|
||||||
|
# 构建CorrectionResult
|
||||||
|
status = str(r.get("status", "incorrect"))
|
||||||
|
if status not in ["correct", "incorrect", "partial"]:
|
||||||
|
status = "incorrect"
|
||||||
|
|
||||||
|
# 使用原始comment,不做截断(由LLM控制长度)
|
||||||
|
comment = str(r.get("comment", ""))
|
||||||
|
|
||||||
|
correction_results.append(CorrectionResult(
|
||||||
|
question_id=str(r.get("question_id", "")),
|
||||||
|
parent_id=str(r.get("parent_id", "")),
|
||||||
|
status=status,
|
||||||
|
score=int(r.get("score", 0) if r.get("score") is not None else 0),
|
||||||
|
full_score=int(r.get("full_score", 10) if r.get("full_score") is not None else 10),
|
||||||
|
comment=comment
|
||||||
|
))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse result: {r}, error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Parsed {len(question_items)} questions and {len(correction_results)} correction results")
|
||||||
|
|
||||||
|
return RecognizeAndCorrectOutput(
|
||||||
|
question_items=question_items,
|
||||||
|
correction_results=correction_results
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,65 @@
|
||||||
|
"""4. 结果整合节点:将识别结果和批改结果合并为最终批注"""
|
||||||
|
from typing import List
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
from coze_coding_utils.runtime_ctx.context import Context
|
||||||
|
|
||||||
|
from graphs.state import (
|
||||||
|
ResultMergeInput,
|
||||||
|
ResultMergeOutput,
|
||||||
|
Annotation,
|
||||||
|
QuestionItem,
|
||||||
|
CorrectionResult
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def result_merge_node(
|
||||||
|
state: ResultMergeInput,
|
||||||
|
config: RunnableConfig,
|
||||||
|
runtime: Runtime[Context]
|
||||||
|
) -> ResultMergeOutput:
|
||||||
|
"""
|
||||||
|
title: 结果整合
|
||||||
|
desc: 将识别结果和批改结果合并为最终批注列表
|
||||||
|
integrations:
|
||||||
|
"""
|
||||||
|
ctx = runtime.context
|
||||||
|
|
||||||
|
annotations: List[Annotation] = []
|
||||||
|
|
||||||
|
# 创建批改结果字典,方便查找
|
||||||
|
correction_dict = {c.question_id: c for c in state.correction_results}
|
||||||
|
|
||||||
|
for q in state.question_items:
|
||||||
|
# 查找对应的批改结果
|
||||||
|
correction = correction_dict.get(q.question_id)
|
||||||
|
|
||||||
|
if correction is None:
|
||||||
|
# 如果没有批改结果,默认为错误
|
||||||
|
annotations.append(Annotation(
|
||||||
|
question_id=q.question_id,
|
||||||
|
parent_id=q.parent_id,
|
||||||
|
status="incorrect",
|
||||||
|
question_text=q.question_text,
|
||||||
|
student_answer=q.student_answer,
|
||||||
|
answer_bbox=q.answer_bbox,
|
||||||
|
comment="未找到批改结果",
|
||||||
|
mark_position=q.mark_position,
|
||||||
|
score=0,
|
||||||
|
full_score=q.full_score
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
annotations.append(Annotation(
|
||||||
|
question_id=q.question_id,
|
||||||
|
parent_id=q.parent_id,
|
||||||
|
status=correction.status,
|
||||||
|
question_text=q.question_text,
|
||||||
|
student_answer=q.student_answer,
|
||||||
|
answer_bbox=q.answer_bbox,
|
||||||
|
comment=correction.comment,
|
||||||
|
mark_position=q.mark_position,
|
||||||
|
score=correction.score,
|
||||||
|
full_score=correction.full_score
|
||||||
|
))
|
||||||
|
|
||||||
|
return ResultMergeOutput(annotations=annotations)
|
||||||
|
|
@ -0,0 +1,232 @@
|
||||||
|
"""初中物理作业批改工作流状态定义 - 支持多图片批改"""
|
||||||
|
from typing import List, Optional, Literal
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from utils.file.file import File
|
||||||
|
|
||||||
|
|
||||||
|
# === 基础数据结构 ===
|
||||||
|
|
||||||
|
class ImageInfo(BaseModel):
|
||||||
|
"""图片信息"""
|
||||||
|
width: int = Field(..., description="图片宽度(像素)")
|
||||||
|
height: int = Field(..., description="图片高度(像素)")
|
||||||
|
dpi: int = Field(default=72, description="图片DPI")
|
||||||
|
|
||||||
|
|
||||||
|
class MarkPosition(BaseModel):
|
||||||
|
"""批改标记位置"""
|
||||||
|
x: int = Field(..., description="标记中心X坐标(绝对像素)")
|
||||||
|
y: int = Field(..., description="标记中心Y坐标(绝对像素)")
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionItem(BaseModel):
|
||||||
|
"""题目项(一体化识别结果)"""
|
||||||
|
question_id: str = Field(..., description="题号(如10、10.1、10.2)")
|
||||||
|
parent_id: str = Field(default="", description="母题题号(子题时填写)")
|
||||||
|
is_sub_question: bool = Field(default=False, description="是否为子题")
|
||||||
|
question_text: str = Field(default="", description="题目内容")
|
||||||
|
student_answer: str = Field(default="", description="学生答案内容")
|
||||||
|
answer_bbox: List[int] = Field(..., description="学生答案区域的边界框 [x1, y1, x2, y2](用于定位标注)")
|
||||||
|
mark_position: MarkPosition = Field(..., description="批改标记应该标注的位置坐标")
|
||||||
|
full_score: int = Field(default=10, description="满分")
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectAnswer(BaseModel):
|
||||||
|
"""标准答案项(从Word文档解析)"""
|
||||||
|
question_id: str = Field(..., description="题号(如10、10.1、10.2)")
|
||||||
|
parent_id: str = Field(default="", description="母题题号(子题时填写)")
|
||||||
|
is_sub_question: bool = Field(default=False, description="是否为子题")
|
||||||
|
question_text: str = Field(default="", description="题目内容/题干")
|
||||||
|
correct_answer: str = Field(..., description="标准答案")
|
||||||
|
full_score: int = Field(default=10, description="满分")
|
||||||
|
answer_analysis: str = Field(default="", description="答案解析/解题步骤")
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectionResult(BaseModel):
|
||||||
|
"""批改结果"""
|
||||||
|
question_id: str = Field(..., description="题号")
|
||||||
|
parent_id: str = Field(default="", description="母题题号")
|
||||||
|
status: Literal["correct", "incorrect", "partial"] = Field(..., description="批改状态")
|
||||||
|
score: int = Field(default=0, description="得分")
|
||||||
|
full_score: int = Field(default=10, description="满分")
|
||||||
|
comment: str = Field(default="", description="批改评语")
|
||||||
|
|
||||||
|
|
||||||
|
class Annotation(BaseModel):
|
||||||
|
"""完整批注"""
|
||||||
|
question_id: str = Field(..., description="题号")
|
||||||
|
parent_id: str = Field(default="", description="母题题号")
|
||||||
|
status: Literal["correct", "incorrect", "partial"] = Field(..., description="批改状态")
|
||||||
|
question_text: str = Field(default="", description="题目内容")
|
||||||
|
student_answer: str = Field(default="", description="学生答案")
|
||||||
|
answer_bbox: List[int] = Field(default=[], description="答案区域边界框")
|
||||||
|
comment: str = Field(default="", description="批改评语")
|
||||||
|
mark_position: MarkPosition = Field(..., description="批改标记位置")
|
||||||
|
score: int = Field(default=0, description="得分")
|
||||||
|
full_score: int = Field(default=10, description="满分")
|
||||||
|
|
||||||
|
|
||||||
|
class SingleImageResult(BaseModel):
|
||||||
|
"""单张图片的批改结果"""
|
||||||
|
image_index: int = Field(..., description="图片索引(从0开始)")
|
||||||
|
image_info: ImageInfo = Field(..., description="图片信息")
|
||||||
|
image_url: str = Field(default="", description="处理后的图片URL")
|
||||||
|
annotations: List[Annotation] = Field(default=[], description="该图片的批注列表")
|
||||||
|
|
||||||
|
|
||||||
|
class FinalResult(BaseModel):
|
||||||
|
"""最终批改结果(多图片汇总)"""
|
||||||
|
total_images: int = Field(..., description="总图片数")
|
||||||
|
image_results: List[SingleImageResult] = Field(default=[], description="各图片的批改结果")
|
||||||
|
overall_comment: str = Field(default="", description="整体评价")
|
||||||
|
total_score: int = Field(default=0, description="总分")
|
||||||
|
full_score: int = Field(default=100, description="满分")
|
||||||
|
grade: str = Field(default="", description="等级")
|
||||||
|
|
||||||
|
|
||||||
|
# === 全局状态 ===
|
||||||
|
class GlobalState(BaseModel):
|
||||||
|
"""工作流全局状态"""
|
||||||
|
# 输入参数
|
||||||
|
homework_images: List[File] = Field(default=[], description="上传的作业图片列表")
|
||||||
|
answer_doc_url: str = Field(default="", description="正确答案Word文件的URL(.docx格式)")
|
||||||
|
comment_max_length: int = Field(default=100, description="评语最大字数")
|
||||||
|
max_concurrent: int = Field(default=10, description="并行批改的最大数量")
|
||||||
|
grade_standards: dict = Field(default={}, description="评价等级标准")
|
||||||
|
# 中间状态
|
||||||
|
correct_answers: List[CorrectAnswer] = Field(default=[], description="从Word解析的标准答案列表")
|
||||||
|
image_results: List[SingleImageResult] = Field(default=[], description="各图片的批改结果列表")
|
||||||
|
# 最终结果
|
||||||
|
final_result: Optional[FinalResult] = Field(default=None, description="最终汇总结果")
|
||||||
|
|
||||||
|
|
||||||
|
# === 图输入输出 ===
|
||||||
|
class GraphInput(BaseModel):
|
||||||
|
"""工作流输入"""
|
||||||
|
homework_images: List[File] = Field(..., description="上传的作业图片列表")
|
||||||
|
answer_doc_url: str = Field(default="", description="正确答案Word文件的URL(.docx格式,可选)")
|
||||||
|
comment_max_length: int = Field(default=100, description="评语最大字数,默认100字")
|
||||||
|
max_concurrent: int = Field(default=10, description="并行批改的最大数量,默认10")
|
||||||
|
grade_standards: dict = Field(
|
||||||
|
default={
|
||||||
|
"A+": {"min_percentage": 95, "description": "答案全部正确,步骤完整规范,逻辑严谨;书写/格式整洁,无错别字、无遗漏;完成度100%,态度认真,质量上乘"},
|
||||||
|
"A": {"min_percentage": 90, "description": "答案完全正确,无任何错误;步骤合理、格式规范,无原则性问题;完成度100%,满足全部要求"},
|
||||||
|
"B": {"min_percentage": 80, "description": "存在少量非关键性错误,或步骤略有缺失;整体思路基本正确,仅细节、格式、计算等小问题;完成大部分内容,整体合格但不够严谨"},
|
||||||
|
"C": {"min_percentage": 70, "description": "错误较多,部分核心题目作答错误;步骤不完整、逻辑不够清晰;完成度一般,有明显应付、漏答情况"},
|
||||||
|
"D": {"min_percentage": 0, "description": "大面积错误,核心知识点未掌握;大量空白、敷衍、抄袭;未达到基本完成要求"}
|
||||||
|
},
|
||||||
|
description="评价等级标准,包含各等级的最低得分率百分比和描述"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphOutput(BaseModel):
|
||||||
|
"""工作流输出"""
|
||||||
|
final_result: FinalResult = Field(..., description="最终批改结果JSON(包含多图片)")
|
||||||
|
|
||||||
|
|
||||||
|
# === 文档答案解析节点 ===
|
||||||
|
class DocExtractInput(BaseModel):
|
||||||
|
"""文档答案解析节点输入"""
|
||||||
|
answer_doc_url: str = Field(default="", description="正确答案Word文件的URL(.docx格式,可选)")
|
||||||
|
|
||||||
|
|
||||||
|
class DocExtractOutput(BaseModel):
|
||||||
|
"""文档答案解析节点输出"""
|
||||||
|
correct_answers: List[CorrectAnswer] = Field(default=[], description="解析的标准答案列表")
|
||||||
|
|
||||||
|
|
||||||
|
# === 子图状态定义(单图片处理流程) ===
|
||||||
|
class SubgraphState(BaseModel):
|
||||||
|
"""子图全局状态(处理单张图片)"""
|
||||||
|
# 输入
|
||||||
|
homework_image: File = Field(..., description="当前处理的作业图片")
|
||||||
|
correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表")
|
||||||
|
image_index: int = Field(default=0, description="图片索引")
|
||||||
|
comment_max_length: int = Field(default=100, description="评语最大字数")
|
||||||
|
# 中间状态
|
||||||
|
image_info: ImageInfo = Field(default_factory=lambda: ImageInfo(width=0, height=0, dpi=72), description="图片信息")
|
||||||
|
image_url: str = Field(default="", description="处理后的图片URL")
|
||||||
|
question_items: List[QuestionItem] = Field(default=[], description="识别的题目项列表")
|
||||||
|
correction_results: List[CorrectionResult] = Field(default=[], description="批改结果列表")
|
||||||
|
annotations: List[Annotation] = Field(default=[], description="最终批注列表")
|
||||||
|
|
||||||
|
|
||||||
|
class SubgraphInput(BaseModel):
|
||||||
|
"""子图输入"""
|
||||||
|
homework_image: File = Field(..., description="当前处理的作业图片")
|
||||||
|
correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表")
|
||||||
|
image_index: int = Field(default=0, description="图片索引")
|
||||||
|
comment_max_length: int = Field(default=100, description="评语最大字数")
|
||||||
|
|
||||||
|
|
||||||
|
class SubgraphOutput(BaseModel):
|
||||||
|
"""子图输出"""
|
||||||
|
image_result: SingleImageResult = Field(..., description="单张图片的批改结果")
|
||||||
|
|
||||||
|
|
||||||
|
# === 子图节点输入输出 ===
|
||||||
|
|
||||||
|
# 1. 图像预处理节点
|
||||||
|
class ImagePreprocessInput(BaseModel):
|
||||||
|
"""图像预处理节点输入"""
|
||||||
|
homework_image: File = Field(..., description="上传的作业图片")
|
||||||
|
|
||||||
|
|
||||||
|
class ImagePreprocessOutput(BaseModel):
|
||||||
|
"""图像预处理节点输出"""
|
||||||
|
image_info: ImageInfo = Field(..., description="图片信息")
|
||||||
|
image_url: str = Field(..., description="处理后的图片URL")
|
||||||
|
|
||||||
|
|
||||||
|
# 2. 一体化识别批改节点(合并识别+批改)
|
||||||
|
class RecognizeAndCorrectInput(BaseModel):
|
||||||
|
"""一体化识别批改节点输入"""
|
||||||
|
image_url: str = Field(..., description="图片URL")
|
||||||
|
image_info: ImageInfo = Field(..., description="图片信息")
|
||||||
|
correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表")
|
||||||
|
comment_max_length: int = Field(default=100, description="评语最大字数")
|
||||||
|
|
||||||
|
|
||||||
|
class RecognizeAndCorrectOutput(BaseModel):
|
||||||
|
"""一体化识别批改节点输出"""
|
||||||
|
question_items: List[QuestionItem] = Field(default=[], description="识别的题目项列表")
|
||||||
|
correction_results: List[CorrectionResult] = Field(default=[], description="批改结果列表")
|
||||||
|
|
||||||
|
|
||||||
|
# 3. 批改判断节点
|
||||||
|
class CorrectionJudgeInput(BaseModel):
|
||||||
|
"""批改判断节点输入"""
|
||||||
|
question_items: List[QuestionItem] = Field(default=[], description="题目项列表")
|
||||||
|
correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表")
|
||||||
|
comment_max_length: int = Field(default=100, description="评语最大字数")
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectionJudgeOutput(BaseModel):
|
||||||
|
"""批改判断节点输出"""
|
||||||
|
correction_results: List[CorrectionResult] = Field(default=[], description="批改结果列表")
|
||||||
|
|
||||||
|
|
||||||
|
# 4. 结果整合节点
|
||||||
|
class ResultMergeInput(BaseModel):
|
||||||
|
"""结果整合节点输入"""
|
||||||
|
question_items: List[QuestionItem] = Field(default=[], description="题目项列表")
|
||||||
|
correction_results: List[CorrectionResult] = Field(default=[], description="批改结果列表")
|
||||||
|
|
||||||
|
|
||||||
|
class ResultMergeOutput(BaseModel):
|
||||||
|
"""结果整合节点输出"""
|
||||||
|
annotations: List[Annotation] = Field(default=[], description="最终批注列表")
|
||||||
|
|
||||||
|
|
||||||
|
# === 循环节点 ===
|
||||||
|
class ProcessImagesInput(BaseModel):
|
||||||
|
"""多图片处理循环节点输入"""
|
||||||
|
homework_images: List[File] = Field(default=[], description="作业图片列表")
|
||||||
|
correct_answers: List[CorrectAnswer] = Field(default=[], description="标准答案列表")
|
||||||
|
comment_max_length: int = Field(default=100, description="评语最大字数")
|
||||||
|
max_concurrent: int = Field(default=10, description="并行批改的最大数量")
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessImagesOutput(BaseModel):
|
||||||
|
"""多图片处理循环节点输出"""
|
||||||
|
final_result: FinalResult = Field(..., description="最终批改结果")
|
||||||
|
|
@ -0,0 +1,546 @@
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Iterable, AsyncIterable, AsyncGenerator, Optional
|
||||||
|
import cozeloop
|
||||||
|
import uvicorn
|
||||||
|
import time
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
from coze_coding_utils.runtime_ctx.context import new_context, Context
|
||||||
|
from coze_coding_utils.helper import graph_helper
|
||||||
|
from coze_coding_utils.log.node_log import LOG_FILE
|
||||||
|
from coze_coding_utils.log.write_log import setup_logging, request_context
|
||||||
|
from coze_coding_utils.log.config import LOG_LEVEL
|
||||||
|
from coze_coding_utils.error.classifier import ErrorClassifier, classify_error
|
||||||
|
from coze_coding_utils.helper.stream_runner import AgentStreamRunner, WorkflowStreamRunner,agent_stream_handler,workflow_stream_handler, RunOpt
|
||||||
|
|
||||||
|
setup_logging(
|
||||||
|
log_file=LOG_FILE,
|
||||||
|
max_bytes=100 * 1024 * 1024, # 100MB
|
||||||
|
backup_count=5,
|
||||||
|
log_level=LOG_LEVEL,
|
||||||
|
use_json_format=True,
|
||||||
|
console_output=True
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
from coze_coding_utils.helper.agent_helper import to_stream_input
|
||||||
|
from coze_coding_utils.openai.handler import OpenAIChatHandler
|
||||||
|
from coze_coding_utils.log.parser import LangGraphParser
|
||||||
|
from coze_coding_utils.log.err_trace import extract_core_stack
|
||||||
|
from coze_coding_utils.log.loop_trace import init_run_config, init_agent_config
|
||||||
|
|
||||||
|
|
||||||
|
# 超时配置常量
|
||||||
|
TIMEOUT_SECONDS = 900 # 15分钟
|
||||||
|
|
||||||
|
class GraphService:
|
||||||
|
def __init__(self):
|
||||||
|
# 用于跟踪正在运行的任务(使用asyncio.Task)
|
||||||
|
self.running_tasks: Dict[str, asyncio.Task] = {}
|
||||||
|
# 错误分类器
|
||||||
|
self.error_classifier = ErrorClassifier()
|
||||||
|
# stream runner
|
||||||
|
self._agent_stream_runner = AgentStreamRunner()
|
||||||
|
self._workflow_stream_runner = WorkflowStreamRunner()
|
||||||
|
self._graph = None
|
||||||
|
self._graph_lock = threading.Lock()
|
||||||
|
|
||||||
|
def _get_graph(self, ctx=Context):
|
||||||
|
if graph_helper.is_agent_proj():
|
||||||
|
return graph_helper.get_agent_instance("agents.agent", ctx)
|
||||||
|
|
||||||
|
if self._graph is not None:
|
||||||
|
return self._graph
|
||||||
|
with self._graph_lock:
|
||||||
|
if self._graph is not None:
|
||||||
|
return self._graph
|
||||||
|
self._graph = graph_helper.get_graph_instance("graphs.graph")
|
||||||
|
return self._graph
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sse_event(data: Any, event_id: Any = None) -> str:
|
||||||
|
id_line = f"id: {event_id}\n" if event_id else ""
|
||||||
|
return f"{id_line}event: message\ndata: {json.dumps(data, ensure_ascii=False, default=str)}\n\n"
|
||||||
|
|
||||||
|
def _get_stream_runner(self):
|
||||||
|
if graph_helper.is_agent_proj():
|
||||||
|
return self._agent_stream_runner
|
||||||
|
else:
|
||||||
|
return self._workflow_stream_runner
|
||||||
|
|
||||||
|
# 流式运行(原始迭代器):本地调用使用
|
||||||
|
def stream(self, payload: Dict[str, Any], run_config: RunnableConfig, ctx=Context) -> Iterable[Any]:
|
||||||
|
graph = self._get_graph(ctx)
|
||||||
|
stream_runner = self._get_stream_runner()
|
||||||
|
for chunk in stream_runner.stream(payload, graph, run_config, ctx):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# 同步运行:本地/HTTP 通用
|
||||||
|
async def run(self, payload: Dict[str, Any], ctx=None) -> Dict[str, Any]:
|
||||||
|
if ctx is None:
|
||||||
|
ctx = new_context("run")
|
||||||
|
|
||||||
|
run_id = ctx.run_id
|
||||||
|
logger.info(f"Starting run with run_id: {run_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
graph = self._get_graph(ctx)
|
||||||
|
# custom tracer
|
||||||
|
run_config = init_run_config(graph, ctx)
|
||||||
|
run_config["configurable"] = {"thread_id": ctx.run_id}
|
||||||
|
|
||||||
|
# 直接调用,LangGraph会在当前任务上下文中执行
|
||||||
|
# 如果当前任务被取消,LangGraph的执行也会被取消
|
||||||
|
return await graph.ainvoke(payload, config=run_config, context=ctx)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info(f"Run {run_id} was cancelled")
|
||||||
|
return {"status": "cancelled", "run_id": run_id, "message": "Execution was cancelled"}
|
||||||
|
except Exception as e:
|
||||||
|
# 使用错误分类器分类错误
|
||||||
|
err = self.error_classifier.classify(e, {"node_name": "run", "run_id": run_id})
|
||||||
|
# 记录详细的错误信息和堆栈跟踪
|
||||||
|
logger.error(
|
||||||
|
f"Error in GraphService.run: [{err.code}] {err.message}\n"
|
||||||
|
f"Category: {err.category.name}\n"
|
||||||
|
f"Traceback:\n{extract_core_stack()}"
|
||||||
|
)
|
||||||
|
# 保留原始异常堆栈,便于上层返回真正的报错位置
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 清理任务记录
|
||||||
|
self.running_tasks.pop(run_id, None)
|
||||||
|
|
||||||
|
# 流式运行(SSE 格式化):HTTP 路由使用
|
||||||
|
async def stream_sse(self, payload: Dict[str, Any], ctx=None, run_opt: Optional[RunOpt] = None) -> AsyncGenerator[str, None]:
|
||||||
|
if ctx is None:
|
||||||
|
ctx = new_context(method="stream_sse")
|
||||||
|
if run_opt is None:
|
||||||
|
run_opt = RunOpt()
|
||||||
|
|
||||||
|
run_id = ctx.run_id
|
||||||
|
logger.info(f"Starting stream with run_id: {run_id}")
|
||||||
|
graph = self._get_graph(ctx)
|
||||||
|
if graph_helper.is_agent_proj():
|
||||||
|
run_config = init_agent_config(graph, ctx)
|
||||||
|
else:
|
||||||
|
run_config = init_run_config(graph, ctx) # vibeflow
|
||||||
|
|
||||||
|
is_workflow = not graph_helper.is_agent_proj()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in self.astream(payload, graph, run_config=run_config, ctx=ctx, run_opt=run_opt):
|
||||||
|
if is_workflow and isinstance(chunk, tuple):
|
||||||
|
event_id, data = chunk
|
||||||
|
yield self._sse_event(data, event_id)
|
||||||
|
else:
|
||||||
|
yield self._sse_event(chunk)
|
||||||
|
finally:
|
||||||
|
# 清理任务记录
|
||||||
|
self.running_tasks.pop(run_id, None)
|
||||||
|
cozeloop.flush()
|
||||||
|
|
||||||
|
# 取消执行 - 使用asyncio的标准方式
|
||||||
|
def cancel_run(self, run_id: str, ctx: Optional[Context] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
取消指定run_id的执行
|
||||||
|
|
||||||
|
使用asyncio.Task.cancel()来取消任务,这是标准的Python异步取消机制。
|
||||||
|
LangGraph会在节点之间检查CancelledError,实现优雅的取消。
|
||||||
|
"""
|
||||||
|
logger.info(f"Attempting to cancel run_id: {run_id}")
|
||||||
|
|
||||||
|
# 查找对应的任务
|
||||||
|
if run_id in self.running_tasks:
|
||||||
|
task = self.running_tasks[run_id]
|
||||||
|
if not task.done():
|
||||||
|
# 使用asyncio的标准取消机制
|
||||||
|
# 这会在下一个await点抛出CancelledError
|
||||||
|
task.cancel()
|
||||||
|
logger.info(f"Cancellation requested for run_id: {run_id}")
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"run_id": run_id,
|
||||||
|
"message": "Cancellation signal sent, task will be cancelled at next await point"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.info(f"Task already completed for run_id: {run_id}")
|
||||||
|
return {
|
||||||
|
"status": "already_completed",
|
||||||
|
"run_id": run_id,
|
||||||
|
"message": "Task has already completed"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.warning(f"No active task found for run_id: {run_id}")
|
||||||
|
return {
|
||||||
|
"status": "not_found",
|
||||||
|
"run_id": run_id,
|
||||||
|
"message": "No active task found with this run_id. Task may have already completed or run_id is invalid."
|
||||||
|
}
|
||||||
|
|
||||||
|
# 运行指定节点:本地/HTTP 通用
|
||||||
|
async def run_node(self, node_id: str, payload: Dict[str, Any], ctx=None) -> Any:
|
||||||
|
if ctx is None or Context.run_id == "":
|
||||||
|
ctx = new_context(method="node_run")
|
||||||
|
|
||||||
|
_graph = self._get_graph()
|
||||||
|
node_func, input_cls, output_cls = graph_helper.get_graph_node_func_with_inout(_graph.get_graph(), node_id)
|
||||||
|
if node_func is None or input_cls is None:
|
||||||
|
raise KeyError(f"node_id '{node_id}' not found")
|
||||||
|
|
||||||
|
parser = LangGraphParser(_graph)
|
||||||
|
metadata = parser.get_node_metadata(node_id) or {}
|
||||||
|
|
||||||
|
_g = StateGraph(input_cls, input_schema=input_cls, output_schema=output_cls)
|
||||||
|
_g.add_node("sn", node_func, metadata=metadata)
|
||||||
|
_g.set_entry_point("sn")
|
||||||
|
_g.add_edge("sn", END)
|
||||||
|
_graph = _g.compile()
|
||||||
|
|
||||||
|
run_config = init_run_config(_graph, ctx)
|
||||||
|
return await _graph.ainvoke(payload, config=run_config)
|
||||||
|
|
||||||
|
def graph_inout_schema(self) -> Any:
|
||||||
|
if graph_helper.is_agent_proj():
|
||||||
|
return {"input_schema": {}, "output_schema": {}}
|
||||||
|
builder = getattr(self._get_graph(), 'builder', None)
|
||||||
|
if builder is not None:
|
||||||
|
input_cls = getattr(builder, 'input_schema', None) or self.graph.get_input_schema()
|
||||||
|
output_cls = getattr(builder, 'output_schema', None) or self.graph.get_output_schema()
|
||||||
|
else:
|
||||||
|
logger.warning(f"No builder input schema found for graph_inout_schema, using graph input schema instead")
|
||||||
|
input_cls = self.graph.get_input_schema()
|
||||||
|
output_cls = self.graph.get_output_schema()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_schema": input_cls.model_json_schema(),
|
||||||
|
"output_schema": output_cls.model_json_schema(),
|
||||||
|
"code":0,
|
||||||
|
"msg":""
|
||||||
|
}
|
||||||
|
|
||||||
|
async def astream(self, payload: Dict[str, Any], graph: CompiledStateGraph, run_config: RunnableConfig, ctx=Context, run_opt: Optional[RunOpt] = None) -> AsyncIterable[Any]:
|
||||||
|
stream_runner = self._get_stream_runner()
|
||||||
|
async for chunk in stream_runner.astream(payload, graph, run_config, ctx, run_opt):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
service = GraphService()
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# OpenAI 兼容接口处理器
|
||||||
|
openai_handler = OpenAIChatHandler(service)
|
||||||
|
|
||||||
|
|
||||||
|
HEADER_X_RUN_ID = "x-run-id"
|
||||||
|
@app.post("/run")
|
||||||
|
async def http_run(request: Request) -> Dict[str, Any]:
|
||||||
|
global result
|
||||||
|
raw_body = await request.body()
|
||||||
|
try:
|
||||||
|
body_text = raw_body.decode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
body_text = str(raw_body)
|
||||||
|
raise HTTPException(status_code=400,
|
||||||
|
detail=f"Invalid JSON format: {body_text}, traceback: {traceback.format_exc()}, error: {e}")
|
||||||
|
|
||||||
|
ctx = new_context(method="run", headers=request.headers)
|
||||||
|
# 优先使用上游指定的 run_id,保证 cancel 能精确匹配
|
||||||
|
upstream_run_id = request.headers.get(HEADER_X_RUN_ID)
|
||||||
|
if upstream_run_id:
|
||||||
|
ctx.run_id = upstream_run_id
|
||||||
|
run_id = ctx.run_id
|
||||||
|
request_context.set(ctx)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Received request for /run: "
|
||||||
|
f"run_id={run_id}, "
|
||||||
|
f"query={dict(request.query_params)}, "
|
||||||
|
f"body={body_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
|
||||||
|
# 创建任务并记录 - 这是关键,让我们可以通过run_id取消任务
|
||||||
|
task = asyncio.create_task(service.run(payload, ctx))
|
||||||
|
service.running_tasks[run_id] = task
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(task, timeout=float(TIMEOUT_SECONDS))
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"Run execution timeout after {TIMEOUT_SECONDS}s for run_id: {run_id}")
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
result = await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return {
|
||||||
|
"status": "timeout",
|
||||||
|
"run_id": run_id,
|
||||||
|
"message": f"Execution timeout: exceeded {TIMEOUT_SECONDS} seconds"
|
||||||
|
}
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
result = {}
|
||||||
|
if isinstance(result, dict):
|
||||||
|
result["run_id"] = run_id
|
||||||
|
return result
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON decode error in http_run: {e}, traceback: {traceback.format_exc()}")
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid JSON format, {extract_core_stack()}")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info(f"Request cancelled for run_id: {run_id}")
|
||||||
|
result = {"status": "cancelled", "run_id": run_id, "message": "Execution was cancelled"}
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 使用错误分类器获取错误信息
|
||||||
|
error_response = service.error_classifier.get_error_response(e, {"node_name": "http_run", "run_id": run_id})
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected error in http_run: [{error_response['error_code']}] {error_response['error_message']}, "
|
||||||
|
f"traceback: {traceback.format_exc()}", exc_info=True
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error_code": error_response["error_code"],
|
||||||
|
"error_message": error_response["error_message"],
|
||||||
|
"stack_trace": extract_core_stack(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
cozeloop.flush()
|
||||||
|
|
||||||
|
|
||||||
|
HEADER_X_WORKFLOW_STREAM_MODE = "x-workflow-stream-mode"
|
||||||
|
|
||||||
|
|
||||||
|
def _register_task(run_id: str, task: asyncio.Task):
|
||||||
|
service.running_tasks[run_id] = task
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/stream_run")
|
||||||
|
async def http_stream_run(request: Request):
|
||||||
|
ctx = new_context(method="stream_run", headers=request.headers)
|
||||||
|
# 优先使用上游指定的 run_id,保证 cancel 能精确匹配
|
||||||
|
upstream_run_id = request.headers.get(HEADER_X_RUN_ID)
|
||||||
|
if upstream_run_id:
|
||||||
|
ctx.run_id = upstream_run_id
|
||||||
|
workflow_stream_mode = request.headers.get(HEADER_X_WORKFLOW_STREAM_MODE, "").lower()
|
||||||
|
workflow_debug = workflow_stream_mode == "debug"
|
||||||
|
request_context.set(ctx)
|
||||||
|
raw_body = await request.body()
|
||||||
|
try:
|
||||||
|
body_text = raw_body.decode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
body_text = str(raw_body)
|
||||||
|
raise HTTPException(status_code=400,
|
||||||
|
detail=f"Invalid JSON format: {body_text}, traceback: {extract_core_stack()}, error: {e}")
|
||||||
|
run_id = ctx.run_id
|
||||||
|
is_agent = graph_helper.is_agent_proj()
|
||||||
|
logger.info(
|
||||||
|
f"Received request for /stream_run: "
|
||||||
|
f"run_id={run_id}, "
|
||||||
|
f"is_agent_project={is_agent}, "
|
||||||
|
f"query={dict(request.query_params)}, "
|
||||||
|
f"body={body_text}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON decode error in http_stream_run: {e}, traceback: {traceback.format_exc()}")
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid JSON format:{extract_core_stack()}")
|
||||||
|
|
||||||
|
if is_agent:
|
||||||
|
stream_generator = agent_stream_handler(
|
||||||
|
payload=payload,
|
||||||
|
ctx=ctx,
|
||||||
|
run_id=run_id,
|
||||||
|
stream_sse_func=service.stream_sse,
|
||||||
|
sse_event_func=service._sse_event,
|
||||||
|
error_classifier=service.error_classifier,
|
||||||
|
register_task_func=_register_task,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stream_generator = workflow_stream_handler(
|
||||||
|
payload=payload,
|
||||||
|
ctx=ctx,
|
||||||
|
run_id=run_id,
|
||||||
|
stream_sse_func=service.stream_sse,
|
||||||
|
sse_event_func=service._sse_event,
|
||||||
|
error_classifier=service.error_classifier,
|
||||||
|
register_task_func=_register_task,
|
||||||
|
run_opt=RunOpt(workflow_debug=workflow_debug),
|
||||||
|
)
|
||||||
|
|
||||||
|
response = StreamingResponse(stream_generator, media_type="text/event-stream")
|
||||||
|
return response
|
||||||
|
|
||||||
|
@app.post("/cancel/{run_id}")
|
||||||
|
async def http_cancel(run_id: str, request: Request):
|
||||||
|
"""
|
||||||
|
取消指定run_id的执行
|
||||||
|
|
||||||
|
使用asyncio.Task.cancel()实现取消,这是Python标准的异步任务取消机制。
|
||||||
|
LangGraph会在节点之间的await点检查CancelledError,实现优雅取消。
|
||||||
|
"""
|
||||||
|
ctx = new_context(method="cancel", headers=request.headers)
|
||||||
|
request_context.set(ctx)
|
||||||
|
logger.info(f"Received cancel request for run_id: {run_id}")
|
||||||
|
result = service.cancel_run(run_id, ctx)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(path="/node_run/{node_id}")
|
||||||
|
async def http_node_run(node_id: str, request: Request):
|
||||||
|
raw_body = await request.body()
|
||||||
|
try:
|
||||||
|
body_text = raw_body.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
body_text = str(raw_body)
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid JSON format: {body_text}")
|
||||||
|
ctx = new_context(method="node_run", headers=request.headers)
|
||||||
|
request_context.set(ctx)
|
||||||
|
logger.info(
|
||||||
|
f"Received request for /node_run/{node_id}: "
|
||||||
|
f"query={dict(request.query_params)}, "
|
||||||
|
f"body={body_text}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON decode error in http_node_run: {e}, traceback: {traceback.format_exc()}")
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid JSON format:{extract_core_stack()}")
|
||||||
|
try:
|
||||||
|
return await service.run_node(node_id, payload, ctx)
|
||||||
|
except KeyError:
|
||||||
|
raise HTTPException(status_code=404,
|
||||||
|
detail=f"node_id '{node_id}' not found or input miss required fields, traceback: {extract_core_stack()}")
|
||||||
|
except Exception as e:
|
||||||
|
# 使用错误分类器获取错误信息
|
||||||
|
error_response = service.error_classifier.get_error_response(e, {"node_name": node_id})
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected error in http_node_run: [{error_response['error_code']}] {error_response['error_message']}, "
|
||||||
|
f"traceback: {traceback.format_exc()}", exc_info=True
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error_code": error_response["error_code"],
|
||||||
|
"error_message": error_response["error_message"],
|
||||||
|
"stack_trace": extract_core_stack(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
cozeloop.flush()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def openai_chat_completions(request: Request):
|
||||||
|
"""OpenAI Chat Completions API 兼容接口"""
|
||||||
|
ctx = new_context(method="openai_chat", headers=request.headers)
|
||||||
|
request_context.set(ctx)
|
||||||
|
|
||||||
|
logger.info(f"Received request for /v1/chat/completions: run_id={ctx.run_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
return await openai_handler.handle(payload, ctx)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"JSON decode error in openai_chat_completions: {e}")
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid JSON format")
|
||||||
|
finally:
|
||||||
|
cozeloop.flush()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
try:
|
||||||
|
# 这里可以添加更多的健康检查逻辑
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"message": "Service is running",
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=503, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get(path="/graph_parameter")
|
||||||
|
async def http_graph_inout_parameter(request: Request):
|
||||||
|
return service.graph_inout_schema()
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Start FastAPI server")
|
||||||
|
parser.add_argument("-m", type=str, default="http", help="Run mode, support http,flow,node")
|
||||||
|
parser.add_argument("-n", type=str, default="", help="Node ID for single node run")
|
||||||
|
parser.add_argument("-p", type=int, default=5000, help="HTTP server port")
|
||||||
|
parser.add_argument("-i", type=str, default="", help="Input JSON string for flow/node mode")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_input(input_str: str) -> Dict[str, Any]:
|
||||||
|
"""Parse input string, support both JSON string and plain text"""
|
||||||
|
if not input_str:
|
||||||
|
return {"text": "你好"}
|
||||||
|
|
||||||
|
# Try to parse as JSON first
|
||||||
|
try:
|
||||||
|
return json.loads(input_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If not valid JSON, treat as plain text
|
||||||
|
return {"text": input_str}
|
||||||
|
|
||||||
|
def start_http_server(port):
|
||||||
|
workers = 1
|
||||||
|
reload = False
|
||||||
|
if graph_helper.is_dev_env():
|
||||||
|
reload = True
|
||||||
|
|
||||||
|
logger.info(f"Start HTTP Server, Port: {port}, Workers: {workers}")
|
||||||
|
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=reload, workers=workers)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
if args.m == "http":
|
||||||
|
start_http_server(args.p)
|
||||||
|
elif args.m == "flow":
|
||||||
|
payload = parse_input(args.i)
|
||||||
|
result = asyncio.run(service.run(payload))
|
||||||
|
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||||
|
elif args.m == "node" and args.n:
|
||||||
|
payload = parse_input(args.i)
|
||||||
|
result = asyncio.run(service.run_node(args.n, payload))
|
||||||
|
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||||
|
elif args.m == "agent":
|
||||||
|
agent_ctx = new_context(method="agent")
|
||||||
|
for chunk in service.stream(
|
||||||
|
{
|
||||||
|
"type": "query",
|
||||||
|
"session_id": "1",
|
||||||
|
"message": "你好",
|
||||||
|
"content": {
|
||||||
|
"query": {
|
||||||
|
"prompt": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"content": {"text": "现在几点了?请调用工具获取当前时间"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
run_config={"configurable": {"session_id": "1"}},
|
||||||
|
ctx=agent_ctx,
|
||||||
|
):
|
||||||
|
print(chunk)
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_RETRY_TIME = 20 # 连接最大重试时间(秒)
|
||||||
|
# Load environment variables from .env if present
|
||||||
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_db_url() -> str:
|
||||||
|
"""Build database URL from environment."""
|
||||||
|
url = os.getenv("PGDATABASE_URL") or ""
|
||||||
|
if url is not None and url != "":
|
||||||
|
return url
|
||||||
|
from coze_workload_identity import Client
|
||||||
|
try:
|
||||||
|
client = Client()
|
||||||
|
env_vars = client.get_project_env_vars()
|
||||||
|
client.close()
|
||||||
|
for env_var in env_vars:
|
||||||
|
if env_var.key == "PGDATABASE_URL":
|
||||||
|
url = env_var.value.replace("'", "'\\''")
|
||||||
|
return url
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading PGDATABASE_URL: {e}")
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
if url is None or url == "":
|
||||||
|
logger.error("PGDATABASE_URL is not set")
|
||||||
|
return url
|
||||||
|
_engine = None
|
||||||
|
_SessionLocal = None
|
||||||
|
|
||||||
|
def _create_engine_with_retry():
|
||||||
|
url = get_db_url()
|
||||||
|
if url is None or url == "":
|
||||||
|
logger.error("PGDATABASE_URL is not set")
|
||||||
|
raise ValueError("PGDATABASE_URL is not set")
|
||||||
|
size = 100
|
||||||
|
overflow = 100
|
||||||
|
recycle = 1800
|
||||||
|
timeout = 30
|
||||||
|
engine = create_engine(
|
||||||
|
url,
|
||||||
|
pool_size=size,
|
||||||
|
max_overflow=overflow,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_recycle=recycle,
|
||||||
|
pool_timeout=timeout,
|
||||||
|
)
|
||||||
|
# 验证连接,带重试
|
||||||
|
start_time = time.time()
|
||||||
|
last_error = None
|
||||||
|
while time.time() - start_time < MAX_RETRY_TIME:
|
||||||
|
try:
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text("SELECT 1"))
|
||||||
|
return engine
|
||||||
|
except OperationalError as e:
|
||||||
|
last_error = e
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
logger.warning(f"Database connection failed, retrying... (elapsed: {elapsed:.1f}s)")
|
||||||
|
time.sleep(min(1, MAX_RETRY_TIME - elapsed))
|
||||||
|
logger.error(f"Database connection failed after {MAX_RETRY_TIME}s: {last_error}")
|
||||||
|
raise last_error # pyright: ignore [reportGeneralTypeIssues]
|
||||||
|
|
||||||
|
def get_engine():
|
||||||
|
global _engine
|
||||||
|
if _engine is None:
|
||||||
|
_engine = _create_engine_with_retry()
|
||||||
|
return _engine
|
||||||
|
|
||||||
|
def get_sessionmaker():
|
||||||
|
global _SessionLocal
|
||||||
|
if _SessionLocal is None:
|
||||||
|
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=get_engine())
|
||||||
|
return _SessionLocal
|
||||||
|
|
||||||
|
def get_session():
|
||||||
|
return get_sessionmaker()()
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_db_url",
|
||||||
|
"get_engine",
|
||||||
|
"get_sessionmaker",
|
||||||
|
"get_session",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
from sqlalchemy import BigInteger, DateTime, Identity, Index, Integer, JSON, PrimaryKeyConstraint, Text, text
|
||||||
|
from typing import Optional
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,135 @@
|
||||||
|
import psycopg
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
from langgraph.checkpoint.postgres import PostgresSaver
|
||||||
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
||||||
|
from typing import Optional, Union
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 数据库连接超时时间(秒),每次尝试 15 秒,共尝试 2 次
|
||||||
|
DB_CONNECTION_TIMEOUT = 15
|
||||||
|
DB_MAX_RETRIES = 2
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryManager:
|
||||||
|
"""Memory Manager 单例类"""
|
||||||
|
|
||||||
|
_instance: Optional['MemoryManager'] = None
|
||||||
|
_checkpointer: Optional[Union[AsyncPostgresSaver, MemorySaver]] = None
|
||||||
|
_pool: Optional[AsyncConnectionPool] = None
|
||||||
|
_setup_done: bool = False
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def _connect_with_retry(self, db_url: str) -> Optional[psycopg.Connection]:
|
||||||
|
"""带重试的数据库连接,每次 15 秒超时,共尝试 2 次"""
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, DB_MAX_RETRIES + 1):
|
||||||
|
try:
|
||||||
|
logger.info(f"Attempting database connection (attempt {attempt}/{DB_MAX_RETRIES})")
|
||||||
|
conn = psycopg.connect(db_url, autocommit=True, connect_timeout=DB_CONNECTION_TIMEOUT)
|
||||||
|
logger.info(f"Database connection established on attempt {attempt}")
|
||||||
|
return conn
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
logger.warning(f"Database connection attempt {attempt} failed: {e}")
|
||||||
|
if attempt < DB_MAX_RETRIES:
|
||||||
|
time.sleep(1) # 重试前短暂等待
|
||||||
|
logger.error(f"All {DB_MAX_RETRIES} database connection attempts failed, last error: {last_error}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _setup_schema_and_tables(self, db_url: str) -> bool:
|
||||||
|
"""同步创建 schema 和表(只执行一次),返回是否成功"""
|
||||||
|
if self._setup_done:
|
||||||
|
return True
|
||||||
|
|
||||||
|
conn = self._connect_with_retry(db_url)
|
||||||
|
if conn is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute("CREATE SCHEMA IF NOT EXISTS memory")
|
||||||
|
conn.execute("SET search_path TO memory")
|
||||||
|
PostgresSaver(conn).setup()
|
||||||
|
self._setup_done = True
|
||||||
|
logger.info("Memory schema and tables created")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to setup schema/tables: {e}")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _get_db_url_safe(self) -> Optional[str]:
|
||||||
|
"""安全获取 db_url,失败时返回 None"""
|
||||||
|
try:
|
||||||
|
from storage.database.db import get_db_url
|
||||||
|
db_url = get_db_url()
|
||||||
|
if db_url and db_url.strip():
|
||||||
|
return db_url
|
||||||
|
logger.warning("db_url is empty, will fallback to MemorySaver")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get db_url: {e}, will fallback to MemorySaver")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_fallback_checkpointer(self) -> MemorySaver:
|
||||||
|
"""创建内存兜底 checkpointer"""
|
||||||
|
self._checkpointer = MemorySaver()
|
||||||
|
logger.warning("Using MemorySaver as fallback checkpointer (data will not persist across restarts)")
|
||||||
|
return self._checkpointer
|
||||||
|
|
||||||
|
def get_checkpointer(self) -> BaseCheckpointSaver:
|
||||||
|
"""获取 checkpointer,优先使用 PostgresSaver,失败时退化为 MemorySaver"""
|
||||||
|
if self._checkpointer is not None:
|
||||||
|
return self._checkpointer
|
||||||
|
|
||||||
|
# 1. 尝试获取 db_url
|
||||||
|
db_url = self._get_db_url_safe()
|
||||||
|
if not db_url:
|
||||||
|
return self._create_fallback_checkpointer()
|
||||||
|
|
||||||
|
# 2. 尝试连接数据库并创建 schema/表(带重试)
|
||||||
|
if not self._setup_schema_and_tables(db_url):
|
||||||
|
return self._create_fallback_checkpointer()
|
||||||
|
|
||||||
|
# 3. 连接字符串加上 search_path
|
||||||
|
if "?" in db_url:
|
||||||
|
db_url = f"{db_url}&options=-csearch_path%3Dmemory"
|
||||||
|
else:
|
||||||
|
db_url = f"{db_url}?options=-csearch_path%3Dmemory"
|
||||||
|
|
||||||
|
# 4. 尝试创建连接池和 checkpointer
|
||||||
|
try:
|
||||||
|
self._pool = AsyncConnectionPool(
|
||||||
|
conninfo=db_url,
|
||||||
|
timeout=DB_CONNECTION_TIMEOUT,
|
||||||
|
min_size=1,
|
||||||
|
max_idle=300,
|
||||||
|
check=AsyncConnectionPool.check_connection,
|
||||||
|
)
|
||||||
|
self._checkpointer = AsyncPostgresSaver(self._pool)
|
||||||
|
logger.info("AsyncPostgresSaver initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to create AsyncPostgresSaver: {e}, will fallback to MemorySaver")
|
||||||
|
return self._create_fallback_checkpointer()
|
||||||
|
|
||||||
|
return self._checkpointer
|
||||||
|
|
||||||
|
_memory_manager: Optional[MemoryManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_saver() -> BaseCheckpointSaver:
|
||||||
|
"""获取 checkpointer,优先使用 PostgresSaver,db_url 不可用或连接失败时退化为 MemorySaver"""
|
||||||
|
global _memory_manager
|
||||||
|
if _memory_manager is None:
|
||||||
|
_memory_manager = MemoryManager()
|
||||||
|
return _memory_manager.get_checkpointer()
|
||||||
|
|
@ -0,0 +1,424 @@
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Any, Dict, List, TypedDict, Iterable
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from boto3.s3.transfer import TransferConfig
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 允许的文件名字符集(面向用户输入的约束)
|
||||||
|
FILE_NAME_ALLOWED_RE = re.compile(r"^[A-Za-z0-9._\-/]+$")
|
||||||
|
|
||||||
|
|
||||||
|
class ListFilesResult(TypedDict):
|
||||||
|
# list_files 的返回结构类型
|
||||||
|
keys: List[str]
|
||||||
|
is_truncated: bool
|
||||||
|
next_continuation_token: Optional[str]
|
||||||
|
|
||||||
|
class S3SyncStorage:
|
||||||
|
"""S3兼容存储实现"""
|
||||||
|
|
||||||
|
def __init__(self, *, endpoint_url: Optional[str] = None, access_key: str, secret_key: str, bucket_name: str, region: str = "cn-beijing"):
|
||||||
|
self.endpoint_url = os.environ.get("COZE_BUCKET_ENDPOINT_URL") or endpoint_url or ''
|
||||||
|
self.access_key = access_key
|
||||||
|
self.secret_key = secret_key
|
||||||
|
self.bucket_name = bucket_name
|
||||||
|
self.region = region
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
def _get_client(self):
|
||||||
|
if self._client is None:
|
||||||
|
endpoint = self.endpoint_url
|
||||||
|
if endpoint is None or endpoint == "":
|
||||||
|
try:
|
||||||
|
from coze_workload_identity import Client as CozeEnvClient
|
||||||
|
coze_env_client = CozeEnvClient()
|
||||||
|
env_vars = coze_env_client.get_project_env_vars()
|
||||||
|
coze_env_client.close()
|
||||||
|
for env_var in env_vars:
|
||||||
|
if env_var.key == "COZE_BUCKET_ENDPOINT_URL":
|
||||||
|
endpoint = env_var.value.replace("'", "'\\''")
|
||||||
|
self.endpoint_url = endpoint
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading COZE_BUCKET_ENDPOINT_URL: {e}")
|
||||||
|
# 保持向下校验逻辑,避免在此处中断
|
||||||
|
if endpoint is None or endpoint == "":
|
||||||
|
logger.error("未配置存储端点:请设置endpoint_url")
|
||||||
|
raise ValueError("未配置存储端点:请设置endpoint_url")
|
||||||
|
|
||||||
|
client = boto3.client(
|
||||||
|
"s3",
|
||||||
|
endpoint_url=endpoint,
|
||||||
|
aws_access_key_id=self.access_key,
|
||||||
|
aws_secret_access_key=self.secret_key,
|
||||||
|
region_name=self.region,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册 before-call 钩子,发送前注入 x-storage-token 头
|
||||||
|
def _inject_header(**kwargs):
|
||||||
|
try:
|
||||||
|
from coze_workload_identity import Client as CozeClient
|
||||||
|
coze_client = CozeClient()
|
||||||
|
try:
|
||||||
|
token = coze_client.get_access_token()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error loading COZE_WORKLOAD_IDENTITY_TOKEN: %s", e)
|
||||||
|
token = None
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
coze_client.close()
|
||||||
|
params = kwargs.get("params", {})
|
||||||
|
headers = params.setdefault("headers", {})
|
||||||
|
headers["x-storage-token"] = token
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error loading COZE_WORKLOAD_IDENTITY_TOKEN: %s", e)
|
||||||
|
pass
|
||||||
|
client.meta.events.register("before-call.s3", _inject_header)
|
||||||
|
self._client = client
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _generate_object_key(self, *, original_name: str) -> str:
|
||||||
|
suffix = Path(original_name).suffix.lower()
|
||||||
|
stem = Path(original_name).stem
|
||||||
|
uniq = uuid4().hex[:8]
|
||||||
|
return f"{stem}_{uniq}{suffix}"
|
||||||
|
|
||||||
|
def _extract_logid(self, e: Exception) -> Optional[str]:
|
||||||
|
"""从 ClientError 中提取 x-tt-logid"""
|
||||||
|
if isinstance(e, ClientError):
|
||||||
|
headers = (e.response or {}).get("ResponseMetadata", {}).get("HTTPHeaders", {})
|
||||||
|
return headers.get("x-tt-logid")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _error_msg(self, msg: str, e: Exception) -> str:
|
||||||
|
"""构建带 logid 的错误信息"""
|
||||||
|
logid = self._extract_logid(e)
|
||||||
|
if logid:
|
||||||
|
return f"{msg}: {e} (x-tt-logid: {logid})"
|
||||||
|
return f"{msg}: {e}"
|
||||||
|
|
||||||
|
def _resolve_bucket(self, bucket: Optional[str]) -> str:
|
||||||
|
"""统一解析 bucket 来源,确保得到有效桶名。"""
|
||||||
|
target_bucket = bucket or os.environ.get("COZE_BUCKET_NAME") or self.bucket_name
|
||||||
|
if not target_bucket:
|
||||||
|
raise ValueError("未配置 bucket:请传入 bucket 或设置 COZE_BUCKET_NAME,或在实例化时提供 bucket_name")
|
||||||
|
return target_bucket
|
||||||
|
|
||||||
|
def _validate_file_name(self, name: str) -> None:
|
||||||
|
"""校验 S3 对象命名:长度≤1024;允许 [A-Za-z0-9._-/];不以 / 起止且不含 //。"""
|
||||||
|
msg = (
|
||||||
|
"file name invalid: 文件名需满足以下 S3 对象命名规范:"
|
||||||
|
"1) 长度 1–1024 字节;"
|
||||||
|
"2) 仅允许字母、数字、点(.)、下划线(_)、短横(-)、目录分隔符(/);"
|
||||||
|
"3) 不允许空格或以下特殊字符:? # & % { } ^ [ ] ` \\ < > ~ | \" ' + = : ;;"
|
||||||
|
"4) 不以 / 开头或结尾,且不包含连续的 //;"
|
||||||
|
"示例:report_2025-12-11.pdf、images/photo-01.png。"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not name or not name.strip():
|
||||||
|
raise ValueError(msg + "(原因:文件名为空)")
|
||||||
|
|
||||||
|
# S3 限制对象 key 最大 1024 字节,这里沿用到输入文件名
|
||||||
|
if len(name.encode("utf-8")) > 1024:
|
||||||
|
raise ValueError(msg + "(原因:长度超过 1024 字节)")
|
||||||
|
|
||||||
|
if name.startswith("/") or name.endswith("/"):
|
||||||
|
raise ValueError(msg + "(原因:以 / 开头或结尾)")
|
||||||
|
if "//" in name:
|
||||||
|
raise ValueError(msg + "(原因:包含连续的 //)")
|
||||||
|
|
||||||
|
# 允许字符集校验
|
||||||
|
if not FILE_NAME_ALLOWED_RE.match(name):
|
||||||
|
bad = re.findall(r"[^A-Za-z0-9._\-/]", name)
|
||||||
|
example = bad[0] if bad else "非法字符"
|
||||||
|
raise ValueError(msg + f"(原因:包含非法字符,例如:{example})")
|
||||||
|
|
||||||
|
def upload_file(self, *, file_content: bytes, file_name: str, content_type: str = "application/octet-stream", bucket: Optional[str] = None) -> str:
|
||||||
|
# 先对输入文件名做规范校验,避免生成无效对象 key
|
||||||
|
self._validate_file_name(file_name)
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
object_key = self._generate_object_key(original_name=file_name)
|
||||||
|
target_bucket = self._resolve_bucket(bucket)
|
||||||
|
client.put_object(Bucket=target_bucket, Key=object_key, Body=file_content, ContentType=content_type)
|
||||||
|
return object_key
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(self._error_msg("Error uploading file to S3", e))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete_file(self, *, file_key: str, bucket: Optional[str] = None) -> bool:
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
target_bucket = self._resolve_bucket(bucket)
|
||||||
|
client.delete_object(Bucket=target_bucket, Key=file_key)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(self._error_msg("Error deleting file from S3", e))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def file_exists(self, *, file_key: str, bucket: Optional[str] = None) -> bool:
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
target_bucket = self._resolve_bucket(bucket)
|
||||||
|
client.head_object(Bucket=target_bucket, Key=file_key)
|
||||||
|
return True
|
||||||
|
except ClientError as e:
|
||||||
|
code = (e.response or {}).get("Error", {}).get("Code", "")
|
||||||
|
if code in {"404", "NoSuchKey", "NotFound"}:
|
||||||
|
return False
|
||||||
|
logger.error(self._error_msg("Error checking file existence in S3", e))
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(self._error_msg("Error checking file existence in S3", e))
|
||||||
|
return False
|
||||||
|
|
||||||
|
def read_file(self, *, file_key: str, bucket: Optional[str] = None) -> bytes:
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
target_bucket = self._resolve_bucket(bucket)
|
||||||
|
resp = client.get_object(Bucket=target_bucket, Key=file_key)
|
||||||
|
body = resp.get("Body")
|
||||||
|
if body is None:
|
||||||
|
raise RuntimeError("S3 get_object returned no Body")
|
||||||
|
try:
|
||||||
|
return body.read()
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
body.close()
|
||||||
|
except Exception as ce:
|
||||||
|
# 资源关闭失败不影响读取结果,仅记录以便排查
|
||||||
|
logger.debug("Failed to close S3 response body: %s", ce)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(self._error_msg("Error reading file from S3", e))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def list_files(self, *, prefix: Optional[str] = None, bucket: Optional[str] = None, max_keys: int = 1000, continuation_token: Optional[str] = None) -> ListFilesResult:
|
||||||
|
"""列出对象,支持前缀过滤与分页;返回 keys/is_truncated/next_continuation_token。"""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
target_bucket = self._resolve_bucket(bucket)
|
||||||
|
if max_keys <= 0 or max_keys > 1000:
|
||||||
|
raise ValueError("max_keys 必须在 1 到 1000 之间")
|
||||||
|
|
||||||
|
kwargs: Dict[str, Any] = {
|
||||||
|
"Bucket": target_bucket,
|
||||||
|
"MaxKeys": max_keys,
|
||||||
|
"Prefix": prefix,
|
||||||
|
"ContinuationToken": continuation_token,
|
||||||
|
}
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
resp = client.list_objects_v2(**kwargs)
|
||||||
|
contents = resp.get("Contents", []) or []
|
||||||
|
keys: List[str] = [item.get("Key") for item in contents if isinstance(item, dict) and item.get("Key")]
|
||||||
|
return {
|
||||||
|
"keys": keys,
|
||||||
|
"is_truncated": bool(resp.get("IsTruncated")),
|
||||||
|
"next_continuation_token": resp.get("NextContinuationToken"),
|
||||||
|
}
|
||||||
|
except ClientError as e:
|
||||||
|
code = (e.response or {}).get("Error", {}).get("Code", "")
|
||||||
|
logger.error(self._error_msg(f"Error listing files in S3 (code={code})", e))
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(self._error_msg("Error listing files in S3", e))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def generate_presigned_url(self, *, key: str, bucket: Optional[str] = None, expire_time: int = 1800) -> str:
|
||||||
|
"""通过 S3 Proxy 生成签名 URL。"""
|
||||||
|
import json
|
||||||
|
import urllib.request as urllib_request
|
||||||
|
try:
|
||||||
|
from coze_workload_identity import Client as CozeClient
|
||||||
|
coze_client = CozeClient()
|
||||||
|
try:
|
||||||
|
token = coze_client.get_access_token()
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
coze_client.close()
|
||||||
|
except Exception:
|
||||||
|
# 资源释放失败不影响后续流程
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading x-storage-token: {e}")
|
||||||
|
raise RuntimeError(f"获取 x-storage-token 失败: {e}")
|
||||||
|
try:
|
||||||
|
sign_base = os.environ.get("COZE_BUCKET_ENDPOINT_URL") or self.endpoint_url
|
||||||
|
if not sign_base:
|
||||||
|
raise ValueError("未配置签名端点:请设置 COZE_BUCKET_ENDPOINT_URL 或传入 endpoint_url")
|
||||||
|
sign_url_endpoint = sign_base.rstrip("/") + "/sign-url"
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"x-storage-token": token,
|
||||||
|
}
|
||||||
|
|
||||||
|
target_bucket = self._resolve_bucket(bucket)
|
||||||
|
payload = {"bucket_name": target_bucket, "path": key, "expire_time": expire_time}
|
||||||
|
data = json.dumps(payload).encode("utf-8")
|
||||||
|
request = urllib_request.Request(sign_url_endpoint, data=data, headers=headers, method="POST")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating request for sign-url: {e}")
|
||||||
|
raise RuntimeError(f"创建 sign-url 请求失败: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with urllib_request.urlopen(request) as resp:
|
||||||
|
resp_bytes = resp.read()
|
||||||
|
content_type = resp.headers.get("Content-Type", "")
|
||||||
|
text = resp_bytes.decode("utf-8", errors="replace")
|
||||||
|
if "application/json" in content_type or text.strip().startswith("{"):
|
||||||
|
try:
|
||||||
|
obj = json.loads(text)
|
||||||
|
except Exception:
|
||||||
|
return text
|
||||||
|
data = obj.get("data")
|
||||||
|
if isinstance(data, dict) and "url" in data:
|
||||||
|
return data["url"]
|
||||||
|
url_value = obj.get("url") or obj.get("signed_url") or obj.get("presigned_url")
|
||||||
|
if url_value:
|
||||||
|
return url_value
|
||||||
|
raise ValueError("签名服务返回缺少 data.url/url 字段")
|
||||||
|
return text
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"生成签名URL失败: {e}")
|
||||||
|
|
||||||
|
def stream_upload_file(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
fileobj,
|
||||||
|
file_name: str,
|
||||||
|
content_type: str = "application/octet-stream",
|
||||||
|
bucket: Optional[str] = None,
|
||||||
|
multipart_chunksize: int = 5 * 1024 * 1024,
|
||||||
|
multipart_threshold: int = 5 * 1024 * 1024,
|
||||||
|
max_concurrency: int = 1,
|
||||||
|
use_threads: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""流式上传(文件对象)
|
||||||
|
- fileobj: 任何带有 read() 方法的文件对象(如 open(..., 'rb') 返回的对象、io.BytesIO 等)
|
||||||
|
- file_name: 原始文件名,用于生成唯一 key
|
||||||
|
- content_type: MIME 类型
|
||||||
|
- bucket: 目标桶;为空时取环境变量或实例默认值
|
||||||
|
- multipart_chunksize: 分片大小(默认 5MB,以适配代理层限制)
|
||||||
|
- multipart_threshold: 触发分片上传的阈值(默认 5MB)
|
||||||
|
- max_concurrency: 并发分片上传的并发数(默认 1,避免代理层节流影响)
|
||||||
|
- use_threads: 是否启用线程并发(默认 False)
|
||||||
|
返回:最终写入的对象 key
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
target_bucket = self._resolve_bucket(bucket)
|
||||||
|
key = self._generate_object_key(original_name=file_name)
|
||||||
|
|
||||||
|
extra_args = {"ContentType": content_type} if content_type else {}
|
||||||
|
# 使用 boto3 的高阶方法执行多段上传(传入 TransferConfig 控制分片大小)
|
||||||
|
|
||||||
|
config = TransferConfig(
|
||||||
|
multipart_chunksize=multipart_chunksize,
|
||||||
|
multipart_threshold=multipart_threshold,
|
||||||
|
max_concurrency=max_concurrency,
|
||||||
|
use_threads=use_threads,
|
||||||
|
)
|
||||||
|
client.upload_fileobj(Fileobj=fileobj, Bucket=target_bucket, Key=key, ExtraArgs=extra_args, Config=config)
|
||||||
|
return key
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(self._error_msg("Error streaming upload (fileobj) to S3", e))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def upload_from_url(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
url: str,
|
||||||
|
bucket: Optional[str] = None,
|
||||||
|
timeout: int = 30,
|
||||||
|
) -> str:
|
||||||
|
"""从 URL 流式下载并上传到 S3
|
||||||
|
- url: 源文件 URL
|
||||||
|
- bucket: 目标桶;为空时取环境变量或实例默认值
|
||||||
|
- timeout: HTTP 请求超时时间(秒,默认 30)
|
||||||
|
返回:最终写入的对象 key
|
||||||
|
"""
|
||||||
|
import urllib.request as urllib_request
|
||||||
|
from urllib.parse import urlparse, unquote
|
||||||
|
try:
|
||||||
|
request = urllib_request.Request(url)
|
||||||
|
with urllib_request.urlopen(request, timeout=timeout) as resp:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
file_name = Path(unquote(parsed.path)).name or "file"
|
||||||
|
content_type = resp.headers.get("Content-Type", "application/octet-stream")
|
||||||
|
return self.stream_upload_file(
|
||||||
|
fileobj=resp,
|
||||||
|
file_name=file_name,
|
||||||
|
content_type=content_type,
|
||||||
|
bucket=bucket,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(self._error_msg("Error uploading from URL to S3", e))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def trunk_upload_file(self, *, chunk_iter: Iterable[bytes], file_name: str,
|
||||||
|
content_type: str = "application/octet-stream", bucket: Optional[str] = None,
|
||||||
|
part_size: int = 5 * 1024 * 1024) -> str:
|
||||||
|
"""流式上传(字节迭代器,显式分片 Multipart Upload)
|
||||||
|
- chunk_iter: 可迭代对象,逐块产生 bytes;每块大小可变(内部累积到 part_size 再上传),最后一块可小于 5MB
|
||||||
|
- file_name: 原始文件名,用于生成唯一 key
|
||||||
|
- content_type: MIME 类型
|
||||||
|
- bucket: 目标桶;为空时取环境或实例默认值
|
||||||
|
- part_size: 每个 part 的最小大小(除最后一个);默认 5MB
|
||||||
|
返回:最终写入的对象 key
|
||||||
|
"""
|
||||||
|
client = self._get_client()
|
||||||
|
target_bucket = self._resolve_bucket(bucket)
|
||||||
|
key = self._generate_object_key(original_name=file_name)
|
||||||
|
|
||||||
|
# 初始化分片上传
|
||||||
|
try:
|
||||||
|
init_resp = client.create_multipart_upload(Bucket=target_bucket, Key=key, ContentType=content_type)
|
||||||
|
upload_id = init_resp["UploadId"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(self._error_msg("create_multipart_upload failed", e))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
part_number = 1
|
||||||
|
buffer = bytearray()
|
||||||
|
try:
|
||||||
|
for chunk in chunk_iter:
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
buffer.extend(chunk)
|
||||||
|
while len(buffer) >= part_size:
|
||||||
|
data = bytes(buffer[:part_size])
|
||||||
|
buffer = buffer[part_size:]
|
||||||
|
resp = client.upload_part(Bucket=target_bucket, Key=key, UploadId=upload_id, PartNumber=part_number,
|
||||||
|
Body=data)
|
||||||
|
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
|
||||||
|
part_number += 1
|
||||||
|
|
||||||
|
# 上传最后不足 part_size 的余量
|
||||||
|
if len(buffer) > 0:
|
||||||
|
resp = client.upload_part(Bucket=target_bucket, Key=key, UploadId=upload_id, PartNumber=part_number,
|
||||||
|
Body=bytes(buffer))
|
||||||
|
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
|
||||||
|
|
||||||
|
# 完成分片
|
||||||
|
client.complete_multipart_upload(
|
||||||
|
Bucket=target_bucket,
|
||||||
|
Key=key,
|
||||||
|
UploadId=upload_id,
|
||||||
|
MultipartUpload={"Parts": parts},
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(self._error_msg("multipart upload failed", e))
|
||||||
|
try:
|
||||||
|
client.abort_multipart_upload(Bucket=target_bucket, Key=key, UploadId=upload_id)
|
||||||
|
except Exception as ae:
|
||||||
|
logger.error(self._error_msg("abort_multipart_upload failed", ae))
|
||||||
|
raise e
|
||||||
|
|
@ -0,0 +1,325 @@
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import uuid
|
||||||
|
import chardet
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Literal,Callable, Any, Optional,Union
|
||||||
|
from pydantic import BaseModel, Field, field_validator,PrivateAttr,ConfigDict
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from pptx import Presentation
|
||||||
|
|
||||||
|
MAX_FILE_SIZE = 100 * 1024 * 1024
|
||||||
|
|
||||||
|
class File(BaseModel):
|
||||||
|
"""
|
||||||
|
通用文件对象,支持自动类型推断和路径管理
|
||||||
|
"""
|
||||||
|
url: str = Field(..., description="文件URL(http/https)或本地路径")
|
||||||
|
file_type: Literal['image', 'video', 'audio', 'document', 'default'] = Field(
|
||||||
|
default="default",
|
||||||
|
description="文件类型"
|
||||||
|
)
|
||||||
|
_local_path: Optional[str] = PrivateAttr(default=None)
|
||||||
|
model_config = ConfigDict(
|
||||||
|
json_schema_extra={
|
||||||
|
"x-component": "file-upload", # 前端用文件上传组件
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_cache_path(self, path: str):
|
||||||
|
"""设置缓存路径"""
|
||||||
|
self._local_path = path
|
||||||
|
|
||||||
|
def get_cache_path(self) -> Optional[str]:
|
||||||
|
"""获取缓存路径(如果文件实际存在)"""
|
||||||
|
return self._local_path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_remote(self) -> bool:
|
||||||
|
"""判断是网络URL还是本地文件"""
|
||||||
|
return self.url.startswith(('http://', 'https://'))
|
||||||
|
|
||||||
|
def infer_file_category(path_or_url: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
根据路径或URL后缀判断文件类型
|
||||||
|
逻辑:
|
||||||
|
1. 解析 URL 去除 query 参数 (?id=...),提取 path
|
||||||
|
2. 获取 path 最后一部分的文件名和后缀
|
||||||
|
3. 查表判断,匹配不到则返回 'default'
|
||||||
|
|
||||||
|
Return:
|
||||||
|
- 分类:image, video, audio, document, default
|
||||||
|
- 后缀:.pdf
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# === 步骤 1 & 2: 提取纯净的后缀名 ===
|
||||||
|
# urlparse 可以同时处理本地路径 (会被视为 path) 和 网络 URL
|
||||||
|
parsed = urlparse(path_or_url)
|
||||||
|
path = parsed.path # 提取路径部分,忽略 http://... 和 ?query=...
|
||||||
|
|
||||||
|
# 获取文件名 (例如 /a/b/test.jpg -> test.jpg)
|
||||||
|
filename = os.path.basename(path)
|
||||||
|
|
||||||
|
# 分离后缀 (test.jpg -> .jpg)
|
||||||
|
_, ext_with_dot = os.path.splitext(filename)
|
||||||
|
|
||||||
|
# 如果没有后缀,直接兜底
|
||||||
|
if not ext_with_dot:
|
||||||
|
return 'default', ""
|
||||||
|
|
||||||
|
# 去除点并转小写 (例如 .JPG -> jpg)
|
||||||
|
ext = ext_with_dot.lstrip('.').lower()
|
||||||
|
|
||||||
|
# === 步骤 3: 查表匹配 ===
|
||||||
|
# 定义常见映射表
|
||||||
|
TYPE_MAPPING = {
|
||||||
|
'image': {
|
||||||
|
'apng', 'avif', 'bmp', 'gif', 'heic', 'ico', 'jpg', 'jpeg', 'png', 'svg', 'tiff', 'webp'
|
||||||
|
},
|
||||||
|
'video': {
|
||||||
|
'mp4', 'avi', 'mov', 'mkv', 'flv', 'wmv', 'webm', 'm4v', '3gp'
|
||||||
|
},
|
||||||
|
'audio': {
|
||||||
|
'mp3', 'wav', 'flac', 'aac', 'ogg', 'wma', 'm4a'
|
||||||
|
},
|
||||||
|
'document': {
|
||||||
|
'pdf', 'doc', 'docx', 'xls', 'xlsx', 'ppt', 'pptx',
|
||||||
|
'txt', 'md', 'csv', 'json', 'xml', 'html', 'htm'
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for category, extensions in TYPE_MAPPING.items():
|
||||||
|
if ext in extensions:
|
||||||
|
return category, ext_with_dot
|
||||||
|
|
||||||
|
return 'default', ext_with_dot
|
||||||
|
|
||||||
|
class FileOps:
|
||||||
|
DOWNLOAD_DIR = "/tmp"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_bytes_stream(file_obj:File) -> tuple[bytes, str]:
|
||||||
|
"""
|
||||||
|
获取文件内容和后缀, 大小限制检查, 超出抛异常
|
||||||
|
"""
|
||||||
|
_, ext = infer_file_category(file_obj.url)
|
||||||
|
|
||||||
|
if file_obj.is_remote:
|
||||||
|
try:
|
||||||
|
# stream=True: 此时只下载 Headers,连接保持打开,还没下载 Body
|
||||||
|
with requests.get(file_obj.url, stream=True, timeout=60) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
content_length = resp.headers.get('Content-Length')
|
||||||
|
if content_length and int(content_length) > MAX_FILE_SIZE:
|
||||||
|
raise Exception(
|
||||||
|
f"文件大小 ({int(content_length)} bytes) 超过限制 100MB,已终止下载。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 场景:Header 缺失 Content-Length 或服务器 Header 欺骗
|
||||||
|
downloaded_content = BytesIO()
|
||||||
|
current_size = 0
|
||||||
|
|
||||||
|
# 分块读取,每块 8KB
|
||||||
|
for chunk in resp.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
current_size += len(chunk)
|
||||||
|
if current_size > MAX_FILE_SIZE:
|
||||||
|
raise Exception(f"检测到文件超过 100MB,已中断。")
|
||||||
|
downloaded_content.write(chunk)
|
||||||
|
|
||||||
|
# 获取完整 bytes
|
||||||
|
return downloaded_content.getvalue(), ext
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise RuntimeError(f"网络请求失败: {e}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
if not os.path.exists(file_obj.url):
|
||||||
|
raise FileNotFoundError(f"本地文件不存在: {file_obj.url}")
|
||||||
|
|
||||||
|
'''
|
||||||
|
file_size = os.path.getsize(file_obj.url)
|
||||||
|
if file_size > MAX_FILE_SIZE:
|
||||||
|
raise Exception(f"本地文件大小 ({file_size} bytes) 超过限制 100MB")
|
||||||
|
'''
|
||||||
|
|
||||||
|
with open(file_obj.url, 'rb') as f:
|
||||||
|
return f.read(), ext
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_to_local(file_obj: File, filename: str) -> str:
|
||||||
|
"""
|
||||||
|
将当前文件对象的内容保存到本地路径, 返回本地路径
|
||||||
|
如果是本地路径,直接返回
|
||||||
|
"""
|
||||||
|
if not file_obj.is_remote:
|
||||||
|
if os.path.exists(file_obj.url):
|
||||||
|
return file_obj.url
|
||||||
|
|
||||||
|
raise FileNotFoundError(f"Local file not found: {file_obj.url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.makedirs(FileOps.DOWNLOAD_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# 简单的文件名生成策略 (真实场景建议用 url hash 避免重复下载)
|
||||||
|
# ext = os.path.splitext(file_obj.url.split('?')[0])[1] or ".tmp"
|
||||||
|
# filename = f"{uuid.uuid4().hex}{ext}"
|
||||||
|
local_path = os.path.join(FileOps.DOWNLOAD_DIR, filename)
|
||||||
|
|
||||||
|
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}
|
||||||
|
with requests.get(file_obj.url, headers=headers, stream=True, timeout=120) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
with open(local_path, 'wb') as f:
|
||||||
|
for chunk in r.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
return local_path
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Download failed for {file_obj.url}: {str(e)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def read_bytes(file_obj:File) -> bytes:
|
||||||
|
"""
|
||||||
|
获取文件的原始二进制数据
|
||||||
|
场景:上传到OSS、保存到本地、传给图像处理库
|
||||||
|
"""
|
||||||
|
content, _ = FileOps._get_bytes_stream(file_obj)
|
||||||
|
return content
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_text(file_obj: File) -> str:
|
||||||
|
"""
|
||||||
|
提取文本内容
|
||||||
|
场景:RAG、HTML解析、文档分析
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
content, ext = FileOps._get_bytes_stream(file_obj)
|
||||||
|
|
||||||
|
if ext in ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx']:
|
||||||
|
return FileOps._parse_document_bytes(file_obj, content, ext)
|
||||||
|
|
||||||
|
# 默认直接读
|
||||||
|
charset = chardet.detect(content)
|
||||||
|
if 'encoding' in charset:
|
||||||
|
return content.decode(charset['encoding'])
|
||||||
|
else:
|
||||||
|
return content.decode('utf-8')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"[FileOps Error] Failed to read content: {str(e)}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_document_bytes(file_obj: File, content: bytes, ext:str) -> str:
|
||||||
|
stream = BytesIO(content)
|
||||||
|
text_result = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
if ext == '.pdf':
|
||||||
|
import pypdf
|
||||||
|
reader = pypdf.PdfReader(stream)
|
||||||
|
for page in reader.pages:
|
||||||
|
text_result += page.extract_text() + "\n"
|
||||||
|
elif ext in ['.docx', '.doc']:
|
||||||
|
text_result = read_docx(stream)
|
||||||
|
elif ext in ['.xlsx', '.xls', '.csv']:
|
||||||
|
import pandas as pd
|
||||||
|
if ext == '.csv':
|
||||||
|
df = pd.read_csv(stream)
|
||||||
|
else:
|
||||||
|
df = pd.read_excel(stream)
|
||||||
|
text_result = df.to_string()
|
||||||
|
elif ext in ['.ppt', '.pptx']:
|
||||||
|
text_result = read_ppt(stream)
|
||||||
|
else:
|
||||||
|
text_result = f"[暂不支持解析该文档格式: {ext}]"
|
||||||
|
except ImportError as e:
|
||||||
|
text_result = f"[解析库缺失] {e}"
|
||||||
|
except Exception as e:
|
||||||
|
text_result = f"[解析失败] {e}"
|
||||||
|
|
||||||
|
return text_result
|
||||||
|
|
||||||
|
def read_docx(cont_stream) -> str:
|
||||||
|
"""
|
||||||
|
使用docx2python按顺序读取内容
|
||||||
|
"""
|
||||||
|
from docx2python import docx2python
|
||||||
|
doc_result = docx2python(cont_stream)
|
||||||
|
|
||||||
|
# 获取文档结构
|
||||||
|
all_parts = []
|
||||||
|
|
||||||
|
# docx2python以嵌套列表形式返回内容
|
||||||
|
# 遍历文档主体
|
||||||
|
for section in doc_result.body:
|
||||||
|
if isinstance(section, list):
|
||||||
|
for item in section:
|
||||||
|
if isinstance(item, list):
|
||||||
|
# 可能是表格或多级内容
|
||||||
|
for sub_item in item:
|
||||||
|
if isinstance(sub_item, str) and sub_item.strip():
|
||||||
|
all_parts.append(sub_item.strip())
|
||||||
|
elif isinstance(sub_item, list):
|
||||||
|
# 表格行
|
||||||
|
row_text = "\n".join([str(cell).strip() for cell in sub_item if str(cell).strip()])
|
||||||
|
if row_text:
|
||||||
|
all_parts.append(row_text)
|
||||||
|
elif isinstance(item, str) and item.strip():
|
||||||
|
all_parts.append(item.strip())
|
||||||
|
|
||||||
|
# 关闭文档
|
||||||
|
doc_result.close()
|
||||||
|
|
||||||
|
return "\n\n".join(all_parts)
|
||||||
|
|
||||||
|
def read_ppt(file_input: Union[str, bytes, BytesIO]) -> str:
|
||||||
|
if not Presentation:
|
||||||
|
return "[Error] 未安装 python-pptx 库,无法解析 PPT 文件"
|
||||||
|
|
||||||
|
# 1. 统一转换为文件流对象 (BytesIO)
|
||||||
|
if isinstance(file_input, str):
|
||||||
|
with open(file_input, 'rb') as f:
|
||||||
|
ppt_stream = BytesIO(f.read())
|
||||||
|
elif isinstance(file_input, bytes):
|
||||||
|
ppt_stream = BytesIO(file_input)
|
||||||
|
else:
|
||||||
|
ppt_stream = file_input
|
||||||
|
|
||||||
|
try:
|
||||||
|
prs = Presentation(ppt_stream)
|
||||||
|
full_text = []
|
||||||
|
|
||||||
|
for i, slide in enumerate(prs.slides):
|
||||||
|
page_content = []
|
||||||
|
page_content.append(f"=== 第 {i+1} 页 ===")
|
||||||
|
|
||||||
|
# shape.text_frame 包含了形状内的文本段落
|
||||||
|
for shape in slide.shapes:
|
||||||
|
# 提取普通文本框
|
||||||
|
if hasattr(shape, "text") and shape.text.strip():
|
||||||
|
page_content.append(shape.text.strip())
|
||||||
|
|
||||||
|
# B. 提取表格内容 (普通 shape.text 无法获取表格内的字)
|
||||||
|
if shape.has_table:
|
||||||
|
table_texts = []
|
||||||
|
for row in shape.table.rows:
|
||||||
|
row_cells = [cell.text_frame.text.strip() for cell in row.cells if cell.text_frame.text.strip()]
|
||||||
|
if row_cells:
|
||||||
|
table_texts.append(" | ".join(row_cells))
|
||||||
|
if table_texts:
|
||||||
|
page_content.append("[表格]\n" + "\n".join(table_texts))
|
||||||
|
|
||||||
|
# 很多重要信息藏在备注里
|
||||||
|
if slide.has_notes_slide:
|
||||||
|
notes = slide.notes_slide.notes_text_frame.text
|
||||||
|
if notes.strip():
|
||||||
|
page_content.append(f"[备注]: {notes.strip()}")
|
||||||
|
|
||||||
|
full_text.append("\n".join(page_content))
|
||||||
|
|
||||||
|
return "\n\n".join(full_text)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"[PPT解析失败] {str(e)}"
|
||||||