AkashDataScience's picture
Minor fix
18f12de
raw
history blame
5.62 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
import torch
import clip
from PIL import Image
import torch.nn as nn
from model import Projections
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import gradio as gr
import librosa
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
projections = Projections(512, 3072)
projections.load_state_dict(torch.load('checkpoint_dir/checkpoint-6000/projection_layer/pytorch_model.bin', map_location=device), strict=False)
projections = projections.to(device)
projections = projections.to(torch.bfloat16)
checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
model_kwargs = dict(
use_cache=False,
trust_remote_code=True,
attn_implementation='eager',
torch_dtype=torch.bfloat16,
device_map=None
)
base_model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
new_model = "checkpoint_dir/checkpoint-6000/phi_model" # change to the path where your model is saved
model = PeftModel.from_pretrained(base_model, new_model)
model = model.merge_and_unload()
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
tokenizer.model_max_length = 2048
tokenizer.pad_token = tokenizer.unk_token # use unk rather than eos token to prevent endless generation
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
tokenizer.padding_side = 'right'
tokenizer.chat_template = "{% for message in messages %}{% if message['from'] == 'system' %}{{'<|system|>' + message['value'] + '<|end|>'}}{% elif message['from'] ==\
'human' %}{{'<|user|>' + message['value'] + '<|end|>'}}{% elif message['from'] == 'gpt' %}{{'<|assistant|>' + message['value'] +\
'<|end|>'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}"
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
# Load Whisper model and processor
whisper_model_name = "openai/whisper-small"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
def infer(message, history):
max_generate_length = 100
combined_embeds = []
with torch.no_grad():
if message['files']:
projected_image_embeds = None
audio_text_embeds = None
for path in message['files']:
if path.endswith(('.jpg', '.png', '.jpeg')):
image = clip_preprocess(Image.open(path)).unsqueeze(0).to(device)
image_features = clip_model.encode_image(image)
projected_image_embeds = projections(image_features.to(torch.bfloat16)).unsqueeze(0)
elif path.endswith(('.mp3', '.wav')):
# Load and preprocess the audio
speech, rate = librosa.load(path, sr=16000)
input_features = whisper_processor(speech, return_tensors="pt", sampling_rate=16000).input_features
predicted_ids = whisper_model.generate(input_features)
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
prompt = tokenizer.apply_chat_template([{"from": "human", "value": transcription}], tokenize=False, add_generation_prompt=True)
prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids']
audio_text_embeds = model.get_input_embeddings()(prompt_tokens)
if projected_image_embeds:
combined_embeds.append(projected_image_embeds)
if audio_text_embeds:
combined_embeds.append(audio_text_embeds)
if message['text']:
prompt = tokenizer.apply_chat_template([{"from": "human", "value": message['text']}], tokenize=False, add_generation_prompt=True)
prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids']
text_embeds = model.get_input_embeddings()(prompt_tokens)
combined_embeds.append(text_embeds)
combined_embeds = torch.cat(combined_embeds,dim=1)
predicted_caption = torch.full((1,max_generate_length),50256).to(device)
for g in range(max_generate_length):
print(g)
phi_output_logits = model(inputs_embeds=combined_embeds)['logits']
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1)
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
predicted_caption[:,g] = predicted_word_token.view(1,-1)
next_token_embeds = model.get_input_embeddings()(predicted_word_token)
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
return predicted_captions_decoded
examples=[{'text':"I am planning to buy a dog and a cat. Suggest some breeds that get along with each other", 'files':[]},
{'text':"Explain biased coin flip", 'files':[]},
{'text': "I want to buy a house. Suggest some factors to consider while making final decision", 'files':[]}]
gr.ChatInterface(infer, chatbot=gr.Chatbot(height=600), theme="soft", examples=examples,
title="Phi-3 Multimodel Assistant", multimodal=True).launch()