|
import torch |
|
from torchvision.models.optical_flow import raft_large |
|
from modules.flow_models.raft.rfr_new import RAFT |
|
|
|
def raft_flow( |
|
I0: torch.Tensor, |
|
I1: torch.Tensor, |
|
data_domain: str = "animation", |
|
device: str = 'cuda' |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
if I0.dtype != torch.float32 or I1.dtype != torch.float32: |
|
I0 = I0.to(torch.float32) |
|
I1 = I1.to(torch.float32) |
|
if data_domain == "animation": |
|
raft = RAFT().requires_grad_(False).eval().to(device) |
|
elif data_domain == "photorealism": |
|
raft = raft_large().requires_grad_(False).eval().to(device) |
|
else: |
|
raise ValueError("data_domain must be either 'animation' or 'photorealism'") |
|
return raft(I0, I1) if data_domain == "animation" else raft(I0, I1)[-1] |