Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
sys.path.append('./') | |
import gradio as gr | |
import spaces | |
import os | |
import sys | |
import subprocess | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
import torch | |
import random | |
os.system("pip install -e ./controlnet_aux") | |
from controlnet_aux import OpenposeDetector #, CannyDetector | |
from depth_anything_v2.dpt import DepthAnythingV2 | |
from huggingface_hub import hf_hub_download | |
from huggingface_hub import login | |
hf_token = os.environ.get("HF_TOKEN") | |
login(token=hf_token) | |
MAX_SEED = np.iinfo(np.int32).max | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model_configs = { | |
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, | |
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, | |
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, | |
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} | |
} | |
ratios_map = { | |
0.5:{"width":704,"height":1408}, | |
0.57:{"width":768,"height":1344}, | |
0.68:{"width":832,"height":1216}, | |
0.72:{"width":832,"height":1152}, | |
0.78:{"width":896,"height":1152}, | |
0.82:{"width":896,"height":1088}, | |
0.88:{"width":960,"height":1088}, | |
0.94:{"width":960,"height":1024}, | |
1.00:{"width":1024,"height":1024}, | |
1.13:{"width":1088,"height":960}, | |
1.21:{"width":1088,"height":896}, | |
1.29:{"width":1152,"height":896}, | |
1.38:{"width":1152,"height":832}, | |
1.46:{"width":1216,"height":832}, | |
1.67:{"width":1280,"height":768}, | |
1.75:{"width":1344,"height":768}, | |
2.00:{"width":1408,"height":704} | |
} | |
ratios = np.array(list(ratios_map.keys())) | |
encoder = 'vitl' | |
model = DepthAnythingV2(**model_configs[encoder]) | |
filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-Large", filename=f"depth_anything_v2_vitl.pth", repo_type="model") | |
state_dict = torch.load(filepath, map_location="cpu") | |
model.load_state_dict(state_dict) | |
model = model.to(DEVICE).eval() | |
from huggingface_hub import hf_hub_download | |
import os | |
import torch | |
from diffusers.utils import load_image | |
from controlnet_bria import BriaControlNetModel, BriaMultiControlNetModel | |
from pipeline_bria_controlnet import BriaControlNetPipeline | |
import PIL.Image as Image | |
base_model = 'briaai/BRIA-4B-Adapt' | |
controlnet_model = 'briaai/BRIA-4B-Adapt-ControlNet-Union' | |
controlnet = BriaControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) | |
controlnet = BriaMultiControlNetModel([controlnet]) | |
pipe = BriaControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16, trust_remote_code=True) | |
pipe.to("cuda") | |
mode_mapping = { | |
"depth": 0, | |
"canny": 1, | |
"colorgrid": 2, | |
"recolor": 3, | |
"tile": 4, | |
"pose": 5, | |
} | |
strength_mapping = { | |
"depth": 1.0, | |
"canny": 1.0, | |
"colorgrid": 1.0, | |
"recolor": 1.0, | |
"tile": 1.0, | |
"pose": 1.0, | |
} | |
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators") | |
torch.backends.cuda.matmul.allow_tf32 = True | |
pipe.enable_model_cpu_offload() # for saving memory | |
def convert_from_image_to_cv2(img: Image) -> np.ndarray: | |
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
def convert_from_cv2_to_image(img: np.ndarray) -> Image: | |
return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
def extract_depth(image): | |
image = np.asarray(image) | |
depth = model.infer_image(image[:, :, ::-1]) | |
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 | |
depth = depth.astype(np.uint8) | |
gray_depth = Image.fromarray(depth).convert('RGB') | |
return gray_depth | |
def extract_openpose(img): | |
processed_image_open_pose = open_pose(img, hand_and_face=True) | |
return processed_image_open_pose | |
def extract_canny(input_image): | |
image = np.array(input_image) | |
image = cv2.Canny(image, 100, 200) | |
image = image[:, :, None] | |
image = np.concatenate([image, image, image], axis=2) | |
canny_image = Image.fromarray(image) | |
return canny_image | |
def convert_to_grayscale(image): | |
image = convert_from_image_to_cv2(image) | |
gray_image = convert_from_cv2_to_image(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)) | |
return gray_image | |
def tile_old(input_image, resolution=768): | |
input_image = convert_from_image_to_cv2(input_image) | |
H, W, C = input_image.shape | |
H = float(H) | |
W = float(W) | |
k = float(resolution) / min(H, W) | |
H *= k | |
W *= k | |
H = int(np.round(H / 16.0)) * 16 | |
W = int(np.round(W / 16.0)) * 16 | |
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) | |
img = convert_from_cv2_to_image(img) | |
return img | |
def tile(downscale_factor, input_image): | |
control_image = input_image.resize((input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)).resize(input_image.size, Image.NEAREST) | |
return control_image | |
def get_size(init_image): | |
w,h=init_image.size | |
curr_ratio = w/h | |
ind = np.argmin(np.abs(curr_ratio-ratios)) | |
ratio = ratios[ind] | |
chosen_ratio = ratios_map[ratio] | |
w,h = chosen_ratio['width'], chosen_ratio['height'] | |
return w,h | |
def resize_img(image): | |
image = image.convert('RGB') | |
w,h = get_size(image) | |
resized_image = image.resize((w, h)) | |
return resized_image | |
def infer(cond_in, image_in, prompt, inference_steps, guidance_scale, control_mode, control_strength, seed, progress=gr.Progress(track_tqdm=True)): | |
control_mode_num = mode_mapping[control_mode] | |
if cond_in is None: | |
if image_in is not None: | |
image_in = resize_img(load_image(image_in)) | |
if control_mode == "canny": | |
control_image = extract_canny(image_in) | |
elif control_mode == "depth": | |
control_image = extract_depth(image_in) | |
elif control_mode == "pose": | |
control_image = extract_openpose(image_in) | |
elif control_mode == "colorgrid": | |
control_image = tile(64, image_in) | |
elif control_mode == "recolor": | |
control_image = convert_to_grayscale(image_in) | |
elif control_mode == "tile": | |
control_image = tile(16, image_in) | |
else: | |
control_image = resize_img(load_image(cond_in)) | |
width, height = control_image.size | |
image = pipe( | |
prompt, | |
control_image=[control_image], | |
control_mode=[control_mode_num], | |
width=width, | |
height=height, | |
controlnet_conditioning_scale=[control_strength], | |
num_inference_steps=inference_steps, | |
guidance_scale=guidance_scale, | |
generator=torch.manual_seed(seed), | |
).images[0] | |
torch.cuda.empty_cache() | |
return image, control_image, gr.update(visible=True) | |
css=""" | |
#col-container{ | |
margin: 0 auto; | |
max-width: 1080px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(""" | |
# BRIA-4B-Adapt-ControlNet-Union | |
A unified ControlNet for BRIA-4B-Adapt model from Bria.ai. BRIA-4B-Adapt improve the generation of humans and illustrations compared to BRIA 2.3 while still trained on licensed data, and so provides full legal liability coverage for copyright and privacy infringement. Model card: [BRIA-4B-Adapt-ControlNet-Union](https://huggingface.co/briaai/BRIA-4B-Adapt-ControlNet-Union). <br /> | |
""") | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(equal_height=True): | |
cond_in = gr.Image(label="Upload a processed control image", sources=["upload"], type="filepath") | |
image_in = gr.Image(label="Extract condition from a reference image (Optional)", sources=["upload"], type="filepath") | |
prompt = gr.Textbox(label="Prompt", value="best quality") | |
with gr.Accordion("Controlnet"): | |
control_mode = gr.Radio( | |
["depth", "canny", "colorgrid", "recolor", "tile", "pose"], label="Mode", value="canny", | |
info="select the control mode, one for all" | |
) | |
control_strength = gr.Slider( | |
label="control strength", | |
minimum=0, | |
maximum=1.0, | |
step=0.05, | |
value=0.9, | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=42, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Accordion("Advanced settings", open=False): | |
with gr.Column(): | |
with gr.Row(): | |
inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=50, step=1, value=24) | |
guidance_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=3.5) | |
submit_btn = gr.Button("Submit") | |
with gr.Column(): | |
result = gr.Image(label="Result") | |
processed_cond = gr.Image(label="Preprocessed Cond") | |
submit_btn.click( | |
fn=randomize_seed_fn, | |
inputs=[seed, randomize_seed], | |
outputs=seed, | |
queue=False, | |
api_name=False | |
).then( | |
fn = infer, | |
inputs = [cond_in, image_in, prompt, inference_steps, guidance_scale, control_mode, control_strength, seed], | |
outputs = [result, processed_cond], | |
show_api=False | |
) | |
demo.queue(api_open=False) | |
demo.launch() |