kevinwang676 commited on
Commit
69ab86f
·
verified ·
1 Parent(s): fda7f35

Delete src

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/audio2exp_models/audio2exp.py +0 -41
  2. src/audio2exp_models/networks.py +0 -74
  3. src/audio2pose_models/audio2pose.py +0 -94
  4. src/audio2pose_models/audio_encoder.py +0 -64
  5. src/audio2pose_models/cvae.py +0 -149
  6. src/audio2pose_models/discriminator.py +0 -76
  7. src/audio2pose_models/networks.py +0 -140
  8. src/audio2pose_models/res_unet.py +0 -65
  9. src/config/auido2exp.yaml +0 -58
  10. src/config/auido2pose.yaml +0 -49
  11. src/config/facerender.yaml +0 -45
  12. src/config/facerender_still.yaml +0 -45
  13. src/config/similarity_Lm3D_all.mat +0 -0
  14. src/face3d/data/__init__.py +0 -116
  15. src/face3d/data/base_dataset.py +0 -125
  16. src/face3d/data/flist_dataset.py +0 -125
  17. src/face3d/data/image_folder.py +0 -66
  18. src/face3d/data/template_dataset.py +0 -75
  19. src/face3d/extract_kp_videos.py +0 -108
  20. src/face3d/extract_kp_videos_safe.py +0 -151
  21. src/face3d/models/__init__.py +0 -67
  22. src/face3d/models/arcface_torch/README.md +0 -164
  23. src/face3d/models/arcface_torch/backbones/__init__.py +0 -25
  24. src/face3d/models/arcface_torch/backbones/iresnet.py +0 -187
  25. src/face3d/models/arcface_torch/backbones/iresnet2060.py +0 -176
  26. src/face3d/models/arcface_torch/backbones/mobilefacenet.py +0 -130
  27. src/face3d/models/arcface_torch/configs/3millions.py +0 -23
  28. src/face3d/models/arcface_torch/configs/3millions_pfc.py +0 -23
  29. src/face3d/models/arcface_torch/configs/__init__.py +0 -0
  30. src/face3d/models/arcface_torch/configs/base.py +0 -56
  31. src/face3d/models/arcface_torch/configs/glint360k_mbf.py +0 -26
  32. src/face3d/models/arcface_torch/configs/glint360k_r100.py +0 -26
  33. src/face3d/models/arcface_torch/configs/glint360k_r18.py +0 -26
  34. src/face3d/models/arcface_torch/configs/glint360k_r34.py +0 -26
  35. src/face3d/models/arcface_torch/configs/glint360k_r50.py +0 -26
  36. src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py +0 -26
  37. src/face3d/models/arcface_torch/configs/ms1mv3_r18.py +0 -26
  38. src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py +0 -26
  39. src/face3d/models/arcface_torch/configs/ms1mv3_r34.py +0 -26
  40. src/face3d/models/arcface_torch/configs/ms1mv3_r50.py +0 -26
  41. src/face3d/models/arcface_torch/configs/speed.py +0 -23
  42. src/face3d/models/arcface_torch/dataset.py +0 -124
  43. src/face3d/models/arcface_torch/docs/eval.md +0 -31
  44. src/face3d/models/arcface_torch/docs/install.md +0 -51
  45. src/face3d/models/arcface_torch/docs/modelzoo.md +0 -0
  46. src/face3d/models/arcface_torch/docs/speed_benchmark.md +0 -93
  47. src/face3d/models/arcface_torch/eval/__init__.py +0 -0
  48. src/face3d/models/arcface_torch/eval/verification.py +0 -407
  49. src/face3d/models/arcface_torch/eval_ijbc.py +0 -483
  50. src/face3d/models/arcface_torch/inference.py +0 -35
