SameerArz's picture
Update app.py
781c41b verified
raw
history blame
4.08 kB
import gradio as gr
import os
import threading
import base64
from io import BytesIO
from groq import Groq
from diffusers import StableDiffusionPipeline
import torch
# πŸ”Ή Initialize Groq API Client (FREE)
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
# πŸ”Ή Load Text-to-Image Models (Updated to use diffusers for Stable Diffusion)
model1 = gr.load("models/prithivMLmods/SD3.5-Turbo-Realism-2.0-LoRA")
model2 = gr.load("models/Purz/face-projection")
# βœ… Load Stable Diffusion Model Properly (Using diffusers)
model3 = StableDiffusionPipeline.from_pretrained(
"stablediffusion/stable-diffusion-2-1",
torch_dtype=torch.float16
).to("cuda") # Move to GPU if available
# πŸ”Ή Stop Event for Threading
stop_event = threading.Event()
# πŸ”Ή Convert PIL image to Base64
def pil_to_base64(pil_image, image_format="jpeg"):
buffered = BytesIO()
pil_image.save(buffered, format=image_format)
base64_string = base64.b64encode(buffered.getvalue()).decode("utf-8")
return base64_string, image_format
# πŸ”Ή Function for Visual Question Answering (VQA) with Mixtral-8x7B
def answer_question(text, image, temperature=0.0, max_tokens=1024):
base64_string, file_format = pil_to_base64(image)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": text},
{"type": "image_url", "image_url": f"data:image/{file_format};base64,{base64_string}"}
]
}
]
chat_response = groq_client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return chat_response.choices[0].message.content
# πŸ”Ή Function to Generate Three Images (Multi-Output)
def generate_images(prompt):
stop_event.clear()
img1 = model1.predict(prompt)
img2 = model2.predict(prompt)
# βœ… Fix: Use Stable Diffusion correctly
img3 = model3(prompt).images[0] # Get first image
return img1, img2, img3
# πŸ”Ή Clear All Fields
def clear_all():
return "", None, "", None, None, None
# πŸ”Ή Set up Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# πŸŽ“ AI Tutor, VQA & Image Generation")
# πŸ”Ή Section 1: Visual Question Answering (Groq)
gr.Markdown("## πŸ–ΌοΈ Visual Question Answering (Mixtral-8x7B)")
with gr.Row():
with gr.Column(scale=2):
question = gr.Textbox(placeholder="Ask about the image...", lines=2)
image = gr.Image(type="pil")
with gr.Row():
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.0, step=0.1)
max_tokens = gr.Slider(label="Max Tokens", minimum=128, maximum=2048, value=1024, step=128)
with gr.Column(scale=3):
output_text = gr.Textbox(lines=10, label="Mixtral VQA Response")
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary")
submit_btn_vqa = gr.Button("Submit", variant="primary")
# πŸ”Ή Section 2: Image Generation (3 Outputs)
gr.Markdown("## 🎨 AI-Generated Images (3 Variations)")
with gr.Row():
prompt = gr.Textbox(placeholder="Describe the image you want...", lines=2)
generate_btn = gr.Button("Generate Images", variant="primary")
with gr.Row():
image1 = gr.Image(label="Image 1")
image2 = gr.Image(label="Image 2")
image3 = gr.Image(label="Image 3")
# πŸ”Ή VQA Processing
submit_btn_vqa.click(
fn=answer_question,
inputs=[question, image, temperature, max_tokens],
outputs=[output_text]
)
# πŸ”Ή Image Generation Processing
generate_btn.click(
fn=generate_images,
inputs=[prompt],
outputs=[image1, image2, image3]
)
# πŸ”Ή Clear All Inputs
clear_btn.click(
fn=clear_all,
inputs=[],
outputs=[question, image, output_text, image1, image2, image3]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)