hylee commited on
Commit
46ba26b
·
1 Parent(s): d54e3d6
Files changed (2) hide show
  1. app.py +73 -35
  2. requirements.txt +10 -8
app.py CHANGED
@@ -22,13 +22,13 @@ import shutil
22
  from options.test_options import TestOptions
23
  from data import CreateDataLoader
24
  from models import create_model
25
-
 
26
  from util import html
27
 
28
  import ntpath
29
  from util import util
30
 
31
-
32
  ORIGINAL_REPO_URL = 'https://github.com/yiranran/APDrawingGAN2'
33
  TITLE = 'yiranran/APDrawingGAN2'
34
  DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}.
@@ -38,9 +38,9 @@ ARTICLE = """
38
 
39
  """
40
 
41
-
42
  MODEL_REPO = 'hylee/apdrawing_model'
43
 
 
44
  def parse_args() -> argparse.Namespace:
45
  parser = argparse.ArgumentParser()
46
  parser.add_argument('--device', type=str, default='cpu')
@@ -59,14 +59,15 @@ def parse_args() -> argparse.Namespace:
59
  def load_checkpoint():
60
  dir = 'checkpoint'
61
  checkpoint_path = huggingface_hub.hf_hub_download(MODEL_REPO,
62
- 'checkpoints.zip',
63
- force_filename='checkpoints.zip')
64
  print(checkpoint_path)
65
  shutil.unpack_archive(checkpoint_path, extract_dir=dir)
66
 
67
- print(os.listdir(dir+'/checkpoints'))
 
 
68
 
69
- return dir+'/checkpoints'
70
 
71
  # save image to the disk
72
  def save_images2(image_dir, visuals, image_path, aspect_ratio=1.0, width=256):
@@ -76,7 +77,7 @@ def save_images2(image_dir, visuals, image_path, aspect_ratio=1.0, width=256):
76
  imgs = []
77
 
78
  for label, im_data in visuals.items():
79
- im = util.tensor2im(im_data)#tensor to numpy array [-1,1]->[0,1]->[0,255]
80
  image_name = '%s_%s.png' % (name, label)
81
  save_path = os.path.join(image_dir, image_name)
82
  h, w, _ = im.shape
@@ -91,6 +92,8 @@ def save_images2(image_dir, visuals, image_path, aspect_ratio=1.0, width=256):
91
 
92
 
