Spaces:
Sleeping
Sleeping
| import spaces | |
| import transformers | |
| import re | |
| import torch | |
| import gradio as gr | |
| import os | |
| import ctranslate2 | |
| from concurrent.futures import ThreadPoolExecutor | |
| # Define the device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load CTranslate2 model and tokenizer | |
| model_path = "ocronos_CT2" | |
| generator = ctranslate2.Generator(model_path, device=device) | |
| tokenizer = transformers.AutoTokenizer.from_pretrained("PleIAs/OCRonos-Vintage") | |
| # CSS for formatting (unchanged) | |
| css = """ | |
| <style> | |
| (your existing CSS) | |
| </style> | |
| """ | |
| # Helper functions | |
| def generate_html_diff(old_text, new_text): | |
| # (unchanged) | |
| ... | |
| def preprocess_text(text): | |
| # (unchanged) | |
| ... | |
| def split_text(text, max_tokens=400): | |
| encoded = tokenizer.encode(text) | |
| splits = [] | |
| for i in range(0, len(encoded), max_tokens): | |
| split = encoded[i:i+max_tokens] | |
| splits.append(tokenizer.decode(split)) | |
| return splits | |
| # Function to generate text using CTranslate2 | |
| def ocr_correction(prompt, max_new_tokens=600): | |
| splits = split_text(prompt, max_tokens=400) | |
| corrected_splits = [] | |
| for split in splits: | |
| full_prompt = f"### Text ###\n{split}\n\n\n### Correction ###\n" | |
| encoded = tokenizer.encode(full_prompt) | |
| prompt_tokens = tokenizer.convert_ids_to_tokens(encoded) | |
| result = generator.generate_batch( | |
| [prompt_tokens], | |
| max_length=max_new_tokens, | |
| sampling_temperature=0.7, | |
| sampling_topk=20, | |
| include_prompt_in_result=False | |
| )[0] | |
| corrected_text = tokenizer.decode(result.sequences_ids[0]) | |
| corrected_splits.append(corrected_text) | |
| return " ".join(corrected_splits) | |
| # OCR Correction Class | |
| class OCRCorrector: | |
| def __init__(self, system_prompt="Le dialogue suivant est une conversation"): | |
| self.system_prompt = system_prompt | |
| def correct(self, user_message): | |
| generated_text = ocr_correction(user_message) | |
| html_diff = generate_html_diff(user_message, generated_text) | |
| return generated_text, html_diff | |
| # Combined Processing Class | |
| class TextProcessor: | |
| def __init__(self): | |
| self.ocr_corrector = OCRCorrector() | |
| def process(self, user_message): | |
| # OCR Correction | |
| corrected_text, html_diff = self.ocr_corrector.correct(user_message) | |
| # Combine results | |
| ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>' | |
| final_output = f"{css}{ocr_result}" | |
| return final_output | |
| # Create the TextProcessor instance | |
| text_processor = TextProcessor() | |
| # Define the Gradio interface | |
| with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo: | |
| gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector</h1>""") | |
| text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5) | |
| process_button = gr.Button("Process Text") | |
| text_output = gr.HTML(label="Processed text") | |
| process_button.click(text_processor.process, inputs=text_input, outputs=[text_output]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |