From 3a145075eff1419ec4dda4240eaf7956773a6cbd Mon Sep 17 00:00:00 2001 From: zhangquan Date: Mon, 30 Mar 2026 20:53:36 +0800 Subject: [PATCH] 1 --- src/storage/s3/__init__.py | 0 src/storage/s3/s3_storage.py | 424 ----------------------------------- 2 files changed, 424 deletions(-) delete mode 100644 src/storage/s3/__init__.py delete mode 100644 src/storage/s3/s3_storage.py diff --git a/src/storage/s3/__init__.py b/src/storage/s3/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/storage/s3/s3_storage.py b/src/storage/s3/s3_storage.py deleted file mode 100644 index 56ba3bc..0000000 --- a/src/storage/s3/s3_storage.py +++ /dev/null @@ -1,424 +0,0 @@ -import os -import re -from pathlib import Path -from typing import Optional, Any, Dict, List, TypedDict, Iterable -from uuid import uuid4 - -import boto3 -from botocore.exceptions import ClientError -from boto3.s3.transfer import TransferConfig -import logging -logger = logging.getLogger(__name__) - -# 允许的文件名字符集(面向用户输入的约束) -FILE_NAME_ALLOWED_RE = re.compile(r"^[A-Za-z0-9._\-/]+$") - - -class ListFilesResult(TypedDict): - # list_files 的返回结构类型 - keys: List[str] - is_truncated: bool - next_continuation_token: Optional[str] - -class S3SyncStorage: - """S3兼容存储实现""" - - def __init__(self, *, endpoint_url: Optional[str] = None, access_key: str, secret_key: str, bucket_name: str, region: str = "cn-beijing"): - self.endpoint_url = os.environ.get("COZE_BUCKET_ENDPOINT_URL") or endpoint_url or '' - self.access_key = access_key - self.secret_key = secret_key - self.bucket_name = bucket_name - self.region = region - self._client = None - - def _get_client(self): - if self._client is None: - endpoint = self.endpoint_url - if endpoint is None or endpoint == "": - try: - from coze_workload_identity import Client as CozeEnvClient - coze_env_client = CozeEnvClient() - env_vars = coze_env_client.get_project_env_vars() - coze_env_client.close() - for env_var in env_vars: - if env_var.key == "COZE_BUCKET_ENDPOINT_URL": - endpoint = env_var.value.replace("'", "'\\''") - self.endpoint_url = endpoint - break - except Exception as e: - logger.error(f"Error loading COZE_BUCKET_ENDPOINT_URL: {e}") - # 保持向下校验逻辑,避免在此处中断 - if endpoint is None or endpoint == "": - logger.error("未配置存储端点:请设置endpoint_url") - raise ValueError("未配置存储端点:请设置endpoint_url") - - client = boto3.client( - "s3", - endpoint_url=endpoint, - aws_access_key_id=self.access_key, - aws_secret_access_key=self.secret_key, - region_name=self.region, - ) - - # 注册 before-call 钩子,发送前注入 x-storage-token 头 - def _inject_header(**kwargs): - try: - from coze_workload_identity import Client as CozeClient - coze_client = CozeClient() - try: - token = coze_client.get_access_token() - except Exception as e: - logger.error("Error loading COZE_WORKLOAD_IDENTITY_TOKEN: %s", e) - token = None - raise e - finally: - coze_client.close() - params = kwargs.get("params", {}) - headers = params.setdefault("headers", {}) - headers["x-storage-token"] = token - except Exception as e: - logger.error("Error loading COZE_WORKLOAD_IDENTITY_TOKEN: %s", e) - pass - client.meta.events.register("before-call.s3", _inject_header) - self._client = client - return self._client - - def _generate_object_key(self, *, original_name: str) -> str: - suffix = Path(original_name).suffix.lower() - stem = Path(original_name).stem - uniq = uuid4().hex[:8] - return f"{stem}_{uniq}{suffix}" - - def _extract_logid(self, e: Exception) -> Optional[str]: - """从 ClientError 中提取 x-tt-logid""" - if isinstance(e, ClientError): - headers = (e.response or {}).get("ResponseMetadata", {}).get("HTTPHeaders", {}) - return headers.get("x-tt-logid") - return None - - def _error_msg(self, msg: str, e: Exception) -> str: - """构建带 logid 的错误信息""" - logid = self._extract_logid(e) - if logid: - return f"{msg}: {e} (x-tt-logid: {logid})" - return f"{msg}: {e}" - - def _resolve_bucket(self, bucket: Optional[str]) -> str: - """统一解析 bucket 来源,确保得到有效桶名。""" - target_bucket = bucket or os.environ.get("COZE_BUCKET_NAME") or self.bucket_name - if not target_bucket: - raise ValueError("未配置 bucket:请传入 bucket 或设置 COZE_BUCKET_NAME,或在实例化时提供 bucket_name") - return target_bucket - - def _validate_file_name(self, name: str) -> None: - """校验 S3 对象命名:长度≤1024;允许 [A-Za-z0-9._-/];不以 / 起止且不含 //。""" - msg = ( - "file name invalid: 文件名需满足以下 S3 对象命名规范:" - "1) 长度 1–1024 字节;" - "2) 仅允许字母、数字、点(.)、下划线(_)、短横(-)、目录分隔符(/);" - "3) 不允许空格或以下特殊字符:? # & % { } ^ [ ] ` \\ < > ~ | \" ' + = : ;;" - "4) 不以 / 开头或结尾,且不包含连续的 //;" - "示例:report_2025-12-11.pdf、images/photo-01.png。" - ) - - if not name or not name.strip(): - raise ValueError(msg + "(原因:文件名为空)") - - # S3 限制对象 key 最大 1024 字节,这里沿用到输入文件名 - if len(name.encode("utf-8")) > 1024: - raise ValueError(msg + "(原因:长度超过 1024 字节)") - - if name.startswith("/") or name.endswith("/"): - raise ValueError(msg + "(原因:以 / 开头或结尾)") - if "//" in name: - raise ValueError(msg + "(原因:包含连续的 //)") - - # 允许字符集校验 - if not FILE_NAME_ALLOWED_RE.match(name): - bad = re.findall(r"[^A-Za-z0-9._\-/]", name) - example = bad[0] if bad else "非法字符" - raise ValueError(msg + f"(原因:包含非法字符,例如:{example})") - - def upload_file(self, *, file_content: bytes, file_name: str, content_type: str = "application/octet-stream", bucket: Optional[str] = None) -> str: - # 先对输入文件名做规范校验,避免生成无效对象 key - self._validate_file_name(file_name) - try: - client = self._get_client() - object_key = self._generate_object_key(original_name=file_name) - target_bucket = self._resolve_bucket(bucket) - client.put_object(Bucket=target_bucket, Key=object_key, Body=file_content, ContentType=content_type) - return object_key - except Exception as e: - logger.error(self._error_msg("Error uploading file to S3", e)) - raise e - - def delete_file(self, *, file_key: str, bucket: Optional[str] = None) -> bool: - try: - client = self._get_client() - target_bucket = self._resolve_bucket(bucket) - client.delete_object(Bucket=target_bucket, Key=file_key) - return True - except Exception as e: - logger.error(self._error_msg("Error deleting file from S3", e)) - raise e - - def file_exists(self, *, file_key: str, bucket: Optional[str] = None) -> bool: - try: - client = self._get_client() - target_bucket = self._resolve_bucket(bucket) - client.head_object(Bucket=target_bucket, Key=file_key) - return True - except ClientError as e: - code = (e.response or {}).get("Error", {}).get("Code", "") - if code in {"404", "NoSuchKey", "NotFound"}: - return False - logger.error(self._error_msg("Error checking file existence in S3", e)) - return False - except Exception as e: - logger.error(self._error_msg("Error checking file existence in S3", e)) - return False - - def read_file(self, *, file_key: str, bucket: Optional[str] = None) -> bytes: - try: - client = self._get_client() - target_bucket = self._resolve_bucket(bucket) - resp = client.get_object(Bucket=target_bucket, Key=file_key) - body = resp.get("Body") - if body is None: - raise RuntimeError("S3 get_object returned no Body") - try: - return body.read() - finally: - try: - body.close() - except Exception as ce: - # 资源关闭失败不影响读取结果,仅记录以便排查 - logger.debug("Failed to close S3 response body: %s", ce) - except Exception as e: - logger.error(self._error_msg("Error reading file from S3", e)) - raise e - - def list_files(self, *, prefix: Optional[str] = None, bucket: Optional[str] = None, max_keys: int = 1000, continuation_token: Optional[str] = None) -> ListFilesResult: - """列出对象,支持前缀过滤与分页;返回 keys/is_truncated/next_continuation_token。""" - try: - client = self._get_client() - target_bucket = self._resolve_bucket(bucket) - if max_keys <= 0 or max_keys > 1000: - raise ValueError("max_keys 必须在 1 到 1000 之间") - - kwargs: Dict[str, Any] = { - "Bucket": target_bucket, - "MaxKeys": max_keys, - "Prefix": prefix, - "ContinuationToken": continuation_token, - } - kwargs = {k: v for k, v in kwargs.items() if v is not None} - - resp = client.list_objects_v2(**kwargs) - contents = resp.get("Contents", []) or [] - keys: List[str] = [item.get("Key") for item in contents if isinstance(item, dict) and item.get("Key")] - return { - "keys": keys, - "is_truncated": bool(resp.get("IsTruncated")), - "next_continuation_token": resp.get("NextContinuationToken"), - } - except ClientError as e: - code = (e.response or {}).get("Error", {}).get("Code", "") - logger.error(self._error_msg(f"Error listing files in S3 (code={code})", e)) - raise e - except Exception as e: - logger.error(self._error_msg("Error listing files in S3", e)) - raise e - - def generate_presigned_url(self, *, key: str, bucket: Optional[str] = None, expire_time: int = 1800) -> str: - """通过 S3 Proxy 生成签名 URL。""" - import json - import urllib.request as urllib_request - try: - from coze_workload_identity import Client as CozeClient - coze_client = CozeClient() - try: - token = coze_client.get_access_token() - finally: - try: - coze_client.close() - except Exception: - # 资源释放失败不影响后续流程 - pass - except Exception as e: - logger.error(f"Error loading x-storage-token: {e}") - raise RuntimeError(f"获取 x-storage-token 失败: {e}") - try: - sign_base = os.environ.get("COZE_BUCKET_ENDPOINT_URL") or self.endpoint_url - if not sign_base: - raise ValueError("未配置签名端点:请设置 COZE_BUCKET_ENDPOINT_URL 或传入 endpoint_url") - sign_url_endpoint = sign_base.rstrip("/") + "/sign-url" - - headers = { - "Content-Type": "application/json", - "x-storage-token": token, - } - - target_bucket = self._resolve_bucket(bucket) - payload = {"bucket_name": target_bucket, "path": key, "expire_time": expire_time} - data = json.dumps(payload).encode("utf-8") - request = urllib_request.Request(sign_url_endpoint, data=data, headers=headers, method="POST") - except Exception as e: - logger.error(f"Error creating request for sign-url: {e}") - raise RuntimeError(f"创建 sign-url 请求失败: {e}") - - try: - with urllib_request.urlopen(request) as resp: - resp_bytes = resp.read() - content_type = resp.headers.get("Content-Type", "") - text = resp_bytes.decode("utf-8", errors="replace") - if "application/json" in content_type or text.strip().startswith("{"): - try: - obj = json.loads(text) - except Exception: - return text - data = obj.get("data") - if isinstance(data, dict) and "url" in data: - return data["url"] - url_value = obj.get("url") or obj.get("signed_url") or obj.get("presigned_url") - if url_value: - return url_value - raise ValueError("签名服务返回缺少 data.url/url 字段") - return text - except Exception as e: - raise RuntimeError(f"生成签名URL失败: {e}") - - def stream_upload_file( - self, - *, - fileobj, - file_name: str, - content_type: str = "application/octet-stream", - bucket: Optional[str] = None, - multipart_chunksize: int = 5 * 1024 * 1024, - multipart_threshold: int = 5 * 1024 * 1024, - max_concurrency: int = 1, - use_threads: bool = False, - ) -> str: - """流式上传(文件对象) - - fileobj: 任何带有 read() 方法的文件对象(如 open(..., 'rb') 返回的对象、io.BytesIO 等) - - file_name: 原始文件名,用于生成唯一 key - - content_type: MIME 类型 - - bucket: 目标桶;为空时取环境变量或实例默认值 - - multipart_chunksize: 分片大小(默认 5MB,以适配代理层限制) - - multipart_threshold: 触发分片上传的阈值(默认 5MB) - - max_concurrency: 并发分片上传的并发数(默认 1,避免代理层节流影响) - - use_threads: 是否启用线程并发(默认 False) - 返回:最终写入的对象 key - """ - try: - client = self._get_client() - target_bucket = self._resolve_bucket(bucket) - key = self._generate_object_key(original_name=file_name) - - extra_args = {"ContentType": content_type} if content_type else {} - # 使用 boto3 的高阶方法执行多段上传(传入 TransferConfig 控制分片大小) - - config = TransferConfig( - multipart_chunksize=multipart_chunksize, - multipart_threshold=multipart_threshold, - max_concurrency=max_concurrency, - use_threads=use_threads, - ) - client.upload_fileobj(Fileobj=fileobj, Bucket=target_bucket, Key=key, ExtraArgs=extra_args, Config=config) - return key - except Exception as e: - logger.error(self._error_msg("Error streaming upload (fileobj) to S3", e)) - raise e - - def upload_from_url( - self, - *, - url: str, - bucket: Optional[str] = None, - timeout: int = 30, - ) -> str: - """从 URL 流式下载并上传到 S3 - - url: 源文件 URL - - bucket: 目标桶;为空时取环境变量或实例默认值 - - timeout: HTTP 请求超时时间(秒,默认 30) - 返回:最终写入的对象 key - """ - import urllib.request as urllib_request - from urllib.parse import urlparse, unquote - try: - request = urllib_request.Request(url) - with urllib_request.urlopen(request, timeout=timeout) as resp: - parsed = urlparse(url) - file_name = Path(unquote(parsed.path)).name or "file" - content_type = resp.headers.get("Content-Type", "application/octet-stream") - return self.stream_upload_file( - fileobj=resp, - file_name=file_name, - content_type=content_type, - bucket=bucket, - ) - except Exception as e: - logger.error(self._error_msg("Error uploading from URL to S3", e)) - raise e - - def trunk_upload_file(self, *, chunk_iter: Iterable[bytes], file_name: str, - content_type: str = "application/octet-stream", bucket: Optional[str] = None, - part_size: int = 5 * 1024 * 1024) -> str: - """流式上传(字节迭代器,显式分片 Multipart Upload) - - chunk_iter: 可迭代对象,逐块产生 bytes;每块大小可变(内部累积到 part_size 再上传),最后一块可小于 5MB - - file_name: 原始文件名,用于生成唯一 key - - content_type: MIME 类型 - - bucket: 目标桶;为空时取环境或实例默认值 - - part_size: 每个 part 的最小大小(除最后一个);默认 5MB - 返回:最终写入的对象 key - """ - client = self._get_client() - target_bucket = self._resolve_bucket(bucket) - key = self._generate_object_key(original_name=file_name) - - # 初始化分片上传 - try: - init_resp = client.create_multipart_upload(Bucket=target_bucket, Key=key, ContentType=content_type) - upload_id = init_resp["UploadId"] - except Exception as e: - logger.error(self._error_msg("create_multipart_upload failed", e)) - raise e - - parts = [] - part_number = 1 - buffer = bytearray() - try: - for chunk in chunk_iter: - if not chunk: - continue - buffer.extend(chunk) - while len(buffer) >= part_size: - data = bytes(buffer[:part_size]) - buffer = buffer[part_size:] - resp = client.upload_part(Bucket=target_bucket, Key=key, UploadId=upload_id, PartNumber=part_number, - Body=data) - parts.append({"PartNumber": part_number, "ETag": resp["ETag"]}) - part_number += 1 - - # 上传最后不足 part_size 的余量 - if len(buffer) > 0: - resp = client.upload_part(Bucket=target_bucket, Key=key, UploadId=upload_id, PartNumber=part_number, - Body=bytes(buffer)) - parts.append({"PartNumber": part_number, "ETag": resp["ETag"]}) - - # 完成分片 - client.complete_multipart_upload( - Bucket=target_bucket, - Key=key, - UploadId=upload_id, - MultipartUpload={"Parts": parts}, - ) - return key - except Exception as e: - logger.error(self._error_msg("multipart upload failed", e)) - try: - client.abort_multipart_upload(Bucket=target_bucket, Key=key, UploadId=upload_id) - except Exception as ae: - logger.error(self._error_msg("abort_multipart_upload failed", ae)) - raise e