Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,395 Bytes
5c79851 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
"""
Patch for torch module to make it compatible with newer diffusers versions
while using PyTorch 2.0.1
"""
import torch
import sys
import warnings
import types
import functools
# Check if the attributes already exist
if not hasattr(torch, 'float8_e4m3fn'):
# Add missing attributes for compatibility
# These won't actually function, but they'll allow imports to succeed
torch.float8_e4m3fn = torch.float16 # Use float16 as a placeholder type
warnings.warn(
"Added placeholder for torch.float8_e4m3fn. Actual 8-bit operations won't work, "
"but imports should succeed. Using PyTorch 2.0.1 with newer diffusers."
)
if not hasattr(torch, 'float8_e5m2'):
torch.float8_e5m2 = torch.float16 # Use float16 as a placeholder type
# Add other missing torch types that might be referenced
for type_name in ['bfloat16', 'bfloat8', 'float8_e4m3fnuz']:
if not hasattr(torch, type_name):
setattr(torch, type_name, torch.float16)
# Create a placeholder for torch._dynamo if it doesn't exist
if not hasattr(torch, '_dynamo'):
torch._dynamo = types.ModuleType('torch._dynamo')
sys.modules['torch._dynamo'] = torch._dynamo
# Add common attributes/functions used by torch._dynamo
torch._dynamo.config = types.SimpleNamespace(suppress_errors=True)
torch._dynamo.optimize = lambda *args, **kwargs: lambda f: f
torch._dynamo.disable = lambda: None
torch._dynamo.reset_repro_cache = lambda: None
# Add torch.compile if it doesn't exist
if not hasattr(torch, 'compile'):
# Just return the function unchanged
torch.compile = lambda fn, **kwargs: fn
# Create a placeholder for torch.cuda.amp if it doesn't exist
if not hasattr(torch.cuda, 'amp'):
torch.cuda.amp = types.ModuleType('torch.cuda.amp')
sys.modules['torch.cuda.amp'] = torch.cuda.amp
# Mock autocast
class MockAutocast:
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
def __exit__(self, *args):
pass
def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
torch.cuda.amp.autocast = MockAutocast
print("PyTorch patched for compatibility with newer diffusers - using latest diffusers with PyTorch 2.0.1") |