File size: 5,618 Bytes
dfae4e0
 
 
 
 
 
 
 
 
b52bed7
dfae4e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b52bed7
 
 
 
6473f86
b52bed7
 
6473f86
b52bed7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4764317
b52bed7
 
 
 
 
 
 
 
 
898922b
 
 
 
b52bed7
18f12de
b52bed7
 
 
 
 
 
dfae4e0
9214f47
 
 
dfae4e0
9214f47
ec0c631
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
import torch
import clip
from PIL import Image
import torch.nn as nn
from model import Projections
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import gradio as gr
import librosa

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
projections = Projections(512, 3072)
projections.load_state_dict(torch.load('checkpoint_dir/checkpoint-6000/projection_layer/pytorch_model.bin', map_location=device), strict=False)
projections = projections.to(device)
projections = projections.to(torch.bfloat16)

checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
model_kwargs = dict(
    use_cache=False,
    trust_remote_code=True,
    attn_implementation='eager',
    torch_dtype=torch.bfloat16,
    device_map=None
)
base_model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)

new_model = "checkpoint_dir/checkpoint-6000/phi_model"  # change to the path where your model is saved

model = PeftModel.from_pretrained(base_model, new_model)
model = model.merge_and_unload()
model = model.to(device)

tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
tokenizer.model_max_length = 2048
tokenizer.pad_token = tokenizer.unk_token  # use unk rather than eos token to prevent endless generation
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
tokenizer.padding_side = 'right'
tokenizer.chat_template = "{% for message in messages %}{% if message['from'] == 'system' %}{{'<|system|>' + message['value'] + '<|end|>'}}{% elif message['from'] ==\
 'human' %}{{'<|user|>' + message['value'] + '<|end|>'}}{% elif message['from'] == 'gpt' %}{{'<|assistant|>' + message['value'] +\
 '<|end|>'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}"

clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)

# Load Whisper model and processor
whisper_model_name = "openai/whisper-small"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)

def infer(message, history):
    max_generate_length = 100
    combined_embeds = []
    
    with torch.no_grad():
        if message['files']:
            projected_image_embeds = None
            audio_text_embeds = None
            for path in message['files']:

                if path.endswith(('.jpg', '.png', '.jpeg')):
                    image = clip_preprocess(Image.open(path)).unsqueeze(0).to(device)
                    image_features = clip_model.encode_image(image)
                    projected_image_embeds = projections(image_features.to(torch.bfloat16)).unsqueeze(0)
            
                elif path.endswith(('.mp3', '.wav')):
                    # Load and preprocess the audio
                    speech, rate = librosa.load(path, sr=16000)
                    input_features = whisper_processor(speech, return_tensors="pt", sampling_rate=16000).input_features 
                    predicted_ids = whisper_model.generate(input_features)
                    transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
                    prompt = tokenizer.apply_chat_template([{"from": "human", "value": transcription}], tokenize=False, add_generation_prompt=True)
                    prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids']
                    audio_text_embeds = model.get_input_embeddings()(prompt_tokens)

            if projected_image_embeds:
                combined_embeds.append(projected_image_embeds)
            
            if audio_text_embeds:
                combined_embeds.append(audio_text_embeds)
        
        if  message['text']:
            prompt = tokenizer.apply_chat_template([{"from": "human", "value": message['text']}], tokenize=False, add_generation_prompt=True)
            prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids']
            text_embeds = model.get_input_embeddings()(prompt_tokens)
            combined_embeds.append(text_embeds)

        combined_embeds = torch.cat(combined_embeds,dim=1)

        predicted_caption = torch.full((1,max_generate_length),50256).to(device)

        for g in range(max_generate_length):
            print(g)
            phi_output_logits = model(inputs_embeds=combined_embeds)['logits'] 
            predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) 
            predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) 
            predicted_caption[:,g] = predicted_word_token.view(1,-1)
            next_token_embeds = model.get_input_embeddings()(predicted_word_token) 
            combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
            
        predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]

    return predicted_captions_decoded


examples=[{'text':"I am planning to buy a dog and a cat. Suggest some breeds that get along with each other", 'files':[]},
          {'text':"Explain biased coin flip", 'files':[]},
           {'text': "I want to buy a house. Suggest some factors to consider while making final decision", 'files':[]}]

gr.ChatInterface(infer, chatbot=gr.Chatbot(height=600),  theme="soft", examples=examples,
                title="Phi-3 Multimodel Assistant", multimodal=True).launch()