Jagrut Thakare commited on
Commit
ba58b0f
·
1 Parent(s): 30d0b05

v8- Trying to install torch in build time and downloading dependacies model

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -0
  2. app.py +145 -37
  3. requirements.txt +2 -7
Dockerfile CHANGED
@@ -28,6 +28,8 @@ WORKDIR /home/user/app
28
  COPY --chown=user:user . /home/user/app
29
  RUN pip install --user pydantic==2.8.2 gradio
30
 
 
 
31
  EXPOSE 7860
32
  CMD ["python", "app.py"]
33
 
 
28
  COPY --chown=user:user . /home/user/app
29
  RUN pip install --user pydantic==2.8.2 gradio
30
 
31
+ RUN pip install --user torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
32
+
33
  EXPOSE 7860
34
  CMD ["python", "app.py"]
35
 
app.py CHANGED
@@ -1,48 +1,156 @@
1
  import os, sys, subprocess
2
  import gradio as gr
 
 
 
 
 
 
 
3
 
4
- def setup_gpu_deps():
 
5
  pkg_dir = os.path.expanduser("~/.local/gpu_packages")
6
  os.makedirs(pkg_dir, exist_ok=True)
7
- subprocess.run([
8
- sys.executable, "-m", "pip", "install",
9
- "--upgrade", "--target", pkg_dir,
10
- "torch", "torchvision", "torchaudio",
11
- "--extra-index-url", "https://download.pytorch.org/whl/cu118", "mxnet-cu112", "onnxruntime-gpu==1.12", "Cython", "insightface==0.2.1", "kornia==0.5.4", "dill"
12
- ], check=True)
13
  sys.path.insert(0, pkg_dir)
14
 
15
- setup_gpu_deps()
 
 
 
 
16
 
17
- def infer_faceswap(input_source, input_target):
18
- return input_source
 
 
 
 
19
 
 
 
 
 
 
20
 
21
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- with gr.Column():
24
-
25
- with gr.Row():
26
- with gr.Column():
27
- with gr.Row(equal_height=True):
28
- input_source = gr.Image(
29
- type="pil",
30
- label="Input Source"
31
- )
32
- input_target = gr.Image(
33
- type="pil",
34
- label="Input Target"
35
- )
36
- run_button = gr.Button("Generate")
37
-
38
- with gr.Column():
39
- result = gr.Image(type='pil', label='Image Output')
40
-
41
- run_button.click(
42
- fn=infer_faceswap,
43
- inputs=[input_source, input_target],
44
- outputs=[result]
45
- )
46
-
47
-
48
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os, sys, subprocess
2
  import gradio as gr
3
+ import argparse
4
+ import cv2
5
+ import time
6
+ def setup_dependancies():
7
+ if not os.path.exists("./download_models.sh"):
8
+ print("Error: download_models.sh script not found.")
9
+ sys.exit(1)
10
 
11
+ subprocess.run(["./download_models.sh"])
12
+
13
  pkg_dir = os.path.expanduser("~/.local/gpu_packages")
14
  os.makedirs(pkg_dir, exist_ok=True)
15
+ # subprocess.run([
16
+ # sys.executable, "-m", "pip", "install",
17
+ # "--upgrade", "--target", pkg_dir,
18
+ # "torch", "torchvision", "torchaudio",
19
+ # "--extra-index-url", "https://download.pytorch.org/whl/cu118", "mxnet-cu112", "onnxruntime-gpu==1.12", "Cython", "insightface==0.2.1", "kornia==0.5.4", "dill", "numpy"
20
+ # ], check=True)
21
  sys.path.insert(0, pkg_dir)
22
 
23
+ def init_models(args):
24
+
25
+ # model for face cropping
26
+ app = Face_detect_crop(name='antelope', root='./insightface_func/models')
27
+ app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))
28
 
29
+ # main model for generation
30
+ G = AEI_Net(args.backbone, num_blocks=args.num_blocks, c_id=512)
31
+ G.eval()
32
+ G.load_state_dict(torch.load(args.G_path, map_location=torch.device('cpu')))
33
+ G = G.cuda()
34
+ G = G.half()
35
 
36
+ # arcface model to get face embedding
37
+ netArc = iresnet100(fp16=False)
38
+ netArc.load_state_dict(torch.load('arcface_model/backbone.pth'))
39
+ netArc=netArc.cuda()
40
+ netArc.eval()
41
 
42
+ # model to get face landmarks
43
+ handler = Handler('./coordinate_reg/model/2d106det', 0, ctx_id=0, det_size=640)
44
+
45
+ # model to make superres of face, set use_sr=True if you want to use super resolution or use_sr=False if you don't
46
+ if args.use_sr:
47
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
48
+ torch.backends.cudnn.benchmark = True
49
+ opt = TestOptions()
50
+ #opt.which_epoch ='10_7'
51
+ model = Pix2PixModel(opt)
52
+ model.netG.train()
53
+ else:
54
+ model = None
55
+
56
+ return app, G, netArc, handler, model
57
+
58
+
59
+ def infer_faceswap(src, tgt):
60
+ app, G, netArc, handler, model = init_models(args)
61
+
62
+ # get crops from source images
63
+ print('List of source paths: ',args.source_paths)
64
+ source = []
65
+ img = cv2.imread(src)
66
+ img = crop_face(img, app, args.crop_size)[0]
67
+ source.append(img[:, :, ::-1])
68
+
69
+
70
+ target = []
71
+ img = cv2.imread(tgt)
72
+ img = crop_face(img, app, args.crop_size)[0]
73
+ target.append(img)
74
 
