File size: 4,308 Bytes
78ec26d
c4ccad7
7ea27ba
c4ccad7
7ea27ba
868b112
c4ccad7
 
7ea27ba
 
bdd1e49
7ea27ba
bdd1e49
7ea27ba
bdd1e49
7ea27ba
bdd1e49
7ea27ba
bdd1e49
7ea27ba
bdd1e49
 
7ea27ba
3455f8c
7ea27ba
bdd1e49
7ea27ba
 
bdd1e49
7ea27ba
 
bdd1e49
 
 
 
7ea27ba
3455f8c
bdd1e49
7ea27ba
 
3455f8c
bdd1e49
f7bfc02
 
7ea27ba
 
bdd1e49
7ea27ba
 
 
 
 
 
3455f8c
bdd1e49
 
 
 
 
 
 
 
7ea27ba
 
 
 
bdd1e49
 
 
 
7ea27ba
bdd1e49
868b112
bdd1e49
7ea27ba
 
bdd1e49
 
7ea27ba
bdd1e49
7ea27ba
 
 
 
 
 
bdd1e49
 
7ea27ba
bdd1e49
7ea27ba
bdd1e49
7ea27ba
78ec26d
bdd1e49
 
 
 
 
 
 
 
7ea27ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import torch
import random
from diffusers import DiffusionPipeline
from transformers import pipeline

device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
MAX_SEED = 2**32 - 1

# --- Model lists ordered by size (light to heavy) ---
image_models = {
    "Stable Diffusion 1.5 (light)": "runwayml/stable-diffusion-v1-5",
    "Stable Diffusion 2.1": "stabilityai/stable-diffusion-2-1",
    "Dreamlike 2.0": "dreamlike-art/dreamlike-photoreal-2.0",
    "Playground v2": "playgroundai/playground-v2-1024px-aesthetic",
    "Muse 512": "amused/muse-512-finetuned",
    "PixArt": "PixArt-alpha/PixArt-LCM-XL-2-1024-MS",
    "Kandinsky 3": "kandinsky-community/kandinsky-3",
    "BLIP Diffusion": "Salesforce/blipdiffusion",
    "SDXL Base 1.0 (heavy)": "stabilityai/stable-diffusion-xl-base-1.0",
    "OpenJourney (heavy)": "prompthero/openjourney"
}

text_models = {
    "GPT-2 (light)": "gpt2",
    "GPT-Neo 1.3B": "EleutherAI/gpt-neo-1.3B",
    "BLOOM 1.1B": "bigscience/bloom-1b1",
    "GPT-J 6B": "EleutherAI/gpt-j-6B",
    "Falcon 7B": "tiiuae/falcon-7b",
    "XGen 7B": "Salesforce/xgen-7b-8k-base",
    "BTLM 3B": "cerebras/btlm-3b-8k-base",
    "MPT 7B": "mosaicml/mpt-7b",
    "StableLM 2": "stabilityai/stablelm-2-1_6b",
    "LLaMA 2 7B (heavy)": "meta-llama/Llama-2-7b-hf"
}

# Cache
image_pipes = {}
text_pipes = {}

def generate_image(prompt, model_name, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.manual_seed(seed)

    progress(0, desc="Loading model...")
    if model_name not in image_pipes:
        image_pipes[model_name] = DiffusionPipeline.from_pretrained(
            image_models[model_name],
            torch_dtype=torch_dtype
        ).to(device)
    pipe = image_pipes[model_name]

    progress(25, desc="Running inference (step 1/3)...")
    result = pipe(prompt=prompt, generator=generator, num_inference_steps=30, width=512, height=512)

    progress(100, desc="Done.")
    return result.images[0], seed

def generate_text(prompt, model_name, progress=gr.Progress(track_tqdm=True)):
    progress(0, desc="Loading model...")
    if model_name not in text_pipes:
        text_pipes[model_name] = pipeline("text-generation", model=text_models[model_name], device=0 if device == "cuda" else -1)
    pipe = text_pipes[model_name]

    progress(50, desc="Generating text...")
    result = pipe(prompt, max_length=100, do_sample=True)[0]['generated_text']
    progress(100, desc="Done.")
    return result

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# ๐Ÿง  Multi-Model AI Playground with Progress")

    with gr.Tabs():
        # ๐Ÿ–ผ๏ธ Image Gen Tab
        with gr.Tab("๐Ÿ–ผ๏ธ Image Generation"):
            img_prompt = gr.Textbox(label="Prompt")
            img_model = gr.Dropdown(choices=list(image_models.keys()), value="Stable Diffusion 1.5 (light)", label="Image Model")
            img_seed = gr.Slider(0, MAX_SEED, value=42, label="Seed")
            img_rand = gr.Checkbox(label="Randomize seed", value=True)
            img_btn = gr.Button("Generate Image")
            img_out = gr.Image()
            img_btn.click(fn=generate_image, inputs=[img_prompt, img_model, img_seed, img_rand], outputs=[img_out, img_seed])

        # ๐Ÿ“ Text Gen Tab
        with gr.Tab("๐Ÿ“ Text Generation"):
            txt_prompt = gr.Textbox(label="Prompt")
            txt_model = gr.Dropdown(choices=list(text_models.keys()), value="GPT-2 (light)", label="Text Model")
            txt_btn = gr.Button("Generate Text")
            txt_out = gr.Textbox(label="Output Text")
            txt_btn.click(fn=generate_text, inputs=[txt_prompt, txt_model], outputs=txt_out)

        # ๐ŸŽฅ Video Gen Tab (placeholder)
        with gr.Tab("๐ŸŽฅ Video Generation (Placeholder)"):
            gr.Markdown("โš ๏ธ Video generation is placeholder only. Models require special setup.")
            vid_prompt = gr.Textbox(label="Prompt")
            vid_btn = gr.Button("Pretend to Generate")
            vid_out = gr.Textbox(label="Result")
            vid_btn.click(lambda x: f"Fake video output for: {x}", inputs=[vid_prompt], outputs=[vid_out])

demo.launch(show_error=True)