Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -9,7 +9,139 @@ from huggingface_hub import snapshot_download 
     | 
|
| 9 | 
         
             
            from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
         
     | 
| 10 | 
         
             
            import math
         
     | 
| 11 | 
         
             
            from utils.utils import get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 12 | 
         
             
            import spaces
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 13 | 
         | 
| 14 | 
         
             
            hf_token = os.getenv("HF_TOKEN")
         
     | 
| 15 | 
         | 
| 
         @@ -59,26 +191,31 @@ image_mask_list.sort() 
     | 
|
| 59 | 
         
             
            @spaces.GPU
         
     | 
| 60 | 
         
             
            def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option):
         
     | 
| 61 | 
         | 
| 
         | 
|
| 62 | 
         
             
                if base_mask_option == "Draw Mask":
         
     | 
| 63 | 
         
            -
                    tar_image = base_image[" 
     | 
| 64 | 
         
            -
                    tar_mask = base_image[" 
     | 
| 65 | 
         
             
                else:
         
     | 
| 66 | 
         
            -
                    tar_image = base_image[" 
     | 
| 67 | 
         
            -
                    tar_mask = base_mask
         
     | 
| 68 | 
         | 
| 69 | 
         
             
                if ref_mask_option == "Draw Mask":
         
     | 
| 70 | 
         
            -
                    ref_image = reference_image[" 
     | 
| 71 | 
         
            -
                    ref_mask = reference_image[" 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 72 | 
         
             
                else:
         
     | 
| 73 | 
         
            -
                    ref_image = reference_image[" 
     | 
| 74 | 
         
            -
                    ref_mask =  
     | 
| 75 | 
         
            -
             
     | 
| 76 | 
         | 
| 77 | 
         
             
                tar_image = tar_image.convert("RGB")
         
     | 
| 78 | 
         
             
                tar_mask = tar_mask.convert("L")
         
     | 
| 79 | 
         
             
                ref_image = ref_image.convert("RGB")
         
     | 
| 80 | 
         
             
                ref_mask = ref_mask.convert("L")
         
     | 
| 81 | 
         | 
| 
         | 
|
| 
         | 
|
| 82 | 
         
             
                tar_image = np.asarray(tar_image)
         
     | 
| 83 | 
         
             
                tar_mask = np.asarray(tar_mask)
         
     | 
| 84 | 
         
             
                tar_mask = np.where(tar_mask > 128, 1, 0).astype(np.uint8)
         
     | 
| 
         @@ -87,15 +224,20 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_ 
     | 
|
| 87 | 
         
             
                ref_mask = np.asarray(ref_mask)
         
     | 
| 88 | 
         
             
                ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
         
     | 
| 89 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 90 | 
         | 
| 91 | 
         
             
                ref_box_yyxx = get_bbox_from_mask(ref_mask)
         
     | 
| 92 | 
         
             
                ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
         
     | 
| 93 | 
         
             
                masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3) 
         
     | 
| 94 | 
         
             
                y1,y2,x1,x2 = ref_box_yyxx
         
     | 
| 95 | 
         
            -
                masked_ref_image = masked_ref_image[y1:y2,x1:x2,:] 
     | 
| 96 | 
         
             
                ref_mask = ref_mask[y1:y2,x1:x2] 
         
     | 
| 97 | 
         
             
                ratio = 1.3
         
     | 
| 98 | 
         
            -
                masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio) 
     | 
| 99 | 
         | 
| 100 | 
         | 
| 101 | 
         
             
                masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False) 
         
     | 
| 
         @@ -172,8 +314,10 @@ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_ 
     | 
|
| 172 | 
         
             
                edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop)) 
         
     | 
| 173 | 
         
             
                edited_image = Image.fromarray(edited_image)
         
     | 
| 174 | 
         | 
| 175 | 
         
            -
             
     | 
| 176 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 177 | 
         | 
| 178 | 
         
             
            def update_ui(option):
         
     | 
| 179 | 
         
             
                if option == "Draw Mask":
         
     | 
| 
         @@ -185,32 +329,37 @@ def update_ui(option): 
     | 
|
| 185 | 
         
             
            with gr.Blocks() as demo:
         
     | 
| 186 | 
         | 
| 187 | 
         | 
| 188 | 
         
            -
                gr.Markdown("# 
     | 
| 189 | 
         
            -
                gr.Markdown(" 
     | 
| 190 | 
         
            -
             
     | 
| 191 | 
         
            -
                gr.Markdown("### Only select one of these two methods. Don't forget to click the corresponding button!!")
         
     | 
| 192 | 
         | 
| 193 | 
         
             
                with gr.Row():
         
     | 
| 194 | 
         
            -
                    with gr.Column():
         
     | 
| 195 | 
         
             
                        with gr.Row():
         
     | 
| 196 | 
         
            -
                            base_image = gr. 
     | 
| 197 | 
         
            -
             
     | 
| 
         | 
|
| 198 | 
         | 
| 199 | 
         
            -
                            base_mask = gr. 
     | 
| 200 | 
         | 
| 201 | 
         
             
                        with gr.Row():
         
     | 
| 202 | 
         
             
                            base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option", value="Upload with Mask")
         
     | 
| 203 | 
         | 
| 204 | 
         
             
                        with gr.Row():
         
     | 
| 205 | 
         
            -
                            ref_image = gr. 
     | 
| 206 | 
         
            -
                                                 
     | 
| 
         | 
|
| 207 | 
         | 
| 208 | 
         
            -
                            ref_mask = gr. 
     | 
| 209 | 
         | 
| 210 | 
         
             
                        with gr.Row():
         
     | 
| 211 | 
         
            -
                            ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Reference Mask Input Option", value="Upload with Mask")
         
     | 
| 212 | 
         | 
| 213 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 214 | 
         
             
                        with gr.Accordion("Advanced Option", open=True):
         
     | 
| 215 | 
         
             
                            seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
         
     | 
| 216 | 
         
             
                            gr.Markdown("### Guidelines")
         
     | 
| 
         @@ -218,7 +367,6 @@ with gr.Blocks() as demo: 
     | 
|
| 218 | 
         | 
| 219 | 
         
             
                run_local_button = gr.Button(value="Run")
         
     | 
| 220 | 
         | 
| 221 | 
         
            -
             
     | 
| 222 | 
         
             
                # #### example #####
         
     | 
| 223 | 
         
             
                num_examples = len(image_list)
         
     | 
| 224 | 
         
             
                for i in range(num_examples):
         
     | 
| 
         @@ -234,12 +382,11 @@ with gr.Blocks() as demo: 
     | 
|
| 234 | 
         
             
                            gr.Examples([ref_list[i]], inputs=[ref_image], examples_per_page=1, label="")
         
     | 
| 235 | 
         
             
                            gr.Examples([ref_mask_list[i]], inputs=[ref_mask], examples_per_page=1, label="")
         
     | 
| 236 | 
         
             
                    if i < num_examples - 1:
         
     | 
| 237 | 
         
            -
                         
     | 
| 238 | 
         
            -
                            gr.HTML("<hr>")
         
     | 
| 239 | 
         
             
                # #### example #####
         
     | 
| 240 | 
         
            -
             
     | 
| 241 | 
         
            -
                run_local_button.click(fn=run_local, 
     | 
| 242 | 
         
            -
             
     | 
| 243 | 
         
            -
             
     | 
| 244 | 
         
            -
             
     | 
| 245 | 
         
             
            demo.launch()
         
     | 
| 
         | 
|
| 9 | 
         
             
            from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
         
     | 
| 10 | 
         
             
            import math
         
     | 
| 11 | 
         
             
            from utils.utils import get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import os,sys
         
     | 
| 14 | 
         
            +
            os.system("python -m pip install -e segment_anything")
         
     | 
| 15 | 
         
            +
            os.system("python -m pip install -e GroundingDINO")
         
     | 
| 16 | 
         
            +
            sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
         
     | 
| 17 | 
         
            +
            sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
         
     | 
| 18 | 
         
            +
            os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swinb_cogcoor.pth")
         
     | 
| 19 | 
         
            +
            os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth")
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            import torchvision
         
     | 
| 22 | 
         
            +
            from GroundingDINO.groundingdino.util.inference import load_model
         
     | 
| 23 | 
         
            +
            from segment_anything import build_sam, SamPredictor 
         
     | 
| 24 | 
         
             
            import spaces
         
     | 
| 25 | 
         
            +
            import GroundingDINO.groundingdino.datasets.transforms as T
         
     | 
| 26 | 
         
            +
            from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            # GroundingDINO config and checkpoint
         
     | 
| 31 | 
         
            +
            GROUNDING_DINO_CONFIG_PATH = "./GroundingDINO_SwinB.cfg.py"
         
     | 
| 32 | 
         
            +
            GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swinb_cogcoor.pth"
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            # Segment-Anything checkpoint
         
     | 
| 35 | 
         
            +
            SAM_ENCODER_VERSION = "vit_h"
         
     | 
| 36 | 
         
            +
            SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            # Building GroundingDINO inference model
         
     | 
| 39 | 
         
            +
            groundingdino_model  = load_model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device="cpu")
         
     | 
| 40 | 
         
            +
            # Building SAM Model and SAM Predictor
         
     | 
| 41 | 
         
            +
            sam = build_sam(checkpoint=SAM_CHECKPOINT_PATH)
         
     | 
| 42 | 
         
            +
            sam_predictor = SamPredictor(sam)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            def transform_image(image_pil):
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                transform = T.Compose(
         
     | 
| 47 | 
         
            +
                    [
         
     | 
| 48 | 
         
            +
                        T.RandomResize([800], max_size=1333),
         
     | 
| 49 | 
         
            +
                        T.ToTensor(),
         
     | 
| 50 | 
         
            +
                        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
         
     | 
| 51 | 
         
            +
                    ]
         
     | 
| 52 | 
         
            +
                )
         
     | 
| 53 | 
         
            +
                image, _ = transform(image_pil, None)  # 3, h, w
         
     | 
| 54 | 
         
            +
                return image
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            def get_grounding_output(model, image, caption, box_threshold=0.25, text_threshold=0.25, with_logits=True):
         
     | 
| 58 | 
         
            +
                caption = caption.lower()
         
     | 
| 59 | 
         
            +
                caption = caption.strip()
         
     | 
| 60 | 
         
            +
                if not caption.endswith("."):
         
     | 
| 61 | 
         
            +
                    caption = caption + "."
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                with torch.no_grad():
         
     | 
| 64 | 
         
            +
                    outputs = model(image[None], captions=[caption])
         
     | 
| 65 | 
         
            +
                logits = outputs["pred_logits"].cpu().sigmoid()[0]  # (nq, 256)
         
     | 
| 66 | 
         
            +
                boxes = outputs["pred_boxes"].cpu()[0]  # (nq, 4)
         
     | 
| 67 | 
         
            +
                logits.shape[0]
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                # filter output
         
     | 
| 70 | 
         
            +
                logits_filt = logits.clone()
         
     | 
| 71 | 
         
            +
                boxes_filt = boxes.clone()
         
     | 
| 72 | 
         
            +
                filt_mask = logits_filt.max(dim=1)[0] > box_threshold
         
     | 
| 73 | 
         
            +
                logits_filt = logits_filt[filt_mask]  # num_filt, 256
         
     | 
| 74 | 
         
            +
                boxes_filt = boxes_filt[filt_mask]  # num_filt, 4
         
     | 
| 75 | 
         
            +
                logits_filt.shape[0]
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                # get phrase
         
     | 
| 78 | 
         
            +
                tokenlizer = model.tokenizer
         
     | 
| 79 | 
         
            +
                tokenized = tokenlizer(caption)
         
     | 
| 80 | 
         
            +
                # build pred
         
     | 
| 81 | 
         
            +
                pred_phrases = []
         
     | 
| 82 | 
         
            +
                scores = []
         
     | 
| 83 | 
         
            +
                for logit, box in zip(logits_filt, boxes_filt):
         
     | 
| 84 | 
         
            +
                    pred_phrase = get_phrases_from_posmap(
         
     | 
| 85 | 
         
            +
                        logit > text_threshold, tokenized, tokenlizer)
         
     | 
| 86 | 
         
            +
                    if with_logits:
         
     | 
| 87 | 
         
            +
                        pred_phrases.append(
         
     | 
| 88 | 
         
            +
                            pred_phrase + f"({str(logit.max().item())[:4]})")
         
     | 
| 89 | 
         
            +
                    else:
         
     | 
| 90 | 
         
            +
                        pred_phrases.append(pred_phrase)
         
     | 
| 91 | 
         
            +
                    scores.append(logit.max().item())
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                return boxes_filt, torch.Tensor(scores), pred_phrases
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            def get_mask(image, label):
         
     | 
| 97 | 
         
            +
                global groundingdino_model, sam_predictor
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                image_pil = image.convert("RGB")
         
     | 
| 101 | 
         
            +
                transformed_image = transform_image(image_pil)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                boxes_filt, scores, pred_phrases = get_grounding_output(
         
     | 
| 105 | 
         
            +
                    groundingdino_model, transformed_image, label
         
     | 
| 106 | 
         
            +
                )
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                size = image_pil.size
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                # process boxes
         
     | 
| 111 | 
         
            +
                H, W = size[1], size[0]
         
     | 
| 112 | 
         
            +
                for i in range(boxes_filt.size(0)):
         
     | 
| 113 | 
         
            +
                    boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
         
     | 
| 114 | 
         
            +
                    boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
         
     | 
| 115 | 
         
            +
                    boxes_filt[i][2:] += boxes_filt[i][:2]
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                boxes_filt = boxes_filt.cpu()
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                # nms
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                nms_idx = torchvision.ops.nms(
         
     | 
| 122 | 
         
            +
                    boxes_filt, scores, 0.8).numpy().tolist()
         
     | 
| 123 | 
         
            +
                boxes_filt = boxes_filt[nms_idx]
         
     | 
| 124 | 
         
            +
                pred_phrases = [pred_phrases[idx] for idx in nms_idx]
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                image = np.array(image_pil)
         
     | 
| 128 | 
         
            +
                sam_predictor.set_image(image)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                transformed_boxes = sam_predictor.transform.apply_boxes_torch(
         
     | 
| 131 | 
         
            +
                    boxes_filt, image.shape[:2])
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                masks, _, _ = sam_predictor.predict_torch(
         
     | 
| 134 | 
         
            +
                    point_coords=None,
         
     | 
| 135 | 
         
            +
                    point_labels=None,
         
     | 
| 136 | 
         
            +
                    boxes=transformed_boxes,
         
     | 
| 137 | 
         
            +
                    multimask_output=False,
         
     | 
| 138 | 
         
            +
                )
         
     | 
| 139 | 
         
            +
                result_mask = masks[0][0].cpu().numpy()
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                result_mask = Image.fromarray(result_mask)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                return result_mask
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         | 
| 146 | 
         
             
            hf_token = os.getenv("HF_TOKEN")
         
     | 
| 147 | 
         | 
| 
         | 
|
| 191 | 
         
             
            @spaces.GPU
         
     | 
| 192 | 
         
             
            def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option):
         
     | 
| 193 | 
         | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
             
                if base_mask_option == "Draw Mask":
         
     | 
| 196 | 
         
            +
                    tar_image = base_image["background"]
         
     | 
| 197 | 
         
            +
                    tar_mask = base_image["layers"][0]
         
     | 
| 198 | 
         
             
                else:
         
     | 
| 199 | 
         
            +
                    tar_image = base_image["background"]
         
     | 
| 200 | 
         
            +
                    tar_mask = base_mask["background"]
         
     | 
| 201 | 
         | 
| 202 | 
         
             
                if ref_mask_option == "Draw Mask":
         
     | 
| 203 | 
         
            +
                    ref_image = reference_image["background"]
         
     | 
| 204 | 
         
            +
                    ref_mask = reference_image["layers"][0]
         
     | 
| 205 | 
         
            +
                elif ref_mask_option == "Upload with Mask":
         
     | 
| 206 | 
         
            +
                    ref_image = reference_image["background"]
         
     | 
| 207 | 
         
            +
                    ref_mask = ref_mask["background"]
         
     | 
| 208 | 
         
             
                else:
         
     | 
| 209 | 
         
            +
                    ref_image = reference_image["background"]
         
     | 
| 210 | 
         
            +
                    ref_mask = get_mask(ref_image, text_prompt)
         
     | 
| 
         | 
|
| 211 | 
         | 
| 212 | 
         
             
                tar_image = tar_image.convert("RGB")
         
     | 
| 213 | 
         
             
                tar_mask = tar_mask.convert("L")
         
     | 
| 214 | 
         
             
                ref_image = ref_image.convert("RGB")
         
     | 
| 215 | 
         
             
                ref_mask = ref_mask.convert("L")
         
     | 
| 216 | 
         | 
| 217 | 
         
            +
                return_ref_mask = ref_mask.copy()
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
             
                tar_image = np.asarray(tar_image)
         
     | 
| 220 | 
         
             
                tar_mask = np.asarray(tar_mask)
         
     | 
| 221 | 
         
             
                tar_mask = np.where(tar_mask > 128, 1, 0).astype(np.uint8)
         
     | 
| 
         | 
|
| 224 | 
         
             
                ref_mask = np.asarray(ref_mask)
         
     | 
| 225 | 
         
             
                ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
         
     | 
| 226 | 
         | 
| 227 | 
         
            +
                if tar_mask.sum() == 0:
         
     | 
| 228 | 
         
            +
                    raise gr.Error('No mask for the background image.Please check mask button!')
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                if ref_mask.sum() == 0:
         
     | 
| 231 | 
         
            +
                    raise gr.Error('No mask for the reference image.Please check mask button!')
         
     | 
| 232 | 
         | 
| 233 | 
         
             
                ref_box_yyxx = get_bbox_from_mask(ref_mask)
         
     | 
| 234 | 
         
             
                ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
         
     | 
| 235 | 
         
             
                masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3) 
         
     | 
| 236 | 
         
             
                y1,y2,x1,x2 = ref_box_yyxx
         
     | 
| 237 | 
         
            +
                masked_ref_image = masked_ref_image[y1:y2,x1:x2,:]
         
     | 
| 238 | 
         
             
                ref_mask = ref_mask[y1:y2,x1:x2] 
         
     | 
| 239 | 
         
             
                ratio = 1.3
         
     | 
| 240 | 
         
            +
                masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio)
         
     | 
