Jelajahi Sumber

Add Baidu cutout fallback for alpha assets

bang 2 minggu lalu
induk
melakukan
6159cee6b8
3 mengubah file dengan 96 tambahan dan 2 penghapusan
  1. 82 0
      baidu_segment.py
  2. 4 1
      local_config.example.json
  3. 10 1
      pipeline.py

+ 82 - 0
baidu_segment.py

@@ -0,0 +1,82 @@
+"""Baidu intelligent cutout fallback for transparent PNG assets."""
+
+import base64
+import io
+import json
+import time
+import urllib.parse
+import urllib.request
+import urllib.error
+
+from PIL import Image
+
+import config
+
+
+TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
+SEGMENT_URL = "https://aip.baidubce.com/rest/2.0/image-process/v1/segment"
+_TOKEN_CACHE = {"token": "", "expires_at": 0.0}
+
+
+def _post_json(url, payload, timeout=60):
+    data = json.dumps(payload).encode("utf-8")
+    req = urllib.request.Request(
+        url,
+        data=data,
+        headers={"Content-Type": "application/json; charset=UTF-8"},
+        method="POST",
+    )
+    try:
+        with urllib.request.urlopen(req, timeout=timeout) as res:
+            return json.loads(res.read().decode("utf-8", "ignore"))
+    except urllib.error.HTTPError as e:
+        body = e.read().decode("utf-8", "ignore")[:800]
+        raise RuntimeError(f"Baidu HTTP {e.code}: {body}")
+
+
+def _get_access_token():
+    now = time.time()
+    if _TOKEN_CACHE["token"] and _TOKEN_CACHE["expires_at"] > now + 60:
+        return _TOKEN_CACHE["token"]
+    api_key = config.get("BAIDU_IMAGEPROCESS_API_KEY", "")
+    secret_key = config.get("BAIDU_IMAGEPROCESS_SECRET_KEY", "")
+    if not api_key or not secret_key:
+        raise RuntimeError("missing BAIDU_IMAGEPROCESS_API_KEY / BAIDU_IMAGEPROCESS_SECRET_KEY")
+    query = urllib.parse.urlencode({
+        "grant_type": "client_credentials",
+        "client_id": api_key,
+        "client_secret": secret_key,
+    })
+    with urllib.request.urlopen(TOKEN_URL + "?" + query, timeout=30) as res:
+        data = json.loads(res.read().decode("utf-8", "ignore"))
+    token = data.get("access_token")
+    if not token:
+        raise RuntimeError(f"Baidu token response missing access_token: {json.dumps(data, ensure_ascii=False)[:500]}")
+    _TOKEN_CACHE["token"] = token
+    _TOKEN_CACHE["expires_at"] = now + int(data.get("expires_in", 2592000))
+    return token
+
+
+def remove_background(img, label="", log=print):
+    """Return a PNG RGBA image with transparent background using Baidu segment."""
+    token = _get_access_token()
+    buf = io.BytesIO()
+    img.convert("RGBA").save(buf, format="PNG")
+    payload = {
+        "image": base64.b64encode(buf.getvalue()).decode("ascii"),
+        "method": "auto",
+        "refine_mask": "true",
+        "return_form": "rgba",
+    }
+    if log:
+        log(f"🧠 [{label}] 百度智能抠图:提交自动主体抠图…")
+    data = _post_json(SEGMENT_URL + "?access_token=" + urllib.parse.quote(token), payload, timeout=120)
+    if data.get("error_code"):
+        raise RuntimeError(f"Baidu segment error {data.get('error_code')}: {data.get('error_msg')}")
+    raw_b64 = data.get("image")
+    if not raw_b64:
+        raise RuntimeError(f"Baidu segment response missing image: {json.dumps(data, ensure_ascii=False)[:500]}")
+    out = Image.open(io.BytesIO(base64.b64decode(raw_b64))).convert("RGBA")
+    if log:
+        log(f"✅ [{label}] 百度智能抠图完成:{out.width}×{out.height}")
+    return out

+ 4 - 1
local_config.example.json

@@ -2,5 +2,8 @@
   "ANIM_STUDIO_BASE_URL": "https://x.long.bid/v1",
   "ANIM_STUDIO_API_KEY": "replace-with-your-key",
   "ANIM_STUDIO_IMAGE_MODEL": "gpt-image-2",
-  "ANIM_STUDIO_TEXT_MODEL": "gpt-5.4-mini"
+  "ANIM_STUDIO_TEXT_MODEL": "gpt-5.4-mini",
+  "BAIDU_IMAGEPROCESS_APP_ID": "replace-with-your-app-id",
+  "BAIDU_IMAGEPROCESS_API_KEY": "replace-with-your-api-key",
+  "BAIDU_IMAGEPROCESS_SECRET_KEY": "replace-with-your-secret-key"
 }

+ 10 - 1
pipeline.py

@@ -9,6 +9,7 @@ import providers
 import spine_builder
 import particle_builder
 import tween_builder
+import baidu_segment
 
 HERE = os.path.dirname(os.path.abspath(__file__))
 
@@ -85,7 +86,15 @@ def run(manifest, out_root, creds=None, log=print):
             log_alpha(label, img, require_alpha)
             if not require_alpha or has_alpha(img):
                 return img
-        raise RuntimeError(f"模型连续 {len(retry_suffixes)} 次没有返回真实 Alpha 透明通道;请换支持透明输出的图像模型或稍后重试")
+        log(f"🧠 [{label}] 模型连续 {len(retry_suffixes)} 次没有真实 Alpha,改用百度智能抠图兜底…")
+        try:
+            fixed = baidu_segment.remove_background(last, label=label, log=log)
+        except Exception as e:
+            raise RuntimeError(f"模型连续 {len(retry_suffixes)} 次没有返回真实 Alpha,百度智能抠图也失败:{e}")
+        log_alpha(label, fixed, True)
+        if has_alpha(fixed):
+            return fixed
+        raise RuntimeError("百度智能抠图返回结果仍没有真实 Alpha 透明通道")
 
     # ---- A. 角色(Spine)----
     for i, c in enumerate(manifest.get("characters", [])):