turbo_fe / app_base.py
Sqxww's picture
modify brightness step
177db88
raw
history blame
5.69 kB
import spaces
import gradio as gr
import time
import torch
import os
import gc
from PIL import Image, ImageEnhance
from segment_utils import(
segment_image,
restore_result_and_save,
)
from enhance_utils import enhance_sd_image
from inversion_run_base import run as base_run
DEFAULT_SRC_PROMPT = "a person"
DEFAULT_EDIT_PROMPT = "a person with perfect face"
DEFAULT_CATEGORY = "face"
@spaces.GPU(duration=10)
@torch.inference_mode()
@torch.no_grad()
def image_to_image(
input_image: Image,
input_image_prompt: str,
edit_prompt: str,
seed: int,
w1: float,
num_steps: int,
start_step: int,
guidance_scale: float,
brightness: float = 1.0,
):
w2 = 1.0
run_task_time = 0
time_cost_str = ''
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
target_area_image = input_image
run_model = base_run
try:
res_image = run_model(
target_area_image,
input_image_prompt,
edit_prompt ,
seed,
w1,
w2,
num_steps,
start_step,
guidance_scale,
)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'run_sd_model done')
finally:
torch.cuda.empty_cache()
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'cuda_empty_cache done')
enhanced_image = res_image
enhanced_image = enhance_sd_image(res_image)
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'enhance_image done')
torch.cuda.empty_cache()
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'cuda_empty_cache done')
if os.getenv('ENABLE_GC', False):
gc.collect()
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str, 'gc_collect done')
enhancer = ImageEnhance.Brightness(enhanced_image)
enhanced_image = enhancer.enhance(brightness)
return enhanced_image, time_cost_str
def get_time_cost(
run_task_time,
time_cost_str,
step: str = ''
):
now_time = int(time.time()*1000)
if run_task_time == 0:
time_cost_str = 'start'
else:
if time_cost_str != '':
time_cost_str += f'-->'
time_cost_str += f'{now_time - run_task_time}'
if step != '':
time_cost_str += f'-->{step}'
run_task_time = now_time
return run_task_time, time_cost_str
def resize_image(image, target_size = 1024):
h, w = image.size
if h >= w:
w = int(w * target_size / h)
h = target_size
else:
h = int(h * target_size / w)
w = target_size
return image.resize((w, h))
def create_demo() -> gr.Blocks:
with gr.Blocks() as demo:
cropper = gr.State()
with gr.Row():
with gr.Column():
input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
brightness = gr.Slider(minimum=0, maximum=2, value=1.0, step=0.05, label="Brightness")
with gr.Accordion("Advanced Options", open=False):
category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
save_quality = gr.Slider(minimum=1, maximum=100, value=95, step=1, label="Save Quality")
with gr.Column():
num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
start_step = gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Start Step")
g_btn = gr.Button("Edit Image")
with gr.Accordion("Advanced Options", open=False):
guidance_scale = gr.Slider(minimum=0, maximum=20, value=0, step=0.5, label="Guidance Scale")
seed = gr.Number(label="Seed", value=8)
w1 = gr.Number(label="W1", value=1.5)
generate_size = gr.Number(label="Generate Size", value=1024)
with gr.Row():
with gr.Column():
origin_area_image = gr.Image(label="Origin Area Image", format="png", type="pil", interactive=False)
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
with gr.Column():
enhanced_image = gr.Image(label="Enhanced Image", format="png", type="pil", interactive=False)
restored_image = gr.Image(label="Restored Image", format="png", type="pil", interactive=False)
download_path = gr.File(label="Download the output image", interactive=False)
generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
g_btn.click(
fn=segment_image,
inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
outputs=[origin_area_image, cropper],
).success(
fn=image_to_image,
inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, brightness],
outputs=[enhanced_image, generated_cost],
).success(
fn=restore_result_and_save,
inputs=[cropper, category, enhanced_image, save_quality],
outputs=[restored_image, download_path],
)
return demo