| 241 | 
         | 
| 242 | 
         | 
| 243 | 
         
             
                masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False) 
         
     | 
| 
         | 
|
| 314 | 
         
             
                edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop)) 
         
     | 
| 315 | 
         
             
                edited_image = Image.fromarray(edited_image)
         
     | 
| 316 | 
         | 
| 317 | 
         
            +
                if ref_mask_option != "Label to Mask":
         
     | 
| 318 | 
         
            +
                    return [edited_image]
         
     | 
| 319 | 
         
            +
                else:
         
     | 
| 320 | 
         
            +
                    return [return_ref_mask, edited_image]   
         
     | 
| 321 | 
         | 
| 322 | 
         
             
            def update_ui(option):
         
     | 
| 323 | 
         
             
                if option == "Draw Mask":
         
     | 
| 
         | 
|
| 329 | 
         
             
            with gr.Blocks() as demo:
         
     | 
| 330 | 
         | 
| 331 | 
         | 
| 332 | 
         
            +
                gr.Markdown("# Insert-Anything")
         
     | 
| 333 | 
         
            +
                gr.Markdown("### Draw mask or upload mask.Only select one of these two methods. Don't forget to click the corresponding button!!")
         
     | 
| 334 | 
         
            +
             
     | 
| 
         | 
