Spaces:
Sleeping
Sleeping
""" | |
Main Logo Downloader class that orchestrates the entire process | |
""" | |
import os | |
import zipfile | |
import logging | |
from pathlib import Path | |
from typing import List, Tuple, Dict, Optional | |
from services.appconfig import DOWNLOADS_DIR, DEFAULT_LOGOS_PER_ENTITY | |
from utils.utils import create_safe_filename, create_directory, format_file_size | |
from .entity_extractor import EntityExtractor | |
from .image_downloader import ImageDownloader | |
logger = logging.getLogger(__name__) | |
class LogoDownloader: | |
"""Main class for downloading logos based on extracted entities""" | |
def __init__(self, gemini_api_key: str, output_dir: Optional[str] = None): | |
""" | |
Initialize LogoDownloader | |
Args: | |
gemini_api_key (str): Gemini API key for entity extraction | |
output_dir (str): Directory to save downloads | |
""" | |
self.output_dir = Path(output_dir) if output_dir else DOWNLOADS_DIR | |
self.entity_extractor = EntityExtractor(gemini_api_key) | |
self.image_downloader = ImageDownloader() | |
self.stats = { | |
'total_entities': 0, | |
'total_downloads': 0, | |
'successful_entities': 0, | |
'failed_entities': 0 | |
} | |
# Create output directory | |
create_directory(self.output_dir) | |
def process_text(self, text: str, logos_per_entity: int = DEFAULT_LOGOS_PER_ENTITY) -> Dict: | |
""" | |
Main processing function: extract entities and download logos | |
Args: | |
text (str): Input text containing entity references | |
logos_per_entity (int): Number of logos to download per entity | |
Returns: | |
Dict: Processing results and statistics | |
""" | |
logger.info("Starting logo download process...") | |
# Reset stats | |
self._reset_stats() | |
# Extract entities | |
entities = self.entity_extractor.extract_entities(text) | |
if not entities: | |
logger.warning("No entities found in text") | |
return self._get_results("No entities found in the provided text") | |
self.stats['total_entities'] = len(entities) | |
logger.info(f"Found {len(entities)} entities: {', '.join(entities)}") | |
# Download logos for each entity | |
results = [] | |
for i, entity in enumerate(entities, 1): | |
logger.info(f"Processing [{i}/{len(entities)}]: {entity}") | |
try: | |
result = self._process_single_entity(entity, logos_per_entity) | |
results.append(result) | |
if result['downloaded_count'] > 0: | |
self.stats['successful_entities'] += 1 | |
self.stats['total_downloads'] += result['downloaded_count'] | |
else: | |
self.stats['failed_entities'] += 1 | |
except Exception as e: | |
logger.error(f"Failed to process entity {entity}: {e}") | |
self.stats['failed_entities'] += 1 | |
results.append({ | |
'entity': entity, | |
'downloaded_count': 0, | |
'files': [], | |
'error': str(e) | |
}) | |
# Create zip package if we have downloads | |
zip_path = None | |
if self.stats['total_downloads'] > 0: | |
zip_path = self._create_zip_package() | |
return self._get_results( | |
"Processing completed successfully", | |
entities=entities, | |
results=results, | |
zip_path=zip_path | |
) | |
def _process_single_entity(self, entity: str, logos_per_entity: int) -> Dict: | |
""" | |
Process a single entity: create folder and download logos | |
Args: | |
entity (str): Entity name | |
logos_per_entity (int): Number of logos to download | |
Returns: | |
Dict: Processing result for this entity | |
""" | |
safe_name = create_safe_filename(entity) | |
entity_folder = self.output_dir / safe_name | |
# Create entity folder | |
if not create_directory(entity_folder): | |
raise Exception(f"Failed to create directory for {entity}") | |
# Download logos | |
downloaded_count, downloaded_files = self.image_downloader.download_logos_for_entity( | |
entity, str(entity_folder), logos_per_entity | |
) | |
return { | |
'entity': entity, | |
'safe_name': safe_name, | |
'downloaded_count': downloaded_count, | |
'files': downloaded_files, | |
'folder': str(entity_folder) | |
} | |
def _create_zip_package(self) -> str: | |
""" | |
Create ZIP package of all downloaded logos | |
Returns: | |
str: Path to created ZIP file | |
""" | |
zip_filename = f"{self.output_dir.name}_logos.zip" | |
zip_path = self.output_dir.parent / zip_filename | |
logger.info(f"Creating ZIP package: {zip_path}") | |
try: | |
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
for root, dirs, files in os.walk(self.output_dir): | |
for file in files: | |
file_path = os.path.join(root, file) | |
arcname = os.path.relpath(file_path, self.output_dir) | |
zipf.write(file_path, arcname) | |
file_size = os.path.getsize(zip_path) | |
logger.info(f"ZIP package created: {zip_path} ({format_file_size(file_size)})") | |
return str(zip_path) | |
except Exception as e: | |
logger.error(f"Failed to create ZIP package: {e}") | |
raise | |
def _reset_stats(self) -> None: | |
"""Reset processing statistics""" | |
self.stats = { | |
'total_entities': 0, | |
'total_downloads': 0, | |
'successful_entities': 0, | |
'failed_entities': 0 | |
} | |
def _get_results(self, message: str, **kwargs) -> Dict: | |
""" | |
Get formatted results dictionary | |
Args: | |
message (str): Status message | |
**kwargs: Additional result data | |
Returns: | |
Dict: Formatted results | |
""" | |
return { | |
'status': 'success' if self.stats['total_downloads'] > 0 else 'warning', | |
'message': message, | |
'stats': self.stats.copy(), | |
**kwargs | |
} | |
def get_stats_summary(self) -> str: | |
""" | |
Get human-readable stats summary | |
Returns: | |
str: Stats summary | |
""" | |
if self.stats['total_entities'] == 0: | |
return "No entities processed" | |
avg_downloads = ( | |
self.stats['total_downloads'] / self.stats['successful_entities'] | |
if self.stats['successful_entities'] > 0 else 0 | |
) | |
return ( | |
f"Processed {self.stats['total_entities']} entities. " | |
f"Successfully downloaded {self.stats['total_downloads']} logos " | |
f"({avg_downloads:.1f} average per entity). " | |
f"Success rate: {self.stats['successful_entities']}/{self.stats['total_entities']}" | |
) | |
def download_logos(text: str, gemini_api_key: str, logos_per_entity: int = DEFAULT_LOGOS_PER_ENTITY) -> Dict: | |
""" | |
Convenience function for downloading logos | |
Args: | |
text (str): Text containing entity references | |
gemini_api_key (str): Gemini API key | |
logos_per_entity (int): Number of logos per entity | |
Returns: | |
Dict: Processing results | |
""" | |
downloader = LogoDownloader(gemini_api_key) | |
return downloader.process_text(text, logos_per_entity) |