wellsaid / app /services /grammar.py
iamspruce
fixed the api
73a6a7e
import difflib
import logging
from typing import List
import torch
from app.services.base import load_hf_pipeline
from app.core.config import settings
from app.core.exceptions import ServiceError
logger = logging.getLogger(f"{settings.APP_NAME}.services.grammar")
class GrammarCorrector:
def __init__(self):
self._pipeline = None
def _get_pipeline(self):
if self._pipeline is None:
logger.info("Loading grammar correction pipeline...")
self._pipeline = load_hf_pipeline(
model_id=settings.GRAMMAR_MODEL_ID,
task="text2text-generation",
feature_name="Grammar Correction"
)
return self._pipeline
async def correct(self, text: str) -> dict:
text = text.strip()
if not text:
raise ServiceError(status_code=400, detail="Input text is empty for grammar correction.")
try:
pipeline = self._get_pipeline()
result = pipeline(text, max_length=512, num_beams=4, early_stopping=True)
corrected = result[0]["generated_text"].strip()
if not corrected:
raise ServiceError(status_code=500, detail="Failed to decode grammar correction output.")
issues = self.get_diff_issues(text, corrected)
return {
"original_text": text,
"corrected_text_suggestion": corrected,
"issues": issues
}
except Exception as e:
logger.error(f"Grammar correction error for input: '{text[:50]}...'", exc_info=True)
raise ServiceError(status_code=500, detail="An internal error occurred during grammar correction.") from e
def get_diff_issues(self, original: str, corrected: str) -> List[dict]:
def safe_slice(s: str, start: int, end: int) -> str:
return s[max(0, start):min(len(s), end)]
matcher = difflib.SequenceMatcher(None, original, corrected)
issues = []
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == "equal":
continue
issues.append({
"offset": i1,
"length": i2 - i1,
"original_segment": original[i1:i2],
"suggested_segment": corrected[j1:j2],
"context_before": safe_slice(original, i1 - 15, i1),
"context_after": safe_slice(original, i2, i2 + 15),
"message": "Grammar correction",
"line": original[:i1].count("\n") + 1,
"column": (i1 - original[:i1].rfind("\n") - 1) if "\n" in original[:i1] else i1 + 1
})
return issues