TalHach61's picture
Update app.py
f950302 verified
raw
history blame
10.2 kB
import sys
sys.path.append('./')
import gradio as gr
import spaces
import os
import sys
import subprocess
import numpy as np
from PIL import Image
import cv2
import torch
import random
os.system("pip install -e ./controlnet_aux")
from controlnet_aux import OpenposeDetector #, CannyDetector
from depth_anything_v2.dpt import DepthAnythingV2
from huggingface_hub import hf_hub_download
from huggingface_hub import login
hf_token = os.environ.get("HF_TOKEN")
login(token=hf_token)
MAX_SEED = np.iinfo(np.int32).max
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
ratios_map = {
0.5:{"width":704,"height":1408},
0.57:{"width":768,"height":1344},
0.68:{"width":832,"height":1216},
0.72:{"width":832,"height":1152},
0.78:{"width":896,"height":1152},
0.82:{"width":896,"height":1088},
0.88:{"width":960,"height":1088},
0.94:{"width":960,"height":1024},
1.00:{"width":1024,"height":1024},
1.13:{"width":1088,"height":960},
1.21:{"width":1088,"height":896},
1.29:{"width":1152,"height":896},
1.38:{"width":1152,"height":832},
1.46:{"width":1216,"height":832},
1.67:{"width":1280,"height":768},
1.75:{"width":1344,"height":768},
2.00:{"width":1408,"height":704}
}
ratios = np.array(list(ratios_map.keys()))
encoder = 'vitl'
model = DepthAnythingV2(**model_configs[encoder])
filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-Large", filename=f"depth_anything_v2_vitl.pth", repo_type="model")
state_dict = torch.load(filepath, map_location="cpu")
model.load_state_dict(state_dict)
model = model.to(DEVICE).eval()
from huggingface_hub import hf_hub_download
import os
import torch
from diffusers.utils import load_image
from controlnet_bria import BriaControlNetModel, BriaMultiControlNetModel
from pipeline_bria_controlnet import BriaControlNetPipeline
import PIL.Image as Image
base_model = 'briaai/BRIA-4B-Adapt'
controlnet_model = 'briaai/BRIA-4B-Adapt-ControlNet-Union'
controlnet = BriaControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
controlnet = BriaMultiControlNetModel([controlnet])
pipe = BriaControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16, trust_remote_code=True)
pipe.to("cuda")
mode_mapping = {
"depth": 0,
"canny": 1,
"colorgrid": 2,
"recolor": 3,
"tile": 4,
"pose": 5,
}
strength_mapping = {
"depth": 1.0,
"canny": 1.0,
"colorgrid": 1.0,
"recolor": 1.0,
"tile": 1.0,
"pose": 1.0,
}
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
torch.backends.cuda.matmul.allow_tf32 = True
pipe.enable_model_cpu_offload() # for saving memory
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
def extract_depth(image):
image = np.asarray(image)
depth = model.infer_image(image[:, :, ::-1])
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
gray_depth = Image.fromarray(depth).convert('RGB')
return gray_depth
def extract_openpose(img):
processed_image_open_pose = open_pose(img, hand_and_face=True)
return processed_image_open_pose
def extract_canny(input_image):
image = np.array(input_image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
return canny_image
def convert_to_grayscale(image):
image = convert_from_image_to_cv2(image)
gray_image = convert_from_cv2_to_image(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY))
return gray_image
def tile_old(input_image, resolution=768):
input_image = convert_from_image_to_cv2(input_image)
H, W, C = input_image.shape
H = float(H)
W = float(W)
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(np.round(H / 16.0)) * 16
W = int(np.round(W / 16.0)) * 16
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
img = convert_from_cv2_to_image(img)
return img
def tile(downscale_factor, input_image):
control_image = input_image.resize((input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)).resize(input_image.size, Image.NEAREST)
return control_image
def get_size(init_image):
w,h=init_image.size
curr_ratio = w/h
ind = np.argmin(np.abs(curr_ratio-ratios))
ratio = ratios[ind]
chosen_ratio = ratios_map[ratio]
w,h = chosen_ratio['width'], chosen_ratio['height']
return w,h
def resize_img(image):
image = image.convert('RGB')
w,h = get_size(image)
resized_image = image.resize((w, h))
return resized_image
@spaces.GPU(duration=180)
def infer(cond_in, image_in, prompt, inference_steps, guidance_scale, control_mode, control_strength, seed, progress=gr.Progress(track_tqdm=True)):
control_mode_num = mode_mapping[control_mode]
if cond_in is None:
if image_in is not None:
image_in = resize_img(load_image(image_in))
if control_mode == "canny":
control_image = extract_canny(image_in)
elif control_mode == "depth":
control_image = extract_depth(image_in)
elif control_mode == "pose":
control_image = extract_openpose(image_in)
elif control_mode == "colorgrid":
control_image = tile(64, image_in)
elif control_mode == "recolor":
control_image = convert_to_grayscale(image_in)
elif control_mode == "tile":
control_image = tile(16, image_in)
else:
control_image = resize_img(load_image(cond_in))
width, height = control_image.size
image = pipe(
prompt,
control_image=[control_image],
control_mode=[control_mode_num],
width=width,
height=height,
controlnet_conditioning_scale=[control_strength],
num_inference_steps=inference_steps,
guidance_scale=guidance_scale,
generator=torch.manual_seed(seed),
).images[0]
torch.cuda.empty_cache()
return image, control_image, gr.update(visible=True)
css="""
#col-container{
margin: 0 auto;
max-width: 1080px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""
# BRIA-4B-Adapt-ControlNet-Union
A unified ControlNet for BRIA-4B-Adapt model from Bria.ai. BRIA-4B-Adapt improve the generation of humans and illustrations compared to BRIA 2.3 while still trained on licensed data, and so provides full legal liability coverage for copyright and privacy infringement. Model card: [BRIA-4B-Adapt-ControlNet-Union](https://huggingface.co/briaai/BRIA-4B-Adapt-ControlNet-Union). <br />
""")
with gr.Column():
with gr.Row():
with gr.Column():
with gr.Row(equal_height=True):
cond_in = gr.Image(label="Upload a processed control image", sources=["upload"], type="filepath")
image_in = gr.Image(label="Extract condition from a reference image (Optional)", sources=["upload"], type="filepath")
prompt = gr.Textbox(label="Prompt", value="best quality")
with gr.Accordion("Controlnet"):
control_mode = gr.Radio(
["depth", "canny", "colorgrid", "recolor", "tile", "pose"], label="Mode", value="canny",
info="select the control mode, one for all"
)
control_strength = gr.Slider(
label="control strength",
minimum=0,
maximum=1.0,
step=0.05,
value=0.9,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Accordion("Advanced settings", open=False):
with gr.Column():
with gr.Row():
inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=50, step=1, value=24)
guidance_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=3.5)
submit_btn = gr.Button("Submit")
with gr.Column():
result = gr.Image(label="Result")
processed_cond = gr.Image(label="Preprocessed Cond")
submit_btn.click(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False
).then(
fn = infer,
inputs = [cond_in, image_in, prompt, inference_steps, guidance_scale, control_mode, control_strength, seed],
outputs = [result, processed_cond],
show_api=False
)
demo.queue(api_open=False)
demo.launch()