File size: 5,231 Bytes
e0336bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Video Upscaler for Blissful Tuner Extension

License: Apache 2.0
Created on Wed Apr 23 10:19:19 2025
@author: blyss
"""

from typing import List
import torch
import numpy as np
from tqdm import tqdm
from rich.traceback import install as install_rich_tracebacks
from swinir.network_swinir import SwinIR
from spandrel import ImageModelDescriptor, ModelLoader
from video_processing_common import BlissfulVideoProcessor, set_seed, setup_parser_video_common
from utils import setup_compute_context, load_torch_file, BlissfulLogger
logger = BlissfulLogger(__name__, "#8e00ed")
install_rich_tracebacks()


def upscale_frames_swin(
    model: torch.nn.Module,
    frames: List[np.ndarray],
    VideoProcessor: BlissfulVideoProcessor
) -> List[np.ndarray]:
    """
    Upscale a list of RGB frames using a compiled SwinIR model.

    Args:
        model: Loaded SwinIR upsampler.
        frames: List of H×W×3 float32 RGB arrays in [0,1].
        device: torch device (cpu or cuda).
        dtype: torch.dtype to use for computation.

    Returns:
        List of upscaled H'×W'×3 uint8 BGR frames.
    """
    window_size = 8
    for img in tqdm(frames, desc="Upscaling SwinIR"):
        # Mark step for CUDA graph capture if enabled
        torch.compiler.cudagraph_mark_step_begin()

        # Convert HWC RGB → CHW tensor
        tensor = VideoProcessor.np_image_to_tensor(img)

        # Pad to window multiple
        _, _, h, w = tensor.shape
        h_pad = ((h + window_size - 1) // window_size) * window_size - h
        w_pad = ((w + window_size - 1) // window_size) * window_size - w
        tensor = torch.cat([tensor, torch.flip(tensor, [2])], 2)[:, :, : h + h_pad, :]
        tensor = torch.cat([tensor, torch.flip(tensor, [3])], 3)[:, :, :, : w + w_pad]

        # Inference
        with torch.no_grad():
            out = model(tensor)

        # Post-process: NCHW → HWC BGR uint8
        VideoProcessor.write_np_or_tensor_to_png(out)


def load_swin_model(
    model_path: str,
    device: torch.device,
    dtype: torch.dtype,
) -> torch.nn.Module:
    """
    Instantiate and load weights into a SwinIR model.

    Args:
        model_path: Path to checkpoint (.pth or safetensors).
        device: torch device.
        dtype: torch dtype.
    Returns:
        SwinIR model in eval() on device and dtype.
    """
    logger.info(f"Loading SwinIR model ({dtype})…")
    model = SwinIR(
        upscale=4,
        in_chans=3,
        img_size=64,
        window_size=8,
        img_range=1.0,
        depths=[6] * 9,
        embed_dim=240,
        num_heads=[8] * 9,
        mlp_ratio=2,
        upsampler='nearest+conv',
        resi_connection='3conv',
    )
    ckpt = load_torch_file(model_path)
    key = 'params_ema' if 'params_ema' in ckpt else None
    model.load_state_dict(ckpt[key] if key else ckpt, strict=True)
    model.to(device, dtype).eval()
    return model


def load_esrgan_model(
    model_path: str,
    device: torch.device,
    dtype: torch.dtype,
) -> torch.nn.Module:
    """
    Load an ESRGAN (or RRDBNet) style model via Spandrel loader.

    Args:
        model_path: Path to ESRGAN checkpoint.
        device: torch device.
        dtype: torch dtype.
    Returns:
        Model ready for inference.
    """
    logger.info(f"Loading ESRGAN model ({dtype})…")
    descriptor = ModelLoader().load_from_file(model_path)
    assert isinstance(descriptor, ImageModelDescriptor)
    model = descriptor.model.eval().to(device, dtype)
    return model


def main() -> None:
    """
    Parse CLI args, load input, model, and run upscaling pipeline.
    """
    parser = setup_parser_video_common(description="Video upscaling using SwinIR or ESRGAN models")
    parser.add_argument(
        "--scale", type=float, default=2,
        help="Final scale multiplier for output resolution"
    )
    parser.add_argument(
        "--mode", choices=["swinir", "esrgan"], default="swinir",
        help="Model architecture to use"
    )
    args = parser.parse_args()
    args.mode = args.mode.lower()
    # Map string → torch.dtype
    device, dtype = setup_compute_context(None, args.dtype)
    VideoProcessor = BlissfulVideoProcessor(device, dtype)
    VideoProcessor.prepare_files_and_path(args.input, args.output, args.mode.upper())

    frames, fps, w, h = VideoProcessor.load_frames(make_rgb=True)
    set_seed(args.seed)
    # Load and run model
    if args.mode == "swinir":
        model = load_swin_model(args.model, device, dtype)
        upscale_frames_swin(model, frames, VideoProcessor)
    else:
        model = load_esrgan_model(args.model, device, dtype)
        logger.info("Processing with ESRGAN...")
        for frame in tqdm(frames, desc="Upscaling ESRGAN"):
            inp = VideoProcessor.np_image_to_tensor(frame)
            with torch.no_grad():
                sr = model(inp)
            VideoProcessor.write_np_or_tensor_to_png(sr)

    # Write video
    logger.info("Encoding output video...")
    out_w, out_h = int(w * args.scale), int(h * args.scale)
    VideoProcessor.write_buffered_frames_to_output(fps, args.keep_pngs, (out_w, out_h))
    logger.info("Done!")


if __name__ == "__main__":
    main()