Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import PIL | |
| from threading import Thread | |
| from transformers import AutoModel, AutoProcessor | |
| from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList | |
| from torchvision.transforms.functional import normalize | |
| from huggingface_hub import hf_hub_download | |
| from briarmbg import BriaRMBG | |
| from PIL import Image | |
| from typing import Tuple | |
| net=BriaRMBG() | |
| # model_path = "./model1.pth" | |
| model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth') | |
| if torch.cuda.is_available(): | |
| net.load_state_dict(torch.load(model_path)) | |
| net=net.cuda() | |
| else: | |
| net.load_state_dict(torch.load(model_path,map_location="cpu")) | |
| net.eval() | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| model = AutoModel.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True).to(device) | |
| processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-qwen-500m", trust_remote_code=True) | |
| class StopOnTokens(StoppingCriteria): | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| stop_ids = [151645] | |
| for stop_id in stop_ids: | |
| if input_ids[0][-1] == stop_id: | |
| return True | |
| return False | |
| def getProductDetails(history, image): | |
| product_description=getImageDescription(history, image) | |
| # clients = InferenceClient("google/gemma-7b") | |
| # rand_val = random.randint(1, 1111111111111111) | |
| if not history: | |
| history = [{"role": "system", "content": "You are a helpful assistant."}] | |
| # generate_kwargs = dict( | |
| # temperature=temp, | |
| # max_new_tokens=tokens, | |
| # top_p=top_p, | |
| # repetition_penalty=rep_p, | |
| # do_sample=True, | |
| # seed=seed, | |
| # ) | |
| # system_prompt="you're a helpful e-commerce marketting assitant" | |
| # prompt="Write me a poem" | |
| # formatted_prompt = self.format_prompt(f"{system_prompt}, {prompt}", history) | |
| # stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=stream_output, details=True, return_full_text=False) | |
| # output = "" | |
| # for response in stream: | |
| # output += response.token.text | |
| # yield [(prompt, output)] | |
| gr.Info('Gemma:' + product_description) | |
| # history.append((prompt, output)) | |
| return history | |
| def getImageDescription(image): | |
| message = "Generate a product title for the image" | |
| gr.Info('Starting...' + message) | |
| stop = StopOnTokens() | |
| messages = [{"role": "system", "content": "You are a helpful assistant."}] | |
| # for user_msg, assistant_msg in history: | |
| # messages.append({"role": "user", "content": user_msg}) | |
| # messages.append({"role": "assistant", "content": assistant_msg}) | |
| if len(messages) == 1: | |
| message = f" <image>{message}" | |
| messages.append({"role": "user", "content": message}) | |
| model_inputs = processor.tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ) | |
| image = ( | |
| processor.feature_extractor(image) | |
| .unsqueeze(0) | |
| ) | |
| attention_mask = torch.ones( | |
| 1, model_inputs.shape[1] + processor.num_image_latents - 1 | |
| ) | |
| model_inputs = { | |
| "input_ids": model_inputs, | |
| "images": image, | |
| "attention_mask": attention_mask | |
| } | |
| model_inputs = {k: v.to(device) for k, v in model_inputs.items()} | |
| streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| model_inputs, | |
| streamer=streamer, | |
| max_new_tokens=1024, | |
| stopping_criteria=StoppingCriteriaList([stop]) | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| # history.append([message, ""]) | |
| partial_response = "" | |
| for new_token in streamer: | |
| partial_response += new_token | |
| # history[-1][1] = partial_response | |
| # yield history | |
| gr.Info('Got:' + partial_response) | |
| return partial_response | |
| def resize_image(image): | |
| image = image.convert('RGB') | |
| model_input_size = (1024, 1024) | |
| image = image.resize(model_input_size, Image.BILINEAR) | |
| return image | |
| def process(image): | |
| # prepare input | |
| orig_image = image | |
| w,h = orig_im_size = orig_image.size | |
| image = resize_image(orig_image) | |
| im_np = np.array(image) | |
| im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1) | |
| im_tensor = torch.unsqueeze(im_tensor,0) | |
| im_tensor = torch.divide(im_tensor,255.0) | |
| im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) | |
| if torch.cuda.is_available(): | |
| im_tensor=im_tensor.cuda() | |
| #inference | |
| result=net(im_tensor) | |
| # post process | |
| result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0) | |
| ma = torch.max(result) | |
| mi = torch.min(result) | |
| result = (result-mi)/(ma-mi) | |
| # image to pil | |
| im_array = (result*255).cpu().data.numpy().astype(np.uint8) | |
| pil_im = Image.fromarray(np.squeeze(im_array)) | |
| # paste the mask on the original image | |
| new_im = Image.new("RGBA", pil_im.size, (0,0,0,0)) | |
| new_im.paste(orig_image, mask=pil_im) | |
| # new_orig_image = orig_image.convert('RGBA') | |
| return new_im | |
| title = """<h1 style="text-align: center;">Product description generator</h1>""" | |
| css = """ | |
| div#col-container { | |
| margin: 0 auto; | |
| max-width: 840px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML(title) | |
| with gr.Row(): | |
| with gr.Column(elem_id="col-container"): | |
| image = gr.Image(type="pil") | |
| chat = gr.Chatbot(show_label=False) | |
| submit = gr.Button(value="Upload", variant="primary") | |
| with gr.Column(): | |
| output = gr.Image(type="pil", interactive=False) | |
| response_handler = ( | |
| getProductDetails, | |
| [image], | |
| [] | |
| ) | |
| background_remover_handler = ( | |
| process, | |
| [image], | |
| [output] | |
| ) | |
| # postresponse_handler = ( | |
| # lambda: (gr.Button(visible=False), gr.Button(visible=True)), | |
| # None, | |
| # [submit] | |
| # ) | |
| event = submit.click(*response_handler) | |
| event2 = submit.click(*background_remover_handler) | |
| # event.then(*postresponse_handler) | |
| demo.launch() |