547 lines
21 KiB
Python
547 lines
21 KiB
Python
import argparse
|
||
import asyncio
|
||
import json
|
||
import threading
|
||
import traceback
|
||
import logging
|
||
from typing import Any, Dict, Iterable, AsyncIterable, AsyncGenerator, Optional
|
||
import cozeloop
|
||
import uvicorn
|
||
import time
|
||
from fastapi import FastAPI, HTTPException, Request
|
||
from fastapi.responses import StreamingResponse, JSONResponse
|
||
from langchain_core.runnables import RunnableConfig
|
||
from langgraph.graph import StateGraph, END
|
||
from langgraph.graph.state import CompiledStateGraph
|
||
from coze_coding_utils.runtime_ctx.context import new_context, Context
|
||
from coze_coding_utils.helper import graph_helper
|
||
from coze_coding_utils.log.node_log import LOG_FILE
|
||
from coze_coding_utils.log.write_log import setup_logging, request_context
|
||
from coze_coding_utils.log.config import LOG_LEVEL
|
||
from coze_coding_utils.error.classifier import ErrorClassifier, classify_error
|
||
from coze_coding_utils.helper.stream_runner import AgentStreamRunner, WorkflowStreamRunner,agent_stream_handler,workflow_stream_handler, RunOpt
|
||
|
||
setup_logging(
|
||
log_file=LOG_FILE,
|
||
max_bytes=100 * 1024 * 1024, # 100MB
|
||
backup_count=5,
|
||
log_level=LOG_LEVEL,
|
||
use_json_format=True,
|
||
console_output=True
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
from coze_coding_utils.helper.agent_helper import to_stream_input
|
||
from coze_coding_utils.openai.handler import OpenAIChatHandler
|
||
from coze_coding_utils.log.parser import LangGraphParser
|
||
from coze_coding_utils.log.err_trace import extract_core_stack
|
||
from coze_coding_utils.log.loop_trace import init_run_config, init_agent_config
|
||
|
||
|
||
# 超时配置常量
|
||
TIMEOUT_SECONDS = 900 # 15分钟
|
||
|
||
class GraphService:
|
||
def __init__(self):
|
||
# 用于跟踪正在运行的任务(使用asyncio.Task)
|
||
self.running_tasks: Dict[str, asyncio.Task] = {}
|
||
# 错误分类器
|
||
self.error_classifier = ErrorClassifier()
|
||
# stream runner
|
||
self._agent_stream_runner = AgentStreamRunner()
|
||
self._workflow_stream_runner = WorkflowStreamRunner()
|
||
self._graph = None
|
||
self._graph_lock = threading.Lock()
|
||
|
||
def _get_graph(self, ctx=Context):
|
||
if graph_helper.is_agent_proj():
|
||
return graph_helper.get_agent_instance("agents.agent", ctx)
|
||
|
||
if self._graph is not None:
|
||
return self._graph
|
||
with self._graph_lock:
|
||
if self._graph is not None:
|
||
return self._graph
|
||
self._graph = graph_helper.get_graph_instance("graphs.graph")
|
||
return self._graph
|
||
|
||
@staticmethod
|
||
def _sse_event(data: Any, event_id: Any = None) -> str:
|
||
id_line = f"id: {event_id}\n" if event_id else ""
|
||
return f"{id_line}event: message\ndata: {json.dumps(data, ensure_ascii=False, default=str)}\n\n"
|
||
|
||
def _get_stream_runner(self):
|
||
if graph_helper.is_agent_proj():
|
||
return self._agent_stream_runner
|
||
else:
|
||
return self._workflow_stream_runner
|
||
|
||
# 流式运行(原始迭代器):本地调用使用
|
||
def stream(self, payload: Dict[str, Any], run_config: RunnableConfig, ctx=Context) -> Iterable[Any]:
|
||
graph = self._get_graph(ctx)
|
||
stream_runner = self._get_stream_runner()
|
||
for chunk in stream_runner.stream(payload, graph, run_config, ctx):
|
||
yield chunk
|
||
|
||
# 同步运行:本地/HTTP 通用
|
||
async def run(self, payload: Dict[str, Any], ctx=None) -> Dict[str, Any]:
|
||
if ctx is None:
|
||
ctx = new_context("run")
|
||
|
||
run_id = ctx.run_id
|
||
logger.info(f"Starting run with run_id: {run_id}")
|
||
|
||
try:
|
||
graph = self._get_graph(ctx)
|
||
# custom tracer
|
||
run_config = init_run_config(graph, ctx)
|
||
run_config["configurable"] = {"thread_id": ctx.run_id}
|
||
|
||
# 直接调用,LangGraph会在当前任务上下文中执行
|
||
# 如果当前任务被取消,LangGraph的执行也会被取消
|
||
return await graph.ainvoke(payload, config=run_config, context=ctx)
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info(f"Run {run_id} was cancelled")
|
||
return {"status": "cancelled", "run_id": run_id, "message": "Execution was cancelled"}
|
||
except Exception as e:
|
||
# 使用错误分类器分类错误
|
||
err = self.error_classifier.classify(e, {"node_name": "run", "run_id": run_id})
|
||
# 记录详细的错误信息和堆栈跟踪
|
||
logger.error(
|
||
f"Error in GraphService.run: [{err.code}] {err.message}\n"
|
||
f"Category: {err.category.name}\n"
|
||
f"Traceback:\n{extract_core_stack()}"
|
||
)
|
||
# 保留原始异常堆栈,便于上层返回真正的报错位置
|
||
raise
|
||
finally:
|
||
# 清理任务记录
|
||
self.running_tasks.pop(run_id, None)
|
||
|
||
# 流式运行(SSE 格式化):HTTP 路由使用
|
||
async def stream_sse(self, payload: Dict[str, Any], ctx=None, run_opt: Optional[RunOpt] = None) -> AsyncGenerator[str, None]:
|
||
if ctx is None:
|
||
ctx = new_context(method="stream_sse")
|
||
if run_opt is None:
|
||
run_opt = RunOpt()
|
||
|
||
run_id = ctx.run_id
|
||
logger.info(f"Starting stream with run_id: {run_id}")
|
||
graph = self._get_graph(ctx)
|
||
if graph_helper.is_agent_proj():
|
||
run_config = init_agent_config(graph, ctx)
|
||
else:
|
||
run_config = init_run_config(graph, ctx) # vibeflow
|
||
|
||
is_workflow = not graph_helper.is_agent_proj()
|
||
|
||
try:
|
||
async for chunk in self.astream(payload, graph, run_config=run_config, ctx=ctx, run_opt=run_opt):
|
||
if is_workflow and isinstance(chunk, tuple):
|
||
event_id, data = chunk
|
||
yield self._sse_event(data, event_id)
|
||
else:
|
||
yield self._sse_event(chunk)
|
||
finally:
|
||
# 清理任务记录
|
||
self.running_tasks.pop(run_id, None)
|
||
cozeloop.flush()
|
||
|
||
# 取消执行 - 使用asyncio的标准方式
|
||
def cancel_run(self, run_id: str, ctx: Optional[Context] = None) -> Dict[str, Any]:
|
||
"""
|
||
取消指定run_id的执行
|
||
|
||
使用asyncio.Task.cancel()来取消任务,这是标准的Python异步取消机制。
|
||
LangGraph会在节点之间检查CancelledError,实现优雅的取消。
|
||
"""
|
||
logger.info(f"Attempting to cancel run_id: {run_id}")
|
||
|
||
# 查找对应的任务
|
||
if run_id in self.running_tasks:
|
||
task = self.running_tasks[run_id]
|
||
if not task.done():
|
||
# 使用asyncio的标准取消机制
|
||
# 这会在下一个await点抛出CancelledError
|
||
task.cancel()
|
||
logger.info(f"Cancellation requested for run_id: {run_id}")
|
||
return {
|
||
"status": "success",
|
||
"run_id": run_id,
|
||
"message": "Cancellation signal sent, task will be cancelled at next await point"
|
||
}
|
||
else:
|
||
logger.info(f"Task already completed for run_id: {run_id}")
|
||
return {
|
||
"status": "already_completed",
|
||
"run_id": run_id,
|
||
"message": "Task has already completed"
|
||
}
|
||
else:
|
||
logger.warning(f"No active task found for run_id: {run_id}")
|
||
return {
|
||
"status": "not_found",
|
||
"run_id": run_id,
|
||
"message": "No active task found with this run_id. Task may have already completed or run_id is invalid."
|
||
}
|
||
|
||
# 运行指定节点:本地/HTTP 通用
|
||
async def run_node(self, node_id: str, payload: Dict[str, Any], ctx=None) -> Any:
|
||
if ctx is None or Context.run_id == "":
|
||
ctx = new_context(method="node_run")
|
||
|
||
_graph = self._get_graph()
|
||
node_func, input_cls, output_cls = graph_helper.get_graph_node_func_with_inout(_graph.get_graph(), node_id)
|
||
if node_func is None or input_cls is None:
|
||
raise KeyError(f"node_id '{node_id}' not found")
|
||
|
||
parser = LangGraphParser(_graph)
|
||
metadata = parser.get_node_metadata(node_id) or {}
|
||
|
||
_g = StateGraph(input_cls, input_schema=input_cls, output_schema=output_cls)
|
||
_g.add_node("sn", node_func, metadata=metadata)
|
||
_g.set_entry_point("sn")
|
||
_g.add_edge("sn", END)
|
||
_graph = _g.compile()
|
||
|
||
run_config = init_run_config(_graph, ctx)
|
||
return await _graph.ainvoke(payload, config=run_config)
|
||
|
||
def graph_inout_schema(self) -> Any:
|
||
if graph_helper.is_agent_proj():
|
||
return {"input_schema": {}, "output_schema": {}}
|
||
builder = getattr(self._get_graph(), 'builder', None)
|
||
if builder is not None:
|
||
input_cls = getattr(builder, 'input_schema', None) or self.graph.get_input_schema()
|
||
output_cls = getattr(builder, 'output_schema', None) or self.graph.get_output_schema()
|
||
else:
|
||
logger.warning(f"No builder input schema found for graph_inout_schema, using graph input schema instead")
|
||
input_cls = self.graph.get_input_schema()
|
||
output_cls = self.graph.get_output_schema()
|
||
|
||
return {
|
||
"input_schema": input_cls.model_json_schema(),
|
||
"output_schema": output_cls.model_json_schema(),
|
||
"code":0,
|
||
"msg":""
|
||
}
|
||
|
||
async def astream(self, payload: Dict[str, Any], graph: CompiledStateGraph, run_config: RunnableConfig, ctx=Context, run_opt: Optional[RunOpt] = None) -> AsyncIterable[Any]:
|
||
stream_runner = self._get_stream_runner()
|
||
async for chunk in stream_runner.astream(payload, graph, run_config, ctx, run_opt):
|
||
yield chunk
|
||
|
||
|
||
service = GraphService()
|
||
app = FastAPI()
|
||
|
||
# OpenAI 兼容接口处理器
|
||
openai_handler = OpenAIChatHandler(service)
|
||
|
||
|
||
HEADER_X_RUN_ID = "x-run-id"
|
||
@app.post("/run")
|
||
async def http_run(request: Request) -> Dict[str, Any]:
|
||
global result
|
||
raw_body = await request.body()
|
||
try:
|
||
body_text = raw_body.decode("utf-8")
|
||
except Exception as e:
|
||
body_text = str(raw_body)
|
||
raise HTTPException(status_code=400,
|
||
detail=f"Invalid JSON format: {body_text}, traceback: {traceback.format_exc()}, error: {e}")
|
||
|
||
ctx = new_context(method="run", headers=request.headers)
|
||
# 优先使用上游指定的 run_id,保证 cancel 能精确匹配
|
||
upstream_run_id = request.headers.get(HEADER_X_RUN_ID)
|
||
if upstream_run_id:
|
||
ctx.run_id = upstream_run_id
|
||
run_id = ctx.run_id
|
||
request_context.set(ctx)
|
||
|
||
logger.info(
|
||
f"Received request for /run: "
|
||
f"run_id={run_id}, "
|
||
f"query={dict(request.query_params)}, "
|
||
f"body={body_text}"
|
||
)
|
||
|
||
try:
|
||
payload = await request.json()
|
||
|
||
# 创建任务并记录 - 这是关键,让我们可以通过run_id取消任务
|
||
task = asyncio.create_task(service.run(payload, ctx))
|
||
service.running_tasks[run_id] = task
|
||
|
||
try:
|
||
result = await asyncio.wait_for(task, timeout=float(TIMEOUT_SECONDS))
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"Run execution timeout after {TIMEOUT_SECONDS}s for run_id: {run_id}")
|
||
task.cancel()
|
||
try:
|
||
result = await task
|
||
except asyncio.CancelledError:
|
||
return {
|
||
"status": "timeout",
|
||
"run_id": run_id,
|
||
"message": f"Execution timeout: exceeded {TIMEOUT_SECONDS} seconds"
|
||
}
|
||
|
||
if not result:
|
||
result = {}
|
||
if isinstance(result, dict):
|
||
result["run_id"] = run_id
|
||
return result
|
||
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"JSON decode error in http_run: {e}, traceback: {traceback.format_exc()}")
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON format, {extract_core_stack()}")
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info(f"Request cancelled for run_id: {run_id}")
|
||
result = {"status": "cancelled", "run_id": run_id, "message": "Execution was cancelled"}
|
||
return result
|
||
|
||
except Exception as e:
|
||
# 使用错误分类器获取错误信息
|
||
error_response = service.error_classifier.get_error_response(e, {"node_name": "http_run", "run_id": run_id})
|
||
logger.error(
|
||
f"Unexpected error in http_run: [{error_response['error_code']}] {error_response['error_message']}, "
|
||
f"traceback: {traceback.format_exc()}", exc_info=True
|
||
)
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail={
|
||
"error_code": error_response["error_code"],
|
||
"error_message": error_response["error_message"],
|
||
"stack_trace": extract_core_stack(),
|
||
}
|
||
)
|
||
finally:
|
||
cozeloop.flush()
|
||
|
||
|
||
HEADER_X_WORKFLOW_STREAM_MODE = "x-workflow-stream-mode"
|
||
|
||
|
||
def _register_task(run_id: str, task: asyncio.Task):
|
||
service.running_tasks[run_id] = task
|
||
|
||
|
||
@app.post("/stream_run")
|
||
async def http_stream_run(request: Request):
|
||
ctx = new_context(method="stream_run", headers=request.headers)
|
||
# 优先使用上游指定的 run_id,保证 cancel 能精确匹配
|
||
upstream_run_id = request.headers.get(HEADER_X_RUN_ID)
|
||
if upstream_run_id:
|
||
ctx.run_id = upstream_run_id
|
||
workflow_stream_mode = request.headers.get(HEADER_X_WORKFLOW_STREAM_MODE, "").lower()
|
||
workflow_debug = workflow_stream_mode == "debug"
|
||
request_context.set(ctx)
|
||
raw_body = await request.body()
|
||
try:
|
||
body_text = raw_body.decode("utf-8")
|
||
except Exception as e:
|
||
body_text = str(raw_body)
|
||
raise HTTPException(status_code=400,
|
||
detail=f"Invalid JSON format: {body_text}, traceback: {extract_core_stack()}, error: {e}")
|
||
run_id = ctx.run_id
|
||
is_agent = graph_helper.is_agent_proj()
|
||
logger.info(
|
||
f"Received request for /stream_run: "
|
||
f"run_id={run_id}, "
|
||
f"is_agent_project={is_agent}, "
|
||
f"query={dict(request.query_params)}, "
|
||
f"body={body_text}"
|
||
)
|
||
try:
|
||
payload = await request.json()
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"JSON decode error in http_stream_run: {e}, traceback: {traceback.format_exc()}")
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON format:{extract_core_stack()}")
|
||
|
||
if is_agent:
|
||
stream_generator = agent_stream_handler(
|
||
payload=payload,
|
||
ctx=ctx,
|
||
run_id=run_id,
|
||
stream_sse_func=service.stream_sse,
|
||
sse_event_func=service._sse_event,
|
||
error_classifier=service.error_classifier,
|
||
register_task_func=_register_task,
|
||
)
|
||
else:
|
||
stream_generator = workflow_stream_handler(
|
||
payload=payload,
|
||
ctx=ctx,
|
||
run_id=run_id,
|
||
stream_sse_func=service.stream_sse,
|
||
sse_event_func=service._sse_event,
|
||
error_classifier=service.error_classifier,
|
||
register_task_func=_register_task,
|
||
run_opt=RunOpt(workflow_debug=workflow_debug),
|
||
)
|
||
|
||
response = StreamingResponse(stream_generator, media_type="text/event-stream")
|
||
return response
|
||
|
||
@app.post("/cancel/{run_id}")
|
||
async def http_cancel(run_id: str, request: Request):
|
||
"""
|
||
取消指定run_id的执行
|
||
|
||
使用asyncio.Task.cancel()实现取消,这是Python标准的异步任务取消机制。
|
||
LangGraph会在节点之间的await点检查CancelledError,实现优雅取消。
|
||
"""
|
||
ctx = new_context(method="cancel", headers=request.headers)
|
||
request_context.set(ctx)
|
||
logger.info(f"Received cancel request for run_id: {run_id}")
|
||
result = service.cancel_run(run_id, ctx)
|
||
return result
|
||
|
||
|
||
@app.post(path="/node_run/{node_id}")
|
||
async def http_node_run(node_id: str, request: Request):
|
||
raw_body = await request.body()
|
||
try:
|
||
body_text = raw_body.decode("utf-8")
|
||
except UnicodeDecodeError:
|
||
body_text = str(raw_body)
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON format: {body_text}")
|
||
ctx = new_context(method="node_run", headers=request.headers)
|
||
request_context.set(ctx)
|
||
logger.info(
|
||
f"Received request for /node_run/{node_id}: "
|
||
f"query={dict(request.query_params)}, "
|
||
f"body={body_text}",
|
||
)
|
||
|
||
try:
|
||
payload = await request.json()
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"JSON decode error in http_node_run: {e}, traceback: {traceback.format_exc()}")
|
||
raise HTTPException(status_code=400, detail=f"Invalid JSON format:{extract_core_stack()}")
|
||
try:
|
||
return await service.run_node(node_id, payload, ctx)
|
||
except KeyError:
|
||
raise HTTPException(status_code=404,
|
||
detail=f"node_id '{node_id}' not found or input miss required fields, traceback: {extract_core_stack()}")
|
||
except Exception as e:
|
||
# 使用错误分类器获取错误信息
|
||
error_response = service.error_classifier.get_error_response(e, {"node_name": node_id})
|
||
logger.error(
|
||
f"Unexpected error in http_node_run: [{error_response['error_code']}] {error_response['error_message']}, "
|
||
f"traceback: {traceback.format_exc()}", exc_info=True
|
||
)
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail={
|
||
"error_code": error_response["error_code"],
|
||
"error_message": error_response["error_message"],
|
||
"stack_trace": extract_core_stack(),
|
||
}
|
||
)
|
||
finally:
|
||
cozeloop.flush()
|
||
|
||
|
||
@app.post("/v1/chat/completions")
|
||
async def openai_chat_completions(request: Request):
|
||
"""OpenAI Chat Completions API 兼容接口"""
|
||
ctx = new_context(method="openai_chat", headers=request.headers)
|
||
request_context.set(ctx)
|
||
|
||
logger.info(f"Received request for /v1/chat/completions: run_id={ctx.run_id}")
|
||
|
||
try:
|
||
payload = await request.json()
|
||
return await openai_handler.handle(payload, ctx)
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"JSON decode error in openai_chat_completions: {e}")
|
||
raise HTTPException(status_code=400, detail="Invalid JSON format")
|
||
finally:
|
||
cozeloop.flush()
|
||
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
try:
|
||
# 这里可以添加更多的健康检查逻辑
|
||
return {
|
||
"status": "ok",
|
||
"message": "Service is running",
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=503, detail=str(e))
|
||
|
||
|
||
@app.get(path="/graph_parameter")
|
||
async def http_graph_inout_parameter(request: Request):
|
||
return service.graph_inout_schema()
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(description="Start FastAPI server")
|
||
parser.add_argument("-m", type=str, default="http", help="Run mode, support http,flow,node")
|
||
parser.add_argument("-n", type=str, default="", help="Node ID for single node run")
|
||
parser.add_argument("-p", type=int, default=5000, help="HTTP server port")
|
||
parser.add_argument("-i", type=str, default="", help="Input JSON string for flow/node mode")
|
||
return parser.parse_args()
|
||
|
||
|
||
def parse_input(input_str: str) -> Dict[str, Any]:
|
||
"""Parse input string, support both JSON string and plain text"""
|
||
if not input_str:
|
||
return {"text": "你好"}
|
||
|
||
# Try to parse as JSON first
|
||
try:
|
||
return json.loads(input_str)
|
||
except json.JSONDecodeError:
|
||
# If not valid JSON, treat as plain text
|
||
return {"text": input_str}
|
||
|
||
def start_http_server(port):
|
||
workers = 1
|
||
reload = False
|
||
if graph_helper.is_dev_env():
|
||
reload = True
|
||
|
||
logger.info(f"Start HTTP Server, Port: {port}, Workers: {workers}")
|
||
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=reload, workers=workers)
|
||
|
||
if __name__ == "__main__":
|
||
args = parse_args()
|
||
if args.m == "http":
|
||
start_http_server(args.p)
|
||
elif args.m == "flow":
|
||
payload = parse_input(args.i)
|
||
result = asyncio.run(service.run(payload))
|
||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||
elif args.m == "node" and args.n:
|
||
payload = parse_input(args.i)
|
||
result = asyncio.run(service.run_node(args.n, payload))
|
||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||
elif args.m == "agent":
|
||
agent_ctx = new_context(method="agent")
|
||
for chunk in service.stream(
|
||
{
|
||
"type": "query",
|
||
"session_id": "1",
|
||
"message": "你好",
|
||
"content": {
|
||
"query": {
|
||
"prompt": [
|
||
{
|
||
"type": "text",
|
||
"content": {"text": "现在几点了?请调用工具获取当前时间"},
|
||
}
|
||
]
|
||
}
|
||
},
|
||
},
|
||
run_config={"configurable": {"session_id": "1"}},
|
||
ctx=agent_ctx,
|
||
):
|
||
print(chunk)
|