Update main.py
Browse files
main.py
CHANGED
@@ -78,18 +78,26 @@ class OpenAIClient:
|
|
78 |
try:
|
79 |
response = await self.client.chat.completions.create(
|
80 |
model=self.model_name,
|
81 |
-
messages=[
|
|
|
|
|
|
|
82 |
max_tokens=max_tokens,
|
83 |
-
temperature=0.1
|
|
|
84 |
)
|
85 |
|
86 |
content = response.choices[0].message.content
|
87 |
-
confidence = 0.
|
|
|
|
|
|
|
88 |
|
89 |
return content, confidence
|
|
|
90 |
except Exception as e:
|
91 |
logger.error(f"OpenAI API error: {e}")
|
92 |
-
return '{"error": "API call failed"}', 0.1
|
93 |
|
94 |
class SchemaAnalyzer:
|
95 |
def analyze_complexity(self, schema: Dict[str, Any]) -> ComplexityMetrics:
|
@@ -147,12 +155,23 @@ class SchemaAnalyzer:
|
|
147 |
)
|
148 |
|
149 |
def create_extraction_plan(self, schema: Dict[str, Any], complexity: ComplexityMetrics) -> ExtractionPlan:
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
def _create_simple_plan(self, schema: Dict[str, Any]) -> ExtractionPlan:
|
158 |
stages = [ExtractionStage(
|
@@ -167,7 +186,7 @@ class SchemaAnalyzer:
|
|
167 |
stages=stages,
|
168 |
estimated_cost=0.02,
|
169 |
estimated_time=5.0,
|
170 |
-
model_assignments={"complete_extraction": "gpt-4o
|
171 |
)
|
172 |
|
173 |
def _create_medium_plan(self, schema: Dict[str, Any]) -> ExtractionPlan:
|
@@ -408,38 +427,85 @@ class ExtractionEngine:
|
|
408 |
return context
|
409 |
|
410 |
def _create_extraction_prompt(self, context: str, schema: Dict[str, Any], previous_results: Dict[str, Any]) -> str:
|
411 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
|
413 |
-
|
414 |
-
{
|
415 |
|
416 |
-
|
417 |
-
{json.dumps(schema, indent=2)}
|
418 |
|
419 |
-
|
420 |
-
1. Extract
|
421 |
-
2.
|
422 |
-
3.
|
423 |
-
4.
|
424 |
-
5.
|
|
|
425 |
|
426 |
-
|
427 |
|
428 |
def _parse_response(self, response: str, expected_fields: List[str]) -> Dict[str, Any]:
|
429 |
try:
|
430 |
-
|
431 |
-
|
432 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
try:
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
|
441 |
-
logger.
|
442 |
-
return {
|
443 |
|
444 |
class QualityAssessor:
|
445 |
def assess_extraction(self, result: ExtractionResult, schema: Dict[str, Any]) -> QualityReport:
|
@@ -448,24 +514,20 @@ class QualityAssessor:
|
|
448 |
consistency_score = self._check_consistency(result.data)
|
449 |
|
450 |
required_fields = schema.get('required', [])
|
|
|
|
|
|
|
|
|
451 |
|
452 |
if field_scores:
|
453 |
-
|
454 |
-
weighted_confidence = 0
|
455 |
-
|
456 |
-
for field, confidence in field_scores.items():
|
457 |
-
weight = 2.0 if field in required_fields else 1.0
|
458 |
-
weighted_confidence += confidence * weight
|
459 |
-
total_weight += weight
|
460 |
-
|
461 |
-
avg_field_confidence = weighted_confidence / total_weight
|
462 |
else:
|
463 |
avg_field_confidence = 0
|
464 |
|
465 |
-
overall_confidence =
|
466 |
overall_confidence = min(overall_confidence, 1.0)
|
467 |
|
468 |
-
review_flags = self._generate_review_flags(field_scores, schema_compliance, overall_confidence, required_fields, result.data)
|
469 |
review_time = self._estimate_review_time(review_flags, field_scores)
|
470 |
|
471 |
return QualityReport(
|
@@ -538,24 +600,31 @@ class QualityAssessor:
|
|
538 |
|
539 |
return max(0.7, consistency_score)
|
540 |
|
541 |
-
def _generate_review_flags(self, field_scores: Dict[str, float], schema_compliance: float, overall_confidence: float, required_fields: List[str], extracted_data: Dict[str, Any]) -> List[str]:
|
542 |
flags = []
|
543 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
544 |
if overall_confidence < 0.6:
|
545 |
-
flags.append("
|
546 |
elif overall_confidence < 0.8:
|
547 |
-
flags.append("
|
548 |
|
549 |
-
if schema_compliance < 0.
|
550 |
-
flags.append("
|
551 |
-
|
552 |
-
low_confidence_fields = [field for field, score in field_scores.items() if score < 0.7]
|
553 |
-
if low_confidence_fields:
|
554 |
-
flags.append(f"uncertain_fields: {', '.join(low_confidence_fields[:3])}")
|
555 |
|
556 |
missing_required = [field for field in required_fields if field not in extracted_data or extracted_data[field] is None]
|
557 |
if missing_required:
|
558 |
-
flags.append(f"
|
|
|
|
|
|
|
|
|
559 |
|
560 |
return flags
|
561 |
|
|
|
78 |
try:
|
79 |
response = await self.client.chat.completions.create(
|
80 |
model=self.model_name,
|
81 |
+
messages=[
|
82 |
+
{"role": "system", "content": "You are a precise data extraction specialist. Extract data according to the provided schema and output only valid JSON."},
|
83 |
+
{"role": "user", "content": prompt}
|
84 |
+
],
|
85 |
max_tokens=max_tokens,
|
86 |
+
temperature=0.1,
|
87 |
+
top_p=0.9
|
88 |
)
|
89 |
|
90 |
content = response.choices[0].message.content
|
91 |
+
confidence = 0.9 if "gpt-4o" in self.model_name else 0.8
|
92 |
+
|
93 |
+
if content and len(content.strip()) > 10:
|
94 |
+
confidence += 0.05
|
95 |
|
96 |
return content, confidence
|
97 |
+
|
98 |
except Exception as e:
|
99 |
logger.error(f"OpenAI API error: {e}")
|
100 |
+
return '{"error": "API call failed", "details": "' + str(e) + '"}', 0.1
|
101 |
|
102 |
class SchemaAnalyzer:
|
103 |
def analyze_complexity(self, schema: Dict[str, Any]) -> ComplexityMetrics:
|
|
|
155 |
)
|
156 |
|
157 |
def create_extraction_plan(self, schema: Dict[str, Any], complexity: ComplexityMetrics) -> ExtractionPlan:
|
158 |
+
return self._create_single_pass_plan(schema)
|
159 |
+
|
160 |
+
def _create_single_pass_plan(self, schema: Dict[str, Any]) -> ExtractionPlan:
|
161 |
+
stages = [ExtractionStage(
|
162 |
+
name="complete_extraction",
|
163 |
+
fields=list(schema.get('properties', {}).keys()),
|
164 |
+
schema_subset=schema,
|
165 |
+
complexity=2,
|
166 |
+
estimated_tokens=4000
|
167 |
+
)]
|
168 |
+
|
169 |
+
return ExtractionPlan(
|
170 |
+
stages=stages,
|
171 |
+
estimated_cost=0.15,
|
172 |
+
estimated_time=15.0,
|
173 |
+
model_assignments={"complete_extraction": "gpt-4o"}
|
174 |
+
)
|
175 |
|
176 |
def _create_simple_plan(self, schema: Dict[str, Any]) -> ExtractionPlan:
|
177 |
stages = [ExtractionStage(
|
|
|
186 |
stages=stages,
|
187 |
estimated_cost=0.02,
|
188 |
estimated_time=5.0,
|
189 |
+
model_assignments={"complete_extraction": "gpt-4o"}
|
190 |
)
|
191 |
|
192 |
def _create_medium_plan(self, schema: Dict[str, Any]) -> ExtractionPlan:
|
|
|
427 |
return context
|
428 |
|
429 |
def _create_extraction_prompt(self, context: str, schema: Dict[str, Any], previous_results: Dict[str, Any]) -> str:
|
430 |
+
schema_properties = schema.get('properties', {})
|
431 |
+
required_fields = schema.get('required', [])
|
432 |
+
|
433 |
+
field_descriptions = []
|
434 |
+
for field_name, field_def in schema_properties.items():
|
435 |
+
if isinstance(field_def, dict):
|
436 |
+
field_type = field_def.get('type', 'string')
|
437 |
+
is_required = field_name in required_fields
|
438 |
+
status = "REQUIRED" if is_required else "optional"
|
439 |
+
field_descriptions.append(f"- {field_name} ({field_type}) [{status}]")
|
440 |
+
|
441 |
+
previous_context = ""
|
442 |
+
if previous_results:
|
443 |
+
previous_context = f"\n\nPreviously extracted data:\n{json.dumps(previous_results, indent=2)}"
|
444 |
+
|
445 |
+
return f"""Extract ALL specified fields from the document content according to the JSON schema.
|
446 |
+
|
447 |
+
DOCUMENT CONTENT:
|
448 |
+
{context[:4000]}
|
449 |
|
450 |
+
REQUIRED OUTPUT FIELDS:
|
451 |
+
{chr(10).join(field_descriptions)}
|
452 |
|
453 |
+
SCHEMA STRUCTURE:
|
454 |
+
{json.dumps(schema, indent=2)}{previous_context}
|
455 |
|
456 |
+
CRITICAL INSTRUCTIONS:
|
457 |
+
1. Extract ALL fields specified in the schema properties
|
458 |
+
2. For arrays, extract ALL items found in the content
|
459 |
+
3. For objects, extract ALL nested properties
|
460 |
+
4. Use null only if data truly cannot be found
|
461 |
+
5. Maintain exact schema structure and types
|
462 |
+
6. Output ONLY valid JSON, no explanations
|
463 |
|
464 |
+
JSON OUTPUT:"""
|
465 |
|
466 |
def _parse_response(self, response: str, expected_fields: List[str]) -> Dict[str, Any]:
|
467 |
try:
|
468 |
+
cleaned_response = response.strip()
|
469 |
+
|
470 |
+
if not cleaned_response.startswith('{'):
|
471 |
+
json_start = cleaned_response.find('{')
|
472 |
+
if json_start != -1:
|
473 |
+
cleaned_response = cleaned_response[json_start:]
|
474 |
+
|
475 |
+
if not cleaned_response.endswith('}'):
|
476 |
+
json_end = cleaned_response.rfind('}')
|
477 |
+
if json_end != -1:
|
478 |
+
cleaned_response = cleaned_response[:json_end + 1]
|
479 |
+
|
480 |
+
data = json.loads(cleaned_response)
|
481 |
+
|
482 |
+
if isinstance(data, dict):
|
483 |
+
return data
|
484 |
+
else:
|
485 |
+
logger.warning("Response is not a dictionary")
|
486 |
+
return {}
|
487 |
+
|
488 |
+
except json.JSONDecodeError as e:
|
489 |
+
logger.warning(f"JSON decode error: {e}")
|
490 |
+
|
491 |
try:
|
492 |
+
import re
|
493 |
+
json_pattern = r'\{(?:[^{}]|{(?:[^{}]|{[^{}]*})*})*\}'
|
494 |
+
matches = re.findall(json_pattern, response, re.DOTALL)
|
495 |
+
|
496 |
+
for match in matches:
|
497 |
+
try:
|
498 |
+
data = json.loads(match)
|
499 |
+
if isinstance(data, dict) and data:
|
500 |
+
return data
|
501 |
+
except:
|
502 |
+
continue
|
503 |
+
|
504 |
+
except Exception as e:
|
505 |
+
logger.warning(f"Regex parsing failed: {e}")
|
506 |
|
507 |
+
logger.error("All JSON parsing attempts failed")
|
508 |
+
return {}
|
509 |
|
510 |
class QualityAssessor:
|
511 |
def assess_extraction(self, result: ExtractionResult, schema: Dict[str, Any]) -> QualityReport:
|
|
|
514 |
consistency_score = self._check_consistency(result.data)
|
515 |
|
516 |
required_fields = schema.get('required', [])
|
517 |
+
total_expected_fields = len(schema.get('properties', {}))
|
518 |
+
extracted_fields = len([k for k, v in result.data.items() if v is not None])
|
519 |
+
|
520 |
+
completeness_score = extracted_fields / total_expected_fields if total_expected_fields > 0 else 0
|
521 |
|
522 |
if field_scores:
|
523 |
+
avg_field_confidence = sum(field_scores.values()) / len(field_scores)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
else:
|
525 |
avg_field_confidence = 0
|
526 |
|
527 |
+
overall_confidence = completeness_score * 0.6 + schema_compliance * 0.3 + consistency_score * 0.1
|
528 |
overall_confidence = min(overall_confidence, 1.0)
|
529 |
|
530 |
+
review_flags = self._generate_review_flags(field_scores, schema_compliance, overall_confidence, required_fields, result.data, total_expected_fields, extracted_fields)
|
531 |
review_time = self._estimate_review_time(review_flags, field_scores)
|
532 |
|
533 |
return QualityReport(
|
|
|
600 |
|
601 |
return max(0.7, consistency_score)
|
602 |
|
603 |
+
def _generate_review_flags(self, field_scores: Dict[str, float], schema_compliance: float, overall_confidence: float, required_fields: List[str], extracted_data: Dict[str, Any], total_expected: int, extracted_count: int) -> List[str]:
|
604 |
flags = []
|
605 |
|
606 |
+
completeness_rate = extracted_count / total_expected if total_expected > 0 else 0
|
607 |
+
|
608 |
+
if completeness_rate < 0.5:
|
609 |
+
flags.append("incomplete_extraction")
|
610 |
+
elif completeness_rate < 0.8:
|
611 |
+
flags.append("partial_extraction")
|
612 |
+
|
613 |
if overall_confidence < 0.6:
|
614 |
+
flags.append("low_quality")
|
615 |
elif overall_confidence < 0.8:
|
616 |
+
flags.append("moderate_quality")
|
617 |
|
618 |
+
if schema_compliance < 0.7:
|
619 |
+
flags.append("schema_violations")
|
|
|
|
|
|
|
|
|
620 |
|
621 |
missing_required = [field for field in required_fields if field not in extracted_data or extracted_data[field] is None]
|
622 |
if missing_required:
|
623 |
+
flags.append(f"missing_required_fields")
|
624 |
+
|
625 |
+
empty_fields = [k for k, v in extracted_data.items() if v is None or v == ""]
|
626 |
+
if len(empty_fields) > total_expected * 0.3:
|
627 |
+
flags.append("many_empty_fields")
|
628 |
|
629 |
return flags
|
630 |
|