GroundBi / app.py
naonauno's picture
Update app.py
0071020 verified
import gradio as gr
import torch
import numpy as np
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
from model import UNet2DConditionModelEx
from pipeline import StableDiffusionControlLoraV3Pipeline
from PIL import Image
import os
from huggingface_hub import login
import spaces
import random
from pathlib import Path
import hashlib
import datetime
import json
from tqdm import tqdm
# Login using the token
login(token=os.environ.get("HF_TOKEN"))
# Setup directories
HF_SPACE_ID = "naonauno/groundbi-factory"
OUTPUT_DIR = "/home/user/outputs"
os.makedirs('outputs', exist_ok=True)
os.makedirs('metadata', exist_ok=True)
metadata_dir = 'metadata'
class AdvancedGenerationTracker:
def __init__(self, total_steps):
self.progress_bar = tqdm(total=total_steps, desc="Image Generation")
self.current_step = 0
self.memory_usage_log = []
def update_progress(self, step_size=1):
self.current_step += step_size
self.progress_bar.update(step_size)
self._log_memory_usage()
def _log_memory_usage(self):
if torch.cuda.is_available():
memory_info = {
'step': self.current_step,
'cuda_allocated': torch.cuda.memory_allocated(),
'cuda_reserved': torch.cuda.memory_reserved(),
'cuda_max_allocated': torch.cuda.max_memory_allocated()
}
self.memory_usage_log.append(memory_info)
def finalize(self):
self.progress_bar.close()
return self.memory_usage_log
def setup_pipeline():
unet = UNet2DConditionModelEx.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="unet"
)
unet = unet.add_extra_conditions("ow-gbi-control-lora")
pipe = StableDiffusionControlLoraV3Pipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
unet=unet
)
# Performance optimizations
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
pipe.load_lora_weights(
"models",
weight_name="40kHalf.safetensors"
)
return pipe
pipe = setup_pipeline()
def save_to_space(image, filename):
path = os.path.join(OUTPUT_DIR, filename)
os.makedirs(os.path.dirname(path), exist_ok=True)
image.save(path)
return path
def generate_advanced_filename(prompt, seed, style=None):
hash_input = f"{prompt}_{seed}"
filename_hash = hashlib.md5(hash_input.encode()).hexdigest()[:8]
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
style_prefix = f"{style}_" if style else ""
return f"{style_prefix}{timestamp}_{filename_hash}"
def export_generation_metadata(metadata, output_path):
with open(output_path, 'w') as f:
json.dump(metadata, f, indent=2)
return output_path
@spaces.GPU(duration=180)
def generate_image(
image,
prompt,
negative_prompt,
guidance_scale,
steps,
seed,
strength=0.8,
num_images=1,
progress=gr.Progress()
):
if image is None:
raise gr.Error("Please provide an input image!")
try:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_base_dir = os.path.join('outputs', timestamp)
os.makedirs(output_base_dir, exist_ok=True)
if seed is not None and seed != "":
try:
generator = torch.Generator().manual_seed(int(seed))
current_seed = int(seed)
except ValueError:
generator = torch.Generator()
current_seed = random.randint(1, 1000000)
else:
generator = torch.Generator()
current_seed = random.randint(1, 1000000)
tracker = AdvancedGenerationTracker(steps)
def callback_on_step_end(pipeline, step, timestep, callback_kwargs):
tracker.update_progress()
if progress is not None:
progress(step/steps)
return {}
progress(0.3, desc="Generating image...")
with torch.no_grad():
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=int(steps),
guidance_scale=float(guidance_scale),
image=image,
strength=strength,
extra_condition_scale=1.0,
generator=generator,
num_images_per_prompt=num_images,
callback_on_step_end=callback_on_step_end
)
generated_image = result.images[0]
# Save the image
filename = generate_advanced_filename(prompt, current_seed)
image_path = os.path.join(output_base_dir, f"{filename}.png")
generated_image.save(image_path)
save_to_space(generated_image, f"{filename}.png")
# Save metadata
generation_metadata = {
"generation_timestamp": timestamp,
"prompt": prompt,
"negative_prompt": negative_prompt,
"seed": current_seed,
"generation_parameters": {
"guidance_scale": guidance_scale,
"steps": steps,
"strength": strength,
"num_images": num_images
},
"image_file": os.path.basename(image_path)
}
metadata_path = os.path.join(metadata_dir, f"{filename}_metadata.json")
export_generation_metadata(generation_metadata, metadata_path)
memory_log = tracker.finalize()
progress(1.0, desc="Done!")
return generated_image
except Exception as e:
raise gr.Error(f"An error occurred: {str(e)}")
css = """
.container { max-width: 900px; margin: auto; }
.parameter-hint { font-size: 0.8em; color: #666; margin-top: -5px; }
"""
# Create the Gradio interface
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# Terrain Generator
⚠️ Warning: This is a demo running on ZeroGPU. Generation might take a few minutes.
For best results, use 15-20 steps for generation.
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the terrain..."
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="What to avoid..."
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=20,
value=7.5,
info="Higher = more prompt adherence, Lower = more creativity"
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50,
value=20,
info="More steps = higher quality but slower"
)
seed = gr.Textbox(
label="Seed (empty for random)",
placeholder="Enter a number for reproducible results",
info="Controls randomness. Same seed = same output."
)
generate = gr.Button("Generate")
with gr.Column():
result = gr.Image(label="Generated Image")
generate.click(
fn=generate_image,
inputs=[
input_image,
prompt,
negative_prompt,
guidance_scale,
steps,
seed
],
outputs=result
)
demo.queue()
demo.launch(share=True)