iamspruce commited on
Commit
73a6a7e
·
1 Parent(s): d893801

fixed the api

Browse files
Dockerfile CHANGED
@@ -2,40 +2,35 @@ FROM python:3.10-slim
2
 
3
  WORKDIR /app
4
 
5
- # Install system dependencies
6
- # git might not be strictly necessary for deployment unless you're cloning repos at runtime
7
- # but it's often useful for debugging or specific workflows.
8
- RUN apt-get update && apt-get install -y git && \
9
- rm -rf /var/lib/apt/lists/*
 
 
10
 
11
  COPY requirements.txt .
12
  RUN pip install --no-cache-dir -r requirements.txt
13
 
14
- # --- Install spaCy model ---
15
- # This downloads the small English model
16
  RUN python -m spacy download en_core_web_sm
17
-
18
- # --- Install NLTK WordNet data ---
19
- # This downloads the WordNet corpus for NLTK
20
  RUN python -m nltk.downloader wordnet
21
 
22
- # --- Configure cache directories for Hugging Face models ---
23
- # HF_HOME is where SentenceTransformers and other Hugging Face models will cache.
24
- # /.cache is also a common location many libraries default to if HF_HOME isn't set,
25
- # or for other internal caching. Setting permissions ensures the app can write there.
26
  ENV HF_HOME=/cache
27
- ENV TRANSFORMERS_CACHE=/cache
28
  ENV NLTK_DATA=/nltk_data
 
29
 
30
  # Create directories and set appropriate permissions
31
  RUN mkdir -p /cache && chmod -R 777 /cache
32
- RUN mkdir -p /root/.cache && chmod -R 777 /root/.cache
33
-
34
- # Ensure NLTK uses the specified data path.
35
- # This makes subsequent 'nltk.downloader' calls store data here,
36
- # and NLTK will look here first.
37
- RUN python -c "import nltk; nltk.data.path.append('/nltk_data')"
38
 
 
 
39
 
40
  COPY app ./app
41
 
 
2
 
3
  WORKDIR /app
4
 
5
+ # Install system dependencies (excluding git)
6
+ # Clean up apt lists to reduce image size
7
+ RUN apt-get update && apt-get install -y --no-install-recommends \
8
+ # Add any other core system dependencies here if needed, but not git
9
+ # e.g., libpq-dev for psycopg2, if you add a PostgreSQL dependency later
10
+ # Example: libpq-dev
11
+ && rm -rf /var/lib/apt/lists/*
12
 
13
  COPY requirements.txt .
14
  RUN pip install --no-cache-dir -r requirements.txt
15
 
16
+ # --- Pre-download models during Docker build ---
17
+ # Ensure spacy and nltk are installed via requirements.txt before these steps
18
  RUN python -m spacy download en_core_web_sm
 
 
 
19
  RUN python -m nltk.downloader wordnet
20
 
21
+ # --- Configure cache directories using Docker ENV (these take precedence) ---
 
 
 
22
  ENV HF_HOME=/cache
23
+ ENV TRANSFORMERS_CACHE=/cache
24
  ENV NLTK_DATA=/nltk_data
25
+ ENV SPACY_DATA=/spacy_data
26
 
27
  # Create directories and set appropriate permissions
28
  RUN mkdir -p /cache && chmod -R 777 /cache
29
+ RUN mkdir -p /nltk_data && chmod -R 777 /nltk_data
30
+ RUN mkdir -p /spacy_data && chmod -R 777 /spacy_data
 
 
 
 
31
 
32
+ # It's good to also create the /root/.cache for general system caching in Docker
33
+ RUN mkdir -p /root/.cache && chmod -R 777 /root/.cache
34
 
35
  COPY app ./app
36
 
app/core/app.py DELETED
@@ -1,37 +0,0 @@
1
- import os
2
- from fastapi import FastAPI
3
- from fastapi.middleware.gzip import GZipMiddleware
4
- from contextlib import asynccontextmanager
5
- from app.routers import grammar, tone, voice, inclusive_language, readability, paraphrase, translate, rewrite, analyze
6
- from app.queue import start_workers
7
- from app.core.middleware import setup_middlewares
8
-
9
- @asynccontextmanager
10
- async def lifespan(app: FastAPI):
11
- num_workers = int(os.getenv("WORKER_COUNT", 4))
12
- start_workers(num_workers)
13
- yield
14
-
15
- def create_app() -> FastAPI:
16
- app = FastAPI(lifespan=lifespan)
17
- app.add_middleware(GZipMiddleware, minimum_size=500)
18
- setup_middlewares(app)
19
-
20
- for router, tag in [
21
- (grammar.router, "Grammar"),
22
- (tone.router, "Tone"),
23
- (voice.router, "Voice"),
24
- (inclusive_language.router, "Inclusive Language"),
25
- (readability.router, "Readability"),
26
- (paraphrase.router, "Paraphrasing"),
27
- (translate.router, "Translation"),
28
- (rewrite.router, "Rewrite"),
29
- (analyze.router, "Analyze")
30
- ]:
31
- app.include_router(router, tags=[tag])
32
-
33
- @app.get("/")
34
- def root():
35
- return {"message": "Welcome to Wellsaid API"}
36
-
37
- return app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/core/config.py CHANGED
@@ -1,22 +1,90 @@
1
- from pydantic_settings import BaseSettings
2
- from typing import List
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class Settings(BaseSettings):
6
- # Server settings (user-configurable)
7
- HOST: str = "0.0.0.0"
8
- PORT: int = 7860
9
- RELOAD: bool = True
10
 
11
- # Security & workers
12
- WELLSAID_API_KEY: str = "12345"
13
- WORKER_COUNT: int = 4
14
 
15
- # Fixed/internal settings
16
- INCLUSIVE_RULES_DIR: str = "app/data/en"
17
  OPENAI_MODEL: str = "gpt-4o"
18
  OPENAI_TEMPERATURE: float = 0.7
19
- OPENAI_MAX_TOKENS: int = 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  SUPPORTED_TRANSLATION_LANGUAGES: List[str] = [
21
  "fr", "fr_BE", "fr_CA", "fr_FR", "wa", "frp", "oc", "ca", "rm", "lld",
22
  "fur", "lij", "lmo", "es", "es_AR", "es_CL", "es_CO", "es_CR", "es_DO",
@@ -26,18 +94,28 @@ class Settings(BaseSettings):
26
  "sc", "ro", "la"
27
  ]
28
 
29
- # Model names
30
- GRAMMAR_MODEL: str = "visheratin/t5-efficient-mini-grammar-correction"
31
- PARAPHRASE_MODEL: str = "humarin/chatgpt_paraphraser_on_T5_base"
32
- TONE_MODEL: str = "boltuix/NeuroFeel"
33
- TONE_CONFIDENCE_THRESHOLD: float = 0.2
34
- TRANSLATION_MODEL: str = "Helsinki-NLP/opus-mt-en-ROMANCE"
35
- SENTENCE_TRANSFORMER_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
36
-
37
- class Config:
38
- env_file = ".env"
39
- case_sensitive = True
40
 
 
 
 
41
 
42
- # Singleton instance
43
  settings = Settings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import List, Optional
5
 
6
+ from pydantic import Field
7
+ from pydantic_settings import BaseSettings, SettingsConfigDict
8
+
9
+ # ─────────────────────────────────────────────────────────────────────────────
10
+ # ⛺ Paths & Constants
11
+ # ─────────────────────────────────────────────────────────────────────────────
12
+
13
+ PROJECT_ROOT = Path(__file__).parent.parent
14
+ APP_DATA_ROOT_DIR = Path.home() / ".wellsaid_app_data"
15
+ MODELS_DIR = APP_DATA_ROOT_DIR / "models"
16
+ NLTK_DATA_DIR = APP_DATA_ROOT_DIR / "nltk_data"
17
+
18
+ OFFLINE_MODE = os.getenv("OFFLINE_MODE", "false").lower() == "true"
19
+
20
+ # ─────────────────────────────────────────────────────────────────────────────
21
+ # 📁 Ensure Directories Exist (for offline desktop usage)
22
+ # ─────────────────────────────────────────────────────────────────────────────
23
+
24
+ for directory in [MODELS_DIR, NLTK_DATA_DIR]:
25
+ try:
26
+ directory.mkdir(parents=True, exist_ok=True)
27
+ except Exception as e:
28
+ logging.warning(f"Failed to create directory {directory}: {e}")
29
+
30
+ # ─────────────────────────────────────────────────────────────────────────────
31
+ # 🌍 Environment Variables Setup (only if not already set)
32
+ # ─────────────────────────────────────────────────────────────────────────────
33
+
34
+ env_defaults = {
35
+ "HF_HOME": str(MODELS_DIR / "hf_cache"),
36
+ "NLTK_DATA": str(NLTK_DATA_DIR),
37
+ "SPACY_DATA": str(MODELS_DIR),
38
+ }
39
+
40
+ for var, default in env_defaults.items():
41
+ if not os.getenv(var):
42
+ os.environ[var] = default
43
+
44
+ # Update nltk.data.path immediately (if nltk is installed)
45
+ try:
46
+ import nltk
47
+ if str(NLTK_DATA_DIR) not in nltk.data.path:
48
+ nltk.data.path.append(str(NLTK_DATA_DIR))
49
+ except ImportError:
50
+ pass
51
+
52
+ # ─────────────────────────────────────────────────────────────────────────────
53
+ # ⚙️ Application Settings
54
+ # ─────────────────────────────────────────────────────────────────────────────
55
 
56
  class Settings(BaseSettings):
57
+ model_config = SettingsConfigDict(env_file=".env", extra="ignore")
 
 
 
58
 
59
+ # App basics
60
+ APP_NAME: str = "WellSaidApp"
61
+ API_KEY: str = "your_strong_api_key_here"
62
 
63
+ # OpenAI
64
+ OPENAI_API_KEY: Optional[str] = None
65
  OPENAI_MODEL: str = "gpt-4o"
66
  OPENAI_TEMPERATURE: float = 0.7
67
+ OPENAI_MAX_TOKENS: int = 1500
68
+
69
+ # API server
70
+ HOST: str = "127.0.0.1"
71
+ PORT: int = 8000
72
+ RELOAD: bool = False
73
+ WORKER_COUNT: int = 1
74
+
75
+ # NLP models
76
+ SPACY_MODEL_ID: str = "en_core_web_sm"
77
+ SENTENCE_TRANSFORMER_MODEL_ID: str = "all-MiniLM-L6-v2"
78
+ SENTENCE_TRANSFORMER_BATCH_SIZE: int = 2
79
+
80
+ GRAMMAR_MODEL_ID: str = "visheratin/t5-efficient-mini-grammar-correction"
81
+ PARAPHRASE_MODEL_ID: str = "humarin/chatgpt_paraphraser_on_T5_base"
82
+ TONE_MODEL_ID: str = "boltuix/NeuroFeel"
83
+ TONE_CONFIDENCE_THRESHOLD: float = 1.0
84
+ TRANSLATION_MODEL_ID: str = "Helsinki-NLP/opus-mt-en-ROMANCE"
85
+
86
+ WORDNET_NLTK_ID: str = "wordnet.zip"
87
+
88
  SUPPORTED_TRANSLATION_LANGUAGES: List[str] = [
89
  "fr", "fr_BE", "fr_CA", "fr_FR", "wa", "frp", "oc", "ca", "rm", "lld",
90
  "fur", "lij", "lmo", "es", "es_AR", "es_CL", "es_CO", "es_CR", "es_DO",
 
94
  "sc", "ro", "la"
95
  ]
96
 
97
+ # Data dirs
98
+ INCLUSIVE_RULES_DIR: str = "app/data/en"
 
 
 
 
 
 
 
 
 
99
 
100
+ # ─────────────────────────────────────────────────────────────────────────────
101
+ # 📦 App-wide constants
102
+ # ─────────────────────────────────────────────────────────────────────────────
103
 
 
104
  settings = Settings()
105
+
106
+ # Core settings for import
107
+ APP_NAME = settings.APP_NAME
108
+ LOCAL_API_HOST = settings.HOST
109
+ LOCAL_API_PORT = settings.PORT
110
+
111
+ # Model names
112
+ SPACY_MODEL_ID = settings.SPACY_MODEL_ID
113
+ SENTENCE_TRANSFORMER_MODEL_ID = settings.SENTENCE_TRANSFORMER_MODEL_ID
114
+ GRAMMAR_MODEL_ID = settings.GRAMMAR_MODEL_ID
115
+ PARAPHRASE_MODEL_ID = settings.PARAPHRASE_MODEL_ID
116
+ TONE_MODEL_ID = settings.TONE_MODEL_ID
117
+ TRANSLATION_MODEL_ID = settings.TRANSLATION_MODEL_ID
118
+ WORDNET_NLTK_ID = settings.WORDNET_NLTK_ID
119
+
120
+ # Data
121
+ INCLUSIVE_RULES_DIR = settings.INCLUSIVE_RULES_DIR
app/core/exceptions.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/core/exceptions.py
2
+ from fastapi import HTTPException
3
+
4
+
5
+ class ServiceError(HTTPException):
6
+ """
7
+ Base exception for general service-related errors.
8
+ Inherits from HTTPException to allow direct use in FastAPI responses.
9
+ """
10
+ def __init__(self, status_code: int, detail: str, error_type: str = "ServiceError"):
11
+ super().__init__(status_code=status_code, detail=detail)
12
+ self.error_type = error_type
13
+
14
+ def to_dict(self):
15
+ """Returns a dictionary representation of the exception."""
16
+ return {
17
+ "detail": self.detail,
18
+ "status_code": self.status_code,
19
+ "error_type": self.error_type
20
+ }
21
+
22
+
23
+ class ModelNotDownloadedError(ServiceError):
24
+ """
25
+ Raised when a required model is not found locally.
26
+ Informs the client that a download is necessary.
27
+ """
28
+ def __init__(self, model_id: str, feature_name: str, detail: str = None):
29
+ detail = detail or f"Model '{model_id}' required for '{feature_name}' is not downloaded."
30
+ super().__init__(status_code=424, detail=detail, error_type="ModelNotDownloaded")
31
+ self.model_id = model_id
32
+ self.feature_name = feature_name
33
+
34
+ def to_dict(self):
35
+ base_dict = super().to_dict()
36
+ base_dict.update({
37
+ "model_id": self.model_id,
38
+ "feature_name": self.feature_name
39
+ })
40
+ return base_dict
41
+
42
+
43
+ class ModelDownloadFailedError(ServiceError):
44
+ """Exception raised when a model download operation fails."""
45
+ def __init__(self, model_id: str, feature_name: str, original_error: str = "Unknown error"):
46
+ super().__init__(
47
+ status_code=503, # Service Unavailable
48
+ detail=f"Failed to download model '{model_id}' for '{feature_name}'. Please check your internet connection or try again. Error: {original_error}",
49
+ error_type="ModelDownloadFailed",
50
+ model_id=model_id,
51
+ feature_name=feature_name
52
+ )
53
+ self.original_error = original_error
app/core/logging.py CHANGED
@@ -1,12 +1,55 @@
1
- import os
2
  import logging
 
 
 
 
3
 
4
  def configure_logging():
5
- logging.basicConfig(
6
- level=logging.INFO,
7
- format="%(asctime)s - %(levelname)s - %(message)s",
8
- handlers=[
9
- logging.StreamHandler(),
10
- logging.FileHandler(os.getenv("LOG_FILE", "app.log"))
11
- ]
 
 
 
 
 
12
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  import logging
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from app.core.config import APP_DATA_ROOT_DIR, APP_NAME
7
 
8
  def configure_logging():
9
+ """
10
+ Configures application-wide logging to both console and a file.
11
+ The log file is placed in the application's data directory.
12
+ """
13
+ log_dir = APP_DATA_ROOT_DIR / "logs"
14
+ log_dir.mkdir(parents=True, exist_ok=True) # Ensure the log directory exists
15
+
16
+ log_file_path = log_dir / f"{APP_NAME.lower()}.log"
17
+
18
+ # Define a custom formatter
19
+ formatter = logging.Formatter(
20
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
21
  )
22
+
23
+ # Console handler
24
+ console_handler = logging.StreamHandler()
25
+ console_handler.setFormatter(formatter)
26
+ console_handler.setLevel(logging.INFO) # Default console level
27
+
28
+ # File handler
29
+ file_handler = logging.FileHandler(log_file_path)
30
+ file_handler.setFormatter(formatter)
31
+ file_handler.setLevel(logging.INFO) # Default file level
32
+
33
+ # Get the root logger
34
+ root_logger = logging.getLogger()
35
+ root_logger.setLevel(logging.INFO) # Overall minimum logging level
36
+
37
+ # Clear existing handlers to prevent duplicate logs if called multiple times
38
+ if root_logger.hasHandlers():
39
+ root_logger.handlers.clear()
40
+
41
+ root_logger.addHandler(console_handler)
42
+ root_logger.addHandler(file_handler)
43
+
44
+ # Set specific log levels for libraries if needed (e.g., to reduce verbosity)
45
+ logging.getLogger("uvicorn").setLevel(logging.WARNING)
46
+ logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
47
+ logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
48
+ logging.getLogger("transformers").setLevel(logging.WARNING)
49
+ logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
50
+ logging.getLogger("nltk").setLevel(logging.WARNING)
51
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
52
+ logging.getLogger("asyncio").setLevel(logging.WARNING) # Reduce asyncio verbosity
53
+
54
+ logger = logging.getLogger(f"{APP_NAME}.core.logging")
55
+ logger.info(f"Logging configured. Logs are saved to: {log_file_path}")
app/core/model_manager.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/core/model_manager.py
2
+ import logging
3
+ import os
4
+ import asyncio
5
+ from pathlib import Path
6
+ from typing import Callable, Optional, Dict, List
7
+
8
+ # Imports for downloading specific model types
9
+ import nltk
10
+ from huggingface_hub import snapshot_download
11
+ import spacy.cli
12
+
13
+ # Internal application imports
14
+ from app.core.config import (
15
+ MODELS_DIR,
16
+ NLTK_DATA_DIR,
17
+ SPACY_MODEL_ID,
18
+ SENTENCE_TRANSFORMER_MODEL_ID,
19
+ TONE_MODEL_ID,
20
+ TRANSLATION_MODEL_ID,
21
+ WORDNET_NLTK_ID,
22
+ APP_NAME
23
+ )
24
+ from app.core.exceptions import ModelNotDownloadedError, ModelDownloadFailedError, ServiceError
25
+
26
+ logger = logging.getLogger(f"{APP_NAME}.core.model_manager")
27
+
28
+ # Type alias for progress callback
29
+ ProgressCallback = Callable[[str, str, float, Optional[str]], None] # (model_id, status, progress, message)
30
+
31
+ def _get_hf_model_local_path(model_id: str) -> Path:
32
+ """Helper to get the expected local path for a Hugging Face model."""
33
+ # snapshot_download creates a specific folder structure inside MODELS_DIR/hf_cache
34
+ # For example, for "bert-base-uncased", it might be MODELS_DIR/hf_cache/models--bert-base-uncased
35
+ # The actual model files are inside that.
36
+ # The `transformers` library usually handles this resolution.
37
+ # We just need to check if the directory created by snapshot_download exists.
38
+ # A robust check involves looking inside that directory.
39
+ return MODELS_DIR / "hf_cache" / model_id.replace("/", "--") # Standard HF cache path logic
40
+
41
+
42
+ def check_model_exists(model_id: str, model_type: str) -> bool:
43
+ """
44
+ Checks if a specific model or NLTK data is already downloaded locally.
45
+ """
46
+ if model_type == "huggingface":
47
+ local_path = _get_hf_model_local_path(model_id)
48
+ # Check if the directory exists and contains some files
49
+ return local_path.is_dir() and any(local_path.iterdir())
50
+ elif model_type == "spacy":
51
+ # spaCy models are symlinked or copied into a specific site-packages location
52
+ # The easiest check is to try loading it, or check spacy.util.is_package
53
+ # For our purposes, we'll check if the directory created by `spacy download` exists
54
+ # within our MODELS_DIR, assuming we direct spaCy there.
55
+ # However, `spacy.load` is the most reliable. For pre-check, we'll rely on the
56
+ # existence check in load_spacy_model. This is a simplified check.
57
+ # The actual loading process in app.services.base handles the `is_package` check.
58
+ # For `spacy.cli.download` to work with MODELS_DIR, it often requires setting SPACY_DATA.
59
+ spacy_target_path = MODELS_DIR / model_id
60
+ return spacy_target_path.is_dir() and any(spacy_target_path.iterdir())
61
+ elif model_type == "nltk":
62
+ # NLTK data check
63
+ try:
64
+ return nltk.data.find(f"corpora/{model_id}") is not None
65
+ except LookupError:
66
+ return False
67
+ else:
68
+ logger.warning(f"Unknown model type for check_model_exists: {model_type}")
69
+ return False
70
+
71
+ # --- Download Functions ---
72
+
73
+ async def download_hf_model_async(
74
+ model_id: str,
75
+ feature_name: str,
76
+ progress_callback: Optional[ProgressCallback] = None
77
+ ) -> None:
78
+ """
79
+ Asynchronously downloads a Hugging Face model from the Hub.
80
+ """
81
+ logger.info(f"Initiating download for Hugging Face model '{model_id}' for '{feature_name}'...")
82
+ if check_model_exists(model_id, "huggingface"):
83
+ logger.info(f"Hugging Face model '{model_id}' already exists locally. Skipping download.")
84
+ if progress_callback:
85
+ progress_callback(model_id, "completed", 1.0, "Already downloaded.")
86
+ return
87
+
88
+ # Use a thread pool for blocking download operation
89
+ try:
90
+ def _blocking_download():
91
+ # This downloads to MODELS_DIR/hf_cache by default if HF_HOME is set to MODELS_DIR
92
+ # Otherwise, specify cache_dir.
93
+ # For simplicity, we rely on `settings.MODELS_DIR` handling HF_HOME in config.py
94
+ snapshot_download(
95
+ repo_id=model_id,
96
+ cache_dir=str(MODELS_DIR / "hf_cache"), # Explicitly set cache directory
97
+ local_dir_use_symlinks=False, # Use False for better self-contained app
98
+ # The `_` prefix means it's an internal parameter not typically exposed.
99
+ # `progress_callback` in `snapshot_download` is not directly exposed for live updates.
100
+ # We log at beginning and end.
101
+ )
102
+ logger.info(f"Hugging Face model '{model_id}' download complete.")
103
+
104
+ if progress_callback:
105
+ progress_callback(model_id, "downloading", 0.05, "Starting download...")
106
+
107
+ await asyncio.to_thread(_blocking_download) # Run blocking download in a separate thread
108
+
109
+ if progress_callback:
110
+ progress_callback(model_id, "completed", 1.0, "Download successful.")
111
+
112
+ except Exception as e:
113
+ logger.error(f"Failed to download Hugging Face model '{model_id}': {e}", exc_info=True)
114
+ if progress_callback:
115
+ progress_callback(model_id, "failed", 0.0, f"Error: {e}")
116
+ raise ModelDownloadFailedError(model_id, feature_name, original_error=str(e))
117
+
118
+
119
+ async def download_spacy_model_async(
120
+ model_id: str,
121
+ feature_name: str,
122
+ progress_callback: Optional[ProgressCallback] = None
123
+ ) -> None:
124
+ """
125
+ Asynchronously downloads a spaCy model.
126
+ """
127
+ logger.info(f"Initiating download for spaCy model '{model_id}' for '{feature_name}'...")
128
+ # Check if the model package is already installed/available in the spacy data path
129
+ # NOTE: This check might not be sufficient if SPACY_DATA isn't correctly pointing.
130
+ # The `spacy.util.is_package` would be more robust but requires `import spacy` first.
131
+ # For now, we trust `spacy.cli.download` to handle the check or fail gracefully.
132
+
133
+ # We must ensure SPACY_DATA environment variable is set to MODELS_DIR
134
+ # for spacy.cli.download to put it in our custom path.
135
+ original_spacy_data = os.environ.get("SPACY_DATA")
136
+ try:
137
+ os.environ["SPACY_DATA"] = str(MODELS_DIR)
138
+
139
+ if check_model_exists(model_id, "spacy"): # Using our own simplified check
140
+ logger.info(f"SpaCy model '{model_id}' already exists locally. Skipping download.")
141
+ if progress_callback:
142
+ progress_callback(model_id, "completed", 1.0, "Already downloaded.")
143
+ return
144
+
145
+ def _blocking_download():
146
+ # spacy.cli.download attempts to download and link/copy
147
+ # It will raise an error if already downloaded if it can't link, etc.
148
+ # We're relying on our check_model_exists before this.
149
+ spacy.cli.download(model_id)
150
+ logger.info(f"SpaCy model '{model_id}' download complete.")
151
+
152
+ if progress_callback:
153
+ progress_callback(model_id, "downloading", 0.05, "Starting download...")
154
+
155
+ await asyncio.to_thread(_blocking_download)
156
+
157
+ if progress_callback:
158
+ progress_callback(model_id, "completed", 1.0, "Download successful.")
159
+
160
+ except Exception as e:
161
+ logger.error(f"Failed to download spaCy model '{model_id}': {e}", exc_info=True)
162
+ if progress_callback:
163
+ progress_callback(model_id, "failed", 0.0, f"Error: {e}")
164
+ raise ModelDownloadFailedError(model_id, feature_name, original_error=str(e))
165
+ finally:
166
+ # Restore original SPACY_DATA if it was set
167
+ if original_spacy_data is not None:
168
+ os.environ["SPACY_DATA"] = original_spacy_data
169
+ else:
170
+ if "SPACY_DATA" in os.environ:
171
+ del os.environ["SPACY_DATA"]
172
+
173
+
174
+ async def download_nltk_data_async(
175
+ data_id: str,
176
+ feature_name: str,
177
+ progress_callback: Optional[ProgressCallback] = None
178
+ ) -> None:
179
+ """
180
+ Asynchronously downloads NLTK data.
181
+ """
182
+ logger.info(f"Initiating download for NLTK data '{data_id}' for '{feature_name}'...")
183
+ # NLTK data path should be set by NLTK_DATA environment variable in config.py
184
+ # `nltk.download` will use this path.
185
+
186
+ if check_model_exists(data_id, "nltk"):
187
+ logger.info(f"NLTK data '{data_id}' already exists locally. Skipping download.")
188
+ if progress_callback:
189
+ progress_callback(data_id, "completed", 1.0, "Already downloaded.")
190
+ return
191
+
192
+ def _blocking_download():
193
+ # NLTK downloader can show a GUI, so ensure it's not trying to do that
194
+ # `download_dir` should be set by NLTK_DATA env variable.
195
+ # `quiet=True` is important for programmatic download.
196
+ nltk.download(data_id, download_dir=str(NLTK_DATA_DIR), quiet=True)
197
+ logger.info(f"NLTK data '{data_id}' download complete.")
198
+
199
+ try:
200
+ if progress_callback:
201
+ progress_callback(data_id, "downloading", 0.05, "Starting download...")
202
+
203
+ await asyncio.to_thread(_blocking_download)
204
+
205
+ if progress_callback:
206
+ progress_callback(data_id, "completed", 1.0, "Download successful.")
207
+
208
+ except Exception as e:
209
+ logger.error(f"Failed to download NLTK data '{data_id}': {e}", exc_info=True)
210
+ if progress_callback:
211
+ progress_callback(data_id, "failed", 0.0, f"Error: {e}")
212
+ raise ModelDownloadFailedError(data_id, feature_name, original_error=str(e))
213
+
214
+
215
+ # --- Comprehensive Model Management ---
216
+
217
+ def get_all_required_models() -> List[Dict]:
218
+ """
219
+ Returns a list of all models required by the application, with their type and feature.
220
+ """
221
+ return [
222
+ {"id": SPACY_MODEL_ID, "type": "spacy", "feature": "Text Processing (General)"},
223
+ {"id": SENTENCE_TRANSFORMER_MODEL_ID, "type": "huggingface", "feature": "Sentence Embeddings"},
224
+ {"id": TONE_MODEL_ID, "type": "huggingface", "feature": "Tone Classification"},
225
+ {"id": TRANSLATION_MODEL_ID, "type": "huggingface", "feature": "Translation"},
226
+ {"id": WORDNET_NLTK_ID, "type": "nltk", "feature": "Synonym Suggestion"},
227
+ # Add any other models here as your application grows
228
+ ]
229
+
230
+ async def download_all_required_models(progress_callback: Optional[ProgressCallback] = None) -> Dict[str, str]:
231
+ """
232
+ Attempts to download all required models.
233
+ Returns a dictionary of download statuses.
234
+ """
235
+ required_models = get_all_required_models()
236
+ download_statuses = {}
237
+
238
+ for model_info in required_models:
239
+ model_id = model_info["id"]
240
+ model_type = model_info["type"]
241
+ feature_name = model_info["feature"]
242
+
243
+ if check_model_exists(model_id, model_type):
244
+ status_message = f"'{model_id}' ({feature_name}) already downloaded."
245
+ logger.info(status_message)
246
+ download_statuses[model_id] = "already_downloaded"
247
+ if progress_callback:
248
+ progress_callback(model_id, "completed", 1.0, status_message)
249
+ continue
250
+
251
+ logger.info(f"Attempting to download '{model_id}' ({feature_name})...")
252
+ try:
253
+ if model_type == "huggingface":
254
+ await download_hf_model_async(model_id, feature_name, progress_callback)
255
+ elif model_type == "spacy":
256
+ await download_spacy_model_async(model_id, feature_name, progress_callback)
257
+ elif model_type == "nltk":
258
+ await download_nltk_data_async(model_id, feature_name, progress_callback)
259
+ else:
260
+ raise ValueError(f"Unsupported model type: {model_type}")
261
+
262
+ status_message = f"'{model_id}' ({feature_name}) downloaded successfully."
263
+ logger.info(status_message)
264
+ download_statuses[model_id] = "success"
265
+
266
+ except ModelDownloadFailedError as e:
267
+ status_message = f"Failed to download '{model_id}' ({feature_name}): {e.original_error}"
268
+ logger.error(status_message)
269
+ download_statuses[model_id] = "failed"
270
+ # The progress_callback is already called within the specific download functions on failure
271
+ except Exception as e:
272
+ status_message = f"An unexpected error occurred while downloading '{model_id}' ({feature_name}): {e}"
273
+ logger.error(status_message, exc_info=True)
274
+ download_statuses[model_id] = "failed"
275
+ if progress_callback:
276
+ progress_callback(model_id, "failed", 0.0, status_message)
277
+
278
+
279
+ logger.info("Finished attempting to download all required models.")
280
+ return download_statuses
app/core/prompts.py DELETED
@@ -1,28 +0,0 @@
1
- def tone_prompt(text: str, tone: str) -> str:
2
- return f"Change the tone of this sentence to {tone}: {text.strip()}"
3
-
4
- def summarize_prompt(text: str) -> str:
5
- return f"Summarize the following text:\n{text.strip()}"
6
-
7
- def clarity_prompt(text: str) -> str:
8
- return f"Improve the clarity of the following sentence:\n{text.strip()}"
9
-
10
- def rewrite_prompt(text: str, instruction: str) -> str:
11
- return f"{instruction.strip()}\n{text.strip()}"
12
-
13
- def vocabulary_prompt(text: str) -> str:
14
- return (
15
- "You are an expert vocabulary enhancer. Rewrite the following text "
16
- "by replacing common and simple words with more sophisticated, "
17
- "precise, and contextually appropriate synonyms. Do not change "
18
- "the original meaning. Maintain the tone.\n" + text.strip()
19
- )
20
-
21
- def concise_prompt(text: str) -> str:
22
- return (
23
- "You are an expert editor specializing in conciseness. "
24
- "Rewrite the following text to be more concise and to the point, "
25
- "removing any verbose phrases, redundant words, or unnecessary clauses. "
26
- "Maintain the original meaning and professional tone.\n" + text.strip()
27
- )
28
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/main.py CHANGED
@@ -1,5 +1,119 @@
1
- from app.core.app import create_app
2
- from app.core.logging import configure_logging
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  configure_logging()
5
- app = create_app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/main.py
2
+ import logging
3
+ from contextlib import asynccontextmanager
4
 
5
+ from fastapi import FastAPI, Request, status
6
+ from fastapi.responses import JSONResponse
7
+ from fastapi.middleware.gzip import GZipMiddleware
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+
10
+ from app.core.config import APP_NAME # For logger naming
11
+ from app.core.logging import configure_logging # Import the new logging configuration
12
+ from app.core.exceptions import ServiceError, ModelNotDownloadedError # Import custom exceptions
13
+
14
+ # Import your routers
15
+ # Adjust these imports if your router file names or structure are different
16
+ from app.routers import (
17
+ grammar, tone, voice, inclusive_language,
18
+ readability, paraphrase, translate, synonyms, rewrite, analyze
19
+ )
20
+
21
+ # Configure logging at the very beginning
22
  configure_logging()
23
+ logger = logging.getLogger(f"{APP_NAME}.main")
24
+
25
+
26
+ @asynccontextmanager
27
+ async def lifespan(app: FastAPI):
28
+ """
29
+ Context manager for application startup and shutdown events.
30
+ Models are now lazily loaded, so no explicit loading here.
31
+ """
32
+ logger.info("Application starting up...")
33
+ # Any other global startup tasks can go here
34
+ yield
35
+ logger.info("Application shutting down...")
36
+ # Any global shutdown tasks can go here (e.g., closing database connections)
37
+
38
+
39
+ app = FastAPI(
40
+ title="Writing Assistant API (Local)",
41
+ description="Local API for the desktop Writing Assistant application, providing various NLP functionalities.",
42
+ version="0.1.0",
43
+ lifespan=lifespan,
44
+ )
45
+
46
+ # --- Middleware Setup ---
47
+ app.add_middleware(GZipMiddleware, minimum_size=500)
48
+
49
+ # CORS Middleware for local development/desktop app scenarios
50
+ # Allows all origins for local testing. Restrict as needed for deployment.
51
+ app.add_middleware(
52
+ CORSMiddleware,
53
+ allow_origins=["*"], # Adjust this for specific origins in a web deployment
54
+ allow_credentials=True,
55
+ allow_methods=["*"],
56
+ allow_headers=["*"],
57
+ )
58
+
59
+ # --- Global Exception Handlers ---
60
+ @app.exception_handler(ServiceError)
61
+ async def service_error_handler(request: Request, exc: ServiceError):
62
+ """
63
+ Handles custom ServiceError exceptions, returning a structured JSON response.
64
+ """
65
+ logger.error(f"Service Error caught for path {request.url.path}: {exc.detail}", exc_info=True)
66
+ return JSONResponse(
67
+ status_code=exc.status_code,
68
+ content=exc.to_dict(), # Use the to_dict method from ServiceError
69
+ )
70
+
71
+ @app.exception_handler(ModelNotDownloadedError)
72
+ async def model_not_downloaded_error_handler(request: Request, exc: ModelNotDownloadedError):
73
+ """
74
+ Handles ModelNotDownloadedError exceptions, informing the client a model is missing.
75
+ """
76
+ logger.warning(f"Model Not Downloaded Error caught for path {request.url.path}: Model '{exc.model_id}' is missing for feature '{exc.feature_name}'.")
77
+ return JSONResponse(
78
+ status_code=exc.status_code,
79
+ content=exc.to_dict(), # Use the to_dict method from ModelNotDownloadedError
80
+ )
81
+
82
+ @app.exception_handler(Exception)
83
+ async def general_exception_handler(request: Request, exc: Exception):
84
+ """
85
+ Handles all other unhandled exceptions, returning a generic server error.
86
+ """
87
+ logger.exception(f"Unhandled exception caught for path {request.url.path}: {exc}") # Use logger.exception to log traceback
88
+ return JSONResponse(
89
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
90
+ content={
91
+ "detail": "An unexpected internal server error occurred.",
92
+ "error_type": "InternalServerError",
93
+ },
94
+ )
95
+
96
+ # --- Include Routers ---
97
+ # Note: You will need to create/update these files in app/routers/
98
+ # if they don't exist or don't match the new async service methods.
99
+ for router, tag in [
100
+ (grammar.router, "Grammar"),
101
+ (tone.router, "Tone"),
102
+ (voice.router, "Voice"),
103
+ (inclusive_language.router, "Inclusive Language"),
104
+ (readability.router, "Readability"),
105
+ (rewrite.router, "Rewrite"),
106
+ (analyze.router, "Analyze"),
107
+ (paraphrase.router, "Paraphrasing"),
108
+ (translate.router, "Translation"),
109
+ (synonyms.router, "Synonyms")
110
+ ]:
111
+ app.include_router(router, tags=[tag])
112
+
113
+ # --- Root Endpoint ---
114
+ @app.get("/", tags=["Health Check"])
115
+ async def root():
116
+ """
117
+ Root endpoint for health check.
118
+ """
119
+ return {"message": "Writing Assistant API is running!"}
app/queue.py DELETED
@@ -1,104 +0,0 @@
1
- import asyncio
2
- import logging
3
- import time
4
- import uuid
5
- import inspect
6
-
7
- from app.services.grammar import GrammarCorrector
8
- from app.services.paraphrase import Paraphraser
9
- from app.services.translation import Translator
10
- from app.services.tone_classification import ToneClassifier
11
- from app.services.inclusive_language import InclusiveLanguageChecker
12
- from app.services.voice_detection import VoiceDetector
13
- from app.services.readability import ReadabilityScorer
14
- from app.services.synonyms import SynonymSuggester
15
- from app.core.config import settings
16
-
17
- # Configure logging
18
- logging.basicConfig(
19
- level=logging.DEBUG if getattr(settings, "DEBUG", False) else logging.INFO,
20
- format="%(asctime)s [%(levelname)s] %(message)s"
21
- )
22
-
23
- # Initialize service instances
24
- grammar = GrammarCorrector()
25
- paraphraser = Paraphraser()
26
- translator = Translator()
27
- tone = ToneClassifier()
28
- inclusive = InclusiveLanguageChecker()
29
- voice_analyzer = VoiceDetector()
30
- readability = ReadabilityScorer()
31
- synonyms = SynonymSuggester()
32
-
33
- # Create async task queue (optional: maxsize=100 to prevent overload)
34
- task_queue = asyncio.Queue(maxsize=100)
35
-
36
- # Task handler map
37
- SERVICE_HANDLERS = {
38
- "grammar": lambda p: grammar.correct(p["text"]),
39
- "paraphrase": lambda p: paraphraser.paraphrase(p["text"]),
40
- "translate": lambda p: translator.translate(p["text"], p["target_lang"]),
41
- "tone": lambda p: tone.classify(p["text"]),
42
- "inclusive": lambda p: inclusive.check(p["text"]),
43
- "voice": lambda p: voice_analyzer.classify(p["text"]),
44
- "readability": lambda p: readability.compute(p["text"]),
45
- "synonyms": lambda p: synonyms.suggest(p["text"]), # ✅ This is async
46
- }
47
-
48
- async def worker(worker_id: int):
49
- logging.info(f"Worker-{worker_id} started")
50
-
51
- while True:
52
- task = await task_queue.get()
53
- future = task["future"]
54
- task_type = task["type"]
55
- payload = task["payload"]
56
- task_id = task["id"]
57
-
58
- start_time = time.perf_counter()
59
- logging.info(f"[Worker-{worker_id}] Processing Task-{task_id} | Type: {task_type} | Queue size: {task_queue.qsize()}")
60
-
61
- try:
62
- handler = SERVICE_HANDLERS.get(task_type)
63
- if not handler:
64
- raise ValueError(f"Unknown task type: {task_type}")
65
-
66
- result = handler(payload)
67
- if inspect.isawaitable(result):
68
- result = await result
69
-
70
- elapsed = time.perf_counter() - start_time
71
- logging.info(f"[Worker-{worker_id}] Finished Task-{task_id} in {elapsed:.2f}s")
72
-
73
- if not future.done():
74
- future.set_result(result)
75
-
76
- except Exception as e:
77
- logging.error(f"[Worker-{worker_id}] Error in Task-{task_id} ({task_type}): {e}")
78
- if not future.done():
79
- future.set_result({"error": str(e)})
80
-
81
- task_queue.task_done()
82
-
83
- def start_workers(count: int = 2):
84
- for i in range(count):
85
- asyncio.create_task(worker(i))
86
-
87
- async def enqueue_task(task_type: str, payload: dict, timeout: float = 10.0):
88
- future = asyncio.get_event_loop().create_future()
89
- task_id = str(uuid.uuid4())[:8]
90
-
91
- await task_queue.put({
92
- "future": future,
93
- "type": task_type,
94
- "payload": payload,
95
- "id": task_id
96
- })
97
-
98
- logging.info(f"[ENQUEUE] Task-{task_id} added to queue | Type: {task_type} | Queue size: {task_queue.qsize()}")
99
-
100
- try:
101
- return await asyncio.wait_for(future, timeout=timeout)
102
- except asyncio.TimeoutError:
103
- logging.warning(f"[ENQUEUE] Task-{task_id} timed out after {timeout}s")
104
- return {"error": f"Task {task_type} timed out after {timeout} seconds."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/routers/analyze.py CHANGED
@@ -1,45 +1,97 @@
1
- from fastapi import APIRouter, Depends, HTTPException
2
- from app.core.security import verify_api_key
3
- from app.schemas.base import TextOnlyRequest
4
- from app.queue import task_queue
5
- import asyncio
6
- import uuid
7
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  router = APIRouter(prefix="/analyze", tags=["Analysis"])
10
- logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
11
 
12
  @router.post("/", dependencies=[Depends(verify_api_key)])
13
- async def analyze_text(payload: TextOnlyRequest):
 
 
 
 
14
  text = payload.text.strip()
15
  if not text:
16
- raise HTTPException(status_code=400, detail="Input text cannot be empty.")
17
-
18
- loop = asyncio.get_event_loop()
19
-
20
- task_definitions = [
21
- ("grammar", {"text": text}),
22
- ("tone", {"text": text}),
23
- ("inclusive", {"text": text}),
24
- ("voice", {"text": text}),
25
- ("readability", {"text": text}),
26
- ("synonyms", {"text": text}),
27
- ]
28
-
29
- futures = []
30
- for task_type, task_payload in task_definitions:
31
- future = loop.create_future()
32
- task_id = str(uuid.uuid4())[:8]
33
-
34
- await task_queue.put({
35
- "type": task_type,
36
- "payload": task_payload,
37
- "future": future,
38
- "id": task_id
39
- })
40
- futures.append((task_type, future))
41
-
42
- results = await asyncio.gather(*[fut for _, fut in futures])
43
- response = {task_type: result for (task_type, _), result in zip(futures, results)}
44
-
45
- return {"analysis": response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/routers/analyze.py
 
 
 
 
 
2
  import logging
3
+ import asyncio
4
+ from fastapi import APIRouter, Depends, HTTPException, status
5
+
6
+ from app.schemas.base import TextOnlyRequest # Assuming this Pydantic model exists
7
+ from app.services.grammar import GrammarCorrector
8
+ from app.services.tone_classification import ToneClassifier
9
+ from app.services.inclusive_language import InclusiveLanguageChecker
10
+ from app.services.voice_detection import VoiceDetector
11
+ from app.services.readability import ReadabilityScorer
12
+ from app.services.synonyms import SynonymSuggester
13
+ from app.core.security import verify_api_key # Assuming you still need API key verification
14
+ from app.core.config import APP_NAME # For logger naming
15
+ from app.core.exceptions import ServiceError, ModelNotDownloadedError # Import custom exceptions
16
+
17
+ logger = logging.getLogger(f"{APP_NAME}.routers.analyze")
18
 
19
  router = APIRouter(prefix="/analyze", tags=["Analysis"])
20
+
21
+ # Initialize service instances once per application lifecycle
22
+ # These services will handle lazy loading their models internally
23
+ grammar_service = GrammarCorrector()
24
+ tone_service = ToneClassifier()
25
+ inclusive_service = InclusiveLanguageChecker()
26
+ voice_service = VoiceDetector()
27
+ readability_service = ReadabilityScorer()
28
+ synonyms_service = SynonymSuggester()
29
+
30
 
31
  @router.post("/", dependencies=[Depends(verify_api_key)])
32
+ async def analyze_text_endpoint(payload: TextOnlyRequest):
33
+ """
34
+ Performs a comprehensive analysis of the provided text,
35
+ including grammar, tone, inclusive language, voice, readability, and synonyms.
36
+ """
37
  text = payload.text.strip()
38
  if not text:
39
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
40
+
41
+ logger.info(f"Received comprehensive analysis request for text (first 50 chars): '{text[:50]}...'")
42
+
43
+ # Define tasks to run concurrently
44
+ tasks = {
45
+ "grammar": grammar_service.correct(text),
46
+ "tone": tone_service.classify(text),
47
+ "inclusive_language": inclusive_service.check(text),
48
+ "voice": voice_service.classify(text),
49
+ "readability": readability_service.compute(text),
50
+ "synonyms": synonyms_service.suggest(text),
51
+ }
52
+
53
+ results = {}
54
+ coroutine_tasks = []
55
+ task_keys = [] # To map results back to their keys
56
+
57
+ for key, coroutine in tasks.items():
58
+ coroutine_tasks.append(coroutine)
59
+ task_keys.append(key)
60
+
61
+ # Run all tasks concurrently and handle potential exceptions for each
62
+ raw_results = await asyncio.gather(*coroutine_tasks, return_exceptions=True)
63
+
64
+ # Process results, handling errors gracefully for each sub-analysis
65
+ for i, result in enumerate(raw_results):
66
+ key = task_keys[i]
67
+ if isinstance(result, ModelNotDownloadedError):
68
+ logger.warning(f"Analysis for '{key}' skipped: Model '{result.model_id}' not downloaded. Detail: {result.detail}")
69
+ results[key] = {
70
+ "status": "skipped",
71
+ "message": result.detail,
72
+ "error_type": result.error_type,
73
+ "model_id": result.model_id,
74
+ "feature_name": result.feature_name
75
+ }
76
+ elif isinstance(result, ServiceError):
77
+ logger.error(f"Analysis for '{key}' failed with ServiceError. Detail: {result.detail}", exc_info=True)
78
+ results[key] = {
79
+ "status": "error",
80
+ "message": result.detail,
81
+ "error_type": result.error_type
82
+ }
83
+ elif isinstance(result, Exception): # Catch any other unexpected exceptions from service methods
84
+ logger.exception(f"Analysis for '{key}' failed with unexpected error.")
85
+ results[key] = {
86
+ "status": "error",
87
+ "message": f"An unexpected error occurred: {str(result)}",
88
+ "error_type": "InternalServiceError"
89
+ }
90
+ else:
91
+ # If successful, merge the service's result into the main results dict
92
+ # Assuming each service returns a dict (e.g., {"grammar_correction": {...}} or {"tone": "..."})
93
+ results[key] = result # Direct assignment if the service result is already dict
94
+
95
+
96
+ logger.info(f"Comprehensive analysis complete for text (first 50 chars): '{text[:50]}...'")
97
+ return {"analysis_results": results}
app/routers/grammar.py CHANGED
@@ -1,36 +1,49 @@
1
- import uuid
2
- import asyncio
3
  import logging
4
- from fastapi import APIRouter, Depends, HTTPException, status
5
- from app.schemas.base import TextOnlyRequest
6
- from app.services.grammar import GrammarCorrector
7
- from app.core.security import verify_api_key
8
- from app.queue import task_queue
 
 
 
 
9
 
10
  router = APIRouter(prefix="/grammar", tags=["Grammar"])
11
- logger = logging.getLogger(__name__)
12
 
13
- @router.post("/", dependencies=[Depends(verify_api_key)])
14
- async def correct_grammar(payload: TextOnlyRequest):
 
 
 
 
 
 
 
 
 
15
  text = payload.text.strip()
16
  if not text:
17
- raise HTTPException(status_code=400, detail="Input text cannot be empty.")
18
-
19
- future = asyncio.get_event_loop().create_future()
20
- task_id = str(uuid.uuid4())[:8]
21
 
22
- await task_queue.put({
23
- "type": "grammar",
24
- "payload": {"text": text},
25
- "future": future,
26
- "id": task_id
27
- })
28
 
29
- result = await future
 
 
 
 
30
 
31
- if "error" in result:
32
- detail = result["error"]
33
- status_code = 400 if "empty" in detail.lower() else 500
34
- raise HTTPException(status_code=status_code, detail=detail)
35
 
36
- return {"grammar": result["result"]}
 
 
 
 
 
 
 
 
1
+ # app/routers/grammar.py
 
2
  import logging
3
+ from fastapi import APIRouter, Depends, status, HTTPException # HTTPException for 400 validation errors
4
+
5
+ from app.schemas.base import TextOnlyRequest # Assuming this Pydantic model exists
6
+ from app.services.grammar import GrammarCorrector # Import the service class
7
+ from app.core.security import verify_api_key # Assuming you still need API key verification
8
+ from app.core.config import APP_NAME # For logger naming
9
+ from app.core.exceptions import ServiceError # Important for catching specific service errors
10
+
11
+ logger = logging.getLogger(f"{APP_NAME}.routers.grammar")
12
 
13
  router = APIRouter(prefix="/grammar", tags=["Grammar"])
 
14
 
15
+ # Initialize service instance once per application lifecycle
16
+ # FastAPI handles dependency injection and lifecycle for routes,
17
+ # so instantiate the service directly.
18
+ grammar_corrector_service = GrammarCorrector()
19
+
20
+
21
+ @router.post("/correct", dependencies=[Depends(verify_api_key)]) # Changed path to /correct for clarity
22
+ async def correct_grammar_endpoint(payload: TextOnlyRequest):
23
+ """
24
+ Corrects grammar in the provided text.
25
+ """
26
  text = payload.text.strip()
27
  if not text:
28
+ # Use FastAPI's HTTPException for direct validation errors
29
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
 
 
30
 
31
+ logger.info(f"Received grammar correction request for text (first 50 chars): '{text[:50]}...'")
 
 
 
 
 
32
 
33
+ try:
34
+ # Directly call the async service method
35
+ # ModelNotDownloadedError will be raised here if model is missing,
36
+ # and caught by the global exception handler in app/main.py
37
+ result = await grammar_corrector_service.correct(text)
38
 
39
+ logger.info(f"Grammar correction successful for text (first 50 chars): '{text[:50]}...'")
40
+ return {"grammar_correction": result}
 
 
41
 
42
+ except ServiceError as e:
43
+ # Re-raise ServiceError. It will be caught by the global exception handler.
44
+ # This ensures consistent error responses across all services.
45
+ raise e
46
+ except Exception as e:
47
+ # Catch any unexpected exceptions and re-raise as a generic ServiceError
48
+ logger.exception(f"Unhandled error in grammar correction endpoint for text: '{text[:50]}...'")
49
+ raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during grammar correction.") from e
app/routers/inclusive_language.py CHANGED
@@ -1,11 +1,45 @@
1
- from fastapi import APIRouter, Depends
 
 
 
2
  from app.schemas.base import TextOnlyRequest
3
- from app.services.inclusive_language import InclusiveLanguageChecker
4
- from app.core.security import verify_api_key
 
 
 
 
5
 
6
  router = APIRouter(prefix="/inclusive-language", tags=["Inclusive Language"])
7
- checker = InclusiveLanguageChecker()
8
 
9
- @router.post("/", dependencies=[Depends(verify_api_key)])
10
- def check_inclusive_language(payload: TextOnlyRequest):
11
- return {"suggestions": checker.check(payload.text)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/routers/inclusive_language.py
2
+ import logging
3
+ from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
4
+
5
  from app.schemas.base import TextOnlyRequest
6
+ from app.services.inclusive_language import InclusiveLanguageChecker # Import the service class
7
+ from app.core.security import verify_api_key # Assuming API key verification is still used
8
+ from app.core.config import APP_NAME # For logger naming
9
+ from app.core.exceptions import ServiceError # For re-raising internal errors
10
+
11
+ logger = logging.getLogger(f"{APP_NAME}.routers.inclusive_language")
12
 
13
  router = APIRouter(prefix="/inclusive-language", tags=["Inclusive Language"])
 
14
 
15
+ # Initialize service instance once per application lifecycle
16
+ inclusive_language_checker_service = InclusiveLanguageChecker()
17
+
18
+
19
+ @router.post("/check", dependencies=[Depends(verify_api_key)]) # Added /check path for clarity
20
+ async def check_inclusive_language_endpoint(payload: TextOnlyRequest):
21
+ """
22
+ Checks the provided text for inclusive language suggestions.
23
+ """
24
+ text = payload.text.strip()
25
+ if not text:
26
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
27
+
28
+ logger.info(f"Received inclusive language check request for text (first 50 chars): '{text[:50]}...'")
29
+
30
+ try:
31
+ # Directly call the async service method
32
+ # ModelNotDownloadedError will be raised here if model is missing,
33
+ # and caught by the global exception handler in app/main.py
34
+ result = await inclusive_language_checker_service.check(text)
35
+
36
+ logger.info(f"Inclusive language check successful for text (first 50 chars): '{text[:50]}...'")
37
+ return {"inclusive_language": result}
38
+
39
+ except ServiceError as e:
40
+ # Re-raise ServiceError. It will be caught by the global exception handler.
41
+ raise e
42
+ except Exception as e:
43
+ # Catch any unexpected exceptions and re-raise as a generic ServiceError
44
+ logger.exception(f"Unhandled error in inclusive language check endpoint for text: '{text[:50]}...'")
45
+ raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during inclusive language checking.") from e
app/routers/paraphrase.py CHANGED
@@ -1,24 +1,45 @@
1
- import asyncio
2
- import uuid
3
- from fastapi import APIRouter, Depends
 
4
  from app.schemas.base import TextOnlyRequest
5
- from app.services.paraphrase import Paraphraser
6
  from app.core.security import verify_api_key
7
- from app.queue import task_queue
 
 
 
8
 
9
  router = APIRouter(prefix="/paraphrase", tags=["Paraphrase"])
10
- paraphraser = Paraphraser()
11
-
12
- @router.post("/", dependencies=[Depends(verify_api_key)])
13
- async def paraphrase_text(payload: TextOnlyRequest):
14
- future = asyncio.get_event_loop().create_future()
15
- task_id = str(uuid.uuid4())[:8]
16
- await task_queue.put({
17
- "type": "paraphrase",
18
- "payload": {"text": payload.text},
19
- "future": future,
20
- "id": task_id
21
- })
22
-
23
- result = await future
24
- return {"result": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/routers/paraphrase.py
2
+ import logging
3
+ from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
4
+
5
  from app.schemas.base import TextOnlyRequest
6
+ from app.services.paraphrase import Paraphraser # Import the service class
7
  from app.core.security import verify_api_key
8
+ from app.core.config import APP_NAME # For logger naming
9
+ from app.core.exceptions import ServiceError # For re-raising internal errors
10
+
11
+ logger = logging.getLogger(f"{APP_NAME}.routers.paraphrase")
12
 
13
  router = APIRouter(prefix="/paraphrase", tags=["Paraphrase"])
14
+
15
+ # Initialize service instance once per application lifecycle
16
+ paraphraser_service = Paraphraser()
17
+
18
+
19
+ @router.post("/generate", dependencies=[Depends(verify_api_key)])
20
+ async def paraphrase_text_endpoint(payload: TextOnlyRequest):
21
+ """
22
+ Generates a paraphrase for the provided text.
23
+ """
24
+ text = payload.text.strip()
25
+ if not text:
26
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
27
+
28
+ logger.info(f"Received paraphrase request for text (first 50 chars): '{text[:50]}...'")
29
+
30
+ try:
31
+ # Directly call the async service method
32
+ # ModelNotDownloadedError will be raised here if model is missing,
33
+ # and caught by the global exception handler in app/main.py
34
+ result = await paraphraser_service.paraphrase(text)
35
+
36
+ logger.info(f"Paraphrasing successful for text (first 50 chars): '{text[:50]}...'")
37
+ return {"paraphrase": result} # Consistent key for response
38
+
39
+ except ServiceError as e:
40
+ # Re-raise ServiceError. It will be caught by the global exception handler.
41
+ raise e
42
+ except Exception as e:
43
+ # Catch any unexpected exceptions and re-raise as a generic ServiceError
44
+ logger.exception(f"Unhandled error in paraphrasing endpoint for text: '{text[:50]}...'")
45
+ raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during paraphrasing.") from e
app/routers/readability.py CHANGED
@@ -1,21 +1,43 @@
1
- from fastapi import APIRouter, Depends, HTTPException, status
 
 
 
2
  from app.schemas.base import TextOnlyRequest
 
3
  from app.core.security import verify_api_key
4
- from app.queue import enqueue_task
5
- import logging
 
 
6
 
7
  router = APIRouter(prefix="/readability", tags=["Readability"])
8
- logger = logging.getLogger(__name__)
9
 
10
- @router.post("/", dependencies=[Depends(verify_api_key)])
11
- async def readability_score(payload: TextOnlyRequest):
 
 
 
 
 
 
 
12
  text = payload.text.strip()
13
  if not text:
14
- raise HTTPException(status_code=400, detail="Input text cannot be empty.")
 
 
15
 
16
- result = await enqueue_task("readability", {"text": text})
 
 
17
 
18
- if isinstance(result, dict) and result.get("error"):
19
- raise HTTPException(status_code=500, detail=result["error"])
20
 
21
- return {"readability_scores": result["result"]}
 
 
 
 
 
 
 
1
+ # app/routers/readability.py
2
+ import logging
3
+ from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
4
+
5
  from app.schemas.base import TextOnlyRequest
6
+ from app.services.readability import ReadabilityScorer # Import the service class
7
  from app.core.security import verify_api_key
8
+ from app.core.config import APP_NAME # For logger naming
9
+ from app.core.exceptions import ServiceError # For re-raising internal errors
10
+
11
+ logger = logging.getLogger(f"{APP_NAME}.routers.readability")
12
 
13
  router = APIRouter(prefix="/readability", tags=["Readability"])
 
14
 
15
+ # Initialize service instance once per application lifecycle
16
+ readability_scorer_service = ReadabilityScorer()
17
+
18
+
19
+ @router.post("/score", dependencies=[Depends(verify_api_key)]) # Added /score path for clarity
20
+ async def readability_score_endpoint(payload: TextOnlyRequest):
21
+ """
22
+ Computes various readability scores for the provided text.
23
+ """
24
  text = payload.text.strip()
25
  if not text:
26
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
27
+
28
+ logger.info(f"Received readability scoring request for text (first 50 chars): '{text[:50]}...'")
29
 
30
+ try:
31
+ # Directly call the async service method
32
+ result = await readability_scorer_service.compute(text)
33
 
34
+ logger.info(f"Readability scoring successful for text (first 50 chars): '{text[:50]}...'")
35
+ return {"readability_scores": result}
36
 
37
+ except ServiceError as e:
38
+ # Re-raise ServiceError. It will be caught by the global exception handler.
39
+ raise e
40
+ except Exception as e:
41
+ # Catch any unexpected exceptions and re-raise as a generic ServiceError
42
+ logger.exception(f"Unhandled error in readability scoring endpoint for text: '{text[:50]}...'")
43
+ raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during readability scoring.") from e
app/routers/rewrite.py CHANGED
@@ -1,18 +1,59 @@
1
- # routers/rewrite.py
 
 
2
 
3
- from fastapi import APIRouter, Depends
4
- from app.schemas.base import RewriteRequest
5
- from app.services.gpt4_rewrite import GPT4Rewriter
6
- from app.core.security import verify_api_key
 
 
 
7
 
8
  router = APIRouter(prefix="/rewrite", tags=["Rewrite"])
9
- rewriter = GPT4Rewriter()
10
-
11
- @router.post("/", dependencies=[Depends(verify_api_key)])
12
- def rewrite_with_instruction(payload: RewriteRequest):
13
- result = rewriter.rewrite(
14
- text=payload.text,
15
- instruction=payload.instruction,
16
- user_api_key=payload.user_api_key
17
- )
18
- return {"result": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/routers/rewrite.py
2
+ import logging
3
+ from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
4
 
5
+ from app.schemas.base import RewriteRequest # Assuming this Pydantic model exists
6
+ from app.services.gpt4_rewrite import GPT4Rewriter # Import the service class
7
+ from app.core.security import verify_api_key # Assuming API key verification is still used
8
+ from app.core.config import APP_NAME # For logger naming
9
+ from app.core.exceptions import ServiceError # For re-raising internal errors
10
+
11
+ logger = logging.getLogger(f"{APP_NAME}.routers.rewrite")
12
 
13
  router = APIRouter(prefix="/rewrite", tags=["Rewrite"])
14
+
15
+ # Initialize service instance once per application lifecycle
16
+ gpt4_rewriter_service = GPT4Rewriter()
17
+
18
+
19
+ @router.post("/with_instruction", dependencies=[Depends(verify_api_key)]) # Changed path to /with_instruction for clarity
20
+ async def rewrite_with_instruction_endpoint(payload: RewriteRequest):
21
+ """
22
+ Rewrites the provided text based on a specific instruction using GPT-4.
23
+ Requires an OpenAI API key.
24
+ """
25
+ text = payload.text.strip()
26
+ instruction = payload.instruction.strip()
27
+ user_api_key = payload.user_api_key # The user's provided API key
28
+
29
+ # Basic input validation for clarity, though service also validates
30
+ if not text:
31
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
32
+ if not instruction:
33
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Instruction cannot be empty.")
34
+ if not user_api_key:
35
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="OpenAI API key is required for this feature.")
36
+
37
+
38
+ logger.info(f"Received rewrite request for text (first 50 chars): '{text[:50]}...' with instruction (first 50 chars): '{instruction[:50]}...'")
39
+
40
+ try:
41
+ # Directly call the async service method
42
+ # ServiceError will be raised here if there's an issue (e.g., missing API key, OpenAI API error),
43
+ # and caught by the global exception handler in app/main.py.
44
+ result = await gpt4_rewriter_service.rewrite(
45
+ text=text,
46
+ instruction=instruction,
47
+ user_api_key=user_api_key # Pass the user's API key
48
+ )
49
+
50
+ logger.info(f"Rewriting successful for text (first 50 chars): '{text[:50]}...'")
51
+ return {"rewrite": result} # Consistent key for response
52
+
53
+ except ServiceError as e:
54
+ # Re-raise ServiceError. It will be caught by the global exception handler.
55
+ raise e
56
+ except Exception as e:
57
+ # Catch any unexpected exceptions and re-raise as a generic ServiceError
58
+ logger.exception(f"Unhandled error in rewriting endpoint for text: '{text[:50]}...'")
59
+ raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during rewriting.") from e
app/routers/synonyms.py CHANGED
@@ -1,21 +1,45 @@
1
- from fastapi import APIRouter, Depends, HTTPException, status
2
- from app.schemas.base import TextOnlyRequest
3
- from app.services.synonyms import SynonymSuggester
4
- from app.core.security import verify_api_key
5
- from app.queue import enqueue_task
6
  import logging
 
 
 
 
 
 
 
 
 
7
 
8
  router = APIRouter(prefix="/synonyms", tags=["Synonyms"])
9
- logger = logging.getLogger(__name__)
10
 
11
- @router.post("/", dependencies=[Depends(verify_api_key)])
12
- async def suggest_synonyms(payload: TextOnlyRequest):
 
 
 
 
 
 
 
13
  text = payload.text.strip()
14
  if not text:
15
- raise HTTPException(status_code=400, detail="Input text cannot be empty.")
 
 
 
 
 
 
 
 
16
 
17
- result = await enqueue_task("synonyms", {"text": text})
18
- if "error" in result:
19
- raise HTTPException(status_code=500, detail=result["error"])
20
 
21
- return {"synonyms": result["result"]}
 
 
 
 
 
 
 
1
+ # app/routers/synonyms.py
 
 
 
 
2
  import logging
3
+ from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
4
+
5
+ from app.schemas.base import TextOnlyRequest
6
+ from app.services.synonyms import SynonymSuggester # Import the service class
7
+ from app.core.security import verify_api_key # Assuming API key verification is still used
8
+ from app.core.config import APP_NAME # For logger naming
9
+ from app.core.exceptions import ServiceError # For re-raising internal errors
10
+
11
+ logger = logging.getLogger(f"{APP_NAME}.routers.synonyms")
12
 
13
  router = APIRouter(prefix="/synonyms", tags=["Synonyms"])
 
14
 
15
+ # Initialize service instance once per application lifecycle
16
+ synonym_suggester_service = SynonymSuggester()
17
+
18
+
19
+ @router.post("/suggest", dependencies=[Depends(verify_api_key)]) # Added /suggest path for clarity
20
+ async def suggest_synonyms_endpoint(payload: TextOnlyRequest):
21
+ """
22
+ Suggests synonyms for words in the provided text.
23
+ """
24
  text = payload.text.strip()
25
  if not text:
26
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
27
+
28
+ logger.info(f"Received synonym suggestion request for text (first 50 chars): '{text[:50]}...'")
29
+
30
+ try:
31
+ # Directly call the async service method
32
+ # ModelNotDownloadedError will be raised here if model/data is missing,
33
+ # and caught by the global exception handler in app/main.py
34
+ result = await synonym_suggester_service.suggest(text)
35
 
36
+ logger.info(f"Synonym suggestion successful for text (first 50 chars): '{text[:50]}...'")
37
+ return {"synonyms": result} # Consistent key for response
 
38
 
39
+ except ServiceError as e:
40
+ # Re-raise ServiceError. It will be caught by the global exception handler.
41
+ raise e
42
+ except Exception as e:
43
+ # Catch any unexpected exceptions and re-raise as a generic ServiceError
44
+ logger.exception(f"Unhandled error in synonym suggestion endpoint for text: '{text[:50]}...'")
45
+ raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during synonym suggestion.") from e
app/routers/tone.py CHANGED
@@ -1,24 +1,45 @@
1
- import asyncio
2
- import uuid
3
- from fastapi import APIRouter, Depends
 
4
  from app.schemas.base import TextOnlyRequest
5
- from app.services.tone_classification import ToneClassifier
6
- from app.core.security import verify_api_key
7
- from app.queue import task_queue
 
 
 
8
 
9
  router = APIRouter(prefix="/tone", tags=["Tone"])
10
- classifier = ToneClassifier()
11
-
12
- @router.post("/", dependencies=[Depends(verify_api_key)])
13
- async def classify_tone(payload: TextOnlyRequest):
14
- future = asyncio.get_event_loop().create_future()
15
- task_id = str(uuid.uuid4())[:8]
16
- await task_queue.put({
17
- "type": "tone",
18
- "payload": {"text": payload.text},
19
- "future": future,
20
- "id": task_id
21
- })
22
-
23
- result = await future
24
- return {"result": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/routers/tone.py
2
+ import logging
3
+ from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
4
+
5
  from app.schemas.base import TextOnlyRequest
6
+ from app.services.tone_classification import ToneClassifier # Import the service class
7
+ from app.core.security import verify_api_key # Assuming API key verification is still used
8
+ from app.core.config import APP_NAME # For logger naming
9
+ from app.core.exceptions import ServiceError # For re-raising internal errors
10
+
11
+ logger = logging.getLogger(f"{APP_NAME}.routers.tone")
12
 
13
  router = APIRouter(prefix="/tone", tags=["Tone"])
14
+
15
+ # Initialize service instance once per application lifecycle
16
+ tone_classifier_service = ToneClassifier()
17
+
18
+
19
+ @router.post("/classify", dependencies=[Depends(verify_api_key)]) # Added /classify path for clarity
20
+ async def classify_tone_endpoint(payload: TextOnlyRequest):
21
+ """
22
+ Classifies the tone of the provided text.
23
+ """
24
+ text = payload.text.strip()
25
+ if not text:
26
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
27
+
28
+ logger.info(f"Received tone classification request for text (first 50 chars): '{text[:50]}...'")
29
+
30
+ try:
31
+ # Directly call the async service method
32
+ # ModelNotDownloadedError will be raised here if model is missing,
33
+ # and caught by the global exception handler in app/main.py
34
+ result = await tone_classifier_service.classify(text)
35
+
36
+ logger.info(f"Tone classification successful for text (first 50 chars): '{text[:50]}...'")
37
+ return {"tone_classification": result} # Consistent key for response
38
+
39
+ except ServiceError as e:
40
+ # Re-raise ServiceError. It will be caught by the global exception handler.
41
+ raise e
42
+ except Exception as e:
43
+ # Catch any unexpected exceptions and re-raise as a generic ServiceError
44
+ logger.exception(f"Unhandled error in tone classification endpoint for text: '{text[:50]}...'")
45
+ raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during tone classification.") from e
app/routers/translate.py CHANGED
@@ -1,24 +1,48 @@
1
- import asyncio
2
- import uuid
3
- from fastapi import APIRouter, Depends
4
- from app.schemas.base import TranslateRequest
5
- from app.services.translation import Translator
6
- from app.core.security import verify_api_key
7
- from app.queue import task_queue
 
 
 
8
 
9
  router = APIRouter(prefix="/translate", tags=["Translate"])
10
- translator = Translator()
 
 
 
11
 
12
  @router.post("/", dependencies=[Depends(verify_api_key)])
13
- async def translate_text(payload: TranslateRequest):
14
- future = asyncio.get_event_loop().create_future()
15
- task_id = str(uuid.uuid4())[:8]
16
- await task_queue.put({
17
- "type": "translate",
18
- "payload": {"text": payload.text, "target_lang": payload.target_lang},
19
- "future": future,
20
- "id": task_id
21
- })
22
-
23
- result = await future
24
- return {"result": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from fastapi import APIRouter, Depends, HTTPException, status
3
+
4
+ from app.schemas.base import TranslateRequest
5
+ from app.services.translation import Translator
6
+ from app.core.security import verify_api_key
7
+ from app.core.config import APP_NAME
8
+ from app.core.exceptions import ServiceError
9
+
10
+ logger = logging.getLogger(f"{APP_NAME}.routers.translate")
11
 
12
  router = APIRouter(prefix="/translate", tags=["Translate"])
13
+
14
+
15
+ translator_service = Translator()
16
+
17
 
18
  @router.post("/", dependencies=[Depends(verify_api_key)])
19
+ async def translate_text_endpoint(payload: TranslateRequest):
20
+ """
21
+ Translates the provided text to a specified target language.
22
+ """
23
+ text = payload.text.strip()
24
+ target_lang = payload.target_lang.strip()
25
+
26
+ if not text:
27
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
28
+ if not target_lang:
29
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Target language cannot be empty.")
30
+
31
+ logger.info(f"Received translation request for text (first 50 chars): '{text[:50]}...' to '{target_lang}'")
32
+
33
+ try:
34
+ # Directly call the async service method
35
+ # ModelNotDownloadedError will be raised here if model is missing,
36
+ # and caught by the global exception handler in app/main.py
37
+ result = await translator_service.translate(text, target_lang)
38
+
39
+ logger.info(f"Translation successful for text (first 50 chars): '{text[:50]}...' to '{target_lang}'")
40
+ return {"translation": result} # Consistent key for response
41
+
42
+ except ServiceError as e:
43
+ # Re-raise ServiceError. It will be caught by the global exception handler.
44
+ raise e
45
+ except Exception as e:
46
+ # Catch any unexpected exceptions and re-raise as a generic ServiceError
47
+ logger.exception(f"Unhandled error in translation endpoint for text: '{text[:50]}...' to '{target_lang}'")
48
+ raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during translation.") from e
app/routers/voice.py CHANGED
@@ -1,11 +1,45 @@
1
- from fastapi import APIRouter, Depends
 
 
 
2
  from app.schemas.base import TextOnlyRequest
3
- from app.services.voice_detection import VoiceDetector
4
- from app.core.security import verify_api_key
 
 
 
 
5
 
6
  router = APIRouter(prefix="/voice", tags=["Voice"])
7
- detector = VoiceDetector()
8
 
9
- @router.post("/", dependencies=[Depends(verify_api_key)])
10
- def detect_voice(payload: TextOnlyRequest):
11
- return {"result": detector.classify(payload.text)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/routers/voice.py
2
+ import logging
3
+ from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
4
+
5
  from app.schemas.base import TextOnlyRequest
6
+ from app.services.voice_detection import VoiceDetector # Import the service class
7
+ from app.core.security import verify_api_key # Assuming API key verification is still used
8
+ from app.core.config import APP_NAME # For logger naming
9
+ from app.core.exceptions import ServiceError # For re-raising internal errors
10
+
11
+ logger = logging.getLogger(f"{APP_NAME}.routers.voice")
12
 
13
  router = APIRouter(prefix="/voice", tags=["Voice"])
 
14
 
15
+ # Initialize service instance once per application lifecycle
16
+ voice_detector_service = VoiceDetector()
17
+
18
+
19
+ @router.post("/detect", dependencies=[Depends(verify_api_key)]) # Added /detect path for clarity
20
+ async def detect_voice_endpoint(payload: TextOnlyRequest):
21
+ """
22
+ Detects the voice (active or passive) of the provided text.
23
+ """
24
+ text = payload.text.strip()
25
+ if not text:
26
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
27
+
28
+ logger.info(f"Received voice detection request for text (first 50 chars): '{text[:50]}...'")
29
+
30
+ try:
31
+ # Directly call the async service method
32
+ # ModelNotDownloadedError will be raised here if model is missing,
33
+ # and caught by the global exception handler in app/main.py
34
+ result = await voice_detector_service.classify(text)
35
+
36
+ logger.info(f"Voice detection successful for text (first 50 chars): '{text[:50]}...'")
37
+ return {"voice_detection": result} # Consistent key for response
38
+
39
+ except ServiceError as e:
40
+ # Re-raise ServiceError. It will be caught by the global exception handler.
41
+ raise e
42
+ except Exception as e:
43
+ # Catch any unexpected exceptions and re-raise as a generic ServiceError
44
+ logger.exception(f"Unhandled error in voice detection endpoint for text: '{text[:50]}...'")
45
+ raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during voice detection.") from e
app/services/base.py CHANGED
@@ -1,49 +1,132 @@
1
- import torch
2
- from threading import Lock
3
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  logger = logging.getLogger(__name__)
6
 
7
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
8
- logger.info(f"Using device: {DEVICE}")
 
 
 
 
 
 
9
 
10
- _models = {}
11
- _models_lock = Lock()
12
 
 
 
13
 
14
- def get_cached_model(model_name: str, load_fn):
15
- with _models_lock:
16
- if model_name not in _models:
17
- logger.info(f"Loading model: {model_name}")
18
- _models[model_name] = load_fn()
19
- return _models[model_name]
20
 
 
21
 
22
- def timed_model_load(label: str, load_fn):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  import time
24
  start = time.time()
25
- model = load_fn()
26
- logger.info(f"{label} loaded in {time.time() - start:.2f}s")
 
27
  return model
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- _nlp = None
 
 
 
 
 
 
 
31
 
 
 
 
 
 
 
 
32
 
33
- def get_spacy():
34
- global _nlp
35
- if _nlp is None:
36
- import spacy
37
- _nlp = spacy.load("en_core_web_sm")
38
- return _nlp
39
 
 
 
 
40
 
41
- # Shared error and response
42
- class ServiceError(Exception):
43
- def __init__(self, message: str):
44
- super().__init__(message)
45
- self.message = message
 
 
 
 
46
 
 
 
 
47
 
48
- def model_response(result: str = "", error: str = None) -> dict:
49
- return {"result": result, "error": error}
 
 
 
1
  import logging
2
+ from pathlib import Path
3
+ from functools import lru_cache
4
+
5
+ import torch
6
+ from transformers import (
7
+ pipeline,
8
+ AutoTokenizer,
9
+ AutoModelForSequenceClassification,
10
+ AutoModelForSeq2SeqLM,
11
+ AutoModelForMaskedLM,
12
+ )
13
+
14
+ from sentence_transformers import SentenceTransformer
15
+
16
+ from app.core.config import (
17
+ MODELS_DIR, SPACY_MODEL_ID, SENTENCE_TRANSFORMER_MODEL_ID,
18
+ OFFLINE_MODE
19
+ )
20
+ from app.core.exceptions import ModelNotDownloadedError
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
+ # ─────────────────────────────────────────────────────────────────────────────
25
+ # 🧠 SpaCy
26
+ # ─────────────────────────────────────────────────────────────────────────────
27
+
28
+ @lru_cache(maxsize=1)
29
+ def load_spacy_model(model_id: str = SPACY_MODEL_ID):
30
+ import spacy
31
+ from spacy.util import is_package
32
 
33
+ logger.info(f"Loading spaCy model: {model_id}")
 
34
 
35
+ if is_package(model_id):
36
+ return spacy.load(model_id)
37
 
38
+ possible_path = MODELS_DIR / model_id
39
+ if possible_path.exists():
40
+ return spacy.load(str(possible_path))
 
 
 
41
 
42
+ raise RuntimeError(f"Could not find spaCy model '{model_id}' at {possible_path}")
43
 
44
+ # ─────────────────────────────────────────────────────────────────────────────
45
+ # 🔤 Sentence Transformers
46
+ # ─────────────────────────────────────────────────────────────────────────────
47
+
48
+ @lru_cache(maxsize=1)
49
+ def load_sentence_transformer_model(model_id: str = SENTENCE_TRANSFORMER_MODEL_ID) -> SentenceTransformer:
50
+ logger.info(f"Loading SentenceTransformer: {model_id}")
51
+ return SentenceTransformer(model_name_or_path=model_id, cache_folder=MODELS_DIR)
52
+
53
+ # ─────────────────────────────────────────────────────────────────────────────
54
+ # 🤗 Hugging Face Pipelines (T5 models, classifiers, etc.)
55
+ # ─────────────────────────────────────────────────────────────────────────────
56
+
57
+ def _check_model_downloaded(model_id: str, cache_dir: str) -> bool:
58
+ model_path = Path(cache_dir) / model_id.replace("/", "_")
59
+ return model_path.exists()
60
+
61
+ def _timed_load(name: str, fn):
62
  import time
63
  start = time.time()
64
+ model = fn()
65
+ elapsed = round(time.time() - start, 2)
66
+ logger.info(f"[{name}] model loaded in {elapsed}s")
67
  return model
68
 
69
+ @lru_cache(maxsize=2)
70
+ def load_hf_pipeline(model_id: str, task: str, feature_name: str, **kwargs):
71
+ if OFFLINE_MODE and not _check_model_downloaded(model_id, str(MODELS_DIR)):
72
+ raise ModelNotDownloadedError(model_id, feature_name, "Model not found locally in offline mode.")
73
+
74
+ try:
75
+ # Choose appropriate AutoModel loader based on task
76
+ if task == "text-classification":
77
+ model_loader = AutoModelForSequenceClassification
78
+ elif task == "text2text-generation" or task.startswith("translation"):
79
+ model_loader = AutoModelForSeq2SeqLM
80
+ elif task == "fill-mask":
81
+ model_loader = AutoModelForMaskedLM
82
+ else:
83
+ raise ValueError(f"Unsupported task type '{task}' for feature '{feature_name}'.")
84
+
85
+ model = _timed_load(
86
+ f"{feature_name}:{model_id} (model)",
87
+ lambda: model_loader.from_pretrained(
88
+ model_id,
89
+ cache_dir=MODELS_DIR,
90
+ local_files_only=OFFLINE_MODE
91
+ )
92
+ )
93
 
94
+ tokenizer = _timed_load(
95
+ f"{feature_name}:{model_id} (tokenizer)",
96
+ lambda: AutoTokenizer.from_pretrained(
97
+ model_id,
98
+ cache_dir=MODELS_DIR,
99
+ local_files_only=OFFLINE_MODE
100
+ )
101
+ )
102
 
103
+ return pipeline(
104
+ task=task,
105
+ model=model,
106
+ tokenizer=tokenizer,
107
+ device=0 if torch.cuda.is_available() else -1,
108
+ **kwargs
109
+ )
110
 
111
+ except Exception as e:
112
+ logger.error(f"Failed to load pipeline for '{feature_name}' - {model_id}: {e}", exc_info=True)
113
+ raise ModelNotDownloadedError(model_id, feature_name, str(e))
 
 
 
114
 
115
+ # ─────────────────────────────────────────────────────────────────────────────
116
+ # 📚 NLTK
117
+ # ─────────────────────────────────────────────────────────────────────────────
118
 
119
+ @lru_cache(maxsize=1)
120
+ def ensure_nltk_resource(resource_name: str = "wordnet") -> None:
121
+ try:
122
+ import nltk
123
+ nltk.data.find(f"corpora/{resource_name}")
124
+ except (LookupError, ImportError):
125
+ if OFFLINE_MODE:
126
+ raise RuntimeError(f"NLTK resource '{resource_name}' not found in offline mode.")
127
+ nltk.download(resource_name)
128
 
129
+ # ─────────────────────────────────────────────────────────────────────────────
130
+ # 🎯 Ready-to-use Loaders (for your app use)
131
+ # ─────────────────────────────────────────────────────────────────────────────
132
 
 
 
app/services/gpt4_rewrite.py CHANGED
@@ -1,46 +1,76 @@
1
  import openai
2
  import logging
3
- from tenacity import retry, stop_after_attempt, wait_exponential
4
- from app.core.config import settings
5
- from app.services.base import model_response, ServiceError
 
6
 
7
- logger = logging.getLogger(__name__)
8
 
9
  class GPT4Rewriter:
10
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
11
- def rewrite(self, text: str, user_api_key: str, instruction: str) -> dict:
 
 
 
 
12
  try:
13
  if not user_api_key:
14
- raise ServiceError("Missing OpenAI API key.")
15
 
16
  text = text.strip()
17
  instruction = instruction.strip()
18
 
19
  if not text:
20
- raise ServiceError("Input text is empty.")
21
  if not instruction:
22
- raise ServiceError("Missing rewrite instruction.")
23
 
24
  messages = [
25
  {"role": "system", "content": instruction},
26
  {"role": "user", "content": text},
27
  ]
28
 
29
- client = openai.OpenAI(api_key=user_api_key)
30
- response = client.chat.completions.create(
31
- model=settings.OPENAI_MODEL,
32
- messages=messages,
33
- temperature=settings.OPENAI_TEMPERATURE,
34
- max_tokens=settings.OPENAI_MAX_TOKENS
35
- )
36
- result = response.choices[0].message.content.strip()
37
- return model_response(result=result)
38
-
39
- except ServiceError as se:
40
- return model_response(error=str(se))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  except openai.APIError as e:
42
- logger.error(f"OpenAI API error: {e}")
43
- return model_response(error=f"OpenAI API error: {e}")
 
 
 
 
44
  except Exception as e:
45
- logger.error(f"Unexpected error in GPT-4 rewrite: {e}")
46
- return model_response(error="Unexpected error during rewrite.")
 
1
  import openai
2
  import logging
3
+ import asyncio
4
+ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
5
+ from app.core.config import settings, APP_NAME
6
+ from app.core.exceptions import ServiceError
7
 
8
+ logger = logging.getLogger(f"{APP_NAME}.services.gpt4_rewrite")
9
 
10
  class GPT4Rewriter:
11
+ @retry(
12
+ stop=stop_after_attempt(3),
13
+ wait=wait_exponential(multiplier=1, min=2, max=10),
14
+ retry=retry_if_exception_type(openai.APIError)
15
+ )
16
+ async def rewrite(self, text: str, user_api_key: str, instruction: str) -> dict:
17
  try:
18
  if not user_api_key:
19
+ raise ServiceError(status_code=401, detail="OpenAI API key is missing. Please provide your key to use this feature.")
20
 
21
  text = text.strip()
22
  instruction = instruction.strip()
23
 
24
  if not text:
25
+ raise ServiceError(status_code=400, detail="Input text is empty for rewriting.")
26
  if not instruction:
27
+ raise ServiceError(status_code=400, detail="Rewrite instruction is missing.")
28
 
29
  messages = [
30
  {"role": "system", "content": instruction},
31
  {"role": "user", "content": text},
32
  ]
33
 
34
+ def _call_openai_api():
35
+ client = openai.OpenAI(api_key=user_api_key)
36
+ response = client.chat.completions.create(
37
+ model=settings.OPENAI_MODEL,
38
+ messages=messages,
39
+ temperature=settings.OPENAI_TEMPERATURE,
40
+ max_tokens=settings.OPENAI_MAX_TOKENS
41
+ )
42
+ return response.choices[0].message.content.strip()
43
+
44
+ result = await asyncio.to_thread(_call_openai_api)
45
+ return {"rewritten_text": result}
46
+
47
+ except openai.APIStatusError as e:
48
+ logger.error(f"OpenAI API status error: {e.status_code} - {e.response}", exc_info=True)
49
+ detail_message = "An OpenAI API error occurred."
50
+ if e.status_code == 401:
51
+ detail_message = "Invalid OpenAI API key. Please check your key."
52
+ elif e.status_code == 429:
53
+ detail_message = "OpenAI API rate limit exceeded or quota exhausted. Please try again later."
54
+ elif e.status_code == 400:
55
+ detail_message = f"OpenAI API request error: {e.response.json().get('detail', e.message)}"
56
+
57
+ raise ServiceError(status_code=e.status_code, detail=detail_message) from e
58
+
59
+ except openai.APITimeoutError as e:
60
+ logger.error(f"OpenAI API timeout error: {e}", exc_info=True)
61
+ raise ServiceError(status_code=504, detail="OpenAI API request timed out. Please try again.") from e
62
+
63
+ except openai.APIConnectionError as e:
64
+ logger.error(f"OpenAI API connection error: {e}", exc_info=True)
65
+ raise ServiceError(status_code=503, detail="Could not connect to OpenAI API. Please check your internet connection.") from e
66
+
67
  except openai.APIError as e:
68
+ logger.error(f"OpenAI API error: {e}", exc_info=True)
69
+ raise ServiceError(status_code=500, detail=f"An unexpected OpenAI API error occurred: {str(e)}") from e
70
+
71
+ except ServiceError as e:
72
+ raise e
73
+
74
  except Exception as e:
75
+ logger.error(f"Unexpected error in GPT-4 rewrite for text: '{text[:50]}...'", exc_info=True)
76
+ raise ServiceError(status_code=500, detail="An unexpected error occurred during rewriting.") from e
app/services/grammar.py CHANGED
@@ -1,78 +1,77 @@
1
  import difflib
2
  import logging
 
 
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
- from app.services.base import (
7
- get_cached_model, DEVICE, timed_model_load,
8
- ServiceError, model_response
9
- )
10
  from app.core.config import settings
 
 
 
11
 
12
- logger = logging.getLogger(__name__)
13
 
14
  class GrammarCorrector:
15
  def __init__(self):
16
- self.tokenizer, self.model = self._load_model()
17
 
18
- def _load_model(self):
19
- def load_fn():
20
- tokenizer = timed_model_load(
21
- "grammar_tokenizer",
22
- lambda: AutoTokenizer.from_pretrained(settings.GRAMMAR_MODEL)
23
- )
24
- model = timed_model_load(
25
- "grammar_model",
26
- lambda: AutoModelForSeq2SeqLM.from_pretrained(settings.GRAMMAR_MODEL)
27
  )
28
- model = model.to(DEVICE).eval()
29
- return tokenizer, model
30
 
31
- return get_cached_model("grammar", load_fn)
 
 
 
32
 
33
- def correct(self, text: str) -> dict:
34
  try:
35
- text = text.strip()
36
- if not text:
37
- raise ServiceError("Input text is empty.")
38
 
39
- with torch.no_grad():
40
- inputs = self.tokenizer([text], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
41
- outputs = self.model.generate(**inputs, max_length=256, num_beams=4, early_stopping=True)
42
- corrected = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
43
 
44
  issues = self.get_diff_issues(text, corrected)
45
 
46
- return model_response(result={
47
  "original_text": text,
48
  "corrected_text_suggestion": corrected,
49
  "issues": issues
50
- })
51
 
52
- except ServiceError as se:
53
- return model_response(error=str(se))
54
  except Exception as e:
55
- logger.error(f"Grammar correction error: {e}", exc_info=True)
56
- return model_response(error="An error occurred during grammar correction.")
 
 
 
 
57
 
58
- def get_diff_issues(self, original: str, corrected: str):
59
  matcher = difflib.SequenceMatcher(None, original, corrected)
60
  issues = []
61
 
62
  for tag, i1, i2, j1, j2 in matcher.get_opcodes():
63
- if tag == 'equal':
64
  continue
65
 
66
  issues.append({
67
  "offset": i1,
68
  "length": i2 - i1,
69
- "original": original[i1:i2],
70
- "suggestion": corrected[j1:j2],
71
- "context_before": original[max(0, i1 - 15):i1],
72
- "context_after": original[i2:i2 + 15],
73
  "message": "Grammar correction",
74
  "line": original[:i1].count("\n") + 1,
75
- "column": i1 - original[:i1].rfind("\n") if "\n" in original[:i1] else i1 + 1
76
  })
77
 
78
  return issues
 
1
  import difflib
2
  import logging
3
+ from typing import List
4
+
5
  import torch
 
6
 
7
+ from app.services.base import load_hf_pipeline
 
 
 
8
  from app.core.config import settings
9
+ from app.core.exceptions import ServiceError
10
+
11
+ logger = logging.getLogger(f"{settings.APP_NAME}.services.grammar")
12
 
 
13
 
14
  class GrammarCorrector:
15
  def __init__(self):
16
+ self._pipeline = None
17
 
18
+ def _get_pipeline(self):
19
+ if self._pipeline is None:
20
+ logger.info("Loading grammar correction pipeline...")
21
+ self._pipeline = load_hf_pipeline(
22
+ model_id=settings.GRAMMAR_MODEL_ID,
23
+ task="text2text-generation",
24
+ feature_name="Grammar Correction"
 
 
25
  )
26
+ return self._pipeline
 
27
 
28
+ async def correct(self, text: str) -> dict:
29
+ text = text.strip()
30
+ if not text:
31
+ raise ServiceError(status_code=400, detail="Input text is empty for grammar correction.")
32
 
 
33
  try:
34
+ pipeline = self._get_pipeline()
 
 
35
 
36
+ result = pipeline(text, max_length=512, num_beams=4, early_stopping=True)
37
+ corrected = result[0]["generated_text"].strip()
38
+
39
+ if not corrected:
40
+ raise ServiceError(status_code=500, detail="Failed to decode grammar correction output.")
41
 
42
  issues = self.get_diff_issues(text, corrected)
43
 
44
+ return {
45
  "original_text": text,
46
  "corrected_text_suggestion": corrected,
47
  "issues": issues
48
+ }
49
 
 
 
50
  except Exception as e:
51
+ logger.error(f"Grammar correction error for input: '{text[:50]}...'", exc_info=True)
52
+ raise ServiceError(status_code=500, detail="An internal error occurred during grammar correction.") from e
53
+
54
+ def get_diff_issues(self, original: str, corrected: str) -> List[dict]:
55
+ def safe_slice(s: str, start: int, end: int) -> str:
56
+ return s[max(0, start):min(len(s), end)]
57
 
 
58
  matcher = difflib.SequenceMatcher(None, original, corrected)
59
  issues = []
60
 
61
  for tag, i1, i2, j1, j2 in matcher.get_opcodes():
62
+ if tag == "equal":
63
  continue
64
 
65
  issues.append({
66
  "offset": i1,
67
  "length": i2 - i1,
68
+ "original_segment": original[i1:i2],
69
+ "suggested_segment": corrected[j1:j2],
70
+ "context_before": safe_slice(original, i1 - 15, i1),
71
+ "context_after": safe_slice(original, i2, i2 + 15),
72
  "message": "Grammar correction",
73
  "line": original[:i1].count("\n") + 1,
74
+ "column": (i1 - original[:i1].rfind("\n") - 1) if "\n" in original[:i1] else i1 + 1
75
  })
76
 
77
  return issues
app/services/inclusive_language.py CHANGED
@@ -1,72 +1,120 @@
 
1
  import yaml
2
  from pathlib import Path
3
  from typing import List, Dict
4
- from app.services.base import get_spacy, model_response, ServiceError
5
- from app.core.config import settings
6
- import logging
7
 
8
- logger = logging.getLogger(__name__)
 
 
 
 
 
9
 
10
  class InclusiveLanguageChecker:
11
- def __init__(self, rules_directory=settings.INCLUSIVE_RULES_DIR):
12
- self.rules = self._load_inclusive_rules(rules_directory)
13
- self.matcher = self._init_matcher()
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- def _load_inclusive_rules(self, directory: str) -> Dict[str, Dict]:
16
  rules = {}
17
- for path in Path(directory).glob("*.yml"):
18
  try:
19
- with open(path, encoding="utf-8") as f:
20
  rule_list = yaml.safe_load(f)
21
- if not isinstance(rule_list, list):
22
- logger.warning(f"Skipping malformed rule file: {path}")
23
- continue
24
- for rule in rule_list:
25
- note = rule.get("note", "")
26
- source = rule.get("source", "")
27
- considerate = rule.get("considerate", [])
28
- inconsiderate = rule.get("inconsiderate", [])
29
-
30
- if isinstance(considerate, str):
31
- considerate = [considerate]
32
- if isinstance(inconsiderate, str):
33
- inconsiderate = [inconsiderate]
34
-
35
- for phrase in inconsiderate:
36
- rules[phrase.lower()] = {
37
- "note": note,
38
- "considerate": considerate,
39
- "source": source,
40
- "type": rule.get("type", "basic")
41
- }
 
 
 
 
 
42
  except Exception as e:
43
- logger.error(f"Error loading inclusive language rule from {path}: {e}")
 
 
 
 
 
 
44
  return rules
45
 
46
- def _init_matcher(self):
 
 
 
 
 
 
 
 
 
 
 
47
  from spacy.matcher import PhraseMatcher
48
- matcher = PhraseMatcher(get_spacy().vocab, attr="LOWER")
 
49
  for phrase in self.rules:
50
- matcher.add(phrase, [get_spacy().make_doc(phrase)])
51
- logger.info(f"Loaded {len(self.rules)} inclusive language rules.")
 
52
  return matcher
53
 
54
- def check(self, text: str) -> dict:
 
 
 
 
 
 
 
55
  try:
56
- text = text.strip()
57
- if not text:
58
- raise ServiceError("Input text is empty.")
59
 
60
- nlp = get_spacy()
61
  doc = nlp(text)
62
  matches = self.matcher(doc)
63
  results = []
64
  matched_spans = set()
65
 
 
66
  for match_id, start, end in matches:
67
- phrase_str = nlp.vocab.strings[match_id]
 
 
 
68
  matched_spans.add((start, end))
69
- rule = self.rules.get(phrase_str.lower())
70
  if rule:
71
  results.append({
72
  "term": doc[start:end].text,
@@ -74,15 +122,18 @@ class InclusiveLanguageChecker:
74
  "note": rule["note"],
75
  "suggestions": rule["considerate"],
76
  "context": doc[start:end].sent.text,
77
- "start": doc[start].idx,
78
- "end": doc[end - 1].idx + len(doc[end - 1]),
79
- "source": rule.get("source", "")
80
  })
81
 
 
82
  for token in doc:
83
  lemma = token.lemma_.lower()
84
- span_key = (token.i, token.i + 1)
85
- if lemma in self.rules and span_key not in matched_spans:
 
 
86
  rule = self.rules[lemma]
87
  results.append({
88
  "term": token.text,
@@ -90,15 +141,16 @@ class InclusiveLanguageChecker:
90
  "note": rule["note"],
91
  "suggestions": rule["considerate"],
92
  "context": token.sent.text,
93
- "start": token.idx,
94
- "end": token.idx + len(token),
95
- "source": rule.get("source", "")
96
  })
97
 
98
- return model_response(result=results)
99
 
100
- except ServiceError as se:
101
- return model_response(error=str(se))
102
  except Exception as e:
103
- logger.error(f"Inclusive language check error: {e}")
104
- return model_response(error="An error occurred during inclusive language checking.")
 
 
 
 
1
+ import logging
2
  import yaml
3
  from pathlib import Path
4
  from typing import List, Dict
 
 
 
5
 
6
+ from app.services.base import load_spacy_model
7
+ from app.core.config import settings, APP_NAME, SPACY_MODEL_ID
8
+ from app.core.exceptions import ServiceError
9
+
10
+ logger = logging.getLogger(f"{APP_NAME}.services.inclusive_language")
11
+
12
 
13
  class InclusiveLanguageChecker:
14
+ def __init__(self, rules_directory: str = settings.INCLUSIVE_RULES_DIR):
15
+ self._nlp = None
16
+ self.matcher = None
17
+ self.rules = self._load_inclusive_rules(Path(rules_directory))
18
+
19
+ def _load_inclusive_rules(self, rules_path: Path) -> Dict[str, Dict]:
20
+ """
21
+ Load YAML-based inclusive language rules from the given directory.
22
+ """
23
+ if not rules_path.is_dir():
24
+ logger.error(f"Inclusive language rules directory not found: {rules_path}")
25
+ raise ServiceError(
26
+ status_code=500,
27
+ detail=f"Inclusive language rules directory not found: {rules_path}"
28
+ )
29
 
 
30
  rules = {}
31
+ for yaml_file in rules_path.glob("*.yml"):
32
  try:
33
+ with yaml_file.open(encoding="utf-8") as f:
34
  rule_list = yaml.safe_load(f)
35
+
36
+ if not isinstance(rule_list, list):
37
+ logger.warning(f"Skipping non-list rule file: {yaml_file}")
38
+ continue
39
+
40
+ for rule in rule_list:
41
+ inconsiderate = rule.get("inconsiderate", [])
42
+ considerate = rule.get("considerate", [])
43
+ note = rule.get("note", "")
44
+ source = rule.get("source", "")
45
+ rule_type = rule.get("type", "basic")
46
+
47
+ # Ensure consistent formatting
48
+ if isinstance(considerate, str):
49
+ considerate = [considerate]
50
+ if isinstance(inconsiderate, str):
51
+ inconsiderate = [inconsiderate]
52
+
53
+ for phrase in inconsiderate:
54
+ rules[phrase.lower()] = {
55
+ "considerate": considerate,
56
+ "note": note,
57
+ "source": source,
58
+ "type": rule_type
59
+ }
60
+
61
  except Exception as e:
62
+ logger.error(f"Error loading rule file {yaml_file}: {e}", exc_info=True)
63
+ raise ServiceError(
64
+ status_code=500,
65
+ detail=f"Failed to load inclusive language rules: {e}"
66
+ )
67
+
68
+ logger.info(f"Loaded {len(rules)} inclusive language rules from {rules_path}")
69
  return rules
70
 
71
+ def _get_nlp(self):
72
+ """
73
+ Lazy-loads the spaCy model for NLP processing.
74
+ """
75
+ if self._nlp is None:
76
+ self._nlp = load_spacy_model(SPACY_MODEL_ID)
77
+ return self._nlp
78
+
79
+ def _init_matcher(self, nlp):
80
+ """
81
+ Initializes spaCy PhraseMatcher using loaded rules.
82
+ """
83
  from spacy.matcher import PhraseMatcher
84
+
85
+ matcher = PhraseMatcher(nlp.vocab, attr="LOWER")
86
  for phrase in self.rules:
87
+ matcher.add(phrase, [nlp.make_doc(phrase)])
88
+
89
+ logger.info(f"PhraseMatcher initialized with {len(self.rules)} phrases.")
90
  return matcher
91
 
92
+ async def check(self, text: str) -> dict:
93
+ """
94
+ Checks a string for non-inclusive language based on rule definitions.
95
+ """
96
+ text = text.strip()
97
+ if not text:
98
+ raise ServiceError(status_code=400, detail="Input text is empty for inclusive language check.")
99
+
100
  try:
101
+ nlp = self._get_nlp()
102
+ if self.matcher is None:
103
+ self.matcher = self._init_matcher(nlp)
104
 
 
105
  doc = nlp(text)
106
  matches = self.matcher(doc)
107
  results = []
108
  matched_spans = set()
109
 
110
+ # Match exact phrases
111
  for match_id, start, end in matches:
112
+ phrase = nlp.vocab.strings[match_id].lower()
113
+ if any(s <= start < e or s < end <= e for s, e in matched_spans):
114
+ continue # Avoid overlapping matches
115
+
116
  matched_spans.add((start, end))
117
+ rule = self.rules.get(phrase)
118
  if rule:
119
  results.append({
120
  "term": doc[start:end].text,
 
122
  "note": rule["note"],
123
  "suggestions": rule["considerate"],
124
  "context": doc[start:end].sent.text,
125
+ "start_char": doc[start].idx,
126
+ "end_char": doc[end - 1].idx + len(doc[end - 1]),
127
+ "source": rule["source"]
128
  })
129
 
130
+ # Match individual token lemmas (fallback)
131
  for token in doc:
132
  lemma = token.lemma_.lower()
133
+ if (token.i, token.i + 1) in matched_spans:
134
+ continue # Already matched in phrase
135
+
136
+ if lemma in self.rules:
137
  rule = self.rules[lemma]
138
  results.append({
139
  "term": token.text,
 
141
  "note": rule["note"],
142
  "suggestions": rule["considerate"],
143
  "context": token.sent.text,
144
+ "start_char": token.idx,
145
+ "end_char": token.idx + len(token),
146
+ "source": rule["source"]
147
  })
148
 
149
+ return {"issues": results}
150
 
 
 
151
  except Exception as e:
152
+ logger.error(f"Inclusive language check error for text: '{text[:50]}...'", exc_info=True)
153
+ raise ServiceError(
154
+ status_code=500,
155
+ detail="An internal error occurred during inclusive language checking."
156
+ ) from e
app/services/paraphrase.py CHANGED
@@ -1,44 +1,40 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
- from app.services.base import get_cached_model, DEVICE, timed_model_load, model_response, ServiceError
3
- from app.core.config import settings
4
- import torch
5
  import logging
6
 
7
- logger = logging.getLogger(__name__)
 
 
 
 
 
8
 
9
  class Paraphraser:
10
  def __init__(self):
11
- self.tokenizer, self.model = self._load_model()
12
 
13
- def _load_model(self):
14
- def load_fn():
15
- tokenizer = timed_model_load("paraphrase_tokenizer", lambda: AutoTokenizer.from_pretrained(settings.PARAPHRASE_MODEL))
16
- model = timed_model_load("paraphrase_model", lambda: AutoModelForSeq2SeqLM.from_pretrained(settings.PARAPHRASE_MODEL))
17
- model = model.to(DEVICE).eval()
18
- return tokenizer, model
19
- return get_cached_model("paraphrase", load_fn)
 
 
20
 
21
- def paraphrase(self, text: str) -> dict:
22
- try:
23
- text = text.strip()
24
- if not text:
25
- raise ServiceError("Input text is empty.")
26
 
 
 
27
  prompt = f"paraphrase: {text} </s>"
28
- with torch.no_grad():
29
- inputs = self.tokenizer([prompt], return_tensors="pt", padding=True, truncation=True).to(DEVICE)
30
- outputs = self.model.generate(
31
- **inputs,
32
- max_length=256,
33
- num_beams=5,
34
- num_return_sequences=1,
35
- early_stopping=True
36
- )
37
- result = self.tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
38
- return model_response(result=result)
39
-
40
- except ServiceError as se:
41
- return model_response(error=str(se))
42
  except Exception as e:
43
- logger.error(f"Paraphrasing error: {e}")
44
- return model_response(error="An error occurred during paraphrasing.")
 
 
 
 
 
1
  import logging
2
 
3
+ from app.services.base import load_hf_pipeline
4
+ from app.core.config import settings, APP_NAME
5
+ from app.core.exceptions import ServiceError
6
+
7
+ logger = logging.getLogger(f"{APP_NAME}.services.paraphrase")
8
+
9
 
10
  class Paraphraser:
11
  def __init__(self):
12
+ self._pipeline = None
13
 
14
+ def _get_pipeline(self):
15
+ if self._pipeline is None:
16
+ logger.info("Loading paraphrasing pipeline...")
17
+ self._pipeline = load_hf_pipeline(
18
+ model_id=settings.PARAPHRASE_MODEL_ID,
19
+ task="text2text-generation",
20
+ feature_name="Paraphrasing"
21
+ )
22
+ return self._pipeline
23
 
24
+ async def paraphrase(self, text: str) -> dict:
25
+ text = text.strip()
26
+ if not text:
27
+ raise ServiceError(status_code=400, detail="Input text is empty for paraphrasing.")
 
28
 
29
+ try:
30
+ pipeline = self._get_pipeline()
31
  prompt = f"paraphrase: {text} </s>"
32
+
33
+ results = pipeline(prompt, max_length=256, num_beams=5, num_return_sequences=1, early_stopping=True)
34
+ paraphrased = results[0]["generated_text"].strip()
35
+
36
+ return {"paraphrased_text": paraphrased}
37
+
 
 
 
 
 
 
 
 
38
  except Exception as e:
39
+ logger.error(f"Paraphrasing error for text: '{text[:50]}...'", exc_info=True)
40
+ raise ServiceError(status_code=500, detail="An internal error occurred during paraphrasing.") from e
app/services/readability.py CHANGED
@@ -1,17 +1,18 @@
 
1
  import textstat
2
  import logging
3
- from app.services.base import model_response, ServiceError
 
4
 
5
- logger = logging.getLogger(__name__)
6
 
7
  class ReadabilityScorer:
8
- def compute(self, text: str) -> dict:
9
  try:
10
  text = text.strip()
11
  if not text:
12
- raise ServiceError("Input text is empty.")
13
 
14
- # Compute scores
15
  scores = {
16
  "flesch_reading_ease": textstat.flesch_reading_ease(text),
17
  "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
@@ -21,41 +22,39 @@ class ReadabilityScorer:
21
  "automated_readability_index": textstat.automated_readability_index(text),
22
  }
23
 
24
- # Friendly descriptions
25
  friendly_scores = {
26
  "flesch_reading_ease": {
27
- "score": scores["flesch_reading_ease"],
28
  "label": "Flesch Reading Ease",
29
  "description": "Higher is easier. 60–70 is plain English; 90+ is very easy."
30
  },
31
  "flesch_kincaid_grade": {
32
- "score": scores["flesch_kincaid_grade"],
33
  "label": "Flesch-Kincaid Grade Level",
34
  "description": "U.S. school grade. 8.0 means an 8th grader can understand it."
35
  },
36
  "gunning_fog_index": {
37
- "score": scores["gunning_fog_index"],
38
  "label": "Gunning Fog Index",
39
  "description": "Estimates years of formal education needed to understand."
40
  },
41
  "smog_index": {
42
- "score": scores["smog_index"],
43
  "label": "SMOG Index",
44
  "description": "Also estimates required years of education."
45
  },
46
  "coleman_liau_index": {
47
- "score": scores["coleman_liau_index"],
48
  "label": "Coleman-Liau Index",
49
  "description": "Grade level based on characters, not syllables."
50
  },
51
  "automated_readability_index": {
52
- "score": scores["automated_readability_index"],
53
  "label": "Automated Readability Index",
54
  "description": "Grade level using word and sentence lengths."
55
  }
56
  }
57
 
58
- # Flesch score guide
59
  ease_score = scores["flesch_reading_ease"]
60
  if ease_score >= 90:
61
  summary = "Very easy to read. Easily understood by 11-year-olds."
@@ -68,13 +67,13 @@ class ReadabilityScorer:
68
  else:
69
  summary = "Very difficult. Best understood by university graduates."
70
 
71
- return model_response(result={
72
  "readability_summary": summary,
73
  "scores": friendly_scores
74
- })
75
 
76
- except ServiceError as se:
77
- return model_response(error=str(se))
78
  except Exception as e:
79
- logger.error(f"Readability scoring error: {e}", exc_info=True)
80
- return model_response(error="An error occurred during readability scoring.")
 
 
 
1
+ # app/services/readability.py
2
  import textstat
3
  import logging
4
+ from app.core.config import APP_NAME
5
+ from app.core.exceptions import ServiceError
6
 
7
+ logger = logging.getLogger(f"{APP_NAME}.services.readability")
8
 
9
  class ReadabilityScorer:
10
+ async def compute(self, text: str) -> dict:
11
  try:
12
  text = text.strip()
13
  if not text:
14
+ raise ServiceError(status_code=400, detail="Input text is empty for readability scoring.")
15
 
 
16
  scores = {
17
  "flesch_reading_ease": textstat.flesch_reading_ease(text),
18
  "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
 
22
  "automated_readability_index": textstat.automated_readability_index(text),
23
  }
24
 
 
25
  friendly_scores = {
26
  "flesch_reading_ease": {
27
+ "score": round(scores["flesch_reading_ease"], 2),
28
  "label": "Flesch Reading Ease",
29
  "description": "Higher is easier. 60–70 is plain English; 90+ is very easy."
30
  },
31
  "flesch_kincaid_grade": {
32
+ "score": round(scores["flesch_kincaid_grade"], 2),
33
  "label": "Flesch-Kincaid Grade Level",
34
  "description": "U.S. school grade. 8.0 means an 8th grader can understand it."
35
  },
36
  "gunning_fog_index": {
37
+ "score": round(scores["gunning_fog_index"], 2),
38
  "label": "Gunning Fog Index",
39
  "description": "Estimates years of formal education needed to understand."
40
  },
41
  "smog_index": {
42
+ "score": round(scores["smog_index"], 2),
43
  "label": "SMOG Index",
44
  "description": "Also estimates required years of education."
45
  },
46
  "coleman_liau_index": {
47
+ "score": round(scores["coleman_liau_index"], 2),
48
  "label": "Coleman-Liau Index",
49
  "description": "Grade level based on characters, not syllables."
50
  },
51
  "automated_readability_index": {
52
+ "score": round(scores["automated_readability_index"], 2),
53
  "label": "Automated Readability Index",
54
  "description": "Grade level using word and sentence lengths."
55
  }
56
  }
57
 
 
58
  ease_score = scores["flesch_reading_ease"]
59
  if ease_score >= 90:
60
  summary = "Very easy to read. Easily understood by 11-year-olds."
 
67
  else:
68
  summary = "Very difficult. Best understood by university graduates."
69
 
70
+ return {
71
  "readability_summary": summary,
72
  "scores": friendly_scores
73
+ }
74
 
 
 
75
  except Exception as e:
76
+ logger.error(f"Readability scoring error for text: '{text[:50]}...'", exc_info=True)
77
+ raise ServiceError(status_code=500, detail="An internal error occurred during readability scoring.") from e
78
+
79
+ # You can continue pasting the rest of your services here for production hardening
app/services/synonyms.py CHANGED
@@ -1,21 +1,27 @@
1
  import logging
2
  import asyncio
3
- from nltk.corpus import wordnet
4
- from transformers import AutoTokenizer
5
- from sentence_transformers import SentenceTransformer, util
6
  from typing import List, Dict
7
  from functools import lru_cache
8
 
9
- # Assuming these are available in your project structure
10
  from app.services.base import (
11
- model_response, ServiceError, timed_model_load,
12
- get_cached_model, DEVICE, get_spacy
 
 
 
 
 
 
 
 
13
  )
14
- from app.core.config import settings # Assuming settings might contain a BATCH_SIZE
 
 
 
15
 
16
- logger = logging.getLogger(__name__)
17
 
18
- # Mapping spaCy POS tags to WordNet POS tags
19
  SPACY_TO_WORDNET_POS = {
20
  "NOUN": wordnet.NOUN,
21
  "VERB": wordnet.VERB,
@@ -23,140 +29,129 @@ SPACY_TO_WORDNET_POS = {
23
  "ADV": wordnet.ADV,
24
  }
25
 
26
- # Only consider these POS tags for synonym suggestions
27
  CONTENT_POS_TAGS = {"NOUN", "VERB", "ADJ", "ADV"}
28
- SENTENCE_TRANSFORMER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
29
- DEFAULT_BATCH_SIZE = settings.SENTENCE_TRANSFORMER_BATCH_SIZE if hasattr(settings, 'SENTENCE_TRANSFORMER_BATCH_SIZE') else 32
30
 
31
  class SynonymSuggester:
32
  def __init__(self):
33
- self.sentence_model = self._load_sentence_transformer_model()
34
- self.nlp = self._load_spacy_model()
35
-
36
- def _load_sentence_transformer_model(self):
37
- def load_fn():
38
- # SentenceTransformer automatically handles device placement if CUDA is available
39
- # It can also be explicitly passed: device=DEVICE if DEVICE else None
40
- model = timed_model_load(
41
- "sentence_transformer",
42
- lambda: SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)
43
  )
44
- return model
45
- return get_cached_model("synonym_sentence_model", load_fn)
46
 
47
- def _load_spacy_model(self):
48
- # Using asyncio.to_thread for initial model load if it's blocking
49
- # but timed_model_load likely handles that by caching.
50
- return timed_model_load("spacy_en_model", lambda: get_spacy())
 
 
51
 
52
  async def suggest(self, text: str) -> dict:
53
- text = text.strip()
54
- if not text:
55
- raise ServiceError("Input text is empty.")
56
-
57
- # Use asyncio.to_thread consistently for blocking operations
58
- doc = await asyncio.to_thread(self.nlp, text)
59
-
60
- all_suggestions: Dict[str, List[str]] = {}
61
-
62
- # Encode original text once
63
- original_text_embedding = await asyncio.to_thread(
64
- self.sentence_model.encode, text, convert_to_tensor=True, normalize_embeddings=True
65
- )
66
-
67
- # 1. Collect all potential candidates and their contexts for batching
68
- candidate_data = [] # List of (original_word, wordnet_candidate, temp_sentence, original_word_idx)
69
-
70
- for token in doc:
71
- if (
72
- token.pos_ in CONTENT_POS_TAGS and
73
- len(token.text.strip()) > 2 and
74
- not token.is_punct and
75
- not token.is_space
76
- ):
77
- original_word = token.text
78
- word_start = token.idx
79
- word_end = token.idx + len(original_word)
80
-
81
- # Filter WordNet synonyms by the token's Part-of-Speech
82
- wordnet_pos = SPACY_TO_WORDNET_POS.get(token.pos_)
83
- if wordnet_pos is None:
84
- continue # Skip if no direct WordNet POS mapping
85
-
86
- wordnet_synonyms_candidates = await asyncio.to_thread(
87
- self._get_wordnet_synonyms_cached, original_word, wordnet_pos
88
- )
89
-
90
- if not wordnet_synonyms_candidates:
91
- continue
92
-
93
- for candidate in wordnet_synonyms_candidates:
94
- temp_sentence = text[:word_start] + candidate + text[word_end:]
95
- candidate_data.append({
96
- "original_word": original_word,
97
- "wordnet_candidate": candidate,
98
- "temp_sentence": temp_sentence,
99
- "original_word_idx": len(all_suggestions.get(original_word, [])) # Used for tracking initial suggestions count
100
- })
101
-
102
- # Initialize list for this original word, if not already
103
- if original_word not in all_suggestions:
104
- all_suggestions[original_word] = []
105
-
106
- if not candidate_data:
107
- return model_response(result={})
108
-
109
- # 2. Encode all candidate sentences in a single batch
110
- all_candidate_sentences = [data["temp_sentence"] for data in candidate_data]
111
- all_candidate_embeddings = await asyncio.to_thread(
112
- self.sentence_model.encode,
113
- all_candidate_sentences,
114
- batch_size=DEFAULT_BATCH_SIZE, # Use a configurable batch size
115
- convert_to_tensor=True,
116
- normalize_embeddings=True
117
- )
118
-
119
- # 3. Calculate similarities and filter
120
- # Ensure original_text_embedding is 2D for cos_sim if it's a single embedding
121
- # util.cos_sim expects (A, B) where A and B are matrices of embeddings
122
- # Reshape original_text_embedding if it's a 1D tensor
123
- if original_text_embedding.dim() == 1:
124
- original_text_embedding = original_text_embedding.unsqueeze(0)
125
-
126
- cosine_scores = util.cos_sim(original_text_embedding, all_candidate_embeddings)[0] # [0] because cos_sim returns a matrix
127
-
128
- similarity_threshold = 0.65
129
- top_n_suggestions = 5
130
-
131
- # Reconstruct results by iterating through candidate_data and scores
132
- for i, data in enumerate(candidate_data):
133
- original_word = data["original_word"]
134
- candidate = data["wordnet_candidate"]
135
- score = cosine_scores[i].item() # Get scalar score from tensor
136
-
137
- # Apply filtering criteria
138
- if score >= similarity_threshold and candidate.lower() != original_word.lower():
139
- if len(all_suggestions[original_word]) < top_n_suggestions:
140
- all_suggestions[original_word].append(candidate)
141
-
142
- # Remove any words for which no suggestions were found after filtering
143
- final_suggestions = {
144
- word: suggestions for word, suggestions in all_suggestions.items() if suggestions
145
- }
146
-
147
- return model_response(result=final_suggestions)
148
 
149
  @lru_cache(maxsize=5000)
150
  def _get_wordnet_synonyms_cached(self, word: str, pos: str) -> List[str]:
151
- """
152
- Retrieves synonyms for a word from WordNet, filtered by Part-of-Speech.
153
- """
154
  synonyms = set()
155
- for syn in wordnet.synsets(word, pos=pos): # Filter by POS
156
  for lemma in syn.lemmas():
157
  name = lemma.name().replace("_", " ").lower()
158
- # Basic filtering for valid word forms
159
  if name.isalpha() and len(name) > 1:
160
  synonyms.add(name)
161
- synonyms.discard(word.lower()) # Ensure original word is not included
162
- return list(synonyms)
 
1
  import logging
2
  import asyncio
 
 
 
3
  from typing import List, Dict
4
  from functools import lru_cache
5
 
 
6
  from app.services.base import (
7
+ load_spacy_model,
8
+ load_sentence_transformer_model,
9
+ ensure_nltk_resource
10
+ )
11
+ from app.core.config import (
12
+ settings,
13
+ APP_NAME,
14
+ SPACY_MODEL_ID,
15
+ WORDNET_NLTK_ID,
16
+ SENTENCE_TRANSFORMER_MODEL_ID
17
  )
18
+ from app.core.exceptions import ServiceError, ModelNotDownloadedError
19
+
20
+ from nltk.corpus import wordnet
21
+ from sentence_transformers.util import cos_sim
22
 
23
+ logger = logging.getLogger(f"{APP_NAME}.services.synonyms")
24
 
 
25
  SPACY_TO_WORDNET_POS = {
26
  "NOUN": wordnet.NOUN,
27
  "VERB": wordnet.VERB,
 
29
  "ADV": wordnet.ADV,
30
  }
31
 
 
32
  CONTENT_POS_TAGS = {"NOUN", "VERB", "ADJ", "ADV"}
33
+
 
34
 
35
  class SynonymSuggester:
36
  def __init__(self):
37
+ self._sentence_model = None
38
+ self._nlp = None
39
+
40
+ def _get_sentence_model(self):
41
+ if self._sentence_model is None:
42
+ self._sentence_model = load_sentence_transformer_model(
43
+ SENTENCE_TRANSFORMER_MODEL_ID
 
 
 
44
  )
45
+ return self._sentence_model
 
46
 
47
+ def _get_nlp(self):
48
+ if self._nlp is None:
49
+ self._nlp = load_spacy_model(
50
+ SPACY_MODEL_ID
51
+ )
52
+ return self._nlp
53
 
54
  async def suggest(self, text: str) -> dict:
55
+ try:
56
+ text = text.strip()
57
+ if not text:
58
+ raise ServiceError(status_code=400, detail="Input text is empty for synonym suggestion.")
59
+
60
+ sentence_model = self._get_sentence_model()
61
+ nlp = self._get_nlp()
62
+ await asyncio.to_thread(ensure_nltk_resource, WORDNET_NLTK_ID)
63
+
64
+ doc = await asyncio.to_thread(nlp, text)
65
+ all_suggestions: Dict[str, List[str]] = {}
66
+
67
+ original_text_embedding = await asyncio.to_thread(
68
+ sentence_model.encode, text,
69
+ convert_to_tensor=True,
70
+ normalize_embeddings=True
71
+ )
72
+
73
+ candidate_data = []
74
+
75
+ for token in doc:
76
+ if token.pos_ in CONTENT_POS_TAGS and len(token.text.strip()) > 2 and not token.is_punct and not token.is_space:
77
+ original_word = token.text
78
+ word_start = token.idx
79
+ word_end = token.idx + len(original_word)
80
+ wordnet_pos = SPACY_TO_WORDNET_POS.get(token.pos_)
81
+ if not wordnet_pos:
82
+ continue
83
+
84
+ wordnet_candidates = await asyncio.to_thread(
85
+ self._get_wordnet_synonyms_cached, original_word, wordnet_pos
86
+ )
87
+ if not wordnet_candidates:
88
+ continue
89
+
90
+ if original_word not in all_suggestions:
91
+ all_suggestions[original_word] = []
92
+
93
+ for candidate in wordnet_candidates:
94
+ temp_sentence = text[:word_start] + candidate + text[word_end:]
95
+ candidate_data.append({
96
+ "original_word": original_word,
97
+ "wordnet_candidate": candidate,
98
+ "temp_sentence": temp_sentence,
99
+ })
100
+
101
+ if not candidate_data:
102
+ return {"suggestions": {}}
103
+
104
+ all_candidate_sentences = [c["temp_sentence"] for c in candidate_data]
105
+ all_candidate_embeddings = await asyncio.to_thread(
106
+ sentence_model.encode,
107
+ all_candidate_sentences,
108
+ batch_size=settings.SENTENCE_TRANSFORMER_BATCH_SIZE,
109
+ convert_to_tensor=True,
110
+ normalize_embeddings=True
111
+ )
112
+
113
+ if original_text_embedding.dim() == 1:
114
+ original_text_embedding = original_text_embedding.unsqueeze(0)
115
+
116
+ cosine_scores = cos_sim(original_text_embedding, all_candidate_embeddings)[0]
117
+
118
+ similarity_threshold = 0.65
119
+ top_n = 5
120
+ temp_scored: Dict[str, List[tuple]] = {word: [] for word in all_suggestions}
121
+
122
+ for i, data in enumerate(candidate_data):
123
+ word = data["original_word"]
124
+ candidate = data["wordnet_candidate"]
125
+ score = cosine_scores[i].item()
126
+ if score >= similarity_threshold and candidate.lower() != word.lower():
127
+ temp_scored[word].append((score, candidate))
128
+
129
+ final_suggestions = {}
130
+ for word, scored in temp_scored.items():
131
+ if scored:
132
+ sorted_unique = []
133
+ seen = set()
134
+ for score, candidate in sorted(scored, key=lambda x: x[0], reverse=True):
135
+ if candidate not in seen:
136
+ sorted_unique.append(candidate)
137
+ seen.add(candidate)
138
+ if len(sorted_unique) >= top_n:
139
+ break
140
+ final_suggestions[word] = sorted_unique
141
+
142
+ return {"suggestions": final_suggestions}
143
+
144
+ except Exception as e:
145
+ logger.error(f"Synonym suggestion error for text: '{text[:50]}...'", exc_info=True)
146
+ raise ServiceError(status_code=500, detail="An internal error occurred during synonym suggestion.") from e
 
 
 
147
 
148
  @lru_cache(maxsize=5000)
149
  def _get_wordnet_synonyms_cached(self, word: str, pos: str) -> List[str]:
 
 
 
150
  synonyms = set()
151
+ for syn in wordnet.synsets(word, pos=pos):
152
  for lemma in syn.lemmas():
153
  name = lemma.name().replace("_", " ").lower()
 
154
  if name.isalpha() and len(name) > 1:
155
  synonyms.add(name)
156
+ synonyms.discard(word.lower())
157
+ return sorted(synonyms)
app/services/tone_classification.py CHANGED
@@ -1,75 +1,60 @@
1
  import logging
2
  import torch
3
- from transformers import pipeline
4
- from app.services.base import get_cached_model, model_response, ServiceError
5
- from app.core.config import settings
6
 
7
- logger = logging.getLogger(__name__)
8
 
9
  class ToneClassifier:
10
  def __init__(self):
11
- self.classifier = self._load_model()
12
-
13
- def _load_model(self):
14
- """
15
- Loads and caches the sentiment-analysis pipeline.
16
- """
17
- def load_fn():
18
- model = pipeline(
19
- "sentiment-analysis", # Or "text-classification" if you prefer
20
- model=settings.TONE_MODEL,
21
- device=0 if torch.cuda.is_available() else -1,
22
- return_all_scores=True # Keep this true
23
  )
24
- logger.info(f"ToneClassifier model loaded on {'CUDA' if torch.cuda.is_available() else 'CPU'}")
25
- return model
26
-
27
- return get_cached_model("tone_model", load_fn)
28
 
29
- def classify(self, text: str) -> dict:
30
  try:
31
  text = text.strip()
32
  if not text:
33
- raise ServiceError("Input text is empty.")
34
 
35
- raw_results = self.classifier(text)
 
36
 
37
- # Check for expected pipeline output format
38
  if not (isinstance(raw_results, list) and raw_results and isinstance(raw_results[0], list)):
39
  logger.error(f"Unexpected raw_results format from pipeline: {raw_results}")
40
- return model_response(error="Unexpected model output format.")
41
 
42
  scores_for_text = raw_results[0]
43
-
44
- # Sort the emotions by score in descending order
45
  sorted_emotions = sorted(scores_for_text, key=lambda x: x['score'], reverse=True)
46
-
47
- # --- NEW LOGGING ADDITIONS START HERE ---
48
 
49
- # Log all emotion scores for the given text (useful for seeing the full distribution)
50
  logger.debug(f"Input Text: '{text}'")
51
  logger.debug("--- Emotion Scores (Label: Score) ---")
52
  for emotion in sorted_emotions:
53
  logger.debug(f" {emotion['label']}: {emotion['score']:.4f}")
54
  logger.debug("-------------------------------------")
55
 
56
- # --- NEW LOGGING ADDITIONS END HERE ---
57
-
58
  top_emotion = sorted_emotions[0]
59
  predicted_label = top_emotion.get("label", "Unknown")
60
  predicted_score = top_emotion.get("score", 0.0)
61
 
62
- # Apply the confidence threshold
63
  if predicted_score >= settings.TONE_CONFIDENCE_THRESHOLD:
64
- logger.info(f"Final prediction for '{text}': '{predicted_label}' (Score: {predicted_score:.4f}, Above Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f})")
65
- return model_response(result=predicted_label)
66
  else:
67
- logger.info(f"Final prediction for '{text}': 'neutral' (Top Score: {predicted_score:.4f}, Below Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f}).")
68
- return model_response(result="neutral")
69
 
70
- except ServiceError as se:
71
- logger.error(f"Tone classification ServiceError for text '{text}': {se}")
72
- return model_response(error=str(se))
73
  except Exception as e:
74
- logger.error(f"Tone classification unexpected error for text '{text}': {e}", exc_info=True)
75
- return model_response(error="An error occurred during tone classification.")
 
 
 
1
  import logging
2
  import torch
3
+ from app.services.base import load_hf_pipeline
4
+ from app.core.config import APP_NAME, settings
5
+ from app.core.exceptions import ServiceError, ModelNotDownloadedError
6
 
7
+ logger = logging.getLogger(f"{APP_NAME}.services.tone_classification")
8
 
9
  class ToneClassifier:
10
  def __init__(self):
11
+ self._classifier = None
12
+
13
+ def _get_classifier(self):
14
+ if self._classifier is None:
15
+ self._classifier = load_hf_pipeline(
16
+ model_id=settings.TONE_MODEL_ID,
17
+ task="text-classification",
18
+ feature_name="Tone Classification",
19
+ top_k=None
 
 
 
20
  )
21
+ return self._classifier
 
 
 
22
 
23
+ async def classify(self, text: str) -> dict:
24
  try:
25
  text = text.strip()
26
  if not text:
27
+ raise ServiceError(status_code=400, detail="Input text is empty for tone classification.")
28
 
29
+ classifier = self._get_classifier()
30
+ raw_results = classifier(text)
31
 
 
32
  if not (isinstance(raw_results, list) and raw_results and isinstance(raw_results[0], list)):
33
  logger.error(f"Unexpected raw_results format from pipeline: {raw_results}")
34
+ raise ServiceError(status_code=500, detail="Unexpected model output format for tone classification.")
35
 
36
  scores_for_text = raw_results[0]
 
 
37
  sorted_emotions = sorted(scores_for_text, key=lambda x: x['score'], reverse=True)
 
 
38
 
 
39
  logger.debug(f"Input Text: '{text}'")
40
  logger.debug("--- Emotion Scores (Label: Score) ---")
41
  for emotion in sorted_emotions:
42
  logger.debug(f" {emotion['label']}: {emotion['score']:.4f}")
43
  logger.debug("-------------------------------------")
44
 
 
 
45
  top_emotion = sorted_emotions[0]
46
  predicted_label = top_emotion.get("label", "Unknown")
47
  predicted_score = top_emotion.get("score", 0.0)
48
 
 
49
  if predicted_score >= settings.TONE_CONFIDENCE_THRESHOLD:
50
+ logger.info(f"Final prediction for '{text[:50]}...': '{predicted_label}' (Score: {predicted_score:.4f}, Above Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f})")
51
+ return {"tone": predicted_label}
52
  else:
53
+ logger.info(f"Final prediction for '{text[:50]}...': 'neutral' (Top Score: {predicted_score:.4f}, Below Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f}).")
54
+ return {"tone": "neutral"}
55
 
 
 
 
56
  except Exception as e:
57
+ logger.error(f"Tone classification unexpected error for text '{text[:50]}...': {e}", exc_info=True)
58
+ raise ServiceError(status_code=500, detail="An internal error occurred during tone classification.") from e
59
+
60
+
app/services/translation.py CHANGED
@@ -1,45 +1,47 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
- import torch
3
  import logging
4
- from app.services.base import get_cached_model, DEVICE, timed_model_load, ServiceError, model_response
5
- from app.core.config import settings
 
6
 
7
- logger = logging.getLogger(__name__)
8
 
9
  class Translator:
10
  def __init__(self):
11
- self.tokenizer, self.model = self._load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- def _load_model(self):
14
- def load_fn():
15
- tokenizer = timed_model_load("translate_tokenizer", lambda: AutoTokenizer.from_pretrained(settings.TRANSLATION_MODEL))
16
- model = timed_model_load("translate_model", lambda: AutoModelForSeq2SeqLM.from_pretrained(settings.TRANSLATION_MODEL))
17
- model = model.to(DEVICE).eval()
18
- return tokenizer, model
19
- return get_cached_model("translate", load_fn)
20
-
21
- def translate(self, text: str, target_lang: str) -> dict:
22
  try:
23
- text = text.strip()
24
- target_lang = target_lang.strip()
 
 
25
 
26
- if not text:
27
- raise ServiceError("Input text is empty.")
28
- if not target_lang:
29
- raise ServiceError("Target language is empty.")
30
- if target_lang not in settings.SUPPORTED_TRANSLATION_LANGUAGES:
31
- raise ServiceError(f"Unsupported target language: {target_lang}")
32
 
33
- prompt = f">>{target_lang}<< {text}"
34
- with torch.no_grad():
35
- inputs = self.tokenizer([prompt], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
36
- outputs = self.model.generate(**inputs, max_length=256, num_beams=1, early_stopping=True)
37
- result = self.tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
38
- return model_response(result=result)
39
-
40
- except ServiceError as se:
41
- return model_response(error=str(se))
42
  except Exception as e:
43
- logger.error(f"Translation error: {e}")
44
- return model_response(error="An error occurred during translation.")
45
-
 
 
 
1
  import logging
2
+ from app.services.base import load_hf_pipeline
3
+ from app.core.config import settings, APP_NAME
4
+ from app.core.exceptions import ServiceError
5
 
6
+ logger = logging.getLogger(f"{APP_NAME}.services.translation")
7
 
8
  class Translator:
9
  def __init__(self):
10
+ self._pipeline = None
11
+
12
+ def _get_pipeline(self):
13
+ if self._pipeline is None:
14
+ logger.info("Loading translation pipeline...")
15
+ self._pipeline = load_hf_pipeline(
16
+ model_id=settings.TRANSLATION_MODEL_ID,
17
+ task="translation",
18
+ feature_name="Translation"
19
+ )
20
+ return self._pipeline
21
+
22
+ async def translate(self, text: str, target_lang: str) -> dict:
23
+ text = text.strip()
24
+ target_lang = target_lang.strip()
25
+
26
+ if not text:
27
+ raise ServiceError(status_code=400, detail="Input text is empty for translation.")
28
+ if not target_lang:
29
+ raise ServiceError(status_code=400, detail="Target language is empty for translation.")
30
+ if target_lang not in settings.SUPPORTED_TRANSLATION_LANGUAGES:
31
+ raise ServiceError(
32
+ status_code=400,
33
+ detail=f"Unsupported target language: {target_lang}. "
34
+ f"Supported languages are: {', '.join(settings.SUPPORTED_TRANSLATION_LANGUAGES)}"
35
+ )
36
 
 
 
 
 
 
 
 
 
 
37
  try:
38
+ pipeline = self._get_pipeline()
39
+ prompt = f">>{target_lang}<< {text}"
40
+ result = pipeline(prompt, max_length=256, num_beams=1, early_stopping=True)[0]
41
+ translated_text = result.get("translation_text") or result.get("generated_text")
42
 
43
+ return {"translated_text": translated_text.strip()}
 
 
 
 
 
44
 
 
 
 
 
 
 
 
 
 
45
  except Exception as e:
46
+ logger.error(f"Translation error for text: '{text[:50]}...' to '{target_lang}'", exc_info=True)
47
+ raise ServiceError(status_code=500, detail="An internal error occurred during translation.") from e
 
app/services/voice_detection.py CHANGED
@@ -1,38 +1,55 @@
 
1
  import logging
2
- from app.services.base import get_spacy, model_response, ServiceError
 
 
3
 
4
- logger = logging.getLogger(__name__)
5
 
6
  class VoiceDetector:
7
  def __init__(self):
8
- self.nlp = get_spacy()
9
 
10
- def classify(self, text: str) -> dict:
 
 
 
 
 
11
  try:
12
  text = text.strip()
13
  if not text:
14
- raise ServiceError("Input text is empty.")
 
 
 
15
 
16
- doc = self.nlp(text)
17
  passive_sentences = 0
18
  total_sentences = 0
19
 
20
  for sent in doc.sents:
21
  total_sentences += 1
 
22
  for token in sent:
23
- if token.dep_ == "nsubjpass":
24
- passive_sentences += 1
25
  break
 
 
26
 
27
  if total_sentences == 0:
28
- return model_response(result="Unknown")
29
 
30
  ratio = passive_sentences / total_sentences
31
- return model_response(result="Passive" if ratio > 0.5 else "Active")
32
 
33
- except ServiceError as se:
34
- return model_response(error=str(se))
35
- except Exception as e:
36
- logger.error(f"Voice detection error: {e}")
37
- return model_response(error="An error occurred during voice detection.")
 
38
 
 
 
 
 
1
+ import asyncio
2
  import logging
3
+ from app.services.base import load_spacy_model
4
+ from app.core.config import APP_NAME, SPACY_MODEL_ID
5
+ from app.core.exceptions import ServiceError, ModelNotDownloadedError
6
 
7
+ logger = logging.getLogger(f"{APP_NAME}.services.voice_detection")
8
 
9
  class VoiceDetector:
10
  def __init__(self):
11
+ self._nlp = None
12
 
13
+ def _get_nlp(self):
14
+ if self._nlp is None:
15
+ self._nlp = load_spacy_model(SPACY_MODEL_ID)
16
+ return self._nlp
17
+
18
+ async def classify(self, text: str) -> dict:
19
  try:
20
  text = text.strip()
21
  if not text:
22
+ raise ServiceError(status_code=400, detail="Input text is empty for voice detection.")
23
+
24
+ nlp = self._get_nlp()
25
+ doc = await asyncio.to_thread(nlp, text)
26
 
 
27
  passive_sentences = 0
28
  total_sentences = 0
29
 
30
  for sent in doc.sents:
31
  total_sentences += 1
32
+ is_passive_sentence = False
33
  for token in sent:
34
+ if token.dep_ == "nsubjpass" and token.head.pos_ == "VERB":
35
+ is_passive_sentence = True
36
  break
37
+ if is_passive_sentence:
38
+ passive_sentences += 1
39
 
40
  if total_sentences == 0:
41
+ return {"voice": "unknown", "passive_ratio": 0.0}
42
 
43
  ratio = passive_sentences / total_sentences
44
+ voice_type = "Passive" if ratio > 0.1 else "Active"
45
 
46
+ return {
47
+ "voice": voice_type,
48
+ "passive_ratio": round(ratio, 3),
49
+ "passive_sentences_count": passive_sentences,
50
+ "total_sentences_count": total_sentences
51
+ }
52
 
53
+ except Exception as e:
54
+ logger.error(f"Voice detection error for text: '{text[:50]}...': {e}", exc_info=True)
55
+ raise ServiceError(status_code=500, detail="An internal error occurred during voice detection.") from e