|
| 335 | 
         | 
| 336 | 
         
             
                with gr.Row():
         
     | 
| 337 | 
         
            +
                    with gr.Column(scale=1):
         
     | 
| 338 | 
         
             
                        with gr.Row():
         
     | 
| 339 | 
         
            +
                            base_image = gr.ImageEditor(label="Background Image", sources="upload", type="pil", brush=gr.Brush(colors=["#FFFFFF"],default_size = 30,color_mode = "fixed"),
         
     | 
| 340 | 
         
            +
                                                layers = False,
         
     | 
| 341 | 
         
            +
                                                interactive=True)
         
     | 
| 342 | 
         | 
| 343 | 
         
            +
                            base_mask = gr.ImageEditor(label="Background Mask", sources="upload", type="pil", layers = False, brush=False, eraser=False)
         
     | 
| 344 | 
         | 
| 345 | 
         
             
                        with gr.Row():
         
     | 
| 346 | 
         
             
                            base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option", value="Upload with Mask")
         
     | 
| 347 | 
         | 
| 348 | 
         
             
                        with gr.Row():
         
     | 
| 349 | 
         
            +
                            ref_image = gr.ImageEditor(label="Reference Image", sources="upload", type="pil", brush=gr.Brush(colors=["#FFFFFF"],default_size = 30,color_mode = "fixed"),
         
     | 
| 350 | 
         
            +
                                                layers = False,
         
     | 
| 351 | 
         
            +
                                                interactive=True)
         
     | 
