Spaces:
Sleeping
Sleeping
| import spaces | |
| import transformers | |
| import re | |
| from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline | |
| from vllm import LLM, SamplingParams | |
| import torch | |
| import gradio as gr | |
| import json | |
| import os | |
| import shutil | |
| import requests | |
| import pandas as pd | |
| import difflib | |
| # Define the device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # OCR Correction Model | |
| ocr_model_name = "PleIAs/OCRonos-Vintage" | |
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| # Load pre-trained model and tokenizer | |
| model_name = "PleIAs/OCRonos-Vintage" | |
| model = GPT2LMHeadModel.from_pretrained(model_name) | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| # Set the device to GPU if available, otherwise use CPU | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # CSS for formatting | |
| css = """ | |
| <style> | |
| .generation { | |
| margin-left: 2em; | |
| margin-right: 2em; | |
| font-size: 1.2em; | |
| } | |
| :target { | |
| background-color: #CCF3DF; | |
| } | |
| .source { | |
| float: left; | |
| max-width: 17%; | |
| margin-left: 2%; | |
| } | |
| .tooltip { | |
| position: relative; | |
| cursor: pointer; | |
| font-variant-position: super; | |
| color: #97999b; | |
| } | |
| .tooltip:hover::after { | |
| content: attr(data-text); | |
| position: absolute; | |
| left: 0; | |
| top: 120%; | |
| white-space: pre-wrap; | |
| width: 500px; | |
| max-width: 500px; | |
| z-index: 1; | |
| background-color: #f9f9f9; | |
| color: #000; | |
| border: 1px solid #ddd; | |
| border-radius: 5px; | |
| padding: 5px; | |
| display: block; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
| } | |
| .deleted { | |
| background-color: #ffcccb; | |
| text-decoration: line-through; | |
| } | |
| .inserted { | |
| background-color: #90EE90; | |
| } | |
| .manuscript { | |
| display: flex; | |
| margin-bottom: 10px; | |
| align-items: baseline; | |
| } | |
| .annotation { | |
| width: 15%; | |
| padding-right: 20px; | |
| color: grey !important; | |
| font-style: italic; | |
| text-align: right; | |
| } | |
| .content { | |
| width: 80%; | |
| } | |
| h2 { | |
| margin: 0; | |
| font-size: 1.5em; | |
| } | |
| .title-content h2 { | |
| font-weight: bold; | |
| } | |
| .bibliography-content { | |
| color: darkgreen !important; | |
| margin-top: -5px; | |
| } | |
| .paratext-content { | |
| color: #a4a4a4 !important; | |
| margin-top: -5px; | |
| } | |
| </style> | |
| """ | |
| # Helper functions | |
| def generate_html_diff(old_text, new_text): | |
| d = difflib.Differ() | |
| diff = list(d.compare(old_text.split(), new_text.split())) | |
| html_diff = [] | |
| for word in diff: | |
| if word.startswith(' '): | |
| html_diff.append(word[2:]) | |
| elif word.startswith('+ '): | |
| html_diff.append(f'<span style="background-color: #90EE90;">{word[2:]}</span>') | |
| return ' '.join(html_diff) | |
| def preprocess_text(text): | |
| text = re.sub(r'<[^>]+>', '', text) | |
| text = re.sub(r'\n', ' ', text) | |
| text = re.sub(r'\s+', ' ', text) | |
| return text.strip() | |
| def split_text(text, max_tokens=500): | |
| parts = text.split("\n") | |
| chunks = [] | |
| current_chunk = "" | |
| for part in parts: | |
| if current_chunk: | |
| temp_chunk = current_chunk + "\n" + part | |
| else: | |
| temp_chunk = part | |
| num_tokens = len(tokenizer.tokenize(temp_chunk)) | |
| if num_tokens <= max_tokens: | |
| current_chunk = temp_chunk | |
| else: | |
| if current_chunk: | |
| chunks.append(current_chunk) | |
| current_chunk = part | |
| if current_chunk: | |
| chunks.append(current_chunk) | |
| if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: | |
| long_text = chunks[0] | |
| chunks = [] | |
| while len(tokenizer.tokenize(long_text)) > max_tokens: | |
| split_point = len(long_text) // 2 | |
| while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): | |
| split_point += 1 | |
| if split_point >= len(long_text): | |
| split_point = len(long_text) - 1 | |
| chunks.append(long_text[:split_point].strip()) | |
| long_text = long_text[split_point:].strip() | |
| if long_text: | |
| chunks.append(long_text) | |
| return chunks | |
| # Function to generate text | |
| def ocr_correction(prompt, max_new_tokens=600): | |
| prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n""" | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| # Generate text | |
| output = model.generate(input_ids, | |
| max_new_tokens=max_new_tokens, | |
| pad_token_id=tokenizer.eos_token_id, | |
| top_k=50) | |
| # Decode and return the generated text | |
| result = tokenizer.decode(output[0], skip_special_tokens=True) | |
| print(result) | |
| result = result.split("### Correction ###")[1] | |
| return result | |
| # OCR Correction Class | |
| class OCRCorrector: | |
| def __init__(self, system_prompt="Le dialogue suivant est une conversation"): | |
| self.system_prompt = system_prompt | |
| def correct(self, user_message): | |
| generated_text = ocr_correction(user_message) | |
| html_diff = generate_html_diff(user_message, generated_text) | |
| return generated_text, html_diff | |
| # Combined Processing Class | |
| class TextProcessor: | |
| def __init__(self): | |
| self.ocr_corrector = OCRCorrector() | |
| def process(self, user_message): | |
| #OCR Correction | |
| corrected_text, html_diff = self.ocr_corrector.correct(user_message) | |
| # Combine results | |
| ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>' | |
| final_output = f"{css}{ocr_result}" | |
| return final_output | |
| # Create the TextProcessor instance | |
| text_processor = TextProcessor() | |
| # Define the Gradio interface | |
| with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo: | |
| gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector</h1>""") | |
| text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5) | |
| process_button = gr.Button("Process Text") | |
| text_output = gr.HTML(label="Processed text") | |
| process_button.click(text_processor.process, inputs=text_input, outputs=[text_output]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |