Spaces:
Runtime error
Runtime error
iamspruce
commited on
Commit
·
73a6a7e
1
Parent(s):
d893801
fixed the api
Browse files- Dockerfile +16 -21
- app/core/app.py +0 -37
- app/core/config.py +102 -24
- app/core/exceptions.py +53 -0
- app/core/logging.py +51 -8
- app/core/model_manager.py +280 -0
- app/core/prompts.py +0 -28
- app/main.py +117 -3
- app/queue.py +0 -104
- app/routers/analyze.py +90 -38
- app/routers/grammar.py +39 -26
- app/routers/inclusive_language.py +41 -7
- app/routers/paraphrase.py +41 -20
- app/routers/readability.py +33 -11
- app/routers/rewrite.py +56 -15
- app/routers/synonyms.py +37 -13
- app/routers/tone.py +42 -21
- app/routers/translate.py +44 -20
- app/routers/voice.py +41 -7
- app/services/base.py +112 -29
- app/services/gpt4_rewrite.py +55 -25
- app/services/grammar.py +39 -40
- app/services/inclusive_language.py +108 -56
- app/services/paraphrase.py +30 -34
- app/services/readability.py +18 -19
- app/services/synonyms.py +125 -130
- app/services/tone_classification.py +27 -42
- app/services/translation.py +37 -35
- app/services/voice_detection.py +32 -15
Dockerfile
CHANGED
@@ -2,40 +2,35 @@ FROM python:3.10-slim
|
|
2 |
|
3 |
WORKDIR /app
|
4 |
|
5 |
-
# Install system dependencies
|
6 |
-
#
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
10 |
|
11 |
COPY requirements.txt .
|
12 |
RUN pip install --no-cache-dir -r requirements.txt
|
13 |
|
14 |
-
# ---
|
15 |
-
#
|
16 |
RUN python -m spacy download en_core_web_sm
|
17 |
-
|
18 |
-
# --- Install NLTK WordNet data ---
|
19 |
-
# This downloads the WordNet corpus for NLTK
|
20 |
RUN python -m nltk.downloader wordnet
|
21 |
|
22 |
-
# --- Configure cache directories
|
23 |
-
# HF_HOME is where SentenceTransformers and other Hugging Face models will cache.
|
24 |
-
# /.cache is also a common location many libraries default to if HF_HOME isn't set,
|
25 |
-
# or for other internal caching. Setting permissions ensures the app can write there.
|
26 |
ENV HF_HOME=/cache
|
27 |
-
ENV TRANSFORMERS_CACHE=/cache
|
28 |
ENV NLTK_DATA=/nltk_data
|
|
|
29 |
|
30 |
# Create directories and set appropriate permissions
|
31 |
RUN mkdir -p /cache && chmod -R 777 /cache
|
32 |
-
RUN mkdir -p /
|
33 |
-
|
34 |
-
# Ensure NLTK uses the specified data path.
|
35 |
-
# This makes subsequent 'nltk.downloader' calls store data here,
|
36 |
-
# and NLTK will look here first.
|
37 |
-
RUN python -c "import nltk; nltk.data.path.append('/nltk_data')"
|
38 |
|
|
|
|
|
39 |
|
40 |
COPY app ./app
|
41 |
|
|
|
2 |
|
3 |
WORKDIR /app
|
4 |
|
5 |
+
# Install system dependencies (excluding git)
|
6 |
+
# Clean up apt lists to reduce image size
|
7 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
8 |
+
# Add any other core system dependencies here if needed, but not git
|
9 |
+
# e.g., libpq-dev for psycopg2, if you add a PostgreSQL dependency later
|
10 |
+
# Example: libpq-dev
|
11 |
+
&& rm -rf /var/lib/apt/lists/*
|
12 |
|
13 |
COPY requirements.txt .
|
14 |
RUN pip install --no-cache-dir -r requirements.txt
|
15 |
|
16 |
+
# --- Pre-download models during Docker build ---
|
17 |
+
# Ensure spacy and nltk are installed via requirements.txt before these steps
|
18 |
RUN python -m spacy download en_core_web_sm
|
|
|
|
|
|
|
19 |
RUN python -m nltk.downloader wordnet
|
20 |
|
21 |
+
# --- Configure cache directories using Docker ENV (these take precedence) ---
|
|
|
|
|
|
|
22 |
ENV HF_HOME=/cache
|
23 |
+
ENV TRANSFORMERS_CACHE=/cache
|
24 |
ENV NLTK_DATA=/nltk_data
|
25 |
+
ENV SPACY_DATA=/spacy_data
|
26 |
|
27 |
# Create directories and set appropriate permissions
|
28 |
RUN mkdir -p /cache && chmod -R 777 /cache
|
29 |
+
RUN mkdir -p /nltk_data && chmod -R 777 /nltk_data
|
30 |
+
RUN mkdir -p /spacy_data && chmod -R 777 /spacy_data
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
# It's good to also create the /root/.cache for general system caching in Docker
|
33 |
+
RUN mkdir -p /root/.cache && chmod -R 777 /root/.cache
|
34 |
|
35 |
COPY app ./app
|
36 |
|
app/core/app.py
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from fastapi import FastAPI
|
3 |
-
from fastapi.middleware.gzip import GZipMiddleware
|
4 |
-
from contextlib import asynccontextmanager
|
5 |
-
from app.routers import grammar, tone, voice, inclusive_language, readability, paraphrase, translate, rewrite, analyze
|
6 |
-
from app.queue import start_workers
|
7 |
-
from app.core.middleware import setup_middlewares
|
8 |
-
|
9 |
-
@asynccontextmanager
|
10 |
-
async def lifespan(app: FastAPI):
|
11 |
-
num_workers = int(os.getenv("WORKER_COUNT", 4))
|
12 |
-
start_workers(num_workers)
|
13 |
-
yield
|
14 |
-
|
15 |
-
def create_app() -> FastAPI:
|
16 |
-
app = FastAPI(lifespan=lifespan)
|
17 |
-
app.add_middleware(GZipMiddleware, minimum_size=500)
|
18 |
-
setup_middlewares(app)
|
19 |
-
|
20 |
-
for router, tag in [
|
21 |
-
(grammar.router, "Grammar"),
|
22 |
-
(tone.router, "Tone"),
|
23 |
-
(voice.router, "Voice"),
|
24 |
-
(inclusive_language.router, "Inclusive Language"),
|
25 |
-
(readability.router, "Readability"),
|
26 |
-
(paraphrase.router, "Paraphrasing"),
|
27 |
-
(translate.router, "Translation"),
|
28 |
-
(rewrite.router, "Rewrite"),
|
29 |
-
(analyze.router, "Analyze")
|
30 |
-
]:
|
31 |
-
app.include_router(router, tags=[tag])
|
32 |
-
|
33 |
-
@app.get("/")
|
34 |
-
def root():
|
35 |
-
return {"message": "Welcome to Wellsaid API"}
|
36 |
-
|
37 |
-
return app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/core/config.py
CHANGED
@@ -1,22 +1,90 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
class Settings(BaseSettings):
|
6 |
-
|
7 |
-
HOST: str = "0.0.0.0"
|
8 |
-
PORT: int = 7860
|
9 |
-
RELOAD: bool = True
|
10 |
|
11 |
-
#
|
12 |
-
|
13 |
-
|
14 |
|
15 |
-
#
|
16 |
-
|
17 |
OPENAI_MODEL: str = "gpt-4o"
|
18 |
OPENAI_TEMPERATURE: float = 0.7
|
19 |
-
OPENAI_MAX_TOKENS: int =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
SUPPORTED_TRANSLATION_LANGUAGES: List[str] = [
|
21 |
"fr", "fr_BE", "fr_CA", "fr_FR", "wa", "frp", "oc", "ca", "rm", "lld",
|
22 |
"fur", "lij", "lmo", "es", "es_AR", "es_CL", "es_CO", "es_CR", "es_DO",
|
@@ -26,18 +94,28 @@ class Settings(BaseSettings):
|
|
26 |
"sc", "ro", "la"
|
27 |
]
|
28 |
|
29 |
-
#
|
30 |
-
|
31 |
-
PARAPHRASE_MODEL: str = "humarin/chatgpt_paraphraser_on_T5_base"
|
32 |
-
TONE_MODEL: str = "boltuix/NeuroFeel"
|
33 |
-
TONE_CONFIDENCE_THRESHOLD: float = 0.2
|
34 |
-
TRANSLATION_MODEL: str = "Helsinki-NLP/opus-mt-en-ROMANCE"
|
35 |
-
SENTENCE_TRANSFORMER_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
|
36 |
-
|
37 |
-
class Config:
|
38 |
-
env_file = ".env"
|
39 |
-
case_sensitive = True
|
40 |
|
|
|
|
|
|
|
41 |
|
42 |
-
# Singleton instance
|
43 |
settings = Settings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List, Optional
|
5 |
|
6 |
+
from pydantic import Field
|
7 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
8 |
+
|
9 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
10 |
+
# ⛺ Paths & Constants
|
11 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
12 |
+
|
13 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
14 |
+
APP_DATA_ROOT_DIR = Path.home() / ".wellsaid_app_data"
|
15 |
+
MODELS_DIR = APP_DATA_ROOT_DIR / "models"
|
16 |
+
NLTK_DATA_DIR = APP_DATA_ROOT_DIR / "nltk_data"
|
17 |
+
|
18 |
+
OFFLINE_MODE = os.getenv("OFFLINE_MODE", "false").lower() == "true"
|
19 |
+
|
20 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
21 |
+
# 📁 Ensure Directories Exist (for offline desktop usage)
|
22 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
23 |
+
|
24 |
+
for directory in [MODELS_DIR, NLTK_DATA_DIR]:
|
25 |
+
try:
|
26 |
+
directory.mkdir(parents=True, exist_ok=True)
|
27 |
+
except Exception as e:
|
28 |
+
logging.warning(f"Failed to create directory {directory}: {e}")
|
29 |
+
|
30 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
31 |
+
# 🌍 Environment Variables Setup (only if not already set)
|
32 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
33 |
+
|
34 |
+
env_defaults = {
|
35 |
+
"HF_HOME": str(MODELS_DIR / "hf_cache"),
|
36 |
+
"NLTK_DATA": str(NLTK_DATA_DIR),
|
37 |
+
"SPACY_DATA": str(MODELS_DIR),
|
38 |
+
}
|
39 |
+
|
40 |
+
for var, default in env_defaults.items():
|
41 |
+
if not os.getenv(var):
|
42 |
+
os.environ[var] = default
|
43 |
+
|
44 |
+
# Update nltk.data.path immediately (if nltk is installed)
|
45 |
+
try:
|
46 |
+
import nltk
|
47 |
+
if str(NLTK_DATA_DIR) not in nltk.data.path:
|
48 |
+
nltk.data.path.append(str(NLTK_DATA_DIR))
|
49 |
+
except ImportError:
|
50 |
+
pass
|
51 |
+
|
52 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
53 |
+
# ⚙️ Application Settings
|
54 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
55 |
|
56 |
class Settings(BaseSettings):
|
57 |
+
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
|
|
|
|
|
|
58 |
|
59 |
+
# App basics
|
60 |
+
APP_NAME: str = "WellSaidApp"
|
61 |
+
API_KEY: str = "your_strong_api_key_here"
|
62 |
|
63 |
+
# OpenAI
|
64 |
+
OPENAI_API_KEY: Optional[str] = None
|
65 |
OPENAI_MODEL: str = "gpt-4o"
|
66 |
OPENAI_TEMPERATURE: float = 0.7
|
67 |
+
OPENAI_MAX_TOKENS: int = 1500
|
68 |
+
|
69 |
+
# API server
|
70 |
+
HOST: str = "127.0.0.1"
|
71 |
+
PORT: int = 8000
|
72 |
+
RELOAD: bool = False
|
73 |
+
WORKER_COUNT: int = 1
|
74 |
+
|
75 |
+
# NLP models
|
76 |
+
SPACY_MODEL_ID: str = "en_core_web_sm"
|
77 |
+
SENTENCE_TRANSFORMER_MODEL_ID: str = "all-MiniLM-L6-v2"
|
78 |
+
SENTENCE_TRANSFORMER_BATCH_SIZE: int = 2
|
79 |
+
|
80 |
+
GRAMMAR_MODEL_ID: str = "visheratin/t5-efficient-mini-grammar-correction"
|
81 |
+
PARAPHRASE_MODEL_ID: str = "humarin/chatgpt_paraphraser_on_T5_base"
|
82 |
+
TONE_MODEL_ID: str = "boltuix/NeuroFeel"
|
83 |
+
TONE_CONFIDENCE_THRESHOLD: float = 1.0
|
84 |
+
TRANSLATION_MODEL_ID: str = "Helsinki-NLP/opus-mt-en-ROMANCE"
|
85 |
+
|
86 |
+
WORDNET_NLTK_ID: str = "wordnet.zip"
|
87 |
+
|
88 |
SUPPORTED_TRANSLATION_LANGUAGES: List[str] = [
|
89 |
"fr", "fr_BE", "fr_CA", "fr_FR", "wa", "frp", "oc", "ca", "rm", "lld",
|
90 |
"fur", "lij", "lmo", "es", "es_AR", "es_CL", "es_CO", "es_CR", "es_DO",
|
|
|
94 |
"sc", "ro", "la"
|
95 |
]
|
96 |
|
97 |
+
# Data dirs
|
98 |
+
INCLUSIVE_RULES_DIR: str = "app/data/en"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
101 |
+
# 📦 App-wide constants
|
102 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
103 |
|
|
|
104 |
settings = Settings()
|
105 |
+
|
106 |
+
# Core settings for import
|
107 |
+
APP_NAME = settings.APP_NAME
|
108 |
+
LOCAL_API_HOST = settings.HOST
|
109 |
+
LOCAL_API_PORT = settings.PORT
|
110 |
+
|
111 |
+
# Model names
|
112 |
+
SPACY_MODEL_ID = settings.SPACY_MODEL_ID
|
113 |
+
SENTENCE_TRANSFORMER_MODEL_ID = settings.SENTENCE_TRANSFORMER_MODEL_ID
|
114 |
+
GRAMMAR_MODEL_ID = settings.GRAMMAR_MODEL_ID
|
115 |
+
PARAPHRASE_MODEL_ID = settings.PARAPHRASE_MODEL_ID
|
116 |
+
TONE_MODEL_ID = settings.TONE_MODEL_ID
|
117 |
+
TRANSLATION_MODEL_ID = settings.TRANSLATION_MODEL_ID
|
118 |
+
WORDNET_NLTK_ID = settings.WORDNET_NLTK_ID
|
119 |
+
|
120 |
+
# Data
|
121 |
+
INCLUSIVE_RULES_DIR = settings.INCLUSIVE_RULES_DIR
|
app/core/exceptions.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/core/exceptions.py
|
2 |
+
from fastapi import HTTPException
|
3 |
+
|
4 |
+
|
5 |
+
class ServiceError(HTTPException):
|
6 |
+
"""
|
7 |
+
Base exception for general service-related errors.
|
8 |
+
Inherits from HTTPException to allow direct use in FastAPI responses.
|
9 |
+
"""
|
10 |
+
def __init__(self, status_code: int, detail: str, error_type: str = "ServiceError"):
|
11 |
+
super().__init__(status_code=status_code, detail=detail)
|
12 |
+
self.error_type = error_type
|
13 |
+
|
14 |
+
def to_dict(self):
|
15 |
+
"""Returns a dictionary representation of the exception."""
|
16 |
+
return {
|
17 |
+
"detail": self.detail,
|
18 |
+
"status_code": self.status_code,
|
19 |
+
"error_type": self.error_type
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
class ModelNotDownloadedError(ServiceError):
|
24 |
+
"""
|
25 |
+
Raised when a required model is not found locally.
|
26 |
+
Informs the client that a download is necessary.
|
27 |
+
"""
|
28 |
+
def __init__(self, model_id: str, feature_name: str, detail: str = None):
|
29 |
+
detail = detail or f"Model '{model_id}' required for '{feature_name}' is not downloaded."
|
30 |
+
super().__init__(status_code=424, detail=detail, error_type="ModelNotDownloaded")
|
31 |
+
self.model_id = model_id
|
32 |
+
self.feature_name = feature_name
|
33 |
+
|
34 |
+
def to_dict(self):
|
35 |
+
base_dict = super().to_dict()
|
36 |
+
base_dict.update({
|
37 |
+
"model_id": self.model_id,
|
38 |
+
"feature_name": self.feature_name
|
39 |
+
})
|
40 |
+
return base_dict
|
41 |
+
|
42 |
+
|
43 |
+
class ModelDownloadFailedError(ServiceError):
|
44 |
+
"""Exception raised when a model download operation fails."""
|
45 |
+
def __init__(self, model_id: str, feature_name: str, original_error: str = "Unknown error"):
|
46 |
+
super().__init__(
|
47 |
+
status_code=503, # Service Unavailable
|
48 |
+
detail=f"Failed to download model '{model_id}' for '{feature_name}'. Please check your internet connection or try again. Error: {original_error}",
|
49 |
+
error_type="ModelDownloadFailed",
|
50 |
+
model_id=model_id,
|
51 |
+
feature_name=feature_name
|
52 |
+
)
|
53 |
+
self.original_error = original_error
|
app/core/logging.py
CHANGED
@@ -1,12 +1,55 @@
|
|
1 |
-
|
2 |
import logging
|
|
|
|
|
|
|
|
|
3 |
|
4 |
def configure_logging():
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
12 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
import logging
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from app.core.config import APP_DATA_ROOT_DIR, APP_NAME
|
7 |
|
8 |
def configure_logging():
|
9 |
+
"""
|
10 |
+
Configures application-wide logging to both console and a file.
|
11 |
+
The log file is placed in the application's data directory.
|
12 |
+
"""
|
13 |
+
log_dir = APP_DATA_ROOT_DIR / "logs"
|
14 |
+
log_dir.mkdir(parents=True, exist_ok=True) # Ensure the log directory exists
|
15 |
+
|
16 |
+
log_file_path = log_dir / f"{APP_NAME.lower()}.log"
|
17 |
+
|
18 |
+
# Define a custom formatter
|
19 |
+
formatter = logging.Formatter(
|
20 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
21 |
)
|
22 |
+
|
23 |
+
# Console handler
|
24 |
+
console_handler = logging.StreamHandler()
|
25 |
+
console_handler.setFormatter(formatter)
|
26 |
+
console_handler.setLevel(logging.INFO) # Default console level
|
27 |
+
|
28 |
+
# File handler
|
29 |
+
file_handler = logging.FileHandler(log_file_path)
|
30 |
+
file_handler.setFormatter(formatter)
|
31 |
+
file_handler.setLevel(logging.INFO) # Default file level
|
32 |
+
|
33 |
+
# Get the root logger
|
34 |
+
root_logger = logging.getLogger()
|
35 |
+
root_logger.setLevel(logging.INFO) # Overall minimum logging level
|
36 |
+
|
37 |
+
# Clear existing handlers to prevent duplicate logs if called multiple times
|
38 |
+
if root_logger.hasHandlers():
|
39 |
+
root_logger.handlers.clear()
|
40 |
+
|
41 |
+
root_logger.addHandler(console_handler)
|
42 |
+
root_logger.addHandler(file_handler)
|
43 |
+
|
44 |
+
# Set specific log levels for libraries if needed (e.g., to reduce verbosity)
|
45 |
+
logging.getLogger("uvicorn").setLevel(logging.WARNING)
|
46 |
+
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
47 |
+
logging.getLogger("huggingface_hub").setLevel(logging.WARNING)
|
48 |
+
logging.getLogger("transformers").setLevel(logging.WARNING)
|
49 |
+
logging.getLogger("sentence_transformers").setLevel(logging.WARNING)
|
50 |
+
logging.getLogger("nltk").setLevel(logging.WARNING)
|
51 |
+
logging.getLogger("urllib3").setLevel(logging.WARNING)
|
52 |
+
logging.getLogger("asyncio").setLevel(logging.WARNING) # Reduce asyncio verbosity
|
53 |
+
|
54 |
+
logger = logging.getLogger(f"{APP_NAME}.core.logging")
|
55 |
+
logger.info(f"Logging configured. Logs are saved to: {log_file_path}")
|
app/core/model_manager.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/core/model_manager.py
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import asyncio
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Callable, Optional, Dict, List
|
7 |
+
|
8 |
+
# Imports for downloading specific model types
|
9 |
+
import nltk
|
10 |
+
from huggingface_hub import snapshot_download
|
11 |
+
import spacy.cli
|
12 |
+
|
13 |
+
# Internal application imports
|
14 |
+
from app.core.config import (
|
15 |
+
MODELS_DIR,
|
16 |
+
NLTK_DATA_DIR,
|
17 |
+
SPACY_MODEL_ID,
|
18 |
+
SENTENCE_TRANSFORMER_MODEL_ID,
|
19 |
+
TONE_MODEL_ID,
|
20 |
+
TRANSLATION_MODEL_ID,
|
21 |
+
WORDNET_NLTK_ID,
|
22 |
+
APP_NAME
|
23 |
+
)
|
24 |
+
from app.core.exceptions import ModelNotDownloadedError, ModelDownloadFailedError, ServiceError
|
25 |
+
|
26 |
+
logger = logging.getLogger(f"{APP_NAME}.core.model_manager")
|
27 |
+
|
28 |
+
# Type alias for progress callback
|
29 |
+
ProgressCallback = Callable[[str, str, float, Optional[str]], None] # (model_id, status, progress, message)
|
30 |
+
|
31 |
+
def _get_hf_model_local_path(model_id: str) -> Path:
|
32 |
+
"""Helper to get the expected local path for a Hugging Face model."""
|
33 |
+
# snapshot_download creates a specific folder structure inside MODELS_DIR/hf_cache
|
34 |
+
# For example, for "bert-base-uncased", it might be MODELS_DIR/hf_cache/models--bert-base-uncased
|
35 |
+
# The actual model files are inside that.
|
36 |
+
# The `transformers` library usually handles this resolution.
|
37 |
+
# We just need to check if the directory created by snapshot_download exists.
|
38 |
+
# A robust check involves looking inside that directory.
|
39 |
+
return MODELS_DIR / "hf_cache" / model_id.replace("/", "--") # Standard HF cache path logic
|
40 |
+
|
41 |
+
|
42 |
+
def check_model_exists(model_id: str, model_type: str) -> bool:
|
43 |
+
"""
|
44 |
+
Checks if a specific model or NLTK data is already downloaded locally.
|
45 |
+
"""
|
46 |
+
if model_type == "huggingface":
|
47 |
+
local_path = _get_hf_model_local_path(model_id)
|
48 |
+
# Check if the directory exists and contains some files
|
49 |
+
return local_path.is_dir() and any(local_path.iterdir())
|
50 |
+
elif model_type == "spacy":
|
51 |
+
# spaCy models are symlinked or copied into a specific site-packages location
|
52 |
+
# The easiest check is to try loading it, or check spacy.util.is_package
|
53 |
+
# For our purposes, we'll check if the directory created by `spacy download` exists
|
54 |
+
# within our MODELS_DIR, assuming we direct spaCy there.
|
55 |
+
# However, `spacy.load` is the most reliable. For pre-check, we'll rely on the
|
56 |
+
# existence check in load_spacy_model. This is a simplified check.
|
57 |
+
# The actual loading process in app.services.base handles the `is_package` check.
|
58 |
+
# For `spacy.cli.download` to work with MODELS_DIR, it often requires setting SPACY_DATA.
|
59 |
+
spacy_target_path = MODELS_DIR / model_id
|
60 |
+
return spacy_target_path.is_dir() and any(spacy_target_path.iterdir())
|
61 |
+
elif model_type == "nltk":
|
62 |
+
# NLTK data check
|
63 |
+
try:
|
64 |
+
return nltk.data.find(f"corpora/{model_id}") is not None
|
65 |
+
except LookupError:
|
66 |
+
return False
|
67 |
+
else:
|
68 |
+
logger.warning(f"Unknown model type for check_model_exists: {model_type}")
|
69 |
+
return False
|
70 |
+
|
71 |
+
# --- Download Functions ---
|
72 |
+
|
73 |
+
async def download_hf_model_async(
|
74 |
+
model_id: str,
|
75 |
+
feature_name: str,
|
76 |
+
progress_callback: Optional[ProgressCallback] = None
|
77 |
+
) -> None:
|
78 |
+
"""
|
79 |
+
Asynchronously downloads a Hugging Face model from the Hub.
|
80 |
+
"""
|
81 |
+
logger.info(f"Initiating download for Hugging Face model '{model_id}' for '{feature_name}'...")
|
82 |
+
if check_model_exists(model_id, "huggingface"):
|
83 |
+
logger.info(f"Hugging Face model '{model_id}' already exists locally. Skipping download.")
|
84 |
+
if progress_callback:
|
85 |
+
progress_callback(model_id, "completed", 1.0, "Already downloaded.")
|
86 |
+
return
|
87 |
+
|
88 |
+
# Use a thread pool for blocking download operation
|
89 |
+
try:
|
90 |
+
def _blocking_download():
|
91 |
+
# This downloads to MODELS_DIR/hf_cache by default if HF_HOME is set to MODELS_DIR
|
92 |
+
# Otherwise, specify cache_dir.
|
93 |
+
# For simplicity, we rely on `settings.MODELS_DIR` handling HF_HOME in config.py
|
94 |
+
snapshot_download(
|
95 |
+
repo_id=model_id,
|
96 |
+
cache_dir=str(MODELS_DIR / "hf_cache"), # Explicitly set cache directory
|
97 |
+
local_dir_use_symlinks=False, # Use False for better self-contained app
|
98 |
+
# The `_` prefix means it's an internal parameter not typically exposed.
|
99 |
+
# `progress_callback` in `snapshot_download` is not directly exposed for live updates.
|
100 |
+
# We log at beginning and end.
|
101 |
+
)
|
102 |
+
logger.info(f"Hugging Face model '{model_id}' download complete.")
|
103 |
+
|
104 |
+
if progress_callback:
|
105 |
+
progress_callback(model_id, "downloading", 0.05, "Starting download...")
|
106 |
+
|
107 |
+
await asyncio.to_thread(_blocking_download) # Run blocking download in a separate thread
|
108 |
+
|
109 |
+
if progress_callback:
|
110 |
+
progress_callback(model_id, "completed", 1.0, "Download successful.")
|
111 |
+
|
112 |
+
except Exception as e:
|
113 |
+
logger.error(f"Failed to download Hugging Face model '{model_id}': {e}", exc_info=True)
|
114 |
+
if progress_callback:
|
115 |
+
progress_callback(model_id, "failed", 0.0, f"Error: {e}")
|
116 |
+
raise ModelDownloadFailedError(model_id, feature_name, original_error=str(e))
|
117 |
+
|
118 |
+
|
119 |
+
async def download_spacy_model_async(
|
120 |
+
model_id: str,
|
121 |
+
feature_name: str,
|
122 |
+
progress_callback: Optional[ProgressCallback] = None
|
123 |
+
) -> None:
|
124 |
+
"""
|
125 |
+
Asynchronously downloads a spaCy model.
|
126 |
+
"""
|
127 |
+
logger.info(f"Initiating download for spaCy model '{model_id}' for '{feature_name}'...")
|
128 |
+
# Check if the model package is already installed/available in the spacy data path
|
129 |
+
# NOTE: This check might not be sufficient if SPACY_DATA isn't correctly pointing.
|
130 |
+
# The `spacy.util.is_package` would be more robust but requires `import spacy` first.
|
131 |
+
# For now, we trust `spacy.cli.download` to handle the check or fail gracefully.
|
132 |
+
|
133 |
+
# We must ensure SPACY_DATA environment variable is set to MODELS_DIR
|
134 |
+
# for spacy.cli.download to put it in our custom path.
|
135 |
+
original_spacy_data = os.environ.get("SPACY_DATA")
|
136 |
+
try:
|
137 |
+
os.environ["SPACY_DATA"] = str(MODELS_DIR)
|
138 |
+
|
139 |
+
if check_model_exists(model_id, "spacy"): # Using our own simplified check
|
140 |
+
logger.info(f"SpaCy model '{model_id}' already exists locally. Skipping download.")
|
141 |
+
if progress_callback:
|
142 |
+
progress_callback(model_id, "completed", 1.0, "Already downloaded.")
|
143 |
+
return
|
144 |
+
|
145 |
+
def _blocking_download():
|
146 |
+
# spacy.cli.download attempts to download and link/copy
|
147 |
+
# It will raise an error if already downloaded if it can't link, etc.
|
148 |
+
# We're relying on our check_model_exists before this.
|
149 |
+
spacy.cli.download(model_id)
|
150 |
+
logger.info(f"SpaCy model '{model_id}' download complete.")
|
151 |
+
|
152 |
+
if progress_callback:
|
153 |
+
progress_callback(model_id, "downloading", 0.05, "Starting download...")
|
154 |
+
|
155 |
+
await asyncio.to_thread(_blocking_download)
|
156 |
+
|
157 |
+
if progress_callback:
|
158 |
+
progress_callback(model_id, "completed", 1.0, "Download successful.")
|
159 |
+
|
160 |
+
except Exception as e:
|
161 |
+
logger.error(f"Failed to download spaCy model '{model_id}': {e}", exc_info=True)
|
162 |
+
if progress_callback:
|
163 |
+
progress_callback(model_id, "failed", 0.0, f"Error: {e}")
|
164 |
+
raise ModelDownloadFailedError(model_id, feature_name, original_error=str(e))
|
165 |
+
finally:
|
166 |
+
# Restore original SPACY_DATA if it was set
|
167 |
+
if original_spacy_data is not None:
|
168 |
+
os.environ["SPACY_DATA"] = original_spacy_data
|
169 |
+
else:
|
170 |
+
if "SPACY_DATA" in os.environ:
|
171 |
+
del os.environ["SPACY_DATA"]
|
172 |
+
|
173 |
+
|
174 |
+
async def download_nltk_data_async(
|
175 |
+
data_id: str,
|
176 |
+
feature_name: str,
|
177 |
+
progress_callback: Optional[ProgressCallback] = None
|
178 |
+
) -> None:
|
179 |
+
"""
|
180 |
+
Asynchronously downloads NLTK data.
|
181 |
+
"""
|
182 |
+
logger.info(f"Initiating download for NLTK data '{data_id}' for '{feature_name}'...")
|
183 |
+
# NLTK data path should be set by NLTK_DATA environment variable in config.py
|
184 |
+
# `nltk.download` will use this path.
|
185 |
+
|
186 |
+
if check_model_exists(data_id, "nltk"):
|
187 |
+
logger.info(f"NLTK data '{data_id}' already exists locally. Skipping download.")
|
188 |
+
if progress_callback:
|
189 |
+
progress_callback(data_id, "completed", 1.0, "Already downloaded.")
|
190 |
+
return
|
191 |
+
|
192 |
+
def _blocking_download():
|
193 |
+
# NLTK downloader can show a GUI, so ensure it's not trying to do that
|
194 |
+
# `download_dir` should be set by NLTK_DATA env variable.
|
195 |
+
# `quiet=True` is important for programmatic download.
|
196 |
+
nltk.download(data_id, download_dir=str(NLTK_DATA_DIR), quiet=True)
|
197 |
+
logger.info(f"NLTK data '{data_id}' download complete.")
|
198 |
+
|
199 |
+
try:
|
200 |
+
if progress_callback:
|
201 |
+
progress_callback(data_id, "downloading", 0.05, "Starting download...")
|
202 |
+
|
203 |
+
await asyncio.to_thread(_blocking_download)
|
204 |
+
|
205 |
+
if progress_callback:
|
206 |
+
progress_callback(data_id, "completed", 1.0, "Download successful.")
|
207 |
+
|
208 |
+
except Exception as e:
|
209 |
+
logger.error(f"Failed to download NLTK data '{data_id}': {e}", exc_info=True)
|
210 |
+
if progress_callback:
|
211 |
+
progress_callback(data_id, "failed", 0.0, f"Error: {e}")
|
212 |
+
raise ModelDownloadFailedError(data_id, feature_name, original_error=str(e))
|
213 |
+
|
214 |
+
|
215 |
+
# --- Comprehensive Model Management ---
|
216 |
+
|
217 |
+
def get_all_required_models() -> List[Dict]:
|
218 |
+
"""
|
219 |
+
Returns a list of all models required by the application, with their type and feature.
|
220 |
+
"""
|
221 |
+
return [
|
222 |
+
{"id": SPACY_MODEL_ID, "type": "spacy", "feature": "Text Processing (General)"},
|
223 |
+
{"id": SENTENCE_TRANSFORMER_MODEL_ID, "type": "huggingface", "feature": "Sentence Embeddings"},
|
224 |
+
{"id": TONE_MODEL_ID, "type": "huggingface", "feature": "Tone Classification"},
|
225 |
+
{"id": TRANSLATION_MODEL_ID, "type": "huggingface", "feature": "Translation"},
|
226 |
+
{"id": WORDNET_NLTK_ID, "type": "nltk", "feature": "Synonym Suggestion"},
|
227 |
+
# Add any other models here as your application grows
|
228 |
+
]
|
229 |
+
|
230 |
+
async def download_all_required_models(progress_callback: Optional[ProgressCallback] = None) -> Dict[str, str]:
|
231 |
+
"""
|
232 |
+
Attempts to download all required models.
|
233 |
+
Returns a dictionary of download statuses.
|
234 |
+
"""
|
235 |
+
required_models = get_all_required_models()
|
236 |
+
download_statuses = {}
|
237 |
+
|
238 |
+
for model_info in required_models:
|
239 |
+
model_id = model_info["id"]
|
240 |
+
model_type = model_info["type"]
|
241 |
+
feature_name = model_info["feature"]
|
242 |
+
|
243 |
+
if check_model_exists(model_id, model_type):
|
244 |
+
status_message = f"'{model_id}' ({feature_name}) already downloaded."
|
245 |
+
logger.info(status_message)
|
246 |
+
download_statuses[model_id] = "already_downloaded"
|
247 |
+
if progress_callback:
|
248 |
+
progress_callback(model_id, "completed", 1.0, status_message)
|
249 |
+
continue
|
250 |
+
|
251 |
+
logger.info(f"Attempting to download '{model_id}' ({feature_name})...")
|
252 |
+
try:
|
253 |
+
if model_type == "huggingface":
|
254 |
+
await download_hf_model_async(model_id, feature_name, progress_callback)
|
255 |
+
elif model_type == "spacy":
|
256 |
+
await download_spacy_model_async(model_id, feature_name, progress_callback)
|
257 |
+
elif model_type == "nltk":
|
258 |
+
await download_nltk_data_async(model_id, feature_name, progress_callback)
|
259 |
+
else:
|
260 |
+
raise ValueError(f"Unsupported model type: {model_type}")
|
261 |
+
|
262 |
+
status_message = f"'{model_id}' ({feature_name}) downloaded successfully."
|
263 |
+
logger.info(status_message)
|
264 |
+
download_statuses[model_id] = "success"
|
265 |
+
|
266 |
+
except ModelDownloadFailedError as e:
|
267 |
+
status_message = f"Failed to download '{model_id}' ({feature_name}): {e.original_error}"
|
268 |
+
logger.error(status_message)
|
269 |
+
download_statuses[model_id] = "failed"
|
270 |
+
# The progress_callback is already called within the specific download functions on failure
|
271 |
+
except Exception as e:
|
272 |
+
status_message = f"An unexpected error occurred while downloading '{model_id}' ({feature_name}): {e}"
|
273 |
+
logger.error(status_message, exc_info=True)
|
274 |
+
download_statuses[model_id] = "failed"
|
275 |
+
if progress_callback:
|
276 |
+
progress_callback(model_id, "failed", 0.0, status_message)
|
277 |
+
|
278 |
+
|
279 |
+
logger.info("Finished attempting to download all required models.")
|
280 |
+
return download_statuses
|
app/core/prompts.py
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
def tone_prompt(text: str, tone: str) -> str:
|
2 |
-
return f"Change the tone of this sentence to {tone}: {text.strip()}"
|
3 |
-
|
4 |
-
def summarize_prompt(text: str) -> str:
|
5 |
-
return f"Summarize the following text:\n{text.strip()}"
|
6 |
-
|
7 |
-
def clarity_prompt(text: str) -> str:
|
8 |
-
return f"Improve the clarity of the following sentence:\n{text.strip()}"
|
9 |
-
|
10 |
-
def rewrite_prompt(text: str, instruction: str) -> str:
|
11 |
-
return f"{instruction.strip()}\n{text.strip()}"
|
12 |
-
|
13 |
-
def vocabulary_prompt(text: str) -> str:
|
14 |
-
return (
|
15 |
-
"You are an expert vocabulary enhancer. Rewrite the following text "
|
16 |
-
"by replacing common and simple words with more sophisticated, "
|
17 |
-
"precise, and contextually appropriate synonyms. Do not change "
|
18 |
-
"the original meaning. Maintain the tone.\n" + text.strip()
|
19 |
-
)
|
20 |
-
|
21 |
-
def concise_prompt(text: str) -> str:
|
22 |
-
return (
|
23 |
-
"You are an expert editor specializing in conciseness. "
|
24 |
-
"Rewrite the following text to be more concise and to the point, "
|
25 |
-
"removing any verbose phrases, redundant words, or unnecessary clauses. "
|
26 |
-
"Maintain the original meaning and professional tone.\n" + text.strip()
|
27 |
-
)
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/main.py
CHANGED
@@ -1,5 +1,119 @@
|
|
1 |
-
|
2 |
-
|
|
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
configure_logging()
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/main.py
|
2 |
+
import logging
|
3 |
+
from contextlib import asynccontextmanager
|
4 |
|
5 |
+
from fastapi import FastAPI, Request, status
|
6 |
+
from fastapi.responses import JSONResponse
|
7 |
+
from fastapi.middleware.gzip import GZipMiddleware
|
8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
9 |
+
|
10 |
+
from app.core.config import APP_NAME # For logger naming
|
11 |
+
from app.core.logging import configure_logging # Import the new logging configuration
|
12 |
+
from app.core.exceptions import ServiceError, ModelNotDownloadedError # Import custom exceptions
|
13 |
+
|
14 |
+
# Import your routers
|
15 |
+
# Adjust these imports if your router file names or structure are different
|
16 |
+
from app.routers import (
|
17 |
+
grammar, tone, voice, inclusive_language,
|
18 |
+
readability, paraphrase, translate, synonyms, rewrite, analyze
|
19 |
+
)
|
20 |
+
|
21 |
+
# Configure logging at the very beginning
|
22 |
configure_logging()
|
23 |
+
logger = logging.getLogger(f"{APP_NAME}.main")
|
24 |
+
|
25 |
+
|
26 |
+
@asynccontextmanager
|
27 |
+
async def lifespan(app: FastAPI):
|
28 |
+
"""
|
29 |
+
Context manager for application startup and shutdown events.
|
30 |
+
Models are now lazily loaded, so no explicit loading here.
|
31 |
+
"""
|
32 |
+
logger.info("Application starting up...")
|
33 |
+
# Any other global startup tasks can go here
|
34 |
+
yield
|
35 |
+
logger.info("Application shutting down...")
|
36 |
+
# Any global shutdown tasks can go here (e.g., closing database connections)
|
37 |
+
|
38 |
+
|
39 |
+
app = FastAPI(
|
40 |
+
title="Writing Assistant API (Local)",
|
41 |
+
description="Local API for the desktop Writing Assistant application, providing various NLP functionalities.",
|
42 |
+
version="0.1.0",
|
43 |
+
lifespan=lifespan,
|
44 |
+
)
|
45 |
+
|
46 |
+
# --- Middleware Setup ---
|
47 |
+
app.add_middleware(GZipMiddleware, minimum_size=500)
|
48 |
+
|
49 |
+
# CORS Middleware for local development/desktop app scenarios
|
50 |
+
# Allows all origins for local testing. Restrict as needed for deployment.
|
51 |
+
app.add_middleware(
|
52 |
+
CORSMiddleware,
|
53 |
+
allow_origins=["*"], # Adjust this for specific origins in a web deployment
|
54 |
+
allow_credentials=True,
|
55 |
+
allow_methods=["*"],
|
56 |
+
allow_headers=["*"],
|
57 |
+
)
|
58 |
+
|
59 |
+
# --- Global Exception Handlers ---
|
60 |
+
@app.exception_handler(ServiceError)
|
61 |
+
async def service_error_handler(request: Request, exc: ServiceError):
|
62 |
+
"""
|
63 |
+
Handles custom ServiceError exceptions, returning a structured JSON response.
|
64 |
+
"""
|
65 |
+
logger.error(f"Service Error caught for path {request.url.path}: {exc.detail}", exc_info=True)
|
66 |
+
return JSONResponse(
|
67 |
+
status_code=exc.status_code,
|
68 |
+
content=exc.to_dict(), # Use the to_dict method from ServiceError
|
69 |
+
)
|
70 |
+
|
71 |
+
@app.exception_handler(ModelNotDownloadedError)
|
72 |
+
async def model_not_downloaded_error_handler(request: Request, exc: ModelNotDownloadedError):
|
73 |
+
"""
|
74 |
+
Handles ModelNotDownloadedError exceptions, informing the client a model is missing.
|
75 |
+
"""
|
76 |
+
logger.warning(f"Model Not Downloaded Error caught for path {request.url.path}: Model '{exc.model_id}' is missing for feature '{exc.feature_name}'.")
|
77 |
+
return JSONResponse(
|
78 |
+
status_code=exc.status_code,
|
79 |
+
content=exc.to_dict(), # Use the to_dict method from ModelNotDownloadedError
|
80 |
+
)
|
81 |
+
|
82 |
+
@app.exception_handler(Exception)
|
83 |
+
async def general_exception_handler(request: Request, exc: Exception):
|
84 |
+
"""
|
85 |
+
Handles all other unhandled exceptions, returning a generic server error.
|
86 |
+
"""
|
87 |
+
logger.exception(f"Unhandled exception caught for path {request.url.path}: {exc}") # Use logger.exception to log traceback
|
88 |
+
return JSONResponse(
|
89 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
90 |
+
content={
|
91 |
+
"detail": "An unexpected internal server error occurred.",
|
92 |
+
"error_type": "InternalServerError",
|
93 |
+
},
|
94 |
+
)
|
95 |
+
|
96 |
+
# --- Include Routers ---
|
97 |
+
# Note: You will need to create/update these files in app/routers/
|
98 |
+
# if they don't exist or don't match the new async service methods.
|
99 |
+
for router, tag in [
|
100 |
+
(grammar.router, "Grammar"),
|
101 |
+
(tone.router, "Tone"),
|
102 |
+
(voice.router, "Voice"),
|
103 |
+
(inclusive_language.router, "Inclusive Language"),
|
104 |
+
(readability.router, "Readability"),
|
105 |
+
(rewrite.router, "Rewrite"),
|
106 |
+
(analyze.router, "Analyze"),
|
107 |
+
(paraphrase.router, "Paraphrasing"),
|
108 |
+
(translate.router, "Translation"),
|
109 |
+
(synonyms.router, "Synonyms")
|
110 |
+
]:
|
111 |
+
app.include_router(router, tags=[tag])
|
112 |
+
|
113 |
+
# --- Root Endpoint ---
|
114 |
+
@app.get("/", tags=["Health Check"])
|
115 |
+
async def root():
|
116 |
+
"""
|
117 |
+
Root endpoint for health check.
|
118 |
+
"""
|
119 |
+
return {"message": "Writing Assistant API is running!"}
|
app/queue.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
import asyncio
|
2 |
-
import logging
|
3 |
-
import time
|
4 |
-
import uuid
|
5 |
-
import inspect
|
6 |
-
|
7 |
-
from app.services.grammar import GrammarCorrector
|
8 |
-
from app.services.paraphrase import Paraphraser
|
9 |
-
from app.services.translation import Translator
|
10 |
-
from app.services.tone_classification import ToneClassifier
|
11 |
-
from app.services.inclusive_language import InclusiveLanguageChecker
|
12 |
-
from app.services.voice_detection import VoiceDetector
|
13 |
-
from app.services.readability import ReadabilityScorer
|
14 |
-
from app.services.synonyms import SynonymSuggester
|
15 |
-
from app.core.config import settings
|
16 |
-
|
17 |
-
# Configure logging
|
18 |
-
logging.basicConfig(
|
19 |
-
level=logging.DEBUG if getattr(settings, "DEBUG", False) else logging.INFO,
|
20 |
-
format="%(asctime)s [%(levelname)s] %(message)s"
|
21 |
-
)
|
22 |
-
|
23 |
-
# Initialize service instances
|
24 |
-
grammar = GrammarCorrector()
|
25 |
-
paraphraser = Paraphraser()
|
26 |
-
translator = Translator()
|
27 |
-
tone = ToneClassifier()
|
28 |
-
inclusive = InclusiveLanguageChecker()
|
29 |
-
voice_analyzer = VoiceDetector()
|
30 |
-
readability = ReadabilityScorer()
|
31 |
-
synonyms = SynonymSuggester()
|
32 |
-
|
33 |
-
# Create async task queue (optional: maxsize=100 to prevent overload)
|
34 |
-
task_queue = asyncio.Queue(maxsize=100)
|
35 |
-
|
36 |
-
# Task handler map
|
37 |
-
SERVICE_HANDLERS = {
|
38 |
-
"grammar": lambda p: grammar.correct(p["text"]),
|
39 |
-
"paraphrase": lambda p: paraphraser.paraphrase(p["text"]),
|
40 |
-
"translate": lambda p: translator.translate(p["text"], p["target_lang"]),
|
41 |
-
"tone": lambda p: tone.classify(p["text"]),
|
42 |
-
"inclusive": lambda p: inclusive.check(p["text"]),
|
43 |
-
"voice": lambda p: voice_analyzer.classify(p["text"]),
|
44 |
-
"readability": lambda p: readability.compute(p["text"]),
|
45 |
-
"synonyms": lambda p: synonyms.suggest(p["text"]), # ✅ This is async
|
46 |
-
}
|
47 |
-
|
48 |
-
async def worker(worker_id: int):
|
49 |
-
logging.info(f"Worker-{worker_id} started")
|
50 |
-
|
51 |
-
while True:
|
52 |
-
task = await task_queue.get()
|
53 |
-
future = task["future"]
|
54 |
-
task_type = task["type"]
|
55 |
-
payload = task["payload"]
|
56 |
-
task_id = task["id"]
|
57 |
-
|
58 |
-
start_time = time.perf_counter()
|
59 |
-
logging.info(f"[Worker-{worker_id}] Processing Task-{task_id} | Type: {task_type} | Queue size: {task_queue.qsize()}")
|
60 |
-
|
61 |
-
try:
|
62 |
-
handler = SERVICE_HANDLERS.get(task_type)
|
63 |
-
if not handler:
|
64 |
-
raise ValueError(f"Unknown task type: {task_type}")
|
65 |
-
|
66 |
-
result = handler(payload)
|
67 |
-
if inspect.isawaitable(result):
|
68 |
-
result = await result
|
69 |
-
|
70 |
-
elapsed = time.perf_counter() - start_time
|
71 |
-
logging.info(f"[Worker-{worker_id}] Finished Task-{task_id} in {elapsed:.2f}s")
|
72 |
-
|
73 |
-
if not future.done():
|
74 |
-
future.set_result(result)
|
75 |
-
|
76 |
-
except Exception as e:
|
77 |
-
logging.error(f"[Worker-{worker_id}] Error in Task-{task_id} ({task_type}): {e}")
|
78 |
-
if not future.done():
|
79 |
-
future.set_result({"error": str(e)})
|
80 |
-
|
81 |
-
task_queue.task_done()
|
82 |
-
|
83 |
-
def start_workers(count: int = 2):
|
84 |
-
for i in range(count):
|
85 |
-
asyncio.create_task(worker(i))
|
86 |
-
|
87 |
-
async def enqueue_task(task_type: str, payload: dict, timeout: float = 10.0):
|
88 |
-
future = asyncio.get_event_loop().create_future()
|
89 |
-
task_id = str(uuid.uuid4())[:8]
|
90 |
-
|
91 |
-
await task_queue.put({
|
92 |
-
"future": future,
|
93 |
-
"type": task_type,
|
94 |
-
"payload": payload,
|
95 |
-
"id": task_id
|
96 |
-
})
|
97 |
-
|
98 |
-
logging.info(f"[ENQUEUE] Task-{task_id} added to queue | Type: {task_type} | Queue size: {task_queue.qsize()}")
|
99 |
-
|
100 |
-
try:
|
101 |
-
return await asyncio.wait_for(future, timeout=timeout)
|
102 |
-
except asyncio.TimeoutError:
|
103 |
-
logging.warning(f"[ENQUEUE] Task-{task_id} timed out after {timeout}s")
|
104 |
-
return {"error": f"Task {task_type} timed out after {timeout} seconds."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/routers/analyze.py
CHANGED
@@ -1,45 +1,97 @@
|
|
1 |
-
|
2 |
-
from app.core.security import verify_api_key
|
3 |
-
from app.schemas.base import TextOnlyRequest
|
4 |
-
from app.queue import task_queue
|
5 |
-
import asyncio
|
6 |
-
import uuid
|
7 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
router = APIRouter(prefix="/analyze", tags=["Analysis"])
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
13 |
-
async def
|
|
|
|
|
|
|
|
|
14 |
text = payload.text.strip()
|
15 |
if not text:
|
16 |
-
raise HTTPException(status_code=
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/routers/analyze.py
|
|
|
|
|
|
|
|
|
|
|
2 |
import logging
|
3 |
+
import asyncio
|
4 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
5 |
+
|
6 |
+
from app.schemas.base import TextOnlyRequest # Assuming this Pydantic model exists
|
7 |
+
from app.services.grammar import GrammarCorrector
|
8 |
+
from app.services.tone_classification import ToneClassifier
|
9 |
+
from app.services.inclusive_language import InclusiveLanguageChecker
|
10 |
+
from app.services.voice_detection import VoiceDetector
|
11 |
+
from app.services.readability import ReadabilityScorer
|
12 |
+
from app.services.synonyms import SynonymSuggester
|
13 |
+
from app.core.security import verify_api_key # Assuming you still need API key verification
|
14 |
+
from app.core.config import APP_NAME # For logger naming
|
15 |
+
from app.core.exceptions import ServiceError, ModelNotDownloadedError # Import custom exceptions
|
16 |
+
|
17 |
+
logger = logging.getLogger(f"{APP_NAME}.routers.analyze")
|
18 |
|
19 |
router = APIRouter(prefix="/analyze", tags=["Analysis"])
|
20 |
+
|
21 |
+
# Initialize service instances once per application lifecycle
|
22 |
+
# These services will handle lazy loading their models internally
|
23 |
+
grammar_service = GrammarCorrector()
|
24 |
+
tone_service = ToneClassifier()
|
25 |
+
inclusive_service = InclusiveLanguageChecker()
|
26 |
+
voice_service = VoiceDetector()
|
27 |
+
readability_service = ReadabilityScorer()
|
28 |
+
synonyms_service = SynonymSuggester()
|
29 |
+
|
30 |
|
31 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
32 |
+
async def analyze_text_endpoint(payload: TextOnlyRequest):
|
33 |
+
"""
|
34 |
+
Performs a comprehensive analysis of the provided text,
|
35 |
+
including grammar, tone, inclusive language, voice, readability, and synonyms.
|
36 |
+
"""
|
37 |
text = payload.text.strip()
|
38 |
if not text:
|
39 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
|
40 |
+
|
41 |
+
logger.info(f"Received comprehensive analysis request for text (first 50 chars): '{text[:50]}...'")
|
42 |
+
|
43 |
+
# Define tasks to run concurrently
|
44 |
+
tasks = {
|
45 |
+
"grammar": grammar_service.correct(text),
|
46 |
+
"tone": tone_service.classify(text),
|
47 |
+
"inclusive_language": inclusive_service.check(text),
|
48 |
+
"voice": voice_service.classify(text),
|
49 |
+
"readability": readability_service.compute(text),
|
50 |
+
"synonyms": synonyms_service.suggest(text),
|
51 |
+
}
|
52 |
+
|
53 |
+
results = {}
|
54 |
+
coroutine_tasks = []
|
55 |
+
task_keys = [] # To map results back to their keys
|
56 |
+
|
57 |
+
for key, coroutine in tasks.items():
|
58 |
+
coroutine_tasks.append(coroutine)
|
59 |
+
task_keys.append(key)
|
60 |
+
|
61 |
+
# Run all tasks concurrently and handle potential exceptions for each
|
62 |
+
raw_results = await asyncio.gather(*coroutine_tasks, return_exceptions=True)
|
63 |
+
|
64 |
+
# Process results, handling errors gracefully for each sub-analysis
|
65 |
+
for i, result in enumerate(raw_results):
|
66 |
+
key = task_keys[i]
|
67 |
+
if isinstance(result, ModelNotDownloadedError):
|
68 |
+
logger.warning(f"Analysis for '{key}' skipped: Model '{result.model_id}' not downloaded. Detail: {result.detail}")
|
69 |
+
results[key] = {
|
70 |
+
"status": "skipped",
|
71 |
+
"message": result.detail,
|
72 |
+
"error_type": result.error_type,
|
73 |
+
"model_id": result.model_id,
|
74 |
+
"feature_name": result.feature_name
|
75 |
+
}
|
76 |
+
elif isinstance(result, ServiceError):
|
77 |
+
logger.error(f"Analysis for '{key}' failed with ServiceError. Detail: {result.detail}", exc_info=True)
|
78 |
+
results[key] = {
|
79 |
+
"status": "error",
|
80 |
+
"message": result.detail,
|
81 |
+
"error_type": result.error_type
|
82 |
+
}
|
83 |
+
elif isinstance(result, Exception): # Catch any other unexpected exceptions from service methods
|
84 |
+
logger.exception(f"Analysis for '{key}' failed with unexpected error.")
|
85 |
+
results[key] = {
|
86 |
+
"status": "error",
|
87 |
+
"message": f"An unexpected error occurred: {str(result)}",
|
88 |
+
"error_type": "InternalServiceError"
|
89 |
+
}
|
90 |
+
else:
|
91 |
+
# If successful, merge the service's result into the main results dict
|
92 |
+
# Assuming each service returns a dict (e.g., {"grammar_correction": {...}} or {"tone": "..."})
|
93 |
+
results[key] = result # Direct assignment if the service result is already dict
|
94 |
+
|
95 |
+
|
96 |
+
logger.info(f"Comprehensive analysis complete for text (first 50 chars): '{text[:50]}...'")
|
97 |
+
return {"analysis_results": results}
|
app/routers/grammar.py
CHANGED
@@ -1,36 +1,49 @@
|
|
1 |
-
|
2 |
-
import asyncio
|
3 |
import logging
|
4 |
-
from fastapi import APIRouter, Depends,
|
5 |
-
|
6 |
-
from app.
|
7 |
-
from app.
|
8 |
-
from app.
|
|
|
|
|
|
|
|
|
9 |
|
10 |
router = APIRouter(prefix="/grammar", tags=["Grammar"])
|
11 |
-
logger = logging.getLogger(__name__)
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
text = payload.text.strip()
|
16 |
if not text:
|
17 |
-
|
18 |
-
|
19 |
-
future = asyncio.get_event_loop().create_future()
|
20 |
-
task_id = str(uuid.uuid4())[:8]
|
21 |
|
22 |
-
|
23 |
-
"type": "grammar",
|
24 |
-
"payload": {"text": text},
|
25 |
-
"future": future,
|
26 |
-
"id": task_id
|
27 |
-
})
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
status_code = 400 if "empty" in detail.lower() else 500
|
34 |
-
raise HTTPException(status_code=status_code, detail=detail)
|
35 |
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/routers/grammar.py
|
|
|
2 |
import logging
|
3 |
+
from fastapi import APIRouter, Depends, status, HTTPException # HTTPException for 400 validation errors
|
4 |
+
|
5 |
+
from app.schemas.base import TextOnlyRequest # Assuming this Pydantic model exists
|
6 |
+
from app.services.grammar import GrammarCorrector # Import the service class
|
7 |
+
from app.core.security import verify_api_key # Assuming you still need API key verification
|
8 |
+
from app.core.config import APP_NAME # For logger naming
|
9 |
+
from app.core.exceptions import ServiceError # Important for catching specific service errors
|
10 |
+
|
11 |
+
logger = logging.getLogger(f"{APP_NAME}.routers.grammar")
|
12 |
|
13 |
router = APIRouter(prefix="/grammar", tags=["Grammar"])
|
|
|
14 |
|
15 |
+
# Initialize service instance once per application lifecycle
|
16 |
+
# FastAPI handles dependency injection and lifecycle for routes,
|
17 |
+
# so instantiate the service directly.
|
18 |
+
grammar_corrector_service = GrammarCorrector()
|
19 |
+
|
20 |
+
|
21 |
+
@router.post("/correct", dependencies=[Depends(verify_api_key)]) # Changed path to /correct for clarity
|
22 |
+
async def correct_grammar_endpoint(payload: TextOnlyRequest):
|
23 |
+
"""
|
24 |
+
Corrects grammar in the provided text.
|
25 |
+
"""
|
26 |
text = payload.text.strip()
|
27 |
if not text:
|
28 |
+
# Use FastAPI's HTTPException for direct validation errors
|
29 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
|
|
|
|
|
30 |
|
31 |
+
logger.info(f"Received grammar correction request for text (first 50 chars): '{text[:50]}...'")
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
try:
|
34 |
+
# Directly call the async service method
|
35 |
+
# ModelNotDownloadedError will be raised here if model is missing,
|
36 |
+
# and caught by the global exception handler in app/main.py
|
37 |
+
result = await grammar_corrector_service.correct(text)
|
38 |
|
39 |
+
logger.info(f"Grammar correction successful for text (first 50 chars): '{text[:50]}...'")
|
40 |
+
return {"grammar_correction": result}
|
|
|
|
|
41 |
|
42 |
+
except ServiceError as e:
|
43 |
+
# Re-raise ServiceError. It will be caught by the global exception handler.
|
44 |
+
# This ensures consistent error responses across all services.
|
45 |
+
raise e
|
46 |
+
except Exception as e:
|
47 |
+
# Catch any unexpected exceptions and re-raise as a generic ServiceError
|
48 |
+
logger.exception(f"Unhandled error in grammar correction endpoint for text: '{text[:50]}...'")
|
49 |
+
raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during grammar correction.") from e
|
app/routers/inclusive_language.py
CHANGED
@@ -1,11 +1,45 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
from app.schemas.base import TextOnlyRequest
|
3 |
-
from app.services.inclusive_language import InclusiveLanguageChecker
|
4 |
-
from app.core.security import verify_api_key
|
|
|
|
|
|
|
|
|
5 |
|
6 |
router = APIRouter(prefix="/inclusive-language", tags=["Inclusive Language"])
|
7 |
-
checker = InclusiveLanguageChecker()
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/routers/inclusive_language.py
|
2 |
+
import logging
|
3 |
+
from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
|
4 |
+
|
5 |
from app.schemas.base import TextOnlyRequest
|
6 |
+
from app.services.inclusive_language import InclusiveLanguageChecker # Import the service class
|
7 |
+
from app.core.security import verify_api_key # Assuming API key verification is still used
|
8 |
+
from app.core.config import APP_NAME # For logger naming
|
9 |
+
from app.core.exceptions import ServiceError # For re-raising internal errors
|
10 |
+
|
11 |
+
logger = logging.getLogger(f"{APP_NAME}.routers.inclusive_language")
|
12 |
|
13 |
router = APIRouter(prefix="/inclusive-language", tags=["Inclusive Language"])
|
|
|
14 |
|
15 |
+
# Initialize service instance once per application lifecycle
|
16 |
+
inclusive_language_checker_service = InclusiveLanguageChecker()
|
17 |
+
|
18 |
+
|
19 |
+
@router.post("/check", dependencies=[Depends(verify_api_key)]) # Added /check path for clarity
|
20 |
+
async def check_inclusive_language_endpoint(payload: TextOnlyRequest):
|
21 |
+
"""
|
22 |
+
Checks the provided text for inclusive language suggestions.
|
23 |
+
"""
|
24 |
+
text = payload.text.strip()
|
25 |
+
if not text:
|
26 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
|
27 |
+
|
28 |
+
logger.info(f"Received inclusive language check request for text (first 50 chars): '{text[:50]}...'")
|
29 |
+
|
30 |
+
try:
|
31 |
+
# Directly call the async service method
|
32 |
+
# ModelNotDownloadedError will be raised here if model is missing,
|
33 |
+
# and caught by the global exception handler in app/main.py
|
34 |
+
result = await inclusive_language_checker_service.check(text)
|
35 |
+
|
36 |
+
logger.info(f"Inclusive language check successful for text (first 50 chars): '{text[:50]}...'")
|
37 |
+
return {"inclusive_language": result}
|
38 |
+
|
39 |
+
except ServiceError as e:
|
40 |
+
# Re-raise ServiceError. It will be caught by the global exception handler.
|
41 |
+
raise e
|
42 |
+
except Exception as e:
|
43 |
+
# Catch any unexpected exceptions and re-raise as a generic ServiceError
|
44 |
+
logger.exception(f"Unhandled error in inclusive language check endpoint for text: '{text[:50]}...'")
|
45 |
+
raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during inclusive language checking.") from e
|
app/routers/paraphrase.py
CHANGED
@@ -1,24 +1,45 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
-
from fastapi import APIRouter, Depends
|
|
|
4 |
from app.schemas.base import TextOnlyRequest
|
5 |
-
from app.services.paraphrase import Paraphraser
|
6 |
from app.core.security import verify_api_key
|
7 |
-
from app.
|
|
|
|
|
|
|
8 |
|
9 |
router = APIRouter(prefix="/paraphrase", tags=["Paraphrase"])
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/routers/paraphrase.py
|
2 |
+
import logging
|
3 |
+
from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
|
4 |
+
|
5 |
from app.schemas.base import TextOnlyRequest
|
6 |
+
from app.services.paraphrase import Paraphraser # Import the service class
|
7 |
from app.core.security import verify_api_key
|
8 |
+
from app.core.config import APP_NAME # For logger naming
|
9 |
+
from app.core.exceptions import ServiceError # For re-raising internal errors
|
10 |
+
|
11 |
+
logger = logging.getLogger(f"{APP_NAME}.routers.paraphrase")
|
12 |
|
13 |
router = APIRouter(prefix="/paraphrase", tags=["Paraphrase"])
|
14 |
+
|
15 |
+
# Initialize service instance once per application lifecycle
|
16 |
+
paraphraser_service = Paraphraser()
|
17 |
+
|
18 |
+
|
19 |
+
@router.post("/generate", dependencies=[Depends(verify_api_key)])
|
20 |
+
async def paraphrase_text_endpoint(payload: TextOnlyRequest):
|
21 |
+
"""
|
22 |
+
Generates a paraphrase for the provided text.
|
23 |
+
"""
|
24 |
+
text = payload.text.strip()
|
25 |
+
if not text:
|
26 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
|
27 |
+
|
28 |
+
logger.info(f"Received paraphrase request for text (first 50 chars): '{text[:50]}...'")
|
29 |
+
|
30 |
+
try:
|
31 |
+
# Directly call the async service method
|
32 |
+
# ModelNotDownloadedError will be raised here if model is missing,
|
33 |
+
# and caught by the global exception handler in app/main.py
|
34 |
+
result = await paraphraser_service.paraphrase(text)
|
35 |
+
|
36 |
+
logger.info(f"Paraphrasing successful for text (first 50 chars): '{text[:50]}...'")
|
37 |
+
return {"paraphrase": result} # Consistent key for response
|
38 |
+
|
39 |
+
except ServiceError as e:
|
40 |
+
# Re-raise ServiceError. It will be caught by the global exception handler.
|
41 |
+
raise e
|
42 |
+
except Exception as e:
|
43 |
+
# Catch any unexpected exceptions and re-raise as a generic ServiceError
|
44 |
+
logger.exception(f"Unhandled error in paraphrasing endpoint for text: '{text[:50]}...'")
|
45 |
+
raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during paraphrasing.") from e
|
app/routers/readability.py
CHANGED
@@ -1,21 +1,43 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
from app.schemas.base import TextOnlyRequest
|
|
|
3 |
from app.core.security import verify_api_key
|
4 |
-
from app.
|
5 |
-
import
|
|
|
|
|
6 |
|
7 |
router = APIRouter(prefix="/readability", tags=["Readability"])
|
8 |
-
logger = logging.getLogger(__name__)
|
9 |
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
text = payload.text.strip()
|
13 |
if not text:
|
14 |
-
raise HTTPException(status_code=
|
|
|
|
|
15 |
|
16 |
-
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/routers/readability.py
|
2 |
+
import logging
|
3 |
+
from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
|
4 |
+
|
5 |
from app.schemas.base import TextOnlyRequest
|
6 |
+
from app.services.readability import ReadabilityScorer # Import the service class
|
7 |
from app.core.security import verify_api_key
|
8 |
+
from app.core.config import APP_NAME # For logger naming
|
9 |
+
from app.core.exceptions import ServiceError # For re-raising internal errors
|
10 |
+
|
11 |
+
logger = logging.getLogger(f"{APP_NAME}.routers.readability")
|
12 |
|
13 |
router = APIRouter(prefix="/readability", tags=["Readability"])
|
|
|
14 |
|
15 |
+
# Initialize service instance once per application lifecycle
|
16 |
+
readability_scorer_service = ReadabilityScorer()
|
17 |
+
|
18 |
+
|
19 |
+
@router.post("/score", dependencies=[Depends(verify_api_key)]) # Added /score path for clarity
|
20 |
+
async def readability_score_endpoint(payload: TextOnlyRequest):
|
21 |
+
"""
|
22 |
+
Computes various readability scores for the provided text.
|
23 |
+
"""
|
24 |
text = payload.text.strip()
|
25 |
if not text:
|
26 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
|
27 |
+
|
28 |
+
logger.info(f"Received readability scoring request for text (first 50 chars): '{text[:50]}...'")
|
29 |
|
30 |
+
try:
|
31 |
+
# Directly call the async service method
|
32 |
+
result = await readability_scorer_service.compute(text)
|
33 |
|
34 |
+
logger.info(f"Readability scoring successful for text (first 50 chars): '{text[:50]}...'")
|
35 |
+
return {"readability_scores": result}
|
36 |
|
37 |
+
except ServiceError as e:
|
38 |
+
# Re-raise ServiceError. It will be caught by the global exception handler.
|
39 |
+
raise e
|
40 |
+
except Exception as e:
|
41 |
+
# Catch any unexpected exceptions and re-raise as a generic ServiceError
|
42 |
+
logger.exception(f"Unhandled error in readability scoring endpoint for text: '{text[:50]}...'")
|
43 |
+
raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during readability scoring.") from e
|
app/routers/rewrite.py
CHANGED
@@ -1,18 +1,59 @@
|
|
1 |
-
# routers/rewrite.py
|
|
|
|
|
2 |
|
3 |
-
from
|
4 |
-
from app.
|
5 |
-
from app.
|
6 |
-
from app.core.
|
|
|
|
|
|
|
7 |
|
8 |
router = APIRouter(prefix="/rewrite", tags=["Rewrite"])
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/routers/rewrite.py
|
2 |
+
import logging
|
3 |
+
from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
|
4 |
|
5 |
+
from app.schemas.base import RewriteRequest # Assuming this Pydantic model exists
|
6 |
+
from app.services.gpt4_rewrite import GPT4Rewriter # Import the service class
|
7 |
+
from app.core.security import verify_api_key # Assuming API key verification is still used
|
8 |
+
from app.core.config import APP_NAME # For logger naming
|
9 |
+
from app.core.exceptions import ServiceError # For re-raising internal errors
|
10 |
+
|
11 |
+
logger = logging.getLogger(f"{APP_NAME}.routers.rewrite")
|
12 |
|
13 |
router = APIRouter(prefix="/rewrite", tags=["Rewrite"])
|
14 |
+
|
15 |
+
# Initialize service instance once per application lifecycle
|
16 |
+
gpt4_rewriter_service = GPT4Rewriter()
|
17 |
+
|
18 |
+
|
19 |
+
@router.post("/with_instruction", dependencies=[Depends(verify_api_key)]) # Changed path to /with_instruction for clarity
|
20 |
+
async def rewrite_with_instruction_endpoint(payload: RewriteRequest):
|
21 |
+
"""
|
22 |
+
Rewrites the provided text based on a specific instruction using GPT-4.
|
23 |
+
Requires an OpenAI API key.
|
24 |
+
"""
|
25 |
+
text = payload.text.strip()
|
26 |
+
instruction = payload.instruction.strip()
|
27 |
+
user_api_key = payload.user_api_key # The user's provided API key
|
28 |
+
|
29 |
+
# Basic input validation for clarity, though service also validates
|
30 |
+
if not text:
|
31 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
|
32 |
+
if not instruction:
|
33 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Instruction cannot be empty.")
|
34 |
+
if not user_api_key:
|
35 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="OpenAI API key is required for this feature.")
|
36 |
+
|
37 |
+
|
38 |
+
logger.info(f"Received rewrite request for text (first 50 chars): '{text[:50]}...' with instruction (first 50 chars): '{instruction[:50]}...'")
|
39 |
+
|
40 |
+
try:
|
41 |
+
# Directly call the async service method
|
42 |
+
# ServiceError will be raised here if there's an issue (e.g., missing API key, OpenAI API error),
|
43 |
+
# and caught by the global exception handler in app/main.py.
|
44 |
+
result = await gpt4_rewriter_service.rewrite(
|
45 |
+
text=text,
|
46 |
+
instruction=instruction,
|
47 |
+
user_api_key=user_api_key # Pass the user's API key
|
48 |
+
)
|
49 |
+
|
50 |
+
logger.info(f"Rewriting successful for text (first 50 chars): '{text[:50]}...'")
|
51 |
+
return {"rewrite": result} # Consistent key for response
|
52 |
+
|
53 |
+
except ServiceError as e:
|
54 |
+
# Re-raise ServiceError. It will be caught by the global exception handler.
|
55 |
+
raise e
|
56 |
+
except Exception as e:
|
57 |
+
# Catch any unexpected exceptions and re-raise as a generic ServiceError
|
58 |
+
logger.exception(f"Unhandled error in rewriting endpoint for text: '{text[:50]}...'")
|
59 |
+
raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during rewriting.") from e
|
app/routers/synonyms.py
CHANGED
@@ -1,21 +1,45 @@
|
|
1 |
-
|
2 |
-
from app.schemas.base import TextOnlyRequest
|
3 |
-
from app.services.synonyms import SynonymSuggester
|
4 |
-
from app.core.security import verify_api_key
|
5 |
-
from app.queue import enqueue_task
|
6 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
router = APIRouter(prefix="/synonyms", tags=["Synonyms"])
|
9 |
-
logger = logging.getLogger(__name__)
|
10 |
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
text = payload.text.strip()
|
14 |
if not text:
|
15 |
-
raise HTTPException(status_code=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
raise HTTPException(status_code=500, detail=result["error"])
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/routers/synonyms.py
|
|
|
|
|
|
|
|
|
2 |
import logging
|
3 |
+
from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
|
4 |
+
|
5 |
+
from app.schemas.base import TextOnlyRequest
|
6 |
+
from app.services.synonyms import SynonymSuggester # Import the service class
|
7 |
+
from app.core.security import verify_api_key # Assuming API key verification is still used
|
8 |
+
from app.core.config import APP_NAME # For logger naming
|
9 |
+
from app.core.exceptions import ServiceError # For re-raising internal errors
|
10 |
+
|
11 |
+
logger = logging.getLogger(f"{APP_NAME}.routers.synonyms")
|
12 |
|
13 |
router = APIRouter(prefix="/synonyms", tags=["Synonyms"])
|
|
|
14 |
|
15 |
+
# Initialize service instance once per application lifecycle
|
16 |
+
synonym_suggester_service = SynonymSuggester()
|
17 |
+
|
18 |
+
|
19 |
+
@router.post("/suggest", dependencies=[Depends(verify_api_key)]) # Added /suggest path for clarity
|
20 |
+
async def suggest_synonyms_endpoint(payload: TextOnlyRequest):
|
21 |
+
"""
|
22 |
+
Suggests synonyms for words in the provided text.
|
23 |
+
"""
|
24 |
text = payload.text.strip()
|
25 |
if not text:
|
26 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
|
27 |
+
|
28 |
+
logger.info(f"Received synonym suggestion request for text (first 50 chars): '{text[:50]}...'")
|
29 |
+
|
30 |
+
try:
|
31 |
+
# Directly call the async service method
|
32 |
+
# ModelNotDownloadedError will be raised here if model/data is missing,
|
33 |
+
# and caught by the global exception handler in app/main.py
|
34 |
+
result = await synonym_suggester_service.suggest(text)
|
35 |
|
36 |
+
logger.info(f"Synonym suggestion successful for text (first 50 chars): '{text[:50]}...'")
|
37 |
+
return {"synonyms": result} # Consistent key for response
|
|
|
38 |
|
39 |
+
except ServiceError as e:
|
40 |
+
# Re-raise ServiceError. It will be caught by the global exception handler.
|
41 |
+
raise e
|
42 |
+
except Exception as e:
|
43 |
+
# Catch any unexpected exceptions and re-raise as a generic ServiceError
|
44 |
+
logger.exception(f"Unhandled error in synonym suggestion endpoint for text: '{text[:50]}...'")
|
45 |
+
raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during synonym suggestion.") from e
|
app/routers/tone.py
CHANGED
@@ -1,24 +1,45 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
-
from fastapi import APIRouter, Depends
|
|
|
4 |
from app.schemas.base import TextOnlyRequest
|
5 |
-
from app.services.tone_classification import ToneClassifier
|
6 |
-
from app.core.security import verify_api_key
|
7 |
-
from app.
|
|
|
|
|
|
|
8 |
|
9 |
router = APIRouter(prefix="/tone", tags=["Tone"])
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/routers/tone.py
|
2 |
+
import logging
|
3 |
+
from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
|
4 |
+
|
5 |
from app.schemas.base import TextOnlyRequest
|
6 |
+
from app.services.tone_classification import ToneClassifier # Import the service class
|
7 |
+
from app.core.security import verify_api_key # Assuming API key verification is still used
|
8 |
+
from app.core.config import APP_NAME # For logger naming
|
9 |
+
from app.core.exceptions import ServiceError # For re-raising internal errors
|
10 |
+
|
11 |
+
logger = logging.getLogger(f"{APP_NAME}.routers.tone")
|
12 |
|
13 |
router = APIRouter(prefix="/tone", tags=["Tone"])
|
14 |
+
|
15 |
+
# Initialize service instance once per application lifecycle
|
16 |
+
tone_classifier_service = ToneClassifier()
|
17 |
+
|
18 |
+
|
19 |
+
@router.post("/classify", dependencies=[Depends(verify_api_key)]) # Added /classify path for clarity
|
20 |
+
async def classify_tone_endpoint(payload: TextOnlyRequest):
|
21 |
+
"""
|
22 |
+
Classifies the tone of the provided text.
|
23 |
+
"""
|
24 |
+
text = payload.text.strip()
|
25 |
+
if not text:
|
26 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
|
27 |
+
|
28 |
+
logger.info(f"Received tone classification request for text (first 50 chars): '{text[:50]}...'")
|
29 |
+
|
30 |
+
try:
|
31 |
+
# Directly call the async service method
|
32 |
+
# ModelNotDownloadedError will be raised here if model is missing,
|
33 |
+
# and caught by the global exception handler in app/main.py
|
34 |
+
result = await tone_classifier_service.classify(text)
|
35 |
+
|
36 |
+
logger.info(f"Tone classification successful for text (first 50 chars): '{text[:50]}...'")
|
37 |
+
return {"tone_classification": result} # Consistent key for response
|
38 |
+
|
39 |
+
except ServiceError as e:
|
40 |
+
# Re-raise ServiceError. It will be caught by the global exception handler.
|
41 |
+
raise e
|
42 |
+
except Exception as e:
|
43 |
+
# Catch any unexpected exceptions and re-raise as a generic ServiceError
|
44 |
+
logger.exception(f"Unhandled error in tone classification endpoint for text: '{text[:50]}...'")
|
45 |
+
raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during tone classification.") from e
|
app/routers/translate.py
CHANGED
@@ -1,24 +1,48 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
|
4 |
-
from app.schemas.base import TranslateRequest
|
5 |
-
from app.services.translation import Translator
|
6 |
-
from app.core.security import verify_api_key
|
7 |
-
from app.
|
|
|
|
|
|
|
8 |
|
9 |
router = APIRouter(prefix="/translate", tags=["Translate"])
|
10 |
-
|
|
|
|
|
|
|
11 |
|
12 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
13 |
-
async def
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
3 |
+
|
4 |
+
from app.schemas.base import TranslateRequest
|
5 |
+
from app.services.translation import Translator
|
6 |
+
from app.core.security import verify_api_key
|
7 |
+
from app.core.config import APP_NAME
|
8 |
+
from app.core.exceptions import ServiceError
|
9 |
+
|
10 |
+
logger = logging.getLogger(f"{APP_NAME}.routers.translate")
|
11 |
|
12 |
router = APIRouter(prefix="/translate", tags=["Translate"])
|
13 |
+
|
14 |
+
|
15 |
+
translator_service = Translator()
|
16 |
+
|
17 |
|
18 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
19 |
+
async def translate_text_endpoint(payload: TranslateRequest):
|
20 |
+
"""
|
21 |
+
Translates the provided text to a specified target language.
|
22 |
+
"""
|
23 |
+
text = payload.text.strip()
|
24 |
+
target_lang = payload.target_lang.strip()
|
25 |
+
|
26 |
+
if not text:
|
27 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
|
28 |
+
if not target_lang:
|
29 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Target language cannot be empty.")
|
30 |
+
|
31 |
+
logger.info(f"Received translation request for text (first 50 chars): '{text[:50]}...' to '{target_lang}'")
|
32 |
+
|
33 |
+
try:
|
34 |
+
# Directly call the async service method
|
35 |
+
# ModelNotDownloadedError will be raised here if model is missing,
|
36 |
+
# and caught by the global exception handler in app/main.py
|
37 |
+
result = await translator_service.translate(text, target_lang)
|
38 |
+
|
39 |
+
logger.info(f"Translation successful for text (first 50 chars): '{text[:50]}...' to '{target_lang}'")
|
40 |
+
return {"translation": result} # Consistent key for response
|
41 |
+
|
42 |
+
except ServiceError as e:
|
43 |
+
# Re-raise ServiceError. It will be caught by the global exception handler.
|
44 |
+
raise e
|
45 |
+
except Exception as e:
|
46 |
+
# Catch any unexpected exceptions and re-raise as a generic ServiceError
|
47 |
+
logger.exception(f"Unhandled error in translation endpoint for text: '{text[:50]}...' to '{target_lang}'")
|
48 |
+
raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during translation.") from e
|
app/routers/voice.py
CHANGED
@@ -1,11 +1,45 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
from app.schemas.base import TextOnlyRequest
|
3 |
-
from app.services.voice_detection import VoiceDetector
|
4 |
-
from app.core.security import verify_api_key
|
|
|
|
|
|
|
|
|
5 |
|
6 |
router = APIRouter(prefix="/voice", tags=["Voice"])
|
7 |
-
detector = VoiceDetector()
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/routers/voice.py
|
2 |
+
import logging
|
3 |
+
from fastapi import APIRouter, Depends, HTTPException, status # Import HTTPException and status for validation
|
4 |
+
|
5 |
from app.schemas.base import TextOnlyRequest
|
6 |
+
from app.services.voice_detection import VoiceDetector # Import the service class
|
7 |
+
from app.core.security import verify_api_key # Assuming API key verification is still used
|
8 |
+
from app.core.config import APP_NAME # For logger naming
|
9 |
+
from app.core.exceptions import ServiceError # For re-raising internal errors
|
10 |
+
|
11 |
+
logger = logging.getLogger(f"{APP_NAME}.routers.voice")
|
12 |
|
13 |
router = APIRouter(prefix="/voice", tags=["Voice"])
|
|
|
14 |
|
15 |
+
# Initialize service instance once per application lifecycle
|
16 |
+
voice_detector_service = VoiceDetector()
|
17 |
+
|
18 |
+
|
19 |
+
@router.post("/detect", dependencies=[Depends(verify_api_key)]) # Added /detect path for clarity
|
20 |
+
async def detect_voice_endpoint(payload: TextOnlyRequest):
|
21 |
+
"""
|
22 |
+
Detects the voice (active or passive) of the provided text.
|
23 |
+
"""
|
24 |
+
text = payload.text.strip()
|
25 |
+
if not text:
|
26 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Input text cannot be empty.")
|
27 |
+
|
28 |
+
logger.info(f"Received voice detection request for text (first 50 chars): '{text[:50]}...'")
|
29 |
+
|
30 |
+
try:
|
31 |
+
# Directly call the async service method
|
32 |
+
# ModelNotDownloadedError will be raised here if model is missing,
|
33 |
+
# and caught by the global exception handler in app/main.py
|
34 |
+
result = await voice_detector_service.classify(text)
|
35 |
+
|
36 |
+
logger.info(f"Voice detection successful for text (first 50 chars): '{text[:50]}...'")
|
37 |
+
return {"voice_detection": result} # Consistent key for response
|
38 |
+
|
39 |
+
except ServiceError as e:
|
40 |
+
# Re-raise ServiceError. It will be caught by the global exception handler.
|
41 |
+
raise e
|
42 |
+
except Exception as e:
|
43 |
+
# Catch any unexpected exceptions and re-raise as a generic ServiceError
|
44 |
+
logger.exception(f"Unhandled error in voice detection endpoint for text: '{text[:50]}...'")
|
45 |
+
raise ServiceError(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during voice detection.") from e
|
app/services/base.py
CHANGED
@@ -1,49 +1,132 @@
|
|
1 |
-
import torch
|
2 |
-
from threading import Lock
|
3 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
logger = logging.getLogger(__name__)
|
6 |
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
_models_lock = Lock()
|
12 |
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
logger.info(f"Loading model: {model_name}")
|
18 |
-
_models[model_name] = load_fn()
|
19 |
-
return _models[model_name]
|
20 |
|
|
|
21 |
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
import time
|
24 |
start = time.time()
|
25 |
-
model =
|
26 |
-
|
|
|
27 |
return model
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
import spacy
|
37 |
-
_nlp = spacy.load("en_core_web_sm")
|
38 |
-
return _nlp
|
39 |
|
|
|
|
|
|
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
|
|
|
|
|
|
|
47 |
|
48 |
-
def model_response(result: str = "", error: str = None) -> dict:
|
49 |
-
return {"result": result, "error": error}
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
from functools import lru_cache
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import (
|
7 |
+
pipeline,
|
8 |
+
AutoTokenizer,
|
9 |
+
AutoModelForSequenceClassification,
|
10 |
+
AutoModelForSeq2SeqLM,
|
11 |
+
AutoModelForMaskedLM,
|
12 |
+
)
|
13 |
+
|
14 |
+
from sentence_transformers import SentenceTransformer
|
15 |
+
|
16 |
+
from app.core.config import (
|
17 |
+
MODELS_DIR, SPACY_MODEL_ID, SENTENCE_TRANSFORMER_MODEL_ID,
|
18 |
+
OFFLINE_MODE
|
19 |
+
)
|
20 |
+
from app.core.exceptions import ModelNotDownloadedError
|
21 |
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
25 |
+
# 🧠 SpaCy
|
26 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
27 |
+
|
28 |
+
@lru_cache(maxsize=1)
|
29 |
+
def load_spacy_model(model_id: str = SPACY_MODEL_ID):
|
30 |
+
import spacy
|
31 |
+
from spacy.util import is_package
|
32 |
|
33 |
+
logger.info(f"Loading spaCy model: {model_id}")
|
|
|
34 |
|
35 |
+
if is_package(model_id):
|
36 |
+
return spacy.load(model_id)
|
37 |
|
38 |
+
possible_path = MODELS_DIR / model_id
|
39 |
+
if possible_path.exists():
|
40 |
+
return spacy.load(str(possible_path))
|
|
|
|
|
|
|
41 |
|
42 |
+
raise RuntimeError(f"Could not find spaCy model '{model_id}' at {possible_path}")
|
43 |
|
44 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
45 |
+
# 🔤 Sentence Transformers
|
46 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
47 |
+
|
48 |
+
@lru_cache(maxsize=1)
|
49 |
+
def load_sentence_transformer_model(model_id: str = SENTENCE_TRANSFORMER_MODEL_ID) -> SentenceTransformer:
|
50 |
+
logger.info(f"Loading SentenceTransformer: {model_id}")
|
51 |
+
return SentenceTransformer(model_name_or_path=model_id, cache_folder=MODELS_DIR)
|
52 |
+
|
53 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
54 |
+
# 🤗 Hugging Face Pipelines (T5 models, classifiers, etc.)
|
55 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
56 |
+
|
57 |
+
def _check_model_downloaded(model_id: str, cache_dir: str) -> bool:
|
58 |
+
model_path = Path(cache_dir) / model_id.replace("/", "_")
|
59 |
+
return model_path.exists()
|
60 |
+
|
61 |
+
def _timed_load(name: str, fn):
|
62 |
import time
|
63 |
start = time.time()
|
64 |
+
model = fn()
|
65 |
+
elapsed = round(time.time() - start, 2)
|
66 |
+
logger.info(f"[{name}] model loaded in {elapsed}s")
|
67 |
return model
|
68 |
|
69 |
+
@lru_cache(maxsize=2)
|
70 |
+
def load_hf_pipeline(model_id: str, task: str, feature_name: str, **kwargs):
|
71 |
+
if OFFLINE_MODE and not _check_model_downloaded(model_id, str(MODELS_DIR)):
|
72 |
+
raise ModelNotDownloadedError(model_id, feature_name, "Model not found locally in offline mode.")
|
73 |
+
|
74 |
+
try:
|
75 |
+
# Choose appropriate AutoModel loader based on task
|
76 |
+
if task == "text-classification":
|
77 |
+
model_loader = AutoModelForSequenceClassification
|
78 |
+
elif task == "text2text-generation" or task.startswith("translation"):
|
79 |
+
model_loader = AutoModelForSeq2SeqLM
|
80 |
+
elif task == "fill-mask":
|
81 |
+
model_loader = AutoModelForMaskedLM
|
82 |
+
else:
|
83 |
+
raise ValueError(f"Unsupported task type '{task}' for feature '{feature_name}'.")
|
84 |
+
|
85 |
+
model = _timed_load(
|
86 |
+
f"{feature_name}:{model_id} (model)",
|
87 |
+
lambda: model_loader.from_pretrained(
|
88 |
+
model_id,
|
89 |
+
cache_dir=MODELS_DIR,
|
90 |
+
local_files_only=OFFLINE_MODE
|
91 |
+
)
|
92 |
+
)
|
93 |
|
94 |
+
tokenizer = _timed_load(
|
95 |
+
f"{feature_name}:{model_id} (tokenizer)",
|
96 |
+
lambda: AutoTokenizer.from_pretrained(
|
97 |
+
model_id,
|
98 |
+
cache_dir=MODELS_DIR,
|
99 |
+
local_files_only=OFFLINE_MODE
|
100 |
+
)
|
101 |
+
)
|
102 |
|
103 |
+
return pipeline(
|
104 |
+
task=task,
|
105 |
+
model=model,
|
106 |
+
tokenizer=tokenizer,
|
107 |
+
device=0 if torch.cuda.is_available() else -1,
|
108 |
+
**kwargs
|
109 |
+
)
|
110 |
|
111 |
+
except Exception as e:
|
112 |
+
logger.error(f"Failed to load pipeline for '{feature_name}' - {model_id}: {e}", exc_info=True)
|
113 |
+
raise ModelNotDownloadedError(model_id, feature_name, str(e))
|
|
|
|
|
|
|
114 |
|
115 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
116 |
+
# 📚 NLTK
|
117 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
118 |
|
119 |
+
@lru_cache(maxsize=1)
|
120 |
+
def ensure_nltk_resource(resource_name: str = "wordnet") -> None:
|
121 |
+
try:
|
122 |
+
import nltk
|
123 |
+
nltk.data.find(f"corpora/{resource_name}")
|
124 |
+
except (LookupError, ImportError):
|
125 |
+
if OFFLINE_MODE:
|
126 |
+
raise RuntimeError(f"NLTK resource '{resource_name}' not found in offline mode.")
|
127 |
+
nltk.download(resource_name)
|
128 |
|
129 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
130 |
+
# 🎯 Ready-to-use Loaders (for your app use)
|
131 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
132 |
|
|
|
|
app/services/gpt4_rewrite.py
CHANGED
@@ -1,46 +1,76 @@
|
|
1 |
import openai
|
2 |
import logging
|
3 |
-
|
4 |
-
from
|
5 |
-
from app.
|
|
|
6 |
|
7 |
-
logger = logging.getLogger(
|
8 |
|
9 |
class GPT4Rewriter:
|
10 |
-
@retry(
|
11 |
-
|
|
|
|
|
|
|
|
|
12 |
try:
|
13 |
if not user_api_key:
|
14 |
-
raise ServiceError("
|
15 |
|
16 |
text = text.strip()
|
17 |
instruction = instruction.strip()
|
18 |
|
19 |
if not text:
|
20 |
-
raise ServiceError("Input text is empty.")
|
21 |
if not instruction:
|
22 |
-
raise ServiceError("
|
23 |
|
24 |
messages = [
|
25 |
{"role": "system", "content": instruction},
|
26 |
{"role": "user", "content": text},
|
27 |
]
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
except openai.APIError as e:
|
42 |
-
logger.error(f"OpenAI API error: {e}")
|
43 |
-
|
|
|
|
|
|
|
|
|
44 |
except Exception as e:
|
45 |
-
logger.error(f"Unexpected error in GPT-4 rewrite: {
|
46 |
-
|
|
|
1 |
import openai
|
2 |
import logging
|
3 |
+
import asyncio
|
4 |
+
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
5 |
+
from app.core.config import settings, APP_NAME
|
6 |
+
from app.core.exceptions import ServiceError
|
7 |
|
8 |
+
logger = logging.getLogger(f"{APP_NAME}.services.gpt4_rewrite")
|
9 |
|
10 |
class GPT4Rewriter:
|
11 |
+
@retry(
|
12 |
+
stop=stop_after_attempt(3),
|
13 |
+
wait=wait_exponential(multiplier=1, min=2, max=10),
|
14 |
+
retry=retry_if_exception_type(openai.APIError)
|
15 |
+
)
|
16 |
+
async def rewrite(self, text: str, user_api_key: str, instruction: str) -> dict:
|
17 |
try:
|
18 |
if not user_api_key:
|
19 |
+
raise ServiceError(status_code=401, detail="OpenAI API key is missing. Please provide your key to use this feature.")
|
20 |
|
21 |
text = text.strip()
|
22 |
instruction = instruction.strip()
|
23 |
|
24 |
if not text:
|
25 |
+
raise ServiceError(status_code=400, detail="Input text is empty for rewriting.")
|
26 |
if not instruction:
|
27 |
+
raise ServiceError(status_code=400, detail="Rewrite instruction is missing.")
|
28 |
|
29 |
messages = [
|
30 |
{"role": "system", "content": instruction},
|
31 |
{"role": "user", "content": text},
|
32 |
]
|
33 |
|
34 |
+
def _call_openai_api():
|
35 |
+
client = openai.OpenAI(api_key=user_api_key)
|
36 |
+
response = client.chat.completions.create(
|
37 |
+
model=settings.OPENAI_MODEL,
|
38 |
+
messages=messages,
|
39 |
+
temperature=settings.OPENAI_TEMPERATURE,
|
40 |
+
max_tokens=settings.OPENAI_MAX_TOKENS
|
41 |
+
)
|
42 |
+
return response.choices[0].message.content.strip()
|
43 |
+
|
44 |
+
result = await asyncio.to_thread(_call_openai_api)
|
45 |
+
return {"rewritten_text": result}
|
46 |
+
|
47 |
+
except openai.APIStatusError as e:
|
48 |
+
logger.error(f"OpenAI API status error: {e.status_code} - {e.response}", exc_info=True)
|
49 |
+
detail_message = "An OpenAI API error occurred."
|
50 |
+
if e.status_code == 401:
|
51 |
+
detail_message = "Invalid OpenAI API key. Please check your key."
|
52 |
+
elif e.status_code == 429:
|
53 |
+
detail_message = "OpenAI API rate limit exceeded or quota exhausted. Please try again later."
|
54 |
+
elif e.status_code == 400:
|
55 |
+
detail_message = f"OpenAI API request error: {e.response.json().get('detail', e.message)}"
|
56 |
+
|
57 |
+
raise ServiceError(status_code=e.status_code, detail=detail_message) from e
|
58 |
+
|
59 |
+
except openai.APITimeoutError as e:
|
60 |
+
logger.error(f"OpenAI API timeout error: {e}", exc_info=True)
|
61 |
+
raise ServiceError(status_code=504, detail="OpenAI API request timed out. Please try again.") from e
|
62 |
+
|
63 |
+
except openai.APIConnectionError as e:
|
64 |
+
logger.error(f"OpenAI API connection error: {e}", exc_info=True)
|
65 |
+
raise ServiceError(status_code=503, detail="Could not connect to OpenAI API. Please check your internet connection.") from e
|
66 |
+
|
67 |
except openai.APIError as e:
|
68 |
+
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
69 |
+
raise ServiceError(status_code=500, detail=f"An unexpected OpenAI API error occurred: {str(e)}") from e
|
70 |
+
|
71 |
+
except ServiceError as e:
|
72 |
+
raise e
|
73 |
+
|
74 |
except Exception as e:
|
75 |
+
logger.error(f"Unexpected error in GPT-4 rewrite for text: '{text[:50]}...'", exc_info=True)
|
76 |
+
raise ServiceError(status_code=500, detail="An unexpected error occurred during rewriting.") from e
|
app/services/grammar.py
CHANGED
@@ -1,78 +1,77 @@
|
|
1 |
import difflib
|
2 |
import logging
|
|
|
|
|
3 |
import torch
|
4 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
5 |
|
6 |
-
from app.services.base import
|
7 |
-
get_cached_model, DEVICE, timed_model_load,
|
8 |
-
ServiceError, model_response
|
9 |
-
)
|
10 |
from app.core.config import settings
|
|
|
|
|
|
|
11 |
|
12 |
-
logger = logging.getLogger(__name__)
|
13 |
|
14 |
class GrammarCorrector:
|
15 |
def __init__(self):
|
16 |
-
self.
|
17 |
|
18 |
-
def
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
"grammar_model",
|
26 |
-
lambda: AutoModelForSeq2SeqLM.from_pretrained(settings.GRAMMAR_MODEL)
|
27 |
)
|
28 |
-
|
29 |
-
return tokenizer, model
|
30 |
|
31 |
-
|
|
|
|
|
|
|
32 |
|
33 |
-
def correct(self, text: str) -> dict:
|
34 |
try:
|
35 |
-
|
36 |
-
if not text:
|
37 |
-
raise ServiceError("Input text is empty.")
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
43 |
|
44 |
issues = self.get_diff_issues(text, corrected)
|
45 |
|
46 |
-
return
|
47 |
"original_text": text,
|
48 |
"corrected_text_suggestion": corrected,
|
49 |
"issues": issues
|
50 |
-
}
|
51 |
|
52 |
-
except ServiceError as se:
|
53 |
-
return model_response(error=str(se))
|
54 |
except Exception as e:
|
55 |
-
logger.error(f"Grammar correction error: {
|
56 |
-
|
|
|
|
|
|
|
|
|
57 |
|
58 |
-
def get_diff_issues(self, original: str, corrected: str):
|
59 |
matcher = difflib.SequenceMatcher(None, original, corrected)
|
60 |
issues = []
|
61 |
|
62 |
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
63 |
-
if tag ==
|
64 |
continue
|
65 |
|
66 |
issues.append({
|
67 |
"offset": i1,
|
68 |
"length": i2 - i1,
|
69 |
-
"
|
70 |
-
"
|
71 |
-
"context_before": original
|
72 |
-
"context_after": original
|
73 |
"message": "Grammar correction",
|
74 |
"line": original[:i1].count("\n") + 1,
|
75 |
-
"column": i1 - original[:i1].rfind("\n") if "\n" in original[:i1] else i1 + 1
|
76 |
})
|
77 |
|
78 |
return issues
|
|
|
1 |
import difflib
|
2 |
import logging
|
3 |
+
from typing import List
|
4 |
+
|
5 |
import torch
|
|
|
6 |
|
7 |
+
from app.services.base import load_hf_pipeline
|
|
|
|
|
|
|
8 |
from app.core.config import settings
|
9 |
+
from app.core.exceptions import ServiceError
|
10 |
+
|
11 |
+
logger = logging.getLogger(f"{settings.APP_NAME}.services.grammar")
|
12 |
|
|
|
13 |
|
14 |
class GrammarCorrector:
|
15 |
def __init__(self):
|
16 |
+
self._pipeline = None
|
17 |
|
18 |
+
def _get_pipeline(self):
|
19 |
+
if self._pipeline is None:
|
20 |
+
logger.info("Loading grammar correction pipeline...")
|
21 |
+
self._pipeline = load_hf_pipeline(
|
22 |
+
model_id=settings.GRAMMAR_MODEL_ID,
|
23 |
+
task="text2text-generation",
|
24 |
+
feature_name="Grammar Correction"
|
|
|
|
|
25 |
)
|
26 |
+
return self._pipeline
|
|
|
27 |
|
28 |
+
async def correct(self, text: str) -> dict:
|
29 |
+
text = text.strip()
|
30 |
+
if not text:
|
31 |
+
raise ServiceError(status_code=400, detail="Input text is empty for grammar correction.")
|
32 |
|
|
|
33 |
try:
|
34 |
+
pipeline = self._get_pipeline()
|
|
|
|
|
35 |
|
36 |
+
result = pipeline(text, max_length=512, num_beams=4, early_stopping=True)
|
37 |
+
corrected = result[0]["generated_text"].strip()
|
38 |
+
|
39 |
+
if not corrected:
|
40 |
+
raise ServiceError(status_code=500, detail="Failed to decode grammar correction output.")
|
41 |
|
42 |
issues = self.get_diff_issues(text, corrected)
|
43 |
|
44 |
+
return {
|
45 |
"original_text": text,
|
46 |
"corrected_text_suggestion": corrected,
|
47 |
"issues": issues
|
48 |
+
}
|
49 |
|
|
|
|
|
50 |
except Exception as e:
|
51 |
+
logger.error(f"Grammar correction error for input: '{text[:50]}...'", exc_info=True)
|
52 |
+
raise ServiceError(status_code=500, detail="An internal error occurred during grammar correction.") from e
|
53 |
+
|
54 |
+
def get_diff_issues(self, original: str, corrected: str) -> List[dict]:
|
55 |
+
def safe_slice(s: str, start: int, end: int) -> str:
|
56 |
+
return s[max(0, start):min(len(s), end)]
|
57 |
|
|
|
58 |
matcher = difflib.SequenceMatcher(None, original, corrected)
|
59 |
issues = []
|
60 |
|
61 |
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
62 |
+
if tag == "equal":
|
63 |
continue
|
64 |
|
65 |
issues.append({
|
66 |
"offset": i1,
|
67 |
"length": i2 - i1,
|
68 |
+
"original_segment": original[i1:i2],
|
69 |
+
"suggested_segment": corrected[j1:j2],
|
70 |
+
"context_before": safe_slice(original, i1 - 15, i1),
|
71 |
+
"context_after": safe_slice(original, i2, i2 + 15),
|
72 |
"message": "Grammar correction",
|
73 |
"line": original[:i1].count("\n") + 1,
|
74 |
+
"column": (i1 - original[:i1].rfind("\n") - 1) if "\n" in original[:i1] else i1 + 1
|
75 |
})
|
76 |
|
77 |
return issues
|
app/services/inclusive_language.py
CHANGED
@@ -1,72 +1,120 @@
|
|
|
|
1 |
import yaml
|
2 |
from pathlib import Path
|
3 |
from typing import List, Dict
|
4 |
-
from app.services.base import get_spacy, model_response, ServiceError
|
5 |
-
from app.core.config import settings
|
6 |
-
import logging
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
class InclusiveLanguageChecker:
|
11 |
-
def __init__(self, rules_directory=settings.INCLUSIVE_RULES_DIR):
|
12 |
-
self.
|
13 |
-
self.matcher =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
def _load_inclusive_rules(self, directory: str) -> Dict[str, Dict]:
|
16 |
rules = {}
|
17 |
-
for
|
18 |
try:
|
19 |
-
with open(
|
20 |
rule_list = yaml.safe_load(f)
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
42 |
except Exception as e:
|
43 |
-
logger.error(f"Error loading
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
return rules
|
45 |
|
46 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
from spacy.matcher import PhraseMatcher
|
48 |
-
|
|
|
49 |
for phrase in self.rules:
|
50 |
-
matcher.add(phrase, [
|
51 |
-
|
|
|
52 |
return matcher
|
53 |
|
54 |
-
def check(self, text: str) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
try:
|
56 |
-
|
57 |
-
if
|
58 |
-
|
59 |
|
60 |
-
nlp = get_spacy()
|
61 |
doc = nlp(text)
|
62 |
matches = self.matcher(doc)
|
63 |
results = []
|
64 |
matched_spans = set()
|
65 |
|
|
|
66 |
for match_id, start, end in matches:
|
67 |
-
|
|
|
|
|
|
|
68 |
matched_spans.add((start, end))
|
69 |
-
rule = self.rules.get(
|
70 |
if rule:
|
71 |
results.append({
|
72 |
"term": doc[start:end].text,
|
@@ -74,15 +122,18 @@ class InclusiveLanguageChecker:
|
|
74 |
"note": rule["note"],
|
75 |
"suggestions": rule["considerate"],
|
76 |
"context": doc[start:end].sent.text,
|
77 |
-
"
|
78 |
-
"
|
79 |
-
"source": rule
|
80 |
})
|
81 |
|
|
|
82 |
for token in doc:
|
83 |
lemma = token.lemma_.lower()
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
rule = self.rules[lemma]
|
87 |
results.append({
|
88 |
"term": token.text,
|
@@ -90,15 +141,16 @@ class InclusiveLanguageChecker:
|
|
90 |
"note": rule["note"],
|
91 |
"suggestions": rule["considerate"],
|
92 |
"context": token.sent.text,
|
93 |
-
"
|
94 |
-
"
|
95 |
-
"source": rule
|
96 |
})
|
97 |
|
98 |
-
return
|
99 |
|
100 |
-
except ServiceError as se:
|
101 |
-
return model_response(error=str(se))
|
102 |
except Exception as e:
|
103 |
-
logger.error(f"Inclusive language check error: {
|
104 |
-
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
import yaml
|
3 |
from pathlib import Path
|
4 |
from typing import List, Dict
|
|
|
|
|
|
|
5 |
|
6 |
+
from app.services.base import load_spacy_model
|
7 |
+
from app.core.config import settings, APP_NAME, SPACY_MODEL_ID
|
8 |
+
from app.core.exceptions import ServiceError
|
9 |
+
|
10 |
+
logger = logging.getLogger(f"{APP_NAME}.services.inclusive_language")
|
11 |
+
|
12 |
|
13 |
class InclusiveLanguageChecker:
|
14 |
+
def __init__(self, rules_directory: str = settings.INCLUSIVE_RULES_DIR):
|
15 |
+
self._nlp = None
|
16 |
+
self.matcher = None
|
17 |
+
self.rules = self._load_inclusive_rules(Path(rules_directory))
|
18 |
+
|
19 |
+
def _load_inclusive_rules(self, rules_path: Path) -> Dict[str, Dict]:
|
20 |
+
"""
|
21 |
+
Load YAML-based inclusive language rules from the given directory.
|
22 |
+
"""
|
23 |
+
if not rules_path.is_dir():
|
24 |
+
logger.error(f"Inclusive language rules directory not found: {rules_path}")
|
25 |
+
raise ServiceError(
|
26 |
+
status_code=500,
|
27 |
+
detail=f"Inclusive language rules directory not found: {rules_path}"
|
28 |
+
)
|
29 |
|
|
|
30 |
rules = {}
|
31 |
+
for yaml_file in rules_path.glob("*.yml"):
|
32 |
try:
|
33 |
+
with yaml_file.open(encoding="utf-8") as f:
|
34 |
rule_list = yaml.safe_load(f)
|
35 |
+
|
36 |
+
if not isinstance(rule_list, list):
|
37 |
+
logger.warning(f"Skipping non-list rule file: {yaml_file}")
|
38 |
+
continue
|
39 |
+
|
40 |
+
for rule in rule_list:
|
41 |
+
inconsiderate = rule.get("inconsiderate", [])
|
42 |
+
considerate = rule.get("considerate", [])
|
43 |
+
note = rule.get("note", "")
|
44 |
+
source = rule.get("source", "")
|
45 |
+
rule_type = rule.get("type", "basic")
|
46 |
+
|
47 |
+
# Ensure consistent formatting
|
48 |
+
if isinstance(considerate, str):
|
49 |
+
considerate = [considerate]
|
50 |
+
if isinstance(inconsiderate, str):
|
51 |
+
inconsiderate = [inconsiderate]
|
52 |
+
|
53 |
+
for phrase in inconsiderate:
|
54 |
+
rules[phrase.lower()] = {
|
55 |
+
"considerate": considerate,
|
56 |
+
"note": note,
|
57 |
+
"source": source,
|
58 |
+
"type": rule_type
|
59 |
+
}
|
60 |
+
|
61 |
except Exception as e:
|
62 |
+
logger.error(f"Error loading rule file {yaml_file}: {e}", exc_info=True)
|
63 |
+
raise ServiceError(
|
64 |
+
status_code=500,
|
65 |
+
detail=f"Failed to load inclusive language rules: {e}"
|
66 |
+
)
|
67 |
+
|
68 |
+
logger.info(f"Loaded {len(rules)} inclusive language rules from {rules_path}")
|
69 |
return rules
|
70 |
|
71 |
+
def _get_nlp(self):
|
72 |
+
"""
|
73 |
+
Lazy-loads the spaCy model for NLP processing.
|
74 |
+
"""
|
75 |
+
if self._nlp is None:
|
76 |
+
self._nlp = load_spacy_model(SPACY_MODEL_ID)
|
77 |
+
return self._nlp
|
78 |
+
|
79 |
+
def _init_matcher(self, nlp):
|
80 |
+
"""
|
81 |
+
Initializes spaCy PhraseMatcher using loaded rules.
|
82 |
+
"""
|
83 |
from spacy.matcher import PhraseMatcher
|
84 |
+
|
85 |
+
matcher = PhraseMatcher(nlp.vocab, attr="LOWER")
|
86 |
for phrase in self.rules:
|
87 |
+
matcher.add(phrase, [nlp.make_doc(phrase)])
|
88 |
+
|
89 |
+
logger.info(f"PhraseMatcher initialized with {len(self.rules)} phrases.")
|
90 |
return matcher
|
91 |
|
92 |
+
async def check(self, text: str) -> dict:
|
93 |
+
"""
|
94 |
+
Checks a string for non-inclusive language based on rule definitions.
|
95 |
+
"""
|
96 |
+
text = text.strip()
|
97 |
+
if not text:
|
98 |
+
raise ServiceError(status_code=400, detail="Input text is empty for inclusive language check.")
|
99 |
+
|
100 |
try:
|
101 |
+
nlp = self._get_nlp()
|
102 |
+
if self.matcher is None:
|
103 |
+
self.matcher = self._init_matcher(nlp)
|
104 |
|
|
|
105 |
doc = nlp(text)
|
106 |
matches = self.matcher(doc)
|
107 |
results = []
|
108 |
matched_spans = set()
|
109 |
|
110 |
+
# Match exact phrases
|
111 |
for match_id, start, end in matches:
|
112 |
+
phrase = nlp.vocab.strings[match_id].lower()
|
113 |
+
if any(s <= start < e or s < end <= e for s, e in matched_spans):
|
114 |
+
continue # Avoid overlapping matches
|
115 |
+
|
116 |
matched_spans.add((start, end))
|
117 |
+
rule = self.rules.get(phrase)
|
118 |
if rule:
|
119 |
results.append({
|
120 |
"term": doc[start:end].text,
|
|
|
122 |
"note": rule["note"],
|
123 |
"suggestions": rule["considerate"],
|
124 |
"context": doc[start:end].sent.text,
|
125 |
+
"start_char": doc[start].idx,
|
126 |
+
"end_char": doc[end - 1].idx + len(doc[end - 1]),
|
127 |
+
"source": rule["source"]
|
128 |
})
|
129 |
|
130 |
+
# Match individual token lemmas (fallback)
|
131 |
for token in doc:
|
132 |
lemma = token.lemma_.lower()
|
133 |
+
if (token.i, token.i + 1) in matched_spans:
|
134 |
+
continue # Already matched in phrase
|
135 |
+
|
136 |
+
if lemma in self.rules:
|
137 |
rule = self.rules[lemma]
|
138 |
results.append({
|
139 |
"term": token.text,
|
|
|
141 |
"note": rule["note"],
|
142 |
"suggestions": rule["considerate"],
|
143 |
"context": token.sent.text,
|
144 |
+
"start_char": token.idx,
|
145 |
+
"end_char": token.idx + len(token),
|
146 |
+
"source": rule["source"]
|
147 |
})
|
148 |
|
149 |
+
return {"issues": results}
|
150 |
|
|
|
|
|
151 |
except Exception as e:
|
152 |
+
logger.error(f"Inclusive language check error for text: '{text[:50]}...'", exc_info=True)
|
153 |
+
raise ServiceError(
|
154 |
+
status_code=500,
|
155 |
+
detail="An internal error occurred during inclusive language checking."
|
156 |
+
) from e
|
app/services/paraphrase.py
CHANGED
@@ -1,44 +1,40 @@
|
|
1 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
-
from app.services.base import get_cached_model, DEVICE, timed_model_load, model_response, ServiceError
|
3 |
-
from app.core.config import settings
|
4 |
-
import torch
|
5 |
import logging
|
6 |
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
class Paraphraser:
|
10 |
def __init__(self):
|
11 |
-
self.
|
12 |
|
13 |
-
def
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
20 |
|
21 |
-
def paraphrase(self, text: str) -> dict:
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
raise ServiceError("Input text is empty.")
|
26 |
|
|
|
|
|
27 |
prompt = f"paraphrase: {text} </s>"
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
num_return_sequences=1,
|
35 |
-
early_stopping=True
|
36 |
-
)
|
37 |
-
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
38 |
-
return model_response(result=result)
|
39 |
-
|
40 |
-
except ServiceError as se:
|
41 |
-
return model_response(error=str(se))
|
42 |
except Exception as e:
|
43 |
-
logger.error(f"Paraphrasing error: {
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
|
3 |
+
from app.services.base import load_hf_pipeline
|
4 |
+
from app.core.config import settings, APP_NAME
|
5 |
+
from app.core.exceptions import ServiceError
|
6 |
+
|
7 |
+
logger = logging.getLogger(f"{APP_NAME}.services.paraphrase")
|
8 |
+
|
9 |
|
10 |
class Paraphraser:
|
11 |
def __init__(self):
|
12 |
+
self._pipeline = None
|
13 |
|
14 |
+
def _get_pipeline(self):
|
15 |
+
if self._pipeline is None:
|
16 |
+
logger.info("Loading paraphrasing pipeline...")
|
17 |
+
self._pipeline = load_hf_pipeline(
|
18 |
+
model_id=settings.PARAPHRASE_MODEL_ID,
|
19 |
+
task="text2text-generation",
|
20 |
+
feature_name="Paraphrasing"
|
21 |
+
)
|
22 |
+
return self._pipeline
|
23 |
|
24 |
+
async def paraphrase(self, text: str) -> dict:
|
25 |
+
text = text.strip()
|
26 |
+
if not text:
|
27 |
+
raise ServiceError(status_code=400, detail="Input text is empty for paraphrasing.")
|
|
|
28 |
|
29 |
+
try:
|
30 |
+
pipeline = self._get_pipeline()
|
31 |
prompt = f"paraphrase: {text} </s>"
|
32 |
+
|
33 |
+
results = pipeline(prompt, max_length=256, num_beams=5, num_return_sequences=1, early_stopping=True)
|
34 |
+
paraphrased = results[0]["generated_text"].strip()
|
35 |
+
|
36 |
+
return {"paraphrased_text": paraphrased}
|
37 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
except Exception as e:
|
39 |
+
logger.error(f"Paraphrasing error for text: '{text[:50]}...'", exc_info=True)
|
40 |
+
raise ServiceError(status_code=500, detail="An internal error occurred during paraphrasing.") from e
|
app/services/readability.py
CHANGED
@@ -1,17 +1,18 @@
|
|
|
|
1 |
import textstat
|
2 |
import logging
|
3 |
-
from app.
|
|
|
4 |
|
5 |
-
logger = logging.getLogger(
|
6 |
|
7 |
class ReadabilityScorer:
|
8 |
-
def compute(self, text: str) -> dict:
|
9 |
try:
|
10 |
text = text.strip()
|
11 |
if not text:
|
12 |
-
raise ServiceError("Input text is empty.")
|
13 |
|
14 |
-
# Compute scores
|
15 |
scores = {
|
16 |
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
17 |
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
@@ -21,41 +22,39 @@ class ReadabilityScorer:
|
|
21 |
"automated_readability_index": textstat.automated_readability_index(text),
|
22 |
}
|
23 |
|
24 |
-
# Friendly descriptions
|
25 |
friendly_scores = {
|
26 |
"flesch_reading_ease": {
|
27 |
-
"score": scores["flesch_reading_ease"],
|
28 |
"label": "Flesch Reading Ease",
|
29 |
"description": "Higher is easier. 60–70 is plain English; 90+ is very easy."
|
30 |
},
|
31 |
"flesch_kincaid_grade": {
|
32 |
-
"score": scores["flesch_kincaid_grade"],
|
33 |
"label": "Flesch-Kincaid Grade Level",
|
34 |
"description": "U.S. school grade. 8.0 means an 8th grader can understand it."
|
35 |
},
|
36 |
"gunning_fog_index": {
|
37 |
-
"score": scores["gunning_fog_index"],
|
38 |
"label": "Gunning Fog Index",
|
39 |
"description": "Estimates years of formal education needed to understand."
|
40 |
},
|
41 |
"smog_index": {
|
42 |
-
"score": scores["smog_index"],
|
43 |
"label": "SMOG Index",
|
44 |
"description": "Also estimates required years of education."
|
45 |
},
|
46 |
"coleman_liau_index": {
|
47 |
-
"score": scores["coleman_liau_index"],
|
48 |
"label": "Coleman-Liau Index",
|
49 |
"description": "Grade level based on characters, not syllables."
|
50 |
},
|
51 |
"automated_readability_index": {
|
52 |
-
"score": scores["automated_readability_index"],
|
53 |
"label": "Automated Readability Index",
|
54 |
"description": "Grade level using word and sentence lengths."
|
55 |
}
|
56 |
}
|
57 |
|
58 |
-
# Flesch score guide
|
59 |
ease_score = scores["flesch_reading_ease"]
|
60 |
if ease_score >= 90:
|
61 |
summary = "Very easy to read. Easily understood by 11-year-olds."
|
@@ -68,13 +67,13 @@ class ReadabilityScorer:
|
|
68 |
else:
|
69 |
summary = "Very difficult. Best understood by university graduates."
|
70 |
|
71 |
-
return
|
72 |
"readability_summary": summary,
|
73 |
"scores": friendly_scores
|
74 |
-
}
|
75 |
|
76 |
-
except ServiceError as se:
|
77 |
-
return model_response(error=str(se))
|
78 |
except Exception as e:
|
79 |
-
logger.error(f"Readability scoring error: {
|
80 |
-
|
|
|
|
|
|
1 |
+
# app/services/readability.py
|
2 |
import textstat
|
3 |
import logging
|
4 |
+
from app.core.config import APP_NAME
|
5 |
+
from app.core.exceptions import ServiceError
|
6 |
|
7 |
+
logger = logging.getLogger(f"{APP_NAME}.services.readability")
|
8 |
|
9 |
class ReadabilityScorer:
|
10 |
+
async def compute(self, text: str) -> dict:
|
11 |
try:
|
12 |
text = text.strip()
|
13 |
if not text:
|
14 |
+
raise ServiceError(status_code=400, detail="Input text is empty for readability scoring.")
|
15 |
|
|
|
16 |
scores = {
|
17 |
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
18 |
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
|
|
22 |
"automated_readability_index": textstat.automated_readability_index(text),
|
23 |
}
|
24 |
|
|
|
25 |
friendly_scores = {
|
26 |
"flesch_reading_ease": {
|
27 |
+
"score": round(scores["flesch_reading_ease"], 2),
|
28 |
"label": "Flesch Reading Ease",
|
29 |
"description": "Higher is easier. 60–70 is plain English; 90+ is very easy."
|
30 |
},
|
31 |
"flesch_kincaid_grade": {
|
32 |
+
"score": round(scores["flesch_kincaid_grade"], 2),
|
33 |
"label": "Flesch-Kincaid Grade Level",
|
34 |
"description": "U.S. school grade. 8.0 means an 8th grader can understand it."
|
35 |
},
|
36 |
"gunning_fog_index": {
|
37 |
+
"score": round(scores["gunning_fog_index"], 2),
|
38 |
"label": "Gunning Fog Index",
|
39 |
"description": "Estimates years of formal education needed to understand."
|
40 |
},
|
41 |
"smog_index": {
|
42 |
+
"score": round(scores["smog_index"], 2),
|
43 |
"label": "SMOG Index",
|
44 |
"description": "Also estimates required years of education."
|
45 |
},
|
46 |
"coleman_liau_index": {
|
47 |
+
"score": round(scores["coleman_liau_index"], 2),
|
48 |
"label": "Coleman-Liau Index",
|
49 |
"description": "Grade level based on characters, not syllables."
|
50 |
},
|
51 |
"automated_readability_index": {
|
52 |
+
"score": round(scores["automated_readability_index"], 2),
|
53 |
"label": "Automated Readability Index",
|
54 |
"description": "Grade level using word and sentence lengths."
|
55 |
}
|
56 |
}
|
57 |
|
|
|
58 |
ease_score = scores["flesch_reading_ease"]
|
59 |
if ease_score >= 90:
|
60 |
summary = "Very easy to read. Easily understood by 11-year-olds."
|
|
|
67 |
else:
|
68 |
summary = "Very difficult. Best understood by university graduates."
|
69 |
|
70 |
+
return {
|
71 |
"readability_summary": summary,
|
72 |
"scores": friendly_scores
|
73 |
+
}
|
74 |
|
|
|
|
|
75 |
except Exception as e:
|
76 |
+
logger.error(f"Readability scoring error for text: '{text[:50]}...'", exc_info=True)
|
77 |
+
raise ServiceError(status_code=500, detail="An internal error occurred during readability scoring.") from e
|
78 |
+
|
79 |
+
# You can continue pasting the rest of your services here for production hardening
|
app/services/synonyms.py
CHANGED
@@ -1,21 +1,27 @@
|
|
1 |
import logging
|
2 |
import asyncio
|
3 |
-
from nltk.corpus import wordnet
|
4 |
-
from transformers import AutoTokenizer
|
5 |
-
from sentence_transformers import SentenceTransformer, util
|
6 |
from typing import List, Dict
|
7 |
from functools import lru_cache
|
8 |
|
9 |
-
# Assuming these are available in your project structure
|
10 |
from app.services.base import (
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
)
|
14 |
-
from app.core.
|
|
|
|
|
|
|
15 |
|
16 |
-
logger = logging.getLogger(
|
17 |
|
18 |
-
# Mapping spaCy POS tags to WordNet POS tags
|
19 |
SPACY_TO_WORDNET_POS = {
|
20 |
"NOUN": wordnet.NOUN,
|
21 |
"VERB": wordnet.VERB,
|
@@ -23,140 +29,129 @@ SPACY_TO_WORDNET_POS = {
|
|
23 |
"ADV": wordnet.ADV,
|
24 |
}
|
25 |
|
26 |
-
# Only consider these POS tags for synonym suggestions
|
27 |
CONTENT_POS_TAGS = {"NOUN", "VERB", "ADJ", "ADV"}
|
28 |
-
|
29 |
-
DEFAULT_BATCH_SIZE = settings.SENTENCE_TRANSFORMER_BATCH_SIZE if hasattr(settings, 'SENTENCE_TRANSFORMER_BATCH_SIZE') else 32
|
30 |
|
31 |
class SynonymSuggester:
|
32 |
def __init__(self):
|
33 |
-
self.
|
34 |
-
self.
|
35 |
-
|
36 |
-
def
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
model = timed_model_load(
|
41 |
-
"sentence_transformer",
|
42 |
-
lambda: SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)
|
43 |
)
|
44 |
-
|
45 |
-
return get_cached_model("synonym_sentence_model", load_fn)
|
46 |
|
47 |
-
def
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
51 |
|
52 |
async def suggest(self, text: str) -> dict:
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
not token.is_punct and
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
}
|
146 |
-
|
147 |
-
return model_response(result=final_suggestions)
|
148 |
|
149 |
@lru_cache(maxsize=5000)
|
150 |
def _get_wordnet_synonyms_cached(self, word: str, pos: str) -> List[str]:
|
151 |
-
"""
|
152 |
-
Retrieves synonyms for a word from WordNet, filtered by Part-of-Speech.
|
153 |
-
"""
|
154 |
synonyms = set()
|
155 |
-
for syn in wordnet.synsets(word, pos=pos):
|
156 |
for lemma in syn.lemmas():
|
157 |
name = lemma.name().replace("_", " ").lower()
|
158 |
-
# Basic filtering for valid word forms
|
159 |
if name.isalpha() and len(name) > 1:
|
160 |
synonyms.add(name)
|
161 |
-
synonyms.discard(word.lower())
|
162 |
-
return
|
|
|
1 |
import logging
|
2 |
import asyncio
|
|
|
|
|
|
|
3 |
from typing import List, Dict
|
4 |
from functools import lru_cache
|
5 |
|
|
|
6 |
from app.services.base import (
|
7 |
+
load_spacy_model,
|
8 |
+
load_sentence_transformer_model,
|
9 |
+
ensure_nltk_resource
|
10 |
+
)
|
11 |
+
from app.core.config import (
|
12 |
+
settings,
|
13 |
+
APP_NAME,
|
14 |
+
SPACY_MODEL_ID,
|
15 |
+
WORDNET_NLTK_ID,
|
16 |
+
SENTENCE_TRANSFORMER_MODEL_ID
|
17 |
)
|
18 |
+
from app.core.exceptions import ServiceError, ModelNotDownloadedError
|
19 |
+
|
20 |
+
from nltk.corpus import wordnet
|
21 |
+
from sentence_transformers.util import cos_sim
|
22 |
|
23 |
+
logger = logging.getLogger(f"{APP_NAME}.services.synonyms")
|
24 |
|
|
|
25 |
SPACY_TO_WORDNET_POS = {
|
26 |
"NOUN": wordnet.NOUN,
|
27 |
"VERB": wordnet.VERB,
|
|
|
29 |
"ADV": wordnet.ADV,
|
30 |
}
|
31 |
|
|
|
32 |
CONTENT_POS_TAGS = {"NOUN", "VERB", "ADJ", "ADV"}
|
33 |
+
|
|
|
34 |
|
35 |
class SynonymSuggester:
|
36 |
def __init__(self):
|
37 |
+
self._sentence_model = None
|
38 |
+
self._nlp = None
|
39 |
+
|
40 |
+
def _get_sentence_model(self):
|
41 |
+
if self._sentence_model is None:
|
42 |
+
self._sentence_model = load_sentence_transformer_model(
|
43 |
+
SENTENCE_TRANSFORMER_MODEL_ID
|
|
|
|
|
|
|
44 |
)
|
45 |
+
return self._sentence_model
|
|
|
46 |
|
47 |
+
def _get_nlp(self):
|
48 |
+
if self._nlp is None:
|
49 |
+
self._nlp = load_spacy_model(
|
50 |
+
SPACY_MODEL_ID
|
51 |
+
)
|
52 |
+
return self._nlp
|
53 |
|
54 |
async def suggest(self, text: str) -> dict:
|
55 |
+
try:
|
56 |
+
text = text.strip()
|
57 |
+
if not text:
|
58 |
+
raise ServiceError(status_code=400, detail="Input text is empty for synonym suggestion.")
|
59 |
+
|
60 |
+
sentence_model = self._get_sentence_model()
|
61 |
+
nlp = self._get_nlp()
|
62 |
+
await asyncio.to_thread(ensure_nltk_resource, WORDNET_NLTK_ID)
|
63 |
+
|
64 |
+
doc = await asyncio.to_thread(nlp, text)
|
65 |
+
all_suggestions: Dict[str, List[str]] = {}
|
66 |
+
|
67 |
+
original_text_embedding = await asyncio.to_thread(
|
68 |
+
sentence_model.encode, text,
|
69 |
+
convert_to_tensor=True,
|
70 |
+
normalize_embeddings=True
|
71 |
+
)
|
72 |
+
|
73 |
+
candidate_data = []
|
74 |
+
|
75 |
+
for token in doc:
|
76 |
+
if token.pos_ in CONTENT_POS_TAGS and len(token.text.strip()) > 2 and not token.is_punct and not token.is_space:
|
77 |
+
original_word = token.text
|
78 |
+
word_start = token.idx
|
79 |
+
word_end = token.idx + len(original_word)
|
80 |
+
wordnet_pos = SPACY_TO_WORDNET_POS.get(token.pos_)
|
81 |
+
if not wordnet_pos:
|
82 |
+
continue
|
83 |
+
|
84 |
+
wordnet_candidates = await asyncio.to_thread(
|
85 |
+
self._get_wordnet_synonyms_cached, original_word, wordnet_pos
|
86 |
+
)
|
87 |
+
if not wordnet_candidates:
|
88 |
+
continue
|
89 |
+
|
90 |
+
if original_word not in all_suggestions:
|
91 |
+
all_suggestions[original_word] = []
|
92 |
+
|
93 |
+
for candidate in wordnet_candidates:
|
94 |
+
temp_sentence = text[:word_start] + candidate + text[word_end:]
|
95 |
+
candidate_data.append({
|
96 |
+
"original_word": original_word,
|
97 |
+
"wordnet_candidate": candidate,
|
98 |
+
"temp_sentence": temp_sentence,
|
99 |
+
})
|
100 |
+
|
101 |
+
if not candidate_data:
|
102 |
+
return {"suggestions": {}}
|
103 |
+
|
104 |
+
all_candidate_sentences = [c["temp_sentence"] for c in candidate_data]
|
105 |
+
all_candidate_embeddings = await asyncio.to_thread(
|
106 |
+
sentence_model.encode,
|
107 |
+
all_candidate_sentences,
|
108 |
+
batch_size=settings.SENTENCE_TRANSFORMER_BATCH_SIZE,
|
109 |
+
convert_to_tensor=True,
|
110 |
+
normalize_embeddings=True
|
111 |
+
)
|
112 |
+
|
113 |
+
if original_text_embedding.dim() == 1:
|
114 |
+
original_text_embedding = original_text_embedding.unsqueeze(0)
|
115 |
+
|
116 |
+
cosine_scores = cos_sim(original_text_embedding, all_candidate_embeddings)[0]
|
117 |
+
|
118 |
+
similarity_threshold = 0.65
|
119 |
+
top_n = 5
|
120 |
+
temp_scored: Dict[str, List[tuple]] = {word: [] for word in all_suggestions}
|
121 |
+
|
122 |
+
for i, data in enumerate(candidate_data):
|
123 |
+
word = data["original_word"]
|
124 |
+
candidate = data["wordnet_candidate"]
|
125 |
+
score = cosine_scores[i].item()
|
126 |
+
if score >= similarity_threshold and candidate.lower() != word.lower():
|
127 |
+
temp_scored[word].append((score, candidate))
|
128 |
+
|
129 |
+
final_suggestions = {}
|
130 |
+
for word, scored in temp_scored.items():
|
131 |
+
if scored:
|
132 |
+
sorted_unique = []
|
133 |
+
seen = set()
|
134 |
+
for score, candidate in sorted(scored, key=lambda x: x[0], reverse=True):
|
135 |
+
if candidate not in seen:
|
136 |
+
sorted_unique.append(candidate)
|
137 |
+
seen.add(candidate)
|
138 |
+
if len(sorted_unique) >= top_n:
|
139 |
+
break
|
140 |
+
final_suggestions[word] = sorted_unique
|
141 |
+
|
142 |
+
return {"suggestions": final_suggestions}
|
143 |
+
|
144 |
+
except Exception as e:
|
145 |
+
logger.error(f"Synonym suggestion error for text: '{text[:50]}...'", exc_info=True)
|
146 |
+
raise ServiceError(status_code=500, detail="An internal error occurred during synonym suggestion.") from e
|
|
|
|
|
|
|
147 |
|
148 |
@lru_cache(maxsize=5000)
|
149 |
def _get_wordnet_synonyms_cached(self, word: str, pos: str) -> List[str]:
|
|
|
|
|
|
|
150 |
synonyms = set()
|
151 |
+
for syn in wordnet.synsets(word, pos=pos):
|
152 |
for lemma in syn.lemmas():
|
153 |
name = lemma.name().replace("_", " ").lower()
|
|
|
154 |
if name.isalpha() and len(name) > 1:
|
155 |
synonyms.add(name)
|
156 |
+
synonyms.discard(word.lower())
|
157 |
+
return sorted(synonyms)
|
app/services/tone_classification.py
CHANGED
@@ -1,75 +1,60 @@
|
|
1 |
import logging
|
2 |
import torch
|
3 |
-
from
|
4 |
-
from app.
|
5 |
-
from app.core.
|
6 |
|
7 |
-
logger = logging.getLogger(
|
8 |
|
9 |
class ToneClassifier:
|
10 |
def __init__(self):
|
11 |
-
self.
|
12 |
-
|
13 |
-
def
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
model=settings.TONE_MODEL,
|
21 |
-
device=0 if torch.cuda.is_available() else -1,
|
22 |
-
return_all_scores=True # Keep this true
|
23 |
)
|
24 |
-
|
25 |
-
return model
|
26 |
-
|
27 |
-
return get_cached_model("tone_model", load_fn)
|
28 |
|
29 |
-
def classify(self, text: str) -> dict:
|
30 |
try:
|
31 |
text = text.strip()
|
32 |
if not text:
|
33 |
-
raise ServiceError("Input text is empty.")
|
34 |
|
35 |
-
|
|
|
36 |
|
37 |
-
# Check for expected pipeline output format
|
38 |
if not (isinstance(raw_results, list) and raw_results and isinstance(raw_results[0], list)):
|
39 |
logger.error(f"Unexpected raw_results format from pipeline: {raw_results}")
|
40 |
-
|
41 |
|
42 |
scores_for_text = raw_results[0]
|
43 |
-
|
44 |
-
# Sort the emotions by score in descending order
|
45 |
sorted_emotions = sorted(scores_for_text, key=lambda x: x['score'], reverse=True)
|
46 |
-
|
47 |
-
# --- NEW LOGGING ADDITIONS START HERE ---
|
48 |
|
49 |
-
# Log all emotion scores for the given text (useful for seeing the full distribution)
|
50 |
logger.debug(f"Input Text: '{text}'")
|
51 |
logger.debug("--- Emotion Scores (Label: Score) ---")
|
52 |
for emotion in sorted_emotions:
|
53 |
logger.debug(f" {emotion['label']}: {emotion['score']:.4f}")
|
54 |
logger.debug("-------------------------------------")
|
55 |
|
56 |
-
# --- NEW LOGGING ADDITIONS END HERE ---
|
57 |
-
|
58 |
top_emotion = sorted_emotions[0]
|
59 |
predicted_label = top_emotion.get("label", "Unknown")
|
60 |
predicted_score = top_emotion.get("score", 0.0)
|
61 |
|
62 |
-
# Apply the confidence threshold
|
63 |
if predicted_score >= settings.TONE_CONFIDENCE_THRESHOLD:
|
64 |
-
logger.info(f"Final prediction for '{text}': '{predicted_label}' (Score: {predicted_score:.4f}, Above Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f})")
|
65 |
-
return
|
66 |
else:
|
67 |
-
logger.info(f"Final prediction for '{text}': 'neutral' (Top Score: {predicted_score:.4f}, Below Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f}).")
|
68 |
-
return
|
69 |
|
70 |
-
except ServiceError as se:
|
71 |
-
logger.error(f"Tone classification ServiceError for text '{text}': {se}")
|
72 |
-
return model_response(error=str(se))
|
73 |
except Exception as e:
|
74 |
-
logger.error(f"Tone classification unexpected error for text '{text}': {e}", exc_info=True)
|
75 |
-
|
|
|
|
|
|
1 |
import logging
|
2 |
import torch
|
3 |
+
from app.services.base import load_hf_pipeline
|
4 |
+
from app.core.config import APP_NAME, settings
|
5 |
+
from app.core.exceptions import ServiceError, ModelNotDownloadedError
|
6 |
|
7 |
+
logger = logging.getLogger(f"{APP_NAME}.services.tone_classification")
|
8 |
|
9 |
class ToneClassifier:
|
10 |
def __init__(self):
|
11 |
+
self._classifier = None
|
12 |
+
|
13 |
+
def _get_classifier(self):
|
14 |
+
if self._classifier is None:
|
15 |
+
self._classifier = load_hf_pipeline(
|
16 |
+
model_id=settings.TONE_MODEL_ID,
|
17 |
+
task="text-classification",
|
18 |
+
feature_name="Tone Classification",
|
19 |
+
top_k=None
|
|
|
|
|
|
|
20 |
)
|
21 |
+
return self._classifier
|
|
|
|
|
|
|
22 |
|
23 |
+
async def classify(self, text: str) -> dict:
|
24 |
try:
|
25 |
text = text.strip()
|
26 |
if not text:
|
27 |
+
raise ServiceError(status_code=400, detail="Input text is empty for tone classification.")
|
28 |
|
29 |
+
classifier = self._get_classifier()
|
30 |
+
raw_results = classifier(text)
|
31 |
|
|
|
32 |
if not (isinstance(raw_results, list) and raw_results and isinstance(raw_results[0], list)):
|
33 |
logger.error(f"Unexpected raw_results format from pipeline: {raw_results}")
|
34 |
+
raise ServiceError(status_code=500, detail="Unexpected model output format for tone classification.")
|
35 |
|
36 |
scores_for_text = raw_results[0]
|
|
|
|
|
37 |
sorted_emotions = sorted(scores_for_text, key=lambda x: x['score'], reverse=True)
|
|
|
|
|
38 |
|
|
|
39 |
logger.debug(f"Input Text: '{text}'")
|
40 |
logger.debug("--- Emotion Scores (Label: Score) ---")
|
41 |
for emotion in sorted_emotions:
|
42 |
logger.debug(f" {emotion['label']}: {emotion['score']:.4f}")
|
43 |
logger.debug("-------------------------------------")
|
44 |
|
|
|
|
|
45 |
top_emotion = sorted_emotions[0]
|
46 |
predicted_label = top_emotion.get("label", "Unknown")
|
47 |
predicted_score = top_emotion.get("score", 0.0)
|
48 |
|
|
|
49 |
if predicted_score >= settings.TONE_CONFIDENCE_THRESHOLD:
|
50 |
+
logger.info(f"Final prediction for '{text[:50]}...': '{predicted_label}' (Score: {predicted_score:.4f}, Above Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f})")
|
51 |
+
return {"tone": predicted_label}
|
52 |
else:
|
53 |
+
logger.info(f"Final prediction for '{text[:50]}...': 'neutral' (Top Score: {predicted_score:.4f}, Below Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f}).")
|
54 |
+
return {"tone": "neutral"}
|
55 |
|
|
|
|
|
|
|
56 |
except Exception as e:
|
57 |
+
logger.error(f"Tone classification unexpected error for text '{text[:50]}...': {e}", exc_info=True)
|
58 |
+
raise ServiceError(status_code=500, detail="An internal error occurred during tone classification.") from e
|
59 |
+
|
60 |
+
|
app/services/translation.py
CHANGED
@@ -1,45 +1,47 @@
|
|
1 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
-
import torch
|
3 |
import logging
|
4 |
-
from app.services.base import
|
5 |
-
from app.core.config import settings
|
|
|
6 |
|
7 |
-
logger = logging.getLogger(
|
8 |
|
9 |
class Translator:
|
10 |
def __init__(self):
|
11 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
def _load_model(self):
|
14 |
-
def load_fn():
|
15 |
-
tokenizer = timed_model_load("translate_tokenizer", lambda: AutoTokenizer.from_pretrained(settings.TRANSLATION_MODEL))
|
16 |
-
model = timed_model_load("translate_model", lambda: AutoModelForSeq2SeqLM.from_pretrained(settings.TRANSLATION_MODEL))
|
17 |
-
model = model.to(DEVICE).eval()
|
18 |
-
return tokenizer, model
|
19 |
-
return get_cached_model("translate", load_fn)
|
20 |
-
|
21 |
-
def translate(self, text: str, target_lang: str) -> dict:
|
22 |
try:
|
23 |
-
|
24 |
-
|
|
|
|
|
25 |
|
26 |
-
|
27 |
-
raise ServiceError("Input text is empty.")
|
28 |
-
if not target_lang:
|
29 |
-
raise ServiceError("Target language is empty.")
|
30 |
-
if target_lang not in settings.SUPPORTED_TRANSLATION_LANGUAGES:
|
31 |
-
raise ServiceError(f"Unsupported target language: {target_lang}")
|
32 |
|
33 |
-
prompt = f">>{target_lang}<< {text}"
|
34 |
-
with torch.no_grad():
|
35 |
-
inputs = self.tokenizer([prompt], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
|
36 |
-
outputs = self.model.generate(**inputs, max_length=256, num_beams=1, early_stopping=True)
|
37 |
-
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
38 |
-
return model_response(result=result)
|
39 |
-
|
40 |
-
except ServiceError as se:
|
41 |
-
return model_response(error=str(se))
|
42 |
except Exception as e:
|
43 |
-
logger.error(f"Translation error: {
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
from app.services.base import load_hf_pipeline
|
3 |
+
from app.core.config import settings, APP_NAME
|
4 |
+
from app.core.exceptions import ServiceError
|
5 |
|
6 |
+
logger = logging.getLogger(f"{APP_NAME}.services.translation")
|
7 |
|
8 |
class Translator:
|
9 |
def __init__(self):
|
10 |
+
self._pipeline = None
|
11 |
+
|
12 |
+
def _get_pipeline(self):
|
13 |
+
if self._pipeline is None:
|
14 |
+
logger.info("Loading translation pipeline...")
|
15 |
+
self._pipeline = load_hf_pipeline(
|
16 |
+
model_id=settings.TRANSLATION_MODEL_ID,
|
17 |
+
task="translation",
|
18 |
+
feature_name="Translation"
|
19 |
+
)
|
20 |
+
return self._pipeline
|
21 |
+
|
22 |
+
async def translate(self, text: str, target_lang: str) -> dict:
|
23 |
+
text = text.strip()
|
24 |
+
target_lang = target_lang.strip()
|
25 |
+
|
26 |
+
if not text:
|
27 |
+
raise ServiceError(status_code=400, detail="Input text is empty for translation.")
|
28 |
+
if not target_lang:
|
29 |
+
raise ServiceError(status_code=400, detail="Target language is empty for translation.")
|
30 |
+
if target_lang not in settings.SUPPORTED_TRANSLATION_LANGUAGES:
|
31 |
+
raise ServiceError(
|
32 |
+
status_code=400,
|
33 |
+
detail=f"Unsupported target language: {target_lang}. "
|
34 |
+
f"Supported languages are: {', '.join(settings.SUPPORTED_TRANSLATION_LANGUAGES)}"
|
35 |
+
)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
try:
|
38 |
+
pipeline = self._get_pipeline()
|
39 |
+
prompt = f">>{target_lang}<< {text}"
|
40 |
+
result = pipeline(prompt, max_length=256, num_beams=1, early_stopping=True)[0]
|
41 |
+
translated_text = result.get("translation_text") or result.get("generated_text")
|
42 |
|
43 |
+
return {"translated_text": translated_text.strip()}
|
|
|
|
|
|
|
|
|
|
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
except Exception as e:
|
46 |
+
logger.error(f"Translation error for text: '{text[:50]}...' to '{target_lang}'", exc_info=True)
|
47 |
+
raise ServiceError(status_code=500, detail="An internal error occurred during translation.") from e
|
|
app/services/voice_detection.py
CHANGED
@@ -1,38 +1,55 @@
|
|
|
|
1 |
import logging
|
2 |
-
from app.services.base import
|
|
|
|
|
3 |
|
4 |
-
logger = logging.getLogger(
|
5 |
|
6 |
class VoiceDetector:
|
7 |
def __init__(self):
|
8 |
-
self.
|
9 |
|
10 |
-
def
|
|
|
|
|
|
|
|
|
|
|
11 |
try:
|
12 |
text = text.strip()
|
13 |
if not text:
|
14 |
-
raise ServiceError("Input text is empty.")
|
|
|
|
|
|
|
15 |
|
16 |
-
doc = self.nlp(text)
|
17 |
passive_sentences = 0
|
18 |
total_sentences = 0
|
19 |
|
20 |
for sent in doc.sents:
|
21 |
total_sentences += 1
|
|
|
22 |
for token in sent:
|
23 |
-
if token.dep_ == "nsubjpass":
|
24 |
-
|
25 |
break
|
|
|
|
|
26 |
|
27 |
if total_sentences == 0:
|
28 |
-
return
|
29 |
|
30 |
ratio = passive_sentences / total_sentences
|
31 |
-
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
38 |
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
import logging
|
3 |
+
from app.services.base import load_spacy_model
|
4 |
+
from app.core.config import APP_NAME, SPACY_MODEL_ID
|
5 |
+
from app.core.exceptions import ServiceError, ModelNotDownloadedError
|
6 |
|
7 |
+
logger = logging.getLogger(f"{APP_NAME}.services.voice_detection")
|
8 |
|
9 |
class VoiceDetector:
|
10 |
def __init__(self):
|
11 |
+
self._nlp = None
|
12 |
|
13 |
+
def _get_nlp(self):
|
14 |
+
if self._nlp is None:
|
15 |
+
self._nlp = load_spacy_model(SPACY_MODEL_ID)
|
16 |
+
return self._nlp
|
17 |
+
|
18 |
+
async def classify(self, text: str) -> dict:
|
19 |
try:
|
20 |
text = text.strip()
|
21 |
if not text:
|
22 |
+
raise ServiceError(status_code=400, detail="Input text is empty for voice detection.")
|
23 |
+
|
24 |
+
nlp = self._get_nlp()
|
25 |
+
doc = await asyncio.to_thread(nlp, text)
|
26 |
|
|
|
27 |
passive_sentences = 0
|
28 |
total_sentences = 0
|
29 |
|
30 |
for sent in doc.sents:
|
31 |
total_sentences += 1
|
32 |
+
is_passive_sentence = False
|
33 |
for token in sent:
|
34 |
+
if token.dep_ == "nsubjpass" and token.head.pos_ == "VERB":
|
35 |
+
is_passive_sentence = True
|
36 |
break
|
37 |
+
if is_passive_sentence:
|
38 |
+
passive_sentences += 1
|
39 |
|
40 |
if total_sentences == 0:
|
41 |
+
return {"voice": "unknown", "passive_ratio": 0.0}
|
42 |
|
43 |
ratio = passive_sentences / total_sentences
|
44 |
+
voice_type = "Passive" if ratio > 0.1 else "Active"
|
45 |
|
46 |
+
return {
|
47 |
+
"voice": voice_type,
|
48 |
+
"passive_ratio": round(ratio, 3),
|
49 |
+
"passive_sentences_count": passive_sentences,
|
50 |
+
"total_sentences_count": total_sentences
|
51 |
+
}
|
52 |
|
53 |
+
except Exception as e:
|
54 |
+
logger.error(f"Voice detection error for text: '{text[:50]}...': {e}", exc_info=True)
|
55 |
+
raise ServiceError(status_code=500, detail="An internal error occurred during voice detection.") from e
|