|
import json |
|
import requests |
|
from typing import List, Dict, Any, Optional |
|
from config.settings import Config |
|
|
|
class LLMExtractor: |
|
def __init__(self): |
|
self.config = Config() |
|
self.headers = { |
|
"Authorization": f"Bearer {self.config.OPENROUTER_API_KEY}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
def extract_entities_and_relationships(self, text: str) -> Dict[str, Any]: |
|
"""Extract entities and relationships from text using LLM.""" |
|
prompt = self._create_extraction_prompt(text) |
|
|
|
try: |
|
response = self._call_openrouter_api(prompt, self.config.EXTRACTION_MODEL) |
|
result = self._parse_extraction_response(response) |
|
return result |
|
except Exception as e: |
|
|
|
try: |
|
response = self._call_openrouter_api(prompt, self.config.BACKUP_MODEL) |
|
result = self._parse_extraction_response(response) |
|
return result |
|
except Exception as backup_e: |
|
return { |
|
"entities": [], |
|
"relationships": [], |
|
"error": f"Primary: {str(e)}, Backup: {str(backup_e)}" |
|
} |
|
|
|
def _create_extraction_prompt(self, text: str) -> str: |
|
"""Create prompt for entity and relationship extraction.""" |
|
return f""" |
|
You are an expert knowledge graph extraction system. Analyze the following text and extract: |
|
|
|
1. ENTITIES: Important people, organizations, locations, concepts, events, objects, etc. |
|
2. RELATIONSHIPS: How these entities relate to each other |
|
3. IMPORTANCE SCORES: Rate each entity's importance from 0.0 to 1.0 based on how central it is to the text |
|
|
|
For each entity, provide: |
|
- name: The entity name (standardized/canonical form) |
|
- type: The entity type (PERSON, ORGANIZATION, LOCATION, CONCEPT, EVENT, OBJECT, etc.) |
|
- importance: Score from 0.0 to 1.0 |
|
- description: Brief description of the entity's role/significance |
|
|
|
For each relationship, provide: |
|
- source: Source entity name |
|
- target: Target entity name |
|
- relationship: Type of relationship (works_at, located_in, part_of, causes, etc.) |
|
- description: Brief description of the relationship |
|
|
|
Only respond with a valid JSON object with this structure and nothing else. Your response must be valid, parsable JSON!! |
|
=== JSON STRUCTURE FOR RESPONSE / RESPONSE FORMAT === |
|
{{ |
|
"entities": [ |
|
{{ |
|
"name": "entity_name", |
|
"type": "ENTITY_TYPE", |
|
"importance": 0.8, |
|
"description": "Brief description" |
|
}} |
|
], |
|
"relationships": [ |
|
{{ |
|
"source": "entity1", |
|
"target": "entity2", |
|
"relationship": "relationship_type", |
|
"description": "Brief description" |
|
}} |
|
] |
|
}} |
|
=== END OF JSON STRUCTURE FOR RESPONSE / END OF RESPONSE FORMAT === |
|
|
|
TEXT TO ANALYZE: |
|
{text} |
|
|
|
Reply in valid json using the format above! |
|
JSON OUTPUT: |
|
""" |
|
|
|
def _call_openrouter_api(self, prompt: str, model: str) -> str: |
|
"""Make API call to OpenRouter.""" |
|
if not self.config.OPENROUTER_API_KEY: |
|
raise ValueError("OpenRouter API key not configured") |
|
|
|
payload = { |
|
"model": model, |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": prompt |
|
} |
|
], |
|
"max_tokens": 2048, |
|
"temperature": 0.1 |
|
} |
|
|
|
response = requests.post( |
|
f"{self.config.OPENROUTER_BASE_URL}/chat/completions", |
|
headers=self.headers, |
|
json=payload, |
|
timeout=60 |
|
) |
|
|
|
if response.status_code != 200: |
|
raise Exception(f"API call failed: {response.status_code} - {response.text}") |
|
|
|
result = response.json() |
|
if "choices" not in result or not result["choices"]: |
|
raise Exception("Invalid API response format") |
|
|
|
return result["choices"][0]["message"]["content"] |
|
|
|
def _parse_extraction_response(self, response: str) -> Dict[str, Any]: |
|
"""Parse the LLM response into structured data.""" |
|
try: |
|
|
|
start_idx = response.find("{") |
|
end_idx = response.rfind("}") + 1 |
|
|
|
if start_idx == -1 or end_idx == 0: |
|
raise ValueError("No JSON found in response") |
|
|
|
json_str = response[start_idx:end_idx] |
|
data = json.loads(json_str) |
|
|
|
|
|
if "entities" not in data: |
|
data["entities"] = [] |
|
if "relationships" not in data: |
|
data["relationships"] = [] |
|
|
|
|
|
filtered_entities = [ |
|
entity for entity in data["entities"] |
|
if entity.get("importance", 0) >= self.config.ENTITY_IMPORTANCE_THRESHOLD |
|
] |
|
|
|
|
|
data["entities"] = filtered_entities[:self.config.MAX_ENTITIES] |
|
data["relationships"] = data["relationships"][:self.config.MAX_RELATIONSHIPS] |
|
|
|
return data |
|
|
|
except json.JSONDecodeError as e: |
|
return { |
|
"entities": [], |
|
"relationships": [], |
|
"error": f"JSON parsing error: {str(e)}" |
|
} |
|
except Exception as e: |
|
return { |
|
"entities": [], |
|
"relationships": [], |
|
"error": f"Response parsing error: {str(e)}" |
|
} |
|
|
|
def process_chunks(self, chunks: List[str]) -> Dict[str, Any]: |
|
"""Process multiple text chunks and combine results.""" |
|
all_entities = [] |
|
all_relationships = [] |
|
errors = [] |
|
|
|
for i, chunk in enumerate(chunks): |
|
try: |
|
result = self.extract_entities_and_relationships(chunk) |
|
|
|
if "error" in result: |
|
errors.append(f"Chunk {i+1}: {result['error']}") |
|
continue |
|
|
|
all_entities.extend(result.get("entities", [])) |
|
all_relationships.extend(result.get("relationships", [])) |
|
|
|
except Exception as e: |
|
errors.append(f"Chunk {i+1}: {str(e)}") |
|
|
|
|
|
unique_entities = self._deduplicate_entities(all_entities) |
|
|
|
|
|
valid_relationships = self._validate_relationships(all_relationships, unique_entities) |
|
|
|
return { |
|
"entities": unique_entities, |
|
"relationships": valid_relationships, |
|
"errors": errors if errors else None |
|
} |
|
|
|
def _deduplicate_entities(self, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
"""Remove duplicate entities and merge similar ones.""" |
|
seen_names = set() |
|
unique_entities = [] |
|
|
|
for entity in entities: |
|
name = entity.get("name", "").lower().strip() |
|
if name and name not in seen_names: |
|
seen_names.add(name) |
|
unique_entities.append(entity) |
|
|
|
|
|
unique_entities.sort(key=lambda x: x.get("importance", 0), reverse=True) |
|
|
|
return unique_entities[:self.config.MAX_ENTITIES] |
|
|
|
def _validate_relationships(self, relationships: List[Dict[str, Any]], entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
"""Validate that relationships reference existing entities.""" |
|
entity_names = {entity.get("name", "").lower() for entity in entities} |
|
valid_relationships = [] |
|
|
|
for rel in relationships: |
|
source = rel.get("source", "").lower() |
|
target = rel.get("target", "").lower() |
|
|
|
if source in entity_names and target in entity_names: |
|
valid_relationships.append(rel) |
|
|
|
return valid_relationships[:self.config.MAX_RELATIONSHIPS] |
|
|