Spaces:
Runtime error
Runtime error
import gradio as gr | |
import peft | |
from peft import LoraConfig | |
from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor | |
import torch | |
from peft import PeftModel | |
import torch.nn as nn | |
import whisperx | |
import os | |
clip_model_name = "openai/clip-vit-base-patch32" | |
phi_model_name = "microsoft/phi-2" | |
# Tokenizers and Processors: The tokenizer tokenizes text, and the processor handles preprocessing for images. | |
# Embedding sizes: clip_embed (768) is for the CLIP model, and phi_embed (2560) is for the Phi-2 model. | |
# Device: It selects CUDA if a GPU is available, otherwise, it uses the CPU. | |
# IMAGE_TOKEN_ID: Token ID reserved for images. | |
tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True) | |
processor = AutoProcessor.from_pretrained(clip_model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
IMAGE_TOKEN_ID = 23893 # token for word comment | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
clip_embed = 768 | |
phi_embed = 2560 | |
compute_type = "float32" | |
audio_batch_size = 16 | |
# This defines a simple residual block that uses a layer normalization (LayerNorm) followed by two linear layers with a GELU activation function in between. | |
# The block is used to add learned transformations to the embeddings, which helps in stabilizing learning and improving generalization. | |
class SimpleResBlock(nn.Module): | |
def __init__(self, phi_embed): | |
super().__init__() | |
self.pre_norm = nn.LayerNorm(phi_embed) | |
self.proj = nn.Sequential( | |
nn.Linear(phi_embed, phi_embed), | |
nn.GELU(), | |
nn.Linear(phi_embed, phi_embed) | |
) | |
def forward(self, x): | |
x = self.pre_norm(x) | |
return x + self.proj(x) | |
# models | |
# CLIP Vision Model: Pretrained on visual tasks, outputs image embeddings. | |
# Projection Layer: Projects the clip_embed (768) dimensions to phi_embed (2560) to match the embedding sizes for downstream tasks. | |
# Residual Block: Uses the custom SimpleResBlock to process the embeddings further. | |
# Phi-2 Model: The language model handles text generation tasks. | |
clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device) | |
projection = torch.nn.Linear(clip_embed, phi_embed).to(device) | |
resblock = SimpleResBlock(phi_embed).to(device) | |
phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device) | |
audio_model = whisperx.load_model("tiny", device, compute_type=compute_type, asr_options={'max_new_tokens': 2048, 'clip_timestamps': True, 'hallucination_silence_threshold': 0.25}) | |
# load weights | |
# LoRA Weights: The LoRA-adapted model merges with the Phi-2 model for fine-tuning. | |
# Loading Finetuned Layers: The pre-trained weights for the projection layer and residual block are loaded for further use. | |
model_to_merge = PeftModel.from_pretrained(phi_model,os.path.join(os.getcwd(), 'model_chkpt/lora_adaptor')) | |
merged_model = model_to_merge.merge_and_unload() | |
projection.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetunned_projection.pth'),map_location=torch.device(device))) | |
resblock.load_state_dict(torch.load(os.path.join(os.getcwd(),'model_chkpt/finetuned_resblock.pth'),map_location=torch.device(device))) | |
# Image Handling: Extracts image embeddings, passes through CLIP and a projection layer. | |
# Audio Handling: Transcribes audio with WhisperX, tokenizes it, and embeds the tokens. | |
# Text Handling: Tokenizes the text query and embeds it. | |
# Generating Response: The model generates tokens sequentially, combining inputs from images, audio, and text, and predicting the next token until it generates a full response. | |
def model_generate_ans(img=None,img_audio=None,val_q=None): | |
max_generate_length = 100 | |
val_combined_embeds = [] | |
with torch.no_grad(): | |
# image | |
if img is not None: | |
image_processed = processor(images=img, return_tensors="pt").to(device) | |
clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:] | |
val_image_embeds = projection(clip_val_outputs) | |
val_image_embeds = resblock(val_image_embeds).to(torch.float16) | |
img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device) | |
img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0) | |
val_combined_embeds.append(val_image_embeds) | |
val_combined_embeds.append(img_token_embeds) | |
# audio | |
if img_audio is not None: | |
audio_result = audio_model.transcribe(img_audio) | |
audio_text = '' | |
for seg in audio_result['segments']: | |
audio_text += seg['text'] | |
audio_text = audio_text.strip() | |
audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device) | |
audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0) | |
val_combined_embeds.append(audio_embeds) | |
# text question | |
if len(val_q) != 0: | |
val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device) | |
val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0) | |
val_combined_embeds.append(val_q_embeds) | |
val_combined_embeds = torch.cat(val_combined_embeds,dim=1) | |
#val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560 | |
predicted_caption = torch.full((1,max_generate_length),50256).to(device) | |
for g in range(max_generate_length): | |
phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200 | |
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200 | |
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1 | |
predicted_caption[:,g] = predicted_word_token.view(1,-1) | |
next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560 | |
val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1) | |
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0] | |
# Split the string at the first occurrence of <|endoftext|> | |
result = predicted_captions_decoded.split('<|endoftext|>')[0] | |
return result.strip() # Strip any trailing spaces or newlines | |
#return predicted_captions_decoded | |
with gr.Blocks() as demo: | |
# Add custom CSS stylesheet within Markdown | |
gr.Markdown( | |
""" | |
<style> | |
/* General Layout */ | |
body { | |
font-family: 'Arial', sans-serif; | |
background-color: #ffe4e1; | |
margin: 0; | |
padding: 0; | |
} | |
/* Header */ | |
h1, h2, h3 { | |
text-align: center; | |
color: #3a3a3a; | |
font-weight: bold; | |
} | |
gr-Markdown h1 { | |
font-size: 28px; | |
color: #a3d5d3; /* Soft pastel teal for the header */ | |
} | |
/* Container and Columns */ | |
.gr-row { | |
display: flex; | |
justify-content: center; | |
margin: 20px 0; | |
} | |
.gr-column { | |
flex: 1; | |
margin: 0 10px; | |
padding: 10px; | |
box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.05); | |
background-color: #f8f0fa; /* Pastel pink background for columns */ | |
border-radius: 8px; | |
} | |
/* Input Components */ | |
.gr-Image, .gr-Audio, .gr-Text { | |
width: 100%; | |
margin-bottom: 15px; | |
background-color: #fff5e1; /* Soft pastel yellow for inputs */ | |
border: 1px solid #e3e3e3; | |
border-radius: 8px; | |
} | |
.gr-Image label, .gr-Audio label, .gr-Text label { | |
font-size: 16px; | |
font-weight: bold; | |
color: #8b8b8b; | |
} | |
/* Submit Button */ | |
.gr-Button { | |
width: 100%; | |
background-color: #b2c7e1; /* Pastel blue button */ | |
color: white; | |
padding: 10px; | |
font-size: 16px; | |
border: none; | |
border-radius: 5px; | |
cursor: pointer; | |
transition: background-color 0.3s ease; | |
} | |
.gr-Button:hover { | |
background-color: #9db6d3; /* Darker pastel blue on hover */ | |
} | |
/* Text Output */ | |
.gr-Text { | |
font-size: 16px; | |
color: #333; | |
min-height: 100px; | |
padding: 10px; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
background-color: #edf5e1; /* Light pastel green for the output text box */ | |
} | |
/* Responsive Design */ | |
@media (max-width: 768px) { | |
.gr-row { | |
flex-direction: column; | |
} | |
.gr-column { | |
margin: 10px 0; | |
} | |
} | |
</style> | |
# Engage with MultiModal GPT! | |
A seamless AI experience combining CLIP and Phi-2 models. | |
""" | |
) | |
# app GUI | |
with gr.Row(): | |
with gr.Column(): | |
img_input = gr.Image(label='Image',type="pil") | |
img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath') | |
img_question = gr.Text(label ='Text Query') | |
with gr.Column(): | |
img_answer = gr.Text(label ='Answer') | |
section_btn = gr.Button("Submit") | |
section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer]) | |
demo.launch() |