Spaces:
Running
Running
| import gradio as gr # pyright: ignore[reportMissingTypeStubs] | |
| import pillow_heif # pyright: ignore[reportMissingTypeStubs] | |
| import spaces # pyright: ignore[reportMissingTypeStubs] | |
| import torch | |
| from PIL import Image | |
| from refiners.fluxion.utils import manual_seed, no_grad | |
| from utils import LightingPreference, load_ic_light, resize_modulo_8 | |
| pillow_heif.register_heif_opener() # pyright: ignore[reportUnknownMemberType] | |
| pillow_heif.register_avif_opener() # pyright: ignore[reportUnknownMemberType] | |
| TITLE = """ | |
| # IC-Light with Refiners | |
| """ | |
| # initialize the enhancer, on the cpu | |
| DEVICE_CPU = torch.device("cpu") | |
| DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 | |
| ic_light = load_ic_light(device=DEVICE_CPU, dtype=DTYPE) | |
| # "move" the enhancer to the gpu, this is handled/intercepted by Zero GPU | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ic_light.to(device=DEVICE, dtype=DTYPE) | |
| ic_light.device = DEVICE | |
| ic_light.dtype = DTYPE | |
| ic_light.solver = ic_light.solver.to(device=DEVICE, dtype=DTYPE) | |
| def process( | |
| image: Image.Image, | |
| light_pref: str, | |
| prompt: str, | |
| negative_prompt: str, | |
| strength_first_pass: float, | |
| strength_second_pass: float, | |
| condition_scale: float, | |
| num_inference_steps: int, | |
| seed: int, | |
| ) -> Image.Image: | |
| assert image.mode == "RGBA" | |
| assert 0 <= strength_second_pass <= 1 | |
| assert 0 <= strength_first_pass <= 1 | |
| assert num_inference_steps > 0 | |
| assert seed >= 0 | |
| # set the seed | |
| manual_seed(seed) | |
| # resize image to ~768x768 | |
| image = resize_modulo_8(image, 768) | |
| # split RGB and alpha channel | |
| mask = image.getchannel("A") | |
| image = image.convert("RGB") | |
| # compute embeddings | |
| clip_text_embedding = ic_light.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) | |
| ic_light.set_ic_light_condition(image=image, mask=mask) | |
| # get the light_pref_image | |
| light_pref_image = LightingPreference.from_str(value=light_pref).get_init_image( | |
| width=image.width, | |
| height=image.height, | |
| interval=(0.2, 0.8), | |
| ) | |
| # if no light preference is provided, do a full strength first pass | |
| if light_pref_image is None: | |
| x = torch.randn_like(ic_light._ic_light_condition) # pyright: ignore[reportPrivateUsage] | |
| strength_first_pass = 1.0 | |
| else: | |
| x = ic_light.lda.image_to_latents(light_pref_image) | |
| x = ic_light.solver.add_noise(x, noise=torch.randn_like(x), step=0) | |
| # configure the first pass | |
| num_steps = int(round(num_inference_steps / strength_first_pass)) | |
| first_step = int(num_steps * (1 - strength_first_pass)) | |
| ic_light.set_inference_steps(num_steps, first_step) | |
| # first pass | |
| for step in ic_light.steps: | |
| x = ic_light( | |
| x, | |
| step=step, | |
| clip_text_embedding=clip_text_embedding, | |
| condition_scale=condition_scale, | |
| ) | |
| # configure the second pass | |
| num_steps = int(round(num_inference_steps / strength_second_pass)) | |
| first_step = int(num_steps * (1 - strength_second_pass)) | |
| ic_light.set_inference_steps(num_steps, first_step) | |
| # initialize the latents | |
| x = ic_light.solver.add_noise(x, noise=torch.randn_like(x), step=first_step) | |
| # second pass | |
| for step in ic_light.steps: | |
| x = ic_light( | |
| x, | |
| step=step, | |
| clip_text_embedding=clip_text_embedding, | |
| condition_scale=condition_scale, | |
| ) | |
| return ic_light.lda.latents_to_image(x) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(TITLE) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Input Image", image_mode="RGBA") | |
| run_button = gr.Button(value="Relight Image") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Result") | |
| with gr.Accordion("Advanced Settings", open=True): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="bright green neon light, best quality, highres", | |
| ) | |
| neg_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| placeholder="worst quality, low quality, normal quality", | |
| ) | |
| light_pref = gr.Radio( | |
| choices=["None", "Left", "Right", "Top", "Bottom"], | |
| label="Light direction preference", | |
| value="None", | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=100_000, | |
| value=69_420, | |
| step=1, | |
| ) | |
| condition_scale = gr.Slider( | |
| label="Condition scale", | |
| minimum=0.5, | |
| maximum=2, | |
| value=1.25, | |
| step=0.05, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=1, | |
| maximum=50, | |
| value=25, | |
| step=1, | |
| ) | |
| with gr.Row(): | |
| strength_first_pass = gr.Slider( | |
| label="Strength of the first pass", | |
| minimum=0, | |
| maximum=1, | |
| value=0.9, | |
| step=0.1, | |
| ) | |
| strength_second_pass = gr.Slider( | |
| label="Strength of the second pass", | |
| minimum=0, | |
| maximum=1, | |
| value=0.5, | |
| step=0.1, | |
| ) | |
| run_button.click( | |
| fn=process, | |
| inputs=[ | |
| input_image, | |
| light_pref, | |
| prompt, | |
| neg_prompt, | |
| strength_first_pass, | |
| strength_second_pass, | |
| condition_scale, | |
| num_inference_steps, | |
| seed, | |
| ], | |
| outputs=output_image, | |
| ) | |
| gr.Examples( # pyright: ignore[reportUnknownMemberType] | |
| examples=[ | |
| [ | |
| "examples/plant.png", | |
| "None", | |
| "blue purple neon light, cyberpunk city background, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF", | |
| "dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", | |
| 0.9, | |
| 0.5, | |
| 1.25, | |
| 25, | |
| 69_420, | |
| ], | |
| [ | |
| "examples/plant.png", | |
| "Right", | |
| "blue purple neon light, cyberpunk city background, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF", | |
| "dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", | |
| 0.9, | |
| 0.5, | |
| 1.25, | |
| 25, | |
| 69_420, | |
| ], | |
| [ | |
| "examples/plant.png", | |
| "Left", | |
| "floor is blue ice cavern, stalactite, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF", | |
| "dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", | |
| 0.9, | |
| 0.5, | |
| 1.25, | |
| 25, | |
| 69_420, | |
| ], | |
| ], | |
| inputs=[ | |
| input_image, | |
| light_pref, | |
| prompt, | |
| neg_prompt, | |
| strength_first_pass, | |
| strength_second_pass, | |
| condition_scale, | |
| num_inference_steps, | |
| seed, | |
| ], | |
| outputs=output_image, | |
| fn=process, | |
| cache_examples="lazy", # type: ignore | |
| run_on_click=False, | |
| ) | |
| demo.launch() | |