| 352 | 
         | 
| 353 | 
         
            +
                            ref_mask = gr.ImageEditor(label="Reference Mask", sources="upload", type="pil", layers = False, brush=False, eraser=False)
         
     | 
| 354 | 
         | 
| 355 | 
         
             
                        with gr.Row():
         
     | 
| 356 | 
         
            +
                            ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask", "Label to Mask"], label="Reference Mask Input Option", value="Upload with Mask")
         
     | 
| 357 | 
         | 
| 358 | 
         
            +
                        with gr.Row():
         
     | 
| 359 | 
         
            +
                            text_prompt = gr.Textbox(label="Label")
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                    with gr.Column(scale=1):
         
     | 
| 362 | 
         
            +
                        baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=701, columns=1)
         
     | 
| 363 | 
         
             
                        with gr.Accordion("Advanced Option", open=True):
         
     | 
| 364 | 
         
             
                            seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
         
     | 
| 365 | 
         
             
                            gr.Markdown("### Guidelines")
         
     | 
| 
         | 
|
| 367 | 
         | 
| 368 | 
         
             
                run_local_button = gr.Button(value="Run")
         
     | 
| 369 | 
         | 
| 
         | 
|
| 370 | 
         
             
                # #### example #####
         
     | 
| 371 | 
         
             
                num_examples = len(image_list)
         
     | 
| 372 | 
         
             
                for i in range(num_examples):
         
     | 
| 
         | 
|
| 382 | 
         
             
                            gr.Examples([ref_list[i]], inputs=[ref_image], examples_per_page=1, label="")
         
     | 
| 383 | 
         
             
                            gr.Examples([ref_mask_list[i]], inputs=[ref_mask], examples_per_page=1, label="")
         
     | 
| 384 | 
         
             
                    if i < num_examples - 1:
         
     | 
| 385 | 
         
            +
                        gr.HTML("<hr>")
         
     | 
| 
         | 
|
| 386 | 
         
             
                # #### example #####
         
     | 
| 387 | 
         
            +
             
     | 
| 388 | 
         
            +
                run_local_button.click(fn=run_local,
         
     | 
| 389 | 
         
            +
                                        inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt],
         
     | 
| 390 | 
         
            +
                                        outputs=[baseline_gallery]
         
     | 
| 391 | 
         
            +
                                        )
         
     | 
| 392 | 
         
             
            demo.launch()
         
     |