File size: 6,640 Bytes
21d27b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
Entity extraction module using Gemini AI with fallback methods
"""
import re
import logging
from typing import List, Optional
import google.generativeai as genai

from services.appconfig import GEMINI_API_KEY, COMMON_TECH_ENTITIES, MAX_ENTITIES

logger = logging.getLogger(__name__)


class EntityExtractor:
    """Extract entities from text using Gemini AI or fallback methods"""
    
    def __init__(self, api_key: Optional[str] = None):
        """
        Initialize EntityExtractor
        
        Args:
            api_key (str, optional): Gemini API key
        """
        self.api_key = api_key or GEMINI_API_KEY
        self.model = None
        self._setup_gemini()
    
    def _setup_gemini(self) -> None:
        """Setup Gemini API"""
        if not self.api_key:
            logger.warning("No Gemini API key provided, using fallback method")
            return
        
        try:
            genai.configure(api_key=self.api_key)
            self.model = genai.GenerativeModel('gemini-2.0-flash-exp')
            logger.info("Gemini API initialized successfully")
        except Exception as e:
            logger.error(f"Failed to initialize Gemini API: {e}")
            self.model = None
    
    def extract_with_gemini(self, text: str) -> List[str]:
        """
        Extract entities using Gemini AI
        
        Args:
            text (str): Input text
            
        Returns:
            List[str]: List of extracted entities
        """
        if not self.model:
            raise Exception("Gemini model not available")
        
        prompt = """
        Extract company names, product names, software names, tool names, and brand names from this text.
        Only return names that would have recognizable logos (like Microsoft, Adobe, React, etc.).
        Return as a simple list, one name per line, no bullet points or numbers.
        Avoid generic terms like "cloud" or "database".
        
        Text: {text}
        """.format(text=text)
        
        try:
            response = self.model.generate_content(prompt)
            
            if not response.text:
                return []
            
            entities = [
                line.strip() 
                for line in response.text.strip().split('\n') 
                if line.strip() and not line.strip().startswith('-') and len(line.strip()) > 1
            ]
            
            # Filter out common words that aren't entities
            filtered_entities = []
            for entity in entities:
                if self._is_valid_entity(entity):
                    filtered_entities.append(entity)
            
            logger.info(f"Gemini extracted {len(filtered_entities)} entities")
            return filtered_entities[:MAX_ENTITIES]
            
        except Exception as e:
            logger.error(f"Gemini extraction failed: {e}")
            raise
    
    def extract_with_fallback(self, text: str) -> List[str]:
        """
        Extract entities using fallback pattern matching
        
        Args:
            text (str): Input text
            
        Returns:
            List[str]: List of extracted entities
        """
        entities = []
        
        # Find common tech entities
        for tech_entity in COMMON_TECH_ENTITIES:
            if tech_entity.lower() in text.lower():
                entities.append(tech_entity)
        
        # Find capitalized words (likely proper nouns)
        cap_words = re.findall(r'\b[A-Z][a-zA-Z]{2,}\b', text)
        for word in cap_words:
            if self._is_valid_entity(word) and word not in entities:
                entities.append(word)
        
        # Find words with specific patterns (e.g., Node.js, C++)
        pattern_words = re.findall(r'\b[A-Z][a-zA-Z]*\.[a-zA-Z]+\b', text)
        for word in pattern_words:
            if word not in entities:
                entities.append(word)
        
        # Remove duplicates while preserving order
        unique_entities = []
        seen = set()
        for entity in entities:
            if entity.lower() not in seen:
                seen.add(entity.lower())
                unique_entities.append(entity)
        
        logger.info(f"Fallback extracted {len(unique_entities)} entities")
        return unique_entities[:MAX_ENTITIES]
    
    def _is_valid_entity(self, entity: str) -> bool:
        """
        Check if entity is valid for logo extraction
        
        Args:
            entity (str): Entity name
            
        Returns:
            bool: True if valid entity
        """
        # Filter out common words that aren't brand names
        invalid_words = {
            'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with',
            'by', 'from', 'up', 'about', 'into', 'through', 'during', 'before',
            'after', 'above', 'below', 'between', 'among'} 
        # 'cloud', 'database',
        #     'server', 'client', 'user', 'admin', 'data', 'system', 'network',
        #     'security', 'management', 'development', 'application', 'platform',
        #     'service', 'solution', 'technology', 'software', 'hardware', 'tool'
        # }
        
        entity_lower = entity.lower()
        
        # Check length
        if len(entity) < 2 or len(entity) > 50:
            return False
        
        # Check if it's a common invalid word
        if entity_lower in invalid_words:
            return False
        
        # Must contain at least one letter
        if not re.search(r'[a-zA-Z]', entity):
            return False
        
        return True
    
    def extract_entities(self, text: str) -> List[str]:
        """
        Extract entities from text using available methods
        
        Args:
            text (str): Input text
            
        Returns:
            List[str]: List of extracted entities
        """
        if not text or not text.strip():
            return []
        
        logger.info("Starting entity extraction...")
        
        # Try Gemini first
        if self.model:
            try:
                entities = self.extract_with_gemini(text)
                if entities:
                    logger.info(f"Successfully extracted {len(entities)} entities with Gemini")
                    return entities
            except Exception as e:
                logger.warning(f"Gemini extraction failed, using fallback: {e}")
        
        # Use fallback method
        entities = self.extract_with_fallback(text)
        logger.info(f"Extracted {len(entities)} entities using fallback method")
        
        return entities