AkashDataScience's picture
Added inference
b52bed7
raw
history blame
5.84 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
import torch
import clip
from PIL import Image
import torch.nn as nn
from model import Projections
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import gradio as gr
import librosa
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
projections = Projections(512, 3072)
projections.load_state_dict(torch.load('checkpoint_dir/checkpoint-6000/projection_layer/pytorch_model.bin', map_location=device), strict=False)
projections = projections.to(device)
projections = projections.to(torch.bfloat16)
checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
model_kwargs = dict(
use_cache=False,
trust_remote_code=True,
attn_implementation='eager',
torch_dtype=torch.bfloat16,
device_map=None
)
base_model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
new_model = "checkpoint_dir/checkpoint-6000/phi_model" # change to the path where your model is saved
model = PeftModel.from_pretrained(base_model, new_model)
model = model.merge_and_unload()
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
tokenizer.model_max_length = 2048
tokenizer.pad_token = tokenizer.unk_token # use unk rather than eos token to prevent endless generation
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
tokenizer.padding_side = 'right'
tokenizer.chat_template = "{% for message in messages %}{% if message['from'] == 'system' %}{{'<|system|>' + message['value'] + '<|end|>'}}{% elif message['from'] ==\
'human' %}{{'<|user|>' + message['value'] + '<|end|>'}}{% elif message['from'] == 'gpt' %}{{'<|assistant|>' + message['value'] +\
'<|end|>'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}"
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
# Load Whisper model and processor
whisper_model_name = "openai/whisper-small"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
def infer(message, history):
max_generate_length = 100
combined_embeds = []
with torch.no_grad():
if message['file']:
projected_image_embeds = None
audio_text_embeds = None
for path in message['file']:
if path.endswith(('.jpg', '.png', '.jpeg')):
image = clip_preprocess(Image.open(path)).unsqueeze(0).to(device)
image_features = clip_model.encode_image(image)
projected_image_embeds = projections(image_features.to(torch.bfloat16)).unsqueeze(0)
elif path.endswith(('.mp3', '.wav')):
# Load and preprocess the audio
speech, rate = librosa.load(path, sr=16000)
input_features = whisper_processor(speech, return_tensors="pt", sampling_rate=16000).input_features
predicted_ids = whisper_model.generate(input_features)
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
prompt = tokenizer.apply_chat_template([{"from": "human", "value": transcription}], tokenize=False, add_generation_prompt=True)
prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids']
audio_text_embeds = model.get_input_embeddings()(prompt_tokens)
if projected_image_embeds:
combined_embeds.append(projected_image_embeds)
if audio_text_embeds:
combined_embeds.append(audio_text_embeds)
if message['text']:
prompt = tokenizer.apply_chat_template([{"from": "human", "value": transcription}], tokenize=False, add_generation_prompt=True)
prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids']
text_embeds = model.get_input_embeddings()(prompt_tokens)
combined_embeds.append(text_embeds)
combined_embeds = torch.cat(combined_embeds,dim=1)
#val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
predicted_caption = torch.full((1,max_generate_length),50256).to(device)
for g in range(max_generate_length):
phi_output_logits = model(inputs_embeds=combined_embeds)['logits'] # 4, 69, 51200
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
predicted_caption[:,g] = predicted_word_token.view(1,-1)
next_token_embeds = model.get_input_embeddings()(prompt_tokens) # 4,1,2560
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
return predicted_captions_decoded
examples=[{'text':"I am planning to buy a dog and a cat. Suggest some breeds that get along with each other"},
{'text':"Explain biased coin flip"},
{'text': "I want to buy a house. Suggest some factors to consider while making final decision"}]
gr.ChatInterface(infer, chatbot=gr.Chatbot(height=600),
textbox=gr.Textbox(placeholder="How can I help you today", container=False,
scale=7), theme="soft", examples=examples,
title="Phi-3 Multimodel Assistant", multimodal=True).launch()