Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import threading | |
import base64 | |
from io import BytesIO | |
from groq import Groq | |
from diffusers import StableDiffusionPipeline | |
import torch | |
# πΉ Initialize Groq API Client (FREE) | |
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
# πΉ Load Text-to-Image Models (Updated to use diffusers for Stable Diffusion) | |
model1 = gr.load("models/prithivMLmods/SD3.5-Turbo-Realism-2.0-LoRA") | |
model2 = gr.load("models/Purz/face-projection") | |
# β Load Stable Diffusion Model Properly (Using diffusers) | |
model3 = StableDiffusionPipeline.from_pretrained( | |
"stablediffusion/stable-diffusion-2-1", | |
torch_dtype=torch.float16 | |
).to("cuda") # Move to GPU if available | |
# πΉ Stop Event for Threading | |
stop_event = threading.Event() | |
# πΉ Convert PIL image to Base64 | |
def pil_to_base64(pil_image, image_format="jpeg"): | |
buffered = BytesIO() | |
pil_image.save(buffered, format=image_format) | |
base64_string = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return base64_string, image_format | |
# πΉ Function for Visual Question Answering (VQA) with Mixtral-8x7B | |
def answer_question(text, image, temperature=0.0, max_tokens=1024): | |
base64_string, file_format = pil_to_base64(image) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": text}, | |
{"type": "image_url", "image_url": f"data:image/{file_format};base64,{base64_string}"} | |
] | |
} | |
] | |
chat_response = groq_client.chat.completions.create( | |
model="mixtral-8x7b-32768", | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
return chat_response.choices[0].message.content | |
# πΉ Function to Generate Three Images (Multi-Output) | |
def generate_images(prompt): | |
stop_event.clear() | |
img1 = model1.predict(prompt) | |
img2 = model2.predict(prompt) | |
# β Fix: Use Stable Diffusion correctly | |
img3 = model3(prompt).images[0] # Get first image | |
return img1, img2, img3 | |
# πΉ Clear All Fields | |
def clear_all(): | |
return "", None, "", None, None, None | |
# πΉ Set up Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# π AI Tutor, VQA & Image Generation") | |
# πΉ Section 1: Visual Question Answering (Groq) | |
gr.Markdown("## πΌοΈ Visual Question Answering (Mixtral-8x7B)") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
question = gr.Textbox(placeholder="Ask about the image...", lines=2) | |
image = gr.Image(type="pil") | |
with gr.Row(): | |
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.0, step=0.1) | |
max_tokens = gr.Slider(label="Max Tokens", minimum=128, maximum=2048, value=1024, step=128) | |
with gr.Column(scale=3): | |
output_text = gr.Textbox(lines=10, label="Mixtral VQA Response") | |
with gr.Row(): | |
clear_btn = gr.Button("Clear", variant="secondary") | |
submit_btn_vqa = gr.Button("Submit", variant="primary") | |
# πΉ Section 2: Image Generation (3 Outputs) | |
gr.Markdown("## π¨ AI-Generated Images (3 Variations)") | |
with gr.Row(): | |
prompt = gr.Textbox(placeholder="Describe the image you want...", lines=2) | |
generate_btn = gr.Button("Generate Images", variant="primary") | |
with gr.Row(): | |
image1 = gr.Image(label="Image 1") | |
image2 = gr.Image(label="Image 2") | |
image3 = gr.Image(label="Image 3") | |
# πΉ VQA Processing | |
submit_btn_vqa.click( | |
fn=answer_question, | |
inputs=[question, image, temperature, max_tokens], | |
outputs=[output_text] | |
) | |
# πΉ Image Generation Processing | |
generate_btn.click( | |
fn=generate_images, | |
inputs=[prompt], | |
outputs=[image1, image2, image3] | |
) | |
# πΉ Clear All Inputs | |
clear_btn.click( | |
fn=clear_all, | |
inputs=[], | |
outputs=[question, image, output_text, image1, image2, image3] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |