Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Upload 4 files
Browse files- .gitattributes +1 -0
 - app (1).py +31 -0
 - delay_tyre.mp4 +3 -0
 - process.py +111 -0
 - requirements (1).txt +26 -0
 
    	
        .gitattributes
    CHANGED
    
    | 
         @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text 
     | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         | 
| 
         | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
            +
            delay_tyre.mp4 filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        app (1).py
    ADDED
    
    | 
         @@ -0,0 +1,31 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gradio as gr
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from process import inference
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            def clickit(video, prompt):
         
     | 
| 6 | 
         
            +
                return inference(
         
     | 
| 7 | 
         
            +
                    video,
         
     | 
| 8 | 
         
            +
                    prompt
         
     | 
| 9 | 
         
            +
                )
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            with gr.Blocks() as blok:
         
     | 
| 12 | 
         
            +
                with gr.Row():
         
     | 
| 13 | 
         
            +
                    with gr.Column():
         
     | 
| 14 | 
         
            +
                        video = gr.Video(
         
     | 
| 15 | 
         
            +
                            label="video input",
         
     | 
| 16 | 
         
            +
                        )
         
     | 
| 17 | 
         
            +
                        prompt = gr.Text(
         
     | 
| 18 | 
         
            +
                            label="Prompt",
         
     | 
| 19 | 
         
            +
                            value="Please describe this video in detail."
         
     | 
| 20 | 
         
            +
                        )
         
     | 
| 21 | 
         
            +
                    with gr.Column():
         
     | 
| 22 | 
         
            +
                        button = gr.Button("Caption it", variant="primary")
         
     | 
| 23 | 
         
            +
                        text = gr.Text(label="Output")
         
     | 
| 24 | 
         
            +
                    
         
     | 
| 25 | 
         
            +
                    button.click(
         
     | 
| 26 | 
         
            +
                        fn=clickit,
         
     | 
| 27 | 
         
            +
                        inputs=[video, prompt],
         
     | 
| 28 | 
         
            +
                        outputs=[text]
         
     | 
| 29 | 
         
            +
                    )
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            blok.launch()
         
     | 
    	
        delay_tyre.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:63a245902a9509f492fda6537c84ab53c3582f868503982b53419c01fee6e592
         
     | 
| 3 | 
         
            +
            size 7352910
         
     | 
    	
        process.py
    ADDED
    
    | 
         @@ -0,0 +1,111 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import io
         
     | 
| 2 | 
         
            +
            import spaces
         
     | 
| 3 | 
         
            +
            import argparse
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            from decord import cpu, VideoReader, bridge
         
     | 
| 7 | 
         
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer
         
     | 
| 8 | 
         
            +
            from transformers import BitsAndBytesConfig
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
         
     | 
| 11 | 
         
            +
            DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
         
     | 
| 12 | 
         
            +
            TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[
         
     | 
| 13 | 
         
            +
                0] >= 8 else torch.float16
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo")
         
     | 
| 16 | 
         
            +
            parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=4)
         
     | 
| 17 | 
         
            +
            args = parser.parse_args([])
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def load_video(video_data, strategy='chat'):
         
     | 
| 20 | 
         
            +
                bridge.set_bridge('torch')
         
     | 
| 21 | 
         
            +
                mp4_stream = video_data
         
     | 
| 22 | 
         
            +
                num_frames = 24
         
     | 
| 23 | 
         
            +
                decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
         
     | 
| 24 | 
         
            +
                frame_id_list = None
         
     | 
| 25 | 
         
            +
                total_frames = len(decord_vr)
         
     | 
| 26 | 
         
            +
                
         
     | 
| 27 | 
         
            +
                if strategy == 'base':
         
     | 
| 28 | 
         
            +
                    clip_end_sec = 60
         
     | 
| 29 | 
         
            +
                    clip_start_sec = 0
         
     | 
| 30 | 
         
            +
                    start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
         
     | 
| 31 | 
         
            +
                    end_frame = min(total_frames,
         
     | 
| 32 | 
         
            +
                                    int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames
         
     | 
| 33 | 
         
            +
                    frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
         
     | 
| 34 | 
         
            +
                elif strategy == 'chat':
         
     | 
| 35 | 
         
            +
                    timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
         
     | 
| 36 | 
         
            +
                    timestamps = [i[0] for i in timestamps]
         
     | 
| 37 | 
         
            +
                    max_second = round(max(timestamps)) + 1
         
     | 
| 38 | 
         
            +
                    frame_id_list = []
         
     | 
| 39 | 
         
            +
                    for second in range(max_second):
         
     | 
| 40 | 
         
            +
                        closest_num = min(timestamps, key=lambda x: abs(x - second))
         
     | 
| 41 | 
         
            +
                        index = timestamps.index(closest_num)
         
     | 
| 42 | 
         
            +
                        frame_id_list.append(index)
         
     | 
| 43 | 
         
            +
                        if len(frame_id_list) >= num_frames:
         
     | 
| 44 | 
         
            +
                            break
         
     | 
| 45 | 
         
            +
                            
         
     | 
| 46 | 
         
            +
                video_data = decord_vr.get_batch(frame_id_list)
         
     | 
| 47 | 
         
            +
                video_data = video_data.permute(3, 0, 1, 2)
         
     | 
| 48 | 
         
            +
                return video_data
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            # Configure quantization
         
     | 
| 51 | 
         
            +
            quantization_config = BitsAndBytesConfig(
         
     | 
| 52 | 
         
            +
                load_in_4bit=True,
         
     | 
| 53 | 
         
            +
                bnb_4bit_compute_dtype=TORCH_TYPE,
         
     | 
| 54 | 
         
            +
                bnb_4bit_use_double_quant=True,
         
     | 
| 55 | 
         
            +
                bnb_4bit_quant_type="nf4"
         
     | 
| 56 | 
         
            +
            )
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            tokenizer = AutoTokenizer.from_pretrained(
         
     | 
| 59 | 
         
            +
                MODEL_PATH,
         
     | 
| 60 | 
         
            +
                trust_remote_code=True,
         
     | 
| 61 | 
         
            +
            )
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 64 | 
         
            +
                MODEL_PATH,
         
     | 
| 65 | 
         
            +
                torch_dtype=TORCH_TYPE,
         
     | 
| 66 | 
         
            +
                trust_remote_code=True,
         
     | 
| 67 | 
         
            +
                quantization_config=quantization_config,
         
     | 
| 68 | 
         
            +
                device_map="auto"
         
     | 
| 69 | 
         
            +
            ).eval()
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            @spaces.GPU
         
     | 
