yonnel
Refactor vector saving to use a temporary file with appropriate permissions for metadata; improve error handling during cleanup
5e003fe
from datasets import Dataset, load_dataset
from huggingface_hub import HfApi, create_repo
import numpy as np
import json
import logging
from typing import Dict, List, Tuple, Optional
import os
from datetime import datetime
import tempfile
logger = logging.getLogger(__name__)
class HFVectorStorage:
def __init__(self):
self.hf_token = os.getenv('HF_TOKEN')
self.repo_name = os.getenv('HF_DATASET_REPO')
# Configure HF cache directory to a writable location
self._setup_hf_cache()
if self.hf_token and self.repo_name:
self.api = HfApi(token=self.hf_token)
# Créer le repo s'il n'existe pas
try:
create_repo(
repo_id=self.repo_name,
repo_type="dataset",
token=self.hf_token,
private=True,
exist_ok=True
)
except Exception as e:
logger.warning(f"Repo creation warning: {e}")
else:
self.api = None
logger.warning("HF_TOKEN or HF_DATASET_REPO not configured")
def _setup_hf_cache(self):
"""Setup HF cache directory to avoid permission issues"""
try:
# Try to use a writable cache directory
cache_dirs = [
os.getenv('HF_HOME'),
os.getenv('XDG_CACHE_HOME'),
os.path.expanduser('~/.cache/huggingface'),
'/tmp/hf_cache',
tempfile.gettempdir() + '/hf_cache'
]
for cache_dir in cache_dirs:
if cache_dir:
try:
os.makedirs(cache_dir, exist_ok=True)
# Test write permission
test_file = os.path.join(cache_dir, 'test_write')
with open(test_file, 'w') as f:
f.write('test')
os.remove(test_file)
# Set environment variables for HF
os.environ['HF_HOME'] = cache_dir
os.environ['HUGGINGFACE_HUB_CACHE'] = cache_dir
logger.info(f"Using HF cache directory: {cache_dir}")
return
except (OSError, PermissionError):
continue
logger.warning("Could not find writable cache directory, using default")
except Exception as e:
logger.warning(f"Error setting up HF cache: {e}")
def save_vectors(self, embeddings: np.ndarray, movies_data: List[Dict],
id_map: Dict, metadata: Dict) -> bool:
"""Sauvegarde les vecteurs sur HF Dataset Hub"""
try:
if not self.hf_token or not self.repo_name:
logger.error("HF_TOKEN or HF_DATASET_REPO not configured")
return False
# Préparer les données pour le dataset
dataset_dict = {
'movie_id': [movie['id'] for movie in movies_data],
'title': [movie['title'] for movie in movies_data],
'overview': [movie.get('overview', '') for movie in movies_data],
'genres': [movie.get('genres', []) for movie in movies_data],
'release_date': [movie.get('release_date', '') for movie in movies_data],
'embedding': embeddings.tolist(),
'tmdb_data': [json.dumps(movie) for movie in movies_data]
}
# Créer le dataset
dataset = Dataset.from_dict(dataset_dict)
# Upload vers HF Hub
dataset.push_to_hub(
self.repo_name,
token=self.hf_token,
commit_message=f"Update vectors - {datetime.now().isoformat()}"
)
# Sauvegarder les métadonnées avec un fichier temporaire dans un répertoire accessible
metadata_with_timestamp = {
**metadata,
'last_updated': datetime.now().isoformat(),
'total_movies': len(movies_data)
}
# Utiliser un répertoire temporaire avec permissions appropriées
temp_file = os.path.join(tempfile.gettempdir(), f'karl_metadata_{os.getpid()}.json')
try:
with open(temp_file, 'w') as f:
json.dump(metadata_with_timestamp, f, indent=2)
self.api.upload_file(
path_or_fileobj=temp_file,
path_in_repo='metadata.json',
repo_id=self.repo_name,
repo_type='dataset',
token=self.hf_token,
commit_message=f"Update metadata - {datetime.now().isoformat()}"
)
logger.info(f"Successfully saved {len(movies_data)} movie vectors to HF Hub")
return True
finally:
# Nettoyer le fichier temporaire
try:
if os.path.exists(temp_file):
os.remove(temp_file)
except Exception as cleanup_error:
logger.warning(f"Could not remove temp file {temp_file}: {cleanup_error}")
except Exception as e:
logger.error(f"Error saving vectors to HF Hub: {e}")
return False
def load_vectors(self) -> Optional[Tuple[np.ndarray, List[Dict], Dict, Dict]]:
"""Charge les vecteurs depuis HF Dataset Hub"""
try:
if not self.hf_token or not self.repo_name:
logger.error("HF_TOKEN or HF_DATASET_REPO not configured")
return None
# Charger le dataset
dataset = load_dataset(self.repo_name, token=self.hf_token)['train']
# Extraire les données
embeddings = np.array(dataset['embedding'])
movies_data = []
id_map = {}
for i, movie_id in enumerate(dataset['movie_id']):
movie_data = json.loads(dataset['tmdb_data'][i])
movies_data.append(movie_data)
id_map[movie_id] = i
# Charger les métadonnées
try:
metadata_file = self.api.hf_hub_download(
repo_id=self.repo_name,
filename='metadata.json',
repo_type='dataset',
token=self.hf_token
)
with open(metadata_file, 'r') as f:
metadata = json.load(f)
except:
metadata = {'last_updated': None}
logger.info(f"Successfully loaded {len(movies_data)} movie vectors from HF Hub")
return embeddings, movies_data, id_map, metadata
except Exception as e:
logger.error(f"Error loading vectors from HF Hub: {e}")
return None
def check_update_needed(self) -> bool:
"""Vérifie si une mise à jour est nécessaire"""
try:
update_interval = int(os.getenv('UPDATE_INTERVAL_HOURS', 24))
# Charger les métadonnées actuelles
result = self.load_vectors()
if not result:
return True
_, _, _, metadata = result
if not metadata.get('last_updated'):
return True
last_update = datetime.fromisoformat(metadata['last_updated'])
hours_since_update = (datetime.now() - last_update).total_seconds() / 3600
return hours_since_update >= update_interval
except Exception as e:
logger.error(f"Error checking update status: {e}")
return True