Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import torch | |
from peft import LoraConfig, get_peft_model | |
import torch.nn as nn | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from peft import PeftModel, PeftConfig | |
from PIL import Image | |
import clip | |
import spaces | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class MultimodalPhi(nn.Module): | |
def __init__(self, phi_model): | |
super().__init__() | |
self.phi_model = phi_model | |
self.embedding_projection = nn.Linear(512, phi_model.config.hidden_size) | |
def forward(self, image_embeddings, input_ids, attention_mask): | |
projected_embeddings = self.embedding_projection(image_embeddings).unsqueeze(1) | |
inputs_embeds = self.phi_model.get_input_embeddings()(input_ids) | |
combined_embeds = torch.cat([projected_embeddings, inputs_embeds], dim=1) | |
extended_attention_mask = torch.cat([torch.ones(attention_mask.shape[0], 1).to(attention_mask.device), attention_mask], dim=1) | |
outputs = self.phi_model(inputs_embeds=combined_embeds, attention_mask=extended_attention_mask) | |
return outputs.logits[:, 1:, :] # Exclude the image token from output | |
def load_models(): | |
try: | |
print("Loading models...") | |
peft_model_name = "sagar007/phi-1_5-finetuned" | |
# Manually load and create LoraConfig, ignoring unknown arguments | |
config_dict = LoraConfig.from_pretrained(peft_model_name).to_dict() | |
# Remove 'layer_replication' if present | |
config_dict.pop('layer_replication', None) | |
lora_config = LoraConfig(**config_dict) | |
print("PEFT config loaded") | |
base_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) | |
print("Base model loaded") | |
phi_model = get_peft_model(base_model, lora_config) | |
phi_model.load_state_dict(torch.load(peft_model_name + '/adapter_model.bin', map_location=device), strict=False) | |
print("PEFT model loaded") | |
multimodal_model = MultimodalPhi(phi_model) | |
multimodal_model.load_state_dict(torch.load('multimodal_phi_small_gpu.pth', map_location=device)) | |
multimodal_model.to(device) | |
multimodal_model.eval() | |
print("Multimodal model loaded") | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5") | |
tokenizer.pad_token = tokenizer.eos_token | |
print("Tokenizer loaded") | |
audio_model = whisper.load_model("base").to(device) | |
print("Audio model loaded") | |
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) | |
print("CLIP model loaded") | |
return multimodal_model, tokenizer, audio_model, clip_model, clip_preprocess | |
except Exception as e: | |
print(f"Error in load_models: {str(e)}") | |
raise | |
model, tokenizer, audio_model, clip_model, clip_preprocess = load_models() | |
def get_clip_embedding(image): | |
image = clip_preprocess(Image.open(image)).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
image_features = clip_model.encode_image(image) | |
return image_features.squeeze(0) | |
def process_text(text): | |
try: | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding='max_length').to(device) | |
dummy_image_embedding = torch.zeros(512).to(device) # Dummy image embedding for text-only input | |
with torch.no_grad(): | |
outputs = model(dummy_image_embedding.unsqueeze(0), inputs.input_ids, inputs.attention_mask) | |
return tokenizer.decode(outputs[0].argmax(dim=-1), skip_special_tokens=True) | |
except Exception as e: | |
return f"Error in process_text: {str(e)}" | |
def process_image(image): | |
try: | |
clip_embedding = get_clip_embedding(image) | |
prompt = "Describe this image:" | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128, padding='max_length').to(device) | |
with torch.no_grad(): | |
outputs = model(clip_embedding.unsqueeze(0), inputs.input_ids, inputs.attention_mask) | |
return tokenizer.decode(outputs[0].argmax(dim=-1), skip_special_tokens=True) | |
except Exception as e: | |
return f"Error in process_image: {str(e)}" | |
def process_audio(audio): | |
try: | |
result = audio_model.transcribe(audio) | |
transcription = result["text"] | |
return process_text(f"Transcription: {transcription}\nPlease respond to this:") | |
except Exception as e: | |
return f"Error in process_audio: {str(e)}" | |
def chat(message, image, audio): | |
if audio is not None: | |
return process_audio(audio) | |
elif image is not None: | |
return process_image(image) | |
else: | |
return process_text(message) | |
iface = gr.Interface( | |
fn=chat, | |
inputs=[ | |
gr.Textbox(placeholder="Enter text here..."), | |
gr.Image(type="pil"), | |
gr.Audio(type="filepath") | |
], | |
outputs="text", | |
title="Multi-Modal Assistant", | |
description="Chat with an AI using text, images, or audio!" | |
) | |
if __name__ == "__main__": | |
print("Starting Gradio interface...") | |
iface.launch(share=True) |