|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler |
|
from model import UNet2DConditionModelEx |
|
from pipeline import StableDiffusionControlLoraV3Pipeline |
|
from PIL import Image |
|
import os |
|
from huggingface_hub import login |
|
import spaces |
|
import random |
|
from pathlib import Path |
|
import hashlib |
|
import datetime |
|
import json |
|
from tqdm import tqdm |
|
|
|
|
|
login(token=os.environ.get("HF_TOKEN")) |
|
|
|
|
|
HF_SPACE_ID = "naonauno/groundbi-factory" |
|
OUTPUT_DIR = "/home/user/outputs" |
|
|
|
os.makedirs('outputs', exist_ok=True) |
|
os.makedirs('metadata', exist_ok=True) |
|
metadata_dir = 'metadata' |
|
|
|
class AdvancedGenerationTracker: |
|
def __init__(self, total_steps): |
|
self.progress_bar = tqdm(total=total_steps, desc="Image Generation") |
|
self.current_step = 0 |
|
self.memory_usage_log = [] |
|
|
|
def update_progress(self, step_size=1): |
|
self.current_step += step_size |
|
self.progress_bar.update(step_size) |
|
self._log_memory_usage() |
|
|
|
def _log_memory_usage(self): |
|
if torch.cuda.is_available(): |
|
memory_info = { |
|
'step': self.current_step, |
|
'cuda_allocated': torch.cuda.memory_allocated(), |
|
'cuda_reserved': torch.cuda.memory_reserved(), |
|
'cuda_max_allocated': torch.cuda.max_memory_allocated() |
|
} |
|
self.memory_usage_log.append(memory_info) |
|
|
|
def finalize(self): |
|
self.progress_bar.close() |
|
return self.memory_usage_log |
|
|
|
def setup_pipeline(): |
|
unet = UNet2DConditionModelEx.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
subfolder="unet" |
|
) |
|
unet = unet.add_extra_conditions("ow-gbi-control-lora") |
|
|
|
pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
unet=unet |
|
) |
|
|
|
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
pipe.enable_attention_slicing() |
|
pipe.enable_vae_slicing() |
|
|
|
pipe.load_lora_weights( |
|
"models", |
|
weight_name="40kHalf.safetensors" |
|
) |
|
return pipe |
|
|
|
pipe = setup_pipeline() |
|
|
|
def save_to_space(image, filename): |
|
path = os.path.join(OUTPUT_DIR, filename) |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
image.save(path) |
|
return path |
|
|
|
def generate_advanced_filename(prompt, seed, style=None): |
|
hash_input = f"{prompt}_{seed}" |
|
filename_hash = hashlib.md5(hash_input.encode()).hexdigest()[:8] |
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
style_prefix = f"{style}_" if style else "" |
|
return f"{style_prefix}{timestamp}_{filename_hash}" |
|
|
|
def export_generation_metadata(metadata, output_path): |
|
with open(output_path, 'w') as f: |
|
json.dump(metadata, f, indent=2) |
|
return output_path |
|
|
|
@spaces.GPU(duration=180) |
|
def generate_image( |
|
image, |
|
prompt, |
|
negative_prompt, |
|
guidance_scale, |
|
steps, |
|
seed, |
|
strength=0.8, |
|
num_images=1, |
|
progress=gr.Progress() |
|
): |
|
if image is None: |
|
raise gr.Error("Please provide an input image!") |
|
|
|
try: |
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
output_base_dir = os.path.join('outputs', timestamp) |
|
os.makedirs(output_base_dir, exist_ok=True) |
|
|
|
if seed is not None and seed != "": |
|
try: |
|
generator = torch.Generator().manual_seed(int(seed)) |
|
current_seed = int(seed) |
|
except ValueError: |
|
generator = torch.Generator() |
|
current_seed = random.randint(1, 1000000) |
|
else: |
|
generator = torch.Generator() |
|
current_seed = random.randint(1, 1000000) |
|
|
|
tracker = AdvancedGenerationTracker(steps) |
|
|
|
def callback_on_step_end(pipeline, step, timestep, callback_kwargs): |
|
tracker.update_progress() |
|
if progress is not None: |
|
progress(step/steps) |
|
return {} |
|
|
|
progress(0.3, desc="Generating image...") |
|
with torch.no_grad(): |
|
result = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=int(steps), |
|
guidance_scale=float(guidance_scale), |
|
image=image, |
|
strength=strength, |
|
extra_condition_scale=1.0, |
|
generator=generator, |
|
num_images_per_prompt=num_images, |
|
callback_on_step_end=callback_on_step_end |
|
) |
|
|
|
generated_image = result.images[0] |
|
|
|
|
|
filename = generate_advanced_filename(prompt, current_seed) |
|
image_path = os.path.join(output_base_dir, f"{filename}.png") |
|
generated_image.save(image_path) |
|
save_to_space(generated_image, f"{filename}.png") |
|
|
|
|
|
generation_metadata = { |
|
"generation_timestamp": timestamp, |
|
"prompt": prompt, |
|
"negative_prompt": negative_prompt, |
|
"seed": current_seed, |
|
"generation_parameters": { |
|
"guidance_scale": guidance_scale, |
|
"steps": steps, |
|
"strength": strength, |
|
"num_images": num_images |
|
}, |
|
"image_file": os.path.basename(image_path) |
|
} |
|
|
|
metadata_path = os.path.join(metadata_dir, f"{filename}_metadata.json") |
|
export_generation_metadata(generation_metadata, metadata_path) |
|
|
|
memory_log = tracker.finalize() |
|
progress(1.0, desc="Done!") |
|
|
|
return generated_image |
|
|
|
except Exception as e: |
|
raise gr.Error(f"An error occurred: {str(e)}") |
|
|
|
css = """ |
|
.container { max-width: 900px; margin: auto; } |
|
.parameter-hint { font-size: 0.8em; color: #666; margin-top: -5px; } |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown( |
|
""" |
|
# Terrain Generator |
|
⚠️ Warning: This is a demo running on ZeroGPU. Generation might take a few minutes. |
|
For best results, use 15-20 steps for generation. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(label="Input Image", type="pil") |
|
|
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
placeholder="Describe the terrain..." |
|
) |
|
negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
placeholder="What to avoid..." |
|
) |
|
guidance_scale = gr.Slider( |
|
label="Guidance Scale", |
|
minimum=1, |
|
maximum=20, |
|
value=7.5, |
|
info="Higher = more prompt adherence, Lower = more creativity" |
|
) |
|
steps = gr.Slider( |
|
label="Steps", |
|
minimum=1, |
|
maximum=50, |
|
value=20, |
|
info="More steps = higher quality but slower" |
|
) |
|
seed = gr.Textbox( |
|
label="Seed (empty for random)", |
|
placeholder="Enter a number for reproducible results", |
|
info="Controls randomness. Same seed = same output." |
|
) |
|
generate = gr.Button("Generate") |
|
|
|
with gr.Column(): |
|
result = gr.Image(label="Generated Image") |
|
|
|
generate.click( |
|
fn=generate_image, |
|
inputs=[ |
|
input_image, |
|
prompt, |
|
negative_prompt, |
|
guidance_scale, |
|
steps, |
|
seed |
|
], |
|
outputs=result |
|
) |
|
|
|
demo.queue() |
|
demo.launch(share=True) |