arjunanand13 commited on
Commit
5f40cee
·
verified ·
1 Parent(s): 91a9da3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +124 -55
main.py CHANGED
@@ -78,18 +78,26 @@ class OpenAIClient:
78
  try:
79
  response = await self.client.chat.completions.create(
80
  model=self.model_name,
81
- messages=[{"role": "user", "content": prompt}],
 
 
 
82
  max_tokens=max_tokens,
83
- temperature=0.1
 
84
  )
85
 
86
  content = response.choices[0].message.content
87
- confidence = 0.92 if "gpt-4o" in self.model_name else 0.85
 
 
 
88
 
89
  return content, confidence
 
90
  except Exception as e:
91
  logger.error(f"OpenAI API error: {e}")
92
- return '{"error": "API call failed"}', 0.1
93
 
94
  class SchemaAnalyzer:
95
  def analyze_complexity(self, schema: Dict[str, Any]) -> ComplexityMetrics:
@@ -147,12 +155,23 @@ class SchemaAnalyzer:
147
  )
148
 
149
  def create_extraction_plan(self, schema: Dict[str, Any], complexity: ComplexityMetrics) -> ExtractionPlan:
150
- if complexity.complexity_tier == 1:
151
- return self._create_simple_plan(schema)
152
- elif complexity.complexity_tier == 2:
153
- return self._create_medium_plan(schema)
154
- else:
155
- return self._create_complex_plan(schema)
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  def _create_simple_plan(self, schema: Dict[str, Any]) -> ExtractionPlan:
158
  stages = [ExtractionStage(
@@ -167,7 +186,7 @@ class SchemaAnalyzer:
167
  stages=stages,
168
  estimated_cost=0.02,
169
  estimated_time=5.0,
170
- model_assignments={"complete_extraction": "gpt-4o-mini"}
171
  )
172
 
173
  def _create_medium_plan(self, schema: Dict[str, Any]) -> ExtractionPlan:
@@ -408,38 +427,85 @@ class ExtractionEngine:
408
  return context
409
 
410
  def _create_extraction_prompt(self, context: str, schema: Dict[str, Any], previous_results: Dict[str, Any]) -> str:
411
- return f"""Extract structured data from the following content according to the JSON schema provided.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
- Context:
414
- {context}
415
 
416
- JSON Schema:
417
- {json.dumps(schema, indent=2)}
418
 
419
- Instructions:
420
- 1. Extract only the fields specified in the schema
421
- 2. Ensure the output is valid JSON
422
- 3. If a field cannot be determined from the content, use null
423
- 4. Be precise and follow the schema constraints exactly
424
- 5. Use previous results as context when relevant
 
425
 
426
- Output the extracted data as a JSON object:"""
427
 
428
  def _parse_response(self, response: str, expected_fields: List[str]) -> Dict[str, Any]:
429
  try:
430
- data = json.loads(response)
431
- return data
432
- except json.JSONDecodeError:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  try:
434
- json_match = re.search(r'\{.*\}', response, re.DOTALL)
435
- if json_match:
436
- data = json.loads(json_match.group())
437
- return data
438
- except:
439
- pass
 
 
 
 
 
 
 
 
440
 
441
- logger.warning("Failed to parse JSON response, using fallback")
442
- return {field: f"extracted_value_for_{field}" for field in expected_fields[:2]}
443
 
444
  class QualityAssessor:
445
  def assess_extraction(self, result: ExtractionResult, schema: Dict[str, Any]) -> QualityReport:
@@ -448,24 +514,20 @@ class QualityAssessor:
448
  consistency_score = self._check_consistency(result.data)
449
 
450
  required_fields = schema.get('required', [])
 
 
 
 
451
 
452
  if field_scores:
453
- total_weight = 0
454
- weighted_confidence = 0
455
-
456
- for field, confidence in field_scores.items():
457
- weight = 2.0 if field in required_fields else 1.0
458
- weighted_confidence += confidence * weight
459
- total_weight += weight
460
-
461
- avg_field_confidence = weighted_confidence / total_weight
462
  else:
463
  avg_field_confidence = 0
464
 
465
- overall_confidence = avg_field_confidence * (0.8 + 0.2 * schema_compliance) * (0.9 + 0.1 * consistency_score)
466
  overall_confidence = min(overall_confidence, 1.0)
467
 
468
- review_flags = self._generate_review_flags(field_scores, schema_compliance, overall_confidence, required_fields, result.data)
469
  review_time = self._estimate_review_time(review_flags, field_scores)
470
 
471
  return QualityReport(
@@ -538,24 +600,31 @@ class QualityAssessor:
538
 
539
  return max(0.7, consistency_score)
540
 
541
- def _generate_review_flags(self, field_scores: Dict[str, float], schema_compliance: float, overall_confidence: float, required_fields: List[str], extracted_data: Dict[str, Any]) -> List[str]:
542
  flags = []
543
 
 
 
 
 
 
 
 
544
  if overall_confidence < 0.6:
545
- flags.append("high_priority_review")
546
  elif overall_confidence < 0.8:
547
- flags.append("standard_review")
548
 
549
- if schema_compliance < 0.8:
550
- flags.append("schema_compliance_issues")
551
-
552
- low_confidence_fields = [field for field, score in field_scores.items() if score < 0.7]
553
- if low_confidence_fields:
554
- flags.append(f"uncertain_fields: {', '.join(low_confidence_fields[:3])}")
555
 
556
  missing_required = [field for field in required_fields if field not in extracted_data or extracted_data[field] is None]
557
  if missing_required:
558
- flags.append(f"missing_required: {', '.join(missing_required[:3])}")
 
 
 
 
559
 
560
  return flags
561
 
 
78
  try:
79
  response = await self.client.chat.completions.create(
80
  model=self.model_name,
81
+ messages=[
82
+ {"role": "system", "content": "You are a precise data extraction specialist. Extract data according to the provided schema and output only valid JSON."},
83
+ {"role": "user", "content": prompt}
84
+ ],
85
  max_tokens=max_tokens,
86
+ temperature=0.1,
87
+ top_p=0.9
88
  )
89
 
90
  content = response.choices[0].message.content
91
+ confidence = 0.9 if "gpt-4o" in self.model_name else 0.8
92
+
93
+ if content and len(content.strip()) > 10:
94
+ confidence += 0.05
95
 
96
  return content, confidence
97
+
98
  except Exception as e:
99
  logger.error(f"OpenAI API error: {e}")
100
+ return '{"error": "API call failed", "details": "' + str(e) + '"}', 0.1
101
 
102
  class SchemaAnalyzer:
103
  def analyze_complexity(self, schema: Dict[str, Any]) -> ComplexityMetrics:
 
155
  )
156
 
157
  def create_extraction_plan(self, schema: Dict[str, Any], complexity: ComplexityMetrics) -> ExtractionPlan:
158
+ return self._create_single_pass_plan(schema)
159
+
160
+ def _create_single_pass_plan(self, schema: Dict[str, Any]) -> ExtractionPlan:
161
+ stages = [ExtractionStage(
162
+ name="complete_extraction",
163
+ fields=list(schema.get('properties', {}).keys()),
164
+ schema_subset=schema,
165
+ complexity=2,
166
+ estimated_tokens=4000
167
+ )]
168
+
169
+ return ExtractionPlan(
170
+ stages=stages,
171
+ estimated_cost=0.15,
172
+ estimated_time=15.0,
173
+ model_assignments={"complete_extraction": "gpt-4o"}
174
+ )
175
 
176
  def _create_simple_plan(self, schema: Dict[str, Any]) -> ExtractionPlan:
177
  stages = [ExtractionStage(
 
186
  stages=stages,
187
  estimated_cost=0.02,
188
  estimated_time=5.0,
189
+ model_assignments={"complete_extraction": "gpt-4o"}
190
  )
191
 
192
  def _create_medium_plan(self, schema: Dict[str, Any]) -> ExtractionPlan:
 
427
  return context
428
 
429
  def _create_extraction_prompt(self, context: str, schema: Dict[str, Any], previous_results: Dict[str, Any]) -> str:
430
+ schema_properties = schema.get('properties', {})
431
+ required_fields = schema.get('required', [])
432
+
433
+ field_descriptions = []
434
+ for field_name, field_def in schema_properties.items():
435
+ if isinstance(field_def, dict):
436
+ field_type = field_def.get('type', 'string')
437
+ is_required = field_name in required_fields
438
+ status = "REQUIRED" if is_required else "optional"
439
+ field_descriptions.append(f"- {field_name} ({field_type}) [{status}]")
440
+
441
+ previous_context = ""
442
+ if previous_results:
443
+ previous_context = f"\n\nPreviously extracted data:\n{json.dumps(previous_results, indent=2)}"
444
+
445
+ return f"""Extract ALL specified fields from the document content according to the JSON schema.
446
+
447
+ DOCUMENT CONTENT:
448
+ {context[:4000]}
449
 
450
+ REQUIRED OUTPUT FIELDS:
451
+ {chr(10).join(field_descriptions)}
452
 
453
+ SCHEMA STRUCTURE:
454
+ {json.dumps(schema, indent=2)}{previous_context}
455
 
456
+ CRITICAL INSTRUCTIONS:
457
+ 1. Extract ALL fields specified in the schema properties
458
+ 2. For arrays, extract ALL items found in the content
459
+ 3. For objects, extract ALL nested properties
460
+ 4. Use null only if data truly cannot be found
461
+ 5. Maintain exact schema structure and types
462
+ 6. Output ONLY valid JSON, no explanations
463
 
464
+ JSON OUTPUT:"""
465
 
466
  def _parse_response(self, response: str, expected_fields: List[str]) -> Dict[str, Any]:
467
  try:
468
+ cleaned_response = response.strip()
469
+
470
+ if not cleaned_response.startswith('{'):
471
+ json_start = cleaned_response.find('{')
472
+ if json_start != -1:
473
+ cleaned_response = cleaned_response[json_start:]
474
+
475
+ if not cleaned_response.endswith('}'):
476
+ json_end = cleaned_response.rfind('}')
477
+ if json_end != -1:
478
+ cleaned_response = cleaned_response[:json_end + 1]
479
+
480
+ data = json.loads(cleaned_response)
481
+
482
+ if isinstance(data, dict):
483
+ return data
484
+ else:
485
+ logger.warning("Response is not a dictionary")
486
+ return {}
487
+
488
+ except json.JSONDecodeError as e:
489
+ logger.warning(f"JSON decode error: {e}")
490
+
491
  try:
492
+ import re
493
+ json_pattern = r'\{(?:[^{}]|{(?:[^{}]|{[^{}]*})*})*\}'
494
+ matches = re.findall(json_pattern, response, re.DOTALL)
495
+
496
+ for match in matches:
497
+ try:
498
+ data = json.loads(match)
499
+ if isinstance(data, dict) and data:
500
+ return data
501
+ except:
502
+ continue
503
+
504
+ except Exception as e:
505
+ logger.warning(f"Regex parsing failed: {e}")
506
 
507
+ logger.error("All JSON parsing attempts failed")
508
+ return {}
509
 
510
  class QualityAssessor:
511
  def assess_extraction(self, result: ExtractionResult, schema: Dict[str, Any]) -> QualityReport:
 
514
  consistency_score = self._check_consistency(result.data)
515
 
516
  required_fields = schema.get('required', [])
517
+ total_expected_fields = len(schema.get('properties', {}))
518
+ extracted_fields = len([k for k, v in result.data.items() if v is not None])
519
+
520
+ completeness_score = extracted_fields / total_expected_fields if total_expected_fields > 0 else 0
521
 
522
  if field_scores:
523
+ avg_field_confidence = sum(field_scores.values()) / len(field_scores)
 
 
 
 
 
 
 
 
524
  else:
525
  avg_field_confidence = 0
526
 
527
+ overall_confidence = completeness_score * 0.6 + schema_compliance * 0.3 + consistency_score * 0.1
528
  overall_confidence = min(overall_confidence, 1.0)
529
 
530
+ review_flags = self._generate_review_flags(field_scores, schema_compliance, overall_confidence, required_fields, result.data, total_expected_fields, extracted_fields)
531
  review_time = self._estimate_review_time(review_flags, field_scores)
532
 
533
  return QualityReport(
 
600
 
601
  return max(0.7, consistency_score)
602
 
603
+ def _generate_review_flags(self, field_scores: Dict[str, float], schema_compliance: float, overall_confidence: float, required_fields: List[str], extracted_data: Dict[str, Any], total_expected: int, extracted_count: int) -> List[str]:
604
  flags = []
605
 
606
+ completeness_rate = extracted_count / total_expected if total_expected > 0 else 0
607
+
608
+ if completeness_rate < 0.5:
609
+ flags.append("incomplete_extraction")
610
+ elif completeness_rate < 0.8:
611
+ flags.append("partial_extraction")
612
+
613
  if overall_confidence < 0.6:
614
+ flags.append("low_quality")
615
  elif overall_confidence < 0.8:
616
+ flags.append("moderate_quality")
617
 
618
+ if schema_compliance < 0.7:
619
+ flags.append("schema_violations")
 
 
 
 
620
 
621
  missing_required = [field for field in required_fields if field not in extracted_data or extracted_data[field] is None]
622
  if missing_required:
623
+ flags.append(f"missing_required_fields")
624
+
625
+ empty_fields = [k for k, v in extracted_data.items() if v is None or v == ""]
626
+ if len(empty_fields) > total_expected * 0.3:
627
+ flags.append("many_empty_fields")
628
 
629
  return flags
630