Lava_phi_model / app.py
sagar007's picture
Update app.py
8d741e2 verified
raw
history blame
5.72 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
import torch
from PIL import Image
import os
# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load model and tokenizer with optimizations for CPU deployment
def load_model():
print("Loading model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(
"sagar007/Lava_phi",
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16,
low_cpu_mem_usage=True,
)
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained("sagar007/Lava_phi")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
print("Model and tokenizer loaded successfully!")
return model, tokenizer, processor
# Load models
model, tokenizer, processor = load_model()
# For text-only generation
def generate_text(prompt, max_length=128):
try:
inputs = tokenizer(f"human: {prompt}\ngpt:", return_tensors="pt").to(device)
# Generate with low memory footprint settings
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the model's response
if "gpt:" in generated_text:
generated_text = generated_text.split("gpt:", 1)[1].strip()
return generated_text
except Exception as e:
return f"Error generating text: {str(e)}"
# For image and text processing
def process_image_and_prompt(image, prompt, max_length=128):
try:
if image is None:
return "No image provided. Please upload an image."
# Process image
image_tensor = processor(images=image, return_tensors="pt").pixel_values.to(device)
# Tokenize input with image token
inputs = tokenizer(f"human: <image>\n{prompt}\ngpt:", return_tensors="pt").to(device)
# Generate with memory optimizations
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
images=image_tensor,
max_new_tokens=max_length,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the model's response
if "gpt:" in generated_text:
generated_text = generated_text.split("gpt:", 1)[1].strip()
return generated_text
except Exception as e:
return f"Error processing image: {str(e)}"
# Create Gradio Interface
with gr.Blocks(title="LLaVA-Phi: Vision-Language Model") as demo:
gr.Markdown("# LLaVA-Phi: Vision-Language Model")
gr.Markdown("This model can generate text responses from text prompts or analyze images with text prompts.")
with gr.Tab("Text Generation"):
with gr.Row():
with gr.Column():
text_input = gr.Textbox(label="Enter your prompt", lines=3, placeholder="What is artificial intelligence?")
text_max_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Maximum response length")
text_button = gr.Button("Generate")
text_output = gr.Textbox(label="Generated response", lines=8)
text_button.click(
fn=generate_text,
inputs=[text_input, text_max_length],
outputs=text_output
)
with gr.Tab("Image + Text Analysis"):
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload an image")
image_text_input = gr.Textbox(label="Enter your prompt about the image",
lines=2,
placeholder="Describe this image in detail.")
image_max_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Maximum response length")
image_button = gr.Button("Analyze")
image_output = gr.Textbox(label="Model response", lines=8)
image_button.click(
fn=process_image_and_prompt,
inputs=[image_input, image_text_input, image_max_length],
outputs=image_output
)
# Example inputs for each tab
gr.Examples(
examples=["What is the advantage of vision-language models?",
"Explain how multimodal AI models work.",
"Tell me a short story about robots."],
inputs=text_input
)
# Add examples for image tab if you have example images
# gr.Examples(
# examples=[["example1.jpg", "What's in this image?"]],
# inputs=[image_input, image_text_input]
# )
# Launch the app with memory optimizations
if __name__ == "__main__":
# Memory cleanup before launch
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Set low CPU thread usage to reduce memory
os.environ["OMP_NUM_THREADS"] = "4"
# Launch with minimal resource usage
demo.launch(
share=True, # Set to False in production
enable_queue=True,
max_threads=4,
show_error=True
)