Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- aworld/utils/async_func.py +66 -0
- aworld/utils/common.py +352 -0
- aworld/utils/import_package.py +256 -0
- aworld/utils/json_encoder.py +12 -0
- aworld/utils/oss.py +628 -0
aworld/utils/async_func.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
# Copyright (c) 2025 inclusionAI.
|
| 3 |
+
|
| 4 |
+
import asyncio
|
| 5 |
+
from functools import wraps
|
| 6 |
+
from typing import Callable, Optional, Union, Any, Dict
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Functionable:
|
| 10 |
+
def __init__(self, function: Callable[..., Any], *args: Any, **kwargs: Dict[str, Any]) -> None:
|
| 11 |
+
self.function = function
|
| 12 |
+
self.args = args
|
| 13 |
+
self.kwargs = kwargs
|
| 14 |
+
self.done: bool = False
|
| 15 |
+
self.error: bool = False
|
| 16 |
+
self.result: Optional[Any] = None
|
| 17 |
+
self.exception: Optional[Exception] = None
|
| 18 |
+
|
| 19 |
+
def __call__(self) -> None:
|
| 20 |
+
try:
|
| 21 |
+
self.result = self.function(*self.args, **self.kwargs)
|
| 22 |
+
except Exception as e:
|
| 23 |
+
self.error = True
|
| 24 |
+
self.exception = e
|
| 25 |
+
self.done = True
|
| 26 |
+
|
| 27 |
+
def call(self):
|
| 28 |
+
self.__call__()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def async_decorator(*func, delay: Optional[Union[int, float]] = 0.5) -> Callable:
|
| 32 |
+
def wrapper(function: Callable[..., Any]) -> Callable[..., Any]:
|
| 33 |
+
@wraps(function)
|
| 34 |
+
async def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
|
| 35 |
+
sleep_time = 0 if delay is None else delay
|
| 36 |
+
task = Functionable(function, *args, **kwargs)
|
| 37 |
+
# TODO: Use thread pool to process task
|
| 38 |
+
task.call()
|
| 39 |
+
if task.error:
|
| 40 |
+
raise task.exception
|
| 41 |
+
await asyncio.sleep(sleep_time)
|
| 42 |
+
return task.result
|
| 43 |
+
|
| 44 |
+
return inner_wrapper
|
| 45 |
+
|
| 46 |
+
if not func:
|
| 47 |
+
return wrapper
|
| 48 |
+
else:
|
| 49 |
+
if asyncio.iscoroutinefunction(func[0]):
|
| 50 |
+
# coroutine function, return itself
|
| 51 |
+
return func[0]
|
| 52 |
+
return wrapper(func[0])
|
| 53 |
+
|
| 54 |
+
def async_func(function: Callable[..., Any]) -> Callable[..., Any]:
|
| 55 |
+
@wraps(function)
|
| 56 |
+
async def inner_wrapper(*args: Any, **kwargs: Any) -> Any:
|
| 57 |
+
task = Functionable(function, *args, **kwargs)
|
| 58 |
+
task.call()
|
| 59 |
+
if task.error:
|
| 60 |
+
raise task.exception
|
| 61 |
+
return task.result
|
| 62 |
+
|
| 63 |
+
if asyncio.iscoroutinefunction(function):
|
| 64 |
+
# coroutine function, return itself
|
| 65 |
+
return function
|
| 66 |
+
return inner_wrapper
|
aworld/utils/common.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
# Copyright (c) 2025 inclusionAI.
|
| 3 |
+
import asyncio
|
| 4 |
+
import inspect
|
| 5 |
+
import os
|
| 6 |
+
import pkgutil
|
| 7 |
+
import re
|
| 8 |
+
import socket
|
| 9 |
+
import sys
|
| 10 |
+
import threading
|
| 11 |
+
import time
|
| 12 |
+
|
| 13 |
+
from functools import wraps
|
| 14 |
+
from types import FunctionType
|
| 15 |
+
from typing import Callable, Any, Tuple, List, Iterator, Dict, Union
|
| 16 |
+
|
| 17 |
+
from aworld.logs.util import logger
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def convert_to_snake(name: str) -> str:
|
| 21 |
+
"""Class name convert to snake."""
|
| 22 |
+
if '_' not in name:
|
| 23 |
+
name = re.sub(r'([a-z])([A-Z])', r'\1_\2', name)
|
| 24 |
+
return name.lower()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def snake_to_camel(snake):
|
| 28 |
+
words = snake.split('_')
|
| 29 |
+
return ''.join([w.capitalize() for w in words])
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def is_abstract_method(cls, method_name):
|
| 33 |
+
method = getattr(cls, method_name)
|
| 34 |
+
return (hasattr(method, '__isabstractmethod__') and method.__isabstractmethod__) or (
|
| 35 |
+
isinstance(method, FunctionType) and hasattr(
|
| 36 |
+
method, '__abstractmethods__') and method in method.__abstractmethods__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def override_in_subclass(name: str, sub_cls: object, base_cls: object) -> bool:
|
| 40 |
+
"""Judge whether a subclass overrides a specified method.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
name: The method name of sub class and base class
|
| 44 |
+
sub_cls: Specify subclasses of the base class.
|
| 45 |
+
base_cls: The parent class of the subclass.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Overwrite as true in subclasses, vice versa.
|
| 49 |
+
"""
|
| 50 |
+
if not issubclass(sub_cls, base_cls):
|
| 51 |
+
logger.warning(f"{sub_cls} is not sub class of {base_cls}")
|
| 52 |
+
return False
|
| 53 |
+
|
| 54 |
+
if sub_cls == base_cls and hasattr(sub_cls, name) and not is_abstract_method(sub_cls, name):
|
| 55 |
+
return True
|
| 56 |
+
|
| 57 |
+
this_method = getattr(sub_cls, name)
|
| 58 |
+
base_method = getattr(base_cls, name)
|
| 59 |
+
return this_method is not base_method
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def convert_to_subclass(obj, subclass):
|
| 63 |
+
obj.__class__ = subclass
|
| 64 |
+
return obj
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _walk_to_root(path: str) -> Iterator[str]:
|
| 68 |
+
"""Yield directories starting from the given directory up to the root."""
|
| 69 |
+
if not os.path.exists(path):
|
| 70 |
+
yield ''
|
| 71 |
+
|
| 72 |
+
if os.path.isfile(path):
|
| 73 |
+
path = os.path.dirname(path)
|
| 74 |
+
|
| 75 |
+
last_dir = None
|
| 76 |
+
current_dir = os.path.abspath(path)
|
| 77 |
+
while last_dir != current_dir:
|
| 78 |
+
yield current_dir
|
| 79 |
+
parent_dir = os.path.abspath(os.path.join(current_dir, os.path.pardir))
|
| 80 |
+
last_dir, current_dir = current_dir, parent_dir
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def find_file(filename: str) -> str:
|
| 84 |
+
"""Find file from the folders for the given file.
|
| 85 |
+
|
| 86 |
+
NOTE: Current running path priority, followed by the execution file path, and finally the aworld package path.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
filename: The file name that you want to search.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def run_dir():
|
| 93 |
+
try:
|
| 94 |
+
main = __import__('__main__', None, None, fromlist=['__file__'])
|
| 95 |
+
return os.path.dirname(main.__file__)
|
| 96 |
+
except ModuleNotFoundError:
|
| 97 |
+
return os.getcwd()
|
| 98 |
+
|
| 99 |
+
path = os.getcwd()
|
| 100 |
+
if os.path.exists(os.path.join(path, filename)):
|
| 101 |
+
path = os.getcwd()
|
| 102 |
+
elif os.path.exists(os.path.join(run_dir(), filename)):
|
| 103 |
+
path = run_dir()
|
| 104 |
+
else:
|
| 105 |
+
frame = inspect.currentframe()
|
| 106 |
+
current_file = __file__
|
| 107 |
+
|
| 108 |
+
while frame.f_code.co_filename == current_file or not os.path.exists(
|
| 109 |
+
frame.f_code.co_filename
|
| 110 |
+
):
|
| 111 |
+
assert frame.f_back is not None
|
| 112 |
+
frame = frame.f_back
|
| 113 |
+
frame_filename = frame.f_code.co_filename
|
| 114 |
+
path = os.path.dirname(os.path.abspath(frame_filename))
|
| 115 |
+
|
| 116 |
+
for dirname in _walk_to_root(path):
|
| 117 |
+
if not dirname:
|
| 118 |
+
continue
|
| 119 |
+
check_path = os.path.join(dirname, filename)
|
| 120 |
+
if os.path.isfile(check_path):
|
| 121 |
+
return check_path
|
| 122 |
+
|
| 123 |
+
return ''
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def search_in_module(module: object, base_classes: List[type]) -> List[Tuple[str, type]]:
|
| 127 |
+
"""Find all classes that inherit from a specific base class in the module."""
|
| 128 |
+
results = []
|
| 129 |
+
for name, obj in inspect.getmembers(module, inspect.isclass):
|
| 130 |
+
for base_class in base_classes:
|
| 131 |
+
if issubclass(obj, base_class) and obj is not base_class:
|
| 132 |
+
results.append((name, obj))
|
| 133 |
+
return results
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _scan_package(package_name: str, base_classes: List[type], results: List[Tuple[str, type]] = []):
|
| 137 |
+
try:
|
| 138 |
+
package = sys.modules[package_name]
|
| 139 |
+
except:
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
for sub_package, name, is_pkg in pkgutil.walk_packages(package.__path__):
|
| 144 |
+
try:
|
| 145 |
+
__import__(f"{package_name}.{name}")
|
| 146 |
+
except:
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
if is_pkg:
|
| 150 |
+
_scan_package(package_name + "." + name, base_classes, results)
|
| 151 |
+
try:
|
| 152 |
+
module = __import__(f"{package_name}.{name}", fromlist=[name])
|
| 153 |
+
results.extend(search_in_module(module, base_classes))
|
| 154 |
+
except:
|
| 155 |
+
continue
|
| 156 |
+
except:
|
| 157 |
+
pass
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def scan_packages(package: str, base_classes: List[type]) -> List[Tuple[str, type]]:
|
| 161 |
+
results = []
|
| 162 |
+
_scan_package(package, base_classes, results)
|
| 163 |
+
return results
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class ReturnThread(threading.Thread):
|
| 167 |
+
def __init__(self, func, *args, **kwargs):
|
| 168 |
+
threading.Thread.__init__(self)
|
| 169 |
+
self.func = func
|
| 170 |
+
self.args = args
|
| 171 |
+
self.kwargs = kwargs
|
| 172 |
+
self.result = None
|
| 173 |
+
self.daemon = True
|
| 174 |
+
|
| 175 |
+
def run(self):
|
| 176 |
+
self.result = asyncio.run(self.func(*self.args, **self.kwargs))
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def asyncio_loop():
|
| 180 |
+
try:
|
| 181 |
+
loop = asyncio.get_running_loop()
|
| 182 |
+
except RuntimeError:
|
| 183 |
+
loop = None
|
| 184 |
+
return loop
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def sync_exec(async_func: Callable[..., Any], *args, **kwargs):
|
| 188 |
+
"""Async function to sync execution."""
|
| 189 |
+
if not asyncio.iscoroutinefunction(async_func):
|
| 190 |
+
return async_func(*args, **kwargs)
|
| 191 |
+
|
| 192 |
+
loop = asyncio_loop()
|
| 193 |
+
if loop and loop.is_running():
|
| 194 |
+
thread = ReturnThread(async_func, *args, **kwargs)
|
| 195 |
+
thread.start()
|
| 196 |
+
thread.join()
|
| 197 |
+
result = thread.result
|
| 198 |
+
else:
|
| 199 |
+
loop = asyncio.get_event_loop()
|
| 200 |
+
result = loop.run_until_complete(async_func(*args, **kwargs))
|
| 201 |
+
return result
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def nest_dict_counter(usage: Dict[str, Union[int, Dict[str, int]]],
|
| 205 |
+
other: Dict[str, Union[int, Dict[str, int]]],
|
| 206 |
+
ignore_zero: bool = True):
|
| 207 |
+
"""Add counts from two dicts or nest dicts."""
|
| 208 |
+
result = {}
|
| 209 |
+
for elem, count in usage.items():
|
| 210 |
+
# nest dict
|
| 211 |
+
if isinstance(count, Dict):
|
| 212 |
+
res = nest_dict_counter(usage[elem], other.get(elem, {}))
|
| 213 |
+
result[elem] = res
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
newcount = count + other.get(elem, 0)
|
| 217 |
+
if not ignore_zero or newcount > 0:
|
| 218 |
+
result[elem] = newcount
|
| 219 |
+
|
| 220 |
+
for elem, count in other.items():
|
| 221 |
+
if elem not in usage and not ignore_zero:
|
| 222 |
+
result[elem] = count
|
| 223 |
+
return result
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_class(module_class: str):
|
| 227 |
+
import importlib
|
| 228 |
+
|
| 229 |
+
assert module_class
|
| 230 |
+
module_class = module_class.strip()
|
| 231 |
+
idx = module_class.rfind('.')
|
| 232 |
+
if idx != -1:
|
| 233 |
+
module = importlib.import_module(module_class[0:idx])
|
| 234 |
+
class_names = module_class[idx + 1:].split(":")
|
| 235 |
+
cls_obj = getattr(module, class_names[0])
|
| 236 |
+
for inner_class_name in class_names[1:]:
|
| 237 |
+
cls_obj = getattr(cls_obj, inner_class_name)
|
| 238 |
+
return cls_obj
|
| 239 |
+
else:
|
| 240 |
+
raise Exception("{} can not find!".format(module_class))
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def new_instance(module_class: str, *args, **kwargs):
|
| 244 |
+
"""Create module class instance based on module name."""
|
| 245 |
+
return get_class(module_class)(*args, **kwargs)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def retryable(tries: int = 3, delay: int = 1):
|
| 249 |
+
def inner_retry(f):
|
| 250 |
+
@wraps(f)
|
| 251 |
+
def f_retry(*args, **kwargs):
|
| 252 |
+
mtries, mdelay = tries, delay
|
| 253 |
+
while mtries > 0:
|
| 254 |
+
try:
|
| 255 |
+
return f(*args, **kwargs)
|
| 256 |
+
except Exception as e:
|
| 257 |
+
msg = f"{str(e)}, Retrying in {mdelay} seconds..."
|
| 258 |
+
logger.warning(msg)
|
| 259 |
+
time.sleep(mdelay)
|
| 260 |
+
mtries -= 1
|
| 261 |
+
return f(*args, **kwargs)
|
| 262 |
+
|
| 263 |
+
return f_retry
|
| 264 |
+
|
| 265 |
+
return inner_retry
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def get_local_ip():
|
| 269 |
+
try:
|
| 270 |
+
# build UDP socket
|
| 271 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
| 272 |
+
# connect to an external address (no need to connect)
|
| 273 |
+
s.connect(("8.8.8.8", 80))
|
| 274 |
+
# get local IP
|
| 275 |
+
local_ip = s.getsockname()[0]
|
| 276 |
+
s.close()
|
| 277 |
+
return local_ip
|
| 278 |
+
except Exception:
|
| 279 |
+
return "127.0.0.1"
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def replace_env_variables(config) -> Any:
|
| 283 |
+
"""Replace environment variables in configuration.
|
| 284 |
+
|
| 285 |
+
Environment variables should be in the format ${ENV_VAR_NAME}.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
config: Configuration to process (dict, list, or other value)
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
Processed configuration with environment variables replaced
|
| 292 |
+
"""
|
| 293 |
+
if isinstance(config, dict):
|
| 294 |
+
for key, value in config.items():
|
| 295 |
+
if isinstance(value, str):
|
| 296 |
+
if "${" in value and "}" in value:
|
| 297 |
+
pattern = r'\${([^}]+)}'
|
| 298 |
+
matches = re.findall(pattern, value)
|
| 299 |
+
result = value
|
| 300 |
+
for env_var_name in matches:
|
| 301 |
+
env_var_value = os.getenv(env_var_name, f"${{{env_var_name}}}")
|
| 302 |
+
result = result.replace(f"${{{env_var_name}}}", env_var_value)
|
| 303 |
+
config[key] = result
|
| 304 |
+
logger.info(f"Replaced {value} with {config[key]}")
|
| 305 |
+
if isinstance(value, dict) or isinstance(value, list):
|
| 306 |
+
replace_env_variables(value)
|
| 307 |
+
elif isinstance(config, list):
|
| 308 |
+
for index, item in enumerate(config):
|
| 309 |
+
if isinstance(item, str):
|
| 310 |
+
if "${" in item and "}" in item:
|
| 311 |
+
pattern = r'\${([^}]+)}'
|
| 312 |
+
matches = re.findall(pattern, item)
|
| 313 |
+
result = item
|
| 314 |
+
for env_var_name in matches:
|
| 315 |
+
env_var_value = os.getenv(env_var_name, f"${{{env_var_name}}}")
|
| 316 |
+
result = result.replace(f"${{{env_var_name}}}", env_var_value)
|
| 317 |
+
config[index] = result
|
| 318 |
+
logger.info(f"Replaced {item} with {config[index]}")
|
| 319 |
+
if isinstance(item, dict) or isinstance(item, list):
|
| 320 |
+
replace_env_variables(item)
|
| 321 |
+
return config
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def get_local_hostname():
|
| 325 |
+
"""
|
| 326 |
+
Get the local hostname.
|
| 327 |
+
First try `socket.gethostname()`, if it fails or returns an invalid value,
|
| 328 |
+
then try reverse DNS lookup using local IP.
|
| 329 |
+
"""
|
| 330 |
+
try:
|
| 331 |
+
hostname = socket.gethostname()
|
| 332 |
+
# Simple validation - if hostname contains '.', consider it a valid FQDN (Fully Qualified Domain Name)
|
| 333 |
+
if hostname and '.' in hostname:
|
| 334 |
+
return hostname
|
| 335 |
+
|
| 336 |
+
# If hostname is not qualified, try reverse lookup via IP
|
| 337 |
+
local_ip = get_local_ip()
|
| 338 |
+
if local_ip:
|
| 339 |
+
try:
|
| 340 |
+
# Get hostname from IP
|
| 341 |
+
hostname, _, _ = socket.gethostbyaddr(local_ip)
|
| 342 |
+
return hostname
|
| 343 |
+
except (socket.herror, socket.gaierror):
|
| 344 |
+
# Reverse lookup failed, return original hostname or IP
|
| 345 |
+
pass
|
| 346 |
+
|
| 347 |
+
# If all methods fail, return original gethostname() result or IP
|
| 348 |
+
return hostname if hostname else local_ip
|
| 349 |
+
|
| 350 |
+
except Exception:
|
| 351 |
+
# Final fallback strategy
|
| 352 |
+
return "localhost"
|
aworld/utils/import_package.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
# Copyright (c) 2025 inclusionAI.
|
| 3 |
+
import os.path
|
| 4 |
+
import time
|
| 5 |
+
import sys
|
| 6 |
+
import importlib
|
| 7 |
+
import subprocess
|
| 8 |
+
from importlib import metadata
|
| 9 |
+
from aworld.logs.util import logger
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ModuleAlias:
|
| 13 |
+
def __init__(self, module):
|
| 14 |
+
self.module = module
|
| 15 |
+
|
| 16 |
+
def __getattr__(self, name):
|
| 17 |
+
return getattr(self.module, name)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def is_package_installed(package_name: str, version: str = "") -> bool:
|
| 21 |
+
"""
|
| 22 |
+
Check if package is already installed and matches version if specified.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
package_name: Name of the package to check
|
| 26 |
+
version: Required version of the package
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
bool: True if package is installed (and version matches if specified), False otherwise
|
| 30 |
+
"""
|
| 31 |
+
try:
|
| 32 |
+
dist = metadata.distribution(package_name)
|
| 33 |
+
|
| 34 |
+
if version and dist.version != version:
|
| 35 |
+
logger.info(f"Package {package_name} is installed but version {dist.version} "
|
| 36 |
+
f"does not match required version {version}")
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
logger.info(f"Package {package_name} is already installed (version: {dist.version})")
|
| 40 |
+
return True
|
| 41 |
+
|
| 42 |
+
except metadata.PackageNotFoundError:
|
| 43 |
+
logger.info(f"Package {package_name} is not installed")
|
| 44 |
+
return False
|
| 45 |
+
except Exception as e:
|
| 46 |
+
logger.warning(f"Error checking if {package_name} is installed: {str(e)}")
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def import_packages(packages: list[str]) -> dict:
|
| 51 |
+
"""
|
| 52 |
+
Import and install multiple packages
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
packages: List of packages to import
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
dict: Dictionary mapping package names to imported modules
|
| 59 |
+
"""
|
| 60 |
+
modules = {}
|
| 61 |
+
for package in packages:
|
| 62 |
+
package_ = import_package(package)
|
| 63 |
+
if package_:
|
| 64 |
+
modules[package] = package_
|
| 65 |
+
return modules
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def import_package(
|
| 69 |
+
package_name: str,
|
| 70 |
+
alias: str = '',
|
| 71 |
+
install_name: str = '',
|
| 72 |
+
version: str = '',
|
| 73 |
+
installer: str = 'pip',
|
| 74 |
+
timeout: int = 300,
|
| 75 |
+
retry_count: int = 3,
|
| 76 |
+
retry_delay: int = 5
|
| 77 |
+
) -> object:
|
| 78 |
+
"""
|
| 79 |
+
Import and install package if not available.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
package_name: Name of the package to import
|
| 83 |
+
alias: Alias to use for the imported module
|
| 84 |
+
install_name: Name of the package to install (if different from import name)
|
| 85 |
+
version: Required version of the package
|
| 86 |
+
installer: Package installer to use ('pip' or 'conda')
|
| 87 |
+
timeout: Installation timeout in seconds
|
| 88 |
+
retry_count: Number of installation retries if install fails
|
| 89 |
+
retry_delay: Delay between retries in seconds
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Imported module
|
| 93 |
+
|
| 94 |
+
Raises:
|
| 95 |
+
ValueError: If input parameters are invalid
|
| 96 |
+
ImportError: If package cannot be imported or installed
|
| 97 |
+
TimeoutError: If installation exceeds timeout
|
| 98 |
+
"""
|
| 99 |
+
# Validate input parameters
|
| 100 |
+
if not package_name:
|
| 101 |
+
raise ValueError("Package name cannot be empty")
|
| 102 |
+
|
| 103 |
+
if installer not in ['pip', 'conda']:
|
| 104 |
+
raise ValueError(f"Unsupported installer: {installer}")
|
| 105 |
+
|
| 106 |
+
# Use package_name as install_name if not provided
|
| 107 |
+
real_install_name = install_name if install_name else package_name
|
| 108 |
+
|
| 109 |
+
# First, check if we need to install the package
|
| 110 |
+
need_install = False
|
| 111 |
+
|
| 112 |
+
# Try to import the module first
|
| 113 |
+
try:
|
| 114 |
+
logger.debug(f"Attempting to import {package_name}")
|
| 115 |
+
module = importlib.import_module(package_name)
|
| 116 |
+
logger.debug(f"Successfully imported {package_name}")
|
| 117 |
+
|
| 118 |
+
# If we successfully imported the module, check version if specified
|
| 119 |
+
if version:
|
| 120 |
+
try:
|
| 121 |
+
# For packages with different import and install names,
|
| 122 |
+
# we need to check the install name for version info
|
| 123 |
+
installed_version = metadata.version(real_install_name)
|
| 124 |
+
if installed_version != version:
|
| 125 |
+
logger.warning(
|
| 126 |
+
f"Package {real_install_name} version mismatch. "
|
| 127 |
+
f"Required: {version}, Installed: {installed_version}"
|
| 128 |
+
)
|
| 129 |
+
need_install = True
|
| 130 |
+
except metadata.PackageNotFoundError:
|
| 131 |
+
logger.warning(f"Could not determine version for {real_install_name}")
|
| 132 |
+
|
| 133 |
+
# If no need to reinstall for version mismatch, return the module
|
| 134 |
+
if not need_install:
|
| 135 |
+
return ModuleAlias(module) if alias else module
|
| 136 |
+
|
| 137 |
+
except ImportError as import_err:
|
| 138 |
+
logger.info(f"Could not import {package_name}: {str(import_err)}")
|
| 139 |
+
# Check if the package is installed
|
| 140 |
+
if not is_package_installed(real_install_name, version):
|
| 141 |
+
need_install = True
|
| 142 |
+
else:
|
| 143 |
+
# If package is installed but import failed, there might be an issue with dependencies
|
| 144 |
+
# or the package itself. Still, let's try to reinstall it.
|
| 145 |
+
logger.warning(f"Package {real_install_name} is installed but import of {package_name} failed. "
|
| 146 |
+
f"Will attempt reinstallation.")
|
| 147 |
+
need_install = True
|
| 148 |
+
|
| 149 |
+
# Install the package if needed
|
| 150 |
+
if need_install:
|
| 151 |
+
logger.info(f"Installation needed for {real_install_name}")
|
| 152 |
+
|
| 153 |
+
# Attempt installation with retries
|
| 154 |
+
for attempt in range(retry_count):
|
| 155 |
+
try:
|
| 156 |
+
cmd = _get_install_command(installer, real_install_name, version)
|
| 157 |
+
logger.info(f"Installing {real_install_name} with command: {' '.join(cmd)}")
|
| 158 |
+
_execute_install_command(cmd, timeout)
|
| 159 |
+
|
| 160 |
+
# Break out of retry loop if installation succeeds
|
| 161 |
+
break
|
| 162 |
+
|
| 163 |
+
except (ImportError, TimeoutError, subprocess.SubprocessError) as e:
|
| 164 |
+
if attempt < retry_count - 1:
|
| 165 |
+
logger.warning(
|
| 166 |
+
f"Installation attempt {attempt + 1} failed: {str(e)}. Retrying in {retry_delay} seconds...")
|
| 167 |
+
time.sleep(retry_delay)
|
| 168 |
+
else:
|
| 169 |
+
logger.error(f"All installation attempts failed for {real_install_name}")
|
| 170 |
+
raise ImportError(f"Failed to install {real_install_name} after {retry_count} attempts: {str(e)}")
|
| 171 |
+
|
| 172 |
+
# Try importing after installation
|
| 173 |
+
try:
|
| 174 |
+
logger.debug(f"Attempting to import {package_name} after installation")
|
| 175 |
+
module = importlib.import_module(package_name)
|
| 176 |
+
logger.debug(f"Successfully imported {package_name}")
|
| 177 |
+
return ModuleAlias(module) if alias else module
|
| 178 |
+
except ImportError as e:
|
| 179 |
+
error_msg = f"Failed to import {package_name} even after installation of {real_install_name}: {str(e)}"
|
| 180 |
+
logger.error(error_msg)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _get_install_command(installer: str, package_name: str, version: str = "") -> list:
|
| 184 |
+
"""
|
| 185 |
+
Generate installation command based on specified installer.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
installer: Package installer to use ('pip' or 'conda')
|
| 189 |
+
package_name: Name of the package to install
|
| 190 |
+
version: Required version of the package
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
list: Command as a list of strings
|
| 194 |
+
|
| 195 |
+
Raises:
|
| 196 |
+
ValueError: If unsupported installer is specified
|
| 197 |
+
"""
|
| 198 |
+
if installer == 'pip':
|
| 199 |
+
# Use sys.executable to ensure the right Python interpreter is used
|
| 200 |
+
pytho3 = os.path.basename(sys.executable)
|
| 201 |
+
cmd = [sys.executable, '-m', 'pip', 'install', '--upgrade']
|
| 202 |
+
if version:
|
| 203 |
+
cmd.append(f'{package_name}=={version}')
|
| 204 |
+
else:
|
| 205 |
+
cmd.append(package_name)
|
| 206 |
+
elif installer == 'conda':
|
| 207 |
+
cmd = ['conda', 'install', '-y', package_name]
|
| 208 |
+
if version:
|
| 209 |
+
cmd.extend([f'={version}'])
|
| 210 |
+
else:
|
| 211 |
+
raise ValueError(f"Unsupported installer: {installer}")
|
| 212 |
+
|
| 213 |
+
return cmd
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _execute_install_command(cmd: list, timeout: int) -> None:
|
| 217 |
+
"""
|
| 218 |
+
Execute package installation command.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
cmd: Installation command as list of strings
|
| 222 |
+
timeout: Installation timeout in seconds
|
| 223 |
+
|
| 224 |
+
Raises:
|
| 225 |
+
TimeoutError: If installation exceeds timeout
|
| 226 |
+
ImportError: If installation fails
|
| 227 |
+
"""
|
| 228 |
+
logger.info(f"Executing: {' '.join(cmd)}")
|
| 229 |
+
|
| 230 |
+
process = subprocess.Popen(
|
| 231 |
+
cmd,
|
| 232 |
+
stdout=subprocess.PIPE,
|
| 233 |
+
stderr=subprocess.PIPE
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
stdout, stderr = process.communicate(timeout=timeout)
|
| 238 |
+
|
| 239 |
+
# Log installation output for debugging
|
| 240 |
+
if stdout:
|
| 241 |
+
logger.debug(f"Installation stdout: {stdout.decode()}")
|
| 242 |
+
if stderr:
|
| 243 |
+
logger.debug(f"Installation stderr: {stderr.decode()}")
|
| 244 |
+
|
| 245 |
+
except subprocess.TimeoutExpired:
|
| 246 |
+
process.kill()
|
| 247 |
+
error_msg = f"Package installation timed out after {timeout} seconds"
|
| 248 |
+
logger.error(error_msg)
|
| 249 |
+
raise TimeoutError(error_msg)
|
| 250 |
+
|
| 251 |
+
if process.returncode != 0:
|
| 252 |
+
error_msg = f"Installation failed with code {process.returncode}: {stderr.decode()}"
|
| 253 |
+
logger.error(error_msg)
|
| 254 |
+
raise ImportError(error_msg)
|
| 255 |
+
|
| 256 |
+
logger.info("Installation completed successfully")
|
aworld/utils/json_encoder.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
# Copyright (c) 2025 inclusionAI.
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class NumpyEncoder(json.JSONEncoder):
|
| 9 |
+
def default(self, obj):
|
| 10 |
+
if isinstance(obj, np.ndarray):
|
| 11 |
+
return obj.tolist()
|
| 12 |
+
return super().default(obj)
|
aworld/utils/oss.py
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
"""
|
| 3 |
+
oss.py
|
| 4 |
+
Utility class for OSS (Object Storage Service) operations.
|
| 5 |
+
Provides simple methods for data operations: upload, read, delete, update.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import tempfile
|
| 10 |
+
from typing import Optional, Dict, List, Any, Tuple, Union, BinaryIO, TextIO, IO, AnyStr
|
| 11 |
+
|
| 12 |
+
from aworld.utils import import_package
|
| 13 |
+
from aworld.logs.util import logger
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class OSSClient:
|
| 17 |
+
"""
|
| 18 |
+
A utility class for OSS (Object Storage Service) operations.
|
| 19 |
+
Provides methods for data operations: upload, read, delete, update.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self,
|
| 23 |
+
access_key_id: Optional[str] = None,
|
| 24 |
+
access_key_secret: Optional[str] = None,
|
| 25 |
+
endpoint: Optional[str] = None,
|
| 26 |
+
bucket_name: Optional[str] = None,
|
| 27 |
+
enable_export: Optional[bool] = None):
|
| 28 |
+
"""
|
| 29 |
+
Initialize OSSClient with credentials.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
access_key_id: OSS access key ID. If None, will try to get from environment variable OSS_ACCESS_KEY_ID
|
| 33 |
+
access_key_secret: OSS access key secret. If None, will try to get from environment variable OSS_ACCESS_KEY_SECRET
|
| 34 |
+
endpoint: OSS endpoint. If None, will try to get from environment variable OSS_ENDPOINT
|
| 35 |
+
bucket_name: OSS bucket name. If None, will try to get from environment variable OSS_BUCKET_NAME
|
| 36 |
+
enable_export: Whether to enable OSS export. If None, will try to get from environment variable EXPORT_REPLAY_TRACE_TO_OSS
|
| 37 |
+
"""
|
| 38 |
+
self.access_key_id = access_key_id or os.getenv('OSS_ACCESS_KEY_ID')
|
| 39 |
+
self.access_key_secret = access_key_secret or os.getenv('OSS_ACCESS_KEY_SECRET')
|
| 40 |
+
self.endpoint = endpoint or os.getenv('OSS_ENDPOINT')
|
| 41 |
+
self.bucket_name = bucket_name or os.getenv('OSS_BUCKET_NAME')
|
| 42 |
+
self.enable_export = enable_export if enable_export is not None else os.getenv("EXPORT_REPLAY_TRACE_TO_OSS",
|
| 43 |
+
"false").lower() == "true"
|
| 44 |
+
self.bucket = None
|
| 45 |
+
self._initialized = False
|
| 46 |
+
|
| 47 |
+
def initialize(self) -> bool:
|
| 48 |
+
"""
|
| 49 |
+
Initialize the OSS client with the provided or environment credentials.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
bool: True if initialization is successful, False otherwise
|
| 53 |
+
"""
|
| 54 |
+
if self._initialized:
|
| 55 |
+
return True
|
| 56 |
+
|
| 57 |
+
if not self.enable_export:
|
| 58 |
+
logger.info("OSS export is disabled. Set EXPORT_REPLAY_TRACE_TO_OSS=true to enable.")
|
| 59 |
+
return False
|
| 60 |
+
|
| 61 |
+
if not all([self.access_key_id, self.access_key_secret, self.endpoint, self.bucket_name]):
|
| 62 |
+
logger.warn(
|
| 63 |
+
"Missing required OSS credentials. Please provide all required parameters or set environment variables.")
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
import_package("oss2")
|
| 68 |
+
import oss2
|
| 69 |
+
auth = oss2.Auth(self.access_key_id, self.access_key_secret)
|
| 70 |
+
self.bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name)
|
| 71 |
+
self._initialized = True
|
| 72 |
+
return True
|
| 73 |
+
except ImportError:
|
| 74 |
+
logger.warn("Failed to import oss2 module. Please install it with 'pip install oss2'.")
|
| 75 |
+
return False
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.warn(f"Failed to initialize OSS client. Error: {str(e)}")
|
| 78 |
+
return False
|
| 79 |
+
|
| 80 |
+
# ---- Basic Data Operation Methods ----
|
| 81 |
+
|
| 82 |
+
def upload_data(self, data: Union[IO[AnyStr], str, bytes, dict], oss_key: str) -> Optional[str]:
|
| 83 |
+
"""
|
| 84 |
+
Upload data to OSS. Supports various types of data:
|
| 85 |
+
- In-memory file objects (IO[AnyStr])
|
| 86 |
+
- Strings (str)
|
| 87 |
+
- Bytes (bytes)
|
| 88 |
+
- Dictionaries (dict), will be automatically converted to JSON
|
| 89 |
+
- File paths (str)
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
data: Data to upload, can be a file object or other supported types
|
| 93 |
+
oss_key: The key (path) in OSS where the data will be stored
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
str: The OSS key if successful, None otherwise
|
| 97 |
+
"""
|
| 98 |
+
if not self.initialize():
|
| 99 |
+
logger.warn("OSS client not initialized or export is disabled")
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
# Handle file objects
|
| 104 |
+
if hasattr(data, 'read'):
|
| 105 |
+
content = data.read()
|
| 106 |
+
if isinstance(content, str):
|
| 107 |
+
content = content.encode('utf-8')
|
| 108 |
+
self.bucket.put_object(oss_key, content)
|
| 109 |
+
logger.info(f"Successfully uploaded memory file to OSS: {oss_key}")
|
| 110 |
+
return oss_key
|
| 111 |
+
|
| 112 |
+
# Handle dictionaries
|
| 113 |
+
if isinstance(data, dict):
|
| 114 |
+
content = json.dumps(data, ensure_ascii=False).encode('utf-8')
|
| 115 |
+
self.bucket.put_object(oss_key, content)
|
| 116 |
+
return oss_key
|
| 117 |
+
|
| 118 |
+
# Handle strings
|
| 119 |
+
if isinstance(data, str):
|
| 120 |
+
# Check if it's a file path
|
| 121 |
+
if os.path.isfile(data):
|
| 122 |
+
self.bucket.put_object_from_file(oss_key, data)
|
| 123 |
+
logger.info(f"Successfully uploaded file {data} to OSS: {oss_key}")
|
| 124 |
+
return oss_key
|
| 125 |
+
# Otherwise treat as string content
|
| 126 |
+
content = data.encode('utf-8')
|
| 127 |
+
self.bucket.put_object(oss_key, content)
|
| 128 |
+
return oss_key
|
| 129 |
+
|
| 130 |
+
# Handle bytes
|
| 131 |
+
self.bucket.put_object(oss_key, data)
|
| 132 |
+
logger.info(f"Successfully uploaded data to OSS: {oss_key}")
|
| 133 |
+
return oss_key
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.warn(f"Failed to upload data to OSS: {str(e)}")
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
def read_data(self, oss_key: str, as_json: bool = False) -> Union[bytes, dict, str, None]:
|
| 139 |
+
"""
|
| 140 |
+
Read data from OSS.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
oss_key: The key (path) in OSS of the data to read
|
| 144 |
+
as_json: If True, parse the data as JSON and return a dict
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
The data as bytes, dict (if as_json=True), or None if failed
|
| 148 |
+
"""
|
| 149 |
+
if not self.initialize():
|
| 150 |
+
logger.warn("OSS client not initialized or export is disabled")
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
# Read data
|
| 155 |
+
result = self.bucket.get_object(oss_key)
|
| 156 |
+
data = result.read()
|
| 157 |
+
|
| 158 |
+
# Convert to string or JSON if requested
|
| 159 |
+
if as_json:
|
| 160 |
+
return json.loads(data)
|
| 161 |
+
|
| 162 |
+
return data
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.warn(f"Failed to read data from OSS: {str(e)}")
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
def read_text(self, oss_key: str) -> Optional[str]:
|
| 168 |
+
"""
|
| 169 |
+
Read text data from OSS.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
oss_key: The key (path) in OSS of the text to read
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
str: The text data, or None if failed
|
| 176 |
+
"""
|
| 177 |
+
data = self.read_data(oss_key)
|
| 178 |
+
if data is not None:
|
| 179 |
+
try:
|
| 180 |
+
return data.decode('utf-8')
|
| 181 |
+
except Exception as e:
|
| 182 |
+
logger.warn(f"Failed to decode data as UTF-8: {str(e)}")
|
| 183 |
+
return None
|
| 184 |
+
|
| 185 |
+
def delete_data(self, oss_key: str) -> bool:
|
| 186 |
+
"""
|
| 187 |
+
Delete data from OSS.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
oss_key: The key (path) in OSS of the data to delete
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
bool: True if successful, False otherwise
|
| 194 |
+
"""
|
| 195 |
+
if not self.initialize():
|
| 196 |
+
logger.warn("OSS client not initialized or export is disabled")
|
| 197 |
+
return False
|
| 198 |
+
|
| 199 |
+
try:
|
| 200 |
+
self.bucket.delete_object(oss_key)
|
| 201 |
+
logger.info(f"Successfully deleted data from OSS: {oss_key}")
|
| 202 |
+
return True
|
| 203 |
+
except Exception as e:
|
| 204 |
+
logger.warn(f"Failed to delete data from OSS: {str(e)}")
|
| 205 |
+
return False
|
| 206 |
+
|
| 207 |
+
def update_data(self, oss_key: str, data: Union[IO[AnyStr], str, bytes, dict]) -> Optional[str]:
|
| 208 |
+
"""
|
| 209 |
+
Update data in OSS (delete and upload).
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
oss_key: The key (path) in OSS of the data to update
|
| 213 |
+
data: New data to upload, can be a file object or other supported types
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
str: The OSS key if successful, None otherwise
|
| 217 |
+
"""
|
| 218 |
+
# For OSS, update is the same as upload (it overwrites)
|
| 219 |
+
return self.upload_data(data, oss_key)
|
| 220 |
+
|
| 221 |
+
def update_json(self, oss_key: str, update_dict: dict) -> Optional[str]:
|
| 222 |
+
"""
|
| 223 |
+
Update JSON data in OSS by merging with existing data.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
oss_key: The key (path) in OSS of the JSON data to update
|
| 227 |
+
update_dict: Dictionary with fields to update
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
str: The OSS key if successful, None otherwise
|
| 231 |
+
"""
|
| 232 |
+
if not self.initialize():
|
| 233 |
+
logger.warn("OSS client not initialized or export is disabled")
|
| 234 |
+
return None
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
# Read existing data
|
| 238 |
+
existing_data = self.read_data(oss_key, as_json=True)
|
| 239 |
+
if existing_data is None:
|
| 240 |
+
existing_data = {}
|
| 241 |
+
|
| 242 |
+
# Update data
|
| 243 |
+
if isinstance(existing_data, dict):
|
| 244 |
+
existing_data.update(update_dict)
|
| 245 |
+
else:
|
| 246 |
+
logger.warn(f"Existing data is not a dictionary: {oss_key}")
|
| 247 |
+
return None
|
| 248 |
+
|
| 249 |
+
# Upload updated data
|
| 250 |
+
return self.upload_data(existing_data, oss_key)
|
| 251 |
+
except Exception as e:
|
| 252 |
+
logger.warn(f"Failed to update JSON data in OSS: {str(e)}")
|
| 253 |
+
return None
|
| 254 |
+
|
| 255 |
+
# ---- File Operation Methods ----
|
| 256 |
+
|
| 257 |
+
def upload_file(self, local_file: str, oss_key: Optional[str] = None) -> Optional[str]:
|
| 258 |
+
"""
|
| 259 |
+
Upload a local file to OSS.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
local_file: Path to the local file
|
| 263 |
+
oss_key: The key (path) in OSS where the file will be stored.
|
| 264 |
+
If None, will use the basename of the local file
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
str: The OSS key if successful, None otherwise
|
| 268 |
+
"""
|
| 269 |
+
if not self.initialize():
|
| 270 |
+
logger.warn("OSS client not initialized or export is disabled")
|
| 271 |
+
return None
|
| 272 |
+
|
| 273 |
+
try:
|
| 274 |
+
if not os.path.exists(local_file):
|
| 275 |
+
logger.warn(f"Local file {local_file} does not exist")
|
| 276 |
+
return None
|
| 277 |
+
|
| 278 |
+
if oss_key is None:
|
| 279 |
+
oss_key = f"uploads/{os.path.basename(local_file)}"
|
| 280 |
+
|
| 281 |
+
self.bucket.put_object_from_file(oss_key, local_file)
|
| 282 |
+
logger.info(f"Successfully uploaded {local_file} to OSS: {oss_key}")
|
| 283 |
+
return oss_key
|
| 284 |
+
except Exception as e:
|
| 285 |
+
logger.warn(f"Failed to upload {local_file} to OSS: {str(e)}")
|
| 286 |
+
return None
|
| 287 |
+
|
| 288 |
+
def download_file(self, oss_key: str, local_file: str) -> bool:
|
| 289 |
+
"""
|
| 290 |
+
Download a file from OSS to local.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
oss_key: The key (path) in OSS of the file to download
|
| 294 |
+
local_file: Path where the downloaded file will be saved
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
bool: True if successful, False otherwise
|
| 298 |
+
"""
|
| 299 |
+
if not self.initialize():
|
| 300 |
+
logger.warn("OSS client not initialized or export is disabled")
|
| 301 |
+
return False
|
| 302 |
+
|
| 303 |
+
try:
|
| 304 |
+
# Ensure the directory exists
|
| 305 |
+
os.makedirs(os.path.dirname(os.path.abspath(local_file)), exist_ok=True)
|
| 306 |
+
|
| 307 |
+
# Download the file
|
| 308 |
+
self.bucket.get_object_to_file(oss_key, local_file)
|
| 309 |
+
logger.info(f"Successfully downloaded {oss_key} to {local_file}")
|
| 310 |
+
return True
|
| 311 |
+
except Exception as e:
|
| 312 |
+
logger.warn(f"Failed to download {oss_key} from OSS: {str(e)}")
|
| 313 |
+
return False
|
| 314 |
+
|
| 315 |
+
def list_objects(self, prefix: str = "", delimiter: str = "") -> List[Dict[str, Any]]:
|
| 316 |
+
"""
|
| 317 |
+
List objects in the OSS bucket with the given prefix.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
prefix: Prefix to filter objects
|
| 321 |
+
delimiter: Delimiter for hierarchical listing
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
List of objects with their properties
|
| 325 |
+
"""
|
| 326 |
+
if not self.initialize():
|
| 327 |
+
logger.warn("OSS client not initialized or export is disabled")
|
| 328 |
+
return []
|
| 329 |
+
|
| 330 |
+
try:
|
| 331 |
+
result = []
|
| 332 |
+
for obj in self.bucket.list_objects(prefix=prefix, delimiter=delimiter).object_list:
|
| 333 |
+
result.append({
|
| 334 |
+
'key': obj.key,
|
| 335 |
+
'size': obj.size,
|
| 336 |
+
'last_modified': obj.last_modified
|
| 337 |
+
})
|
| 338 |
+
return result
|
| 339 |
+
except Exception as e:
|
| 340 |
+
logger.warn(f"Failed to list objects with prefix {prefix}: {str(e)}")
|
| 341 |
+
return []
|
| 342 |
+
|
| 343 |
+
# ---- Advanced Operation Methods ----
|
| 344 |
+
|
| 345 |
+
def exists(self, oss_key: str) -> bool:
|
| 346 |
+
"""
|
| 347 |
+
Check if an object exists in OSS.
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
oss_key: The key (path) in OSS to check
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
bool: True if the object exists, False otherwise
|
| 354 |
+
"""
|
| 355 |
+
if not self.initialize():
|
| 356 |
+
logger.warn("OSS client not initialized or export is disabled")
|
| 357 |
+
return False
|
| 358 |
+
|
| 359 |
+
try:
|
| 360 |
+
# Use head_object to check if the object exists
|
| 361 |
+
self.bucket.head_object(oss_key)
|
| 362 |
+
return True
|
| 363 |
+
except:
|
| 364 |
+
return False
|
| 365 |
+
|
| 366 |
+
def copy_object(self, source_key: str, target_key: str) -> bool:
|
| 367 |
+
"""
|
| 368 |
+
Copy an object within the same bucket.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
source_key: The source object key
|
| 372 |
+
target_key: The target object key
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
bool: True if successful, False otherwise
|
| 376 |
+
"""
|
| 377 |
+
if not self.initialize():
|
| 378 |
+
logger.warn("OSS client not initialized or export is disabled")
|
| 379 |
+
return False
|
| 380 |
+
|
| 381 |
+
try:
|
| 382 |
+
self.bucket.copy_object(self.bucket_name, source_key, target_key)
|
| 383 |
+
logger.info(f"Successfully copied {source_key} to {target_key}")
|
| 384 |
+
return True
|
| 385 |
+
except Exception as e:
|
| 386 |
+
logger.warn(f"Failed to copy {source_key} to {target_key}: {str(e)}")
|
| 387 |
+
return False
|
| 388 |
+
|
| 389 |
+
def get_object_url(self, oss_key: str, expires: int = 3600) -> Optional[str]:
|
| 390 |
+
"""
|
| 391 |
+
Generate a temporary URL for accessing an object.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
oss_key: The key (path) in OSS of the object
|
| 395 |
+
expires: URL expiration time in seconds
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
str: The signed URL, or None if failed
|
| 399 |
+
"""
|
| 400 |
+
if not self.initialize():
|
| 401 |
+
logger.warn("OSS client not initialized or export is disabled")
|
| 402 |
+
return None
|
| 403 |
+
|
| 404 |
+
try:
|
| 405 |
+
url = self.bucket.sign_url('GET', oss_key, expires)
|
| 406 |
+
return url
|
| 407 |
+
except Exception as e:
|
| 408 |
+
logger.warn(f"Failed to generate URL for {oss_key}: {str(e)}")
|
| 409 |
+
return None
|
| 410 |
+
|
| 411 |
+
def upload_directory(self, local_dir: str, oss_prefix: str = "") -> Tuple[bool, List[str]]:
|
| 412 |
+
"""
|
| 413 |
+
Upload an entire directory to OSS.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
local_dir: Path to the local directory
|
| 417 |
+
oss_prefix: Prefix to prepend to all uploaded files
|
| 418 |
+
|
| 419 |
+
Returns:
|
| 420 |
+
Tuple of (success, list of uploaded files)
|
| 421 |
+
"""
|
| 422 |
+
if not self.initialize():
|
| 423 |
+
logger.warn("OSS client not initialized or export is disabled")
|
| 424 |
+
return False, []
|
| 425 |
+
|
| 426 |
+
if not os.path.isdir(local_dir):
|
| 427 |
+
logger.warn(f"Local directory {local_dir} does not exist or is not a directory")
|
| 428 |
+
return False, []
|
| 429 |
+
|
| 430 |
+
uploaded_files = []
|
| 431 |
+
errors = []
|
| 432 |
+
|
| 433 |
+
for root, _, files in os.walk(local_dir):
|
| 434 |
+
for file in files:
|
| 435 |
+
local_file = os.path.join(root, file)
|
| 436 |
+
rel_path = os.path.relpath(local_file, local_dir)
|
| 437 |
+
oss_key = os.path.join(oss_prefix, rel_path).replace("\\", "/")
|
| 438 |
+
|
| 439 |
+
result = self.upload_file(local_file, oss_key)
|
| 440 |
+
if result:
|
| 441 |
+
uploaded_files.append(result)
|
| 442 |
+
else:
|
| 443 |
+
errors.append(local_file)
|
| 444 |
+
|
| 445 |
+
if errors:
|
| 446 |
+
logger.warn(f"Failed to upload {len(errors)} files")
|
| 447 |
+
return False, uploaded_files
|
| 448 |
+
return True, uploaded_files
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def get_oss_client(access_key_id: Optional[str] = None,
|
| 452 |
+
access_key_secret: Optional[str] = None,
|
| 453 |
+
endpoint: Optional[str] = None,
|
| 454 |
+
bucket_name: Optional[str] = None,
|
| 455 |
+
enable_export: Optional[bool] = None) -> OSSClient:
|
| 456 |
+
"""
|
| 457 |
+
Factory function to create and initialize an OSSClient.
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
access_key_id: OSS access key ID
|
| 461 |
+
access_key_secret: OSS access key secret
|
| 462 |
+
endpoint: OSS endpoint
|
| 463 |
+
bucket_name: OSS bucket name
|
| 464 |
+
enable_export: Whether to enable OSS export
|
| 465 |
+
|
| 466 |
+
Returns:
|
| 467 |
+
OSSClient: An initialized OSSClient instance
|
| 468 |
+
"""
|
| 469 |
+
client = OSSClient(
|
| 470 |
+
access_key_id=access_key_id,
|
| 471 |
+
access_key_secret=access_key_secret,
|
| 472 |
+
endpoint=endpoint,
|
| 473 |
+
bucket_name=bucket_name,
|
| 474 |
+
enable_export=enable_export
|
| 475 |
+
)
|
| 476 |
+
client.initialize()
|
| 477 |
+
return client
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
# ---- Test Cases ----
|
| 481 |
+
if __name__ == "__main__":
|
| 482 |
+
"""
|
| 483 |
+
OSS tool class test cases
|
| 484 |
+
Note: Before running the tests, you need to set the following environment variables,
|
| 485 |
+
or provide the parameters directly in the test code:
|
| 486 |
+
- OSS_ACCESS_KEY_ID
|
| 487 |
+
- OSS_ACCESS_KEY_SECRET
|
| 488 |
+
- OSS_ENDPOINT
|
| 489 |
+
- OSS_BUCKET_NAME
|
| 490 |
+
- EXPORT_REPLAY_TRACE_TO_OSS=true
|
| 491 |
+
"""
|
| 492 |
+
import io
|
| 493 |
+
import time
|
| 494 |
+
|
| 495 |
+
# Test configuration
|
| 496 |
+
TEST_PREFIX = f"test/oss_utils_123" # Use timestamp to avoid conflicts
|
| 497 |
+
|
| 498 |
+
# Initialize client
|
| 499 |
+
# Method 1: Using environment variables
|
| 500 |
+
# oss_client = get_oss_client(enable_export=True)
|
| 501 |
+
|
| 502 |
+
# Method 2: Provide parameters directly
|
| 503 |
+
oss_client = get_oss_client(
|
| 504 |
+
access_key_id="", # Replace with your actual access key ID
|
| 505 |
+
access_key_secret="", # Replace with your actual access key secret
|
| 506 |
+
endpoint="", # Replace with your actual OSS endpoint
|
| 507 |
+
bucket_name="", # Replace with your actual bucket name
|
| 508 |
+
enable_export=True
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
text_key = f"{TEST_PREFIX}/text.txt"
|
| 512 |
+
result = oss_client.upload_data("malai This is a test text", text_key)
|
| 513 |
+
print(f"Upload string data: {'Success: ' + result if result else 'Failed'}")
|
| 514 |
+
print("\nTest 6: Read text data")
|
| 515 |
+
content = oss_client.read_text(text_key)
|
| 516 |
+
print(f"Read text data: {content}")
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
# Test 1: Upload string data
|
| 520 |
+
print("\nTest 1: Upload string data")
|
| 521 |
+
text_key = f"{TEST_PREFIX}/text.txt"
|
| 522 |
+
result = oss_client.upload_data("This is a test text", text_key)
|
| 523 |
+
print(f"Upload string data: {'Success: ' + result if result else 'Failed'}")
|
| 524 |
+
|
| 525 |
+
# Test 2: Upload dictionary data (automatically converted to JSON)
|
| 526 |
+
print("\nTest 2: Upload dictionary data")
|
| 527 |
+
json_key = f"{TEST_PREFIX}/data.json"
|
| 528 |
+
data = {
|
| 529 |
+
"name": "Test data",
|
| 530 |
+
"values": [1, 2, 3],
|
| 531 |
+
"nested": {
|
| 532 |
+
"key": "value"
|
| 533 |
+
}
|
| 534 |
+
}
|
| 535 |
+
result = oss_client.upload_data(data, json_key)
|
| 536 |
+
print(f"Upload dictionary data: {'Success: ' + result if result else 'Failed'}")
|
| 537 |
+
|
| 538 |
+
# Test 3: Upload in-memory binary file object
|
| 539 |
+
print("\nTest 3: Upload in-memory binary file object")
|
| 540 |
+
binary_key = f"{TEST_PREFIX}/binary.dat"
|
| 541 |
+
binary_data = io.BytesIO(b"\x00\x01\x02\x03\x04")
|
| 542 |
+
result = oss_client.upload_data(binary_data, binary_key)
|
| 543 |
+
print(f"Upload binary file object: {'Success: ' + result if result else 'Failed'}")
|
| 544 |
+
|
| 545 |
+
# Test 4: Upload in-memory text file object
|
| 546 |
+
print("\nTest 4: Upload in-memory text file object")
|
| 547 |
+
text_file_key = f"{TEST_PREFIX}/text_file.txt"
|
| 548 |
+
text_file = io.StringIO("This is the content of an in-memory text file")
|
| 549 |
+
result = oss_client.upload_data(text_file, text_file_key)
|
| 550 |
+
print(f"Upload text file object: {'Success: ' + result if result else 'Failed'}")
|
| 551 |
+
|
| 552 |
+
# Test 5: Create and upload temporary file
|
| 553 |
+
print("\nTest 5: Create and upload temporary file")
|
| 554 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
| 555 |
+
tmp.write(b"This is the content of a temporary file")
|
| 556 |
+
tmp_path = tmp.name
|
| 557 |
+
|
| 558 |
+
file_key = f"{TEST_PREFIX}/temp_file.txt"
|
| 559 |
+
result = oss_client.upload_file(tmp_path, file_key)
|
| 560 |
+
print(f"Upload temporary file: {'Success: ' + result if result else 'Failed'}")
|
| 561 |
+
os.unlink(tmp_path) # Delete temporary file
|
| 562 |
+
|
| 563 |
+
# Test 6: Read text data
|
| 564 |
+
print("\nTest 6: Read text data")
|
| 565 |
+
content = oss_client.read_text(text_key)
|
| 566 |
+
print(f"Read text data: {content}")
|
| 567 |
+
|
| 568 |
+
# Test 7: Read JSON data
|
| 569 |
+
print("\nTest 7: Read JSON data")
|
| 570 |
+
json_content = oss_client.read_data(json_key, as_json=True)
|
| 571 |
+
print(f"Read JSON data: {json_content}")
|
| 572 |
+
|
| 573 |
+
# Test 8: Update JSON data (merge method)
|
| 574 |
+
print("\nTest 8: Update JSON data")
|
| 575 |
+
update_data = {"updated": True, "timestamp": time.time()}
|
| 576 |
+
result = oss_client.update_json(json_key, update_data)
|
| 577 |
+
print(f"Update JSON data: {'Success: ' + result if result else 'Failed'}")
|
| 578 |
+
|
| 579 |
+
# View updated JSON data
|
| 580 |
+
updated_json = oss_client.read_data(json_key, as_json=True)
|
| 581 |
+
print(f"Updated JSON data: {updated_json}")
|
| 582 |
+
|
| 583 |
+
# Test 9: Overwrite existing data
|
| 584 |
+
print("\nTest 9: Overwrite existing data")
|
| 585 |
+
result = oss_client.upload_data("This is the overwritten text", text_key)
|
| 586 |
+
print(f"Overwrite existing data: {'Success: ' + result if result else 'Failed'}")
|
| 587 |
+
|
| 588 |
+
# View overwritten data
|
| 589 |
+
new_content = oss_client.read_text(text_key)
|
| 590 |
+
print(f"Overwritten text data: {new_content}")
|
| 591 |
+
|
| 592 |
+
# Test 10: List objects
|
| 593 |
+
print("\nTest 10: List objects")
|
| 594 |
+
objects = oss_client.list_objects(prefix=TEST_PREFIX)
|
| 595 |
+
print(f"Found {len(objects)} objects:")
|
| 596 |
+
for obj in objects:
|
| 597 |
+
print(f" - {obj['key']} (Size: {obj['size']} bytes, Modified: {obj['last_modified']})")
|
| 598 |
+
|
| 599 |
+
# Test 11: Generate temporary URL
|
| 600 |
+
print("\nTest 11: Generate temporary URL")
|
| 601 |
+
url = oss_client.get_object_url(text_key, expires=300) # 5 minutes expiration
|
| 602 |
+
print(f"Temporary URL: {url}")
|
| 603 |
+
|
| 604 |
+
# Test 12: Copy object
|
| 605 |
+
print("\nTest 12: Copy object")
|
| 606 |
+
copy_key = f"{TEST_PREFIX}/copy_of_text.txt"
|
| 607 |
+
result = oss_client.copy_object(text_key, copy_key)
|
| 608 |
+
print(f"Copy object: {'Success: ' + copy_key if result else 'Failed'}")
|
| 609 |
+
|
| 610 |
+
# Test 13: Check if object exists
|
| 611 |
+
print("\nTest 13: Check if object exists")
|
| 612 |
+
exists = oss_client.exists(text_key)
|
| 613 |
+
print(f"Object {text_key} exists: {exists}")
|
| 614 |
+
|
| 615 |
+
non_existent_key = f"{TEST_PREFIX}/non_existent.txt"
|
| 616 |
+
exists = oss_client.exists(non_existent_key)
|
| 617 |
+
print(f"Object {non_existent_key} exists: {exists}")
|
| 618 |
+
|
| 619 |
+
# Test 14: Delete objects
|
| 620 |
+
print("\nTest 14: Delete objects")
|
| 621 |
+
for obj in objects:
|
| 622 |
+
success = oss_client.delete_data(obj['key'])
|
| 623 |
+
print(f"Delete object {obj['key']}: {'Success' if success else 'Failed'}")
|
| 624 |
+
|
| 625 |
+
# Cleanup: Delete copied object (may not be included in the previous list)
|
| 626 |
+
oss_client.delete_data(copy_key)
|
| 627 |
+
|
| 628 |
+
print("\nTests completed!")
|