|
|
|
|
|
|
|
import sys |
|
import torchvision |
|
|
|
def fix_torchvision_functional_tensor(): |
|
""" |
|
Fix torchvision.transforms.functional_tensor import issue |
|
""" |
|
try: |
|
|
|
import torchvision.transforms.functional_tensor |
|
print("torchvision.transforms.functional_tensor is available") |
|
return True |
|
except ImportError: |
|
print("torchvision.transforms.functional_tensor not found, applying compatibility fix...") |
|
|
|
try: |
|
|
|
import torch |
|
import torchvision.transforms.functional as F |
|
|
|
class FunctionalTensorMock: |
|
"""Mock module to replace functional_tensor""" |
|
|
|
@staticmethod |
|
def rgb_to_grayscale(img, num_output_channels=1): |
|
"""Convert RGB image to grayscale""" |
|
if hasattr(F, 'rgb_to_grayscale'): |
|
return F.rgb_to_grayscale(img, num_output_channels) |
|
else: |
|
|
|
if len(img.shape) == 4: |
|
weights = torch.tensor([0.299, 0.587, 0.114], |
|
device=img.device, dtype=img.dtype).view(1, 3, 1, 1) |
|
else: |
|
weights = torch.tensor([0.299, 0.587, 0.114], |
|
device=img.device, dtype=img.dtype).view(3, 1, 1) |
|
|
|
grayscale = torch.sum(img * weights, dim=-3, keepdim=True) |
|
|
|
if num_output_channels == 3: |
|
if len(img.shape) == 4: |
|
grayscale = grayscale.repeat(1, 3, 1, 1) |
|
else: |
|
grayscale = grayscale.repeat(3, 1, 1) |
|
|
|
return grayscale |
|
|
|
@staticmethod |
|
def resize(img, size, interpolation=2, antialias=None): |
|
"""Resize function wrapper""" |
|
try: |
|
from torchvision.transforms.v2.functional import resize as v2_resize |
|
return v2_resize(img, size, |
|
interpolation=interpolation, |
|
antialias=antialias) |
|
except ImportError: |
|
if hasattr(F, 'resize'): |
|
return F.resize(img, size, interpolation=interpolation) |
|
else: |
|
import torch.nn.functional as torch_F |
|
if isinstance(size, int): |
|
size = (size, size) |
|
return torch_F.interpolate( |
|
img.unsqueeze(0) if len(img.shape) == 3 else img, |
|
size=size, mode='bilinear', align_corners=False) |
|
|
|
def __getattr__(self, name): |
|
"""Fallback to regular functional module""" |
|
if hasattr(F, name): |
|
return getattr(F, name) |
|
else: |
|
|
|
try: |
|
import torchvision.transforms.v2.functional as v2_F |
|
if hasattr(v2_F, name): |
|
return getattr(v2_F, name) |
|
except ImportError: |
|
pass |
|
|
|
raise AttributeError(f"'{name}' not found in functional_tensor mock") |
|
|
|
|
|
functional_tensor_mock = FunctionalTensorMock() |
|
|
|
|
|
sys.modules['torchvision.transforms.functional_tensor'] = functional_tensor_mock |
|
print("Applied compatibility fix: created functional_tensor mock module") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"Failed to create functional_tensor mock: {e}") |
|
return False |
|
|
|
def apply_fix(): |
|
"""Apply the torchvision compatibility fix""" |
|
print(f"Torchvision version: {torchvision.__version__}") |
|
return fix_torchvision_functional_tensor() |
|
|
|
if __name__ == "__main__": |
|
apply_fix() |
|
|