File size: 3,882 Bytes
5fd09d2
91efa54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
044f8b8
dc9ce74
91efa54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fd09d2
d7b8b8a
5fd09d2
 
dc9ce74
 
 
 
 
 
5fd09d2
 
 
 
91efa54
 
d7b8b8a
 
dc9ce74
91efa54
8a1a2c6
91efa54
 
dc1a6d9
91efa54
 
dc9ce74
91efa54
dc9ce74
91efa54
 
 
 
dc9ce74
91efa54
 
 
 
dc9ce74
91efa54
 
dc9ce74
 
91efa54
dc9ce74
91efa54
dc9ce74
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import spaces
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler
from PIL import Image
import torch
import numpy as np
import cv2
import gradio as gr
from torchvision import transforms 

controlnet = ControlNetModel.from_pretrained(
    "briaai/BRIA-2.2-ControlNet-Recoloring",
    torch_dtype=torch.float16
).to('cuda')

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "briaai/BRIA-2.2",
    controlnet=controlnet,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    offload_state_dict=True,
).to('cuda').to(torch.float16)

pipe.scheduler = EulerAncestralDiscreteScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    num_train_timesteps=1000,
    steps_offset=1
)
pipe.force_zeros_for_empty_prompt = False

def resize_image(image):
    image = image.convert('RGB')
    current_size = image.size
    if current_size[0] > current_size[1]:
        center_cropped_image = transforms.functional.center_crop(image, (current_size[1], current_size[1]))
    else:
        center_cropped_image = transforms.functional.center_crop(image, (current_size[0], current_size[0]))
    resized_image = transforms.functional.resize(center_cropped_image, (1024, 1024))
    return resized_image

@spaces.GPU
def generate_(prompt, negative_prompt, grayscale_image, num_steps, controlnet_conditioning_scale, seed):
    generator = torch.Generator("cuda").manual_seed(seed)    
    images = pipe(
        prompt, 
        negative_prompt=negative_prompt, 
        image=grayscale_image, 
        num_inference_steps=num_steps, 
        controlnet_conditioning_scale=float(controlnet_conditioning_scale),
        generator=generator,
    ).images
    return images

@spaces.GPU
def process(input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed):
    input_image = resize_image(input_image)
    grayscale_image = input_image.convert('L').convert('RGB')
    images = generate_(prompt, negative_prompt, grayscale_image, num_steps, controlnet_conditioning_scale, seed)
    return grayscale_image, images[0]

block = gr.Blocks()

with block:
    gr.Markdown("## BRIA 2.2 ControlNet Recoloring")
    gr.HTML('''
      <p style="margin-bottom: 10px; font-size: 94%">
        This is a demo for ControlNet Recoloring that uses
        <a href="https://huggingface.co/briaai/BRIA-2.2" target="_blank">BRIA 2.2 text-to-image model</a> as backbone. 
        Trained on licensed data, BRIA 2.2 provides full legal liability coverage for copyright and privacy infringement.
      </p>
    ''')
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(sources=None, type="pil")
            prompt = gr.Textbox(label="Prompt")
            negative_prompt = gr.Textbox(label="Negative prompt", value="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers")
            num_steps = gr.Slider(label="Number of steps", minimum=25, maximum=100, value=50, step=1)
            controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05)
            seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
            run_button = gr.Button(value="Run")
        with gr.Column():
            preview = gr.Image(label="Control (Grayscale)", type="pil")
            result = gr.Image(label="Output Image", type="pil")
    ips = [input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed]
    run_button.click(fn=process, inputs=ips, outputs=[preview, result])

block.launch(debug=True)