File size: 5,618 Bytes
dfae4e0 b52bed7 dfae4e0 b52bed7 6473f86 b52bed7 6473f86 b52bed7 4764317 b52bed7 898922b b52bed7 18f12de b52bed7 dfae4e0 9214f47 dfae4e0 9214f47 ec0c631 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel
import torch
import clip
from PIL import Image
import torch.nn as nn
from model import Projections
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import gradio as gr
import librosa
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
projections = Projections(512, 3072)
projections.load_state_dict(torch.load('checkpoint_dir/checkpoint-6000/projection_layer/pytorch_model.bin', map_location=device), strict=False)
projections = projections.to(device)
projections = projections.to(torch.bfloat16)
checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
model_kwargs = dict(
use_cache=False,
trust_remote_code=True,
attn_implementation='eager',
torch_dtype=torch.bfloat16,
device_map=None
)
base_model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
new_model = "checkpoint_dir/checkpoint-6000/phi_model" # change to the path where your model is saved
model = PeftModel.from_pretrained(base_model, new_model)
model = model.merge_and_unload()
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
tokenizer.model_max_length = 2048
tokenizer.pad_token = tokenizer.unk_token # use unk rather than eos token to prevent endless generation
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
tokenizer.padding_side = 'right'
tokenizer.chat_template = "{% for message in messages %}{% if message['from'] == 'system' %}{{'<|system|>' + message['value'] + '<|end|>'}}{% elif message['from'] ==\
'human' %}{{'<|user|>' + message['value'] + '<|end|>'}}{% elif message['from'] == 'gpt' %}{{'<|assistant|>' + message['value'] +\
'<|end|>'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}"
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
# Load Whisper model and processor
whisper_model_name = "openai/whisper-small"
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
def infer(message, history):
max_generate_length = 100
combined_embeds = []
with torch.no_grad():
if message['files']:
projected_image_embeds = None
audio_text_embeds = None
for path in message['files']:
if path.endswith(('.jpg', '.png', '.jpeg')):
image = clip_preprocess(Image.open(path)).unsqueeze(0).to(device)
image_features = clip_model.encode_image(image)
projected_image_embeds = projections(image_features.to(torch.bfloat16)).unsqueeze(0)
elif path.endswith(('.mp3', '.wav')):
# Load and preprocess the audio
speech, rate = librosa.load(path, sr=16000)
input_features = whisper_processor(speech, return_tensors="pt", sampling_rate=16000).input_features
predicted_ids = whisper_model.generate(input_features)
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)
prompt = tokenizer.apply_chat_template([{"from": "human", "value": transcription}], tokenize=False, add_generation_prompt=True)
prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids']
audio_text_embeds = model.get_input_embeddings()(prompt_tokens)
if projected_image_embeds:
combined_embeds.append(projected_image_embeds)
if audio_text_embeds:
combined_embeds.append(audio_text_embeds)
if message['text']:
prompt = tokenizer.apply_chat_template([{"from": "human", "value": message['text']}], tokenize=False, add_generation_prompt=True)
prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids']
text_embeds = model.get_input_embeddings()(prompt_tokens)
combined_embeds.append(text_embeds)
combined_embeds = torch.cat(combined_embeds,dim=1)
predicted_caption = torch.full((1,max_generate_length),50256).to(device)
for g in range(max_generate_length):
print(g)
phi_output_logits = model(inputs_embeds=combined_embeds)['logits']
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1)
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
predicted_caption[:,g] = predicted_word_token.view(1,-1)
next_token_embeds = model.get_input_embeddings()(predicted_word_token)
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
return predicted_captions_decoded
examples=[{'text':"I am planning to buy a dog and a cat. Suggest some breeds that get along with each other", 'files':[]},
{'text':"Explain biased coin flip", 'files':[]},
{'text': "I want to buy a house. Suggest some factors to consider while making final decision", 'files':[]}]
gr.ChatInterface(infer, chatbot=gr.Chatbot(height=600), theme="soft", examples=examples,
title="Phi-3 Multimodel Assistant", multimodal=True).launch() |