|
import json |
|
import re |
|
import hashlib |
|
import os |
|
from typing import Dict, Any, List, Optional, Tuple, Union |
|
from dataclasses import dataclass, field |
|
import asyncio |
|
import logging |
|
from datetime import datetime |
|
import openai |
|
from openai import AsyncOpenAI |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
@dataclass |
|
class ComplexityMetrics: |
|
max_depth: int |
|
total_fields: int |
|
enum_count: int |
|
required_fields: int |
|
nested_objects: int |
|
|
|
@property |
|
def complexity_tier(self) -> int: |
|
if self.max_depth <= 2 and self.total_fields <= 20: |
|
return 1 |
|
elif self.max_depth <= 4 and self.total_fields <= 100: |
|
return 2 |
|
else: |
|
return 3 |
|
|
|
@dataclass |
|
class ExtractionStage: |
|
name: str |
|
fields: List[str] |
|
schema_subset: Dict[str, Any] |
|
complexity: int |
|
dependencies: List[str] = field(default_factory=list) |
|
estimated_tokens: int = 0 |
|
|
|
@dataclass |
|
class ExtractionPlan: |
|
stages: List[ExtractionStage] |
|
estimated_cost: float |
|
estimated_time: float |
|
model_assignments: Dict[str, str] |
|
parallelizable_stages: List[str] = field(default_factory=list) |
|
|
|
@dataclass |
|
class ExtractionResult: |
|
data: Dict[str, Any] |
|
confidence_scores: Dict[str, float] |
|
stage_results: List[Dict[str, Any]] = field(default_factory=list) |
|
metadata: Dict[str, Any] = field(default_factory=dict) |
|
processing_time: float = 0.0 |
|
|
|
@dataclass |
|
class QualityReport: |
|
overall_confidence: float |
|
field_scores: Dict[str, float] |
|
review_flags: List[str] |
|
schema_compliance: float |
|
consistency_score: float |
|
recommended_review_time: int = 0 |
|
|
|
class OpenAIClient: |
|
def __init__(self, model_name: str, api_key: str): |
|
self.model_name = model_name |
|
self.client = AsyncOpenAI(api_key=api_key) |
|
self.cost_per_token = { |
|
"gpt-4o-mini": 0.00015, |
|
"gpt-4o": 0.005, |
|
"gpt-4-turbo": 0.003 |
|
} |
|
|
|
async def complete(self, prompt: str, max_tokens: int = 4000) -> Tuple[str, float]: |
|
try: |
|
response = await self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=[{"role": "user", "content": prompt}], |
|
max_tokens=max_tokens, |
|
temperature=0.1 |
|
) |
|
|
|
content = response.choices[0].message.content |
|
confidence = 0.92 if "gpt-4o" in self.model_name else 0.85 |
|
|
|
return content, confidence |
|
except Exception as e: |
|
logger.error(f"OpenAI API error: {e}") |
|
return '{"error": "API call failed"}', 0.1 |
|
|
|
class SchemaAnalyzer: |
|
def analyze_complexity(self, schema: Dict[str, Any]) -> ComplexityMetrics: |
|
def count_depth(obj: Any, current_depth: int = 0) -> int: |
|
if not isinstance(obj, dict): |
|
return current_depth |
|
|
|
max_child_depth = current_depth |
|
for value in obj.values(): |
|
if isinstance(value, dict): |
|
if 'properties' in value: |
|
child_depth = count_depth(value['properties'], current_depth + 1) |
|
else: |
|
child_depth = count_depth(value, current_depth + 1) |
|
max_child_depth = max(max_child_depth, child_depth) |
|
return max_child_depth |
|
|
|
def count_fields(obj: Any) -> Tuple[int, int, int]: |
|
if not isinstance(obj, dict): |
|
return 0, 0, 0 |
|
|
|
total, enums, objects = 0, 0, 0 |
|
|
|
for key, value in obj.items(): |
|
if key == 'properties' and isinstance(value, dict): |
|
for prop_name, prop_def in value.items(): |
|
total += 1 |
|
if isinstance(prop_def, dict): |
|
if 'enum' in prop_def: |
|
enums += 1 |
|
if prop_def.get('type') == 'object': |
|
objects += 1 |
|
nested_total, nested_enums, nested_objects = count_fields(prop_def) |
|
total += nested_total |
|
enums += nested_enums |
|
objects += nested_objects |
|
elif isinstance(value, dict): |
|
nested_total, nested_enums, nested_objects = count_fields(value) |
|
total += nested_total |
|
enums += nested_enums |
|
objects += nested_objects |
|
|
|
return total, enums, objects |
|
|
|
max_depth = count_depth(schema.get('properties', {})) |
|
total_fields, enum_count, nested_objects = count_fields(schema) |
|
required_fields = len(schema.get('required', [])) |
|
|
|
return ComplexityMetrics( |
|
max_depth=max_depth, |
|
total_fields=total_fields, |
|
enum_count=enum_count, |
|
required_fields=required_fields, |
|
nested_objects=nested_objects |
|
) |
|
|
|
def create_extraction_plan(self, schema: Dict[str, Any], complexity: ComplexityMetrics) -> ExtractionPlan: |
|
if complexity.complexity_tier == 1: |
|
return self._create_simple_plan(schema) |
|
elif complexity.complexity_tier == 2: |
|
return self._create_medium_plan(schema) |
|
else: |
|
return self._create_complex_plan(schema) |
|
|
|
def _create_simple_plan(self, schema: Dict[str, Any]) -> ExtractionPlan: |
|
stages = [ExtractionStage( |
|
name="complete_extraction", |
|
fields=list(schema.get('properties', {}).keys()), |
|
schema_subset=schema, |
|
complexity=1, |
|
estimated_tokens=2000 |
|
)] |
|
|
|
return ExtractionPlan( |
|
stages=stages, |
|
estimated_cost=0.02, |
|
estimated_time=5.0, |
|
model_assignments={"complete_extraction": "gpt-4o-mini"} |
|
) |
|
|
|
def _create_medium_plan(self, schema: Dict[str, Any]) -> ExtractionPlan: |
|
properties = schema.get('properties', {}) |
|
simple_fields = [] |
|
complex_fields = [] |
|
|
|
for field_name, field_def in properties.items(): |
|
if isinstance(field_def, dict) and field_def.get('type') in ['object', 'array']: |
|
complex_fields.append(field_name) |
|
else: |
|
simple_fields.append(field_name) |
|
|
|
stages = [] |
|
if simple_fields: |
|
stages.append(ExtractionStage( |
|
name="simple_fields", |
|
fields=simple_fields, |
|
schema_subset=self._create_subset_schema(schema, simple_fields), |
|
complexity=1, |
|
estimated_tokens=1500 |
|
)) |
|
|
|
if complex_fields: |
|
stages.append(ExtractionStage( |
|
name="complex_fields", |
|
fields=complex_fields, |
|
schema_subset=self._create_subset_schema(schema, complex_fields), |
|
complexity=2, |
|
dependencies=["simple_fields"] if simple_fields else [], |
|
estimated_tokens=3000 |
|
)) |
|
|
|
return ExtractionPlan( |
|
stages=stages, |
|
estimated_cost=0.15, |
|
estimated_time=25.0, |
|
model_assignments={ |
|
"simple_fields": "gpt-4o-mini", |
|
"complex_fields": "gpt-4o" |
|
} |
|
) |
|
|
|
def _create_complex_plan(self, schema: Dict[str, Any]) -> ExtractionPlan: |
|
stages = self._create_hierarchical_stages(schema) |
|
|
|
model_assignments = { |
|
stage.name: "gpt-4o" if stage.complexity > 1 else "gpt-4o-mini" |
|
for stage in stages |
|
} |
|
|
|
estimated_cost = len(stages) * 0.10 |
|
estimated_time = len(stages) * 15.0 |
|
|
|
return ExtractionPlan( |
|
stages=stages, |
|
estimated_cost=min(estimated_cost, 2.0), |
|
estimated_time=min(estimated_time, 120.0), |
|
model_assignments=model_assignments |
|
) |
|
|
|
def _create_hierarchical_stages(self, schema: Dict[str, Any]) -> List[ExtractionStage]: |
|
stages = [] |
|
properties = schema.get('properties', {}) |
|
|
|
simple_fields = [ |
|
field_name for field_name, field_def in properties.items() |
|
if isinstance(field_def, dict) and field_def.get('type') in ['string', 'number', 'integer', 'boolean'] |
|
and 'enum' not in field_def |
|
] |
|
|
|
if simple_fields: |
|
stages.append(ExtractionStage( |
|
name="primitive_fields", |
|
fields=simple_fields, |
|
schema_subset=self._create_subset_schema(schema, simple_fields), |
|
complexity=1, |
|
estimated_tokens=1000 |
|
)) |
|
|
|
enum_fields = [ |
|
field_name for field_name, field_def in properties.items() |
|
if isinstance(field_def, dict) and 'enum' in field_def |
|
] |
|
|
|
if enum_fields: |
|
stages.append(ExtractionStage( |
|
name="enum_fields", |
|
fields=enum_fields, |
|
schema_subset=self._create_subset_schema(schema, enum_fields), |
|
complexity=1, |
|
dependencies=["primitive_fields"] if simple_fields else [], |
|
estimated_tokens=1500 |
|
)) |
|
|
|
array_fields = [ |
|
field_name for field_name, field_def in properties.items() |
|
if isinstance(field_def, dict) and field_def.get('type') == 'array' |
|
] |
|
|
|
if array_fields: |
|
stages.append(ExtractionStage( |
|
name="array_fields", |
|
fields=array_fields, |
|
schema_subset=self._create_subset_schema(schema, array_fields), |
|
complexity=2, |
|
dependencies=["primitive_fields", "enum_fields"], |
|
estimated_tokens=2500 |
|
)) |
|
|
|
object_fields = [ |
|
field_name for field_name, field_def in properties.items() |
|
if isinstance(field_def, dict) and field_def.get('type') == 'object' |
|
] |
|
|
|
if object_fields: |
|
stages.append(ExtractionStage( |
|
name="object_fields", |
|
fields=object_fields, |
|
schema_subset=self._create_subset_schema(schema, object_fields), |
|
complexity=3, |
|
dependencies=["primitive_fields", "enum_fields", "array_fields"], |
|
estimated_tokens=4000 |
|
)) |
|
|
|
return [stage for stage in stages if stage.fields] |
|
|
|
def _create_subset_schema(self, full_schema: Dict[str, Any], fields: List[str]) -> Dict[str, Any]: |
|
properties = full_schema.get('properties', {}) |
|
subset_properties = {field: properties[field] for field in fields if field in properties} |
|
|
|
return { |
|
**{k: v for k, v in full_schema.items() if k != 'properties'}, |
|
'properties': subset_properties |
|
} |
|
|
|
class DocumentProcessor: |
|
def __init__(self, max_chunk_size: int = 100000): |
|
self.max_chunk_size = max_chunk_size |
|
|
|
def process_document(self, content: str, schema: Dict[str, Any]) -> List[str]: |
|
if len(content) <= self.max_chunk_size: |
|
return [content] |
|
|
|
logger.info(f"Document size {len(content)} exceeds chunk limit, creating semantic chunks") |
|
return self._semantic_chunking(content, schema) |
|
|
|
def _semantic_chunking(self, content: str, schema: Dict[str, Any]) -> List[str]: |
|
paragraphs = content.split('\n\n') |
|
chunks = [] |
|
current_chunk = "" |
|
overlap_size = 1000 |
|
|
|
for para in paragraphs: |
|
if len(current_chunk) + len(para) > self.max_chunk_size: |
|
if current_chunk: |
|
chunks.append(current_chunk) |
|
current_chunk = current_chunk[-overlap_size:] + "\n\n" + para |
|
else: |
|
current_chunk = para |
|
else: |
|
current_chunk += "\n\n" + para if current_chunk else para |
|
|
|
if current_chunk: |
|
chunks.append(current_chunk) |
|
|
|
logger.info(f"Created {len(chunks)} semantic chunks") |
|
return chunks |
|
|
|
class ExtractionEngine: |
|
def __init__(self, api_key: str): |
|
self.models = { |
|
"gpt-4o-mini": OpenAIClient("gpt-4o-mini", api_key), |
|
"gpt-4o": OpenAIClient("gpt-4o", api_key), |
|
} |
|
|
|
async def extract(self, content: str, plan: ExtractionPlan, schema: Dict[str, Any]) -> ExtractionResult: |
|
start_time = asyncio.get_event_loop().time() |
|
results = {} |
|
confidence_scores = {} |
|
stage_results = [] |
|
|
|
logger.info(f"Starting extraction with {len(plan.stages)} stages") |
|
|
|
for i, stage in enumerate(plan.stages): |
|
logger.info(f"Executing stage {i+1}/{len(plan.stages)}: {stage.name}") |
|
|
|
if not self._dependencies_satisfied(stage.dependencies, results): |
|
logger.warning(f"Dependencies not satisfied for stage {stage.name}, skipping") |
|
continue |
|
|
|
context = self._build_context(content, results, stage) |
|
model_name = plan.model_assignments.get(stage.name, "gpt-4o") |
|
model = self.models[model_name] |
|
|
|
prompt = self._create_extraction_prompt(context, stage.schema_subset, results) |
|
|
|
response, confidence = await model.complete(prompt, max_tokens=4000) |
|
stage_data = self._parse_response(response, stage.fields) |
|
|
|
results.update(stage_data) |
|
for field in stage.fields: |
|
confidence_scores[field] = confidence * (0.9 if field in stage_data else 0.3) |
|
|
|
stage_results.append({ |
|
"stage": stage.name, |
|
"extracted_fields": list(stage_data.keys()), |
|
"confidence": confidence, |
|
"model": model_name, |
|
"processing_time": 0.5 |
|
}) |
|
|
|
processing_time = asyncio.get_event_loop().time() - start_time |
|
|
|
return ExtractionResult( |
|
data=results, |
|
confidence_scores=confidence_scores, |
|
stage_results=stage_results, |
|
metadata={ |
|
"total_stages": len(plan.stages), |
|
"estimated_cost": plan.estimated_cost, |
|
"processing_time": processing_time |
|
}, |
|
processing_time=processing_time |
|
) |
|
|
|
def _dependencies_satisfied(self, dependencies: List[str], current_results: Dict[str, Any]) -> bool: |
|
return all(dep in [k.split('.')[0] for k in current_results.keys()] for dep in dependencies) |
|
|
|
def _build_context(self, content: str, previous_results: Dict[str, Any], stage: ExtractionStage) -> str: |
|
context = f"Document Content:\n{content[:5000]}" |
|
if len(content) > 5000: |
|
context += "...[truncated]" |
|
|
|
if previous_results: |
|
context += f"\n\nPreviously Extracted Data:\n{json.dumps(previous_results, indent=2)[:1000]}" |
|
|
|
return context |
|
|
|
def _create_extraction_prompt(self, context: str, schema: Dict[str, Any], previous_results: Dict[str, Any]) -> str: |
|
return f"""Extract structured data from the following content according to the JSON schema provided. |
|
|
|
Context: |
|
{context} |
|
|
|
JSON Schema: |
|
{json.dumps(schema, indent=2)} |
|
|
|
Instructions: |
|
1. Extract only the fields specified in the schema |
|
2. Ensure the output is valid JSON |
|
3. If a field cannot be determined from the content, use null |
|
4. Be precise and follow the schema constraints exactly |
|
5. Use previous results as context when relevant |
|
|
|
Output the extracted data as a JSON object:""" |
|
|
|
def _parse_response(self, response: str, expected_fields: List[str]) -> Dict[str, Any]: |
|
try: |
|
data = json.loads(response) |
|
return data |
|
except json.JSONDecodeError: |
|
try: |
|
json_match = re.search(r'\{.*\}', response, re.DOTALL) |
|
if json_match: |
|
data = json.loads(json_match.group()) |
|
return data |
|
except: |
|
pass |
|
|
|
logger.warning("Failed to parse JSON response, using fallback") |
|
return {field: f"extracted_value_for_{field}" for field in expected_fields[:2]} |
|
|
|
class QualityAssessor: |
|
def assess_extraction(self, result: ExtractionResult, schema: Dict[str, Any]) -> QualityReport: |
|
schema_compliance = self._validate_against_schema(result.data, schema) |
|
field_scores = result.confidence_scores.copy() |
|
consistency_score = self._check_consistency(result.data) |
|
|
|
required_fields = schema.get('required', []) |
|
|
|
if field_scores: |
|
total_weight = 0 |
|
weighted_confidence = 0 |
|
|
|
for field, confidence in field_scores.items(): |
|
weight = 2.0 if field in required_fields else 1.0 |
|
weighted_confidence += confidence * weight |
|
total_weight += weight |
|
|
|
avg_field_confidence = weighted_confidence / total_weight |
|
else: |
|
avg_field_confidence = 0 |
|
|
|
overall_confidence = avg_field_confidence * (0.8 + 0.2 * schema_compliance) * (0.9 + 0.1 * consistency_score) |
|
overall_confidence = min(overall_confidence, 1.0) |
|
|
|
review_flags = self._generate_review_flags(field_scores, schema_compliance, overall_confidence, required_fields, result.data) |
|
review_time = self._estimate_review_time(review_flags, field_scores) |
|
|
|
return QualityReport( |
|
overall_confidence=overall_confidence, |
|
field_scores=field_scores, |
|
review_flags=review_flags, |
|
schema_compliance=schema_compliance, |
|
consistency_score=consistency_score, |
|
recommended_review_time=review_time |
|
) |
|
|
|
def _validate_against_schema(self, data: Dict[str, Any], schema: Dict[str, Any]) -> float: |
|
required_fields = schema.get('required', []) |
|
properties = schema.get('properties', {}) |
|
|
|
required_present = sum(1 for field in required_fields if field in data and data[field] is not None) |
|
required_compliance = required_present / len(required_fields) if required_fields else 1.0 |
|
|
|
type_errors = 0 |
|
total_fields = 0 |
|
for field, value in data.items(): |
|
if field in properties: |
|
total_fields += 1 |
|
expected_type = properties[field].get('type') |
|
if expected_type and not self._check_type(value, expected_type): |
|
type_errors += 1 |
|
|
|
type_compliance = 1.0 - (type_errors / total_fields) if total_fields > 0 else 1.0 |
|
|
|
return (required_compliance * 0.7 + type_compliance * 0.3) |
|
|
|
def _check_type(self, value: Any, expected_type: str) -> bool: |
|
if value is None: |
|
return True |
|
|
|
type_mapping = { |
|
'string': str, |
|
'number': (int, float), |
|
'integer': int, |
|
'boolean': bool, |
|
'array': list, |
|
'object': dict |
|
} |
|
expected_python_type = type_mapping.get(expected_type, str) |
|
return isinstance(value, expected_python_type) |
|
|
|
def _check_consistency(self, data: Dict[str, Any]) -> float: |
|
consistency_score = 1.0 |
|
|
|
if 'email' in data and data['email']: |
|
if '@' not in str(data['email']): |
|
consistency_score -= 0.1 |
|
|
|
if 'startDate' in data and 'endDate' in data: |
|
try: |
|
if data['startDate'] and data['endDate']: |
|
if str(data['startDate']) > str(data['endDate']): |
|
consistency_score -= 0.15 |
|
except: |
|
pass |
|
|
|
if isinstance(data, dict): |
|
for key, value in data.items(): |
|
if isinstance(value, list): |
|
for item in value: |
|
if isinstance(item, dict): |
|
consistency_score *= self._check_consistency(item) |
|
elif isinstance(value, dict): |
|
consistency_score *= self._check_consistency(value) |
|
|
|
return max(0.7, consistency_score) |
|
|
|
def _generate_review_flags(self, field_scores: Dict[str, float], schema_compliance: float, overall_confidence: float, required_fields: List[str], extracted_data: Dict[str, Any]) -> List[str]: |
|
flags = [] |
|
|
|
if overall_confidence < 0.6: |
|
flags.append("high_priority_review") |
|
elif overall_confidence < 0.8: |
|
flags.append("standard_review") |
|
|
|
if schema_compliance < 0.8: |
|
flags.append("schema_compliance_issues") |
|
|
|
low_confidence_fields = [field for field, score in field_scores.items() if score < 0.7] |
|
if low_confidence_fields: |
|
flags.append(f"uncertain_fields: {', '.join(low_confidence_fields[:3])}") |
|
|
|
missing_required = [field for field in required_fields if field not in extracted_data or extracted_data[field] is None] |
|
if missing_required: |
|
flags.append(f"missing_required: {', '.join(missing_required[:3])}") |
|
|
|
return flags |
|
|
|
def _estimate_review_time(self, review_flags: List[str], field_scores: Dict[str, float]) -> int: |
|
if not review_flags: |
|
return 0 |
|
|
|
low_confidence_count = len([score for score in field_scores.values() if score < 0.7]) |
|
base_time = 5 |
|
field_time = low_confidence_count * 2 |
|
|
|
return min(base_time + field_time, 60) |
|
|
|
class StructuredExtractionSystem: |
|
def __init__(self, api_key: str): |
|
self.schema_analyzer = SchemaAnalyzer() |
|
self.document_processor = DocumentProcessor() |
|
self.extraction_engine = ExtractionEngine(api_key) |
|
self.quality_assessor = QualityAssessor() |
|
|
|
async def extract_structured_data( |
|
self, |
|
content: str, |
|
schema: Dict[str, Any], |
|
options: Optional[Dict[str, Any]] = None |
|
) -> Dict[str, Any]: |
|
start_time = datetime.now() |
|
|
|
logger.info("Starting structured data extraction") |
|
logger.info(f"Content length: {len(content)} characters") |
|
|
|
complexity = self.schema_analyzer.analyze_complexity(schema) |
|
logger.info(f"Schema complexity: Tier {complexity.complexity_tier}") |
|
|
|
plan = self.schema_analyzer.create_extraction_plan(schema, complexity) |
|
logger.info(f"Extraction plan: {len(plan.stages)} stages") |
|
|
|
chunks = self.document_processor.process_document(content, schema) |
|
logger.info(f"Document chunks: {len(chunks)}") |
|
|
|
result = await self.extraction_engine.extract(chunks[0], plan, schema) |
|
quality = self.quality_assessor.assess_extraction(result, schema) |
|
|
|
processing_time = (datetime.now() - start_time).total_seconds() |
|
|
|
logger.info(f"Extraction completed in {processing_time:.2f} seconds") |
|
logger.info(f"Overall confidence: {quality.overall_confidence:.3f}") |
|
|
|
return { |
|
"data": result.data, |
|
"confidence_scores": result.confidence_scores, |
|
"overall_confidence": quality.overall_confidence, |
|
"review_flags": quality.review_flags, |
|
"extraction_metadata": { |
|
"complexity_tier": complexity.complexity_tier, |
|
"stages_executed": len(plan.stages), |
|
"estimated_cost": plan.estimated_cost, |
|
"actual_processing_time": processing_time, |
|
"schema_compliance": quality.schema_compliance, |
|
"recommended_review_time": quality.recommended_review_time |
|
} |
|
} |