Spaces:
Runtime error
Runtime error
iamspruce
commited on
Commit
·
ce2ce69
1
Parent(s):
b7b6eaf
added synonyms generator
Browse files- .DS_Store +0 -0
- .gitignore +2 -1
- .python-version +1 -0
- Dockerfile +1 -1
- app/.DS_Store +0 -0
- app/core/app.py +27 -21
- app/core/config.py +38 -4
- app/core/logging.py +6 -2
- app/core/middleware.py +9 -9
- app/core/prompts.py +2 -4
- app/core/security.py +4 -2
- app/queue.py +104 -0
- app/routers/analyze.py +45 -0
- app/routers/grammar.py +21 -45
- app/routers/paraphrase.py +14 -2
- app/routers/readability.py +11 -23
- app/routers/synonyms.py +21 -0
- app/routers/tone.py +15 -2
- app/routers/translate.py +14 -5
- app/services/base.py +17 -10
- app/services/conciseness_suggestion.py +0 -30
- app/services/gpt4_rewrite.py +27 -25
- app/services/grammar.py +60 -14
- app/services/inclusive_language.py +60 -54
- app/services/paraphrase.py +19 -14
- app/services/readability.py +80 -0
- app/services/synonyms.py +162 -0
- app/services/tone_classification.py +62 -23
- app/services/translation.py +24 -19
- app/services/vocabulary_enhancement.py +0 -30
- app/services/voice_detection.py +16 -11
- app/test/test_rewrite.py +0 -15
- app/tests/__init__.py +0 -0
- app/tests/test_api.py +115 -0
- app/tests/test_services.py +149 -0
- requirements.txt +6 -6
- run.py +4 -9
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitignore
CHANGED
@@ -5,8 +5,9 @@ __pycache__/
|
|
5 |
*.pyo
|
6 |
*.pyd
|
7 |
.Python
|
8 |
-
env
|
9 |
venv/
|
|
|
10 |
*.env # Local environment variable files (e.g., for API keys)
|
11 |
.venv/ # Another common virtual environment name
|
12 |
|
|
|
5 |
*.pyo
|
6 |
*.pyd
|
7 |
.Python
|
8 |
+
.env
|
9 |
venv/
|
10 |
+
torch-venv/
|
11 |
*.env # Local environment variable files (e.g., for API keys)
|
12 |
.venv/ # Another common virtual environment name
|
13 |
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.10.13
|
Dockerfile
CHANGED
@@ -25,4 +25,4 @@ COPY app ./app
|
|
25 |
# Expose the port your FastAPI application will run on
|
26 |
EXPOSE 7860
|
27 |
|
28 |
-
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
25 |
# Expose the port your FastAPI application will run on
|
26 |
EXPOSE 7860
|
27 |
|
28 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "4"]
|
app/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
app/core/app.py
CHANGED
@@ -1,31 +1,37 @@
|
|
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.middleware.gzip import GZipMiddleware
|
3 |
-
from
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
|
14 |
def create_app() -> FastAPI:
|
15 |
-
app = FastAPI()
|
16 |
app.add_middleware(GZipMiddleware, minimum_size=500)
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
@app.get("/")
|
28 |
def root():
|
29 |
return {"message": "Welcome to Wellsaid API"}
|
30 |
|
31 |
-
return app
|
|
|
1 |
+
import os
|
2 |
from fastapi import FastAPI
|
3 |
from fastapi.middleware.gzip import GZipMiddleware
|
4 |
+
from contextlib import asynccontextmanager
|
5 |
+
from app.routers import grammar, tone, voice, inclusive_language, readability, paraphrase, translate, rewrite, analyze
|
6 |
+
from app.queue import start_workers
|
7 |
+
from app.core.middleware import setup_middlewares
|
8 |
+
|
9 |
+
@asynccontextmanager
|
10 |
+
async def lifespan(app: FastAPI):
|
11 |
+
num_workers = int(os.getenv("WORKER_COUNT", 4))
|
12 |
+
start_workers(num_workers)
|
13 |
+
yield
|
14 |
|
15 |
def create_app() -> FastAPI:
|
16 |
+
app = FastAPI(lifespan=lifespan)
|
17 |
app.add_middleware(GZipMiddleware, minimum_size=500)
|
18 |
+
setup_middlewares(app)
|
19 |
+
|
20 |
+
for router, tag in [
|
21 |
+
(grammar.router, "Grammar"),
|
22 |
+
(tone.router, "Tone"),
|
23 |
+
(voice.router, "Voice"),
|
24 |
+
(inclusive_language.router, "Inclusive Language"),
|
25 |
+
(readability.router, "Readability"),
|
26 |
+
(paraphrase.router, "Paraphrasing"),
|
27 |
+
(translate.router, "Translation"),
|
28 |
+
(rewrite.router, "Rewrite"),
|
29 |
+
(analyze.router, "Analyze")
|
30 |
+
]:
|
31 |
+
app.include_router(router, tags=[tag])
|
32 |
|
33 |
@app.get("/")
|
34 |
def root():
|
35 |
return {"message": "Welcome to Wellsaid API"}
|
36 |
|
37 |
+
return app
|
app/core/config.py
CHANGED
@@ -1,9 +1,43 @@
|
|
1 |
from pydantic_settings import BaseSettings
|
|
|
|
|
2 |
|
3 |
class Settings(BaseSettings):
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
|
|
9 |
settings = Settings()
|
|
|
1 |
from pydantic_settings import BaseSettings
|
2 |
+
from typing import List
|
3 |
+
|
4 |
|
5 |
class Settings(BaseSettings):
|
6 |
+
# Server settings (user-configurable)
|
7 |
+
HOST: str = "0.0.0.0"
|
8 |
+
PORT: int = 7860
|
9 |
+
RELOAD: bool = True
|
10 |
+
|
11 |
+
# Security & workers
|
12 |
+
WELLSAID_API_KEY: str = "12345"
|
13 |
+
WORKER_COUNT: int = 4
|
14 |
+
|
15 |
+
# Fixed/internal settings
|
16 |
+
INCLUSIVE_RULES_DIR: str = "app/data/en"
|
17 |
+
OPENAI_MODEL: str = "gpt-4o"
|
18 |
+
OPENAI_TEMPERATURE: float = 0.7
|
19 |
+
OPENAI_MAX_TOKENS: int = 512
|
20 |
+
SUPPORTED_TRANSLATION_LANGUAGES: List[str] = [
|
21 |
+
"fr", "fr_BE", "fr_CA", "fr_FR", "wa", "frp", "oc", "ca", "rm", "lld",
|
22 |
+
"fur", "lij", "lmo", "es", "es_AR", "es_CL", "es_CO", "es_CR", "es_DO",
|
23 |
+
"es_EC", "es_ES", "es_GT", "es_HN", "es_MX", "es_NI", "es_PA", "es_PE",
|
24 |
+
"es_PR", "es_SV", "es_UY", "es_VE", "pt", "pt_br", "pt_BR", "pt_PT",
|
25 |
+
"gl", "lad", "an", "mwl", "it", "it_IT", "co", "nap", "scn", "vec",
|
26 |
+
"sc", "ro", "la"
|
27 |
+
]
|
28 |
+
|
29 |
+
# Model names
|
30 |
+
GRAMMAR_MODEL: str = "visheratin/t5-efficient-mini-grammar-correction"
|
31 |
+
PARAPHRASE_MODEL: str = "humarin/chatgpt_paraphraser_on_T5_base"
|
32 |
+
TONE_MODEL: str = "boltuix/NeuroFeel"
|
33 |
+
TONE_CONFIDENCE_THRESHOLD: float = 0.2
|
34 |
+
TRANSLATION_MODEL: str = "Helsinki-NLP/opus-mt-en-ROMANCE"
|
35 |
+
SENTENCE_TRANSFORMER_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
|
36 |
+
|
37 |
+
class Config:
|
38 |
+
env_file = ".env"
|
39 |
+
case_sensitive = True
|
40 |
+
|
41 |
|
42 |
+
# Singleton instance
|
43 |
settings = Settings()
|
app/core/logging.py
CHANGED
@@ -1,8 +1,12 @@
|
|
1 |
-
|
2 |
import logging
|
3 |
|
4 |
def configure_logging():
|
5 |
logging.basicConfig(
|
6 |
level=logging.INFO,
|
7 |
-
format="%(asctime)s - %(levelname)s - %(message)s"
|
|
|
|
|
|
|
|
|
8 |
)
|
|
|
1 |
+
import os
|
2 |
import logging
|
3 |
|
4 |
def configure_logging():
|
5 |
logging.basicConfig(
|
6 |
level=logging.INFO,
|
7 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
8 |
+
handlers=[
|
9 |
+
logging.StreamHandler(),
|
10 |
+
logging.FileHandler(os.getenv("LOG_FILE", "app.log"))
|
11 |
+
]
|
12 |
)
|
app/core/middleware.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
from fastapi import FastAPI, Request
|
2 |
from fastapi.responses import JSONResponse
|
3 |
-
from slowapi import Limiter
|
4 |
from slowapi.util import get_remote_address
|
5 |
from slowapi.errors import RateLimitExceeded
|
6 |
import logging
|
|
|
7 |
|
8 |
-
|
|
|
|
|
|
|
9 |
|
10 |
def setup_middlewares(app: FastAPI):
|
11 |
@app.exception_handler(Exception)
|
12 |
async def unhandled_exception_handler(request: Request, exc: Exception):
|
13 |
-
logging.exception("Unhandled exception")
|
14 |
return JSONResponse(status_code=500, content={"detail": "Internal server error"})
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
return JSONResponse(status_code=429, content={"detail": "Rate limit exceeded"})
|
19 |
-
|
20 |
-
limiter.init_app(app)
|
21 |
-
app.state.limiter = limiter
|
|
|
1 |
from fastapi import FastAPI, Request
|
2 |
from fastapi.responses import JSONResponse
|
3 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
4 |
from slowapi.util import get_remote_address
|
5 |
from slowapi.errors import RateLimitExceeded
|
6 |
import logging
|
7 |
+
import os
|
8 |
|
9 |
+
def get_rate_limit():
|
10 |
+
return os.getenv("RATE_LIMIT", "100/minute")
|
11 |
+
|
12 |
+
limiter = Limiter(key_func=get_remote_address, headers_enabled=True, default_limits=[get_rate_limit()])
|
13 |
|
14 |
def setup_middlewares(app: FastAPI):
|
15 |
@app.exception_handler(Exception)
|
16 |
async def unhandled_exception_handler(request: Request, exc: Exception):
|
17 |
+
logging.exception(f"Unhandled exception from {request.client.host}")
|
18 |
return JSONResponse(status_code=500, content={"detail": "Internal server error"})
|
19 |
|
20 |
+
app.state.limiter = limiter
|
21 |
+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
|
|
|
|
|
|
|
app/core/prompts.py
CHANGED
@@ -7,9 +7,6 @@ def summarize_prompt(text: str) -> str:
|
|
7 |
def clarity_prompt(text: str) -> str:
|
8 |
return f"Improve the clarity of the following sentence:\n{text.strip()}"
|
9 |
|
10 |
-
def conciseness_prompt(text: str) -> str:
|
11 |
-
return f"Make the following sentence more concise:\n{text.strip()}"
|
12 |
-
|
13 |
def rewrite_prompt(text: str, instruction: str) -> str:
|
14 |
return f"{instruction.strip()}\n{text.strip()}"
|
15 |
|
@@ -27,4 +24,5 @@ def concise_prompt(text: str) -> str:
|
|
27 |
"Rewrite the following text to be more concise and to the point, "
|
28 |
"removing any verbose phrases, redundant words, or unnecessary clauses. "
|
29 |
"Maintain the original meaning and professional tone.\n" + text.strip()
|
30 |
-
)
|
|
|
|
7 |
def clarity_prompt(text: str) -> str:
|
8 |
return f"Improve the clarity of the following sentence:\n{text.strip()}"
|
9 |
|
|
|
|
|
|
|
10 |
def rewrite_prompt(text: str, instruction: str) -> str:
|
11 |
return f"{instruction.strip()}\n{text.strip()}"
|
12 |
|
|
|
24 |
"Rewrite the following text to be more concise and to the point, "
|
25 |
"removing any verbose phrases, redundant words, or unnecessary clauses. "
|
26 |
"Maintain the original meaning and professional tone.\n" + text.strip()
|
27 |
+
)
|
28 |
+
|
app/core/security.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
import os
|
2 |
-
|
|
|
3 |
|
4 |
API_KEY = os.getenv("WELLSAID_API_KEY", "12345")
|
5 |
|
6 |
def verify_api_key(x_api_key: str = Header(...)) -> None:
|
7 |
if not x_api_key or x_api_key != API_KEY:
|
|
|
8 |
raise HTTPException(
|
9 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
10 |
detail="Invalid or missing API key"
|
11 |
-
)
|
|
|
1 |
import os
|
2 |
+
import logging
|
3 |
+
from fastapi import Header, HTTPException, status
|
4 |
|
5 |
API_KEY = os.getenv("WELLSAID_API_KEY", "12345")
|
6 |
|
7 |
def verify_api_key(x_api_key: str = Header(...)) -> None:
|
8 |
if not x_api_key or x_api_key != API_KEY:
|
9 |
+
logging.warning("Unauthorized access attempt with key: %s", x_api_key)
|
10 |
raise HTTPException(
|
11 |
status_code=status.HTTP_401_UNAUTHORIZED,
|
12 |
detail="Invalid or missing API key"
|
13 |
+
)
|
app/queue.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
import uuid
|
5 |
+
import inspect
|
6 |
+
|
7 |
+
from app.services.grammar import GrammarCorrector
|
8 |
+
from app.services.paraphrase import Paraphraser
|
9 |
+
from app.services.translation import Translator
|
10 |
+
from app.services.tone_classification import ToneClassifier
|
11 |
+
from app.services.inclusive_language import InclusiveLanguageChecker
|
12 |
+
from app.services.voice_detection import VoiceDetector
|
13 |
+
from app.services.readability import ReadabilityScorer
|
14 |
+
from app.services.synonyms import SynonymSuggester
|
15 |
+
from app.core.config import settings
|
16 |
+
|
17 |
+
# Configure logging
|
18 |
+
logging.basicConfig(
|
19 |
+
level=logging.DEBUG if getattr(settings, "DEBUG", False) else logging.INFO,
|
20 |
+
format="%(asctime)s [%(levelname)s] %(message)s"
|
21 |
+
)
|
22 |
+
|
23 |
+
# Initialize service instances
|
24 |
+
grammar = GrammarCorrector()
|
25 |
+
paraphraser = Paraphraser()
|
26 |
+
translator = Translator()
|
27 |
+
tone = ToneClassifier()
|
28 |
+
inclusive = InclusiveLanguageChecker()
|
29 |
+
voice_analyzer = VoiceDetector()
|
30 |
+
readability = ReadabilityScorer()
|
31 |
+
synonyms = SynonymSuggester()
|
32 |
+
|
33 |
+
# Create async task queue (optional: maxsize=100 to prevent overload)
|
34 |
+
task_queue = asyncio.Queue(maxsize=100)
|
35 |
+
|
36 |
+
# Task handler map
|
37 |
+
SERVICE_HANDLERS = {
|
38 |
+
"grammar": lambda p: grammar.correct(p["text"]),
|
39 |
+
"paraphrase": lambda p: paraphraser.paraphrase(p["text"]),
|
40 |
+
"translate": lambda p: translator.translate(p["text"], p["target_lang"]),
|
41 |
+
"tone": lambda p: tone.classify(p["text"]),
|
42 |
+
"inclusive": lambda p: inclusive.check(p["text"]),
|
43 |
+
"voice": lambda p: voice_analyzer.classify(p["text"]),
|
44 |
+
"readability": lambda p: readability.compute(p["text"]),
|
45 |
+
"synonyms": lambda p: synonyms.suggest(p["text"]), # ✅ This is async
|
46 |
+
}
|
47 |
+
|
48 |
+
async def worker(worker_id: int):
|
49 |
+
logging.info(f"Worker-{worker_id} started")
|
50 |
+
|
51 |
+
while True:
|
52 |
+
task = await task_queue.get()
|
53 |
+
future = task["future"]
|
54 |
+
task_type = task["type"]
|
55 |
+
payload = task["payload"]
|
56 |
+
task_id = task["id"]
|
57 |
+
|
58 |
+
start_time = time.perf_counter()
|
59 |
+
logging.info(f"[Worker-{worker_id}] Processing Task-{task_id} | Type: {task_type} | Queue size: {task_queue.qsize()}")
|
60 |
+
|
61 |
+
try:
|
62 |
+
handler = SERVICE_HANDLERS.get(task_type)
|
63 |
+
if not handler:
|
64 |
+
raise ValueError(f"Unknown task type: {task_type}")
|
65 |
+
|
66 |
+
result = handler(payload)
|
67 |
+
if inspect.isawaitable(result):
|
68 |
+
result = await result
|
69 |
+
|
70 |
+
elapsed = time.perf_counter() - start_time
|
71 |
+
logging.info(f"[Worker-{worker_id}] Finished Task-{task_id} in {elapsed:.2f}s")
|
72 |
+
|
73 |
+
if not future.done():
|
74 |
+
future.set_result(result)
|
75 |
+
|
76 |
+
except Exception as e:
|
77 |
+
logging.error(f"[Worker-{worker_id}] Error in Task-{task_id} ({task_type}): {e}")
|
78 |
+
if not future.done():
|
79 |
+
future.set_result({"error": str(e)})
|
80 |
+
|
81 |
+
task_queue.task_done()
|
82 |
+
|
83 |
+
def start_workers(count: int = 2):
|
84 |
+
for i in range(count):
|
85 |
+
asyncio.create_task(worker(i))
|
86 |
+
|
87 |
+
async def enqueue_task(task_type: str, payload: dict, timeout: float = 10.0):
|
88 |
+
future = asyncio.get_event_loop().create_future()
|
89 |
+
task_id = str(uuid.uuid4())[:8]
|
90 |
+
|
91 |
+
await task_queue.put({
|
92 |
+
"future": future,
|
93 |
+
"type": task_type,
|
94 |
+
"payload": payload,
|
95 |
+
"id": task_id
|
96 |
+
})
|
97 |
+
|
98 |
+
logging.info(f"[ENQUEUE] Task-{task_id} added to queue | Type: {task_type} | Queue size: {task_queue.qsize()}")
|
99 |
+
|
100 |
+
try:
|
101 |
+
return await asyncio.wait_for(future, timeout=timeout)
|
102 |
+
except asyncio.TimeoutError:
|
103 |
+
logging.warning(f"[ENQUEUE] Task-{task_id} timed out after {timeout}s")
|
104 |
+
return {"error": f"Task {task_type} timed out after {timeout} seconds."}
|
app/routers/analyze.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Depends, HTTPException
|
2 |
+
from app.core.security import verify_api_key
|
3 |
+
from app.schemas.base import TextOnlyRequest
|
4 |
+
from app.queue import task_queue
|
5 |
+
import asyncio
|
6 |
+
import uuid
|
7 |
+
import logging
|
8 |
+
|
9 |
+
router = APIRouter(prefix="/analyze", tags=["Analysis"])
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
@router.post("/", dependencies=[Depends(verify_api_key)])
|
13 |
+
async def analyze_text(payload: TextOnlyRequest):
|
14 |
+
text = payload.text.strip()
|
15 |
+
if not text:
|
16 |
+
raise HTTPException(status_code=400, detail="Input text cannot be empty.")
|
17 |
+
|
18 |
+
loop = asyncio.get_event_loop()
|
19 |
+
|
20 |
+
task_definitions = [
|
21 |
+
("grammar", {"text": text}),
|
22 |
+
("tone", {"text": text}),
|
23 |
+
("inclusive", {"text": text}),
|
24 |
+
("voice", {"text": text}),
|
25 |
+
("readability", {"text": text}),
|
26 |
+
("synonyms", {"text": text}),
|
27 |
+
]
|
28 |
+
|
29 |
+
futures = []
|
30 |
+
for task_type, task_payload in task_definitions:
|
31 |
+
future = loop.create_future()
|
32 |
+
task_id = str(uuid.uuid4())[:8]
|
33 |
+
|
34 |
+
await task_queue.put({
|
35 |
+
"type": task_type,
|
36 |
+
"payload": task_payload,
|
37 |
+
"future": future,
|
38 |
+
"id": task_id
|
39 |
+
})
|
40 |
+
futures.append((task_type, future))
|
41 |
+
|
42 |
+
results = await asyncio.gather(*[fut for _, fut in futures])
|
43 |
+
response = {task_type: result for (task_type, _), result in zip(futures, results)}
|
44 |
+
|
45 |
+
return {"analysis": response}
|
app/routers/grammar.py
CHANGED
@@ -1,60 +1,36 @@
|
|
|
|
|
|
|
|
1 |
from fastapi import APIRouter, Depends, HTTPException, status
|
2 |
from app.schemas.base import TextOnlyRequest
|
3 |
from app.services.grammar import GrammarCorrector
|
4 |
from app.core.security import verify_api_key
|
5 |
-
import
|
6 |
-
import logging
|
7 |
|
8 |
router = APIRouter(prefix="/grammar", tags=["Grammar"])
|
9 |
-
corrector = GrammarCorrector()
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
-
def get_diff_issues(original: str, corrected: str):
|
13 |
-
matcher = difflib.SequenceMatcher(None, original, corrected)
|
14 |
-
issues = []
|
15 |
-
|
16 |
-
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
17 |
-
if tag == 'equal':
|
18 |
-
continue
|
19 |
-
|
20 |
-
issue = {
|
21 |
-
"offset": i1,
|
22 |
-
"length": i2 - i1,
|
23 |
-
"original": original[i1:i2],
|
24 |
-
"suggestion": corrected[j1:j2],
|
25 |
-
"context_before": original[max(0, i1 - 15):i1],
|
26 |
-
"context_after": original[i2:i2 + 15],
|
27 |
-
"message": "Grammar correction",
|
28 |
-
"line": original[:i1].count("\n") + 1,
|
29 |
-
"column": i1 - original[:i1].rfind("\n") if "\n" in original[:i1] else i1 + 1
|
30 |
-
}
|
31 |
-
issues.append(issue)
|
32 |
-
|
33 |
-
return issues
|
34 |
-
|
35 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
36 |
-
def correct_grammar(payload: TextOnlyRequest):
|
37 |
text = payload.text.strip()
|
38 |
-
|
39 |
if not text:
|
40 |
-
raise HTTPException(
|
41 |
-
status_code=status.HTTP_400_BAD_REQUEST,
|
42 |
-
detail="Input text cannot be empty."
|
43 |
-
)
|
44 |
|
45 |
-
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
|
|
|
|
|
|
53 |
|
54 |
-
return {
|
55 |
-
"grammar": {
|
56 |
-
"original_text": text,
|
57 |
-
"corrected_text_suggestion": corrected,
|
58 |
-
"issues": issues
|
59 |
-
}
|
60 |
-
}
|
|
|
1 |
+
import uuid
|
2 |
+
import asyncio
|
3 |
+
import logging
|
4 |
from fastapi import APIRouter, Depends, HTTPException, status
|
5 |
from app.schemas.base import TextOnlyRequest
|
6 |
from app.services.grammar import GrammarCorrector
|
7 |
from app.core.security import verify_api_key
|
8 |
+
from app.queue import task_queue
|
|
|
9 |
|
10 |
router = APIRouter(prefix="/grammar", tags=["Grammar"])
|
|
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
14 |
+
async def correct_grammar(payload: TextOnlyRequest):
|
15 |
text = payload.text.strip()
|
|
|
16 |
if not text:
|
17 |
+
raise HTTPException(status_code=400, detail="Input text cannot be empty.")
|
|
|
|
|
|
|
18 |
|
19 |
+
future = asyncio.get_event_loop().create_future()
|
20 |
+
task_id = str(uuid.uuid4())[:8]
|
21 |
|
22 |
+
await task_queue.put({
|
23 |
+
"type": "grammar",
|
24 |
+
"payload": {"text": text},
|
25 |
+
"future": future,
|
26 |
+
"id": task_id
|
27 |
+
})
|
28 |
+
|
29 |
+
result = await future
|
30 |
|
31 |
+
if "error" in result:
|
32 |
+
detail = result["error"]
|
33 |
+
status_code = 400 if "empty" in detail.lower() else 500
|
34 |
+
raise HTTPException(status_code=status_code, detail=detail)
|
35 |
|
36 |
+
return {"grammar": result["result"]}
|
|
|
|
|
|
|
|
|
|
|
|
app/routers/paraphrase.py
CHANGED
@@ -1,12 +1,24 @@
|
|
|
|
|
|
1 |
from fastapi import APIRouter, Depends
|
2 |
from app.schemas.base import TextOnlyRequest
|
3 |
from app.services.paraphrase import Paraphraser
|
4 |
from app.core.security import verify_api_key
|
|
|
5 |
|
6 |
router = APIRouter(prefix="/paraphrase", tags=["Paraphrase"])
|
7 |
paraphraser = Paraphraser()
|
8 |
|
9 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
10 |
-
def paraphrase_text(payload: TextOnlyRequest):
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
return {"result": result}
|
|
|
1 |
+
import asyncio
|
2 |
+
import uuid
|
3 |
from fastapi import APIRouter, Depends
|
4 |
from app.schemas.base import TextOnlyRequest
|
5 |
from app.services.paraphrase import Paraphraser
|
6 |
from app.core.security import verify_api_key
|
7 |
+
from app.queue import task_queue
|
8 |
|
9 |
router = APIRouter(prefix="/paraphrase", tags=["Paraphrase"])
|
10 |
paraphraser = Paraphraser()
|
11 |
|
12 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
13 |
+
async def paraphrase_text(payload: TextOnlyRequest):
|
14 |
+
future = asyncio.get_event_loop().create_future()
|
15 |
+
task_id = str(uuid.uuid4())[:8]
|
16 |
+
await task_queue.put({
|
17 |
+
"type": "paraphrase",
|
18 |
+
"payload": {"text": payload.text},
|
19 |
+
"future": future,
|
20 |
+
"id": task_id
|
21 |
+
})
|
22 |
+
|
23 |
+
result = await future
|
24 |
return {"result": result}
|
app/routers/readability.py
CHANGED
@@ -1,33 +1,21 @@
|
|
1 |
from fastapi import APIRouter, Depends, HTTPException, status
|
2 |
from app.schemas.base import TextOnlyRequest
|
3 |
from app.core.security import verify_api_key
|
4 |
-
from app.
|
5 |
-
import asyncio
|
6 |
import logging
|
7 |
-
import textstat
|
8 |
|
9 |
router = APIRouter(prefix="/readability", tags=["Readability"])
|
10 |
logger = logging.getLogger(__name__)
|
11 |
|
12 |
-
def compute_readability(text: str) -> dict:
|
13 |
-
return {
|
14 |
-
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
15 |
-
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
16 |
-
"gunning_fog_index": textstat.gunning_fog(text),
|
17 |
-
"smog_index": textstat.smog_index(text),
|
18 |
-
"coleman_liau_index": textstat.coleman_liau_index(text),
|
19 |
-
"automated_readability_index": textstat.automated_readability_index(text),
|
20 |
-
}
|
21 |
-
|
22 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
23 |
async def readability_score(payload: TextOnlyRequest):
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
1 |
from fastapi import APIRouter, Depends, HTTPException, status
|
2 |
from app.schemas.base import TextOnlyRequest
|
3 |
from app.core.security import verify_api_key
|
4 |
+
from app.queue import enqueue_task
|
|
|
5 |
import logging
|
|
|
6 |
|
7 |
router = APIRouter(prefix="/readability", tags=["Readability"])
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
11 |
async def readability_score(payload: TextOnlyRequest):
|
12 |
+
text = payload.text.strip()
|
13 |
+
if not text:
|
14 |
+
raise HTTPException(status_code=400, detail="Input text cannot be empty.")
|
15 |
+
|
16 |
+
result = await enqueue_task("readability", {"text": text})
|
17 |
+
|
18 |
+
if isinstance(result, dict) and result.get("error"):
|
19 |
+
raise HTTPException(status_code=500, detail=result["error"])
|
20 |
+
|
21 |
+
return {"readability_scores": result["result"]}
|
app/routers/synonyms.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
2 |
+
from app.schemas.base import TextOnlyRequest
|
3 |
+
from app.services.synonyms import SynonymSuggester
|
4 |
+
from app.core.security import verify_api_key
|
5 |
+
from app.queue import enqueue_task
|
6 |
+
import logging
|
7 |
+
|
8 |
+
router = APIRouter(prefix="/synonyms", tags=["Synonyms"])
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
@router.post("/", dependencies=[Depends(verify_api_key)])
|
12 |
+
async def suggest_synonyms(payload: TextOnlyRequest):
|
13 |
+
text = payload.text.strip()
|
14 |
+
if not text:
|
15 |
+
raise HTTPException(status_code=400, detail="Input text cannot be empty.")
|
16 |
+
|
17 |
+
result = await enqueue_task("synonyms", {"text": text})
|
18 |
+
if "error" in result:
|
19 |
+
raise HTTPException(status_code=500, detail=result["error"])
|
20 |
+
|
21 |
+
return {"synonyms": result["result"]}
|
app/routers/tone.py
CHANGED
@@ -1,11 +1,24 @@
|
|
|
|
|
|
1 |
from fastapi import APIRouter, Depends
|
2 |
from app.schemas.base import TextOnlyRequest
|
3 |
from app.services.tone_classification import ToneClassifier
|
4 |
from app.core.security import verify_api_key
|
|
|
5 |
|
6 |
router = APIRouter(prefix="/tone", tags=["Tone"])
|
7 |
classifier = ToneClassifier()
|
8 |
|
9 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
10 |
-
def classify_tone(payload: TextOnlyRequest):
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import uuid
|
3 |
from fastapi import APIRouter, Depends
|
4 |
from app.schemas.base import TextOnlyRequest
|
5 |
from app.services.tone_classification import ToneClassifier
|
6 |
from app.core.security import verify_api_key
|
7 |
+
from app.queue import task_queue
|
8 |
|
9 |
router = APIRouter(prefix="/tone", tags=["Tone"])
|
10 |
classifier = ToneClassifier()
|
11 |
|
12 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
13 |
+
async def classify_tone(payload: TextOnlyRequest):
|
14 |
+
future = asyncio.get_event_loop().create_future()
|
15 |
+
task_id = str(uuid.uuid4())[:8]
|
16 |
+
await task_queue.put({
|
17 |
+
"type": "tone",
|
18 |
+
"payload": {"text": payload.text},
|
19 |
+
"future": future,
|
20 |
+
"id": task_id
|
21 |
+
})
|
22 |
+
|
23 |
+
result = await future
|
24 |
+
return {"result": result}
|
app/routers/translate.py
CHANGED
@@ -1,15 +1,24 @@
|
|
|
|
|
|
1 |
from fastapi import APIRouter, Depends
|
2 |
from app.schemas.base import TranslateRequest
|
3 |
from app.services.translation import Translator
|
4 |
from app.core.security import verify_api_key
|
|
|
5 |
|
6 |
router = APIRouter(prefix="/translate", tags=["Translate"])
|
7 |
translator = Translator()
|
8 |
|
9 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
10 |
-
def translate_text(payload: TranslateRequest):
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
return {"result": result}
|
|
|
1 |
+
import asyncio
|
2 |
+
import uuid
|
3 |
from fastapi import APIRouter, Depends
|
4 |
from app.schemas.base import TranslateRequest
|
5 |
from app.services.translation import Translator
|
6 |
from app.core.security import verify_api_key
|
7 |
+
from app.queue import task_queue
|
8 |
|
9 |
router = APIRouter(prefix="/translate", tags=["Translate"])
|
10 |
translator = Translator()
|
11 |
|
12 |
@router.post("/", dependencies=[Depends(verify_api_key)])
|
13 |
+
async def translate_text(payload: TranslateRequest):
|
14 |
+
future = asyncio.get_event_loop().create_future()
|
15 |
+
task_id = str(uuid.uuid4())[:8]
|
16 |
+
await task_queue.put({
|
17 |
+
"type": "translate",
|
18 |
+
"payload": {"text": payload.text, "target_lang": payload.target_lang},
|
19 |
+
"future": future,
|
20 |
+
"id": task_id
|
21 |
+
})
|
22 |
+
|
23 |
+
result = await future
|
24 |
return {"result": result}
|
app/services/base.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
import torch
|
2 |
from threading import Lock
|
3 |
import logging
|
4 |
-
import time
|
5 |
-
logger = logging.getLogger(__name__)
|
6 |
|
|
|
7 |
|
8 |
-
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
-
|
10 |
|
11 |
_models = {}
|
12 |
_models_lock = Lock()
|
@@ -15,24 +14,22 @@ _models_lock = Lock()
|
|
15 |
def get_cached_model(model_name: str, load_fn):
|
16 |
with _models_lock:
|
17 |
if model_name not in _models:
|
18 |
-
|
19 |
_models[model_name] = load_fn()
|
20 |
return _models[model_name]
|
21 |
|
22 |
|
23 |
-
def load_with_timer(name, fn):
|
24 |
-
# TODO: Add timing later if needed
|
25 |
-
return fn()
|
26 |
-
|
27 |
def timed_model_load(label: str, load_fn):
|
|
|
28 |
start = time.time()
|
29 |
model = load_fn()
|
30 |
logger.info(f"{label} loaded in {time.time() - start:.2f}s")
|
31 |
return model
|
32 |
|
33 |
-
|
34 |
_nlp = None
|
35 |
|
|
|
36 |
def get_spacy():
|
37 |
global _nlp
|
38 |
if _nlp is None:
|
@@ -40,3 +37,13 @@ def get_spacy():
|
|
40 |
_nlp = spacy.load("en_core_web_sm")
|
41 |
return _nlp
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from threading import Lock
|
3 |
import logging
|
|
|
|
|
4 |
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
|
7 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
8 |
+
logger.info(f"Using device: {DEVICE}")
|
9 |
|
10 |
_models = {}
|
11 |
_models_lock = Lock()
|
|
|
14 |
def get_cached_model(model_name: str, load_fn):
|
15 |
with _models_lock:
|
16 |
if model_name not in _models:
|
17 |
+
logger.info(f"Loading model: {model_name}")
|
18 |
_models[model_name] = load_fn()
|
19 |
return _models[model_name]
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
22 |
def timed_model_load(label: str, load_fn):
|
23 |
+
import time
|
24 |
start = time.time()
|
25 |
model = load_fn()
|
26 |
logger.info(f"{label} loaded in {time.time() - start:.2f}s")
|
27 |
return model
|
28 |
|
29 |
+
|
30 |
_nlp = None
|
31 |
|
32 |
+
|
33 |
def get_spacy():
|
34 |
global _nlp
|
35 |
if _nlp is None:
|
|
|
37 |
_nlp = spacy.load("en_core_web_sm")
|
38 |
return _nlp
|
39 |
|
40 |
+
|
41 |
+
# Shared error and response
|
42 |
+
class ServiceError(Exception):
|
43 |
+
def __init__(self, message: str):
|
44 |
+
super().__init__(message)
|
45 |
+
self.message = message
|
46 |
+
|
47 |
+
|
48 |
+
def model_response(result: str = "", error: str = None) -> dict:
|
49 |
+
return {"result": result, "error": error}
|
app/services/conciseness_suggestion.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from app.services.gpt4_rewrite import GPT4Rewriter
|
3 |
-
from app.core.prompts import concise_prompt
|
4 |
-
|
5 |
-
logger = logging.getLogger(__name__)
|
6 |
-
gpt4_rewriter = GPT4Rewriter()
|
7 |
-
|
8 |
-
class ConcisenessSuggester:
|
9 |
-
def suggest(self, text: str, user_api_key: str) -> str:
|
10 |
-
text = text.strip()
|
11 |
-
if not text:
|
12 |
-
logger.warning("Conciseness suggestion requested for empty input.")
|
13 |
-
return "Input text is empty."
|
14 |
-
|
15 |
-
if not user_api_key:
|
16 |
-
logger.error("Conciseness suggestion failed: Missing user_api_key.")
|
17 |
-
return "Missing OpenAI API key."
|
18 |
-
|
19 |
-
instruction = concise_prompt(text)
|
20 |
-
concise_text = gpt4_rewriter.rewrite(text, user_api_key, instruction)
|
21 |
-
|
22 |
-
if concise_text.startswith("An OpenAI API error occurred:") or \
|
23 |
-
concise_text.startswith("An unexpected error occurred:") or \
|
24 |
-
concise_text.startswith("Missing OpenAI API key.") or \
|
25 |
-
concise_text.startswith("Input text is empty."):
|
26 |
-
logger.error(f"GPT-4 conciseness suggestion failed for text: '{text[:50]}...' - {concise_text}")
|
27 |
-
return concise_text
|
28 |
-
|
29 |
-
logger.info(f"Conciseness suggestion completed for text length: {len(text)}")
|
30 |
-
return concise_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/services/gpt4_rewrite.py
CHANGED
@@ -2,43 +2,45 @@ import openai
|
|
2 |
import logging
|
3 |
from tenacity import retry, stop_after_attempt, wait_exponential
|
4 |
from app.core.config import settings
|
|
|
5 |
|
6 |
logger = logging.getLogger(__name__)
|
7 |
|
8 |
class GPT4Rewriter:
|
9 |
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
|
10 |
-
def rewrite(self, text: str, user_api_key: str, instruction: str) ->
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
|
15 |
-
|
16 |
-
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
logger.error("GPT-4 rewrite requested without a specific instruction.")
|
23 |
-
return "Missing rewrite instruction. Please provide a clear instruction for the rewrite."
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
|
30 |
-
try:
|
31 |
client = openai.OpenAI(api_key=user_api_key)
|
32 |
response = client.chat.completions.create(
|
33 |
-
model=settings.
|
34 |
messages=messages,
|
35 |
-
temperature=settings.
|
36 |
-
max_tokens=settings.
|
37 |
)
|
38 |
-
|
|
|
|
|
|
|
|
|
39 |
except openai.APIError as e:
|
40 |
-
logger.error(f"OpenAI API error
|
41 |
-
return f"
|
42 |
except Exception as e:
|
43 |
-
logger.error(f"Unexpected error
|
44 |
-
return "
|
|
|
2 |
import logging
|
3 |
from tenacity import retry, stop_after_attempt, wait_exponential
|
4 |
from app.core.config import settings
|
5 |
+
from app.services.base import model_response, ServiceError
|
6 |
|
7 |
logger = logging.getLogger(__name__)
|
8 |
|
9 |
class GPT4Rewriter:
|
10 |
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10))
|
11 |
+
def rewrite(self, text: str, user_api_key: str, instruction: str) -> dict:
|
12 |
+
try:
|
13 |
+
if not user_api_key:
|
14 |
+
raise ServiceError("Missing OpenAI API key.")
|
15 |
|
16 |
+
text = text.strip()
|
17 |
+
instruction = instruction.strip()
|
18 |
|
19 |
+
if not text:
|
20 |
+
raise ServiceError("Input text is empty.")
|
21 |
+
if not instruction:
|
22 |
+
raise ServiceError("Missing rewrite instruction.")
|
|
|
|
|
23 |
|
24 |
+
messages = [
|
25 |
+
{"role": "system", "content": instruction},
|
26 |
+
{"role": "user", "content": text},
|
27 |
+
]
|
28 |
|
|
|
29 |
client = openai.OpenAI(api_key=user_api_key)
|
30 |
response = client.chat.completions.create(
|
31 |
+
model=settings.OPENAI_MODEL,
|
32 |
messages=messages,
|
33 |
+
temperature=settings.OPENAI_TEMPERATURE,
|
34 |
+
max_tokens=settings.OPENAI_MAX_TOKENS
|
35 |
)
|
36 |
+
result = response.choices[0].message.content.strip()
|
37 |
+
return model_response(result=result)
|
38 |
+
|
39 |
+
except ServiceError as se:
|
40 |
+
return model_response(error=str(se))
|
41 |
except openai.APIError as e:
|
42 |
+
logger.error(f"OpenAI API error: {e}")
|
43 |
+
return model_response(error=f"OpenAI API error: {e}")
|
44 |
except Exception as e:
|
45 |
+
logger.error(f"Unexpected error in GPT-4 rewrite: {e}")
|
46 |
+
return model_response(error="Unexpected error during rewrite.")
|
app/services/grammar.py
CHANGED
@@ -1,7 +1,15 @@
|
|
1 |
-
|
2 |
-
import torch
|
3 |
-
from .base import get_cached_model, DEVICE, load_with_timer
|
4 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
class GrammarCorrector:
|
7 |
def __init__(self):
|
@@ -9,24 +17,62 @@ class GrammarCorrector:
|
|
9 |
|
10 |
def _load_model(self):
|
11 |
def load_fn():
|
12 |
-
tokenizer =
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
model = model.to(DEVICE).eval()
|
15 |
return tokenizer, model
|
16 |
|
17 |
return get_cached_model("grammar", load_fn)
|
18 |
|
19 |
-
def correct(self, text: str) ->
|
20 |
-
text = text.strip()
|
21 |
-
if not text:
|
22 |
-
logging.warning("Grammar correction requested for empty input.")
|
23 |
-
return "Input text is empty."
|
24 |
-
|
25 |
try:
|
|
|
|
|
|
|
|
|
26 |
with torch.no_grad():
|
27 |
inputs = self.tokenizer([text], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
|
28 |
outputs = self.model.generate(**inputs, max_length=256, num_beams=4, early_stopping=True)
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
except Exception as e:
|
31 |
-
|
32 |
-
return "An error occurred during grammar correction."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import difflib
|
|
|
|
|
2 |
import logging
|
3 |
+
import torch
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
5 |
+
|
6 |
+
from app.services.base import (
|
7 |
+
get_cached_model, DEVICE, timed_model_load,
|
8 |
+
ServiceError, model_response
|
9 |
+
)
|
10 |
+
from app.core.config import settings
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
|
14 |
class GrammarCorrector:
|
15 |
def __init__(self):
|
|
|
17 |
|
18 |
def _load_model(self):
|
19 |
def load_fn():
|
20 |
+
tokenizer = timed_model_load(
|
21 |
+
"grammar_tokenizer",
|
22 |
+
lambda: AutoTokenizer.from_pretrained(settings.GRAMMAR_MODEL)
|
23 |
+
)
|
24 |
+
model = timed_model_load(
|
25 |
+
"grammar_model",
|
26 |
+
lambda: AutoModelForSeq2SeqLM.from_pretrained(settings.GRAMMAR_MODEL)
|
27 |
+
)
|
28 |
model = model.to(DEVICE).eval()
|
29 |
return tokenizer, model
|
30 |
|
31 |
return get_cached_model("grammar", load_fn)
|
32 |
|
33 |
+
def correct(self, text: str) -> dict:
|
|
|
|
|
|
|
|
|
|
|
34 |
try:
|
35 |
+
text = text.strip()
|
36 |
+
if not text:
|
37 |
+
raise ServiceError("Input text is empty.")
|
38 |
+
|
39 |
with torch.no_grad():
|
40 |
inputs = self.tokenizer([text], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
|
41 |
outputs = self.model.generate(**inputs, max_length=256, num_beams=4, early_stopping=True)
|
42 |
+
corrected = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
43 |
+
|
44 |
+
issues = self.get_diff_issues(text, corrected)
|
45 |
+
|
46 |
+
return model_response(result={
|
47 |
+
"original_text": text,
|
48 |
+
"corrected_text_suggestion": corrected,
|
49 |
+
"issues": issues
|
50 |
+
})
|
51 |
+
|
52 |
+
except ServiceError as se:
|
53 |
+
return model_response(error=str(se))
|
54 |
except Exception as e:
|
55 |
+
logger.error(f"Grammar correction error: {e}", exc_info=True)
|
56 |
+
return model_response(error="An error occurred during grammar correction.")
|
57 |
+
|
58 |
+
def get_diff_issues(self, original: str, corrected: str):
|
59 |
+
matcher = difflib.SequenceMatcher(None, original, corrected)
|
60 |
+
issues = []
|
61 |
+
|
62 |
+
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
63 |
+
if tag == 'equal':
|
64 |
+
continue
|
65 |
+
|
66 |
+
issues.append({
|
67 |
+
"offset": i1,
|
68 |
+
"length": i2 - i1,
|
69 |
+
"original": original[i1:i2],
|
70 |
+
"suggestion": corrected[j1:j2],
|
71 |
+
"context_before": original[max(0, i1 - 15):i1],
|
72 |
+
"context_after": original[i2:i2 + 15],
|
73 |
+
"message": "Grammar correction",
|
74 |
+
"line": original[:i1].count("\n") + 1,
|
75 |
+
"column": i1 - original[:i1].rfind("\n") if "\n" in original[:i1] else i1 + 1
|
76 |
+
})
|
77 |
+
|
78 |
+
return issues
|
app/services/inclusive_language.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
import yaml
|
2 |
from pathlib import Path
|
3 |
-
from spacy.matcher import PhraseMatcher
|
4 |
from typing import List, Dict
|
5 |
-
from .base import get_spacy
|
6 |
from app.core.config import settings
|
7 |
import logging
|
8 |
|
|
|
|
|
9 |
class InclusiveLanguageChecker:
|
10 |
-
def __init__(self, rules_directory=settings.
|
11 |
-
self.matcher = PhraseMatcher(get_spacy().vocab, attr="LOWER")
|
12 |
self.rules = self._load_inclusive_rules(rules_directory)
|
13 |
-
self._init_matcher()
|
14 |
|
15 |
def _load_inclusive_rules(self, directory: str) -> Dict[str, Dict]:
|
16 |
rules = {}
|
@@ -19,7 +19,7 @@ class InclusiveLanguageChecker:
|
|
19 |
with open(path, encoding="utf-8") as f:
|
20 |
rule_list = yaml.safe_load(f)
|
21 |
if not isinstance(rule_list, list):
|
22 |
-
|
23 |
continue
|
24 |
for rule in rule_list:
|
25 |
note = rule.get("note", "")
|
@@ -33,66 +33,72 @@ class InclusiveLanguageChecker:
|
|
33 |
inconsiderate = [inconsiderate]
|
34 |
|
35 |
for phrase in inconsiderate:
|
36 |
-
|
37 |
-
rules[key] = {
|
38 |
"note": note,
|
39 |
"considerate": considerate,
|
40 |
"source": source,
|
41 |
"type": rule.get("type", "basic")
|
42 |
}
|
43 |
except Exception as e:
|
44 |
-
|
45 |
return rules
|
46 |
|
47 |
def _init_matcher(self):
|
|
|
|
|
48 |
for phrase in self.rules:
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
})
|
80 |
|
81 |
-
|
82 |
-
for token in doc:
|
83 |
-
lemma = token.lemma_.lower()
|
84 |
-
span_key = (token.i, token.i + 1)
|
85 |
-
if lemma in self.rules and span_key not in matched_spans:
|
86 |
-
rule = self.rules[lemma]
|
87 |
-
results.append({
|
88 |
-
"term": token.text,
|
89 |
-
"type": rule["type"],
|
90 |
-
"note": rule["note"],
|
91 |
-
"suggestions": rule["considerate"],
|
92 |
-
"context": token.sent.text,
|
93 |
-
"start": token.idx,
|
94 |
-
"end": token.idx + len(token),
|
95 |
-
"source": rule.get("source", "")
|
96 |
-
})
|
97 |
|
98 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import yaml
|
2 |
from pathlib import Path
|
|
|
3 |
from typing import List, Dict
|
4 |
+
from app.services.base import get_spacy, model_response, ServiceError
|
5 |
from app.core.config import settings
|
6 |
import logging
|
7 |
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
class InclusiveLanguageChecker:
|
11 |
+
def __init__(self, rules_directory=settings.INCLUSIVE_RULES_DIR):
|
|
|
12 |
self.rules = self._load_inclusive_rules(rules_directory)
|
13 |
+
self.matcher = self._init_matcher()
|
14 |
|
15 |
def _load_inclusive_rules(self, directory: str) -> Dict[str, Dict]:
|
16 |
rules = {}
|
|
|
19 |
with open(path, encoding="utf-8") as f:
|
20 |
rule_list = yaml.safe_load(f)
|
21 |
if not isinstance(rule_list, list):
|
22 |
+
logger.warning(f"Skipping malformed rule file: {path}")
|
23 |
continue
|
24 |
for rule in rule_list:
|
25 |
note = rule.get("note", "")
|
|
|
33 |
inconsiderate = [inconsiderate]
|
34 |
|
35 |
for phrase in inconsiderate:
|
36 |
+
rules[phrase.lower()] = {
|
|
|
37 |
"note": note,
|
38 |
"considerate": considerate,
|
39 |
"source": source,
|
40 |
"type": rule.get("type", "basic")
|
41 |
}
|
42 |
except Exception as e:
|
43 |
+
logger.error(f"Error loading inclusive language rule from {path}: {e}")
|
44 |
return rules
|
45 |
|
46 |
def _init_matcher(self):
|
47 |
+
from spacy.matcher import PhraseMatcher
|
48 |
+
matcher = PhraseMatcher(get_spacy().vocab, attr="LOWER")
|
49 |
for phrase in self.rules:
|
50 |
+
matcher.add(phrase, [get_spacy().make_doc(phrase)])
|
51 |
+
logger.info(f"Loaded {len(self.rules)} inclusive language rules.")
|
52 |
+
return matcher
|
53 |
+
|
54 |
+
def check(self, text: str) -> dict:
|
55 |
+
try:
|
56 |
+
text = text.strip()
|
57 |
+
if not text:
|
58 |
+
raise ServiceError("Input text is empty.")
|
59 |
|
60 |
+
nlp = get_spacy()
|
61 |
+
doc = nlp(text)
|
62 |
+
matches = self.matcher(doc)
|
63 |
+
results = []
|
64 |
+
matched_spans = set()
|
65 |
|
66 |
+
for match_id, start, end in matches:
|
67 |
+
phrase_str = nlp.vocab.strings[match_id]
|
68 |
+
matched_spans.add((start, end))
|
69 |
+
rule = self.rules.get(phrase_str.lower())
|
70 |
+
if rule:
|
71 |
+
results.append({
|
72 |
+
"term": doc[start:end].text,
|
73 |
+
"type": rule["type"],
|
74 |
+
"note": rule["note"],
|
75 |
+
"suggestions": rule["considerate"],
|
76 |
+
"context": doc[start:end].sent.text,
|
77 |
+
"start": doc[start].idx,
|
78 |
+
"end": doc[end - 1].idx + len(doc[end - 1]),
|
79 |
+
"source": rule.get("source", "")
|
80 |
+
})
|
81 |
|
82 |
+
for token in doc:
|
83 |
+
lemma = token.lemma_.lower()
|
84 |
+
span_key = (token.i, token.i + 1)
|
85 |
+
if lemma in self.rules and span_key not in matched_spans:
|
86 |
+
rule = self.rules[lemma]
|
87 |
+
results.append({
|
88 |
+
"term": token.text,
|
89 |
+
"type": rule["type"],
|
90 |
+
"note": rule["note"],
|
91 |
+
"suggestions": rule["considerate"],
|
92 |
+
"context": token.sent.text,
|
93 |
+
"start": token.idx,
|
94 |
+
"end": token.idx + len(token),
|
95 |
+
"source": rule.get("source", "")
|
96 |
+
})
|
|
|
97 |
|
98 |
+
return model_response(result=results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
+
except ServiceError as se:
|
101 |
+
return model_response(error=str(se))
|
102 |
+
except Exception as e:
|
103 |
+
logger.error(f"Inclusive language check error: {e}")
|
104 |
+
return model_response(error="An error occurred during inclusive language checking.")
|
app/services/paraphrase.py
CHANGED
@@ -1,28 +1,30 @@
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
|
2 |
import torch
|
3 |
-
from .base import get_cached_model, DEVICE, timed_model_load
|
4 |
import logging
|
5 |
|
|
|
|
|
6 |
class Paraphraser:
|
7 |
def __init__(self):
|
8 |
self.tokenizer, self.model = self._load_model()
|
9 |
|
10 |
def _load_model(self):
|
11 |
def load_fn():
|
12 |
-
tokenizer = timed_model_load("paraphrase_tokenizer", lambda: AutoTokenizer.from_pretrained(
|
13 |
-
model = timed_model_load("paraphrase_model", lambda: AutoModelForSeq2SeqLM.from_pretrained(
|
14 |
model = model.to(DEVICE).eval()
|
15 |
return tokenizer, model
|
16 |
return get_cached_model("paraphrase", load_fn)
|
17 |
|
18 |
-
def paraphrase(self, text: str) ->
|
19 |
-
text = text.strip()
|
20 |
-
if not text:
|
21 |
-
logging.warning("Paraphrasing requested for empty input.")
|
22 |
-
return "Input text is empty."
|
23 |
-
|
24 |
-
prompt = f"paraphrase: {text} </s>"
|
25 |
try:
|
|
|
|
|
|
|
|
|
|
|
26 |
with torch.no_grad():
|
27 |
inputs = self.tokenizer([prompt], return_tensors="pt", padding=True, truncation=True).to(DEVICE)
|
28 |
outputs = self.model.generate(
|
@@ -32,8 +34,11 @@ class Paraphraser:
|
|
32 |
num_return_sequences=1,
|
33 |
early_stopping=True
|
34 |
)
|
35 |
-
|
36 |
-
|
37 |
-
logging.error(f"Error during paraphrasing: {e}")
|
38 |
-
return f"An error occurred during paraphrasing: {e}"
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
+
from app.services.base import get_cached_model, DEVICE, timed_model_load, model_response, ServiceError
|
3 |
+
from app.core.config import settings
|
4 |
import torch
|
|
|
5 |
import logging
|
6 |
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
class Paraphraser:
|
10 |
def __init__(self):
|
11 |
self.tokenizer, self.model = self._load_model()
|
12 |
|
13 |
def _load_model(self):
|
14 |
def load_fn():
|
15 |
+
tokenizer = timed_model_load("paraphrase_tokenizer", lambda: AutoTokenizer.from_pretrained(settings.PARAPHRASE_MODEL))
|
16 |
+
model = timed_model_load("paraphrase_model", lambda: AutoModelForSeq2SeqLM.from_pretrained(settings.PARAPHRASE_MODEL))
|
17 |
model = model.to(DEVICE).eval()
|
18 |
return tokenizer, model
|
19 |
return get_cached_model("paraphrase", load_fn)
|
20 |
|
21 |
+
def paraphrase(self, text: str) -> dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
try:
|
23 |
+
text = text.strip()
|
24 |
+
if not text:
|
25 |
+
raise ServiceError("Input text is empty.")
|
26 |
+
|
27 |
+
prompt = f"paraphrase: {text} </s>"
|
28 |
with torch.no_grad():
|
29 |
inputs = self.tokenizer([prompt], return_tensors="pt", padding=True, truncation=True).to(DEVICE)
|
30 |
outputs = self.model.generate(
|
|
|
34 |
num_return_sequences=1,
|
35 |
early_stopping=True
|
36 |
)
|
37 |
+
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
38 |
+
return model_response(result=result)
|
|
|
|
|
39 |
|
40 |
+
except ServiceError as se:
|
41 |
+
return model_response(error=str(se))
|
42 |
+
except Exception as e:
|
43 |
+
logger.error(f"Paraphrasing error: {e}")
|
44 |
+
return model_response(error="An error occurred during paraphrasing.")
|
app/services/readability.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import textstat
|
2 |
+
import logging
|
3 |
+
from app.services.base import model_response, ServiceError
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
class ReadabilityScorer:
|
8 |
+
def compute(self, text: str) -> dict:
|
9 |
+
try:
|
10 |
+
text = text.strip()
|
11 |
+
if not text:
|
12 |
+
raise ServiceError("Input text is empty.")
|
13 |
+
|
14 |
+
# Compute scores
|
15 |
+
scores = {
|
16 |
+
"flesch_reading_ease": textstat.flesch_reading_ease(text),
|
17 |
+
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
|
18 |
+
"gunning_fog_index": textstat.gunning_fog(text),
|
19 |
+
"smog_index": textstat.smog_index(text),
|
20 |
+
"coleman_liau_index": textstat.coleman_liau_index(text),
|
21 |
+
"automated_readability_index": textstat.automated_readability_index(text),
|
22 |
+
}
|
23 |
+
|
24 |
+
# Friendly descriptions
|
25 |
+
friendly_scores = {
|
26 |
+
"flesch_reading_ease": {
|
27 |
+
"score": scores["flesch_reading_ease"],
|
28 |
+
"label": "Flesch Reading Ease",
|
29 |
+
"description": "Higher is easier. 60–70 is plain English; 90+ is very easy."
|
30 |
+
},
|
31 |
+
"flesch_kincaid_grade": {
|
32 |
+
"score": scores["flesch_kincaid_grade"],
|
33 |
+
"label": "Flesch-Kincaid Grade Level",
|
34 |
+
"description": "U.S. school grade. 8.0 means an 8th grader can understand it."
|
35 |
+
},
|
36 |
+
"gunning_fog_index": {
|
37 |
+
"score": scores["gunning_fog_index"],
|
38 |
+
"label": "Gunning Fog Index",
|
39 |
+
"description": "Estimates years of formal education needed to understand."
|
40 |
+
},
|
41 |
+
"smog_index": {
|
42 |
+
"score": scores["smog_index"],
|
43 |
+
"label": "SMOG Index",
|
44 |
+
"description": "Also estimates required years of education."
|
45 |
+
},
|
46 |
+
"coleman_liau_index": {
|
47 |
+
"score": scores["coleman_liau_index"],
|
48 |
+
"label": "Coleman-Liau Index",
|
49 |
+
"description": "Grade level based on characters, not syllables."
|
50 |
+
},
|
51 |
+
"automated_readability_index": {
|
52 |
+
"score": scores["automated_readability_index"],
|
53 |
+
"label": "Automated Readability Index",
|
54 |
+
"description": "Grade level using word and sentence lengths."
|
55 |
+
}
|
56 |
+
}
|
57 |
+
|
58 |
+
# Flesch score guide
|
59 |
+
ease_score = scores["flesch_reading_ease"]
|
60 |
+
if ease_score >= 90:
|
61 |
+
summary = "Very easy to read. Easily understood by 11-year-olds."
|
62 |
+
elif ease_score >= 70:
|
63 |
+
summary = "Fairly easy. Conversational English for most people."
|
64 |
+
elif ease_score >= 60:
|
65 |
+
summary = "Plain English. Easily understood by 13–15-year-olds."
|
66 |
+
elif ease_score >= 30:
|
67 |
+
summary = "Fairly difficult. College-level reading."
|
68 |
+
else:
|
69 |
+
summary = "Very difficult. Best understood by university graduates."
|
70 |
+
|
71 |
+
return model_response(result={
|
72 |
+
"readability_summary": summary,
|
73 |
+
"scores": friendly_scores
|
74 |
+
})
|
75 |
+
|
76 |
+
except ServiceError as se:
|
77 |
+
return model_response(error=str(se))
|
78 |
+
except Exception as e:
|
79 |
+
logger.error(f"Readability scoring error: {e}", exc_info=True)
|
80 |
+
return model_response(error="An error occurred during readability scoring.")
|
app/services/synonyms.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import asyncio
|
3 |
+
from nltk.corpus import wordnet
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
from sentence_transformers import SentenceTransformer, util
|
6 |
+
from typing import List, Dict
|
7 |
+
from functools import lru_cache
|
8 |
+
|
9 |
+
# Assuming these are available in your project structure
|
10 |
+
from app.services.base import (
|
11 |
+
model_response, ServiceError, timed_model_load,
|
12 |
+
get_cached_model, DEVICE, get_spacy
|
13 |
+
)
|
14 |
+
from app.core.config import settings # Assuming settings might contain a BATCH_SIZE
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
# Mapping spaCy POS tags to WordNet POS tags
|
19 |
+
SPACY_TO_WORDNET_POS = {
|
20 |
+
"NOUN": wordnet.NOUN,
|
21 |
+
"VERB": wordnet.VERB,
|
22 |
+
"ADJ": wordnet.ADJ,
|
23 |
+
"ADV": wordnet.ADV,
|
24 |
+
}
|
25 |
+
|
26 |
+
# Only consider these POS tags for synonym suggestions
|
27 |
+
CONTENT_POS_TAGS = {"NOUN", "VERB", "ADJ", "ADV"}
|
28 |
+
SENTENCE_TRANSFORMER_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
29 |
+
DEFAULT_BATCH_SIZE = settings.SENTENCE_TRANSFORMER_BATCH_SIZE if hasattr(settings, 'SENTENCE_TRANSFORMER_BATCH_SIZE') else 32
|
30 |
+
|
31 |
+
class SynonymSuggester:
|
32 |
+
def __init__(self):
|
33 |
+
self.sentence_model = self._load_sentence_transformer_model()
|
34 |
+
self.nlp = self._load_spacy_model()
|
35 |
+
|
36 |
+
def _load_sentence_transformer_model(self):
|
37 |
+
def load_fn():
|
38 |
+
# SentenceTransformer automatically handles device placement if CUDA is available
|
39 |
+
# It can also be explicitly passed: device=DEVICE if DEVICE else None
|
40 |
+
model = timed_model_load(
|
41 |
+
"sentence_transformer",
|
42 |
+
lambda: SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)
|
43 |
+
)
|
44 |
+
return model
|
45 |
+
return get_cached_model("synonym_sentence_model", load_fn)
|
46 |
+
|
47 |
+
def _load_spacy_model(self):
|
48 |
+
# Using asyncio.to_thread for initial model load if it's blocking
|
49 |
+
# but timed_model_load likely handles that by caching.
|
50 |
+
return timed_model_load("spacy_en_model", lambda: get_spacy())
|
51 |
+
|
52 |
+
async def suggest(self, text: str) -> dict:
|
53 |
+
text = text.strip()
|
54 |
+
if not text:
|
55 |
+
raise ServiceError("Input text is empty.")
|
56 |
+
|
57 |
+
# Use asyncio.to_thread consistently for blocking operations
|
58 |
+
doc = await asyncio.to_thread(self.nlp, text)
|
59 |
+
|
60 |
+
all_suggestions: Dict[str, List[str]] = {}
|
61 |
+
|
62 |
+
# Encode original text once
|
63 |
+
original_text_embedding = await asyncio.to_thread(
|
64 |
+
self.sentence_model.encode, text, convert_to_tensor=True, normalize_embeddings=True
|
65 |
+
)
|
66 |
+
|
67 |
+
# 1. Collect all potential candidates and their contexts for batching
|
68 |
+
candidate_data = [] # List of (original_word, wordnet_candidate, temp_sentence, original_word_idx)
|
69 |
+
|
70 |
+
for token in doc:
|
71 |
+
if (
|
72 |
+
token.pos_ in CONTENT_POS_TAGS and
|
73 |
+
len(token.text.strip()) > 2 and
|
74 |
+
not token.is_punct and
|
75 |
+
not token.is_space
|
76 |
+
):
|
77 |
+
original_word = token.text
|
78 |
+
word_start = token.idx
|
79 |
+
word_end = token.idx + len(original_word)
|
80 |
+
|
81 |
+
# Filter WordNet synonyms by the token's Part-of-Speech
|
82 |
+
wordnet_pos = SPACY_TO_WORDNET_POS.get(token.pos_)
|
83 |
+
if wordnet_pos is None:
|
84 |
+
continue # Skip if no direct WordNet POS mapping
|
85 |
+
|
86 |
+
wordnet_synonyms_candidates = await asyncio.to_thread(
|
87 |
+
self._get_wordnet_synonyms_cached, original_word, wordnet_pos
|
88 |
+
)
|
89 |
+
|
90 |
+
if not wordnet_synonyms_candidates:
|
91 |
+
continue
|
92 |
+
|
93 |
+
for candidate in wordnet_synonyms_candidates:
|
94 |
+
temp_sentence = text[:word_start] + candidate + text[word_end:]
|
95 |
+
candidate_data.append({
|
96 |
+
"original_word": original_word,
|
97 |
+
"wordnet_candidate": candidate,
|
98 |
+
"temp_sentence": temp_sentence,
|
99 |
+
"original_word_idx": len(all_suggestions.get(original_word, [])) # Used for tracking initial suggestions count
|
100 |
+
})
|
101 |
+
|
102 |
+
# Initialize list for this original word, if not already
|
103 |
+
if original_word not in all_suggestions:
|
104 |
+
all_suggestions[original_word] = []
|
105 |
+
|
106 |
+
if not candidate_data:
|
107 |
+
return model_response(result={})
|
108 |
+
|
109 |
+
# 2. Encode all candidate sentences in a single batch
|
110 |
+
all_candidate_sentences = [data["temp_sentence"] for data in candidate_data]
|
111 |
+
all_candidate_embeddings = await asyncio.to_thread(
|
112 |
+
self.sentence_model.encode,
|
113 |
+
all_candidate_sentences,
|
114 |
+
batch_size=DEFAULT_BATCH_SIZE, # Use a configurable batch size
|
115 |
+
convert_to_tensor=True,
|
116 |
+
normalize_embeddings=True
|
117 |
+
)
|
118 |
+
|
119 |
+
# 3. Calculate similarities and filter
|
120 |
+
# Ensure original_text_embedding is 2D for cos_sim if it's a single embedding
|
121 |
+
# util.cos_sim expects (A, B) where A and B are matrices of embeddings
|
122 |
+
# Reshape original_text_embedding if it's a 1D tensor
|
123 |
+
if original_text_embedding.dim() == 1:
|
124 |
+
original_text_embedding = original_text_embedding.unsqueeze(0)
|
125 |
+
|
126 |
+
cosine_scores = util.cos_sim(original_text_embedding, all_candidate_embeddings)[0] # [0] because cos_sim returns a matrix
|
127 |
+
|
128 |
+
similarity_threshold = 0.65
|
129 |
+
top_n_suggestions = 5
|
130 |
+
|
131 |
+
# Reconstruct results by iterating through candidate_data and scores
|
132 |
+
for i, data in enumerate(candidate_data):
|
133 |
+
original_word = data["original_word"]
|
134 |
+
candidate = data["wordnet_candidate"]
|
135 |
+
score = cosine_scores[i].item() # Get scalar score from tensor
|
136 |
+
|
137 |
+
# Apply filtering criteria
|
138 |
+
if score >= similarity_threshold and candidate.lower() != original_word.lower():
|
139 |
+
if len(all_suggestions[original_word]) < top_n_suggestions:
|
140 |
+
all_suggestions[original_word].append(candidate)
|
141 |
+
|
142 |
+
# Remove any words for which no suggestions were found after filtering
|
143 |
+
final_suggestions = {
|
144 |
+
word: suggestions for word, suggestions in all_suggestions.items() if suggestions
|
145 |
+
}
|
146 |
+
|
147 |
+
return model_response(result=final_suggestions)
|
148 |
+
|
149 |
+
@lru_cache(maxsize=5000)
|
150 |
+
def _get_wordnet_synonyms_cached(self, word: str, pos: str) -> List[str]:
|
151 |
+
"""
|
152 |
+
Retrieves synonyms for a word from WordNet, filtered by Part-of-Speech.
|
153 |
+
"""
|
154 |
+
synonyms = set()
|
155 |
+
for syn in wordnet.synsets(word, pos=pos): # Filter by POS
|
156 |
+
for lemma in syn.lemmas():
|
157 |
+
name = lemma.name().replace("_", " ").lower()
|
158 |
+
# Basic filtering for valid word forms
|
159 |
+
if name.isalpha() and len(name) > 1:
|
160 |
+
synonyms.add(name)
|
161 |
+
synonyms.discard(word.lower()) # Ensure original word is not included
|
162 |
+
return list(synonyms)
|
app/services/tone_classification.py
CHANGED
@@ -1,36 +1,75 @@
|
|
1 |
-
from transformers import pipeline
|
2 |
-
import torch
|
3 |
-
from .base import get_cached_model, DEVICE
|
4 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
class ToneClassifier:
|
7 |
def __init__(self):
|
8 |
self.classifier = self._load_model()
|
9 |
|
10 |
def _load_model(self):
|
|
|
|
|
|
|
11 |
def load_fn():
|
12 |
-
|
13 |
-
"sentiment-analysis",
|
14 |
-
model=
|
15 |
-
|
16 |
-
|
17 |
)
|
18 |
-
|
|
|
19 |
|
20 |
-
|
21 |
-
text = text.strip()
|
22 |
-
if not text:
|
23 |
-
logging.warning("Tone classification requested for empty input.")
|
24 |
-
return "Input text is empty."
|
25 |
|
|
|
26 |
try:
|
27 |
-
|
28 |
-
if
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
except Exception as e:
|
35 |
-
|
36 |
-
return "An error occurred during tone classification."
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
import torch
|
3 |
+
from transformers import pipeline
|
4 |
+
from app.services.base import get_cached_model, model_response, ServiceError
|
5 |
+
from app.core.config import settings
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
|
9 |
class ToneClassifier:
|
10 |
def __init__(self):
|
11 |
self.classifier = self._load_model()
|
12 |
|
13 |
def _load_model(self):
|
14 |
+
"""
|
15 |
+
Loads and caches the sentiment-analysis pipeline.
|
16 |
+
"""
|
17 |
def load_fn():
|
18 |
+
model = pipeline(
|
19 |
+
"sentiment-analysis", # Or "text-classification" if you prefer
|
20 |
+
model=settings.TONE_MODEL,
|
21 |
+
device=0 if torch.cuda.is_available() else -1,
|
22 |
+
return_all_scores=True # Keep this true
|
23 |
)
|
24 |
+
logger.info(f"ToneClassifier model loaded on {'CUDA' if torch.cuda.is_available() else 'CPU'}")
|
25 |
+
return model
|
26 |
|
27 |
+
return get_cached_model("tone_model", load_fn)
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
def classify(self, text: str) -> dict:
|
30 |
try:
|
31 |
+
text = text.strip()
|
32 |
+
if not text:
|
33 |
+
raise ServiceError("Input text is empty.")
|
34 |
+
|
35 |
+
raw_results = self.classifier(text)
|
36 |
+
|
37 |
+
# Check for expected pipeline output format
|
38 |
+
if not (isinstance(raw_results, list) and raw_results and isinstance(raw_results[0], list)):
|
39 |
+
logger.error(f"Unexpected raw_results format from pipeline: {raw_results}")
|
40 |
+
return model_response(error="Unexpected model output format.")
|
41 |
+
|
42 |
+
scores_for_text = raw_results[0]
|
43 |
+
|
44 |
+
# Sort the emotions by score in descending order
|
45 |
+
sorted_emotions = sorted(scores_for_text, key=lambda x: x['score'], reverse=True)
|
46 |
+
|
47 |
+
# --- NEW LOGGING ADDITIONS START HERE ---
|
48 |
+
|
49 |
+
# Log all emotion scores for the given text (useful for seeing the full distribution)
|
50 |
+
logger.debug(f"Input Text: '{text}'")
|
51 |
+
logger.debug("--- Emotion Scores (Label: Score) ---")
|
52 |
+
for emotion in sorted_emotions:
|
53 |
+
logger.debug(f" {emotion['label']}: {emotion['score']:.4f}")
|
54 |
+
logger.debug("-------------------------------------")
|
55 |
+
|
56 |
+
# --- NEW LOGGING ADDITIONS END HERE ---
|
57 |
+
|
58 |
+
top_emotion = sorted_emotions[0]
|
59 |
+
predicted_label = top_emotion.get("label", "Unknown")
|
60 |
+
predicted_score = top_emotion.get("score", 0.0)
|
61 |
+
|
62 |
+
# Apply the confidence threshold
|
63 |
+
if predicted_score >= settings.TONE_CONFIDENCE_THRESHOLD:
|
64 |
+
logger.info(f"Final prediction for '{text}': '{predicted_label}' (Score: {predicted_score:.4f}, Above Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f})")
|
65 |
+
return model_response(result=predicted_label)
|
66 |
+
else:
|
67 |
+
logger.info(f"Final prediction for '{text}': 'neutral' (Top Score: {predicted_score:.4f}, Below Threshold: {settings.TONE_CONFIDENCE_THRESHOLD:.2f}).")
|
68 |
+
return model_response(result="neutral")
|
69 |
+
|
70 |
+
except ServiceError as se:
|
71 |
+
logger.error(f"Tone classification ServiceError for text '{text}': {se}")
|
72 |
+
return model_response(error=str(se))
|
73 |
except Exception as e:
|
74 |
+
logger.error(f"Tone classification unexpected error for text '{text}': {e}", exc_info=True)
|
75 |
+
return model_response(error="An error occurred during tone classification.")
|
app/services/translation.py
CHANGED
@@ -1,40 +1,45 @@
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
import torch
|
3 |
import logging
|
4 |
-
from .base import get_cached_model, DEVICE, timed_model_load
|
5 |
from app.core.config import settings
|
6 |
|
|
|
|
|
7 |
class Translator:
|
8 |
def __init__(self):
|
9 |
self.tokenizer, self.model = self._load_model()
|
10 |
|
11 |
def _load_model(self):
|
12 |
def load_fn():
|
13 |
-
tokenizer = timed_model_load("translate_tokenizer", lambda: AutoTokenizer.from_pretrained(
|
14 |
-
model = timed_model_load("translate_model", lambda: AutoModelForSeq2SeqLM.from_pretrained(
|
15 |
model = model.to(DEVICE).eval()
|
16 |
return tokenizer, model
|
17 |
return get_cached_model("translate", load_fn)
|
18 |
|
19 |
-
def translate(self, text: str, target_lang: str) ->
|
20 |
-
|
21 |
-
|
|
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
if target_lang not in settings.supported_translation_languages:
|
30 |
-
return f"Unsupported target language: {target_lang}"
|
31 |
|
32 |
-
|
33 |
-
try:
|
34 |
with torch.no_grad():
|
35 |
inputs = self.tokenizer([prompt], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
|
36 |
outputs = self.model.generate(**inputs, max_length=256, num_beams=1, early_stopping=True)
|
37 |
-
|
|
|
|
|
|
|
|
|
38 |
except Exception as e:
|
39 |
-
|
40 |
-
return
|
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
import torch
|
3 |
import logging
|
4 |
+
from app.services.base import get_cached_model, DEVICE, timed_model_load, ServiceError, model_response
|
5 |
from app.core.config import settings
|
6 |
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
class Translator:
|
10 |
def __init__(self):
|
11 |
self.tokenizer, self.model = self._load_model()
|
12 |
|
13 |
def _load_model(self):
|
14 |
def load_fn():
|
15 |
+
tokenizer = timed_model_load("translate_tokenizer", lambda: AutoTokenizer.from_pretrained(settings.TRANSLATION_MODEL))
|
16 |
+
model = timed_model_load("translate_model", lambda: AutoModelForSeq2SeqLM.from_pretrained(settings.TRANSLATION_MODEL))
|
17 |
model = model.to(DEVICE).eval()
|
18 |
return tokenizer, model
|
19 |
return get_cached_model("translate", load_fn)
|
20 |
|
21 |
+
def translate(self, text: str, target_lang: str) -> dict:
|
22 |
+
try:
|
23 |
+
text = text.strip()
|
24 |
+
target_lang = target_lang.strip()
|
25 |
|
26 |
+
if not text:
|
27 |
+
raise ServiceError("Input text is empty.")
|
28 |
+
if not target_lang:
|
29 |
+
raise ServiceError("Target language is empty.")
|
30 |
+
if target_lang not in settings.SUPPORTED_TRANSLATION_LANGUAGES:
|
31 |
+
raise ServiceError(f"Unsupported target language: {target_lang}")
|
|
|
|
|
32 |
|
33 |
+
prompt = f">>{target_lang}<< {text}"
|
|
|
34 |
with torch.no_grad():
|
35 |
inputs = self.tokenizer([prompt], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
|
36 |
outputs = self.model.generate(**inputs, max_length=256, num_beams=1, early_stopping=True)
|
37 |
+
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
38 |
+
return model_response(result=result)
|
39 |
+
|
40 |
+
except ServiceError as se:
|
41 |
+
return model_response(error=str(se))
|
42 |
except Exception as e:
|
43 |
+
logger.error(f"Translation error: {e}")
|
44 |
+
return model_response(error="An error occurred during translation.")
|
45 |
+
|
app/services/vocabulary_enhancement.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from app.services.gpt4_rewrite import GPT4Rewriter
|
3 |
-
from app.core.prompts import vocabulary_prompt
|
4 |
-
|
5 |
-
logger = logging.getLogger(__name__)
|
6 |
-
gpt4_rewriter = GPT4Rewriter()
|
7 |
-
|
8 |
-
class VocabularyEnhancer:
|
9 |
-
def enhance(self, text: str, user_api_key: str) -> str:
|
10 |
-
text = text.strip()
|
11 |
-
if not text:
|
12 |
-
logger.warning("Vocabulary enhancement requested for empty input.")
|
13 |
-
return "Input text is empty."
|
14 |
-
|
15 |
-
if not user_api_key:
|
16 |
-
logger.error("Vocabulary enhancement failed: Missing user_api_key.")
|
17 |
-
return "Missing OpenAI API key."
|
18 |
-
|
19 |
-
instruction = vocabulary_prompt(text)
|
20 |
-
enhanced_text = gpt4_rewriter.rewrite(text, user_api_key, instruction)
|
21 |
-
|
22 |
-
if enhanced_text.startswith("An OpenAI API error occurred:") or \
|
23 |
-
enhanced_text.startswith("An unexpected error occurred:") or \
|
24 |
-
enhanced_text.startswith("Missing OpenAI API key.") or \
|
25 |
-
enhanced_text.startswith("Input text is empty."):
|
26 |
-
logger.error(f"GPT-4 vocabulary enhancement failed for text: '{text[:50]}...' - {enhanced_text}")
|
27 |
-
return enhanced_text
|
28 |
-
|
29 |
-
logger.info(f"Vocabulary enhancement completed for text length: {len(text)}")
|
30 |
-
return enhanced_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/services/voice_detection.py
CHANGED
@@ -1,17 +1,18 @@
|
|
1 |
import logging
|
2 |
-
from .base import get_spacy
|
|
|
|
|
3 |
|
4 |
class VoiceDetector:
|
5 |
def __init__(self):
|
6 |
self.nlp = get_spacy()
|
7 |
|
8 |
-
def classify(self, text: str) ->
|
9 |
-
text = text.strip()
|
10 |
-
if not text:
|
11 |
-
logging.warning("Voice detection requested for empty input.")
|
12 |
-
return "Input text is empty."
|
13 |
-
|
14 |
try:
|
|
|
|
|
|
|
|
|
15 |
doc = self.nlp(text)
|
16 |
passive_sentences = 0
|
17 |
total_sentences = 0
|
@@ -24,10 +25,14 @@ class VoiceDetector:
|
|
24 |
break
|
25 |
|
26 |
if total_sentences == 0:
|
27 |
-
return "Unknown"
|
28 |
|
29 |
ratio = passive_sentences / total_sentences
|
30 |
-
return "Passive" if ratio > 0.5 else "Active"
|
|
|
|
|
|
|
31 |
except Exception as e:
|
32 |
-
|
33 |
-
return "An error occurred during voice detection."
|
|
|
|
1 |
import logging
|
2 |
+
from app.services.base import get_spacy, model_response, ServiceError
|
3 |
+
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
|
6 |
class VoiceDetector:
|
7 |
def __init__(self):
|
8 |
self.nlp = get_spacy()
|
9 |
|
10 |
+
def classify(self, text: str) -> dict:
|
|
|
|
|
|
|
|
|
|
|
11 |
try:
|
12 |
+
text = text.strip()
|
13 |
+
if not text:
|
14 |
+
raise ServiceError("Input text is empty.")
|
15 |
+
|
16 |
doc = self.nlp(text)
|
17 |
passive_sentences = 0
|
18 |
total_sentences = 0
|
|
|
25 |
break
|
26 |
|
27 |
if total_sentences == 0:
|
28 |
+
return model_response(result="Unknown")
|
29 |
|
30 |
ratio = passive_sentences / total_sentences
|
31 |
+
return model_response(result="Passive" if ratio > 0.5 else "Active")
|
32 |
+
|
33 |
+
except ServiceError as se:
|
34 |
+
return model_response(error=str(se))
|
35 |
except Exception as e:
|
36 |
+
logger.error(f"Voice detection error: {e}")
|
37 |
+
return model_response(error="An error occurred during voice detection.")
|
38 |
+
|
app/test/test_rewrite.py
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
import pytest
|
2 |
-
from fastapi.testclient import TestClient
|
3 |
-
from main import create_app
|
4 |
-
|
5 |
-
client = TestClient(create_app())
|
6 |
-
|
7 |
-
def test_rewrite_success():
|
8 |
-
payload = {
|
9 |
-
"text": "This is a very long and unnecessarily wordy sentence that could be simpler.",
|
10 |
-
"instruction": "Make this more concise"
|
11 |
-
}
|
12 |
-
headers = {"x-api-key": "your-secret-key"}
|
13 |
-
response = client.post("/rewrite/", json=payload, headers=headers)
|
14 |
-
assert response.status_code == 200
|
15 |
-
assert "result" in response.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/tests/__init__.py
ADDED
File without changes
|
app/tests/test_api.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/tests/test_api_with_worker.py
|
2 |
+
import pytest
|
3 |
+
from httpx import AsyncClient
|
4 |
+
from unittest.mock import AsyncMock, patch, MagicMock
|
5 |
+
import asyncio
|
6 |
+
from app.core.app import create_app
|
7 |
+
from app.core.config import settings
|
8 |
+
|
9 |
+
app = create_app()
|
10 |
+
TEST_API_KEY = "test_fixture_key"
|
11 |
+
HEADERS = {"x-api-key": TEST_API_KEY}
|
12 |
+
|
13 |
+
|
14 |
+
@pytest.fixture(autouse=True)
|
15 |
+
def mock_api_key_dependency():
|
16 |
+
with patch("app.core.security.verify_api_key", return_value=True):
|
17 |
+
yield
|
18 |
+
|
19 |
+
|
20 |
+
@pytest.fixture(scope="function")
|
21 |
+
def fresh_future_and_queue_mock():
|
22 |
+
"""Creates a fresh future and mocks the task_queue.put logic."""
|
23 |
+
future = asyncio.Future()
|
24 |
+
|
25 |
+
with patch("asyncio.get_event_loop") as mock_loop:
|
26 |
+
mock_event_loop = MagicMock()
|
27 |
+
mock_event_loop.create_future.return_value = future
|
28 |
+
mock_loop.return_value = mock_event_loop
|
29 |
+
|
30 |
+
with patch("app.queue.task_queue.put", new_callable=AsyncMock) as mock_put:
|
31 |
+
def side_effect(task_data):
|
32 |
+
pass # We'll manually control the future in test
|
33 |
+
mock_put.side_effect = side_effect
|
34 |
+
yield future, mock_put
|
35 |
+
|
36 |
+
|
37 |
+
@pytest.fixture(scope="module")
|
38 |
+
async def client():
|
39 |
+
async with AsyncClient(app=app, base_url="http://test") as ac:
|
40 |
+
yield ac
|
41 |
+
|
42 |
+
|
43 |
+
async def test_root(client):
|
44 |
+
response = await client.get("/")
|
45 |
+
assert response.status_code == 200
|
46 |
+
assert response.json() == {"message": "Welcome to Wellsaid API"}
|
47 |
+
|
48 |
+
|
49 |
+
@patch('app.services.grammar.GrammarCorrector.correct', new_callable=AsyncMock)
|
50 |
+
async def test_grammar(mock_correct, client, fresh_future_and_queue_mock):
|
51 |
+
future, mock_put = fresh_future_and_queue_mock
|
52 |
+
original = "She go to school."
|
53 |
+
corrected = "She goes to school."
|
54 |
+
mock_correct.return_value = corrected
|
55 |
+
future.set_result(corrected)
|
56 |
+
|
57 |
+
response = await client.post("/grammar", json={"text": original}, headers=HEADERS)
|
58 |
+
assert response.status_code == 200
|
59 |
+
data = response.json()["grammar"]
|
60 |
+
assert data["original_text"] == original
|
61 |
+
assert data["corrected_text_suggestion"] == corrected
|
62 |
+
assert "issues" in data
|
63 |
+
mock_put.assert_called_once()
|
64 |
+
|
65 |
+
|
66 |
+
@patch('app.services.paraphrase.Paraphraser.paraphrase', new_callable=AsyncMock)
|
67 |
+
async def test_paraphrase(mock_paraphrase, client, fresh_future_and_queue_mock):
|
68 |
+
future, mock_put = fresh_future_and_queue_mock
|
69 |
+
input_text = "This is a simple sentence."
|
70 |
+
result_text = "Here's a straightforward phrase."
|
71 |
+
mock_paraphrase.return_value = result_text
|
72 |
+
future.set_result(result_text)
|
73 |
+
|
74 |
+
response = await client.post("/paraphrase", json={"text": input_text}, headers=HEADERS)
|
75 |
+
assert response.status_code == 200
|
76 |
+
assert response.json()["result"] == result_text
|
77 |
+
mock_put.assert_called_once()
|
78 |
+
|
79 |
+
|
80 |
+
@patch('app.services.tone.ToneClassifier.classify', new_callable=AsyncMock)
|
81 |
+
async def test_tone(mock_classify, client, fresh_future_and_queue_mock):
|
82 |
+
future, mock_put = fresh_future_and_queue_mock
|
83 |
+
tone_result = "Positive"
|
84 |
+
mock_classify.return_value = tone_result
|
85 |
+
future.set_result(tone_result)
|
86 |
+
|
87 |
+
response = await client.post("/tone", json={"text": "Great job!"}, headers=HEADERS)
|
88 |
+
assert response.status_code == 200
|
89 |
+
assert response.json()["result"] == tone_result
|
90 |
+
mock_put.assert_called_once()
|
91 |
+
|
92 |
+
|
93 |
+
@patch('app.services.translation.Translator.translate', new_callable=AsyncMock)
|
94 |
+
async def test_translate(mock_translate, client, fresh_future_and_queue_mock):
|
95 |
+
future, mock_put = fresh_future_and_queue_mock
|
96 |
+
translated = "Bonjour"
|
97 |
+
mock_translate.return_value = translated
|
98 |
+
future.set_result(translated)
|
99 |
+
|
100 |
+
response = await client.post("/translate", json={"text": "Hello", "target_lang": "fr"}, headers=HEADERS)
|
101 |
+
assert response.status_code == 200
|
102 |
+
assert response.json()["result"] == translated
|
103 |
+
mock_put.assert_called_once()
|
104 |
+
|
105 |
+
|
106 |
+
@patch('app.services.voice.VoiceAnalyzer.analyze_voice', new_callable=AsyncMock)
|
107 |
+
async def test_voice(mock_voice, client, fresh_future_and_queue_mock):
|
108 |
+
future, mock_put = fresh_future_and_queue_mock
|
109 |
+
mock_voice.return_value = "Passive"
|
110 |
+
future.set_result("Passive")
|
111 |
+
|
112 |
+
response = await client.post("/voice", json={"text": "The ball was thrown."}, headers=HEADERS)
|
113 |
+
assert response.status_code == 200
|
114 |
+
assert response.json()["result"] == "Passive"
|
115 |
+
mock_put.assert_called_once()
|
app/tests/test_services.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from app.services.translation import Translator
|
3 |
+
from app.services.tone_classification import ToneClassifier
|
4 |
+
from app.services.voice_detection import VoiceDetector
|
5 |
+
from app.services.gpt4_rewrite import GPT4Rewriter
|
6 |
+
from app.services.grammar import GrammarCorrector
|
7 |
+
from app.services.paraphrase import Paraphraser
|
8 |
+
from app.services.inclusive_language import InclusiveLanguageChecker
|
9 |
+
|
10 |
+
|
11 |
+
# --- Translation Tests ---
|
12 |
+
@pytest.fixture(scope="module")
|
13 |
+
def translator():
|
14 |
+
return Translator()
|
15 |
+
|
16 |
+
|
17 |
+
def test_translate_valid(translator):
|
18 |
+
response = translator.translate("Hello", "fr")
|
19 |
+
assert "result" in response
|
20 |
+
assert response["error"] is None
|
21 |
+
|
22 |
+
|
23 |
+
def test_translate_empty(translator):
|
24 |
+
response = translator.translate("", "fr")
|
25 |
+
assert response["result"] == ""
|
26 |
+
assert response["error"] == "Input text is empty."
|
27 |
+
|
28 |
+
|
29 |
+
def test_translate_invalid_lang(translator):
|
30 |
+
response = translator.translate("Hello", "xx")
|
31 |
+
assert "Unsupported target language" in response["error"]
|
32 |
+
|
33 |
+
|
34 |
+
# --- Tone Classification Tests ---
|
35 |
+
@pytest.fixture(scope="module")
|
36 |
+
def tone_classifier():
|
37 |
+
return ToneClassifier()
|
38 |
+
|
39 |
+
|
40 |
+
def test_tone_classify_valid(tone_classifier):
|
41 |
+
response = tone_classifier.classify("I am very happy today!")
|
42 |
+
assert "result" in response
|
43 |
+
assert response["error"] is None
|
44 |
+
|
45 |
+
|
46 |
+
def test_tone_classify_empty(tone_classifier):
|
47 |
+
response = tone_classifier.classify("")
|
48 |
+
assert response["result"] == ""
|
49 |
+
assert response["error"] == "Input text is empty."
|
50 |
+
|
51 |
+
|
52 |
+
# --- Voice Detection Tests ---
|
53 |
+
@pytest.fixture(scope="module")
|
54 |
+
def voice_detector():
|
55 |
+
return VoiceDetector()
|
56 |
+
|
57 |
+
|
58 |
+
def test_voice_classify_active(voice_detector):
|
59 |
+
response = voice_detector.classify("The dog chased the cat.")
|
60 |
+
assert response["result"] in ["Active", "Passive"]
|
61 |
+
assert response["error"] is None
|
62 |
+
|
63 |
+
|
64 |
+
def test_voice_classify_empty(voice_detector):
|
65 |
+
response = voice_detector.classify("")
|
66 |
+
assert response["result"] == ""
|
67 |
+
assert response["error"] == "Input text is empty."
|
68 |
+
|
69 |
+
|
70 |
+
# --- GPT-4 Rewrite Tests ---
|
71 |
+
@pytest.fixture(scope="module")
|
72 |
+
def gpt4_rewriter():
|
73 |
+
return GPT4Rewriter()
|
74 |
+
|
75 |
+
|
76 |
+
def test_gpt4_rewrite_valid(gpt4_rewriter):
|
77 |
+
response = gpt4_rewriter.rewrite(
|
78 |
+
"Rewrite this professionally.", "your_key_here", "You are a helpful assistant."
|
79 |
+
)
|
80 |
+
assert "result" in response or "error" in response
|
81 |
+
|
82 |
+
|
83 |
+
def test_gpt4_rewrite_missing_input(gpt4_rewriter):
|
84 |
+
response = gpt4_rewriter.rewrite("", "your_key_here", "instruction")
|
85 |
+
assert response["error"] == "Input text is empty."
|
86 |
+
|
87 |
+
|
88 |
+
def test_gpt4_rewrite_missing_key(gpt4_rewriter):
|
89 |
+
response = gpt4_rewriter.rewrite("Text", "", "instruction")
|
90 |
+
assert response["error"] == "Missing OpenAI API key."
|
91 |
+
|
92 |
+
|
93 |
+
def test_gpt4_rewrite_missing_instruction(gpt4_rewriter):
|
94 |
+
response = gpt4_rewriter.rewrite("Text", "your_key_here", "")
|
95 |
+
assert response["error"] == "Missing rewrite instruction."
|
96 |
+
|
97 |
+
|
98 |
+
# --- Grammar Correction Tests ---
|
99 |
+
@pytest.fixture(scope="module")
|
100 |
+
def grammar_corrector():
|
101 |
+
return GrammarCorrector()
|
102 |
+
|
103 |
+
|
104 |
+
def test_grammar_correct_valid(grammar_corrector):
|
105 |
+
response = grammar_corrector.correct("She go to school.")
|
106 |
+
assert "result" in response
|
107 |
+
assert response["error"] is None
|
108 |
+
|
109 |
+
|
110 |
+
def test_grammar_correct_empty(grammar_corrector):
|
111 |
+
response = grammar_corrector.correct("")
|
112 |
+
assert response["result"] == ""
|
113 |
+
assert response["error"] == "Input text is empty."
|
114 |
+
|
115 |
+
|
116 |
+
# --- Paraphraser Tests ---
|
117 |
+
@pytest.fixture(scope="module")
|
118 |
+
def paraphraser():
|
119 |
+
return Paraphraser()
|
120 |
+
|
121 |
+
|
122 |
+
def test_paraphrase_valid(paraphraser):
|
123 |
+
response = paraphraser.paraphrase("This is a test sentence.")
|
124 |
+
assert "result" in response
|
125 |
+
assert response["error"] is None
|
126 |
+
|
127 |
+
|
128 |
+
def test_paraphrase_empty(paraphraser):
|
129 |
+
response = paraphraser.paraphrase("")
|
130 |
+
assert response["result"] == ""
|
131 |
+
assert response["error"] == "Input text is empty."
|
132 |
+
|
133 |
+
|
134 |
+
# --- Inclusive Language Checker Tests ---
|
135 |
+
@pytest.fixture(scope="module")
|
136 |
+
def inclusive_checker():
|
137 |
+
return InclusiveLanguageChecker()
|
138 |
+
|
139 |
+
|
140 |
+
def test_inclusive_check_valid(inclusive_checker):
|
141 |
+
response = inclusive_checker.check("The chairman will arrive soon.")
|
142 |
+
assert "result" in response
|
143 |
+
assert isinstance(response["result"], list)
|
144 |
+
|
145 |
+
|
146 |
+
def test_inclusive_check_empty(inclusive_checker):
|
147 |
+
response = inclusive_checker.check("")
|
148 |
+
assert response["result"] == ""
|
149 |
+
assert response["error"] == "Input text is empty."
|
requirements.txt
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
fastapi
|
2 |
uvicorn[standard]
|
3 |
-
transformers
|
4 |
-
torch
|
5 |
sentencepiece
|
6 |
-
pyspellchecker
|
7 |
spacy
|
8 |
nltk
|
9 |
-
scikit-learn
|
10 |
textstat
|
11 |
-
numpy
|
12 |
pydantic_settings
|
13 |
openai
|
14 |
slowapi
|
15 |
pydantic
|
16 |
-
tenacity
|
|
|
|
|
|
1 |
fastapi
|
2 |
uvicorn[standard]
|
3 |
+
transformers==4.41.2
|
4 |
+
torch==2.2.2 --index-url https://download.pytorch.org/whl/cpu
|
5 |
sentencepiece
|
|
|
6 |
spacy
|
7 |
nltk
|
|
|
8 |
textstat
|
9 |
+
numpy==1.26.4
|
10 |
pydantic_settings
|
11 |
openai
|
12 |
slowapi
|
13 |
pydantic
|
14 |
+
tenacity
|
15 |
+
sentence-transformers==2.6.1
|
16 |
+
|
run.py
CHANGED
@@ -1,15 +1,10 @@
|
|
1 |
-
# run.py (no change needed for path handling if app.main:app is used)
|
2 |
import uvicorn
|
3 |
import os
|
4 |
-
import sys
|
5 |
-
# from pathlib import Path # No longer strictly necessary to add project_root to sys.path explicitly here
|
6 |
-
# sys.path.insert(0, str(project_root)) # Can likely remove this line if your only problem was 'models' import
|
7 |
|
8 |
-
host = "0.0.0.0"
|
9 |
-
port = 7860
|
10 |
-
app_module = "app.main:app"
|
11 |
|
12 |
if __name__ == "__main__":
|
13 |
print(f"Starting Uvicorn server for {app_module} at http://{host}:{port}")
|
14 |
-
|
15 |
-
uvicorn.run(app_module, host=host, port=port, reload=True)
|
|
|
|
|
1 |
import uvicorn
|
2 |
import os
|
|
|
|
|
|
|
3 |
|
4 |
+
host = os.getenv("HOST", "0.0.0.0")
|
5 |
+
port = int(os.getenv("PORT", 7860))
|
6 |
+
app_module = "app.main:app"
|
7 |
|
8 |
if __name__ == "__main__":
|
9 |
print(f"Starting Uvicorn server for {app_module} at http://{host}:{port}")
|
10 |
+
uvicorn.run(app_module, host=host, port=port, reload=os.getenv("RELOAD", "true") == "true", log_level="info")
|
|