Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor | |
import torch | |
from PIL import Image | |
import os | |
# Check if CUDA is available, otherwise use CPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load model and tokenizer with optimizations for CPU deployment | |
def load_model(): | |
print("Loading model and tokenizer...") | |
model = AutoModelForCausalLM.from_pretrained( | |
"sagar007/Lava_phi", | |
torch_dtype=torch.float32 if device == "cpu" else torch.bfloat16, | |
low_cpu_mem_usage=True, | |
) | |
model = model.to(device) | |
tokenizer = AutoTokenizer.from_pretrained("sagar007/Lava_phi") | |
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
print("Model and tokenizer loaded successfully!") | |
return model, tokenizer, processor | |
# Load models | |
model, tokenizer, processor = load_model() | |
# For text-only generation | |
def generate_text(prompt, max_length=128): | |
try: | |
inputs = tokenizer(f"human: {prompt}\ngpt:", return_tensors="pt").to(device) | |
# Generate with low memory footprint settings | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_length, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the model's response | |
if "gpt:" in generated_text: | |
generated_text = generated_text.split("gpt:", 1)[1].strip() | |
return generated_text | |
except Exception as e: | |
return f"Error generating text: {str(e)}" | |
# For image and text processing | |
def process_image_and_prompt(image, prompt, max_length=128): | |
try: | |
if image is None: | |
return "No image provided. Please upload an image." | |
# Process image | |
image_tensor = processor(images=image, return_tensors="pt").pixel_values.to(device) | |
# Tokenize input with image token | |
inputs = tokenizer(f"human: <image>\n{prompt}\ngpt:", return_tensors="pt").to(device) | |
# Generate with memory optimizations | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
images=image_tensor, | |
max_new_tokens=max_length, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the model's response | |
if "gpt:" in generated_text: | |
generated_text = generated_text.split("gpt:", 1)[1].strip() | |
return generated_text | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
# Create Gradio Interface | |
with gr.Blocks(title="LLaVA-Phi: Vision-Language Model") as demo: | |
gr.Markdown("# LLaVA-Phi: Vision-Language Model") | |
gr.Markdown("This model can generate text responses from text prompts or analyze images with text prompts.") | |
with gr.Tab("Text Generation"): | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox(label="Enter your prompt", lines=3, placeholder="What is artificial intelligence?") | |
text_max_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Maximum response length") | |
text_button = gr.Button("Generate") | |
text_output = gr.Textbox(label="Generated response", lines=8) | |
text_button.click( | |
fn=generate_text, | |
inputs=[text_input, text_max_length], | |
outputs=text_output | |
) | |
with gr.Tab("Image + Text Analysis"): | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload an image") | |
image_text_input = gr.Textbox(label="Enter your prompt about the image", | |
lines=2, | |
placeholder="Describe this image in detail.") | |
image_max_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Maximum response length") | |
image_button = gr.Button("Analyze") | |
image_output = gr.Textbox(label="Model response", lines=8) | |
image_button.click( | |
fn=process_image_and_prompt, | |
inputs=[image_input, image_text_input, image_max_length], | |
outputs=image_output | |
) | |
# Example inputs for each tab | |
gr.Examples( | |
examples=["What is the advantage of vision-language models?", | |
"Explain how multimodal AI models work.", | |
"Tell me a short story about robots."], | |
inputs=text_input | |
) | |
# Add examples for image tab if you have example images | |
# gr.Examples( | |
# examples=[["example1.jpg", "What's in this image?"]], | |
# inputs=[image_input, image_text_input] | |
# ) | |
# Launch the app with memory optimizations | |
if __name__ == "__main__": | |
# Memory cleanup before launch | |
torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
# Set low CPU thread usage to reduce memory | |
os.environ["OMP_NUM_THREADS"] = "4" | |
# Launch with minimal resource usage | |
demo.launch( | |
share=True, # Set to False in production | |
enable_queue=True, | |
max_threads=4, | |
show_error=True | |
) |