75
+ start = time.time()
76
+ final_frames_list, crop_frames_list, full_frames, tfm_array_list = model_inference(full_frames,
77
+ source,
78
+ target,
79
+ netArc,
80
+ G,
81
+ app,
82
+ True,
83
+ similarity_th=args.similarity_th,
84
+ crop_size=args.crop_size,
85
+ BS=args.batch_size)
86
+ result = get_final_image(final_frames_list, crop_frames_list, full_frames[0], tfm_array_list, handler)
87
+ cv2.imwrite(args.out_image_name, result)
88
+ print(f'Swapped Image saved with path {args.out_image_name}')
89
+
90
+ print('Total time: ', time.time()-start)
91
+
92
+
93
+
94
+
95
+ return result
96
+
97
+
98
+ if __name__ == "__main__":
99
+
100
+ setup_dependancies()
101
+
102
+ parser = argparse.ArgumentParser()
103
+
104
+ # Generator params
105
+ parser.add_argument('--G_path', default='weights/G_unet_2blocks.pth', type=str, help='Path to weights for G')
106
+ parser.add_argument('--backbone', default='unet', const='unet', nargs='?', choices=['unet', 'linknet', 'resnet'], help='Backbone for attribute encoder')
107
+ parser.add_argument('--num_blocks', default=2, type=int, help='Numbers of AddBlocks at AddResblock')
108
+
109
+ parser.add_argument('--batch_size', default=40, type=int)
110
+ parser.add_argument('--crop_size', default=224, type=int, help="Don't change this")
111
+ parser.add_argument('--use_sr', default=False, type=bool, help='True for super resolution on swap images')
112
+ parser.add_argument('--similarity_th', default=0.15, type=float, help='Threshold for selecting a face similar to the target')
113
+
114
+ parser.add_argument('--source_paths', default=['examples/images/mark.jpg', 'examples/images/elon_musk.jpg'], nargs='+')
115
+ parser.add_argument('--target_faces_paths', default=[], nargs='+', help="It's necessary to set the face/faces in the video to which the source face/faces is swapped. You can skip this parametr, and then any face is selected in the target video for swap.")
116
+
117
+ # parameters for image to video
118
+ parser.add_argument('--target_video', default='examples/videos/nggyup.mp4', type=str, help="It's necessary for image to video swap")
119
+ parser.add_argument('--out_video_name', default='examples/results/result.mp4', type=str, help="It's necessary for image to video swap")
120
+
121
+ # parameters for image to image
122
+ parser.add_argument('--image_to_image', default=True, type=bool, help='True for image to image swap, False for swap on video')
123
+ parser.add_argument('--target_image', default='examples/images/beckham.jpg', type=str, help="It's necessary for image to image swap")
124
+ parser.add_argument('--out_image_name', default='examples/results/result.png', type=str,help="It's necessary for image to image swap")
125
+
126
+ args = parser.parse_args()
127
+
128
+
129
+ with gr.Blocks() as demo:
130
+
131
+ with gr.Column():
132
+
133
+ with gr.Row():
134
+ with gr.Column():
135
+ with gr.Row(equal_height=True):
136
+ input_source = gr.Image(
137
+ type="pil",
138
+ label="Input Source"
139
+ )
140
+ input_target = gr.Image(
141
+ type="pil",
142
+ label="Input Target"
143
+ )
144
+ run_button = gr.Button("Generate")
145
+
146
+ with gr.Column():
147
+ result = gr.Image(type='pil', label='Image Output')
148
+
149
+ run_button.click(
150
+ fn=infer_faceswap,
151
+ inputs=[input_source, input_target],
152
+ outputs=[result]
153
+ )
154
+
155
+
156
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True)
requirements.txt CHANGED
@@ -1,12 +1,7 @@
1
- numpy
2
- -f https://download.pytorch.org/whl/torch_stable.html
3
- torch==1.6.0+cu101
4
- -f https://download.pytorch.org/whl/torch_stable.html
5
- torchvision==0.7.0+cu101
6
  opencv-python
7
- onnx==1.9.0
8
  onnxruntime-gpu==1.12
9
- mxnet-cu101mkl
10
  scikit-image
11
  insightface==0.2.1
12
  requests==2.25.1
 
 
 
 
 
 
1
  opencv-python
2
+ onnx
3
  onnxruntime-gpu==1.12
4
+ mxnet-cu112
5
  scikit-image
6
  insightface==0.2.1
7
  requests==2.25.1