Lava_phi_model / app.py
sagar007's picture
Update app.py
117542a verified
raw
history blame
5.2 kB
import os
import gradio as gr
import torch
from peft import LoraConfig, get_peft_model
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
from PIL import Image
import clip
import spaces
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MultimodalPhi(nn.Module):
def __init__(self, phi_model):
super().__init__()
self.phi_model = phi_model
self.embedding_projection = nn.Linear(512, phi_model.config.hidden_size)
def forward(self, image_embeddings, input_ids, attention_mask):
projected_embeddings = self.embedding_projection(image_embeddings).unsqueeze(1)
inputs_embeds = self.phi_model.get_input_embeddings()(input_ids)
combined_embeds = torch.cat([projected_embeddings, inputs_embeds], dim=1)
extended_attention_mask = torch.cat([torch.ones(attention_mask.shape[0], 1).to(attention_mask.device), attention_mask], dim=1)
outputs = self.phi_model(inputs_embeds=combined_embeds, attention_mask=extended_attention_mask)
return outputs.logits[:, 1:, :] # Exclude the image token from output
def load_models():
try:
print("Loading models...")
peft_model_name = "sagar007/phi-1_5-finetuned"
# Manually load and create LoraConfig, ignoring unknown arguments
config_dict = LoraConfig.from_pretrained(peft_model_name).to_dict()
# Remove 'layer_replication' if present
config_dict.pop('layer_replication', None)
lora_config = LoraConfig(**config_dict)
print("PEFT config loaded")
base_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
print("Base model loaded")
phi_model = get_peft_model(base_model, lora_config)
phi_model.load_state_dict(torch.load(peft_model_name + '/adapter_model.bin', map_location=device), strict=False)
print("PEFT model loaded")
multimodal_model = MultimodalPhi(phi_model)
multimodal_model.load_state_dict(torch.load('multimodal_phi_small_gpu.pth', map_location=device))
multimodal_model.to(device)
multimodal_model.eval()
print("Multimodal model loaded")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer loaded")
audio_model = whisper.load_model("base").to(device)
print("Audio model loaded")
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
print("CLIP model loaded")
return multimodal_model, tokenizer, audio_model, clip_model, clip_preprocess
except Exception as e:
print(f"Error in load_models: {str(e)}")
raise
model, tokenizer, audio_model, clip_model, clip_preprocess = load_models()
@spaces.GPU
def get_clip_embedding(image):
image = clip_preprocess(Image.open(image)).unsqueeze(0).to(device)
with torch.no_grad():
image_features = clip_model.encode_image(image)
return image_features.squeeze(0)
@spaces.GPU
def process_text(text):
try:
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding='max_length').to(device)
dummy_image_embedding = torch.zeros(512).to(device) # Dummy image embedding for text-only input
with torch.no_grad():
outputs = model(dummy_image_embedding.unsqueeze(0), inputs.input_ids, inputs.attention_mask)
return tokenizer.decode(outputs[0].argmax(dim=-1), skip_special_tokens=True)
except Exception as e:
return f"Error in process_text: {str(e)}"
@spaces.GPU
def process_image(image):
try:
clip_embedding = get_clip_embedding(image)
prompt = "Describe this image:"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128, padding='max_length').to(device)
with torch.no_grad():
outputs = model(clip_embedding.unsqueeze(0), inputs.input_ids, inputs.attention_mask)
return tokenizer.decode(outputs[0].argmax(dim=-1), skip_special_tokens=True)
except Exception as e:
return f"Error in process_image: {str(e)}"
@spaces.GPU
def process_audio(audio):
try:
result = audio_model.transcribe(audio)
transcription = result["text"]
return process_text(f"Transcription: {transcription}\nPlease respond to this:")
except Exception as e:
return f"Error in process_audio: {str(e)}"
def chat(message, image, audio):
if audio is not None:
return process_audio(audio)
elif image is not None:
return process_image(image)
else:
return process_text(message)
iface = gr.Interface(
fn=chat,
inputs=[
gr.Textbox(placeholder="Enter text here..."),
gr.Image(type="pil"),
gr.Audio(type="filepath")
],
outputs="text",
title="Multi-Modal Assistant",
description="Chat with an AI using text, images, or audio!"
)
if __name__ == "__main__":
print("Starting Gradio interface...")
iface.launch(share=True)