""" Entity extraction module using Gemini AI with fallback methods """ import re import logging from typing import List, Optional import google.generativeai as genai from services.appconfig import GEMINI_API_KEY, COMMON_TECH_ENTITIES, MAX_ENTITIES logger = logging.getLogger(__name__) class EntityExtractor: """Extract entities from text using Gemini AI or fallback methods""" def __init__(self, api_key: Optional[str] = None): """ Initialize EntityExtractor Args: api_key (str, optional): Gemini API key """ self.api_key = api_key or GEMINI_API_KEY self.model = None self._setup_gemini() def _setup_gemini(self) -> None: """Setup Gemini API""" if not self.api_key: logger.warning("No Gemini API key provided, using fallback method") return try: genai.configure(api_key=self.api_key) self.model = genai.GenerativeModel('gemini-2.0-flash-exp') logger.info("Gemini API initialized successfully") except Exception as e: logger.error(f"Failed to initialize Gemini API: {e}") self.model = None def extract_with_gemini(self, text: str) -> List[str]: """ Extract entities using Gemini AI Args: text (str): Input text Returns: List[str]: List of extracted entities """ if not self.model: raise Exception("Gemini model not available") prompt = """ Extract company names, product names, software names, tool names, and brand names from this text. Only return names that would have recognizable logos (like Microsoft, Adobe, React, etc.). Return as a simple list, one name per line, no bullet points or numbers. Avoid generic terms like "cloud" or "database". Text: {text} """.format(text=text) try: response = self.model.generate_content(prompt) if not response.text: return [] entities = [ line.strip() for line in response.text.strip().split('\n') if line.strip() and not line.strip().startswith('-') and len(line.strip()) > 1 ] # Filter out common words that aren't entities filtered_entities = [] for entity in entities: if self._is_valid_entity(entity): filtered_entities.append(entity) logger.info(f"Gemini extracted {len(filtered_entities)} entities") return filtered_entities[:MAX_ENTITIES] except Exception as e: logger.error(f"Gemini extraction failed: {e}") raise def extract_with_fallback(self, text: str) -> List[str]: """ Extract entities using fallback pattern matching Args: text (str): Input text Returns: List[str]: List of extracted entities """ entities = [] # Find common tech entities for tech_entity in COMMON_TECH_ENTITIES: if tech_entity.lower() in text.lower(): entities.append(tech_entity) # Find capitalized words (likely proper nouns) cap_words = re.findall(r'\b[A-Z][a-zA-Z]{2,}\b', text) for word in cap_words: if self._is_valid_entity(word) and word not in entities: entities.append(word) # Find words with specific patterns (e.g., Node.js, C++) pattern_words = re.findall(r'\b[A-Z][a-zA-Z]*\.[a-zA-Z]+\b', text) for word in pattern_words: if word not in entities: entities.append(word) # Remove duplicates while preserving order unique_entities = [] seen = set() for entity in entities: if entity.lower() not in seen: seen.add(entity.lower()) unique_entities.append(entity) logger.info(f"Fallback extracted {len(unique_entities)} entities") return unique_entities[:MAX_ENTITIES] def _is_valid_entity(self, entity: str) -> bool: """ Check if entity is valid for logo extraction Args: entity (str): Entity name Returns: bool: True if valid entity """ # Filter out common words that aren't brand names invalid_words = { 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'from', 'up', 'about', 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'between', 'among'} # 'cloud', 'database', # 'server', 'client', 'user', 'admin', 'data', 'system', 'network', # 'security', 'management', 'development', 'application', 'platform', # 'service', 'solution', 'technology', 'software', 'hardware', 'tool' # } entity_lower = entity.lower() # Check length if len(entity) < 2 or len(entity) > 50: return False # Check if it's a common invalid word if entity_lower in invalid_words: return False # Must contain at least one letter if not re.search(r'[a-zA-Z]', entity): return False return True def extract_entities(self, text: str) -> List[str]: """ Extract entities from text using available methods Args: text (str): Input text Returns: List[str]: List of extracted entities """ if not text or not text.strip(): return [] logger.info("Starting entity extraction...") # Try Gemini first if self.model: try: entities = self.extract_with_gemini(text) if entities: logger.info(f"Successfully extracted {len(entities)} entities with Gemini") return entities except Exception as e: logger.warning(f"Gemini extraction failed, using fallback: {e}") # Use fallback method entities = self.extract_with_fallback(text) logger.info(f"Extracted {len(entities)} entities using fallback method") return entities