File size: 3,773 Bytes
e4f36a2
 
0576dea
 
 
3afc8c8
e4f36a2
3afc8c8
 
e4f36a2
3afc8c8
78c1c43
 
3afc8c8
78c1c43
3afc8c8
78c1c43
 
3afc8c8
 
0576dea
 
3afc8c8
0576dea
 
3afc8c8
0576dea
 
 
 
 
 
 
 
 
 
 
 
 
3afc8c8
 
0576dea
 
 
 
 
 
 
3afc8c8
 
 
 
 
 
 
0576dea
3afc8c8
0576dea
3afc8c8
0576dea
3afc8c8
e4f36a2
3afc8c8
78c1c43
3afc8c8
 
0576dea
 
 
 
 
 
 
 
 
3afc8c8
78c1c43
0576dea
 
 
 
3afc8c8
 
 
 
 
 
 
 
 
 
 
 
0576dea
 
 
 
 
 
3afc8c8
 
 
 
 
 
 
 
0576dea
 
 
3afc8c8
0576dea
 
e4f36a2
78c1c43
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import os
import threading
import base64
from io import BytesIO
from groq import Groq

# ๐Ÿ”น Initialize Groq API Client (FREE)
groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))

# ๐Ÿ”น Load Text-to-Image Models (Restoring Multi-Image Generation)
model1 = gr.load("models/prithivMLmods/SD3.5-Turbo-Realism-2.0-LoRA")
model2 = gr.load("models/Purz/face-projection")
model3 = gr.load("models/stablediffusion/stable-diffusion-xl")

# ๐Ÿ”น Stop Event for Threading
stop_event = threading.Event()

# ๐Ÿ”น Convert PIL image to Base64
def pil_to_base64(pil_image, image_format="jpeg"):
    buffered = BytesIO()
    pil_image.save(buffered, format=image_format)
    base64_string = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return base64_string, image_format

# ๐Ÿ”น Function for Visual Question Answering (VQA) with Mixtral-8x7B
def answer_question(text, image, temperature=0.0, max_tokens=1024):
    base64_string, file_format = pil_to_base64(image)

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": text},
                {"type": "image_url", "image_url": f"data:image/{file_format};base64,{base64_string}"}
            ]
        }
    ]
    
    chat_response = groq_client.chat.completions.create(
        model="mixtral-8x7b-32768",
        messages=messages,
        temperature=temperature,
        max_tokens=max_tokens
    )

    return chat_response.choices[0].message.content

# ๐Ÿ”น Function to Generate Three Images (Multi-Output)
def generate_images(prompt):
    stop_event.clear()
    img1 = model1.predict(prompt)
    img2 = model2.predict(prompt)
    img3 = model3.predict(prompt)
    return img1, img2, img3

# ๐Ÿ”น Clear All Fields
def clear_all():
    return "", None, "", None, None, None

# ๐Ÿ”น Set up Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# ๐ŸŽ“ AI Tutor, VQA & Image Generation")

    # ๐Ÿ”น Section 1: Visual Question Answering (Groq)
    gr.Markdown("## ๐Ÿ–ผ๏ธ Visual Question Answering (Mixtral-8x7B)")
    with gr.Row():
        with gr.Column(scale=2):
            question = gr.Textbox(placeholder="Ask about the image...", lines=2)
            image = gr.Image(type="pil")
            with gr.Row():
                temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.0, step=0.1)
                max_tokens = gr.Slider(label="Max Tokens", minimum=128, maximum=2048, value=1024, step=128)
        
        with gr.Column(scale=3):
            output_text = gr.Textbox(lines=10, label="Mixtral VQA Response")
    
    with gr.Row():
        clear_btn = gr.Button("Clear", variant="secondary")
        submit_btn_vqa = gr.Button("Submit", variant="primary")

    # ๐Ÿ”น Section 2: Image Generation (3 Outputs)
    gr.Markdown("## ๐ŸŽจ AI-Generated Images (3 Variations)")
    with gr.Row():
        prompt = gr.Textbox(placeholder="Describe the image you want...", lines=2)
        generate_btn = gr.Button("Generate Images", variant="primary")

    with gr.Row():
        image1 = gr.Image(label="Image 1")
        image2 = gr.Image(label="Image 2")
        image3 = gr.Image(label="Image 3")

    # ๐Ÿ”น VQA Processing
    submit_btn_vqa.click(
        fn=answer_question,
        inputs=[question, image, temperature, max_tokens],
        outputs=[output_text]
    )

    # ๐Ÿ”น Image Generation Processing
    generate_btn.click(
        fn=generate_images,
        inputs=[prompt],
        outputs=[image1, image2, image3]
    )

    # ๐Ÿ”น Clear All Inputs
    clear_btn.click(
        fn=clear_all,
        inputs=[],
        outputs=[question, image, output_text, image1, image2, image3]
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)