Spaces:
Running
Running
import ast | |
from abc import ABC, abstractmethod | |
from app.config import config | |
from app.models import const | |
# Base class for state management | |
class BaseState(ABC): | |
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs): | |
pass | |
def get_task(self, task_id: str): | |
pass | |
# Memory state management | |
class MemoryState(BaseState): | |
def __init__(self): | |
self._tasks = {} | |
def update_task( | |
self, | |
task_id: str, | |
state: int = const.TASK_STATE_PROCESSING, | |
progress: int = 0, | |
**kwargs, | |
): | |
progress = int(progress) | |
if progress > 100: | |
progress = 100 | |
self._tasks[task_id] = { | |
"state": state, | |
"progress": progress, | |
**kwargs, | |
} | |
def get_task(self, task_id: str): | |
return self._tasks.get(task_id, None) | |
def delete_task(self, task_id: str): | |
if task_id in self._tasks: | |
del self._tasks[task_id] | |
# Redis state management | |
class RedisState(BaseState): | |
def __init__(self, host="localhost", port=6379, db=0, password=None): | |
import redis | |
self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password) | |
def update_task( | |
self, | |
task_id: str, | |
state: int = const.TASK_STATE_PROCESSING, | |
progress: int = 0, | |
**kwargs, | |
): | |
progress = int(progress) | |
if progress > 100: | |
progress = 100 | |
fields = { | |
"state": state, | |
"progress": progress, | |
**kwargs, | |
} | |
for field, value in fields.items(): | |
self._redis.hset(task_id, field, str(value)) | |
def get_task(self, task_id: str): | |
task_data = self._redis.hgetall(task_id) | |
if not task_data: | |
return None | |
task = { | |
key.decode("utf-8"): self._convert_to_original_type(value) | |
for key, value in task_data.items() | |
} | |
return task | |
def delete_task(self, task_id: str): | |
self._redis.delete(task_id) | |
def _convert_to_original_type(value): | |
""" | |
Convert the value from byte string to its original data type. | |
You can extend this method to handle other data types as needed. | |
""" | |
value_str = value.decode("utf-8") | |
try: | |
# try to convert byte string array to list | |
return ast.literal_eval(value_str) | |
except (ValueError, SyntaxError): | |
pass | |
if value_str.isdigit(): | |
return int(value_str) | |
# Add more conversions here if needed | |
return value_str | |
# Global state | |
_enable_redis = config.app.get("enable_redis", False) | |
_redis_host = config.app.get("redis_host", "localhost") | |
_redis_port = config.app.get("redis_port", 6379) | |
_redis_db = config.app.get("redis_db", 0) | |
_redis_password = config.app.get("redis_password", None) | |
state = ( | |
RedisState( | |
host=_redis_host, port=_redis_port, db=_redis_db, password=_redis_password | |
) | |
if _enable_redis | |
else MemoryState() | |
) | |