Spaces:
Runtime error
Runtime error
import difflib | |
import logging | |
from typing import List | |
import torch | |
from app.services.base import load_hf_pipeline | |
from app.core.config import settings | |
from app.core.exceptions import ServiceError | |
logger = logging.getLogger(f"{settings.APP_NAME}.services.grammar") | |
class GrammarCorrector: | |
def __init__(self): | |
self._pipeline = None | |
def _get_pipeline(self): | |
if self._pipeline is None: | |
logger.info("Loading grammar correction pipeline...") | |
self._pipeline = load_hf_pipeline( | |
model_id=settings.GRAMMAR_MODEL_ID, | |
task="text2text-generation", | |
feature_name="Grammar Correction" | |
) | |
return self._pipeline | |
async def correct(self, text: str) -> dict: | |
text = text.strip() | |
if not text: | |
raise ServiceError(status_code=400, detail="Input text is empty for grammar correction.") | |
try: | |
pipeline = self._get_pipeline() | |
result = pipeline(text, max_length=512, num_beams=4, early_stopping=True) | |
corrected = result[0]["generated_text"].strip() | |
if not corrected: | |
raise ServiceError(status_code=500, detail="Failed to decode grammar correction output.") | |
issues = self.get_diff_issues(text, corrected) | |
return { | |
"original_text": text, | |
"corrected_text_suggestion": corrected, | |
"issues": issues | |
} | |
except Exception as e: | |
logger.error(f"Grammar correction error for input: '{text[:50]}...'", exc_info=True) | |
raise ServiceError(status_code=500, detail="An internal error occurred during grammar correction.") from e | |
def get_diff_issues(self, original: str, corrected: str) -> List[dict]: | |
def safe_slice(s: str, start: int, end: int) -> str: | |
return s[max(0, start):min(len(s), end)] | |
matcher = difflib.SequenceMatcher(None, original, corrected) | |
issues = [] | |
for tag, i1, i2, j1, j2 in matcher.get_opcodes(): | |
if tag == "equal": | |
continue | |
issues.append({ | |
"offset": i1, | |
"length": i2 - i1, | |
"original_segment": original[i1:i2], | |
"suggested_segment": corrected[j1:j2], | |
"context_before": safe_slice(original, i1 - 15, i1), | |
"context_after": safe_slice(original, i2, i2 + 15), | |
"message": "Grammar correction", | |
"line": original[:i1].count("\n") + 1, | |
"column": (i1 - original[:i1].rfind("\n") - 1) if "\n" in original[:i1] else i1 + 1 | |
}) | |
return issues | |