Spaces:
Sleeping
Sleeping
# core.py - Enhanced with Text Quality AssessmentR | |
import pyiqa | |
import torch | |
from PIL import Image | |
import glob | |
import logging | |
import numpy as np | |
import cv2 | |
import easyocr | |
from typing import Dict, List, Tuple, Optional | |
import warnings | |
warnings.filterwarnings("ignore") | |
class TextQualityAssessor: | |
"""Specialized text quality assessment using OCR confidence scores""" | |
def __init__(self): | |
self.ocr_reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available()) | |
def assess_text_quality(self, image: Image.Image) -> Dict: | |
"""Assess text quality using OCR confidence and detection metrics""" | |
try: | |
# Convert PIL to OpenCV format | |
cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
# Perform OCR with confidence scores | |
results = self.ocr_reader.readtext(cv_image, detail=1) | |
if not results: | |
return { | |
'text_detected': False, | |
'text_quality_score': 100.0, # No text = no text quality issues | |
'avg_confidence': 1.0, | |
'text_regions': 0, | |
'low_quality_regions': 0, | |
'details': "No text detected" | |
} | |
confidences = [result[2] for result in results] | |
avg_confidence = np.mean(confidences) | |
# Count low quality text regions (confidence < 0.8) | |
low_quality_threshold = 0.8 | |
low_quality_regions = sum(1 for conf in confidences if conf < low_quality_threshold) | |
# Calculate text quality score based on confidence distribution | |
# Higher penalties for very low confidence text | |
quality_penalties = [] | |
for conf in confidences: | |
if conf >= 0.9: | |
quality_penalties.append(0) # Excellent text | |
elif conf >= 0.8: | |
quality_penalties.append(5) # Good text | |
elif conf >= 0.6: | |
quality_penalties.append(15) # Readable but poor quality | |
elif conf >= 0.4: | |
quality_penalties.append(30) # Heavily distorted | |
else: | |
quality_penalties.append(50) # Severely distorted/unreadable | |
avg_penalty = np.mean(quality_penalties) if quality_penalties else 0 | |
text_quality_score = max(0, 100 - avg_penalty) | |
# Additional penalty for high proportion of low-quality regions | |
if len(confidences) > 0: | |
low_quality_ratio = low_quality_regions / len(confidences) | |
if low_quality_ratio > 0.5: # More than half regions are poor quality | |
text_quality_score *= 0.7 # 30% additional penalty | |
return { | |
'text_detected': True, | |
'text_quality_score': text_quality_score, | |
'avg_confidence': avg_confidence, | |
'text_regions': len(results), | |
'low_quality_regions': low_quality_regions, | |
'details': f"Detected {len(results)} text regions, avg confidence: {avg_confidence:.3f}" | |
} | |
except Exception as e: | |
logging.error(f"Text quality assessment error: {str(e)}") | |
return { | |
'text_detected': False, | |
'text_quality_score': 50.0, # Neutral score on error | |
'avg_confidence': 0.0, | |
'text_regions': 0, | |
'low_quality_regions': 0, | |
'details': f"Error: {str(e)}" | |
} | |
class HybridIQA: | |
"""Enhanced IQA with text-specific quality assessment""" | |
def __init__(self, model_name="qualiclip+", text_weight=0.3): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model = pyiqa.create_metric(model_name, device=device) | |
self.text_assessor = TextQualityAssessor() | |
self.text_weight = text_weight # Weight for text quality in final score | |
self.model_name = model_name | |
logging.basicConfig(level=logging.INFO) | |
self.logger = logging.getLogger(__name__) | |
self.logger.info(f"Hybrid IQA loaded: {model_name} + Text Quality Assessment on {device}") | |
def __call__(self, image, return_details=False): | |
""" | |
Evaluate image quality with both traditional IQA and text-specific assessment | |
Args: | |
image: PIL Image or path to image | |
return_details: If True, return detailed breakdown | |
Returns: | |
If return_details=False: Combined quality score (0-100) | |
If return_details=True: Dict with detailed scores and analysis | |
""" | |
try: | |
# Ensure image is PIL Image | |
if not isinstance(image, Image.Image): | |
image = Image.open(image).convert("RGB") | |
else: | |
image = image.convert("RGB") | |
# Get traditional IQA score | |
# Get traditional IQA score | |
if self.model_name == 'qalign': | |
# Q-Align has special interface for quality assessment | |
traditional_score = self.model(image, task_='quality') | |
else: | |
traditional_score = self.model(image) | |
if hasattr(traditional_score, 'item'): | |
traditional_score = traditional_score.item() | |
# Normalize traditional score to 0-100 range | |
if 0 <= traditional_score <= 1: | |
traditional_score *= 100 | |
# Get text quality assessment | |
text_analysis = self.text_assessor.assess_text_quality(image) | |
# Calculate combined score | |
if text_analysis['text_detected']: | |
# If text is detected, combine scores | |
combined_score = ( | |
(1 - self.text_weight) * traditional_score + | |
self.text_weight * text_analysis['text_quality_score'] | |
) | |
# Apply additional penalty if text quality is very poor | |
if text_analysis['text_quality_score'] < 30: | |
combined_score *= 0.8 # 20% additional penalty for severely poor text | |
else: | |
# No text detected, use traditional score only | |
combined_score = traditional_score | |
if return_details: | |
return { | |
'combined_score': combined_score, | |
'traditional_score': traditional_score, | |
'text_analysis': text_analysis, | |
'model_used': self.model_name, | |
'text_weight': self.text_weight | |
} | |
else: | |
return combined_score | |
except Exception as e: | |
self.logger.error(f"Error processing image: {str(e)}") | |
return None if not return_details else {'error': str(e)} | |
# Backward compatibility - maintain original IQA interface | |
class IQA(HybridIQA): | |
"""Backward compatible IQA class with enhanced text assessment""" | |
def __init__(self, model_name="qualiclip+"): | |
super().__init__(model_name, text_weight=0.3) | |
def __call__(self, image): | |
"""Maintain original interface - returns single score""" | |
return super().__call__(image, return_details=False) | |
def detailed_analysis(self, image): | |
"""New method for detailed analysis""" | |
return super().__call__(image, return_details=True) | |
# Advanced usage class for power users | |
class TextAwareIQA: | |
"""Advanced interface with configurable text assessment parameters""" | |
def __init__(self, model_name="qualiclip+", text_weight=0.3, text_threshold=0.8): | |
self.hybrid_iqa = HybridIQA(model_name, text_weight) | |
self.text_threshold = text_threshold | |
def evaluate(self, image, text_penalty_mode='balanced'): | |
""" | |
Evaluate with different text penalty modes | |
Args: | |
image: PIL Image or path | |
text_penalty_mode: 'strict', 'balanced', or 'lenient' | |
""" | |
details = self.hybrid_iqa(image, return_details=True) | |
if details is None or 'error' in details: | |
return details | |
# Adjust text penalties based on mode | |
if details['text_analysis']['text_detected']: | |
text_score = details['text_analysis']['text_quality_score'] | |
traditional_score = details['traditional_score'] | |
if text_penalty_mode == 'strict': | |
# Heavily penalize any text quality issues | |
weight = 0.5 | |
if text_score < 70: | |
text_score *= 0.6 | |
elif text_penalty_mode == 'lenient': | |
# Only penalize severe text issues | |
weight = 0.1 | |
if text_score > 40: | |
text_score = min(text_score * 1.2, 100) | |
else: # balanced | |
weight = 0.3 | |
combined_score = (1 - weight) * traditional_score + weight * text_score | |
details['combined_score'] = combined_score | |
details['penalty_mode'] = text_penalty_mode | |
return details | |
if __name__ == "__main__": | |
# Test both interfaces | |
print("Testing Hybrid IQA System") | |
print("=" * 50) | |
# Original interface (backward compatible) | |
print("\n1. Original Interface (Backward Compatible):") | |
iqa_metric = IQA(model_name="qualiclip+") | |
# Advanced interface | |
print("\n2. Advanced Interface:") | |
advanced_iqa = TextAwareIQA(model_name="qualiclip+", text_weight=0.4) | |
image_files = glob.glob("samples/*") | |
if not image_files: | |
print("No images found in samples directory. Please add images or adjust the path.") | |
else: | |
for image_file in image_files[:3]: # Test first 3 images | |
print(f"\nAnalyzing: {image_file}") | |
# Original score | |
score = iqa_metric(image_file) | |
if score is not None: | |
print(f" Simple Score: {score:.2f}/100") | |
# Detailed analysis | |
details = iqa_metric.detailed_analysis(image_file) | |
if details and 'error' not in details: | |
print(f" Traditional IQA: {details['traditional_score']:.2f}/100") | |
print(f" Text Quality: {details['text_analysis']['text_quality_score']:.2f}/100") | |
print(f" Combined Score: {details['combined_score']:.2f}/100") | |
print(f" Text Details: {details['text_analysis']['details']}") | |
if details['text_analysis']['text_detected']: | |
print(f" Text Regions: {details['text_analysis']['text_regions']}") | |
print(f" Low Quality Regions: {details['text_analysis']['low_quality_regions']}") | |
print(f" Avg OCR Confidence: {details['text_analysis']['avg_confidence']:.3f}") | |
print("-" * 30) |