File size: 2,707 Bytes
ce2ce69
71192d1
73a6a7e
 
ce2ce69
 
73a6a7e
ce2ce69
73a6a7e
 
 
ce2ce69
71192d1
 
 
73a6a7e
71192d1
73a6a7e
 
 
 
 
 
 
ce2ce69
73a6a7e
71192d1
73a6a7e
 
 
 
71192d1
 
73a6a7e
ce2ce69
73a6a7e
 
 
 
 
ce2ce69
 
 
73a6a7e
ce2ce69
 
 
73a6a7e
ce2ce69
71192d1
73a6a7e
 
 
 
 
 
ce2ce69
 
 
 
 
73a6a7e
ce2ce69
 
 
 
 
73a6a7e
 
 
 
ce2ce69
 
73a6a7e
ce2ce69
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import difflib
import logging
from typing import List

import torch

from app.services.base import load_hf_pipeline
from app.core.config import settings
from app.core.exceptions import ServiceError

logger = logging.getLogger(f"{settings.APP_NAME}.services.grammar")


class GrammarCorrector:
    def __init__(self):
        self._pipeline = None

    def _get_pipeline(self):
        if self._pipeline is None:
            logger.info("Loading grammar correction pipeline...")
            self._pipeline = load_hf_pipeline(
                model_id=settings.GRAMMAR_MODEL_ID,
                task="text2text-generation",
                feature_name="Grammar Correction"
            )
        return self._pipeline

    async def correct(self, text: str) -> dict:
        text = text.strip()
        if not text:
            raise ServiceError(status_code=400, detail="Input text is empty for grammar correction.")

        try:
            pipeline = self._get_pipeline()

            result = pipeline(text, max_length=512, num_beams=4, early_stopping=True)
            corrected = result[0]["generated_text"].strip()

            if not corrected:
                raise ServiceError(status_code=500, detail="Failed to decode grammar correction output.")

            issues = self.get_diff_issues(text, corrected)

            return {
                "original_text": text,
                "corrected_text_suggestion": corrected,
                "issues": issues
            }

        except Exception as e:
            logger.error(f"Grammar correction error for input: '{text[:50]}...'", exc_info=True)
            raise ServiceError(status_code=500, detail="An internal error occurred during grammar correction.") from e

    def get_diff_issues(self, original: str, corrected: str) -> List[dict]:
        def safe_slice(s: str, start: int, end: int) -> str:
            return s[max(0, start):min(len(s), end)]

        matcher = difflib.SequenceMatcher(None, original, corrected)
        issues = []

        for tag, i1, i2, j1, j2 in matcher.get_opcodes():
            if tag == "equal":
                continue

            issues.append({
                "offset": i1,
                "length": i2 - i1,
                "original_segment": original[i1:i2],
                "suggested_segment": corrected[j1:j2],
                "context_before": safe_slice(original, i1 - 15, i1),
                "context_after": safe_slice(original, i2, i2 + 15),
                "message": "Grammar correction",
                "line": original[:i1].count("\n") + 1,
                "column": (i1 - original[:i1].rfind("\n") - 1) if "\n" in original[:i1] else i1 + 1
            })

        return issues