Spaces:
Sleeping
Sleeping
""" | |
Entity extraction module using Gemini AI with fallback methods | |
""" | |
import re | |
import logging | |
from typing import List, Optional | |
import google.generativeai as genai | |
from services.appconfig import GEMINI_API_KEY, COMMON_TECH_ENTITIES, MAX_ENTITIES | |
logger = logging.getLogger(__name__) | |
class EntityExtractor: | |
"""Extract entities from text using Gemini AI or fallback methods""" | |
def __init__(self, api_key: Optional[str] = None): | |
""" | |
Initialize EntityExtractor | |
Args: | |
api_key (str, optional): Gemini API key | |
""" | |
self.api_key = api_key or GEMINI_API_KEY | |
self.model = None | |
self._setup_gemini() | |
def _setup_gemini(self) -> None: | |
"""Setup Gemini API""" | |
if not self.api_key: | |
logger.warning("No Gemini API key provided, using fallback method") | |
return | |
try: | |
genai.configure(api_key=self.api_key) | |
self.model = genai.GenerativeModel('gemini-2.0-flash-exp') | |
logger.info("Gemini API initialized successfully") | |
except Exception as e: | |
logger.error(f"Failed to initialize Gemini API: {e}") | |
self.model = None | |
def extract_with_gemini(self, text: str) -> List[str]: | |
""" | |
Extract entities using Gemini AI | |
Args: | |
text (str): Input text | |
Returns: | |
List[str]: List of extracted entities | |
""" | |
if not self.model: | |
raise Exception("Gemini model not available") | |
prompt = """ | |
Extract company names, product names, software names, tool names, and brand names from this text. | |
Only return names that would have recognizable logos (like Microsoft, Adobe, React, etc.). | |
Return as a simple list, one name per line, no bullet points or numbers. | |
Avoid generic terms like "cloud" or "database". | |
Text: {text} | |
""".format(text=text) | |
try: | |
response = self.model.generate_content(prompt) | |
if not response.text: | |
return [] | |
entities = [ | |
line.strip() | |
for line in response.text.strip().split('\n') | |
if line.strip() and not line.strip().startswith('-') and len(line.strip()) > 1 | |
] | |
# Filter out common words that aren't entities | |
filtered_entities = [] | |
for entity in entities: | |
if self._is_valid_entity(entity): | |
filtered_entities.append(entity) | |
logger.info(f"Gemini extracted {len(filtered_entities)} entities") | |
return filtered_entities[:MAX_ENTITIES] | |
except Exception as e: | |
logger.error(f"Gemini extraction failed: {e}") | |
raise | |
def extract_with_fallback(self, text: str) -> List[str]: | |
""" | |
Extract entities using fallback pattern matching | |
Args: | |
text (str): Input text | |
Returns: | |
List[str]: List of extracted entities | |
""" | |
entities = [] | |
# Find common tech entities | |
for tech_entity in COMMON_TECH_ENTITIES: | |
if tech_entity.lower() in text.lower(): | |
entities.append(tech_entity) | |
# Find capitalized words (likely proper nouns) | |
cap_words = re.findall(r'\b[A-Z][a-zA-Z]{2,}\b', text) | |
for word in cap_words: | |
if self._is_valid_entity(word) and word not in entities: | |
entities.append(word) | |
# Find words with specific patterns (e.g., Node.js, C++) | |
pattern_words = re.findall(r'\b[A-Z][a-zA-Z]*\.[a-zA-Z]+\b', text) | |
for word in pattern_words: | |
if word not in entities: | |
entities.append(word) | |
# Remove duplicates while preserving order | |
unique_entities = [] | |
seen = set() | |
for entity in entities: | |
if entity.lower() not in seen: | |
seen.add(entity.lower()) | |
unique_entities.append(entity) | |
logger.info(f"Fallback extracted {len(unique_entities)} entities") | |
return unique_entities[:MAX_ENTITIES] | |
def _is_valid_entity(self, entity: str) -> bool: | |
""" | |
Check if entity is valid for logo extraction | |
Args: | |
entity (str): Entity name | |
Returns: | |
bool: True if valid entity | |
""" | |
# Filter out common words that aren't brand names | |
invalid_words = { | |
'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', | |
'by', 'from', 'up', 'about', 'into', 'through', 'during', 'before', | |
'after', 'above', 'below', 'between', 'among'} | |
# 'cloud', 'database', | |
# 'server', 'client', 'user', 'admin', 'data', 'system', 'network', | |
# 'security', 'management', 'development', 'application', 'platform', | |
# 'service', 'solution', 'technology', 'software', 'hardware', 'tool' | |
# } | |
entity_lower = entity.lower() | |
# Check length | |
if len(entity) < 2 or len(entity) > 50: | |
return False | |
# Check if it's a common invalid word | |
if entity_lower in invalid_words: | |
return False | |
# Must contain at least one letter | |
if not re.search(r'[a-zA-Z]', entity): | |
return False | |
return True | |
def extract_entities(self, text: str) -> List[str]: | |
""" | |
Extract entities from text using available methods | |
Args: | |
text (str): Input text | |
Returns: | |
List[str]: List of extracted entities | |
""" | |
if not text or not text.strip(): | |
return [] | |
logger.info("Starting entity extraction...") | |
# Try Gemini first | |
if self.model: | |
try: | |
entities = self.extract_with_gemini(text) | |
if entities: | |
logger.info(f"Successfully extracted {len(entities)} entities with Gemini") | |
return entities | |
except Exception as e: | |
logger.warning(f"Gemini extraction failed, using fallback: {e}") | |
# Use fallback method | |
entities = self.extract_with_fallback(text) | |
logger.info(f"Extracted {len(entities)} entities using fallback method") | |
return entities |