File size: 2,696 Bytes
f9cfd2a c55c1eb 5e05b9e f9cfd2a c55c1eb 1104805 c55c1eb e928816 9a6d4b3 e928816 c55c1eb f9cfd2a 314fa7f f9cfd2a c55c1eb f9cfd2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Dict, List, Any
import base64
from io import BytesIO
import os
import boto3
import datetime
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.model = AutoModelForImageSegmentation.from_pretrained(
'whlzy/remove_bg_api',
trust_remote_code=True,
token=os.environ.get("HUGGINGFACE_TOKEN")
)
self.model.to(device)
self.model.eval()
image_size = (1024, 1024)
self.transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def update_to_s3(self, image):
BUCKET_NAME = 'popwear-assets'
BUCKET_PREFIX_PATH = 'removebg'
ACCOUNT_ID = '18cc2282d0ee72171c1ea322ed22983c'
ACCESS_KEY_ID = '007f1852a377a2df43a21d5c8d54542e'
SECRET_ACCESS_KEY = 'db2658e2429950bb05e15afb6c53c8b7fd23ab9e1bf79cd42604c89f276068e4'
ENDPOINT_URL = f'https://{ACCOUNT_ID}.r2.cloudflarestorage.com'
bucket_postfix_path = f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jpg"
image_url = f"https://assets.popwear.ai/{BUCKET_PREFIX_PATH}/{bucket_postfix_path}"
s3 = boto3.client(
's3',
endpoint_url=ENDPOINT_URL,
aws_access_key_id=ACCESS_KEY_ID,
aws_secret_access_key=SECRET_ACCESS_KEY,
region_name='auto'
)
output_buffer = BytesIO()
image.save(output_buffer, format='WEBP', quality=85, method=4)
output_buffer.seek(0)
s3.upload_fileobj(output_buffer, BUCKET_NAME, f"{BUCKET_PREFIX_PATH}/{bucket_postfix_path}")
return image_url
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
image = data.pop("inputs", data)
# image = self.decode_base64_image(image)
input_images = self.transform_image(image).unsqueeze(0).to('cuda')
with torch.no_grad():
preds = self.model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image.size)
image.putalpha(mask)
image_url = self.update_to_s3(image)
return image_url
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image
|