Spaces:
Running
Running
Add other style types
Browse files
app.py
CHANGED
|
@@ -129,15 +129,22 @@ def postprocess(tensor: torch.Tensor) -> PIL.Image.Image:
|
|
| 129 |
@torch.inference_mode()
|
| 130 |
def run(
|
| 131 |
image,
|
| 132 |
-
|
|
|
|
| 133 |
dlib_landmark_model,
|
| 134 |
encoder: nn.Module,
|
| 135 |
-
|
| 136 |
-
|
| 137 |
transform: Callable,
|
| 138 |
device: torch.device,
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
stylename = list(exstyles.keys())[style_id]
|
| 142 |
|
| 143 |
image = align_face(filepath=image.name, predictor=dlib_landmark_model)
|
|
@@ -181,7 +188,11 @@ def run(
|
|
| 181 |
img_gen1 = postprocess(img_gen[1])
|
| 182 |
img_gen2 = postprocess(img_gen2[0])
|
| 183 |
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
return image, style_image, img_rec, img_gen0, img_gen1, img_gen2
|
| 187 |
|
|
@@ -192,43 +203,60 @@ def main():
|
|
| 192 |
args = parse_args()
|
| 193 |
device = torch.device(args.device)
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
download_cartoon_images()
|
| 199 |
dlib_landmark_model = create_dlib_landmark_model()
|
| 200 |
encoder = load_encoder(device)
|
| 201 |
-
generator = load_generator(style_type, device)
|
| 202 |
-
exstyles = load_exstylecode(style_type)
|
| 203 |
transform = create_transform()
|
| 204 |
|
| 205 |
func = functools.partial(run,
|
| 206 |
dlib_landmark_model=dlib_landmark_model,
|
| 207 |
encoder=encoder,
|
| 208 |
-
|
| 209 |
-
|
| 210 |
transform=transform,
|
| 211 |
-
device=device
|
| 212 |
-
style_image_dir=style_image_dir)
|
| 213 |
func = functools.update_wrapper(func, run)
|
| 214 |
|
| 215 |
repo_url = 'https://github.com/williamyang1991/DualStyleGAN'
|
| 216 |
title = 'williamyang1991/DualStyleGAN'
|
| 217 |
description = f"""A demo for {repo_url}
|
| 218 |
|
| 219 |
-
You can select style images from the table below.
|
| 220 |
"""
|
| 221 |
article = ''
|
| 222 |
|
| 223 |
image_paths = sorted(pathlib.Path('images').glob('*'))
|
| 224 |
-
examples = [[path.as_posix(), 26] for path in image_paths]
|
| 225 |
|
| 226 |
gr.Interface(
|
| 227 |
func,
|
| 228 |
[
|
| 229 |
-
gr.inputs.Image(type='file', label='Image'),
|
| 230 |
-
gr.inputs.
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
],
|
| 233 |
[
|
| 234 |
gr.outputs.Image(type='pil', label='Aligned Face'),
|
|
|
|
| 129 |
@torch.inference_mode()
|
| 130 |
def run(
|
| 131 |
image,
|
| 132 |
+
style_type: str,
|
| 133 |
+
style_id: float,
|
| 134 |
dlib_landmark_model,
|
| 135 |
encoder: nn.Module,
|
| 136 |
+
generator_dict: dict[str, nn.Module],
|
| 137 |
+
exstyle_dict: dict[str, dict[str, np.ndarray]],
|
| 138 |
transform: Callable,
|
| 139 |
device: torch.device,
|
| 140 |
+
) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image,
|
| 141 |
+
PIL.Image, PIL.Image]:
|
| 142 |
+
generator = generator_dict[style_type]
|
| 143 |
+
exstyles = exstyle_dict[style_type]
|
| 144 |
+
|
| 145 |
+
style_id = int(style_id)
|
| 146 |
+
style_id = min(max(0, style_id), len(exstyles) - 1)
|
| 147 |
+
|
| 148 |
stylename = list(exstyles.keys())[style_id]
|
| 149 |
|
| 150 |
image = align_face(filepath=image.name, predictor=dlib_landmark_model)
|
|
|
|
| 188 |
img_gen1 = postprocess(img_gen[1])
|
| 189 |
img_gen2 = postprocess(img_gen2[0])
|
| 190 |
|
| 191 |
+
try:
|
| 192 |
+
style_image_dir = pathlib.Path(style_type)
|
| 193 |
+
style_image = PIL.Image.open(style_image_dir / stylename)
|
| 194 |
+
except Exception:
|
| 195 |
+
style_image = None
|
| 196 |
|
| 197 |
return image, style_image, img_rec, img_gen0, img_gen1, img_gen2
|
| 198 |
|
|
|
|
| 203 |
args = parse_args()
|
| 204 |
device = torch.device(args.device)
|
| 205 |
|
| 206 |
+
style_types = [
|
| 207 |
+
'cartoon',
|
| 208 |
+
'caricature',
|
| 209 |
+
'anime',
|
| 210 |
+
'arcane',
|
| 211 |
+
'comic',
|
| 212 |
+
'pixar',
|
| 213 |
+
'slamdunk',
|
| 214 |
+
]
|
| 215 |
+
generator_dict = {
|
| 216 |
+
style_type: load_generator(style_type, device)
|
| 217 |
+
for style_type in style_types
|
| 218 |
+
}
|
| 219 |
+
exstyle_dict = {
|
| 220 |
+
style_type: load_exstylecode(style_type)
|
| 221 |
+
for style_type in style_types
|
| 222 |
+
}
|
| 223 |
|
| 224 |
download_cartoon_images()
|
| 225 |
dlib_landmark_model = create_dlib_landmark_model()
|
| 226 |
encoder = load_encoder(device)
|
|
|
|
|
|
|
| 227 |
transform = create_transform()
|
| 228 |
|
| 229 |
func = functools.partial(run,
|
| 230 |
dlib_landmark_model=dlib_landmark_model,
|
| 231 |
encoder=encoder,
|
| 232 |
+
generator_dict=generator_dict,
|
| 233 |
+
exstyle_dict=exstyle_dict,
|
| 234 |
transform=transform,
|
| 235 |
+
device=device)
|
|
|
|
| 236 |
func = functools.update_wrapper(func, run)
|
| 237 |
|
| 238 |
repo_url = 'https://github.com/williamyang1991/DualStyleGAN'
|
| 239 |
title = 'williamyang1991/DualStyleGAN'
|
| 240 |
description = f"""A demo for {repo_url}
|
| 241 |
|
| 242 |
+
You can select style images for cartoon from the table below.
|
| 243 |
"""
|
| 244 |
article = ''
|
| 245 |
|
| 246 |
image_paths = sorted(pathlib.Path('images').glob('*'))
|
| 247 |
+
examples = [[path.as_posix(), 'cartoon', 26] for path in image_paths]
|
| 248 |
|
| 249 |
gr.Interface(
|
| 250 |
func,
|
| 251 |
[
|
| 252 |
+
gr.inputs.Image(type='file', label='Input Image'),
|
| 253 |
+
gr.inputs.Radio(
|
| 254 |
+
style_types,
|
| 255 |
+
type='value',
|
| 256 |
+
default='cartoon',
|
| 257 |
+
label='Style Type',
|
| 258 |
+
),
|
| 259 |
+
gr.inputs.Number(default=26, label='Style Image Index'),
|
| 260 |
],
|
| 261 |
[
|
| 262 |
gr.outputs.Image(type='pil', label='Aligned Face'),
|