historical-ocr / improvements.md
milwright's picture
fix cline
2d01495

A newer version of the Streamlit SDK is available: 1.45.1

Upgrade

Historical OCR Application Improvements

Based on a thorough code review of the Historical OCR application, I've identified several areas for improvement to reduce technical debt and enhance the application's functionality, maintainability, and performance.

1. Code Organization and Structure

1.1 Modularize Large Functions

Several functions in the codebase are excessively long and handle multiple responsibilities:

  • Issue: process_file() in ocr_processing.py is over 400 lines and handles file validation, preprocessing, OCR processing, and result formatting.
  • Solution: Break down into smaller, focused functions:
    def process_file(uploaded_file, options):
        # Validate and prepare file
        file_info = validate_and_prepare_file(uploaded_file)
        
        # Apply preprocessing based on document type
        preprocessed_file = preprocess_document(file_info, options)
        
        # Perform OCR processing
        ocr_result = perform_ocr(preprocessed_file, options)
        
        # Format and enhance results
        return format_and_enhance_results(ocr_result, file_info)
    

1.2 Consistent Error Handling

Error handling approaches vary across modules:

  • Issue: Some functions use try/except blocks with detailed logging, while others return error dictionaries or raise exceptions.
  • Solution: Implement a consistent error handling strategy:
    class OCRError(Exception):
        def __init__(self, message, error_code=None, details=None):
            self.message = message
            self.error_code = error_code
            self.details = details
            super().__init__(self.message)
            
    def handle_error(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except OCRError as e:
                logger.error(f"OCR Error: {e.message} (Code: {e.error_code})")
                return {"error": e.message, "error_code": e.error_code, "details": e.details}
            except Exception as e:
                logger.error(f"Unexpected error: {str(e)}")
                return {"error": "An unexpected error occurred", "details": str(e)}
        return wrapper
    

2. API Integration and Performance

2.1 API Client Optimization

The Mistral API client initialization and usage can be improved:

  • Issue: The client is initialized for each request and error handling is duplicated.
  • Solution: Create a singleton API client with centralized error handling:
    class MistralClient:
        _instance = None
        
        @classmethod
        def get_instance(cls, api_key=None):
            if cls._instance is None:
                cls._instance = cls(api_key)
            return cls._instance
            
        def __init__(self, api_key=None):
            self.api_key = api_key or os.environ.get("MISTRAL_API_KEY", "")
            self.client = Mistral(api_key=self.api_key)
            
        def process_ocr(self, document, **kwargs):
            try:
                return self.client.ocr.process(document=document, **kwargs)
            except Exception as e:
                # Centralized error handling
                return self._handle_api_error(e)
    

2.2 Caching Strategy

The current caching approach can be improved:

  • Issue: Cache keys don't always account for all relevant parameters, and TTL is fixed at 24 hours.
  • Solution: Implement a more sophisticated caching strategy:
    def generate_cache_key(file_content, options):
        # Create a comprehensive hash of all relevant parameters
        options_str = json.dumps(options, sort_keys=True)
        content_hash = hashlib.md5(file_content).hexdigest()
        return f"{content_hash}_{hashlib.md5(options_str.encode()).hexdigest()}"
        
    # Adaptive TTL based on document type
    def get_cache_ttl(document_type):
        ttl_map = {
            "handwritten": 48 * 3600,  # 48 hours for handwritten docs
            "newspaper": 24 * 3600,    # 24 hours for newspapers
            "standard": 12 * 3600      # 12 hours for standard docs
        }
        return ttl_map.get(document_type, 24 * 3600)
    

3. State Management

3.1 Streamlit Session State

The application uses a complex state management approach:

  • Issue: Many session state variables with unclear relationships and reset logic.
  • Solution: Implement a more structured state management approach:
    class DocumentState:
        def __init__(self):
            self.document = None
            self.original_bytes = None
            self.name = None
            self.mime_type = None
            self.is_sample = False
            self.processed = False
            self.temp_files = []
            
        def reset(self):
            # Clean up temp files
            for temp_file in self.temp_files:
                if os.path.exists(temp_file):
                    os.unlink(temp_file)
            
            # Reset state
            self.__init__()
            
    # Initialize in session state
    if 'document_state' not in st.session_state:
        st.session_state.document_state = DocumentState()
    

3.2 Result History Management

The current approach to managing result history can be improved:

  • Issue: Results are stored directly in session state with limited management.
  • Solution: Create a dedicated class for result history:
    class ResultHistory:
        def __init__(self, max_results=20):
            self.results = []
            self.max_results = max_results
            
        def add_result(self, result):
            # Add timestamp and ensure result is serializable
            result = self._prepare_result(result)
            self.results.insert(0, result)
            
            # Trim to max size
            if len(self.results) > self.max_results:
                self.results = self.results[:self.max_results]
                
        def _prepare_result(self, result):
            # Add timestamp and ensure result is serializable
            result = result.copy()
            result['timestamp'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            
            # Ensure result is serializable
            return json.loads(json.dumps(result, default=str))
    

4. Image Processing Pipeline

4.1 Preprocessing Configuration

The preprocessing configuration can be improved:

  • Issue: Preprocessing options are scattered across different parts of the code.
  • Solution: Create a centralized preprocessing configuration:
    PREPROCESSING_CONFIGS = {
        "standard": {
            "grayscale": True,
            "denoise": True,
            "contrast": 5,
            "deskew": True
        },
        "handwritten": {
            "grayscale": True,
            "denoise": True,
            "contrast": 10,
            "deskew": True,
            "adaptive_threshold": {
                "block_size": 21,
                "constant": 5
            }
        },
        "newspaper": {
            "grayscale": True,
            "denoise": True,
            "contrast": 5,
            "deskew": True,
            "column_detection": True
        }
    }
    

4.2 Image Segmentation

The image segmentation approach can be improved:

  • Issue: Segmentation is optional and not well-integrated with the preprocessing pipeline.
  • Solution: Make segmentation a standard part of the preprocessing pipeline for certain document types:
    def preprocess_document(file_info, options):
        # Apply basic preprocessing
        preprocessed_file = apply_basic_preprocessing(file_info, options)
        
        # Apply segmentation for specific document types
        if options["document_type"] in ["newspaper", "book", "multi_column"]:
            return apply_segmentation(preprocessed_file, options)
        
        return preprocessed_file
    

5. User Experience Enhancements

5.1 Progressive Loading

Improve the user experience during processing:

  • Issue: The UI can appear frozen during long-running operations.
  • Solution: Implement progressive loading and feedback:
    def process_with_feedback(file, options, progress_callback):
        # Update progress at each step
        progress_callback(10, "Validating document...")
        file_info = validate_and_prepare_file(file)
        
        progress_callback(30, "Preprocessing document...")
        preprocessed_file = preprocess_document(file_info, options)
        
        progress_callback(50, "Performing OCR...")
        ocr_result = perform_ocr(preprocessed_file, options)
        
        progress_callback(80, "Enhancing results...")
        final_result = format_and_enhance_results(ocr_result, file_info)
        
        progress_callback(100, "Complete!")
        return final_result
    

5.2 Result Visualization

Enhance the visualization of OCR results:

  • Issue: Results are displayed in a basic format with limited visualization.
  • Solution: Implement enhanced visualization options:
    def display_enhanced_results(result):
        # Create tabs for different views
        tabs = st.tabs(["Text", "Annotated", "Side-by-Side", "JSON"])
        
        with tabs[0]:
            # Display formatted text
            st.markdown(format_ocr_text(result["ocr_contents"]["raw_text"]))
            
        with tabs[1]:
            # Display annotated image with bounding boxes
            display_annotated_image(result)
            
        with tabs[2]:
            # Display side-by-side comparison
            col1, col2 = st.columns(2)
            with col1:
                st.image(result["original_image"])
            with col2:
                st.markdown(format_ocr_text(result["ocr_contents"]["raw_text"]))
                
        with tabs[3]:
            # Display raw JSON
            st.json(result)
    

6. Testing and Reliability

6.1 Automated Testing

Implement comprehensive testing:

  • Issue: Limited or no automated testing.
  • Solution: Implement unit and integration tests:
    # Unit test for preprocessing
    def test_preprocess_image():
        # Test with various document types
        for doc_type in ["standard", "handwritten", "newspaper"]:
            # Load test image
            with open(f"test_data/{doc_type}_sample.jpg", "rb") as f:
                image_bytes = f.read()
                
            # Apply preprocessing
            options = {"document_type": doc_type, "grayscale": True, "denoise": True}
            result = preprocess_image(image_bytes, options)
            
            # Assert result is not None and different from original
            assert result is not None
            assert result != image_bytes
    

6.2 Error Recovery

Implement better error recovery mechanisms:

  • Issue: Errors in one part of the pipeline can cause the entire process to fail.
  • Solution: Implement graceful degradation:
    def process_with_fallbacks(file, options):
        try:
            # Try full processing pipeline
            return full_processing_pipeline(file, options)
        except OCRError as e:
            logger.warning(f"Full pipeline failed: {e.message}. Trying simplified pipeline.")
            try:
                # Try simplified pipeline
                return simplified_processing_pipeline(file, options)
            except Exception as e2:
                logger.error(f"Simplified pipeline failed: {str(e2)}. Falling back to basic OCR.")
                # Fall back to basic OCR
                return basic_ocr_only(file)
    

7. Documentation and Maintainability

7.1 Code Documentation

Improve code documentation:

  • Issue: Inconsistent documentation across modules.
  • Solution: Implement consistent docstring format and add module-level documentation:
    """
    OCR Processing Module
    
    This module handles the core OCR processing functionality, including:
    - File validation and preparation
    - Image preprocessing
    - OCR processing with Mistral AI
    - Result formatting and enhancement
    
    The main entry point is the `process_file` function.
    """
    
    def process_file(file, options):
        """
        Process a file with OCR.
        
        Args:
            file: The file to process (UploadedFile or bytes)
            options: Dictionary of processing options
                - document_type: Type of document (standard, handwritten, etc.)
                - preprocessing: Dictionary of preprocessing options
                - use_vision: Whether to use vision model
                
        Returns:
            Dictionary containing OCR results and metadata
            
        Raises:
            OCRError: If OCR processing fails
        """
        # Implementation
    

7.2 Configuration Management

Improve configuration management:

  • Issue: Configuration is scattered across multiple files.
  • Solution: Implement a centralized configuration system:
    """
    Configuration Module
    
    This module provides a centralized configuration system for the application.
    """
    
    import os
    import yaml
    from pathlib import Path
    
    class Config:
        _instance = None
        
        @classmethod
        def get_instance(cls):
            if cls._instance is None:
                cls._instance = cls()
            return cls._instance
            
        def __init__(self):
            self.config = {}
            self.load_config()
            
        def load_config(self):
            # Load from config file
            config_path = Path(__file__).parent / "config.yaml"
            if config_path.exists():
                with open(config_path, "r") as f:
                    self.config = yaml.safe_load(f)
                    
            # Override with environment variables
            for key, value in os.environ.items():
                if key.startswith("OCR_"):
                    config_key = key[4:].lower()
                    self.config[config_key] = value
                    
        def get(self, key, default=None):
            return self.config.get(key, default)
    

8. Security Enhancements

8.1 API Key Management

Improve API key management:

  • Issue: API keys are stored in environment variables with limited validation.
  • Solution: Implement secure API key management:
    def get_api_key():
        # Try to get from secure storage first
        api_key = get_from_secure_storage("mistral_api_key")
        
        # Fall back to environment variable
        if not api_key:
            api_key = os.environ.get("MISTRAL_API_KEY", "")
            
        # Validate key format
        if api_key and not re.match(r'^[A-Za-z0-9_-]{30,}$', api_key):
            logger.warning("API key format appears invalid")
            
        return api_key
    

8.2 Input Validation

Improve input validation:

  • Issue: Limited validation of user inputs.
  • Solution: Implement comprehensive input validation:
    def validate_file(file):
        # Check file size
        if len(file.getvalue()) > MAX_FILE_SIZE:
            raise OCRError("File too large", "FILE_TOO_LARGE")
            
        # Check file type
        file_type = get_file_type(file)
        if file_type not in ALLOWED_FILE_TYPES:
            raise OCRError(f"Unsupported file type: {file_type}", "UNSUPPORTED_FILE_TYPE")
            
        # Check for malicious content
        if is_potentially_malicious(file):
            raise OCRError("File appears to be malicious", "SECURITY_RISK")
            
        return file_type
    

9. Performance Optimizations

9.1 Parallel Processing

Implement parallel processing for multi-page documents:

  • Issue: Pages are processed sequentially, which can be slow for large documents.
  • Solution: Implement parallel processing:
    def process_pdf_pages(pdf_path, options):
        # Extract pages
        pages = extract_pdf_pages(pdf_path)
        
        # Process pages in parallel
        with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
            future_to_page = {executor.submit(process_page, page, options): i 
                             for i, page in enumerate(pages)}
            
            results = []
            for future in concurrent.futures.as_completed(future_to_page):
                page_idx = future_to_page[future]
                try:
                    result = future.result()
                    results.append((page_idx, result))
                except Exception as e:
                    logger.error(f"Error processing page {page_idx}: {str(e)}")
                    
        # Sort results by page index
        results.sort(key=lambda x: x[0])
        
        # Combine results
        return combine_page_results([r[1] for r in results])
    

9.2 Resource Management

Improve resource management:

  • Issue: Temporary files are not always cleaned up properly.
  • Solution: Implement better resource management:
    class TempFileManager:
        def __init__(self):
            self.temp_files = []
            
        def create_temp_file(self, content, suffix=".tmp"):
            with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
                tmp.write(content)
                self.temp_files.append(tmp.name)
                return tmp.name
                
        def cleanup(self):
            for temp_file in self.temp_files:
                try:
                    if os.path.exists(temp_file):
                        os.unlink(temp_file)
                except Exception as e:
                    logger.warning(f"Failed to remove temp file {temp_file}: {str(e)}")
            self.temp_files = []
            
        def __enter__(self):
            return self
            
        def __exit__(self, exc_type, exc_val, exc_tb):
            self.cleanup()
    

10. Extensibility

10.1 Plugin System

Implement a plugin system for extensibility:

  • Issue: Adding new document types or processing methods requires code changes.
  • Solution: Implement a plugin system:
    class OCRPlugin:
        def __init__(self, name, description):
            self.name = name
            self.description = description
            
        def can_handle(self, file_info):
            """Return True if this plugin can handle the file"""
            raise NotImplementedError
            
        def process(self, file_info, options):
            """Process the file and return results"""
            raise NotImplementedError
            
    # Example plugin
    class HandwrittenDocumentPlugin(OCRPlugin):
        def __init__(self):
            super().__init__("handwritten", "Handwritten document processor")
            
        def can_handle(self, file_info):
            # Check if this is a handwritten document
            return file_info.get("document_type") == "handwritten"
            
        def process(self, file_info, options):
            # Specialized processing for handwritten documents
            # ...
    

10.2 API Abstraction

Create an abstraction layer for the OCR API:

  • Issue: The application is tightly coupled to the Mistral AI API.
  • Solution: Implement an abstraction layer:
    class OCRProvider:
        def process_image(self, image_path, options):
            """Process an image and return OCR results"""
            raise NotImplementedError
            
        def process_pdf(self, pdf_path, options):
            """Process a PDF and return OCR results"""
            raise NotImplementedError
            
    class MistralOCRProvider(OCRProvider):
        def __init__(self, api_key=None):
            self.client = MistralClient.get_instance(api_key)
            
        def process_image(self, image_path, options):
            # Implementation using Mistral API
            
        def process_pdf(self, pdf_path, options):
            # Implementation using Mistral API
            
    # Factory function to get the appropriate provider
    def get_ocr_provider(provider_name="mistral"):
        if provider_name == "mistral":
            return MistralOCRProvider()
        # Add more providers as needed
        raise ValueError(f"Unknown OCR provider: {provider_name}")
    

Implementation Priority