File size: 7,166 Bytes
8125207 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import gradio as gr
import time
from typing import Dict, List, Optional, Callable
class MultiModelImageGenerator:
"""
## Multi-Model Stable Diffusion Image Generation Framework
### Core Design Principles
- Flexible model loading and management
- Concurrent image generation support
- Robust error handling
- Configurable generation strategies
### Technical Components
- Dynamic model function registration
- Fallback mechanism for model loading
- Task tracking and management
"""
def __init__(
self,
models: List[str],
default_model_path: str = 'models/'
):
"""
Initialize multi-model image generation system.
Args:
models (List[str]): List of model paths for image generation
default_model_path (str): Base path for model loading
"""
self.models = models
self.default_model_path = default_model_path
self.model_functions: Dict[int, Callable] = {}
self._initialize_models()
def _initialize_models(self):
"""
Load and initialize image generation models with fallback mechanism.
Strategy:
- Attempt to load each model
- Provide default no-op function if loading fails
"""
for model_idx, model_path in enumerate(self.models, 1):
try:
# Attempt to load model with Gradio interface
model_fn = gr.Interface.load(
f"{self.default_model_path}{model_path}",
live=False,
preprocess=True,
postprocess=False
)
self.model_functions[model_idx] = model_fn
except Exception as error:
# Fallback: Create a no-op function
def fallback_fn(txt):
return None
self.model_functions[model_idx] = gr.Interface(
fn=fallback_fn,
inputs=["text"],
outputs=["image"]
)
def generate_with_model(
self,
model_idx: int,
prompt: str
) -> Optional[gr.Image]:
"""
Generate image using specified model with intelligent fallback.
Args:
model_idx (int): Index of model to use
prompt (str): Generation prompt
Returns:
Generated image or None if generation fails
"""
# Use specified model, fallback to first model if not available
selected_model = (
self.model_functions.get(str(model_idx)) or
self.model_functions.get(str(1))
)
return selected_model(prompt)
def create_gradio_interface(self) -> gr.Blocks:
"""
Create Gradio interface for multi-model image generation.
Returns:
Configurable Gradio Blocks interface
"""
with gr.Blocks(title="Multi-Model Stable Diffusion", theme="Nymbo/Nymbo_Theme") as interface:
with gr.Column(scale=12):
with gr.Row():
primary_prompt = gr.Textbox(label="Generation Prompt", value="")
with gr.Row():
run_btn = gr.Button("Generate", variant="primary")
clear_btn = gr.Button("Clear")
# Dynamic output image grid
sd_outputs = {}
for model_idx, model_path in enumerate(self.models, 1):
with gr.Column(scale=3, min_width=320):
with gr.Box():
sd_outputs[model_idx] = gr.Image(label=model_path)
# Task tracking components
with gr.Row(visible=False):
start_box = gr.Number(interactive=False)
end_box = gr.Number(interactive=False)
task_status_box = gr.Textbox(value=0, interactive=False)
# Event bindings
def start_task():
t_stamp = time.time()
return (
gr.update(value=t_stamp),
gr.update(value=t_stamp),
gr.update(value=0)
)
def check_task_status(cnt, t_stamp):
current_time = time.time()
timeout = t_stamp + 60
if current_time > timeout and t_stamp != 0:
return gr.update(value=0), gr.update(value=1)
else:
return (
gr.update(value=current_time if cnt != 0 else 0),
gr.update(value=0)
)
def clear_interface():
return tuple([None] + [None] * len(self.models))
# Task management events
start_box.change(
check_task_status,
[start_box, end_box],
[start_box, task_status_box],
every=1,
show_progress=False
)
primary_prompt.submit(start_task, None, [start_box, end_box, task_status_box])
run_btn.click(start_task, None, [start_box, end_box, task_status_box])
# Dynamic model generation events
generation_tasks = {}
for model_idx, model_path in enumerate(self.models, 1):
generation_tasks[model_idx] = run_btn.click(
self.generate_with_model,
inputs=[gr.Number(model_idx), primary_prompt],
outputs=[sd_outputs[model_idx]]
)
# Clear button handler
clear_btn.click(
clear_interface,
None,
[primary_prompt, *list(sd_outputs.values())],
cancels=list(generation_tasks.values())
)
return interface
def launch(self, **kwargs):
"""
Launch Gradio interface with configurable parameters.
Args:
**kwargs: Gradio launch configuration parameters
"""
interface = self.create_gradio_interface()
interface.queue(concurrency_count=600, status_update_rate=0.1)
interface.launch(**kwargs)
def main():
"""
Demonstration of Multi-Model Image Generation Framework
"""
models = [
"doohickey/neopian-diffusion",
"dxli/duck_toy",
"dxli/bear_plushie",
"haor/Evt_V4-preview",
"Yntec/Dreamscapes_n_Dragonfire_v2"
]
image_generator = MultiModelImageGenerator(models)
image_generator.launch(inline=True, show_api=False)
if __name__ == "__main__":
main() |