Upload handler.py
Browse files- handler.py +24 -1
handler.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Dict, List, Any
|
|
6 |
import base64
|
7 |
from io import BytesIO
|
8 |
import os
|
|
|
9 |
|
10 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
11 |
|
@@ -25,6 +26,27 @@ class EndpointHandler():
|
|
25 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
26 |
])
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
29 |
image = data.pop("inputs", data)
|
30 |
# image = self.decode_base64_image(image)
|
@@ -35,7 +57,8 @@ class EndpointHandler():
|
|
35 |
pred_pil = transforms.ToPILImage()(pred)
|
36 |
mask = pred_pil.resize(image.size)
|
37 |
image.putalpha(mask)
|
38 |
-
|
|
|
39 |
|
40 |
def decode_base64_image(self, image_string):
|
41 |
base64_image = base64.b64decode(image_string)
|
|
|
6 |
import base64
|
7 |
from io import BytesIO
|
8 |
import os
|
9 |
+
import boto3
|
10 |
|
11 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
|
|
|
26 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
27 |
])
|
28 |
|
29 |
+
def update_to_s3(self, image):
|
30 |
+
BUCKET_NAME = 'popwear-assets'
|
31 |
+
BUCKET_PREFIX_PATH = 'removebg'
|
32 |
+
ACCOUNT_ID = '18cc2282d0ee72171c1ea322ed22983c'
|
33 |
+
ACCESS_KEY_ID = '007f1852a377a2df43a21d5c8d54542e'
|
34 |
+
SECRET_ACCESS_KEY = 'db2658e2429950bb05e15afb6c53c8b7fd23ab9e1bf79cd42604c89f276068e4'
|
35 |
+
ENDPOINT_URL = f'https://{ACCOUNT_ID}.r2.cloudflarestorage.com'
|
36 |
+
bucket_postfix_path = f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jpg"
|
37 |
+
|
38 |
+
image_url = f"https://assets.popwear.ai/{bucket_prefix_path}/{bucket_postfix_path}"
|
39 |
+
s3 = boto3.client(
|
40 |
+
's3',
|
41 |
+
endpoint_url=ENDPOINT_URL,
|
42 |
+
aws_access_key_id=ACCESS_KEY_ID,
|
43 |
+
aws_secret_access_key=SECRET_ACCESS_KEY,
|
44 |
+
region_name='auto'
|
45 |
+
)
|
46 |
+
image.seek(0)
|
47 |
+
s3.upload_fileobj(image, BUCKET_NAME, f"{BUCKET_PREFIX_PATH}/{bucket_postfix_path}")
|
48 |
+
return image_url
|
49 |
+
|
50 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
51 |
image = data.pop("inputs", data)
|
52 |
# image = self.decode_base64_image(image)
|
|
|
57 |
pred_pil = transforms.ToPILImage()(pred)
|
58 |
mask = pred_pil.resize(image.size)
|
59 |
image.putalpha(mask)
|
60 |
+
image_url = self.update_to_s3(image)
|
61 |
+
return image_url
|
62 |
|
63 |
def decode_base64_image(self, image_string):
|
64 |
base64_image = base64.b64decode(image_string)
|