Spaces:
Runtime error
Runtime error
| import os | |
| current_file_dir=os.path.dirname(os.path.abspath(__file__)) | |
| import subprocess | |
| msg = subprocess.check_output([f"cd {current_file_dir}/model/lib/pointnet2 && python setup.py install"]) | |
| print(msg) | |
| import sys | |
| import os | |
| import argparse | |
| import multiprocessing as mp | |
| import numpy as np | |
| from typing import List, Optional | |
| import torch | |
| import torch.distributed as dist | |
| from fairscale.nn.model_parallel import initialize as fs_init | |
| import gradio as gr | |
| from util.misc import setup_for_distributed | |
| from util.misc import default_tensor_type | |
| from model.meta import MetaModel | |
| from data.conversation_lib import conv_templates, SeparatorStyle | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from data.fintune_dataset import make_audio_features | |
| from data import video_utils | |
| from dataclasses import dataclass | |
| from huggingface_hub import hf_hub_download | |
| import plotly.graph_objects as go | |
| from data.fintune_dataset import pc_norm | |
| from functools import partial | |
| import glob | |
| T_random_resized_crop = transforms.Compose([ | |
| transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3, | |
| antialias=None), # 3 is bicubic | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) | |
| def load_audio(audio_path): | |
| fbank = make_audio_features(audio_path, mel_bins=128) | |
| fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024] | |
| return fbank | |
| def load_video(video_path): | |
| video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5) | |
| return video_feats[:, :, 0] | |
| def load_point(point_path): | |
| point_feat = np.load(point_path) | |
| point_feat = torch.tensor(point_feat) | |
| point_feat = pc_norm(point_feat) | |
| return point_feat | |
| def load_fmri(fmri_path): | |
| data = np.load(fmri_path) | |
| data = data.mean(axis=0) | |
| data = torch.tensor(data[None]) | |
| return data | |
| def model_worker( | |
| rank: int, args: argparse.Namespace, barrier: mp.Barrier, | |
| request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None, | |
| ) -> None: | |
| """ | |
| The worker function that manipulates the GPU to run the inference. | |
| Exact n_gpu workers are started, with each one operating on a separate GPU. | |
| Args: | |
| rank (int): Distributed rank of the worker. | |
| args (argparse.Namespace): All command line arguments. | |
| barrier (multiprocessing.Barrier): A barrier used to delay the start | |
| of Web UI to be after the start of the model. | |
| """ | |
| world_size = len(args.gpu_ids) | |
| gpu_id = args.gpu_ids[rank] | |
| dist.init_process_group( | |
| backend="nccl", rank=rank, world_size=world_size, | |
| init_method=f"tcp://{args.master_addr}:{args.master_port}", | |
| ) | |
| print(f"| distributed init on worker {rank}/{world_size}. " | |
| f"using gpu: {gpu_id}") | |
| fs_init.initialize_model_parallel(world_size) | |
| torch.cuda.set_device(gpu_id) | |
| torch.manual_seed(1) | |
| np.random.seed(1) | |
| # set the print behavior. | |
| setup_for_distributed(rank == 0) | |
| target_dtype = { | |
| "bf16": torch.bfloat16, | |
| "fp16": torch.float16 | |
| }[args.dtype] | |
| with default_tensor_type(dtype=target_dtype, device="cuda"): | |
| model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path) | |
| for ckpt_id in range(args.num_ckpts): | |
| # ckpt_path = hf_hub_download(repo_id=args.pretrained_path, filename=args.ckpt_format.format(str(ckpt_id))) | |
| ckpt_path = os.path.join(args.pretrained_path, args.ckpt_format.format(str(ckpt_id))) | |
| print(f"Loading pretrained weights {ckpt_path}") | |
| checkpoint = torch.load(ckpt_path, map_location='cpu') | |
| msg = model.load_state_dict(checkpoint, strict=False) | |
| # print("load result:\n", msg) | |
| model.cuda() | |
| model.eval() | |
| print(f"Model = {str(model)}") | |
| barrier.wait() | |
| while True: | |
| img_path, audio_path, video_path, point_path, fmri_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get() | |
| if 'image' in modality and img_path is not None: | |
| image = Image.open(img_path).convert('RGB') | |
| inputs = T_random_resized_crop(image) | |
| elif 'video' in modality and video_path is not None: | |
| inputs = load_video(video_path) | |
| elif 'audio' in modality and audio_path is not None: | |
| inputs = load_audio(audio_path) | |
| elif 'point' in modality and point_path is not None: | |
| inputs = load_point(point_path) | |
| elif 'fmri' in modality and fmri_path is not None: | |
| inputs = load_fmri(fmri_path) | |
| else: | |
| inputs = None | |
| if inputs is not None: | |
| inputs = inputs[None].cuda().to(target_dtype) | |
| conv = conv_templates["v1"].copy() | |
| for user, bot in chatbot: | |
| conv.append_message(conv.roles[0], user) | |
| conv.append_message(conv.roles[1], bot) | |
| with torch.cuda.amp.autocast(dtype=target_dtype): | |
| print(conv.get_prompt()) | |
| for stream_response in model.stream_generate( | |
| conv.get_prompt(), inputs, | |
| max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, | |
| modal = modality | |
| ): | |
| conv_sep = ( | |
| conv.sep | |
| if conv.sep_style == SeparatorStyle.SINGLE | |
| else conv.sep2 | |
| ) | |
| end_pos = stream_response["text"].find(conv_sep) | |
| if end_pos != -1: | |
| stream_response["text"] = ( | |
| stream_response['text'][:end_pos].rstrip() + "\n" | |
| ) | |
| stream_response["end_of_content"] = True | |
| # keep a few characters if not end_of_content to avoid sending | |
| # part of conv_sep before all of it is generated. | |
| if not stream_response["end_of_content"]: | |
| if len(stream_response["text"]) < len(conv_sep): | |
| continue | |
| stream_response["text"] = ( | |
| stream_response["text"][:-len(conv_sep)] | |
| ) | |
| if response_queue is not None: | |
| response_queue.put(stream_response) | |
| if stream_response["end_of_content"]: | |
| break | |
| def gradio_worker( | |
| request_queues: List[mp.Queue], response_queue: mp.Queue, | |
| args: argparse.Namespace, barrier: mp.Barrier, | |
| ) -> None: | |
| """ | |
| The gradio worker is responsible for displaying the WebUI and relay the | |
| requests to model workers. It should be launched only once. | |
| Args: | |
| request_queues (List[mp.Queue]): A list of request queues (one for | |
| each model worker). | |
| args (argparse.Namespace): All command line arguments. | |
| barrier (multiprocessing.Barrier): A barrier used to delay the start | |
| of Web UI to be after the start of the model. | |
| """ | |
| def show_user_input(msg, chatbot): | |
| return "", chatbot + [[msg, None]] | |
| def stream_model_output(img_path, audio_path, video_path, point_path, fmri_path, chatbot, max_gen_len, gen_t, top_p, modality): | |
| for queue in request_queues: | |
| queue.put((img_path, audio_path, video_path, point_path, fmri_path, chatbot, max_gen_len, gen_t, top_p, modality)) | |
| while True: | |
| content_piece = response_queue.get() | |
| chatbot[-1][1] = content_piece["text"] | |
| yield chatbot | |
| if content_piece["end_of_content"]: | |
| break | |
| def undo(chatbot): | |
| if len(chatbot) > 0: | |
| chatbot = chatbot[:-1] | |
| return chatbot | |
| def clear(): | |
| chatbot = [] | |
| msg = "" | |
| return chatbot, msg | |
| def show_point_cloud(file): | |
| point = load_point(file).numpy() | |
| fig = go.Figure( | |
| data=[ | |
| go.Scatter3d( | |
| x=point[:,0], y=point[:,1], z=point[:,2], | |
| mode='markers', | |
| marker=dict( | |
| size=1.2, | |
| color=['rgb({},{},{})'.format(r, g, b) for r,g,b in zip(point[:,3], point[:,4], point[:,5])] | |
| ))], | |
| layout=dict( | |
| scene=dict( | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| zaxis=dict(visible=False) | |
| )),) | |
| return fig | |
| def change_modality(modal): | |
| return modal | |
| CSS =""" | |
| .contain { display: flex; flex-direction: column; } | |
| #component-0 { height: 100%; } | |
| #chatbot { flex-grow: 1; overflow: auto;} | |
| """ | |
| header=""" | |
| ## OneLLM: One Framework to Align All Modalities with Language | |
| [[Project Page](https://onellm.csuhan.com)] [[Paper](https://github.com/csuhan/OneLLM/blob/main/OneLLM_Arxiv.pdf)] [[Code](https://github.com/csuhan/OneLLM)] | |
| """ | |
| with gr.Blocks(css=CSS, theme=gr.themes.Base()) as demo: | |
| gr.Markdown(header) | |
| with gr.Row(equal_height=True): | |
| modality = gr.Textbox(value='image', visible=False) | |
| with gr.Column(scale=1): | |
| with gr.Tab('Image') as img_tab: | |
| img_path = gr.Image(label='Image Input', type='filepath') | |
| gr.Examples( | |
| examples=[ | |
| "examples/new_york.jpg", | |
| "examples/food_menu.png", | |
| ], | |
| inputs=[img_path], | |
| ) | |
| with gr.Tab('Video') as video_tab: | |
| video_path = gr.Video(label='Video Input') | |
| gr.Examples( | |
| examples=[ | |
| "examples/flower.mp4", | |
| "examples/star_kun.mp4", | |
| ], | |
| inputs=[video_path], | |
| ) | |
| with gr.Tab('Audio') as audio_tab: | |
| audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload']) | |
| gr.Examples( | |
| examples=[ | |
| "examples/bell_ring.wav", | |
| "examples/bird_audio.wav", | |
| ], | |
| inputs=[audio_path], | |
| ) | |
| with gr.Tab('Point Cloud') as point_tab: | |
| point_path = gr.File(label='Point Cloud Input', elem_id="pointpath", elem_classes="") | |
| point_vis = gr.Plot() | |
| btn = gr.Button(value="Show Point Cloud") | |
| btn.click(show_point_cloud, point_path, point_vis) | |
| gr.Examples( | |
| examples=glob.glob("examples/point/*.npy"), | |
| inputs=[point_path], | |
| examples_per_page=5, | |
| ) | |
| with gr.Tab('IMU') as imu_tab: | |
| gr.Markdown('Coming soon🤗') | |
| with gr.Tab('fMRI') as fmri_tab: | |
| fmri_path = gr.File(label='fMRI Input', elem_id="fmripath", elem_classes="") | |
| fmri_image_path = gr.Image(interactive=False) | |
| gr.Examples( | |
| examples=[ | |
| [file.replace('.jpg', '.npy'), file] | |
| for file in glob.glob("examples/fmri/*.jpg") | |
| ], | |
| inputs=[fmri_path, fmri_image_path], | |
| examples_per_page=3, | |
| ) | |
| with gr.Tab('Depth Map') as depth_tab: | |
| gr.Markdown('Coming soon🤗') | |
| with gr.Tab('Normal Map') as normal_tab: | |
| gr.Markdown('Coming soon🤗') | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(elem_id="chatbot") | |
| msg = gr.Textbox() | |
| with gr.Row(): | |
| submit_button = gr.Button("Submit", variant="primary") | |
| undo_button = gr.Button("Undo") | |
| clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, point_path, fmri_path, point_vis]) | |
| with gr.Row(): | |
| max_gen_len = gr.Slider( | |
| minimum=1, maximum=args.model_max_seq_len // 2, | |
| value=args.model_max_seq_len // 2, interactive=True, | |
| label="Single-turn max response length", | |
| ) | |
| gen_t = gr.Slider( | |
| minimum=0, maximum=1, value=0.1, interactive=True, | |
| label="Temperature", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0, maximum=1, value=0.75, interactive=True, | |
| label="Top-p", | |
| ) | |
| img_tab.select(partial(change_modality, 'image'), [], [modality]) | |
| video_tab.select(partial(change_modality, 'video'), [], [modality]) | |
| audio_tab.select(partial(change_modality, 'audio'), [], [modality]) | |
| point_tab.select(partial(change_modality, 'point'), [], [modality]) | |
| fmri_tab.select(partial(change_modality, 'fmri'), [], [modality]) | |
| msg.submit( | |
| show_user_input, [msg, chatbot], [msg, chatbot], | |
| ).then( | |
| stream_model_output, [img_path, audio_path, video_path, point_path, fmri_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot, | |
| ) | |
| submit_button.click( | |
| show_user_input, [msg, chatbot], [msg, chatbot], | |
| ).then( | |
| stream_model_output, [img_path, audio_path, video_path, point_path, fmri_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot, | |
| ) | |
| undo_button.click(undo, chatbot, chatbot) | |
| # img_path.change(clear, [], [chatbot, msg]) | |
| barrier.wait() | |
| demo.queue(api_open=True).launch(share=True, max_threads=1) | |
| class DemoConfig: | |
| gpu_ids = [0] | |
| tokenizer_path = "config/llama2/tokenizer.model" | |
| llama_type = "onellm" | |
| llama_config = "config/llama2/7B.json" | |
| model_max_seq_len = 2048 | |
| # pretrained_path = "weights/7B_2048/consolidated.00-of-01.pth" | |
| # pretrained_path = hf_hub_download(repo_id="csuhan/OneLLM-7B", filename="consolidated.00-of-01.pth") | |
| # pretrained_path = "csuhan/OneLLM-7B-hf" | |
| pretrained_path = "/home/pgao/jiaming/weights/7B_v20_splits/" | |
| ckpt_format = "consolidated.00-of-01.s{}.pth" | |
| num_ckpts = 10 | |
| master_port = 23863 | |
| master_addr = "127.0.0.1" | |
| dtype = "fp16" | |
| if __name__ == "__main__": | |
| args = DemoConfig() | |
| # using the default "fork" method messes up some imported libs (e.g., | |
| # pandas) | |
| # mp.set_start_method("spawn") | |
| # setup the queues and start the model workers | |
| request_queues = [] | |
| response_queue = mp.Queue() | |
| worker_processes = [] | |
| barrier = mp.Barrier(len(args.gpu_ids) + 1) | |
| for rank, gpu_id in enumerate(args.gpu_ids): | |
| request_queue = mp.Queue() | |
| rank_response_queue = response_queue if rank == 0 else None | |
| process = mp.Process( | |
| target=model_worker, | |
| args=(rank, args, barrier, request_queue, rank_response_queue), | |
| ) | |
| process.start() | |
| worker_processes.append(process) | |
| request_queues.append(request_queue) | |
| gradio_worker(request_queues, response_queue, args, barrier) | |