File size: 4,569 Bytes
90b6afe
47a1b08
20d069f
50d0879
 
20d069f
50d0879
 
20d069f
 
50d0879
 
 
 
 
20d069f
50d0879
 
20d069f
 
 
 
 
 
00f7e20
 
 
 
50d0879
 
20d069f
 
 
 
 
 
 
 
50d0879
 
 
 
20d069f
 
38264cb
 
 
 
 
20d069f
50d0879
20d069f
38264cb
20d069f
 
 
 
 
38264cb
20d069f
 
 
38264cb
 
 
20d069f
 
38264cb
 
20d069f
 
 
 
 
 
 
 
 
 
 
 
 
 
38264cb
20d069f
 
38264cb
26208b0
 
20d069f
 
 
 
 
 
 
 
 
 
 
 
 
38264cb
20d069f
 
 
 
38264cb
20d069f
 
38264cb
20d069f
 
50d0879
00f7e20
20d069f
00f7e20
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
@spaces.GPU
def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option, text_prompt):
    if base_mask_option == "Draw Mask":
        tar_image = base_image["background"]
        tar_mask = base_image["layers"][0]
    else:
        tar_image = base_image["background"]
        tar_mask = base_mask["background"]

    if ref_mask_option == "Draw Mask":
        ref_image = reference_image["background"]
        ref_mask = reference_image["layers"][0]
    elif ref_mask_option == "Upload with Mask":
        ref_image = reference_image["background"]
        ref_mask = ref_mask["background"]
    else:
        ref_image = reference_image["background"]
        ref_mask = get_mask(ref_image, text_prompt)

    tar_image = tar_image.convert("RGB")
    tar_mask = tar_mask.convert("L")
    ref_image = ref_image.convert("RGB")
    ref_mask = ref_mask.convert("L")

    # Store the received masks for return
    received_tar_mask = tar_mask.copy()
    received_ref_mask = ref_mask.copy()

    return_ref_mask = ref_mask.copy()

    tar_image = np.asarray(tar_image)
    tar_mask = np.asarray(tar_mask)
    tar_mask = np.where(tar_mask > 128, 1, 0).astype(np.uint8)

    ref_image = np.asarray(ref_image)
    ref_mask = np.asarray(ref_mask)
    ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)

    if tar_mask.sum() == 0:
        raise gr.Error('No mask for the background image.Please check mask button!')
    if ref_mask.sum() == 0:
        raise gr.Error('No mask for the reference image.Please check mask button!')

    ref_box_yyxx = get_bbox_from_mask(ref_mask)
    ref_mask_3 = np.stack([ref_mask, ref_mask, ref_mask], -1)
    masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1 - ref_mask_3)
    y1, y2, x1, x2 = ref_box_yyxx
    masked_ref_image = masked_ref_image[y1:y2, x1:x2, :]
    ref_mask = ref_mask[y1:y2, x1:x2]
    ratio = 1.3
    masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio)

    masked_ref_image = pad_to_square(masked_ref_image, pad_value=255, random=False)

    kernel = np.ones((7, 7), np.uint8)
    iterations = 2
    tar_mask = cv2.dilate(tar_mask, kernel, iterations=iterations)

    # zoom in
    tar_box_yyxx = get_bbox_from_mask(tar_mask)
    tar_box_yyxx = expand_bbox(tar_mask, tar_box_yyxx, ratio=1.2)

    tar_box_yyxx_crop = expand_bbox(tar_image, tar_box_yyxx, ratio=2)
    tar_box_yyxx_crop = box2squre(tar_image, tar_box_yyxx_crop)  # crop box
    y1, y2, x1, x2 = tar_box_yyxx_crop

    old_tar_image = tar_image.copy()
    tar_image = tar_image[y1:y2, x1:x2, :]
    tar_mask = tar_mask[y1:y2, x1:x2]

    H1, W1 = tar_image.shape[0], tar_image.shape[1]

    tar_mask = pad_to_square(tar_mask, pad_value=0)
    tar_mask = cv2.resize(tar_mask, size)

    masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), size).astype(np.uint8)
    pipe_prior_output = redux(Image.fromarray(masked_ref_image))

    tar_image = pad_to_square(tar_image, pad_value=255)
    H2, W2 = tar_image.shape[0], tar_image.shape[1]
    tar_image = cv2.resize(tar_image, size)
    diptych_ref_tar = np.concatenate([masked_ref_image, tar_image], axis=1)

    tar_mask = np.stack([tar_mask, tar_mask, tar_mask], -1)
    mask_black = np.ones_like(tar_image) * 0
    mask_diptych = np.concatenate([mask_black, tar_mask], axis=1)

    show_diptych_ref_tar = create_highlighted_mask(diptych_ref_tar, mask_diptych)
    show_diptych_ref_tar = Image.fromarray(show_diptych_ref_tar)

    diptych_ref_tar = Image.fromarray(diptych_ref_tar)
    mask_diptych[mask_diptych == 1] = 255
    mask_diptych = Image.fromarray(mask_diptych)

    generator = torch.Generator("cuda").manual_seed(seed)
    edited_image = pipe(
        image=diptych_ref_tar,
        mask_image=mask_diptych,
        height=mask_diptych.size[1],
        width=mask_diptych.size[0],
        max_sequence_length=512,
        generator=generator,
        **pipe_prior_output,
    ).images[0]

    width, height = edited_image.size
    left = width // 2
    edited_image = edited_image.crop((left, 0, width, height))

    edited_image = np.array(edited_image)
    edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
    edited_image = Image.fromarray(edited_image)

    if ref_mask_option != "Label to Mask":
        return [show_diptych_ref_tar, edited_image, received_tar_mask, received_ref_mask]
    else:
        return [return_ref_mask, show_diptych_ref_tar, edited_image, received_tar_mask, received_ref_mask]