Spaces:
Running
Running
File size: 4,340 Bytes
9fd1204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import importlib
import importlib.util
import operator as op
from typing import Union
import importlib_metadata
from packaging.version import Version, parse
from finetrainers.logging import get_logger
logger = get_logger()
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
# This function was copied from: https://github.com/huggingface/diffusers/blob/5873377a660dac60a6bd86ef9b4fdfc385305977/src/diffusers/utils/import_utils.py#L57
def _is_package_available(pkg_name: str):
pkg_exists = importlib.util.find_spec(pkg_name) is not None
pkg_version = "N/A"
if pkg_exists:
try:
pkg_version = importlib_metadata.version(pkg_name)
logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
except (ImportError, importlib_metadata.PackageNotFoundError):
pkg_exists = False
return pkg_exists, pkg_version
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
"""
Compares a library version to some requirement using a given operation.
Args:
library_or_version (`str` or `packaging.version.Version`):
A library name or a version to check.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`.
requirement_version (`str`):
The version to compare the library version against
"""
if operation not in STR_OPERATION_TO_FUNC.keys():
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
operation = STR_OPERATION_TO_FUNC[operation]
if isinstance(library_or_version, str):
library_or_version = parse(importlib_metadata.version(library_or_version))
return operation(library_or_version, parse(requirement_version))
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
_datasets_available, _datasets_version = _is_package_available("datasets")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_kornia_available, _kornia_version = _is_package_available("kornia")
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_torch_available, _torch_version = _is_package_available("torch")
_xformers_available, _xformers_version = _is_package_available("xformers")
def is_bitsandbytes_available():
return _bitsandbytes_available
def is_datasets_available():
return _datasets_available
def is_flash_attn_available():
return _flash_attn_available
def is_kornia_available():
return _kornia_available
def is_sageattention_available():
return _sageattention_available
def is_torch_available():
return _torch_available
def is_xformers_available():
return _xformers_available
def is_bitsandbytes_version(operation: str, version: str):
if not _bitsandbytes_available:
return False
return compare_versions(parse(_bitsandbytes_version), operation, version)
def is_datasets_version(operation: str, version: str):
if not _datasets_available:
return False
return compare_versions(parse(_datasets_version), operation, version)
def is_flash_attn_version(operation: str, version: str):
if not _flash_attn_available:
return False
return compare_versions(parse(_flash_attn_version), operation, version)
def is_kornia_version(operation: str, version: str):
if not _kornia_available:
return False
return compare_versions(parse(_kornia_version), operation, version)
def is_sageattention_version(operation: str, version: str):
if not _sageattention_available:
return False
return compare_versions(parse(_sageattention_version), operation, version)
def is_torch_version(operation: str, version: str):
if not _torch_available:
return False
return compare_versions(parse(_torch_version), operation, version)
def is_xformers_version(operation: str, version: str):
if not _xformers_available:
return False
return compare_versions(parse(_xformers_version), operation, version)
|