from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from peft import PeftModel import torch import clip from PIL import Image import torch.nn as nn from model import Projections from transformers import WhisperProcessor, WhisperForConditionalGeneration import gradio as gr import librosa device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') projections = Projections(512, 3072) projections.load_state_dict(torch.load('checkpoint_dir/checkpoint-6000/projection_layer/pytorch_model.bin', map_location=device), strict=False) projections = projections.to(device) projections = projections.to(torch.bfloat16) checkpoint_path = "microsoft/Phi-3-mini-4k-instruct" model_kwargs = dict( use_cache=False, trust_remote_code=True, attn_implementation='eager', torch_dtype=torch.bfloat16, device_map=None ) base_model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs) new_model = "checkpoint_dir/checkpoint-6000/phi_model" # change to the path where your model is saved model = PeftModel.from_pretrained(base_model, new_model) model = model.merge_and_unload() model = model.to(device) tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True) tokenizer.model_max_length = 2048 tokenizer.pad_token = tokenizer.unk_token # use unk rather than eos token to prevent endless generation tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) tokenizer.padding_side = 'right' tokenizer.chat_template = "{% for message in messages %}{% if message['from'] == 'system' %}{{'<|system|>' + message['value'] + '<|end|>'}}{% elif message['from'] ==\ 'human' %}{{'<|user|>' + message['value'] + '<|end|>'}}{% elif message['from'] == 'gpt' %}{{'<|assistant|>' + message['value'] +\ '<|end|>'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}" clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) # Load Whisper model and processor whisper_model_name = "openai/whisper-small" whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name) whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name) def infer(message, history): max_generate_length = 100 combined_embeds = [] with torch.no_grad(): if message['files']: projected_image_embeds = None audio_text_embeds = None for path in message['files']: if path.endswith(('.jpg', '.png', '.jpeg')): image = clip_preprocess(Image.open(path)).unsqueeze(0).to(device) image_features = clip_model.encode_image(image) projected_image_embeds = projections(image_features.to(torch.bfloat16)).unsqueeze(0) elif path.endswith(('.mp3', '.wav')): # Load and preprocess the audio speech, rate = librosa.load(path, sr=16000) input_features = whisper_processor(speech, return_tensors="pt", sampling_rate=16000).input_features predicted_ids = whisper_model.generate(input_features) transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) prompt = tokenizer.apply_chat_template([{"from": "human", "value": transcription}], tokenize=False, add_generation_prompt=True) prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids'] audio_text_embeds = model.get_input_embeddings()(prompt_tokens) if projected_image_embeds: combined_embeds.append(projected_image_embeds) if audio_text_embeds: combined_embeds.append(audio_text_embeds) if message['text']: prompt = tokenizer.apply_chat_template([{"from": "human", "value": message['text']}], tokenize=False, add_generation_prompt=True) prompt_tokens = tokenizer(prompt, padding=True, truncation=True, max_length=2048, return_tensors="pt")['input_ids'] text_embeds = model.get_input_embeddings()(prompt_tokens) combined_embeds.append(text_embeds) combined_embeds = torch.cat(combined_embeds,dim=1) predicted_caption = torch.full((1,max_generate_length),50256).to(device) for g in range(max_generate_length): print(g) phi_output_logits = model(inputs_embeds=combined_embeds)['logits'] predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) predicted_caption[:,g] = predicted_word_token.view(1,-1) next_token_embeds = model.get_input_embeddings()(predicted_word_token) combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1) predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0] return predicted_captions_decoded examples=[{'text':"I am planning to buy a dog and a cat. Suggest some breeds that get along with each other", 'files':[]}, {'text':"Explain biased coin flip", 'files':[]}, {'text': "I want to buy a house. Suggest some factors to consider while making final decision", 'files':[]}] gr.ChatInterface(infer, chatbot=gr.Chatbot(height=600), theme="soft", examples=examples, title="Phi-3 Multimodel Assistant", multimodal=True).launch()