Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import sys | |
| from projects.llava_sam2.gradio.app_utils import\ | |
| process_markdown, show_mask_pred, description, preprocess_video,\ | |
| show_mask_pred_video, image2video_and_save | |
| import torch | |
| from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, | |
| BitsAndBytesConfig, CLIPImageProcessor, | |
| CLIPVisionModel, GenerationConfig) | |
| import argparse | |
| import os | |
| TORCH_DTYPE_MAP = dict( | |
| fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') | |
| def parse_args(args): | |
| parser = argparse.ArgumentParser(description="Sa2VA Demo") | |
| parser.add_argument('hf_path', help='Sa2VA hf path.') | |
| return parser.parse_args(args) | |
| def inference(image, video, follow_up, input_str): | |
| input_image = image | |
| if image is not None and (video is not None and os.path.exists(video)): | |
| return image, video, "Error: Please only input a image or a video !!!" | |
| if image is None and (video is None or not os.path.exists(video)) and not follow_up: | |
| return image, video, "Error: Please input a image or a video !!!" | |
| if not follow_up: | |
| # reset | |
| print('Log: History responses have been removed!') | |
| global_infos.n_turn = 0 | |
| global_infos.inputs = '' | |
| text = input_str | |
| image = input_image | |
| global_infos.image_for_show = image | |
| global_infos.image = image | |
| video = video | |
| global_infos.video = video | |
| if image is not None: | |
| global_infos.input_type = "image" | |
| else: | |
| global_infos.input_type = "video" | |
| else: | |
| text = input_str | |
| image = global_infos.image | |
| video = global_infos.video | |
| input_type = global_infos.input_type | |
| if input_type == "video": | |
| video = preprocess_video(video, global_infos.inputs+input_str) | |
| past_text = global_infos.inputs | |
| if past_text == "" and "<image>" not in text: | |
| text = "<image>" + text | |
| if input_type == "image": | |
| input_dict = { | |
| 'image': image, | |
| 'text': text, | |
| 'past_text': past_text, | |
| 'mask_prompts': None, | |
| 'tokenizer': tokenizer, | |
| } | |
| else: | |
| input_dict = { | |
| 'video': video, | |
| 'text': text, | |
| 'past_text': past_text, | |
| 'mask_prompts': None, | |
| 'tokenizer': tokenizer, | |
| } | |
| return_dict = sa2va_model.predict_forward(**input_dict) | |
| global_infos.inputs = return_dict["past_text"] | |
| print(return_dict['past_text']) | |
| if 'prediction_masks' in return_dict.keys() and return_dict['prediction_masks'] and len( | |
| return_dict['prediction_masks']) != 0: | |
| if input_type == "image": | |
| image_mask_show, selected_colors = show_mask_pred(global_infos.image_for_show, return_dict['prediction_masks'],) | |
| video_mask_show = global_infos.video | |
| else: | |
| image_mask_show = None | |
| video_mask_show, selected_colors = show_mask_pred_video(video, return_dict['prediction_masks'],) | |
| video_mask_show = image2video_and_save(video_mask_show, save_path="./ret_video.mp4") | |
| else: | |
| image_mask_show = global_infos.image_for_show | |
| video_mask_show = global_infos.video | |
| selected_colors = [] | |
| predict = return_dict['prediction'].strip() | |
| global_infos.n_turn += 1 | |
| predict = process_markdown(predict, selected_colors) | |
| return image_mask_show, video_mask_show, predict | |
| def init_models(args): | |
| model_path = args.hf_path | |
| model = AutoModel.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| use_flash_attn=True, | |
| trust_remote_code=True, | |
| ).eval().cuda() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| ) | |
| return model, tokenizer | |
| class global_infos: | |
| inputs = '' | |
| n_turn = 0 | |
| image_width = 0 | |
| image_height = 0 | |
| image_for_show = None | |
| image = None | |
| video = None | |
| input_type = "image" # "image" or "video" | |
| if __name__ == "__main__": | |
| # get parse args and set models | |
| args = parse_args(sys.argv[1:]) | |
| sa2va_model, tokenizer = \ | |
| init_models(args) | |
| demo = gr.Interface( | |
| inference, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image", height=360), | |
| gr.Video(sources=["upload", "webcam"], label="Upload mp4 video", height=360), | |
| gr.Checkbox(label="Follow up Question"), | |
| gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),], | |
| outputs=[ | |
| gr.Image(type="pil", label="Output Image"), | |
| gr.Video(label="Output Video", show_download_button=True, format='mp4'), | |
| gr.Markdown()], | |
| theme=gr.themes.Soft(), allow_flagging="auto", description=description, | |
| title='Sa2VA' | |
| ) | |
| demo.queue() | |
| demo.launch(share=True) |