from PIL import Image import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation from typing import Dict, List, Any import base64 from io import BytesIO import os import boto3 import datetime device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class EndpointHandler(): def __init__(self, path=""): self.model = AutoModelForImageSegmentation.from_pretrained( 'whlzy/remove_bg_api', trust_remote_code=True, token=os.environ.get("HUGGINGFACE_TOKEN") ) self.model.to(device) self.model.eval() image_size = (1024, 1024) self.transform_image = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def update_to_s3(self, image): BUCKET_NAME = 'popwear-assets' BUCKET_PREFIX_PATH = 'removebg' ACCOUNT_ID = '18cc2282d0ee72171c1ea322ed22983c' ACCESS_KEY_ID = '007f1852a377a2df43a21d5c8d54542e' SECRET_ACCESS_KEY = 'db2658e2429950bb05e15afb6c53c8b7fd23ab9e1bf79cd42604c89f276068e4' ENDPOINT_URL = f'https://{ACCOUNT_ID}.r2.cloudflarestorage.com' bucket_postfix_path = f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jpg" image_url = f"https://assets.popwear.ai/{BUCKET_PREFIX_PATH}/{bucket_postfix_path}" s3 = boto3.client( 's3', endpoint_url=ENDPOINT_URL, aws_access_key_id=ACCESS_KEY_ID, aws_secret_access_key=SECRET_ACCESS_KEY, region_name='auto' ) output_buffer = BytesIO() image.save(output_buffer, format='WEBP', quality=85, method=4) output_buffer.seek(0) s3.upload_fileobj(output_buffer, BUCKET_NAME, f"{BUCKET_PREFIX_PATH}/{bucket_postfix_path}") return image_url def __call__(self, data: Any) -> List[List[Dict[str, float]]]: image = data.pop("inputs", data) # image = self.decode_base64_image(image) input_images = self.transform_image(image).unsqueeze(0).to('cuda') with torch.no_grad(): preds = self.model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image.size) image.putalpha(mask) image_url = self.update_to_s3(image) return image_url def decode_base64_image(self, image_string): base64_image = base64.b64decode(image_string) buffer = BytesIO(base64_image) image = Image.open(buffer) return image