93
  SAFEHASH = [x for x in "0123456789-abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
 
 
94
  def compress_UUID():
95
  '''
96
  根据http://www.ietf.org/rfc/rfc1738.txt,由uuid编码扩bai大字符域生成du串
@@ -108,13 +111,29 @@ def compress_UUID():
108
  return safe_code
109
 
110
 
111
- def run(
112
- image,
113
- model,
114
- opt,
115
- ) -> tuple[PIL.Image.Image]:
116
 
117
- dataroot = 'images/'+compress_UUID()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  opt.dataroot = os.path.join(dataroot, 'src/')
119
  os.makedirs(opt.dataroot, exist_ok=True)
120
  opt.results_dir = os.path.join(dataroot, 'results/')
@@ -127,25 +146,40 @@ def run(
127
 
128
  shutil.copy(image.name, opt.dataroot)
129
 
130
- data_loader = CreateDataLoader(opt)
131
- dataset = data_loader.load_data()
132
 
133
- imgs = [image.name]
134
- # test
135
- # model.eval()
136
- for i, data in enumerate(dataset):
137
- if i >= opt.how_many: # test code only supports batch_size = 1, how_many means how many test images to run
138
- break
139
- model.set_input(data)
140
- model.test()
141
- visuals = model.get_current_visuals() # in test the loadSize is set to the same as fineSize
142
- img_path = model.get_image_paths()
143
- # if i % 5 == 0:
144
- # print('processing (%04d)-th image... %s' % (i, img_path))
145
- imgs = save_images2(opt.results_dir, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
146
 
147
- print(imgs)
148
- return PIL.Image.open(imgs[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
  def main():
@@ -178,14 +212,18 @@ def main():
178
 
179
  opt.checkpoints_dir = checkpoint_dir
180
 
181
-
182
  model = create_model(opt)
183
  model.setup(opt)
184
 
185
- func = functools.partial(run, model=model, opt=opt)
 
 
 
 
 
 
186
  func = functools.update_wrapper(func, run)
187
 
188
-
189
  gr.Interface(
190
  func,
191
  [
@@ -196,7 +234,7 @@ def main():
196
  type='pil',
197
  label='Result'),
198
  ],
199
- #examples=examples,
200
  theme=args.theme,
201
  title=TITLE,
202
  description=DESCRIPTION,
 
22
  from options.test_options import TestOptions
23
  from data import CreateDataLoader
24
  from models import create_model
25
+ import dlib
26
+ import preprocess.get_partmask
27
  from util import html
28
 
29
  import ntpath
30
  from util import util
31
 
 
32
  ORIGINAL_REPO_URL = 'https://github.com/yiranran/APDrawingGAN2'
33
  TITLE = 'yiranran/APDrawingGAN2'
34
  DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}.
 
38
 
39
  """
40
 
 
41
  MODEL_REPO = 'hylee/apdrawing_model'
42
 
43
+
44
  def parse_args() -> argparse.Namespace:
45
  parser = argparse.ArgumentParser()
46
  parser.add_argument('--device', type=str, default='cpu')
 
59
  def load_checkpoint():
60
  dir = 'checkpoint'
61
  checkpoint_path = huggingface_hub.hf_hub_download(MODEL_REPO,
62
+ 'checkpoints.zip',
63
+ force_filename='checkpoints.zip')
64
  print(checkpoint_path)
65
  shutil.unpack_archive(checkpoint_path, extract_dir=dir)
66
 
67
+ print(os.listdir(dir + '/checkpoints'))
68
+
69
+ return dir + '/checkpoints'
70
 
 
71
 
72
  # save image to the disk
73
  def save_images2(image_dir, visuals, image_path, aspect_ratio=1.0, width=256):
 
77
  imgs = []
78
 
79
  for label, im_data in visuals.items():
80
+ im = util.tensor2im(im_data) # tensor to numpy array [-1,1]->[0,1]->[0,255]
81
  image_name = '%s_%s.png' % (name, label)
82
  save_path = os.path.join(image_dir, image_name)
83
  h, w, _ = im.shape
 
92
 
93
 
94
  SAFEHASH = [x for x in "0123456789-abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ"]
95
+
96
+
97
  def compress_UUID():
98
  '''
99
  根据http://www.ietf.org/rfc/rfc1738.txt,由uuid编码扩bai大字符域生成du串
 
111
  return safe_code
112
 
113
 
 
 
 
 
 
114
 
115
+ def get_68lm(imgfile, savepath, detector, predictor):
116
+ image = cv2.imread(imgfile)
117
+ rgbImg = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
118
+ rects = detector(rgbImg, 1)
119
+ for (i, rect) in enumerate(rects):
120
+ landmarks = predictor(rgbImg, rect)
121
+ landmarks = shape_to_np(landmarks)
122
+ f = open(savepath, 'w')
123
+ for i in range(len(landmarks)):
124
+ lm = landmarks[i]
125
+ print(lm[0], lm[1], file=f)
126
+ f.close()
127
+
128
+
129
+ def run(
130
+ image,
131
+ model,
132
+ opt,
133
+ detector,
134
+ predictor,
135
+ ) -> tuple[PIL.Image.Image,PIL.Image.Image,PIL.Image.Image,PIL.Image.Image]:
136
+ dataroot = 'images/' + compress_UUID()
137
  opt.dataroot = os.path.join(dataroot, 'src/')
138
  os.makedirs(opt.dataroot, exist_ok=True)
139
  opt.results_dir = os.path.join(dataroot, 'results/')
 
146
 
147
  shutil.copy(image.name, opt.dataroot)
148
 
149
+ fullname = os.path.basename(image.name)
150
+ name = fullname.split(".")[0]
151
 
152
+ imgfile = os.path.join(opt.dataroot, fullname)
153
+ lmfile = os.path.join(opt.lm_dir, name+'.txt')
154
+ # 预处理数据
155
+ get_68lm(imgfile, lmfile, detector, predictor)
 
 
 
 
 
 
 
 
 
156
 
157
+ imgs = []
158
+ for part in ['eyel', 'eyer', 'nose', 'mouth']:
159
+ savepath = os.path.join(opt.bg_dir + part, name+'.png')
160
+ get_partmask.get_partmask(imgfile, part, lmfile, savepath)
161
+ imgs.append(savepath)
162
+
163
+ # data_loader = CreateDataLoader(opt)
164
+ # dataset = data_loader.load_data()
165
+ #
166
+ # imgs = [image.name]
167
+ # # test
168
+ # # model.eval()
169
+ # for i, data in enumerate(dataset):
170
+ # if i >= opt.how_many: # test code only supports batch_size = 1, how_many means how many test images to run
171
+ # break
172
+ # model.set_input(data)
173
+ # model.test()
174
+ # visuals = model.get_current_visuals() # in test the loadSize is set to the same as fineSize
175
+ # img_path = model.get_image_paths()
176
+ # # if i % 5 == 0:
177
+ # # print('processing (%04d)-th image... %s' % (i, img_path))
178
+ # imgs = save_images2(opt.results_dir, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
179
+ #
180
+ # print(imgs)
181
+
182
+ return PIL.Image.open(imgs[0]),PIL.Image.open(imgs[1]),PIL.Image.open(imgs[2]),PIL.Image.open(imgs[3])
183
 
184
 
185
  def main():
 
212
 
213
  opt.checkpoints_dir = checkpoint_dir
214
 
 
215
  model = create_model(opt)
216
  model.setup(opt)
217
 
218
+ '''
219
+ 预处理数据
220
+ '''
221
+ detector = dlib.get_frontal_face_detector()
222
+ predictor = dlib.shape_predictor(checkpoint_dir + '/shape_predictor_68_face_landmarks.dat')
223
+
224
+ func = functools.partial(run, model=model, opt=opt, detector=detector, predictor=predictor)
225
  func = functools.update_wrapper(func, run)
226
 
 
227
  gr.Interface(
228
  func,
229
  [
 
234
  type='pil',
235
  label='Result'),
236
  ],
237
+ # examples=examples,
238
  theme=args.theme,
239
  title=TITLE,
240
  description=DESCRIPTION,
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
- torch>=0.4.0
2
- torchvision>=0.2.1
3
- dominate>=2.3.1
4
- visdom>=0.1.8.3
5
- scipy>=1.1.0
6
- numpy>=1.14.1
7
- Pillow>=5.0.0
8
- opencv-python>=3.4.2
 
 
 
1
+ torch==1.1.0
2
+ torchvision==0.4.0
3
+ dominate==2.4.0
4
+ visdom==0.1.8.9
5
+ scipy==1.1.0
6
+ numpy==1.16.4
7
+ Pillow==4.3.0
8
+ opencv-python==4.1.0.25
9
+ dlib==19.18.0
10
+ shapely==1.7.0