| 72 | 
         
            +
            def predict(prompt, video_data, temperature):
         
     | 
| 73 | 
         
            +
                strategy = 'chat'
         
     | 
| 74 | 
         
            +
                video = load_video(video_data, strategy=strategy)
         
     | 
| 75 | 
         
            +
                history = []
         
     | 
| 76 | 
         
            +
                query = prompt
         
     | 
| 77 | 
         
            +
                inputs = model.build_conversation_input_ids(
         
     | 
| 78 | 
         
            +
                    tokenizer=tokenizer,
         
     | 
| 79 | 
         
            +
                    query=query,
         
     | 
| 80 | 
         
            +
                    images=[video],
         
     | 
| 81 | 
         
            +
                    history=history,
         
     | 
| 82 | 
         
            +
                    template_version=strategy
         
     | 
| 83 | 
         
            +
                )
         
     | 
| 84 | 
         
            +
                
         
     | 
| 85 | 
         
            +
                inputs = {
         
     | 
| 86 | 
         
            +
                    'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
         
     | 
| 87 | 
         
            +
                    'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
         
     | 
| 88 | 
         
            +
                    'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
         
     | 
| 89 | 
         
            +
                    'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
         
     | 
| 90 | 
         
            +
                }
         
     | 
| 91 | 
         
            +
                
         
     | 
| 92 | 
         
            +
                gen_kwargs = {
         
     | 
| 93 | 
         
            +
                    "max_new_tokens": 2048,
         
     | 
| 94 | 
         
            +
                    "pad_token_id": 128002,
         
     | 
| 95 | 
         
            +
                    "top_k": 1,
         
     | 
| 96 | 
         
            +
                    "do_sample": False,
         
     | 
| 97 | 
         
            +
                    "top_p": 0.1,
         
     | 
| 98 | 
         
            +
                    "temperature": temperature,
         
     | 
| 99 | 
         
            +
                }
         
     | 
| 100 | 
         
            +
                
         
     | 
| 101 | 
         
            +
                with torch.no_grad():
         
     | 
| 102 | 
         
            +
                    outputs = model.generate(**inputs, **gen_kwargs)
         
     | 
| 103 | 
         
            +
                    outputs = outputs[:, inputs['input_ids'].shape[1]:]
         
     | 
| 104 | 
         
            +
                    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
         
     | 
| 105 | 
         
            +
                    return response
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            def inference(video, prompt):
         
     | 
| 108 | 
         
            +
                temperature = 0.1
         
     | 
| 109 | 
         
            +
                video_data = open(video, 'rb').read()
         
     | 
| 110 | 
         
            +
                response = predict(prompt, video_data, temperature)
         
     | 
| 111 | 
         
            +
                return response
         
     | 
    	
        requirements (1).txt
    ADDED
    
    | 
         @@ -0,0 +1,26 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            decord>=0.6.0
         
     | 
| 2 | 
         
            +
            #鏍规嵁https://download.pytorch.org/whl/torch/锛宲ython鐗堟湰涓篬3.8,3.11]
         
     | 
| 3 | 
         
            +
            torch==2.1.0
         
     | 
| 4 | 
         
            +
            torchvision== 0.16.0
         
     | 
| 5 | 
         
            +
            pytorchvideo==0.1.5
         
     | 
| 6 | 
         
            +
            xformers
         
     | 
| 7 | 
         
            +
            transformers==4.42.4
         
     | 
| 8 | 
         
            +
            #git+https://github.com/huggingface/transformers.git
         
     | 
| 9 | 
         
            +
            huggingface-hub>=0.23.0
         
     | 
| 10 | 
         
            +
            pillow
         
     | 
| 11 | 
         
            +
            chainlit>=1.0
         
     | 
| 12 | 
         
            +
            pydantic>=2.7.1
         
     | 
| 13 | 
         
            +
            timm>=0.9.16
         
     | 
| 14 | 
         
            +
            openai>=1.30.1
         
     | 
| 15 | 
         
            +
            loguru>=0.7.2
         
     | 
| 16 | 
         
            +
            pydantic>=2.7.1
         
     | 
| 17 | 
         
            +
            einops
         
     | 
| 18 | 
         
            +
            sse-starlette>=2.1.0
         
     | 
| 19 | 
         
            +
            flask
         
     | 
| 20 | 
         
            +
            gunicorn
         
     | 
| 21 | 
         
            +
            gevent
         
     | 
| 22 | 
         
            +
            requests
         
     | 
| 23 | 
         
            +
            gradio
         
     | 
| 24 | 
         
            +
            accelerate
         
     | 
| 25 | 
         
            +
            bitsandbytes>=0.39.0
         
     | 
| 26 | 
         
            +
            spaces
         
     |