Ahsen Khaliq commited on
Commit
2a30e2f
·
1 Parent(s): c616c46

use facexlib

Browse files
Files changed (1) hide show
  1. app.py +27 -9
app.py CHANGED
@@ -3,7 +3,11 @@ from PIL import Image
3
  import torch
4
  import gradio as gr
5
  os.system("pip install gradio==2.5.3")
6
- os.system("pip install autocrop")
 
 
 
 
7
  #os.system("pip install dlib")
8
  from autocrop import Cropper
9
  import torch
@@ -34,7 +38,15 @@ os.makedirs('models', exist_ok=True)
34
  #os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
35
  #os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
36
 
37
- cropper = Cropper(face_percent=80)
 
 
 
 
 
 
 
 
38
 
39
  device = 'cpu'
40
 
@@ -85,15 +97,21 @@ generatorjinx.load_state_dict(ckptjinx["g"], strict=False)
85
 
86
 
87
  def inference(img, model):
 
88
  #aligned_face = align_face(img)
89
- cropped_array = cropper.crop(img[:,:,::-1])
90
 
91
- if cropped_array.any():
92
- aligned_face = Image.fromarray(cropped_array)
93
- else:
94
- aligned_face = Image.fromarray(img[:,:,::-1])
95
-
96
- my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0)
 
 
 
 
 
97
  if model == 'JoJo':
98
  with torch.no_grad():
99
  my_sample = generatorjojo(my_w, input_is_latent=True)
 
3
  import torch
4
  import gradio as gr
5
  os.system("pip install gradio==2.5.3")
6
+
7
+ os.system("pip install facexlib")
8
+
9
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
10
+ #os.system("pip install autocrop")
11
  #os.system("pip install dlib")
12
  from autocrop import Cropper
13
  import torch
 
38
  #os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2")
39
  #os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat")
40
 
41
+ #cropper = Cropper(face_percent=80)
42
+
43
+ face_helper = FaceRestoreHelper(
44
+ upscale_factor=0,
45
+ face_size=512,
46
+ crop_ratio=(1, 1),
47
+ det_model='retinaface_resnet50',
48
+ save_ext='png',
49
+ device='cpu')
50
 
51
  device = 'cpu'
52
 
 
97
 
98
 
99
  def inference(img, model):
100
+ face_helper.clean_all()
101
  #aligned_face = align_face(img)
102
+ #cropped_array = cropper.crop(img[:,:,::-1])
103
 
104
+ #if cropped_array.any():
105
+ #aligned_face = Image.fromarray(cropped_array)
106
+ #else:
107
+ #aligned_face = Image.fromarray(img[:,:,::-1])
108
+
109
+ face_helper.read_image(img)
110
+ face_helper.get_face_landmarks_5(only_center_face=False, eye_dist_threshold=5)
111
+ face_helper.align_warp_face(save_cropped_path=".")
112
+ pilimg = Image.open("./_00.png")
113
+
114
+ my_w = e4e_projection(pilimg, "test.pt", device).unsqueeze(0)
115
  if model == 'JoJo':
116
  with torch.no_grad():
117
  my_sample = generatorjojo(my_w, input_is_latent=True)