|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
from peft import PeftModel |
|
import torch |
|
import clip |
|
from PIL import Image |
|
import torch.nn as nn |
|
from model import Projections |
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration |
|
import gradio as gr |
|
import librosa |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
projections = Projections(512, 3072) |
|
projections.load_state_dict(torch.load('checkpoint_dir/checkpoint-6000/projection_layer/pytorch_model.bin', map_location=device), strict=False) |
|
projections = projections.to(device) |
|
projections = projections.to(torch.bfloat16) |
|
|
|
checkpoint_path = "microsoft/Phi-3-mini-4k-instruct" |
|
model_kwargs = dict( |
|
use_cache=False, |
|
trust_remote_code=True, |
|
attn_implementation='eager', |
|
torch_dtype=torch.bfloat16, |
|
device_map=None |
|
) |
|
base_model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs) |
|
|
|
new_model = "checkpoint_dir/checkpoint-6000/phi_model" |
|
|
|
model = PeftModel.from_pretrained(base_model, new_model) |
|
model = model.merge_and_unload() |
|
model = model.to(device) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True) |
|
tokenizer.model_max_length = 2048 |
|
tokenizer.pad_token = tokenizer.unk_token |
|
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) |
|
tokenizer.padding_side = 'right' |
|
tokenizer.chat_template = "{% for message in messages %}{% if message['from'] == 'system' %}{{'<|system|>' + message['value'] + '<|end|>'}}{% elif message['from'] ==\ |
|
'human' %}{{'<|user|>' + message['value'] + '<|end|>'}}{% elif message['from'] == 'gpt' %}{{'<|assistant|>' + message['value'] +\ |
|
'<|end|>'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}" |
|
|
|
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) |
|
|
|
|
|
whisper_model_name = "openai/whisper-small" |
|
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name) |
|
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name) |
|
|
|
def infer(message, history): |
|
max_generate_length = 100 |
|
combined_embeds = [] |
|
|
|
with torch.no_grad(): |
|
if message['files']: |
|
projected_image_embeds = None |
|
audio_text_embeds = None |
|
for path in message['files']: |
|
|
|
if path.endswith(('.jpg', '.png', '.jpeg')): |
|
image = clip_preprocess(Image.open(path)).unsqueeze(0).to(device) |
|
image_features = clip_model.encode_image(image) |
|
projected_image_embeds = projections(image_features.to(torch.bfloat16)).unsqueeze(0) |
|
|
|
elif path.endswith(('.mp3', '.wav')): |
|
|
|
speech, rate = librosa.load(path, sr=16000) |
|
input_features = whisper_processor(speech, return_tensors="pt", sampling_rate=16000).input_features |
|
predicted_ids = whisper_model.generate(input_features) |
|
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) |
|
prompt = tokenizer.apply_chat_template([{"from": "human", "value": transcription}], tokenize=False, add_generation_prompt=True) |
|
prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids'] |
|
audio_text_embeds = model.get_input_embeddings()(prompt_tokens) |
|
|
|
if projected_image_embeds: |
|
combined_embeds.append(projected_image_embeds) |
|
|
|
if audio_text_embeds: |
|
combined_embeds.append(audio_text_embeds) |
|
|
|
if message['text']: |
|
prompt = tokenizer.apply_chat_template([{"from": "human", "value": message['text']}], tokenize=False, add_generation_prompt=True) |
|
prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids'] |
|
text_embeds = model.get_input_embeddings()(prompt_tokens) |
|
combined_embeds.append(text_embeds) |
|
|
|
combined_embeds = torch.cat(combined_embeds,dim=1) |
|
|
|
predicted_caption = torch.full((1,max_generate_length),50256).to(device) |
|
|
|
for g in range(max_generate_length): |
|
print(g) |
|
phi_output_logits = model(inputs_embeds=combined_embeds)['logits'] |
|
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) |
|
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) |
|
predicted_caption[:,g] = predicted_word_token.view(1,-1) |
|
next_token_embeds = model.get_input_embeddings()(predicted_word_token) |
|
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1) |
|
|
|
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0] |
|
|
|
return predicted_captions_decoded |
|
|
|
|
|
examples=[{'text':"I am planning to buy a dog and a cat. Suggest some breeds that get along with each other", 'files':[]}, |
|
{'text':"Explain biased coin flip", 'files':[]}, |
|
{'text': "I want to buy a house. Suggest some factors to consider while making final decision", 'files':[]}] |
|
|
|
gr.ChatInterface(infer, chatbot=gr.Chatbot(height=600), theme="soft", examples=examples, |
|
title="Phi-3 Multimodel Assistant", multimodal=True).launch() |