baidu_segment.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. """Baidu intelligent cutout fallback for transparent PNG assets."""
  2. import base64
  3. import io
  4. import json
  5. import time
  6. import urllib.parse
  7. import urllib.request
  8. import urllib.error
  9. from PIL import Image
  10. import config
  11. TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
  12. SEGMENT_URL = "https://aip.baidubce.com/rest/2.0/image-process/v1/segment"
  13. _TOKEN_CACHE = {"token": "", "expires_at": 0.0}
  14. def _post_json(url, payload, timeout=60):
  15. data = json.dumps(payload).encode("utf-8")
  16. req = urllib.request.Request(
  17. url,
  18. data=data,
  19. headers={"Content-Type": "application/json; charset=UTF-8"},
  20. method="POST",
  21. )
  22. try:
  23. with urllib.request.urlopen(req, timeout=timeout) as res:
  24. return json.loads(res.read().decode("utf-8", "ignore"))
  25. except urllib.error.HTTPError as e:
  26. body = e.read().decode("utf-8", "ignore")[:800]
  27. raise RuntimeError(f"Baidu HTTP {e.code}: {body}")
  28. def _get_access_token():
  29. now = time.time()
  30. if _TOKEN_CACHE["token"] and _TOKEN_CACHE["expires_at"] > now + 60:
  31. return _TOKEN_CACHE["token"]
  32. api_key = config.get("BAIDU_IMAGEPROCESS_API_KEY", "")
  33. secret_key = config.get("BAIDU_IMAGEPROCESS_SECRET_KEY", "")
  34. if not api_key or not secret_key:
  35. raise RuntimeError("missing BAIDU_IMAGEPROCESS_API_KEY / BAIDU_IMAGEPROCESS_SECRET_KEY")
  36. query = urllib.parse.urlencode({
  37. "grant_type": "client_credentials",
  38. "client_id": api_key,
  39. "client_secret": secret_key,
  40. })
  41. with urllib.request.urlopen(TOKEN_URL + "?" + query, timeout=30) as res:
  42. data = json.loads(res.read().decode("utf-8", "ignore"))
  43. token = data.get("access_token")
  44. if not token:
  45. raise RuntimeError(f"Baidu token response missing access_token: {json.dumps(data, ensure_ascii=False)[:500]}")
  46. _TOKEN_CACHE["token"] = token
  47. _TOKEN_CACHE["expires_at"] = now + int(data.get("expires_in", 2592000))
  48. return token
  49. def remove_background(img, label="", log=print):
  50. """Return a PNG RGBA image with transparent background using Baidu segment."""
  51. token = _get_access_token()
  52. buf = io.BytesIO()
  53. img.convert("RGBA").save(buf, format="PNG")
  54. payload = {
  55. "image": base64.b64encode(buf.getvalue()).decode("ascii"),
  56. "method": "auto",
  57. "refine_mask": "true",
  58. "return_form": "rgba",
  59. }
  60. if log:
  61. log(f"🧠 [{label}] 百度智能抠图:提交自动主体抠图…")
  62. data = _post_json(SEGMENT_URL + "?access_token=" + urllib.parse.quote(token), payload, timeout=120)
  63. if data.get("error_code"):
  64. raise RuntimeError(f"Baidu segment error {data.get('error_code')}: {data.get('error_msg')}")
  65. raw_b64 = data.get("image")
  66. if not raw_b64:
  67. raise RuntimeError(f"Baidu segment response missing image: {json.dumps(data, ensure_ascii=False)[:500]}")
  68. out = Image.open(io.BytesIO(base64.b64decode(raw_b64))).convert("RGBA")
  69. if log:
  70. log(f"✅ [{label}] 百度智能抠图完成:{out.width}×{out.height}")
  71. return out