File size: 2,696 Bytes
f9cfd2a
 
 
 
 
 
 
 
c55c1eb
5e05b9e
f9cfd2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c55c1eb
 
 
 
 
 
 
 
1104805
c55c1eb
 
 
 
 
 
 
e928816
9a6d4b3
e928816
 
c55c1eb
 
f9cfd2a
 
314fa7f
f9cfd2a
 
 
 
 
 
 
c55c1eb
 
f9cfd2a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import  Dict, List, Any
import base64
from io import BytesIO
import os
import boto3
import datetime

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class EndpointHandler():
    def __init__(self, path=""):
        self.model = AutoModelForImageSegmentation.from_pretrained(
            'whlzy/remove_bg_api', 
            trust_remote_code=True,
            token=os.environ.get("HUGGINGFACE_TOKEN")
        )
        self.model.to(device)
        self.model.eval()
        image_size = (1024, 1024)
        self.transform_image = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def update_to_s3(self, image):
        BUCKET_NAME = 'popwear-assets'
        BUCKET_PREFIX_PATH = 'removebg'
        ACCOUNT_ID = '18cc2282d0ee72171c1ea322ed22983c'
        ACCESS_KEY_ID = '007f1852a377a2df43a21d5c8d54542e'
        SECRET_ACCESS_KEY = 'db2658e2429950bb05e15afb6c53c8b7fd23ab9e1bf79cd42604c89f276068e4'
        ENDPOINT_URL = f'https://{ACCOUNT_ID}.r2.cloudflarestorage.com'
        bucket_postfix_path = f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jpg"
        image_url = f"https://assets.popwear.ai/{BUCKET_PREFIX_PATH}/{bucket_postfix_path}"
        s3 = boto3.client(
            's3',
            endpoint_url=ENDPOINT_URL,
            aws_access_key_id=ACCESS_KEY_ID,
            aws_secret_access_key=SECRET_ACCESS_KEY,
            region_name='auto'
        )
        output_buffer = BytesIO()
        image.save(output_buffer, format='WEBP', quality=85, method=4)
        output_buffer.seek(0)
        s3.upload_fileobj(output_buffer, BUCKET_NAME, f"{BUCKET_PREFIX_PATH}/{bucket_postfix_path}")
        return image_url

    def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
        image = data.pop("inputs", data)
        # image = self.decode_base64_image(image)
        input_images = self.transform_image(image).unsqueeze(0).to('cuda')
        with torch.no_grad():
            preds = self.model(input_images)[-1].sigmoid().cpu()
        pred = preds[0].squeeze()
        pred_pil = transforms.ToPILImage()(pred)
        mask = pred_pil.resize(image.size)
        image.putalpha(mask)
        image_url = self.update_to_s3(image)
        return image_url

    def decode_base64_image(self, image_string):
        base64_image = base64.b64decode(image_string)
        buffer = BytesIO(base64_image)
        image = Image.open(buffer)
        return image