src/audio2exp_models/audio2exp.py DELETED
@@ -1,41 +0,0 @@
1
- from tqdm import tqdm
2
- import torch
3
- from torch import nn
4
-
5
-
6
- class Audio2Exp(nn.Module):
7
- def __init__(self, netG, cfg, device, prepare_training_loss=False):
8
- super(Audio2Exp, self).__init__()
9
- self.cfg = cfg
10
- self.device = device
11
- self.netG = netG.to(device)
12
-
13
- def test(self, batch):
14
-
15
- mel_input = batch['indiv_mels'] # bs T 1 80 16
16
- bs = mel_input.shape[0]
17
- T = mel_input.shape[1]
18
-
19
- exp_coeff_pred = []
20
-
21
- for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
22
-
23
- current_mel_input = mel_input[:,i:i+10]
24
-
25
- #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
26
- ref = batch['ref'][:, :, :64][:, i:i+10]
27
- ratio = batch['ratio_gt'][:, i:i+10] #bs T
28
-
29
- audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
30
-
31
- curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
32
-
33
- exp_coeff_pred += [curr_exp_coeff_pred]
34
-
35
- # BS x T x 64
36
- results_dict = {
37
- 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
38
- }
39
- return results_dict
40
-
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2exp_models/networks.py DELETED
@@ -1,74 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
-
5
- class Conv2d(nn.Module):
6
- def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7
- super().__init__(*args, **kwargs)
8
- self.conv_block = nn.Sequential(
9
- nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
- nn.BatchNorm2d(cout)
11
- )
12
- self.act = nn.ReLU()
13
- self.residual = residual
14
- self.use_act = use_act
15
-
16
- def forward(self, x):
17
- out = self.conv_block(x)
18
- if self.residual:
19
- out += x
20
-
21
- if self.use_act:
22
- return self.act(out)
23
- else:
24
- return out
25
-
26
- class SimpleWrapperV2(nn.Module):
27
- def __init__(self) -> None:
28
- super().__init__()
29
- self.audio_encoder = nn.Sequential(
30
- Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33
-
34
- Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37
-
38
- Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41
-
42
- Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44
-
45
- Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46
- Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47
- )
48
-
49
- #### load the pre-trained audio_encoder
50
- #self.audio_encoder = self.audio_encoder.to(device)
51
- '''
52
- wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53
- state_dict = self.audio_encoder.state_dict()
54
-
55
- for k,v in wav2lip_state_dict.items():
56
- if 'audio_encoder' in k:
57
- print('init:', k)
58
- state_dict[k.replace('module.audio_encoder.', '')] = v
59
- self.audio_encoder.load_state_dict(state_dict)
60
- '''
61
-
62
- self.mapping1 = nn.Linear(512+64+1, 64)
63
- #self.mapping2 = nn.Linear(30, 64)
64
- #nn.init.constant_(self.mapping1.weight, 0.)
65
- nn.init.constant_(self.mapping1.bias, 0.)
66
-
67
- def forward(self, x, ref, ratio):
68
- x = self.audio_encoder(x).view(x.size(0), -1)
69
- ref_reshape = ref.reshape(x.size(0), -1)
70
- ratio = ratio.reshape(x.size(0), -1)
71
-
72
- y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73
- out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2pose_models/audio2pose.py DELETED
@@ -1,94 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from src.audio2pose_models.cvae import CVAE
4
- from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
5
- from src.audio2pose_models.audio_encoder import AudioEncoder
6
-
7
- class Audio2Pose(nn.Module):
8
- def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9
- super().__init__()
10
- self.cfg = cfg
11
- self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12
- self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13
- self.device = device
14
-
15
- self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
16
- self.audio_encoder.eval()
17
- for param in self.audio_encoder.parameters():
18
- param.requires_grad = False
19
-
20
- self.netG = CVAE(cfg)
21
- self.netD_motion = PoseSequenceDiscriminator(cfg)
22
-
23
-
24
- def forward(self, x):
25
-
26
- batch = {}
27
- coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28
- batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
29
- batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
30
- batch['class'] = x['class'].squeeze(0).cuda() # bs
31
- indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32
-
33
- # forward
34
- audio_emb_list = []
35
- audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
36
- batch['audio_emb'] = audio_emb
37
- batch = self.netG(batch)
38
-
39
- pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40
- pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
41
- pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
42
-
43
- batch['pose_pred'] = pose_pred
44
- batch['pose_gt'] = pose_gt
45
-
46
- return batch
47
-
48
- def test(self, x):
49
-
50
- batch = {}
51
- ref = x['ref'] #bs 1 70
52
- batch['ref'] = x['ref'][:,0,-6:]
53
- batch['class'] = x['class']
54
- bs = ref.shape[0]
55
-
56
- indiv_mels= x['indiv_mels'] # bs T 1 80 16
57
- indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
58
- num_frames = x['num_frames']
59
- num_frames = int(num_frames) - 1
60
-
61
- #
62
- div = num_frames//self.seq_len
63
- re = num_frames%self.seq_len
64
- audio_emb_list = []
65
- pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
66
- device=batch['ref'].device)]
67
-
68
- for i in range(div):
69
- z = torch.randn(bs, self.latent_dim).to(ref.device)
70
- batch['z'] = z
71
- audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
72
- batch['audio_emb'] = audio_emb
73
- batch = self.netG.test(batch)
74
- pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
75
-
76
- if re != 0:
77
- z = torch.randn(bs, self.latent_dim).to(ref.device)
78
- batch['z'] = z
79
- audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
80
- if audio_emb.shape[1] != self.seq_len:
81
- pad_dim = self.seq_len-audio_emb.shape[1]
82
- pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
83
- audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
84
- batch['audio_emb'] = audio_emb
85
- batch = self.netG.test(batch)
86
- pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
87
-
88
- pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
89
- batch['pose_motion_pred'] = pose_motion_pred
90
-
91
- pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
92
-
93
- batch['pose_pred'] = pose_pred
94
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2pose_models/audio_encoder.py DELETED
@@ -1,64 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
-
5
- class Conv2d(nn.Module):
6
- def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
- super().__init__(*args, **kwargs)
8
- self.conv_block = nn.Sequential(
9
- nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
- nn.BatchNorm2d(cout)
11
- )
12
- self.act = nn.ReLU()
13
- self.residual = residual
14
-
15
- def forward(self, x):
16
- out = self.conv_block(x)
17
- if self.residual:
18
- out += x
19
- return self.act(out)
20
-
21
- class AudioEncoder(nn.Module):
22
- def __init__(self, wav2lip_checkpoint, device):
23
- super(AudioEncoder, self).__init__()
24
-
25
- self.audio_encoder = nn.Sequential(
26
- Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
27
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
28
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
29
-
30
- Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
31
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
32
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
33
-
34
- Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
35
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
36
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
37
-
38
- Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
39
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
40
-
41
- Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
- Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
-
44
- #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
45
- # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
46
- # state_dict = self.audio_encoder.state_dict()
47
-
48
- # for k,v in wav2lip_state_dict.items():
49
- # if 'audio_encoder' in k:
50
- # state_dict[k.replace('module.audio_encoder.', '')] = v
51
- # self.audio_encoder.load_state_dict(state_dict)
52
-
53
-
54
- def forward(self, audio_sequences):
55
- # audio_sequences = (B, T, 1, 80, 16)
56
- B = audio_sequences.size(0)
57
-
58
- audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
59
-
60
- audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
61
- dim = audio_embedding.shape[1]
62
- audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
63
-
64
- return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2pose_models/cvae.py DELETED
@@ -1,149 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
- from src.audio2pose_models.res_unet import ResUnet
5
-
6
- def class2onehot(idx, class_num):
7
-
8
- assert torch.max(idx).item() < class_num
9
- onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
10
- onehot.scatter_(1, idx, 1)
11
- return onehot
12
-
13
- class CVAE(nn.Module):
14
- def __init__(self, cfg):
15
- super().__init__()
16
- encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
17
- decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
18
- latent_size = cfg.MODEL.CVAE.LATENT_SIZE
19
- num_classes = cfg.DATASET.NUM_CLASSES
20
- audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
21
- audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
22
- seq_len = cfg.MODEL.CVAE.SEQ_LEN
23
-
24
- self.latent_size = latent_size
25
-
26
- self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
27
- audio_emb_in_size, audio_emb_out_size, seq_len)
28
- self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
29
- audio_emb_in_size, audio_emb_out_size, seq_len)
30
- def reparameterize(self, mu, logvar):
31
- std = torch.exp(0.5 * logvar)
32
- eps = torch.randn_like(std)
33
- return mu + eps * std
34
-
35
- def forward(self, batch):
36
- batch = self.encoder(batch)
37
- mu = batch['mu']
38
- logvar = batch['logvar']
39
- z = self.reparameterize(mu, logvar)
40
- batch['z'] = z
41
- return self.decoder(batch)
42
-
43
- def test(self, batch):
44
- '''
45
- class_id = batch['class']
46
- z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
47
- batch['z'] = z
48
- '''
49
- return self.decoder(batch)
50
-
51
- class ENCODER(nn.Module):
52
- def __init__(self, layer_sizes, latent_size, num_classes,
53
- audio_emb_in_size, audio_emb_out_size, seq_len):
54
- super().__init__()
55
-
56
- self.resunet = ResUnet()
57
- self.num_classes = num_classes
58
- self.seq_len = seq_len
59
-
60
- self.MLP = nn.Sequential()
61
- layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
62
- for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
63
- self.MLP.add_module(
64
- name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
65
- self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
66
-
67
- self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
68
- self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
69
- self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
70
-
71
- self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
72
-
73
- def forward(self, batch):
74
- class_id = batch['class']
75
- pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
76
- ref = batch['ref'] #bs 6
77
- bs = pose_motion_gt.shape[0]
78
- audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
79
-
80
- #pose encode
81
- pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
82
- pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
83
-
84
- #audio mapping
85
- print(audio_in.shape)
86
- audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
87
- audio_out = audio_out.reshape(bs, -1)
88
-
89
- class_bias = self.classbias[class_id] #bs latent_size
90
- x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
91
- x_out = self.MLP(x_in)
92
-
93
- mu = self.linear_means(x_out)
94
- logvar = self.linear_means(x_out) #bs latent_size
95
-
96
- batch.update({'mu':mu, 'logvar':logvar})
97
- return batch
98
-
99
- class DECODER(nn.Module):
100
- def __init__(self, layer_sizes, latent_size, num_classes,
101
- audio_emb_in_size, audio_emb_out_size, seq_len):
102
- super().__init__()
103
-
104
- self.resunet = ResUnet()
105
- self.num_classes = num_classes
106
- self.seq_len = seq_len
107
-
108
- self.MLP = nn.Sequential()
109
- input_size = latent_size + seq_len*audio_emb_out_size + 6
110
- for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
111
- self.MLP.add_module(
112
- name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
113
- if i+1 < len(layer_sizes):
114
- self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
115
- else:
116
- self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
117
-
118
- self.pose_linear = nn.Linear(6, 6)
119
- self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
120
-
121
- self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
122
-
123
- def forward(self, batch):
124
-
125
- z = batch['z'] #bs latent_size
126
- bs = z.shape[0]
127
- class_id = batch['class']
128
- ref = batch['ref'] #bs 6
129
- audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
130
- #print('audio_in: ', audio_in[:, :, :10])
131
-
132
- audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
133
- #print('audio_out: ', audio_out[:, :, :10])
134
- audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
135
- class_bias = self.classbias[class_id] #bs latent_size
136
-
137
- z = z + class_bias
138
- x_in = torch.cat([ref, z, audio_out], dim=-1)
139
- x_out = self.MLP(x_in) # bs layer_sizes[-1]
140
- x_out = x_out.reshape((bs, self.seq_len, -1))
141
-
142
- #print('x_out: ', x_out)
143
-
144
- pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
145
-
146
- pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
147
-
148
- batch.update({'pose_motion_pred':pose_motion_pred})
149
- return batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2pose_models/discriminator.py DELETED
@@ -1,76 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from torch import nn
4
-
5
- class ConvNormRelu(nn.Module):
6
- def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
7
- kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
8
- super().__init__()
9
- if kernel_size is None:
10
- if downsample:
11
- kernel_size, stride, padding = 4, 2, 1
12
- else:
13
- kernel_size, stride, padding = 3, 1, 1
14
-
15
- if conv_type == '2d':
16
- self.conv = nn.Conv2d(
17
- in_channels,
18
- out_channels,
19
- kernel_size,
20
- stride,
21
- padding,
22
- bias=False,
23
- )
24
- if norm == 'BN':
25
- self.norm = nn.BatchNorm2d(out_channels)
26
- elif norm == 'IN':
27
- self.norm = nn.InstanceNorm2d(out_channels)
28
- else:
29
- raise NotImplementedError
30
- elif conv_type == '1d':
31
- self.conv = nn.Conv1d(
32
- in_channels,
33
- out_channels,
34
- kernel_size,
35
- stride,
36
- padding,
37
- bias=False,
38
- )
39
- if norm == 'BN':
40
- self.norm = nn.BatchNorm1d(out_channels)
41
- elif norm == 'IN':
42
- self.norm = nn.InstanceNorm1d(out_channels)
43
- else:
44
- raise NotImplementedError
45
- nn.init.kaiming_normal_(self.conv.weight)
46
-
47
- self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
48
-
49
- def forward(self, x):
50
- x = self.conv(x)
51
- if isinstance(self.norm, nn.InstanceNorm1d):
52
- x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
53
- else:
54
- x = self.norm(x)
55
- x = self.act(x)
56
- return x
57
-
58
-
59
- class PoseSequenceDiscriminator(nn.Module):
60
- def __init__(self, cfg):
61
- super().__init__()
62
- self.cfg = cfg
63
- leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
64
-
65
- self.seq = nn.Sequential(
66
- ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
67
- ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
68
- ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
69
- nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
70
- )
71
-
72
- def forward(self, x):
73
- x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
74
- x = self.seq(x)
75
- x = x.squeeze(1)
76
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2pose_models/networks.py DELETED
@@ -1,140 +0,0 @@
1
- import torch.nn as nn
2
- import torch
3
-
4
-
5
- class ResidualConv(nn.Module):
6
- def __init__(self, input_dim, output_dim, stride, padding):
7
- super(ResidualConv, self).__init__()
8
-
9
- self.conv_block = nn.Sequential(
10
- nn.BatchNorm2d(input_dim),
11
- nn.ReLU(),
12
- nn.Conv2d(
13
- input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
14
- ),
15
- nn.BatchNorm2d(output_dim),
16
- nn.ReLU(),
17
- nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
18
- )
19
- self.conv_skip = nn.Sequential(
20
- nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
21
- nn.BatchNorm2d(output_dim),
22
- )
23
-
24
- def forward(self, x):
25
-
26
- return self.conv_block(x) + self.conv_skip(x)
27
-
28
-
29
- class Upsample(nn.Module):
30
- def __init__(self, input_dim, output_dim, kernel, stride):
31
- super(Upsample, self).__init__()
32
-
33
- self.upsample = nn.ConvTranspose2d(
34
- input_dim, output_dim, kernel_size=kernel, stride=stride
35
- )
36
-
37
- def forward(self, x):
38
- return self.upsample(x)
39
-
40
-
41
- class Squeeze_Excite_Block(nn.Module):
42
- def __init__(self, channel, reduction=16):
43
- super(Squeeze_Excite_Block, self).__init__()
44
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
45
- self.fc = nn.Sequential(
46
- nn.Linear(channel, channel // reduction, bias=False),
47
- nn.ReLU(inplace=True),
48
- nn.Linear(channel // reduction, channel, bias=False),
49
- nn.Sigmoid(),
50
- )
51
-
52
- def forward(self, x):
53
- b, c, _, _ = x.size()
54
- y = self.avg_pool(x).view(b, c)
55
- y = self.fc(y).view(b, c, 1, 1)
56
- return x * y.expand_as(x)
57
-
58
-
59
- class ASPP(nn.Module):
60
- def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
61
- super(ASPP, self).__init__()
62
-
63
- self.aspp_block1 = nn.Sequential(
64
- nn.Conv2d(
65
- in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
66
- ),
67
- nn.ReLU(inplace=True),
68
- nn.BatchNorm2d(out_dims),
69
- )
70
- self.aspp_block2 = nn.Sequential(
71
- nn.Conv2d(
72
- in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
73
- ),
74
- nn.ReLU(inplace=True),
75
- nn.BatchNorm2d(out_dims),
76
- )
77
- self.aspp_block3 = nn.Sequential(
78
- nn.Conv2d(
79
- in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
80
- ),
81
- nn.ReLU(inplace=True),
82
- nn.BatchNorm2d(out_dims),
83
- )
84
-
85
- self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
86
- self._init_weights()
87
-
88
- def forward(self, x):
89
- x1 = self.aspp_block1(x)
90
- x2 = self.aspp_block2(x)
91
- x3 = self.aspp_block3(x)
92
- out = torch.cat([x1, x2, x3], dim=1)
93
- return self.output(out)
94
-
95
- def _init_weights(self):
96
- for m in self.modules():
97
- if isinstance(m, nn.Conv2d):
98
- nn.init.kaiming_normal_(m.weight)
99
- elif isinstance(m, nn.BatchNorm2d):
100
- m.weight.data.fill_(1)
101
- m.bias.data.zero_()
102
-
103
-
104
- class Upsample_(nn.Module):
105
- def __init__(self, scale=2):
106
- super(Upsample_, self).__init__()
107
-
108
- self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
109
-
110
- def forward(self, x):
111
- return self.upsample(x)
112
-
113
-
114
- class AttentionBlock(nn.Module):
115
- def __init__(self, input_encoder, input_decoder, output_dim):
116
- super(AttentionBlock, self).__init__()
117
-
118
- self.conv_encoder = nn.Sequential(
119
- nn.BatchNorm2d(input_encoder),
120
- nn.ReLU(),
121
- nn.Conv2d(input_encoder, output_dim, 3, padding=1),
122
- nn.MaxPool2d(2, 2),
123
- )
124
-
125
- self.conv_decoder = nn.Sequential(
126
- nn.BatchNorm2d(input_decoder),
127
- nn.ReLU(),
128
- nn.Conv2d(input_decoder, output_dim, 3, padding=1),
129
- )
130
-
131
- self.conv_attn = nn.Sequential(
132
- nn.BatchNorm2d(output_dim),
133
- nn.ReLU(),
134
- nn.Conv2d(output_dim, 1, 1),
135
- )
136
-
137
- def forward(self, x1, x2):
138
- out = self.conv_encoder(x1) + self.conv_decoder(x2)
139
- out = self.conv_attn(out)
140
- return out * x2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/audio2pose_models/res_unet.py DELETED
@@ -1,65 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from src.audio2pose_models.networks import ResidualConv, Upsample
4
-
5
-
6
- class ResUnet(nn.Module):
7
- def __init__(self, channel=1, filters=[32, 64, 128, 256]):
8
- super(ResUnet, self).__init__()
9
-
10
- self.input_layer = nn.Sequential(
11
- nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
12
- nn.BatchNorm2d(filters[0]),
13
- nn.ReLU(),
14
- nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
15
- )
16
- self.input_skip = nn.Sequential(
17
- nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
18
- )
19
-
20
- self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
21
- self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
22
-
23
- self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
24
-
25
- self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
26
- self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
27
-
28
- self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
29
- self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
30
-
31
- self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
32
- self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
33
-
34
- self.output_layer = nn.Sequential(
35
- nn.Conv2d(filters[0], 1, 1, 1),
36
- nn.Sigmoid(),
37
- )
38
-
39
- def forward(self, x):
40
- # Encode
41
- x1 = self.input_layer(x) + self.input_skip(x)
42
- x2 = self.residual_conv_1(x1)
43
- x3 = self.residual_conv_2(x2)
44
- # Bridge
45
- x4 = self.bridge(x3)
46
-
47
- # Decode
48
- x4 = self.upsample_1(x4)
49
- x5 = torch.cat([x4, x3], dim=1)
50
-
51
- x6 = self.up_residual_conv1(x5)
52
-
53
- x6 = self.upsample_2(x6)
54
- x7 = torch.cat([x6, x2], dim=1)
55
-
56
- x8 = self.up_residual_conv2(x7)
57
-
58
- x8 = self.upsample_3(x8)
59
- x9 = torch.cat([x8, x1], dim=1)
60
-
61
- x10 = self.up_residual_conv3(x9)
62
-
63
- output = self.output_layer(x10)
64
-
65
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/auido2exp.yaml DELETED
@@ -1,58 +0,0 @@
1
- DATASET:
2
- TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
3
- EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
4
- TRAIN_BATCH_SIZE: 32
5
- EVAL_BATCH_SIZE: 32
6
- EXP: True
7
- EXP_DIM: 64
8
- FRAME_LEN: 32
9
- COEFF_LEN: 73
10
- NUM_CLASSES: 46
11
- AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
- COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
13
- LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
14
- DEBUG: True
15
- NUM_REPEATS: 2
16
- T: 40
17
-
18
-
19
- MODEL:
20
- FRAMEWORK: V2
21
- AUDIOENCODER:
22
- LEAKY_RELU: True
23
- NORM: 'IN'
24
- DISCRIMINATOR:
25
- LEAKY_RELU: False
26
- INPUT_CHANNELS: 6
27
- CVAE:
28
- AUDIO_EMB_IN_SIZE: 512
29
- AUDIO_EMB_OUT_SIZE: 128
30
- SEQ_LEN: 32
31
- LATENT_SIZE: 256
32
- ENCODER_LAYER_SIZES: [192, 1024]
33
- DECODER_LAYER_SIZES: [1024, 192]
34
-
35
-
36
- TRAIN:
37
- MAX_EPOCH: 300
38
- GENERATOR:
39
- LR: 2.0e-5
40
- DISCRIMINATOR:
41
- LR: 1.0e-5
42
- LOSS:
43
- W_FEAT: 0
44
- W_COEFF_EXP: 2
45
- W_LM: 1.0e-2
46
- W_LM_MOUTH: 0
47
- W_REG: 0
48
- W_SYNC: 0
49
- W_COLOR: 0
50
- W_EXPRESSION: 0
51
- W_LIPREADING: 0.01
52
- W_LIPREADING_VV: 0
53
- W_EYE_BLINK: 4
54
-
55
- TAG:
56
- NAME: small_dataset
57
-
58
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/auido2pose.yaml DELETED
@@ -1,49 +0,0 @@
1
- DATASET:
2
- TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
3
- EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
4
- TRAIN_BATCH_SIZE: 64
5
- EVAL_BATCH_SIZE: 1
6
- EXP: True
7
- EXP_DIM: 64
8
- FRAME_LEN: 32
9
- COEFF_LEN: 73
10
- NUM_CLASSES: 46
11
- AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
- COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
13
- DEBUG: True
14
-
15
-
16
- MODEL:
17
- AUDIOENCODER:
18
- LEAKY_RELU: True
19
- NORM: 'IN'
20
- DISCRIMINATOR:
21
- LEAKY_RELU: False
22
- INPUT_CHANNELS: 6
23
- CVAE:
24
- AUDIO_EMB_IN_SIZE: 512
25
- AUDIO_EMB_OUT_SIZE: 6
26
- SEQ_LEN: 32
27
- LATENT_SIZE: 64
28
- ENCODER_LAYER_SIZES: [192, 128]
29
- DECODER_LAYER_SIZES: [128, 192]
30
-
31
-
32
- TRAIN:
33
- MAX_EPOCH: 150
34
- GENERATOR:
35
- LR: 1.0e-4
36
- DISCRIMINATOR:
37
- LR: 1.0e-4
38
- LOSS:
39
- LAMBDA_REG: 1
40
- LAMBDA_LANDMARKS: 0
41
- LAMBDA_VERTICES: 0
42
- LAMBDA_GAN_MOTION: 0.7
43
- LAMBDA_GAN_COEFF: 0
44
- LAMBDA_KL: 1
45
-
46
- TAG:
47
- NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
48
-
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/facerender.yaml DELETED
@@ -1,45 +0,0 @@
1
- model_params:
2
- common_params:
3
- num_kp: 15
4
- image_channel: 3
5
- feature_channel: 32
6
- estimate_jacobian: False # True
7
- kp_detector_params:
8
- temperature: 0.1
9
- block_expansion: 32
10
- max_features: 1024
11
- scale_factor: 0.25 # 0.25
12
- num_blocks: 5
13
- reshape_channel: 16384 # 16384 = 1024 * 16
14
- reshape_depth: 16
15
- he_estimator_params:
16
- block_expansion: 64
17
- max_features: 2048
18
- num_bins: 66
19
- generator_params:
20
- block_expansion: 64
21
- max_features: 512
22
- num_down_blocks: 2
23
- reshape_channel: 32
24
- reshape_depth: 16 # 512 = 32 * 16
25
- num_resblocks: 6
26
- estimate_occlusion_map: True
27
- dense_motion_params:
28
- block_expansion: 32
29
- max_features: 1024
30
- num_blocks: 5
31
- reshape_depth: 16
32
- compress: 4
33
- discriminator_params:
34
- scales: [1]
35
- block_expansion: 32
36
- max_features: 512
37
- num_blocks: 4
38
- sn: True
39
- mapping_params:
40
- coeff_nc: 70
41
- descriptor_nc: 1024
42
- layer: 3
43
- num_kp: 15
44
- num_bins: 66
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/facerender_still.yaml DELETED
@@ -1,45 +0,0 @@
1
- model_params:
2
- common_params:
3
- num_kp: 15
4
- image_channel: 3
5
- feature_channel: 32
6
- estimate_jacobian: False # True
7
- kp_detector_params:
8
- temperature: 0.1
9
- block_expansion: 32
10
- max_features: 1024
11
- scale_factor: 0.25 # 0.25
12
- num_blocks: 5
13
- reshape_channel: 16384 # 16384 = 1024 * 16
14
- reshape_depth: 16
15
- he_estimator_params:
16
- block_expansion: 64
17
- max_features: 2048
18
- num_bins: 66
19
- generator_params:
20
- block_expansion: 64
21
- max_features: 512
22
- num_down_blocks: 2
23
- reshape_channel: 32
24
- reshape_depth: 16 # 512 = 32 * 16
25
- num_resblocks: 6
26
- estimate_occlusion_map: True
27
- dense_motion_params:
28
- block_expansion: 32
29
- max_features: 1024
30
- num_blocks: 5
31
- reshape_depth: 16
32
- compress: 4
33
- discriminator_params:
34
- scales: [1]
35
- block_expansion: 32
36
- max_features: 512
37
- num_blocks: 4
38
- sn: True
39
- mapping_params:
40
- coeff_nc: 73
41
- descriptor_nc: 1024
42
- layer: 3
43
- num_kp: 15
44
- num_bins: 66
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/config/similarity_Lm3D_all.mat DELETED
Binary file (994 Bytes)
 
src/face3d/data/__init__.py DELETED
@@ -1,116 +0,0 @@
1
- """This package includes all the modules related to data loading and preprocessing
2
-
3
- To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
- You need to implement four functions:
5
- -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
- -- <__len__>: return the size of dataset.
7
- -- <__getitem__>: get a data point from data loader.
8
- -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
-
10
- Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
- See our template dataset class 'template_dataset.py' for more details.
12
- """
13
- import numpy as np
14
- import importlib
15
- import torch.utils.data
16
- from face3d.data.base_dataset import BaseDataset
17
-
18
-
19
- def find_dataset_using_name(dataset_name):
20
- """Import the module "data/[dataset_name]_dataset.py".
21
-
22
- In the file, the class called DatasetNameDataset() will
23
- be instantiated. It has to be a subclass of BaseDataset,
24
- and it is case-insensitive.
25
- """
26
- dataset_filename = "data." + dataset_name + "_dataset"
27
- datasetlib = importlib.import_module(dataset_filename)
28
-
29
- dataset = None
30
- target_dataset_name = dataset_name.replace('_', '') + 'dataset'
31
- for name, cls in datasetlib.__dict__.items():
32
- if name.lower() == target_dataset_name.lower() \
33
- and issubclass(cls, BaseDataset):
34
- dataset = cls
35
-
36
- if dataset is None:
37
- raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
38
-
39
- return dataset
40
-
41
-
42
- def get_option_setter(dataset_name):
43
- """Return the static method <modify_commandline_options> of the dataset class."""
44
- dataset_class = find_dataset_using_name(dataset_name)
45
- return dataset_class.modify_commandline_options
46
-
47
-
48
- def create_dataset(opt, rank=0):
49
- """Create a dataset given the option.
50
-
51
- This function wraps the class CustomDatasetDataLoader.
52
- This is the main interface between this package and 'train.py'/'test.py'
53
-
54
- Example:
55
- >>> from data import create_dataset
56
- >>> dataset = create_dataset(opt)
57
- """
58
- data_loader = CustomDatasetDataLoader(opt, rank=rank)
59
- dataset = data_loader.load_data()
60
- return dataset
61
-
62
- class CustomDatasetDataLoader():
63
- """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
-
65
- def __init__(self, opt, rank=0):
66
- """Initialize this class
67
-
68
- Step 1: create a dataset instance given the name [dataset_mode]
69
- Step 2: create a multi-threaded data loader.
70
- """
71
- self.opt = opt
72
- dataset_class = find_dataset_using_name(opt.dataset_mode)
73
- self.dataset = dataset_class(opt)
74
- self.sampler = None
75
- print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
76
- if opt.use_ddp and opt.isTrain:
77
- world_size = opt.world_size
78
- self.sampler = torch.utils.data.distributed.DistributedSampler(
79
- self.dataset,
80
- num_replicas=world_size,
81
- rank=rank,
82
- shuffle=not opt.serial_batches
83
- )
84
- self.dataloader = torch.utils.data.DataLoader(
85
- self.dataset,
86
- sampler=self.sampler,
87
- num_workers=int(opt.num_threads / world_size),
88
- batch_size=int(opt.batch_size / world_size),
89
- drop_last=True)
90
- else:
91
- self.dataloader = torch.utils.data.DataLoader(
92
- self.dataset,
93
- batch_size=opt.batch_size,
94
- shuffle=(not opt.serial_batches) and opt.isTrain,
95
- num_workers=int(opt.num_threads),
96
- drop_last=True
97
- )
98
-
99
- def set_epoch(self, epoch):
100
- self.dataset.current_epoch = epoch
101
- if self.sampler is not None:
102
- self.sampler.set_epoch(epoch)
103
-
104
- def load_data(self):
105
- return self
106
-
107
- def __len__(self):
108
- """Return the number of data in the dataset"""
109
- return min(len(self.dataset), self.opt.max_dataset_size)
110
-
111
- def __iter__(self):
112
- """Return a batch of data"""
113
- for i, data in enumerate(self.dataloader):
114
- if i * self.opt.batch_size >= self.opt.max_dataset_size:
115
- break
116
- yield data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/data/base_dataset.py DELETED
@@ -1,125 +0,0 @@
1
- """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
-
3
- It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
- """
5
- import random
6
- import numpy as np
7
- import torch.utils.data as data
8
- from PIL import Image
9
- import torchvision.transforms as transforms
10
- from abc import ABC, abstractmethod
11
-
12
-
13
- class BaseDataset(data.Dataset, ABC):
14
- """This class is an abstract base class (ABC) for datasets.
15
-
16
- To create a subclass, you need to implement the following four functions:
17
- -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
- -- <__len__>: return the size of dataset.
19
- -- <__getitem__>: get a data point.
20
- -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
- """
22
-
23
- def __init__(self, opt):
24
- """Initialize the class; save the options in the class
25
-
26
- Parameters:
27
- opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
- """
29
- self.opt = opt
30
- # self.root = opt.dataroot
31
- self.current_epoch = 0
32
-
33
- @staticmethod
34
- def modify_commandline_options(parser, is_train):
35
- """Add new dataset-specific options, and rewrite default values for existing options.
36
-
37
- Parameters:
38
- parser -- original option parser
39
- is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40
-
41
- Returns:
42
- the modified parser.
43
- """
44
- return parser
45
-
46
- @abstractmethod
47
- def __len__(self):
48
- """Return the total number of images in the dataset."""
49
- return 0
50
-
51
- @abstractmethod
52
- def __getitem__(self, index):
53
- """Return a data point and its metadata information.
54
-
55
- Parameters:
56
- index - - a random integer for data indexing
57
-
58
- Returns:
59
- a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60
- """
61
- pass
62
-
63
-
64
- def get_transform(grayscale=False):
65
- transform_list = []
66
- if grayscale:
67
- transform_list.append(transforms.Grayscale(1))
68
- transform_list += [transforms.ToTensor()]
69
- return transforms.Compose(transform_list)
70
-
71
- def get_affine_mat(opt, size):
72
- shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
73
- w, h = size
74
-
75
- if 'shift' in opt.preprocess:
76
- shift_pixs = int(opt.shift_pixs)
77
- shift_x = random.randint(-shift_pixs, shift_pixs)
78
- shift_y = random.randint(-shift_pixs, shift_pixs)
79
- if 'scale' in opt.preprocess:
80
- scale = 1 + opt.scale_delta * (2 * random.random() - 1)
81
- if 'rot' in opt.preprocess:
82
- rot_angle = opt.rot_angle * (2 * random.random() - 1)
83
- rot_rad = -rot_angle * np.pi/180
84
- if 'flip' in opt.preprocess:
85
- flip = random.random() > 0.5
86
-
87
- shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
88
- flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
89
- shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
90
- rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
91
- scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
92
- shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
93
-
94
- affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
95
- affine_inv = np.linalg.inv(affine)
96
- return affine, affine_inv, flip
97
-
98
- def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
99
- return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
100
-
101
- def apply_lm_affine(landmark, affine, flip, size):
102
- _, h = size
103
- lm = landmark.copy()
104
- lm[:, 1] = h - 1 - lm[:, 1]
105
- lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
106
- lm = lm @ np.transpose(affine)
107
- lm[:, :2] = lm[:, :2] / lm[:, 2:]
108
- lm = lm[:, :2]
109
- lm[:, 1] = h - 1 - lm[:, 1]
110
- if flip:
111
- lm_ = lm.copy()
112
- lm_[:17] = lm[16::-1]
113
- lm_[17:22] = lm[26:21:-1]
114
- lm_[22:27] = lm[21:16:-1]
115
- lm_[31:36] = lm[35:30:-1]
116
- lm_[36:40] = lm[45:41:-1]
117
- lm_[40:42] = lm[47:45:-1]
118
- lm_[42:46] = lm[39:35:-1]
119
- lm_[46:48] = lm[41:39:-1]
120
- lm_[48:55] = lm[54:47:-1]
121
- lm_[55:60] = lm[59:54:-1]
122
- lm_[60:65] = lm[64:59:-1]
123
- lm_[65:68] = lm[67:64:-1]
124
- lm = lm_
125
- return lm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/data/flist_dataset.py DELETED
@@ -1,125 +0,0 @@
1
- """This script defines the custom dataset for Deep3DFaceRecon_pytorch
2
- """
3
-
4
- import os.path
5
- from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
6
- from data.image_folder import make_dataset
7
- from PIL import Image
8
- import random
9
- import util.util as util
10
- import numpy as np
11
- import json
12
- import torch
13
- from scipy.io import loadmat, savemat
14
- import pickle
15
- from util.preprocess import align_img, estimate_norm
16
- from util.load_mats import load_lm3d
17
-
18
-
19
- def default_flist_reader(flist):
20
- """
21
- flist format: impath label\nimpath label\n ...(same to caffe's filelist)
22
- """
23
- imlist = []
24
- with open(flist, 'r') as rf:
25
- for line in rf.readlines():
26
- impath = line.strip()
27
- imlist.append(impath)
28
-
29
- return imlist
30
-
31
- def jason_flist_reader(flist):
32
- with open(flist, 'r') as fp:
33
- info = json.load(fp)
34
- return info
35
-
36
- def parse_label(label):
37
- return torch.tensor(np.array(label).astype(np.float32))
38
-
39
-
40
- class FlistDataset(BaseDataset):
41
- """
42
- It requires one directories to host training images '/path/to/data/train'
43
- You can train the model with the dataset flag '--dataroot /path/to/data'.
44
- """
45
-
46
- def __init__(self, opt):
47
- """Initialize this dataset class.
48
-
49
- Parameters:
50
- opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
51
- """
52
- BaseDataset.__init__(self, opt)
53
-
54
- self.lm3d_std = load_lm3d(opt.bfm_folder)
55
-
56
- msk_names = default_flist_reader(opt.flist)
57
- self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
58
-
59
- self.size = len(self.msk_paths)
60
- self.opt = opt
61
-
62
- self.name = 'train' if opt.isTrain else 'val'
63
- if '_' in opt.flist:
64
- self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
65
-
66
-
67
- def __getitem__(self, index):
68
- """Return a data point and its metadata information.
69
-
70
- Parameters:
71
- index (int) -- a random integer for data indexing
72
-
73
- Returns a dictionary that contains A, B, A_paths and B_paths
74
- img (tensor) -- an image in the input domain
75
- msk (tensor) -- its corresponding attention mask
76
- lm (tensor) -- its corresponding 3d landmarks
77
- im_paths (str) -- image paths
78
- aug_flag (bool) -- a flag used to tell whether its raw or augmented
79
- """
80
- msk_path = self.msk_paths[index % self.size] # make sure index is within then range
81
- img_path = msk_path.replace('mask/', '')
82
- lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
83
-
84
- raw_img = Image.open(img_path).convert('RGB')
85
- raw_msk = Image.open(msk_path).convert('RGB')
86
- raw_lm = np.loadtxt(lm_path).astype(np.float32)
87
-
88
- _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
89
-
90
- aug_flag = self.opt.use_aug and self.opt.isTrain
91
- if aug_flag:
92
- img, lm, msk = self._augmentation(img, lm, self.opt, msk)
93
-
94
- _, H = img.size
95
- M = estimate_norm(lm, H)
96
- transform = get_transform()
97
- img_tensor = transform(img)
98
- msk_tensor = transform(msk)[:1, ...]
99
- lm_tensor = parse_label(lm)
100
- M_tensor = parse_label(M)
101
-
102
-
103
- return {'imgs': img_tensor,
104
- 'lms': lm_tensor,
105
- 'msks': msk_tensor,
106
- 'M': M_tensor,
107
- 'im_paths': img_path,
108
- 'aug_flag': aug_flag,
109
- 'dataset': self.name}
110
-
111
- def _augmentation(self, img, lm, opt, msk=None):
112
- affine, affine_inv, flip = get_affine_mat(opt, img.size)
113
- img = apply_img_affine(img, affine_inv)
114
- lm = apply_lm_affine(lm, affine, flip, img.size)
115
- if msk is not None:
116
- msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
117
- return img, lm, msk
118
-
119
-
120
-
121
-
122
- def __len__(self):
123
- """Return the total number of images in the dataset.
124
- """
125
- return self.size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/data/image_folder.py DELETED
@@ -1,66 +0,0 @@
1
- """A modified image folder class
2
-
3
- We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
- so that this class can load images from both current directory and its subdirectories.
5
- """
6
- import numpy as np
7
- import torch.utils.data as data
8
-
9
- from PIL import Image
10
- import os
11
- import os.path
12
-
13
- IMG_EXTENSIONS = [
14
- '.jpg', '.JPG', '.jpeg', '.JPEG',
15
- '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
- '.tif', '.TIF', '.tiff', '.TIFF',
17
- ]
18
-
19
-
20
- def is_image_file(filename):
21
- return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
-
23
-
24
- def make_dataset(dir, max_dataset_size=float("inf")):
25
- images = []
26
- assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27
-
28
- for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29
- for fname in fnames:
30
- if is_image_file(fname):
31
- path = os.path.join(root, fname)
32
- images.append(path)
33
- return images[:min(max_dataset_size, len(images))]
34
-
35
-
36
- def default_loader(path):
37
- return Image.open(path).convert('RGB')
38
-
39
-
40
- class ImageFolder(data.Dataset):
41
-
42
- def __init__(self, root, transform=None, return_paths=False,
43
- loader=default_loader):
44
- imgs = make_dataset(root)
45
- if len(imgs) == 0:
46
- raise(RuntimeError("Found 0 images in: " + root + "\n"
47
- "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48
-
49
- self.root = root
50
- self.imgs = imgs
51
- self.transform = transform
52
- self.return_paths = return_paths
53
- self.loader = loader
54
-
55
- def __getitem__(self, index):
56
- path = self.imgs[index]
57
- img = self.loader(path)
58
- if self.transform is not None:
59
- img = self.transform(img)
60
- if self.return_paths:
61
- return img, path
62
- else:
63
- return img
64
-
65
- def __len__(self):
66
- return len(self.imgs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/data/template_dataset.py DELETED
@@ -1,75 +0,0 @@
1
- """Dataset class template
2
-
3
- This module provides a template for users to implement custom datasets.
4
- You can specify '--dataset_mode template' to use this dataset.
5
- The class name should be consistent with both the filename and its dataset_mode option.
6
- The filename should be <dataset_mode>_dataset.py
7
- The class name should be <Dataset_mode>Dataset.py
8
- You need to implement the following functions:
9
- -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
10
- -- <__init__>: Initialize this dataset class.
11
- -- <__getitem__>: Return a data point and its metadata information.
12
- -- <__len__>: Return the number of images.
13
- """
14
- from data.base_dataset import BaseDataset, get_transform
15
- # from data.image_folder import make_dataset
16
- # from PIL import Image
17
-
18
-
19
- class TemplateDataset(BaseDataset):
20
- """A template dataset class for you to implement custom datasets."""
21
- @staticmethod
22
- def modify_commandline_options(parser, is_train):
23
- """Add new dataset-specific options, and rewrite default values for existing options.
24
-
25
- Parameters:
26
- parser -- original option parser
27
- is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28
-
29
- Returns:
30
- the modified parser.
31
- """
32
- parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33
- parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34
- return parser
35
-
36
- def __init__(self, opt):
37
- """Initialize this dataset class.
38
-
39
- Parameters:
40
- opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41
-
42
- A few things can be done here.
43
- - save the options (have been done in BaseDataset)
44
- - get image paths and meta information of the dataset.
45
- - define the image transformation.
46
- """
47
- # save the option and dataset root
48
- BaseDataset.__init__(self, opt)
49
- # get the image paths of your dataset;
50
- self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51
- # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
52
- self.transform = get_transform(opt)
53
-
54
- def __getitem__(self, index):
55
- """Return a data point and its metadata information.
56
-
57
- Parameters:
58
- index -- a random integer for data indexing
59
-
60
- Returns:
61
- a dictionary of data with their names. It usually contains the data itself and its metadata information.
62
-
63
- Step 1: get a random image path: e.g., path = self.image_paths[index]
64
- Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65
- Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66
- Step 4: return a data point as a dictionary.
67
- """
68
- path = 'temp' # needs to be a string
69
- data_A = None # needs to be a tensor
70
- data_B = None # needs to be a tensor
71
- return {'data_A': data_A, 'data_B': data_B, 'path': path}
72
-
73
- def __len__(self):
74
- """Return the total number of images."""
75
- return len(self.image_paths)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/extract_kp_videos.py DELETED
@@ -1,108 +0,0 @@
1
- import os
2
- import cv2
3
- import time
4
- import glob
5
- import argparse
6
- import face_alignment
7
- import numpy as np
8
- from PIL import Image
9
- from tqdm import tqdm
10
- from itertools import cycle
11
-
12
- from torch.multiprocessing import Pool, Process, set_start_method
13
-
14
- class KeypointExtractor():
15
- def __init__(self, device):
16
- self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
17
- device=device)
18
-
19
- def extract_keypoint(self, images, name=None, info=True):
20
- if isinstance(images, list):
21
- keypoints = []
22
- if info:
23
- i_range = tqdm(images,desc='landmark Det:')
24
- else:
25
- i_range = images
26
-
27
- for image in i_range:
28
- current_kp = self.extract_keypoint(image)
29
- if np.mean(current_kp) == -1 and keypoints:
30
- keypoints.append(keypoints[-1])
31
- else:
32
- keypoints.append(current_kp[None])
33
-
34
- keypoints = np.concatenate(keypoints, 0)
35
- np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
36
- return keypoints
37
- else:
38
- while True:
39
- try:
40
- keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
41
- break
42
- except RuntimeError as e:
43
- if str(e).startswith('CUDA'):
44
- print("Warning: out of memory, sleep for 1s")
45
- time.sleep(1)
46
- else:
47
- print(e)
48
- break
49
- except TypeError:
50
- print('No face detected in this image')
51
- shape = [68, 2]
52
- keypoints = -1. * np.ones(shape)
53
- break
54
- if name is not None:
55
- np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
56
- return keypoints
57
-
58
- def read_video(filename):
59
- frames = []
60
- cap = cv2.VideoCapture(filename)
61
- while cap.isOpened():
62
- ret, frame = cap.read()
63
- if ret:
64
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
- frame = Image.fromarray(frame)
66
- frames.append(frame)
67
- else:
68
- break
69
- cap.release()
70
- return frames
71
-
72
- def run(data):
73
- filename, opt, device = data
74
- os.environ['CUDA_VISIBLE_DEVICES'] = device
75
- kp_extractor = KeypointExtractor()
76
- images = read_video(filename)
77
- name = filename.split('/')[-2:]
78
- os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
79
- kp_extractor.extract_keypoint(
80
- images,
81
- name=os.path.join(opt.output_dir, name[-2], name[-1])
82
- )
83
-
84
- if __name__ == '__main__':
85
- set_start_method('spawn')
86
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
87
- parser.add_argument('--input_dir', type=str, help='the folder of the input files')
88
- parser.add_argument('--output_dir', type=str, help='the folder of the output files')
89
- parser.add_argument('--device_ids', type=str, default='0,1')
90
- parser.add_argument('--workers', type=int, default=4)
91
-
92
- opt = parser.parse_args()
93
- filenames = list()
94
- VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
95
- VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
96
- extensions = VIDEO_EXTENSIONS
97
-
98
- for ext in extensions:
99
- os.listdir(f'{opt.input_dir}')
100
- print(f'{opt.input_dir}/*.{ext}')
101
- filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
102
- print('Total number of videos:', len(filenames))
103
- pool = Pool(opt.workers)
104
- args_list = cycle([opt])
105
- device_ids = opt.device_ids.split(",")
106
- device_ids = cycle(device_ids)
107
- for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
108
- None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/extract_kp_videos_safe.py DELETED
@@ -1,151 +0,0 @@
1
- import os
2
- import cv2
3
- import time
4
- import glob
5
- import argparse
6
- import numpy as np
7
- from PIL import Image
8
- import torch
9
- from tqdm import tqdm
10
- from itertools import cycle
11
- from torch.multiprocessing import Pool, Process, set_start_method
12
-
13
- from facexlib.alignment import landmark_98_to_68
14
- from facexlib.detection import init_detection_model
15
-
16
- from facexlib.utils import load_file_from_url
17
- from src.face3d.util.my_awing_arch import FAN
18
-
19
- def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
20
- if model_name == 'awing_fan':
21
- model = FAN(num_modules=4, num_landmarks=98, device=device)
22
- model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
23
- else:
24
- raise NotImplementedError(f'{model_name} is not implemented.')
25
-
26
- model_path = load_file_from_url(
27
- url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
28
- model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
29
- model.eval()
30
- model = model.to(device)
31
- return model
32
-
33
-
34
- class KeypointExtractor():
35
- def __init__(self, device='cuda'):
36
-
37
- ### gfpgan/weights
38
- try:
39
- import webui # in webui
40
- root_path = 'extensions/SadTalker/gfpgan/weights'
41
-
42
- except:
43
- root_path = 'gfpgan/weights'
44
-
45
- self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
46
- self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
47
-
48
- def extract_keypoint(self, images, name=None, info=True):
49
- if isinstance(images, list):
50
- keypoints = []
51
- if info:
52
- i_range = tqdm(images,desc='landmark Det:')
53
- else:
54
- i_range = images
55
-
56
- for image in i_range:
57
- current_kp = self.extract_keypoint(image)
58
- # current_kp = self.detector.get_landmarks(np.array(image))
59
- if np.mean(current_kp) == -1 and keypoints:
60
- keypoints.append(keypoints[-1])
61
- else:
62
- keypoints.append(current_kp[None])
63
-
64
- keypoints = np.concatenate(keypoints, 0)
65
- np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
66
- return keypoints
67
- else:
68
- while True:
69
- try:
70
- with torch.no_grad():
71
- # face detection -> face alignment.
72
- img = np.array(images)
73
- bboxes = self.det_net.detect_faces(images, 0.97)
74
-
75
- bboxes = bboxes[0]
76
- img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
77
-
78
- keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
79
-
80
- #### keypoints to the original location
81
- keypoints[:,0] += int(bboxes[0])
82
- keypoints[:,1] += int(bboxes[1])
83
-
84
- break
85
- except RuntimeError as e:
86
- if str(e).startswith('CUDA'):
87
- print("Warning: out of memory, sleep for 1s")
88
- time.sleep(1)
89
- else:
90
- print(e)
91
- break
92
- except TypeError:
93
- print('No face detected in this image')
94
- shape = [68, 2]
95
- keypoints = -1. * np.ones(shape)
96
- break
97
- if name is not None:
98
- np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
99
- return keypoints
100
-
101
- def read_video(filename):
102
- frames = []
103
- cap = cv2.VideoCapture(filename)
104
- while cap.isOpened():
105
- ret, frame = cap.read()
106
- if ret:
107
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
108
- frame = Image.fromarray(frame)
109
- frames.append(frame)
110
- else:
111
- break
112
- cap.release()
113
- return frames
114
-
115
- def run(data):
116
- filename, opt, device = data
117
- os.environ['CUDA_VISIBLE_DEVICES'] = device
118
- kp_extractor = KeypointExtractor()
119
- images = read_video(filename)
120
- name = filename.split('/')[-2:]
121
- os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
122
- kp_extractor.extract_keypoint(
123
- images,
124
- name=os.path.join(opt.output_dir, name[-2], name[-1])
125
- )
126
-
127
- if __name__ == '__main__':
128
- set_start_method('spawn')
129
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
130
- parser.add_argument('--input_dir', type=str, help='the folder of the input files')
131
- parser.add_argument('--output_dir', type=str, help='the folder of the output files')
132
- parser.add_argument('--device_ids', type=str, default='0,1')
133
- parser.add_argument('--workers', type=int, default=4)
134
-
135
- opt = parser.parse_args()
136
- filenames = list()
137
- VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
138
- VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
139
- extensions = VIDEO_EXTENSIONS
140
-
141
- for ext in extensions:
142
- os.listdir(f'{opt.input_dir}')
143
- print(f'{opt.input_dir}/*.{ext}')
144
- filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
145
- print('Total number of videos:', len(filenames))
146
- pool = Pool(opt.workers)
147
- args_list = cycle([opt])
148
- device_ids = opt.device_ids.split(",")
149
- device_ids = cycle(device_ids)
150
- for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
151
- None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/__init__.py DELETED
@@ -1,67 +0,0 @@
1
- """This package contains modules related to objective functions, optimizations, and network architectures.
2
-
3
- To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
- You need to implement the following five functions:
5
- -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
- -- <set_input>: unpack data from dataset and apply preprocessing.
7
- -- <forward>: produce intermediate results.
8
- -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
- -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
-
11
- In the function <__init__>, you need to define four lists:
12
- -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
- -- self.model_names (str list): define networks used in our training.
14
- -- self.visual_names (str list): specify the images that you want to display and save.
15
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
-
17
- Now you can use the model class by specifying flag '--model dummy'.
18
- See our template model class 'template_model.py' for more details.
19
- """
20
-
21
- import importlib
22
- from src.face3d.models.base_model import BaseModel
23
-
24
-
25
- def find_model_using_name(model_name):
26
- """Import the module "models/[model_name]_model.py".
27
-
28
- In the file, the class called DatasetNameModel() will
29
- be instantiated. It has to be a subclass of BaseModel,
30
- and it is case-insensitive.
31
- """
32
- model_filename = "face3d.models." + model_name + "_model"
33
- modellib = importlib.import_module(model_filename)
34
- model = None
35
- target_model_name = model_name.replace('_', '') + 'model'
36
- for name, cls in modellib.__dict__.items():
37
- if name.lower() == target_model_name.lower() \
38
- and issubclass(cls, BaseModel):
39
- model = cls
40
-
41
- if model is None:
42
- print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
- exit(0)
44
-
45
- return model
46
-
47
-
48
- def get_option_setter(model_name):
49
- """Return the static method <modify_commandline_options> of the model class."""
50
- model_class = find_model_using_name(model_name)
51
- return model_class.modify_commandline_options
52
-
53
-
54
- def create_model(opt):
55
- """Create a model given the option.
56
-
57
- This function warps the class CustomDatasetDataLoader.
58
- This is the main interface between this package and 'train.py'/'test.py'
59
-
60
- Example:
61
- >>> from models import create_model
62
- >>> model = create_model(opt)
63
- """
64
- model = find_model_using_name(opt.model)
65
- instance = model(opt)
66
- print("model [%s] was created" % type(instance).__name__)
67
- return instance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/README.md DELETED
@@ -1,164 +0,0 @@
1
- # Distributed Arcface Training in Pytorch
2
-
3
- This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
4
- identity on a single server.
5
-
6
- ## Requirements
7
-
8
- - Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
9
- - `pip install -r requirements.txt`.
10
- - Download the dataset
11
- from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
12
- .
13
-
14
- ## How to Training
15
-
16
- To train a model, run `train.py` with the path to the configs:
17
-
18
- ### 1. Single node, 8 GPUs:
19
-
20
- ```shell
21
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
22
- ```
23
-
24
- ### 2. Multiple nodes, each node 8 GPUs:
25
-
26
- Node 0:
27
-
28
- ```shell
29
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
30
- ```
31
-
32
- Node 1:
33
-
34
- ```shell
35
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
36
- ```
37
-
38
- ### 3.Training resnet2060 with 8 GPUs:
39
-
40
- ```shell
41
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
42
- ```
43
-
44
- ## Model Zoo
45
-
46
- - The models are available for non-commercial research purposes only.
47
- - All models can be found in here.
48
- - [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
49
- - [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
50
-
51
- ### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
52
-
53
- ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
54
- recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
55
- As the result, we can evaluate the FAIR performance for different algorithms.
56
-
57
- For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
58
- globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
59
-
60
- For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
61
- Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
62
- There are totally 13,928 positive pairs and 96,983,824 negative pairs.
63
-
64
- | Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
65
- | :---: | :--- | :--- | :--- |:--- |:--- |
66
- | MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
67
- | Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
68
- | MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
69
- | Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
70
- | MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
71
- | Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
72
- | MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
73
- | Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
74
- | MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
75
- | Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
76
-
77
- ### Performance on IJB-C and Verification Datasets
78
-
79
- | Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
80
- | :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
81
- | MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
82
- | MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
83
- | MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
84
- | MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
85
- | MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
86
- | Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
87
- | Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
88
- | Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
89
- | Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
90
-
91
- [comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)
92
-
93
-
94
- ## [Speed Benchmark](docs/speed_benchmark.md)
95
-
96
- **Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
97
- classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
98
- accuracy with several times faster training performance and smaller GPU memory.
99
- Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
100
- sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
101
- sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
102
- we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
103
- training and mixed precision training.
104
-
105
- ![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
106
-
107
- More details see
108
- [speed_benchmark.md](docs/speed_benchmark.md) in docs.
109
-
110
- ### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
111
-
112
- `-` means training failed because of gpu memory limitations.
113
-
114
- | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
115
- | :--- | :--- | :--- | :--- |
116
- |125000 | 4681 | 4824 | 5004 |
117
- |1400000 | **1672** | 3043 | 4738 |
118
- |5500000 | **-** | **1389** | 3975 |
119
- |8000000 | **-** | **-** | 3565 |
120
- |16000000 | **-** | **-** | 2679 |
121
- |29000000 | **-** | **-** | **1855** |
122
-
123
- ### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
124
-
125
- | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
126
- | :--- | :--- | :--- | :--- |
127
- |125000 | 7358 | 5306 | 4868 |
128
- |1400000 | 32252 | 11178 | 6056 |
129
- |5500000 | **-** | 32188 | 9854 |
130
- |8000000 | **-** | **-** | 12310 |
131
- |16000000 | **-** | **-** | 19950 |
132
- |29000000 | **-** | **-** | 32324 |
133
-
134
- ## Evaluation ICCV2021-MFR and IJB-C
135
-
136
- More details see [eval.md](docs/eval.md) in docs.
137
-
138
- ## Test
139
-
140
- We tested many versions of PyTorch. Please create an issue if you are having trouble.
141
-
142
- - [x] torch 1.6.0
143
- - [x] torch 1.7.1
144
- - [x] torch 1.8.0
145
- - [x] torch 1.9.0
146
-
147
- ## Citation
148
-
149
- ```
150
- @inproceedings{deng2019arcface,
151
- title={Arcface: Additive angular margin loss for deep face recognition},
152
- author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
153
- booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
154
- pages={4690--4699},
155
- year={2019}
156
- }
157
- @inproceedings{an2020partical_fc,
158
- title={Partial FC: Training 10 Million Identities on a Single Machine},
159
- author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
160
- Zhang, Debing and Fu Ying},
161
- booktitle={Arxiv 2010.05222},
162
- year={2020}
163
- }
164
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/backbones/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
2
- from .mobilefacenet import get_mbf
3
-
4
-
5
- def get_model(name, **kwargs):
6
- # resnet
7
- if name == "r18":
8
- return iresnet18(False, **kwargs)
9
- elif name == "r34":
10
- return iresnet34(False, **kwargs)
11
- elif name == "r50":
12
- return iresnet50(False, **kwargs)
13
- elif name == "r100":
14
- return iresnet100(False, **kwargs)
15
- elif name == "r200":
16
- return iresnet200(False, **kwargs)
17
- elif name == "r2060":
18
- from .iresnet2060 import iresnet2060
19
- return iresnet2060(False, **kwargs)
20
- elif name == "mbf":
21
- fp16 = kwargs.get("fp16", False)
22
- num_features = kwargs.get("num_features", 512)
23
- return get_mbf(fp16=fp16, num_features=num_features)
24
- else:
25
- raise ValueError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/backbones/iresnet.py DELETED
@@ -1,187 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
5
-
6
-
7
- def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8
- """3x3 convolution with padding"""
9
- return nn.Conv2d(in_planes,
10
- out_planes,
11
- kernel_size=3,
12
- stride=stride,
13
- padding=dilation,
14
- groups=groups,
15
- bias=False,
16
- dilation=dilation)
17
-
18
-
19
- def conv1x1(in_planes, out_planes, stride=1):
20
- """1x1 convolution"""
21
- return nn.Conv2d(in_planes,
22
- out_planes,
23
- kernel_size=1,
24
- stride=stride,
25
- bias=False)
26
-
27
-
28
- class IBasicBlock(nn.Module):
29
- expansion = 1
30
- def __init__(self, inplanes, planes, stride=1, downsample=None,
31
- groups=1, base_width=64, dilation=1):
32
- super(IBasicBlock, self).__init__()
33
- if groups != 1 or base_width != 64:
34
- raise ValueError('BasicBlock only supports groups=1 and base_width=64')
35
- if dilation > 1:
36
- raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
37
- self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
38
- self.conv1 = conv3x3(inplanes, planes)
39
- self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
40
- self.prelu = nn.PReLU(planes)
41
- self.conv2 = conv3x3(planes, planes, stride)
42
- self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
43
- self.downsample = downsample
44
- self.stride = stride
45
-
46
- def forward(self, x):
47
- identity = x
48
- out = self.bn1(x)
49
- out = self.conv1(out)
50
- out = self.bn2(out)
51
- out = self.prelu(out)
52
- out = self.conv2(out)
53
- out = self.bn3(out)
54
- if self.downsample is not None:
55
- identity = self.downsample(x)
56
- out += identity
57
- return out
58
-
59
-
60
- class IResNet(nn.Module):
61
- fc_scale = 7 * 7
62
- def __init__(self,
63
- block, layers, dropout=0, num_features=512, zero_init_residual=False,
64
- groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
65
- super(IResNet, self).__init__()
66
- self.fp16 = fp16
67
- self.inplanes = 64
68
- self.dilation = 1
69
- if replace_stride_with_dilation is None:
70
- replace_stride_with_dilation = [False, False, False]
71
- if len(replace_stride_with_dilation) != 3:
72
- raise ValueError("replace_stride_with_dilation should be None "
73
- "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
74
- self.groups = groups
75
- self.base_width = width_per_group
76
- self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
77
- self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
78
- self.prelu = nn.PReLU(self.inplanes)
79
- self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
80
- self.layer2 = self._make_layer(block,
81
- 128,
82
- layers[1],
83
- stride=2,
84
- dilate=replace_stride_with_dilation[0])
85
- self.layer3 = self._make_layer(block,
86
- 256,
87
- layers[2],
88
- stride=2,
89
- dilate=replace_stride_with_dilation[1])
90
- self.layer4 = self._make_layer(block,
91
- 512,
92
- layers[3],
93
- stride=2,
94
- dilate=replace_stride_with_dilation[2])
95
- self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
96
- self.dropout = nn.Dropout(p=dropout, inplace=True)
97
- self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
98
- self.features = nn.BatchNorm1d(num_features, eps=1e-05)
99
- nn.init.constant_(self.features.weight, 1.0)
100
- self.features.weight.requires_grad = False
101
-
102
- for m in self.modules():
103
- if isinstance(m, nn.Conv2d):
104
- nn.init.normal_(m.weight, 0, 0.1)
105
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
106
- nn.init.constant_(m.weight, 1)
107
- nn.init.constant_(m.bias, 0)
108
-
109
- if zero_init_residual:
110
- for m in self.modules():
111
- if isinstance(m, IBasicBlock):
112
- nn.init.constant_(m.bn2.weight, 0)
113
-
114
- def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
115
- downsample = None
116
- previous_dilation = self.dilation
117
- if dilate:
118
- self.dilation *= stride
119
- stride = 1
120
- if stride != 1 or self.inplanes != planes * block.expansion:
121
- downsample = nn.Sequential(
122
- conv1x1(self.inplanes, planes * block.expansion, stride),
123
- nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
124
- )
125
- layers = []
126
- layers.append(
127
- block(self.inplanes, planes, stride, downsample, self.groups,
128
- self.base_width, previous_dilation))
129
- self.inplanes = planes * block.expansion
130
- for _ in range(1, blocks):
131
- layers.append(
132
- block(self.inplanes,
133
- planes,
134
- groups=self.groups,
135
- base_width=self.base_width,
136
- dilation=self.dilation))
137
-
138
- return nn.Sequential(*layers)
139
-
140
- def forward(self, x):
141
- with torch.cuda.amp.autocast(self.fp16):
142
- x = self.conv1(x)
143
- x = self.bn1(x)
144
- x = self.prelu(x)
145
- x = self.layer1(x)
146
- x = self.layer2(x)
147
- x = self.layer3(x)
148
- x = self.layer4(x)
149
- x = self.bn2(x)
150
- x = torch.flatten(x, 1)
151
- x = self.dropout(x)
152
- x = self.fc(x.float() if self.fp16 else x)
153
- x = self.features(x)
154
- return x
155
-
156
-
157
- def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
158
- model = IResNet(block, layers, **kwargs)
159
- if pretrained:
160
- raise ValueError()
161
- return model
162
-
163
-
164
- def iresnet18(pretrained=False, progress=True, **kwargs):
165
- return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
166
- progress, **kwargs)
167
-
168
-
169
- def iresnet34(pretrained=False, progress=True, **kwargs):
170
- return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
171
- progress, **kwargs)
172
-
173
-
174
- def iresnet50(pretrained=False, progress=True, **kwargs):
175
- return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
176
- progress, **kwargs)
177
-
178
-
179
- def iresnet100(pretrained=False, progress=True, **kwargs):
180
- return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
181
- progress, **kwargs)
182
-
183
-
184
- def iresnet200(pretrained=False, progress=True, **kwargs):
185
- return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
186
- progress, **kwargs)
187
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/backbones/iresnet2060.py DELETED
@@ -1,176 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
- assert torch.__version__ >= "1.8.1"
5
- from torch.utils.checkpoint import checkpoint_sequential
6
-
7
- __all__ = ['iresnet2060']
8
-
9
-
10
- def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11
- """3x3 convolution with padding"""
12
- return nn.Conv2d(in_planes,
13
- out_planes,
14
- kernel_size=3,
15
- stride=stride,
16
- padding=dilation,
17
- groups=groups,
18
- bias=False,
19
- dilation=dilation)
20
-
21
-
22
- def conv1x1(in_planes, out_planes, stride=1):
23
- """1x1 convolution"""
24
- return nn.Conv2d(in_planes,
25
- out_planes,
26
- kernel_size=1,
27
- stride=stride,
28
- bias=False)
29
-
30
-
31
- class IBasicBlock(nn.Module):
32
- expansion = 1
33
-
34
- def __init__(self, inplanes, planes, stride=1, downsample=None,
35
- groups=1, base_width=64, dilation=1):
36
- super(IBasicBlock, self).__init__()
37
- if groups != 1 or base_width != 64:
38
- raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39
- if dilation > 1:
40
- raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41
- self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
42
- self.conv1 = conv3x3(inplanes, planes)
43
- self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
44
- self.prelu = nn.PReLU(planes)
45
- self.conv2 = conv3x3(planes, planes, stride)
46
- self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
47
- self.downsample = downsample
48
- self.stride = stride
49
-
50
- def forward(self, x):
51
- identity = x
52
- out = self.bn1(x)
53
- out = self.conv1(out)
54
- out = self.bn2(out)
55
- out = self.prelu(out)
56
- out = self.conv2(out)
57
- out = self.bn3(out)
58
- if self.downsample is not None:
59
- identity = self.downsample(x)
60
- out += identity
61
- return out
62
-
63
-
64
- class IResNet(nn.Module):
65
- fc_scale = 7 * 7
66
-
67
- def __init__(self,
68
- block, layers, dropout=0, num_features=512, zero_init_residual=False,
69
- groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
70
- super(IResNet, self).__init__()
71
- self.fp16 = fp16
72
- self.inplanes = 64
73
- self.dilation = 1
74
- if replace_stride_with_dilation is None:
75
- replace_stride_with_dilation = [False, False, False]
76
- if len(replace_stride_with_dilation) != 3:
77
- raise ValueError("replace_stride_with_dilation should be None "
78
- "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
79
- self.groups = groups
80
- self.base_width = width_per_group
81
- self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
82
- self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
83
- self.prelu = nn.PReLU(self.inplanes)
84
- self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
85
- self.layer2 = self._make_layer(block,
86
- 128,
87
- layers[1],
88
- stride=2,
89
- dilate=replace_stride_with_dilation[0])
90
- self.layer3 = self._make_layer(block,
91
- 256,
92
- layers[2],
93
- stride=2,
94
- dilate=replace_stride_with_dilation[1])
95
- self.layer4 = self._make_layer(block,
96
- 512,
97
- layers[3],
98
- stride=2,
99
- dilate=replace_stride_with_dilation[2])
100
- self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
101
- self.dropout = nn.Dropout(p=dropout, inplace=True)
102
- self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
103
- self.features = nn.BatchNorm1d(num_features, eps=1e-05)
104
- nn.init.constant_(self.features.weight, 1.0)
105
- self.features.weight.requires_grad = False
106
-
107
- for m in self.modules():
108
- if isinstance(m, nn.Conv2d):
109
- nn.init.normal_(m.weight, 0, 0.1)
110
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111
- nn.init.constant_(m.weight, 1)
112
- nn.init.constant_(m.bias, 0)
113
-
114
- if zero_init_residual:
115
- for m in self.modules():
116
- if isinstance(m, IBasicBlock):
117
- nn.init.constant_(m.bn2.weight, 0)
118
-
119
- def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
120
- downsample = None
121
- previous_dilation = self.dilation
122
- if dilate:
123
- self.dilation *= stride
124
- stride = 1
125
- if stride != 1 or self.inplanes != planes * block.expansion:
126
- downsample = nn.Sequential(
127
- conv1x1(self.inplanes, planes * block.expansion, stride),
128
- nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
129
- )
130
- layers = []
131
- layers.append(
132
- block(self.inplanes, planes, stride, downsample, self.groups,
133
- self.base_width, previous_dilation))
134
- self.inplanes = planes * block.expansion
135
- for _ in range(1, blocks):
136
- layers.append(
137
- block(self.inplanes,
138
- planes,
139
- groups=self.groups,
140
- base_width=self.base_width,
141
- dilation=self.dilation))
142
-
143
- return nn.Sequential(*layers)
144
-
145
- def checkpoint(self, func, num_seg, x):
146
- if self.training:
147
- return checkpoint_sequential(func, num_seg, x)
148
- else:
149
- return func(x)
150
-
151
- def forward(self, x):
152
- with torch.cuda.amp.autocast(self.fp16):
153
- x = self.conv1(x)
154
- x = self.bn1(x)
155
- x = self.prelu(x)
156
- x = self.layer1(x)
157
- x = self.checkpoint(self.layer2, 20, x)
158
- x = self.checkpoint(self.layer3, 100, x)
159
- x = self.layer4(x)
160
- x = self.bn2(x)
161
- x = torch.flatten(x, 1)
162
- x = self.dropout(x)
163
- x = self.fc(x.float() if self.fp16 else x)
164
- x = self.features(x)
165
- return x
166
-
167
-
168
- def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
169
- model = IResNet(block, layers, **kwargs)
170
- if pretrained:
171
- raise ValueError()
172
- return model
173
-
174
-
175
- def iresnet2060(pretrained=False, progress=True, **kwargs):
176
- return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/backbones/mobilefacenet.py DELETED
@@ -1,130 +0,0 @@
1
- '''
2
- Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
3
- Original author cavalleria
4
- '''
5
-
6
- import torch.nn as nn
7
- from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
8
- import torch
9
-
10
-
11
- class Flatten(Module):
12
- def forward(self, x):
13
- return x.view(x.size(0), -1)
14
-
15
-
16
- class ConvBlock(Module):
17
- def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
18
- super(ConvBlock, self).__init__()
19
- self.layers = nn.Sequential(
20
- Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
21
- BatchNorm2d(num_features=out_c),
22
- PReLU(num_parameters=out_c)
23
- )
24
-
25
- def forward(self, x):
26
- return self.layers(x)
27
-
28
-
29
- class LinearBlock(Module):
30
- def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
31
- super(LinearBlock, self).__init__()
32
- self.layers = nn.Sequential(
33
- Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
34
- BatchNorm2d(num_features=out_c)
35
- )
36
-
37
- def forward(self, x):
38
- return self.layers(x)
39
-
40
-
41
- class DepthWise(Module):
42
- def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
43
- super(DepthWise, self).__init__()
44
- self.residual = residual
45
- self.layers = nn.Sequential(
46
- ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
47
- ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
48
- LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
49
- )
50
-
51
- def forward(self, x):
52
- short_cut = None
53
- if self.residual:
54
- short_cut = x
55
- x = self.layers(x)
56
- if self.residual:
57
- output = short_cut + x
58
- else:
59
- output = x
60
- return output
61
-
62
-
63
- class Residual(Module):
64
- def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
65
- super(Residual, self).__init__()
66
- modules = []
67
- for _ in range(num_block):
68
- modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
69
- self.layers = Sequential(*modules)
70
-
71
- def forward(self, x):
72
- return self.layers(x)
73
-
74
-
75
- class GDC(Module):
76
- def __init__(self, embedding_size):
77
- super(GDC, self).__init__()
78
- self.layers = nn.Sequential(
79
- LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
80
- Flatten(),
81
- Linear(512, embedding_size, bias=False),
82
- BatchNorm1d(embedding_size))
83
-
84
- def forward(self, x):
85
- return self.layers(x)
86
-
87
-
88
- class MobileFaceNet(Module):
89
- def __init__(self, fp16=False, num_features=512):
90
- super(MobileFaceNet, self).__init__()
91
- scale = 2
92
- self.fp16 = fp16
93
- self.layers = nn.Sequential(
94
- ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
95
- ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
96
- DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
97
- Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
98
- DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
99
- Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
100
- DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
101
- Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
102
- )
103
- self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
104
- self.features = GDC(num_features)
105
- self._initialize_weights()
106
-
107
- def _initialize_weights(self):
108
- for m in self.modules():
109
- if isinstance(m, nn.Conv2d):
110
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
111
- if m.bias is not None:
112
- m.bias.data.zero_()
113
- elif isinstance(m, nn.BatchNorm2d):
114
- m.weight.data.fill_(1)
115
- m.bias.data.zero_()
116
- elif isinstance(m, nn.Linear):
117
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
118
- if m.bias is not None:
119
- m.bias.data.zero_()
120
-
121
- def forward(self, x):
122
- with torch.cuda.amp.autocast(self.fp16):
123
- x = self.layers(x)
124
- x = self.conv_sep(x.float() if self.fp16 else x)
125
- x = self.features(x)
126
- return x
127
-
128
-
129
- def get_mbf(fp16, num_features):
130
- return MobileFaceNet(fp16, num_features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/3millions.py DELETED
@@ -1,23 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # configs for test speed
4
-
5
- config = edict()
6
- config.loss = "arcface"
7
- config.network = "r50"
8
- config.resume = False
9
- config.output = None
10
- config.embedding_size = 512
11
- config.sample_rate = 1.0
12
- config.fp16 = True
13
- config.momentum = 0.9
14
- config.weight_decay = 5e-4
15
- config.batch_size = 128
16
- config.lr = 0.1 # batch size is 512
17
-
18
- config.rec = "synthetic"
19
- config.num_classes = 300 * 10000
20
- config.num_epoch = 30
21
- config.warmup_epoch = -1
22
- config.decay_epoch = [10, 16, 22]
23
- config.val_targets = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/3millions_pfc.py DELETED
@@ -1,23 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # configs for test speed
4
-
5
- config = edict()
6
- config.loss = "arcface"
7
- config.network = "r50"
8
- config.resume = False
9
- config.output = None
10
- config.embedding_size = 512
11
- config.sample_rate = 0.1
12
- config.fp16 = True
13
- config.momentum = 0.9
14
- config.weight_decay = 5e-4
15
- config.batch_size = 128
16
- config.lr = 0.1 # batch size is 512
17
-
18
- config.rec = "synthetic"
19
- config.num_classes = 300 * 10000
20
- config.num_epoch = 30
21
- config.warmup_epoch = -1
22
- config.decay_epoch = [10, 16, 22]
23
- config.val_targets = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/__init__.py DELETED
File without changes
src/face3d/models/arcface_torch/configs/base.py DELETED
@@ -1,56 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # make training faster
4
- # our RAM is 256G
5
- # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
-
7
- config = edict()
8
- config.loss = "arcface"
9
- config.network = "r50"
10
- config.resume = False
11
- config.output = "ms1mv3_arcface_r50"
12
-
13
- config.dataset = "ms1m-retinaface-t1"
14
- config.embedding_size = 512
15
- config.sample_rate = 1
16
- config.fp16 = False
17
- config.momentum = 0.9
18
- config.weight_decay = 5e-4
19
- config.batch_size = 128
20
- config.lr = 0.1 # batch size is 512
21
-
22
- if config.dataset == "emore":
23
- config.rec = "/train_tmp/faces_emore"
24
- config.num_classes = 85742
25
- config.num_image = 5822653
26
- config.num_epoch = 16
27
- config.warmup_epoch = -1
28
- config.decay_epoch = [8, 14, ]
29
- config.val_targets = ["lfw", ]
30
-
31
- elif config.dataset == "ms1m-retinaface-t1":
32
- config.rec = "/train_tmp/ms1m-retinaface-t1"
33
- config.num_classes = 93431
34
- config.num_image = 5179510
35
- config.num_epoch = 25
36
- config.warmup_epoch = -1
37
- config.decay_epoch = [11, 17, 22]
38
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
39
-
40
- elif config.dataset == "glint360k":
41
- config.rec = "/train_tmp/glint360k"
42
- config.num_classes = 360232
43
- config.num_image = 17091657
44
- config.num_epoch = 20
45
- config.warmup_epoch = -1
46
- config.decay_epoch = [8, 12, 15, 18]
47
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
48
-
49
- elif config.dataset == "webface":
50
- config.rec = "/train_tmp/faces_webface_112x112"
51
- config.num_classes = 10572
52
- config.num_image = "forget"
53
- config.num_epoch = 34
54
- config.warmup_epoch = -1
55
- config.decay_epoch = [20, 28, 32]
56
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/glint360k_mbf.py DELETED
@@ -1,26 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # make training faster
4
- # our RAM is 256G
5
- # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
-
7
- config = edict()
8
- config.loss = "cosface"
9
- config.network = "mbf"
10
- config.resume = False
11
- config.output = None
12
- config.embedding_size = 512
13
- config.sample_rate = 0.1
14
- config.fp16 = True
15
- config.momentum = 0.9
16
- config.weight_decay = 2e-4
17
- config.batch_size = 128
18
- config.lr = 0.1 # batch size is 512
19
-
20
- config.rec = "/train_tmp/glint360k"
21
- config.num_classes = 360232
22
- config.num_image = 17091657
23
- config.num_epoch = 20
24
- config.warmup_epoch = -1
25
- config.decay_epoch = [8, 12, 15, 18]
26
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/glint360k_r100.py DELETED
@@ -1,26 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # make training faster
4
- # our RAM is 256G
5
- # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
-
7
- config = edict()
8
- config.loss = "cosface"
9
- config.network = "r100"
10
- config.resume = False
11
- config.output = None
12
- config.embedding_size = 512
13
- config.sample_rate = 1.0
14
- config.fp16 = True
15
- config.momentum = 0.9
16
- config.weight_decay = 5e-4
17
- config.batch_size = 128
18
- config.lr = 0.1 # batch size is 512
19
-
20
- config.rec = "/train_tmp/glint360k"
21
- config.num_classes = 360232
22
- config.num_image = 17091657
23
- config.num_epoch = 20
24
- config.warmup_epoch = -1
25
- config.decay_epoch = [8, 12, 15, 18]
26
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/glint360k_r18.py DELETED
@@ -1,26 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # make training faster
4
- # our RAM is 256G
5
- # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
-
7
- config = edict()
8
- config.loss = "cosface"
9
- config.network = "r18"
10
- config.resume = False
11
- config.output = None
12
- config.embedding_size = 512
13
- config.sample_rate = 1.0
14
- config.fp16 = True
15
- config.momentum = 0.9
16
- config.weight_decay = 5e-4
17
- config.batch_size = 128
18
- config.lr = 0.1 # batch size is 512
19
-
20
- config.rec = "/train_tmp/glint360k"
21
- config.num_classes = 360232
22
- config.num_image = 17091657
23
- config.num_epoch = 20
24
- config.warmup_epoch = -1
25
- config.decay_epoch = [8, 12, 15, 18]
26
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/glint360k_r34.py DELETED
@@ -1,26 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # make training faster
4
- # our RAM is 256G
5
- # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
-
7
- config = edict()
8
- config.loss = "cosface"
9
- config.network = "r34"
10
- config.resume = False
11
- config.output = None
12
- config.embedding_size = 512
13
- config.sample_rate = 1.0
14
- config.fp16 = True
15
- config.momentum = 0.9
16
- config.weight_decay = 5e-4
17
- config.batch_size = 128
18
- config.lr = 0.1 # batch size is 512
19
-
20
- config.rec = "/train_tmp/glint360k"
21
- config.num_classes = 360232
22
- config.num_image = 17091657
23
- config.num_epoch = 20
24
- config.warmup_epoch = -1
25
- config.decay_epoch = [8, 12, 15, 18]
26
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/glint360k_r50.py DELETED
@@ -1,26 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # make training faster
4
- # our RAM is 256G
5
- # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
-
7
- config = edict()
8
- config.loss = "cosface"
9
- config.network = "r50"
10
- config.resume = False
11
- config.output = None
12
- config.embedding_size = 512
13
- config.sample_rate = 1.0
14
- config.fp16 = True
15
- config.momentum = 0.9
16
- config.weight_decay = 5e-4
17
- config.batch_size = 128
18
- config.lr = 0.1 # batch size is 512
19
-
20
- config.rec = "/train_tmp/glint360k"
21
- config.num_classes = 360232
22
- config.num_image = 17091657
23
- config.num_epoch = 20
24
- config.warmup_epoch = -1
25
- config.decay_epoch = [8, 12, 15, 18]
26
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py DELETED
@@ -1,26 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # make training faster
4
- # our RAM is 256G
5
- # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
-
7
- config = edict()
8
- config.loss = "arcface"
9
- config.network = "mbf"
10
- config.resume = False
11
- config.output = None
12
- config.embedding_size = 512
13
- config.sample_rate = 1.0
14
- config.fp16 = True
15
- config.momentum = 0.9
16
- config.weight_decay = 2e-4
17
- config.batch_size = 128
18
- config.lr = 0.1 # batch size is 512
19
-
20
- config.rec = "/train_tmp/ms1m-retinaface-t1"
21
- config.num_classes = 93431
22
- config.num_image = 5179510
23
- config.num_epoch = 30
24
- config.warmup_epoch = -1
25
- config.decay_epoch = [10, 20, 25]
26
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/ms1mv3_r18.py DELETED
@@ -1,26 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # make training faster
4
- # our RAM is 256G
5
- # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
-
7
- config = edict()
8
- config.loss = "arcface"
9
- config.network = "r18"
10
- config.resume = False
11
- config.output = None
12
- config.embedding_size = 512
13
- config.sample_rate = 1.0
14
- config.fp16 = True
15
- config.momentum = 0.9
16
- config.weight_decay = 5e-4
17
- config.batch_size = 128
18
- config.lr = 0.1 # batch size is 512
19
-
20
- config.rec = "/train_tmp/ms1m-retinaface-t1"
21
- config.num_classes = 93431
22
- config.num_image = 5179510
23
- config.num_epoch = 25
24
- config.warmup_epoch = -1
25
- config.decay_epoch = [10, 16, 22]
26
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py DELETED
@@ -1,26 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # make training faster
4
- # our RAM is 256G
5
- # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
-
7
- config = edict()
8
- config.loss = "arcface"
9
- config.network = "r2060"
10
- config.resume = False
11
- config.output = None
12
- config.embedding_size = 512
13
- config.sample_rate = 1.0
14
- config.fp16 = True
15
- config.momentum = 0.9
16
- config.weight_decay = 5e-4
17
- config.batch_size = 64
18
- config.lr = 0.1 # batch size is 512
19
-
20
- config.rec = "/train_tmp/ms1m-retinaface-t1"
21
- config.num_classes = 93431
22
- config.num_image = 5179510
23
- config.num_epoch = 25
24
- config.warmup_epoch = -1
25
- config.decay_epoch = [10, 16, 22]
26
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/ms1mv3_r34.py DELETED
@@ -1,26 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # make training faster
4
- # our RAM is 256G
5
- # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
-
7
- config = edict()
8
- config.loss = "arcface"
9
- config.network = "r34"
10
- config.resume = False
11
- config.output = None
12
- config.embedding_size = 512
13
- config.sample_rate = 1.0
14
- config.fp16 = True
15
- config.momentum = 0.9
16
- config.weight_decay = 5e-4
17
- config.batch_size = 128
18
- config.lr = 0.1 # batch size is 512
19
-
20
- config.rec = "/train_tmp/ms1m-retinaface-t1"
21
- config.num_classes = 93431
22
- config.num_image = 5179510
23
- config.num_epoch = 25
24
- config.warmup_epoch = -1
25
- config.decay_epoch = [10, 16, 22]
26
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/ms1mv3_r50.py DELETED
@@ -1,26 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # make training faster
4
- # our RAM is 256G
5
- # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
-
7
- config = edict()
8
- config.loss = "arcface"
9
- config.network = "r50"
10
- config.resume = False
11
- config.output = None
12
- config.embedding_size = 512
13
- config.sample_rate = 1.0
14
- config.fp16 = True
15
- config.momentum = 0.9
16
- config.weight_decay = 5e-4
17
- config.batch_size = 128
18
- config.lr = 0.1 # batch size is 512
19
-
20
- config.rec = "/train_tmp/ms1m-retinaface-t1"
21
- config.num_classes = 93431
22
- config.num_image = 5179510
23
- config.num_epoch = 25
24
- config.warmup_epoch = -1
25
- config.decay_epoch = [10, 16, 22]
26
- config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/configs/speed.py DELETED
@@ -1,23 +0,0 @@
1
- from easydict import EasyDict as edict
2
-
3
- # configs for test speed
4
-
5
- config = edict()
6
- config.loss = "arcface"
7
- config.network = "r50"
8
- config.resume = False
9
- config.output = None
10
- config.embedding_size = 512
11
- config.sample_rate = 1.0
12
- config.fp16 = True
13
- config.momentum = 0.9
14
- config.weight_decay = 5e-4
15
- config.batch_size = 128
16
- config.lr = 0.1 # batch size is 512
17
-
18
- config.rec = "synthetic"
19
- config.num_classes = 100 * 10000
20
- config.num_epoch = 30
21
- config.warmup_epoch = -1
22
- config.decay_epoch = [10, 16, 22]
23
- config.val_targets = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/dataset.py DELETED
@@ -1,124 +0,0 @@
1
- import numbers
2
- import os
3
- import queue as Queue
4
- import threading
5
-
6
- import mxnet as mx
7
- import numpy as np
8
- import torch
9
- from torch.utils.data import DataLoader, Dataset
10
- from torchvision import transforms
11
-
12
-
13
- class BackgroundGenerator(threading.Thread):
14
- def __init__(self, generator, local_rank, max_prefetch=6):
15
- super(BackgroundGenerator, self).__init__()
16
- self.queue = Queue.Queue(max_prefetch)
17
- self.generator = generator
18
- self.local_rank = local_rank
19
- self.daemon = True
20
- self.start()
21
-
22
- def run(self):
23
- torch.cuda.set_device(self.local_rank)
24
- for item in self.generator:
25
- self.queue.put(item)
26
- self.queue.put(None)
27
-
28
- def next(self):
29
- next_item = self.queue.get()
30
- if next_item is None:
31
- raise StopIteration
32
- return next_item
33
-
34
- def __next__(self):
35
- return self.next()
36
-
37
- def __iter__(self):
38
- return self
39
-
40
-
41
- class DataLoaderX(DataLoader):
42
-
43
- def __init__(self, local_rank, **kwargs):
44
- super(DataLoaderX, self).__init__(**kwargs)
45
- self.stream = torch.cuda.Stream(local_rank)
46
- self.local_rank = local_rank
47
-
48
- def __iter__(self):
49
- self.iter = super(DataLoaderX, self).__iter__()
50
- self.iter = BackgroundGenerator(self.iter, self.local_rank)
51
- self.preload()
52
- return self
53
-
54
- def preload(self):
55
- self.batch = next(self.iter, None)
56
- if self.batch is None:
57
- return None
58
- with torch.cuda.stream(self.stream):
59
- for k in range(len(self.batch)):
60
- self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
61
-
62
- def __next__(self):
63
- torch.cuda.current_stream().wait_stream(self.stream)
64
- batch = self.batch
65
- if batch is None:
66
- raise StopIteration
67
- self.preload()
68
- return batch
69
-
70
-
71
- class MXFaceDataset(Dataset):
72
- def __init__(self, root_dir, local_rank):
73
- super(MXFaceDataset, self).__init__()
74
- self.transform = transforms.Compose(
75
- [transforms.ToPILImage(),
76
- transforms.RandomHorizontalFlip(),
77
- transforms.ToTensor(),
78
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
79
- ])
80
- self.root_dir = root_dir
81
- self.local_rank = local_rank
82
- path_imgrec = os.path.join(root_dir, 'train.rec')
83
- path_imgidx = os.path.join(root_dir, 'train.idx')
84
- self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
85
- s = self.imgrec.read_idx(0)
86
- header, _ = mx.recordio.unpack(s)
87
- if header.flag > 0:
88
- self.header0 = (int(header.label[0]), int(header.label[1]))
89
- self.imgidx = np.array(range(1, int(header.label[0])))
90
- else:
91
- self.imgidx = np.array(list(self.imgrec.keys))
92
-
93
- def __getitem__(self, index):
94
- idx = self.imgidx[index]
95
- s = self.imgrec.read_idx(idx)
96
- header, img = mx.recordio.unpack(s)
97
- label = header.label
98
- if not isinstance(label, numbers.Number):
99
- label = label[0]
100
- label = torch.tensor(label, dtype=torch.long)
101
- sample = mx.image.imdecode(img).asnumpy()
102
- if self.transform is not None:
103
- sample = self.transform(sample)
104
- return sample, label
105
-
106
- def __len__(self):
107
- return len(self.imgidx)
108
-
109
-
110
- class SyntheticDataset(Dataset):
111
- def __init__(self, local_rank):
112
- super(SyntheticDataset, self).__init__()
113
- img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
114
- img = np.transpose(img, (2, 0, 1))
115
- img = torch.from_numpy(img).squeeze(0).float()
116
- img = ((img / 255) - 0.5) / 0.5
117
- self.img = img
118
- self.label = 1
119
-
120
- def __getitem__(self, index):
121
- return self.img, self.label
122
-
123
- def __len__(self):
124
- return 1000000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/docs/eval.md DELETED
@@ -1,31 +0,0 @@
1
- ## Eval on ICCV2021-MFR
2
-
3
- coming soon.
4
-
5
-
6
- ## Eval IJBC
7
- You can eval ijbc with pytorch or onnx.
8
-
9
-
10
- 1. Eval IJBC With Onnx
11
- ```shell
12
- CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50
13
- ```
14
-
15
- 2. Eval IJBC With Pytorch
16
- ```shell
17
- CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \
18
- --model-prefix ms1mv3_arcface_r50/backbone.pth \
19
- --image-path IJB_release/IJBC \
20
- --result-dir ms1mv3_arcface_r50 \
21
- --batch-size 128 \
22
- --job ms1mv3_arcface_r50 \
23
- --target IJBC \
24
- --network iresnet50
25
- ```
26
-
27
- ## Inference
28
-
29
- ```shell
30
- python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50
31
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/docs/install.md DELETED
@@ -1,51 +0,0 @@
1
- ## v1.8.0
2
- ### Linux and Windows
3
- ```shell
4
- # CUDA 11.0
5
- pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
6
-
7
- # CUDA 10.2
8
- pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0
9
-
10
- # CPU only
11
- pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
12
-
13
- ```
14
-
15
-
16
- ## v1.7.1
17
- ### Linux and Windows
18
- ```shell
19
- # CUDA 11.0
20
- pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
21
-
22
- # CUDA 10.2
23
- pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
24
-
25
- # CUDA 10.1
26
- pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
27
-
28
- # CUDA 9.2
29
- pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
30
-
31
- # CPU only
32
- pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
33
- ```
34
-
35
-
36
- ## v1.6.0
37
-
38
- ### Linux and Windows
39
- ```shell
40
- # CUDA 10.2
41
- pip install torch==1.6.0 torchvision==0.7.0
42
-
43
- # CUDA 10.1
44
- pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
45
-
46
- # CUDA 9.2
47
- pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
48
-
49
- # CPU only
50
- pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
51
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/docs/modelzoo.md DELETED
File without changes
src/face3d/models/arcface_torch/docs/speed_benchmark.md DELETED
@@ -1,93 +0,0 @@
1
- ## Test Training Speed
2
-
3
- - Test Commands
4
-
5
- You need to use the following two commands to test the Partial FC training performance.
6
- The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50,
7
- batch size is 1024.
8
- ```shell
9
- # Model Parallel
10
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions
11
- # Partial FC 0.1
12
- python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc
13
- ```
14
-
15
- - GPU Memory
16
-
17
- ```
18
- # (Model Parallel) gpustat -i
19
- [0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB
20
- [1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB
21
- [2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB
22
- [3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB
23
- [4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB
24
- [5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB
25
- [6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB
26
- [7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB
27
-
28
- # (Partial FC 0.1) gpustat -i
29
- [0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │·······················
30
- [1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │·······················
31
- [2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │·······················
32
- [3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │·······················
33
- [4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │·······················
34
- [5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │·······················
35
- [6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │·······················
36
- [7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │·······················
37
- ```
38
-
39
- - Training Speed
40
-
41
- ```python
42
- # (Model Parallel) trainging.log
43
- Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100
44
- Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
45
- Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
46
- Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
47
- Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
48
-
49
- # (Partial FC 0.1) trainging.log
50
- Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100
51
- Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
52
- Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
53
- Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
54
- Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
55
- ```
56
-
57
- In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel,
58
- and the training speed is 2.5 times faster than the model parallel.
59
-
60
-
61
- ## Speed Benchmark
62
-
63
- 1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better)
64
-
65
- | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
66
- | :--- | :--- | :--- | :--- |
67
- |125000 | 4681 | 4824 | 5004 |
68
- |250000 | 4047 | 4521 | 4976 |
69
- |500000 | 3087 | 4013 | 4900 |
70
- |1000000 | 2090 | 3449 | 4803 |
71
- |1400000 | 1672 | 3043 | 4738 |
72
- |2000000 | - | 2593 | 4626 |
73
- |4000000 | - | 1748 | 4208 |
74
- |5500000 | - | 1389 | 3975 |
75
- |8000000 | - | - | 3565 |
76
- |16000000 | - | - | 2679 |
77
- |29000000 | - | - | 1855 |
78
-
79
- 2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better)
80
-
81
- | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
82
- | :--- | :--- | :--- | :--- |
83
- |125000 | 7358 | 5306 | 4868 |
84
- |250000 | 9940 | 5826 | 5004 |
85
- |500000 | 14220 | 7114 | 5202 |
86
- |1000000 | 23708 | 9966 | 5620 |
87
- |1400000 | 32252 | 11178 | 6056 |
88
- |2000000 | - | 13978 | 6472 |
89
- |4000000 | - | 23238 | 8284 |
90
- |5500000 | - | 32188 | 9854 |
91
- |8000000 | - | - | 12310 |
92
- |16000000 | - | - | 19950 |
93
- |29000000 | - | - | 32324 |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/eval/__init__.py DELETED
File without changes
src/face3d/models/arcface_torch/eval/verification.py DELETED
@@ -1,407 +0,0 @@
1
- """Helper for evaluation on the Labeled Faces in the Wild dataset
2
- """
3
-
4
- # MIT License
5
- #
6
- # Copyright (c) 2016 David Sandberg
7
- #
8
- # Permission is hereby granted, free of charge, to any person obtaining a copy
9
- # of this software and associated documentation files (the "Software"), to deal
10
- # in the Software without restriction, including without limitation the rights
11
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
- # copies of the Software, and to permit persons to whom the Software is
13
- # furnished to do so, subject to the following conditions:
14
- #
15
- # The above copyright notice and this permission notice shall be included in all
16
- # copies or substantial portions of the Software.
17
- #
18
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
- # SOFTWARE.
25
-
26
-
27
- import datetime
28
- import os
29
- import pickle
30
-
31
- import mxnet as mx
32
- import numpy as np
33
- import sklearn
34
- import torch
35
- from mxnet import ndarray as nd
36
- from scipy import interpolate
37
- from sklearn.decomposition import PCA
38
- from sklearn.model_selection import KFold
39
-
40
-
41
- class LFold:
42
- def __init__(self, n_splits=2, shuffle=False):
43
- self.n_splits = n_splits
44
- if self.n_splits > 1:
45
- self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)
46
-
47
- def split(self, indices):
48
- if self.n_splits > 1:
49
- return self.k_fold.split(indices)
50
- else:
51
- return [(indices, indices)]
52
-
53
-
54
- def calculate_roc(thresholds,
55
- embeddings1,
56
- embeddings2,
57
- actual_issame,
58
- nrof_folds=10,
59
- pca=0):
60
- assert (embeddings1.shape[0] == embeddings2.shape[0])
61
- assert (embeddings1.shape[1] == embeddings2.shape[1])
62
- nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
63
- nrof_thresholds = len(thresholds)
64
- k_fold = LFold(n_splits=nrof_folds, shuffle=False)
65
-
66
- tprs = np.zeros((nrof_folds, nrof_thresholds))
67
- fprs = np.zeros((nrof_folds, nrof_thresholds))
68
- accuracy = np.zeros((nrof_folds))
69
- indices = np.arange(nrof_pairs)
70
-
71
- if pca == 0:
72
- diff = np.subtract(embeddings1, embeddings2)
73
- dist = np.sum(np.square(diff), 1)
74
-
75
- for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
76
- if pca > 0:
77
- print('doing pca on', fold_idx)
78
- embed1_train = embeddings1[train_set]
79
- embed2_train = embeddings2[train_set]
80
- _embed_train = np.concatenate((embed1_train, embed2_train), axis=0)
81
- pca_model = PCA(n_components=pca)
82
- pca_model.fit(_embed_train)
83
- embed1 = pca_model.transform(embeddings1)
84
- embed2 = pca_model.transform(embeddings2)
85
- embed1 = sklearn.preprocessing.normalize(embed1)
86
- embed2 = sklearn.preprocessing.normalize(embed2)
87
- diff = np.subtract(embed1, embed2)
88
- dist = np.sum(np.square(diff), 1)
89
-
90
- # Find the best threshold for the fold
91
- acc_train = np.zeros((nrof_thresholds))
92
- for threshold_idx, threshold in enumerate(thresholds):
93
- _, _, acc_train[threshold_idx] = calculate_accuracy(
94
- threshold, dist[train_set], actual_issame[train_set])
95
- best_threshold_index = np.argmax(acc_train)
96
- for threshold_idx, threshold in enumerate(thresholds):
97
- tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(
98
- threshold, dist[test_set],
99
- actual_issame[test_set])
100
- _, _, accuracy[fold_idx] = calculate_accuracy(
101
- thresholds[best_threshold_index], dist[test_set],
102
- actual_issame[test_set])
103
-
104
- tpr = np.mean(tprs, 0)
105
- fpr = np.mean(fprs, 0)
106
- return tpr, fpr, accuracy
107
-
108
-
109
- def calculate_accuracy(threshold, dist, actual_issame):
110
- predict_issame = np.less(dist, threshold)
111
- tp = np.sum(np.logical_and(predict_issame, actual_issame))
112
- fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
113
- tn = np.sum(
114
- np.logical_and(np.logical_not(predict_issame),
115
- np.logical_not(actual_issame)))
116
- fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
117
-
118
- tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)
119
- fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)
120
- acc = float(tp + tn) / dist.size
121
- return tpr, fpr, acc
122
-
123
-
124
- def calculate_val(thresholds,
125
- embeddings1,
126
- embeddings2,
127
- actual_issame,
128
- far_target,
129
- nrof_folds=10):
130
- assert (embeddings1.shape[0] == embeddings2.shape[0])
131
- assert (embeddings1.shape[1] == embeddings2.shape[1])
132
- nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
133
- nrof_thresholds = len(thresholds)
134
- k_fold = LFold(n_splits=nrof_folds, shuffle=False)
135
-
136
- val = np.zeros(nrof_folds)
137
- far = np.zeros(nrof_folds)
138
-
139
- diff = np.subtract(embeddings1, embeddings2)
140
- dist = np.sum(np.square(diff), 1)
141
- indices = np.arange(nrof_pairs)
142
-
143
- for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
144
-
145
- # Find the threshold that gives FAR = far_target
146
- far_train = np.zeros(nrof_thresholds)
147
- for threshold_idx, threshold in enumerate(thresholds):
148
- _, far_train[threshold_idx] = calculate_val_far(
149
- threshold, dist[train_set], actual_issame[train_set])
150
- if np.max(far_train) >= far_target:
151
- f = interpolate.interp1d(far_train, thresholds, kind='slinear')
152
- threshold = f(far_target)
153
- else:
154
- threshold = 0.0
155
-
156
- val[fold_idx], far[fold_idx] = calculate_val_far(
157
- threshold, dist[test_set], actual_issame[test_set])
158
-
159
- val_mean = np.mean(val)
160
- far_mean = np.mean(far)
161
- val_std = np.std(val)
162
- return val_mean, val_std, far_mean
163
-
164
-
165
- def calculate_val_far(threshold, dist, actual_issame):
166
- predict_issame = np.less(dist, threshold)
167
- true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
168
- false_accept = np.sum(
169
- np.logical_and(predict_issame, np.logical_not(actual_issame)))
170
- n_same = np.sum(actual_issame)
171
- n_diff = np.sum(np.logical_not(actual_issame))
172
- # print(true_accept, false_accept)
173
- # print(n_same, n_diff)
174
- val = float(true_accept) / float(n_same)
175
- far = float(false_accept) / float(n_diff)
176
- return val, far
177
-
178
-
179
- def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
180
- # Calculate evaluation metrics
181
- thresholds = np.arange(0, 4, 0.01)
182
- embeddings1 = embeddings[0::2]
183
- embeddings2 = embeddings[1::2]
184
- tpr, fpr, accuracy = calculate_roc(thresholds,
185
- embeddings1,
186
- embeddings2,
187
- np.asarray(actual_issame),
188
- nrof_folds=nrof_folds,
189
- pca=pca)
190
- thresholds = np.arange(0, 4, 0.001)
191
- val, val_std, far = calculate_val(thresholds,
192
- embeddings1,
193
- embeddings2,
194
- np.asarray(actual_issame),
195
- 1e-3,
196
- nrof_folds=nrof_folds)
197
- return tpr, fpr, accuracy, val, val_std, far
198
-
199
- @torch.no_grad()
200
- def load_bin(path, image_size):
201
- try:
202
- with open(path, 'rb') as f:
203
- bins, issame_list = pickle.load(f) # py2
204
- except UnicodeDecodeError as e:
205
- with open(path, 'rb') as f:
206
- bins, issame_list = pickle.load(f, encoding='bytes') # py3
207
- data_list = []
208
- for flip in [0, 1]:
209
- data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
210
- data_list.append(data)
211
- for idx in range(len(issame_list) * 2):
212
- _bin = bins[idx]
213
- img = mx.image.imdecode(_bin)
214
- if img.shape[1] != image_size[0]:
215
- img = mx.image.resize_short(img, image_size[0])
216
- img = nd.transpose(img, axes=(2, 0, 1))
217
- for flip in [0, 1]:
218
- if flip == 1:
219
- img = mx.ndarray.flip(data=img, axis=2)
220
- data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
221
- if idx % 1000 == 0:
222
- print('loading bin', idx)
223
- print(data_list[0].shape)
224
- return data_list, issame_list
225
-
226
- @torch.no_grad()
227
- def test(data_set, backbone, batch_size, nfolds=10):
228
- print('testing verification..')
229
- data_list = data_set[0]
230
- issame_list = data_set[1]
231
- embeddings_list = []
232
- time_consumed = 0.0
233
- for i in range(len(data_list)):
234
- data = data_list[i]
235
- embeddings = None
236
- ba = 0
237
- while ba < data.shape[0]:
238
- bb = min(ba + batch_size, data.shape[0])
239
- count = bb - ba
240
- _data = data[bb - batch_size: bb]
241
- time0 = datetime.datetime.now()
242
- img = ((_data / 255) - 0.5) / 0.5
243
- net_out: torch.Tensor = backbone(img)
244
- _embeddings = net_out.detach().cpu().numpy()
245
- time_now = datetime.datetime.now()
246
- diff = time_now - time0
247
- time_consumed += diff.total_seconds()
248
- if embeddings is None:
249
- embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
250
- embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
251
- ba = bb
252
- embeddings_list.append(embeddings)
253
-
254
- _xnorm = 0.0
255
- _xnorm_cnt = 0
256
- for embed in embeddings_list:
257
- for i in range(embed.shape[0]):
258
- _em = embed[i]
259
- _norm = np.linalg.norm(_em)
260
- _xnorm += _norm
261
- _xnorm_cnt += 1
262
- _xnorm /= _xnorm_cnt
263
-
264
- acc1 = 0.0
265
- std1 = 0.0
266
- embeddings = embeddings_list[0] + embeddings_list[1]
267
- embeddings = sklearn.preprocessing.normalize(embeddings)
268
- print(embeddings.shape)
269
- print('infer time', time_consumed)
270
- _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds)
271
- acc2, std2 = np.mean(accuracy), np.std(accuracy)
272
- return acc1, std1, acc2, std2, _xnorm, embeddings_list
273
-
274
-
275
- def dumpR(data_set,
276
- backbone,
277
- batch_size,
278
- name='',
279
- data_extra=None,
280
- label_shape=None):
281
- print('dump verification embedding..')
282
- data_list = data_set[0]
283
- issame_list = data_set[1]
284
- embeddings_list = []
285
- time_consumed = 0.0
286
- for i in range(len(data_list)):
287
- data = data_list[i]
288
- embeddings = None
289
- ba = 0
290
- while ba < data.shape[0]:
291
- bb = min(ba + batch_size, data.shape[0])
292
- count = bb - ba
293
-
294
- _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)
295
- time0 = datetime.datetime.now()
296
- if data_extra is None:
297
- db = mx.io.DataBatch(data=(_data,), label=(_label,))
298
- else:
299
- db = mx.io.DataBatch(data=(_data, _data_extra),
300
- label=(_label,))
301
- model.forward(db, is_train=False)
302
- net_out = model.get_outputs()
303
- _embeddings = net_out[0].asnumpy()
304
- time_now = datetime.datetime.now()
305
- diff = time_now - time0
306
- time_consumed += diff.total_seconds()
307
- if embeddings is None:
308
- embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
309
- embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
310
- ba = bb
311
- embeddings_list.append(embeddings)
312
- embeddings = embeddings_list[0] + embeddings_list[1]
313
- embeddings = sklearn.preprocessing.normalize(embeddings)
314
- actual_issame = np.asarray(issame_list)
315
- outname = os.path.join('temp.bin')
316
- with open(outname, 'wb') as f:
317
- pickle.dump((embeddings, issame_list),
318
- f,
319
- protocol=pickle.HIGHEST_PROTOCOL)
320
-
321
-
322
- # if __name__ == '__main__':
323
- #
324
- # parser = argparse.ArgumentParser(description='do verification')
325
- # # general
326
- # parser.add_argument('--data-dir', default='', help='')
327
- # parser.add_argument('--model',
328
- # default='../model/softmax,50',
329
- # help='path to load model.')
330
- # parser.add_argument('--target',
331
- # default='lfw,cfp_ff,cfp_fp,agedb_30',
332
- # help='test targets.')
333
- # parser.add_argument('--gpu', default=0, type=int, help='gpu id')
334
- # parser.add_argument('--batch-size', default=32, type=int, help='')
335
- # parser.add_argument('--max', default='', type=str, help='')
336
- # parser.add_argument('--mode', default=0, type=int, help='')
337
- # parser.add_argument('--nfolds', default=10, type=int, help='')
338
- # args = parser.parse_args()
339
- # image_size = [112, 112]
340
- # print('image_size', image_size)
341
- # ctx = mx.gpu(args.gpu)
342
- # nets = []
343
- # vec = args.model.split(',')
344
- # prefix = args.model.split(',')[0]
345
- # epochs = []
346
- # if len(vec) == 1:
347
- # pdir = os.path.dirname(prefix)
348
- # for fname in os.listdir(pdir):
349
- # if not fname.endswith('.params'):
350
- # continue
351
- # _file = os.path.join(pdir, fname)
352
- # if _file.startswith(prefix):
353
- # epoch = int(fname.split('.')[0].split('-')[1])
354
- # epochs.append(epoch)
355
- # epochs = sorted(epochs, reverse=True)
356
- # if len(args.max) > 0:
357
- # _max = [int(x) for x in args.max.split(',')]
358
- # assert len(_max) == 2
359
- # if len(epochs) > _max[1]:
360
- # epochs = epochs[_max[0]:_max[1]]
361
- #
362
- # else:
363
- # epochs = [int(x) for x in vec[1].split('|')]
364
- # print('model number', len(epochs))
365
- # time0 = datetime.datetime.now()
366
- # for epoch in epochs:
367
- # print('loading', prefix, epoch)
368
- # sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
369
- # # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
370
- # all_layers = sym.get_internals()
371
- # sym = all_layers['fc1_output']
372
- # model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
373
- # # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
374
- # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
375
- # image_size[1]))])
376
- # model.set_params(arg_params, aux_params)
377
- # nets.append(model)
378
- # time_now = datetime.datetime.now()
379
- # diff = time_now - time0
380
- # print('model loading time', diff.total_seconds())
381
- #
382
- # ver_list = []
383
- # ver_name_list = []
384
- # for name in args.target.split(','):
385
- # path = os.path.join(args.data_dir, name + ".bin")
386
- # if os.path.exists(path):
387
- # print('loading.. ', name)
388
- # data_set = load_bin(path, image_size)
389
- # ver_list.append(data_set)
390
- # ver_name_list.append(name)
391
- #
392
- # if args.mode == 0:
393
- # for i in range(len(ver_list)):
394
- # results = []
395
- # for model in nets:
396
- # acc1, std1, acc2, std2, xnorm, embeddings_list = test(
397
- # ver_list[i], model, args.batch_size, args.nfolds)
398
- # print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))
399
- # print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))
400
- # print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2))
401
- # results.append(acc2)
402
- # print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results)))
403
- # elif args.mode == 1:
404
- # raise ValueError
405
- # else:
406
- # model = nets[0]
407
- # dumpR(ver_list[0], model, args.batch_size, args.target)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/eval_ijbc.py DELETED
@@ -1,483 +0,0 @@
1
- # coding: utf-8
2
-
3
- import os
4
- import pickle
5
-
6
- import matplotlib
7
- import pandas as pd
8
-
9
- matplotlib.use('Agg')
10
- import matplotlib.pyplot as plt
11
- import timeit
12
- import sklearn
13
- import argparse
14
- import cv2
15
- import numpy as np
16
- import torch
17
- from skimage import transform as trans
18
- from backbones import get_model
19
- from sklearn.metrics import roc_curve, auc
20
-
21
- from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
22
- from prettytable import PrettyTable
23
- from pathlib import Path
24
-
25
- import sys
26
- import warnings
27
-
28
- sys.path.insert(0, "../")
29
- warnings.filterwarnings("ignore")
30
-
31
- parser = argparse.ArgumentParser(description='do ijb test')
32
- # general
33
- parser.add_argument('--model-prefix', default='', help='path to load model.')
34
- parser.add_argument('--image-path', default='', type=str, help='')
35
- parser.add_argument('--result-dir', default='.', type=str, help='')
36
- parser.add_argument('--batch-size', default=128, type=int, help='')
37
- parser.add_argument('--network', default='iresnet50', type=str, help='')
38
- parser.add_argument('--job', default='insightface', type=str, help='job name')
39
- parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
40
- args = parser.parse_args()
41
-
42
- target = args.target
43
- model_path = args.model_prefix
44
- image_path = args.image_path
45
- result_dir = args.result_dir
46
- gpu_id = None
47
- use_norm_score = True # if Ture, TestMode(N1)
48
- use_detector_score = True # if Ture, TestMode(D1)
49
- use_flip_test = True # if Ture, TestMode(F1)
50
- job = args.job
51
- batch_size = args.batch_size
52
-
53
-
54
- class Embedding(object):
55
- def __init__(self, prefix, data_shape, batch_size=1):
56
- image_size = (112, 112)
57
- self.image_size = image_size
58
- weight = torch.load(prefix)
59
- resnet = get_model(args.network, dropout=0, fp16=False).cuda()
60
- resnet.load_state_dict(weight)
61
- model = torch.nn.DataParallel(resnet)
62
- self.model = model
63
- self.model.eval()
64
- src = np.array([
65
- [30.2946, 51.6963],
66
- [65.5318, 51.5014],
67
- [48.0252, 71.7366],
68
- [33.5493, 92.3655],
69
- [62.7299, 92.2041]], dtype=np.float32)
70
- src[:, 0] += 8.0
71
- self.src = src
72
- self.batch_size = batch_size
73
- self.data_shape = data_shape
74
-
75
- def get(self, rimg, landmark):
76
-
77
- assert landmark.shape[0] == 68 or landmark.shape[0] == 5
78
- assert landmark.shape[1] == 2
79
- if landmark.shape[0] == 68:
80
- landmark5 = np.zeros((5, 2), dtype=np.float32)
81
- landmark5[0] = (landmark[36] + landmark[39]) / 2
82
- landmark5[1] = (landmark[42] + landmark[45]) / 2
83
- landmark5[2] = landmark[30]
84
- landmark5[3] = landmark[48]
85
- landmark5[4] = landmark[54]
86
- else:
87
- landmark5 = landmark
88
- tform = trans.SimilarityTransform()
89
- tform.estimate(landmark5, self.src)
90
- M = tform.params[0:2, :]
91
- img = cv2.warpAffine(rimg,
92
- M, (self.image_size[1], self.image_size[0]),
93
- borderValue=0.0)
94
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
95
- img_flip = np.fliplr(img)
96
- img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB
97
- img_flip = np.transpose(img_flip, (2, 0, 1))
98
- input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8)
99
- input_blob[0] = img
100
- input_blob[1] = img_flip
101
- return input_blob
102
-
103
- @torch.no_grad()
104
- def forward_db(self, batch_data):
105
- imgs = torch.Tensor(batch_data).cuda()
106
- imgs.div_(255).sub_(0.5).div_(0.5)
107
- feat = self.model(imgs)
108
- feat = feat.reshape([self.batch_size, 2 * feat.shape[1]])
109
- return feat.cpu().numpy()
110
-
111
-
112
- # 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
113
- def divideIntoNstrand(listTemp, n):
114
- twoList = [[] for i in range(n)]
115
- for i, e in enumerate(listTemp):
116
- twoList[i % n].append(e)
117
- return twoList
118
-
119
-
120
- def read_template_media_list(path):
121
- # ijb_meta = np.loadtxt(path, dtype=str)
122
- ijb_meta = pd.read_csv(path, sep=' ', header=None).values
123
- templates = ijb_meta[:, 1].astype(np.int)
124
- medias = ijb_meta[:, 2].astype(np.int)
125
- return templates, medias
126
-
127
-
128
- # In[ ]:
129
-
130
-
131
- def read_template_pair_list(path):
132
- # pairs = np.loadtxt(path, dtype=str)
133
- pairs = pd.read_csv(path, sep=' ', header=None).values
134
- # print(pairs.shape)
135
- # print(pairs[:, 0].astype(np.int))
136
- t1 = pairs[:, 0].astype(np.int)
137
- t2 = pairs[:, 1].astype(np.int)
138
- label = pairs[:, 2].astype(np.int)
139
- return t1, t2, label
140
-
141
-
142
- # In[ ]:
143
-
144
-
145
- def read_image_feature(path):
146
- with open(path, 'rb') as fid:
147
- img_feats = pickle.load(fid)
148
- return img_feats
149
-
150
-
151
- # In[ ]:
152
-
153
-
154
- def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
155
- batch_size = args.batch_size
156
- data_shape = (3, 112, 112)
157
-
158
- files = files_list
159
- print('files:', len(files))
160
- rare_size = len(files) % batch_size
161
- faceness_scores = []
162
- batch = 0
163
- img_feats = np.empty((len(files), 1024), dtype=np.float32)
164
-
165
- batch_data = np.empty((2 * batch_size, 3, 112, 112))
166
- embedding = Embedding(model_path, data_shape, batch_size)
167
- for img_index, each_line in enumerate(files[:len(files) - rare_size]):
168
- name_lmk_score = each_line.strip().split(' ')
169
- img_name = os.path.join(img_path, name_lmk_score[0])
170
- img = cv2.imread(img_name)
171
- lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
172
- dtype=np.float32)
173
- lmk = lmk.reshape((5, 2))
174
- input_blob = embedding.get(img, lmk)
175
-
176
- batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]
177
- batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]
178
- if (img_index + 1) % batch_size == 0:
179
- print('batch', batch)
180
- img_feats[batch * batch_size:batch * batch_size +
181
- batch_size][:] = embedding.forward_db(batch_data)
182
- batch += 1
183
- faceness_scores.append(name_lmk_score[-1])
184
-
185
- batch_data = np.empty((2 * rare_size, 3, 112, 112))
186
- embedding = Embedding(model_path, data_shape, rare_size)
187
- for img_index, each_line in enumerate(files[len(files) - rare_size:]):
188
- name_lmk_score = each_line.strip().split(' ')
189
- img_name = os.path.join(img_path, name_lmk_score[0])
190
- img = cv2.imread(img_name)
191
- lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
192
- dtype=np.float32)
193
- lmk = lmk.reshape((5, 2))
194
- input_blob = embedding.get(img, lmk)
195
- batch_data[2 * img_index][:] = input_blob[0]
196
- batch_data[2 * img_index + 1][:] = input_blob[1]
197
- if (img_index + 1) % rare_size == 0:
198
- print('batch', batch)
199
- img_feats[len(files) -
200
- rare_size:][:] = embedding.forward_db(batch_data)
201
- batch += 1
202
- faceness_scores.append(name_lmk_score[-1])
203
- faceness_scores = np.array(faceness_scores).astype(np.float32)
204
- # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
205
- # faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
206
- return img_feats, faceness_scores
207
-
208
-
209
- # In[ ]:
210
-
211
-
212
- def image2template_feature(img_feats=None, templates=None, medias=None):
213
- # ==========================================================
214
- # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
215
- # 2. compute media feature.
216
- # 3. compute template feature.
217
- # ==========================================================
218
- unique_templates = np.unique(templates)
219
- template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
220
-
221
- for count_template, uqt in enumerate(unique_templates):
222
-
223
- (ind_t,) = np.where(templates == uqt)
224
- face_norm_feats = img_feats[ind_t]
225
- face_medias = medias[ind_t]
226
- unique_medias, unique_media_counts = np.unique(face_medias,
227
- return_counts=True)
228
- media_norm_feats = []
229
- for u, ct in zip(unique_medias, unique_media_counts):
230
- (ind_m,) = np.where(face_medias == u)
231
- if ct == 1:
232
- media_norm_feats += [face_norm_feats[ind_m]]
233
- else: # image features from the same video will be aggregated into one feature
234
- media_norm_feats += [
235
- np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)
236
- ]
237
- media_norm_feats = np.array(media_norm_feats)
238
- # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
239
- template_feats[count_template] = np.sum(media_norm_feats, axis=0)
240
- if count_template % 2000 == 0:
241
- print('Finish Calculating {} template features.'.format(
242
- count_template))
243
- # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))
244
- template_norm_feats = sklearn.preprocessing.normalize(template_feats)
245
- # print(template_norm_feats.shape)
246
- return template_norm_feats, unique_templates
247
-
248
-
249
- # In[ ]:
250
-
251
-
252
- def verification(template_norm_feats=None,
253
- unique_templates=None,
254
- p1=None,
255
- p2=None):
256
- # ==========================================================
257
- # Compute set-to-set Similarity Score.
258
- # ==========================================================
259
- template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
260
- for count_template, uqt in enumerate(unique_templates):
261
- template2id[uqt] = count_template
262
-
263
- score = np.zeros((len(p1),)) # save cosine distance between pairs
264
-
265
- total_pairs = np.array(range(len(p1)))
266
- batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
267
- sublists = [
268
- total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
269
- ]
270
- total_sublists = len(sublists)
271
- for c, s in enumerate(sublists):
272
- feat1 = template_norm_feats[template2id[p1[s]]]
273
- feat2 = template_norm_feats[template2id[p2[s]]]
274
- similarity_score = np.sum(feat1 * feat2, -1)
275
- score[s] = similarity_score.flatten()
276
- if c % 10 == 0:
277
- print('Finish {}/{} pairs.'.format(c, total_sublists))
278
- return score
279
-
280
-
281
- # In[ ]:
282
- def verification2(template_norm_feats=None,
283
- unique_templates=None,
284
- p1=None,
285
- p2=None):
286
- template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
287
- for count_template, uqt in enumerate(unique_templates):
288
- template2id[uqt] = count_template
289
- score = np.zeros((len(p1),)) # save cosine distance between pairs
290
- total_pairs = np.array(range(len(p1)))
291
- batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
292
- sublists = [
293
- total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
294
- ]
295
- total_sublists = len(sublists)
296
- for c, s in enumerate(sublists):
297
- feat1 = template_norm_feats[template2id[p1[s]]]
298
- feat2 = template_norm_feats[template2id[p2[s]]]
299
- similarity_score = np.sum(feat1 * feat2, -1)
300
- score[s] = similarity_score.flatten()
301
- if c % 10 == 0:
302
- print('Finish {}/{} pairs.'.format(c, total_sublists))
303
- return score
304
-
305
-
306
- def read_score(path):
307
- with open(path, 'rb') as fid:
308
- img_feats = pickle.load(fid)
309
- return img_feats
310
-
311
-
312
- # # Step1: Load Meta Data
313
-
314
- # In[ ]:
315
-
316
- assert target == 'IJBC' or target == 'IJBB'
317
-
318
- # =============================================================
319
- # load image and template relationships for template feature embedding
320
- # tid --> template id, mid --> media id
321
- # format:
322
- # image_name tid mid
323
- # =============================================================
324
- start = timeit.default_timer()
325
- templates, medias = read_template_media_list(
326
- os.path.join('%s/meta' % image_path,
327
- '%s_face_tid_mid.txt' % target.lower()))
328
- stop = timeit.default_timer()
329
- print('Time: %.2f s. ' % (stop - start))
330
-
331
- # In[ ]:
332
-
333
- # =============================================================
334
- # load template pairs for template-to-template verification
335
- # tid : template id, label : 1/0
336
- # format:
337
- # tid_1 tid_2 label
338
- # =============================================================
339
- start = timeit.default_timer()
340
- p1, p2, label = read_template_pair_list(
341
- os.path.join('%s/meta' % image_path,
342
- '%s_template_pair_label.txt' % target.lower()))
343
- stop = timeit.default_timer()
344
- print('Time: %.2f s. ' % (stop - start))
345
-
346
- # # Step 2: Get Image Features
347
-
348
- # In[ ]:
349
-
350
- # =============================================================
351
- # load image features
352
- # format:
353
- # img_feats: [image_num x feats_dim] (227630, 512)
354
- # =============================================================
355
- start = timeit.default_timer()
356
- img_path = '%s/loose_crop' % image_path
357
- img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())
358
- img_list = open(img_list_path)
359
- files = img_list.readlines()
360
- # files_list = divideIntoNstrand(files, rank_size)
361
- files_list = files
362
-
363
- # img_feats
364
- # for i in range(rank_size):
365
- img_feats, faceness_scores = get_image_feature(img_path, files_list,
366
- model_path, 0, gpu_id)
367
- stop = timeit.default_timer()
368
- print('Time: %.2f s. ' % (stop - start))
369
- print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],
370
- img_feats.shape[1]))
371
-
372
- # # Step3: Get Template Features
373
-
374
- # In[ ]:
375
-
376
- # =============================================================
377
- # compute template features from image features.
378
- # =============================================================
379
- start = timeit.default_timer()
380
- # ==========================================================
381
- # Norm feature before aggregation into template feature?
382
- # Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
383
- # ==========================================================
384
- # 1. FaceScore (Feature Norm)
385
- # 2. FaceScore (Detector)
386
-
387
- if use_flip_test:
388
- # concat --- F1
389
- # img_input_feats = img_feats
390
- # add --- F2
391
- img_input_feats = img_feats[:, 0:img_feats.shape[1] //
392
- 2] + img_feats[:, img_feats.shape[1] // 2:]
393
- else:
394
- img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
395
-
396
- if use_norm_score:
397
- img_input_feats = img_input_feats
398
- else:
399
- # normalise features to remove norm information
400
- img_input_feats = img_input_feats / np.sqrt(
401
- np.sum(img_input_feats ** 2, -1, keepdims=True))
402
-
403
- if use_detector_score:
404
- print(img_input_feats.shape, faceness_scores.shape)
405
- img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
406
- else:
407
- img_input_feats = img_input_feats
408
-
409
- template_norm_feats, unique_templates = image2template_feature(
410
- img_input_feats, templates, medias)
411
- stop = timeit.default_timer()
412
- print('Time: %.2f s. ' % (stop - start))
413
-
414
- # # Step 4: Get Template Similarity Scores
415
-
416
- # In[ ]:
417
-
418
- # =============================================================
419
- # compute verification scores between template pairs.
420
- # =============================================================
421
- start = timeit.default_timer()
422
- score = verification(template_norm_feats, unique_templates, p1, p2)
423
- stop = timeit.default_timer()
424
- print('Time: %.2f s. ' % (stop - start))
425
-
426
- # In[ ]:
427
- save_path = os.path.join(result_dir, args.job)
428
- # save_path = result_dir + '/%s_result' % target
429
-
430
- if not os.path.exists(save_path):
431
- os.makedirs(save_path)
432
-
433
- score_save_file = os.path.join(save_path, "%s.npy" % target.lower())
434
- np.save(score_save_file, score)
435
-
436
- # # Step 5: Get ROC Curves and TPR@FPR Table
437
-
438
- # In[ ]:
439
-
440
- files = [score_save_file]
441
- methods = []
442
- scores = []
443
- for file in files:
444
- methods.append(Path(file).stem)
445
- scores.append(np.load(file))
446
-
447
- methods = np.array(methods)
448
- scores = dict(zip(methods, scores))
449
- colours = dict(
450
- zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
451
- x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
452
- tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
453
- fig = plt.figure()
454
- for method in methods:
455
- fpr, tpr, _ = roc_curve(label, scores[method])
456
- roc_auc = auc(fpr, tpr)
457
- fpr = np.flipud(fpr)
458
- tpr = np.flipud(tpr) # select largest tpr at same fpr
459
- plt.plot(fpr,
460
- tpr,
461
- color=colours[method],
462
- lw=1,
463
- label=('[%s (AUC = %0.4f %%)]' %
464
- (method.split('-')[-1], roc_auc * 100)))
465
- tpr_fpr_row = []
466
- tpr_fpr_row.append("%s-%s" % (method, target))
467
- for fpr_iter in np.arange(len(x_labels)):
468
- _, min_index = min(
469
- list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
470
- tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
471
- tpr_fpr_table.add_row(tpr_fpr_row)
472
- plt.xlim([10 ** -6, 0.1])
473
- plt.ylim([0.3, 1.0])
474
- plt.grid(linestyle='--', linewidth=1)
475
- plt.xticks(x_labels)
476
- plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
477
- plt.xscale('log')
478
- plt.xlabel('False Positive Rate')
479
- plt.ylabel('True Positive Rate')
480
- plt.title('ROC on IJB')
481
- plt.legend(loc="lower right")
482
- fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower()))
483
- print(tpr_fpr_table)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/face3d/models/arcface_torch/inference.py DELETED
@@ -1,35 +0,0 @@
1
- import argparse
2
-
3
- import cv2
4
- import numpy as np
5
- import torch
6
-
7
- from backbones import get_model
8
-
9
-
10
- @torch.no_grad()
11
- def inference(weight, name, img):
12
- if img is None:
13
- img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8)
14
- else:
15
- img = cv2.imread(img)
16
- img = cv2.resize(img, (112, 112))
17
-
18
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
19
- img = np.transpose(img, (2, 0, 1))
20
- img = torch.from_numpy(img).unsqueeze(0).float()
21
- img.div_(255).sub_(0.5).div_(0.5)
22
- net = get_model(name, fp16=False)
23
- net.load_state_dict(torch.load(weight))
24
- net.eval()
25
- feat = net(img).numpy()
26
- print(feat)
27
-
28
-
29
- if __name__ == "__main__":
30
- parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
31
- parser.add_argument('--network', type=str, default='r50', help='backbone network')
32
- parser.add_argument('--weight', type=str, default='')
33
- parser.add_argument('--img', type=str, default=None)
34
- args = parser.parse_args()
35
- inference(args.weight, args.network, args.img)