Fabrice-TIERCELIN commited on
Commit
ca2205c
·
verified ·
1 Parent(s): 520d777
Files changed (1) hide show
  1. infer.py +85 -85
infer.py CHANGED
@@ -1,86 +1,86 @@
1
- from PIL import Image
2
- import cv2 as cv
3
- import torch
4
- from RealESRGAN import RealESRGAN
5
- import tempfile
6
- import numpy as np
7
- import tqdm
8
- import ffmpeg
9
- import spaces
10
-
11
-
12
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
-
14
- @spaces.GPU(duration=60)
15
- def infer_image(img: Image.Image, size_modifier: int ) -> Image.Image:
16
- if img is None:
17
- raise Exception("Image not uploaded")
18
-
19
- width, height = img.size
20
-
21
- if width >= 5000 or height >= 5000:
22
- raise Exception("The image is too large.")
23
-
24
- model = RealESRGAN(device, scale=size_modifier)
25
- model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
26
-
27
- result = model.predict(img.convert('RGB'))
28
- print(f"Image size ({device}): {size_modifier} ... OK")
29
- return result
30
-
31
- @spaces.GPU(duration=300)
32
- def infer_video(video_filepath: str, size_modifier: int) -> str:
33
- model = RealESRGAN(device, scale=size_modifier)
34
- model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
35
-
36
- cap = cv.VideoCapture(video_filepath)
37
-
38
- tmpfile = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
39
- vid_output = tmpfile.name
40
- tmpfile.close()
41
-
42
- # Check if the input video has an audio stream
43
- probe = ffmpeg.probe(video_filepath)
44
- has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
45
-
46
- if has_audio:
47
- # Extract audio from the input video
48
- audio_file = video_filepath.replace(".mp4", ".wav")
49
- ffmpeg.input(video_filepath).output(audio_file, format='wav', ac=1).run(overwrite_output=True)
50
-
51
- vid_writer = cv.VideoWriter(
52
- vid_output,
53
- fourcc=cv.VideoWriter.fourcc(*'mp4v'),
54
- fps=cap.get(cv.CAP_PROP_FPS),
55
- frameSize=(int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) * size_modifier, int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) * size_modifier)
56
- )
57
-
58
- n_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
59
-
60
- for _ in tqdm.tqdm(range(n_frames)):
61
- ret, frame = cap.read()
62
- if not ret:
63
- break
64
-
65
- frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
66
- frame = Image.fromarray(frame)
67
-
68
- upscaled_frame = model.predict(frame.convert('RGB'))
69
-
70
- upscaled_frame = np.array(upscaled_frame)
71
- upscaled_frame = cv.cvtColor(upscaled_frame, cv.COLOR_RGB2BGR)
72
-
73
- vid_writer.write(upscaled_frame)
74
-
75
- vid_writer.release()
76
-
77
- if has_audio:
78
- # Re-encode the video with the modified audio
79
- ffmpeg.input(vid_output).output(video_filepath.replace(".mp4", "_upscaled.mp4"), vcodec='libx264', acodec='aac', audio_bitrate='320k').run(overwrite_output=True)
80
-
81
- # Replace the original audio with the upscaled audio
82
- ffmpeg.input(audio_file).output(video_filepath.replace(".mp4", "_upscaled.mp4"), acodec='aac', audio_bitrate='320k').run(overwrite_output=True)
83
-
84
- print(f"Video file : {video_filepath}")
85
-
86
  return vid_output.replace(".mp4", "_upscaled.mp4") if has_audio else vid_output
 
1
+ from PIL import Image
2
+ import cv2 as cv
3
+ import torch
4
+ from RealESRGAN import RealESRGAN
5
+ import tempfile
6
+ import numpy as np
7
+ import tqdm
8
+ import ffmpeg
9
+ import spaces
10
+
11
+
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ @spaces.GPU(duration=60)
15
+ def infer_image(img: Image.Image, size_modifier: int ) -> Image.Image:
16
+ if img is None:
17
+ raise Exception("Image not uploaded")
18
+
19
+ width, height = img.size
20
+
21
+ if width >= 5000 or height >= 5000:
22
+ raise Exception("The image is too large.")
23
+
24
+ model = RealESRGAN(device, scale=size_modifier)
25
+ model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
26
+
27
+ result = model.predict(img.convert('RGB'))
28
+ print(f"Image size ({device}): {size_modifier} ... OK")
29
+ return result
30
+
31
+ @spaces.GPU(duration=180)
32
+ def infer_video(video_filepath: str, size_modifier: int) -> str:
33
+ model = RealESRGAN(device, scale=size_modifier)
34
+ model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
35
+
36
+ cap = cv.VideoCapture(video_filepath)
37
+
38
+ tmpfile = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
39
+ vid_output = tmpfile.name
40
+ tmpfile.close()
41
+
42
+ # Check if the input video has an audio stream
43
+ probe = ffmpeg.probe(video_filepath)
44
+ has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
45
+
46
+ if has_audio:
47
+ # Extract audio from the input video
48
+ audio_file = video_filepath.replace(".mp4", ".wav")
49
+ ffmpeg.input(video_filepath).output(audio_file, format='wav', ac=1).run(overwrite_output=True)
50
+
51
+ vid_writer = cv.VideoWriter(
52
+ vid_output,
53
+ fourcc=cv.VideoWriter.fourcc(*'mp4v'),
54
+ fps=cap.get(cv.CAP_PROP_FPS),
55
+ frameSize=(int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) * size_modifier, int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) * size_modifier)
56
+ )
57
+
58
+ n_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
59
+
60
+ for _ in tqdm.tqdm(range(n_frames)):
61
+ ret, frame = cap.read()
62
+ if not ret:
63
+ break
64
+
65
+ frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
66
+ frame = Image.fromarray(frame)
67
+
68
+ upscaled_frame = model.predict(frame.convert('RGB'))
69
+
70
+ upscaled_frame = np.array(upscaled_frame)
71
+ upscaled_frame = cv.cvtColor(upscaled_frame, cv.COLOR_RGB2BGR)
72
+
73
+ vid_writer.write(upscaled_frame)
74
+
75
+ vid_writer.release()
76
+
77
+ if has_audio:
78
+ # Re-encode the video with the modified audio
79
+ ffmpeg.input(vid_output).output(video_filepath.replace(".mp4", "_upscaled.mp4"), vcodec='libx264', acodec='aac', audio_bitrate='320k').run(overwrite_output=True)
80
+
81
+ # Replace the original audio with the upscaled audio
82
+ ffmpeg.input(audio_file).output(video_filepath.replace(".mp4", "_upscaled.mp4"), acodec='aac', audio_bitrate='320k').run(overwrite_output=True)
83
+
84
+ print(f"Video file : {video_filepath}")
85
+
86
  return vid_output.replace(".mp4", "_upscaled.mp4") if has_audio else vid_output