Spaces:
Runtime error
Runtime error
iamspruce
commited on
Commit
·
6eff95d
1
Parent(s):
869988f
updated models
Browse files- app/models.py +22 -5
- app/routers/analyze.py +9 -7
app/models.py
CHANGED
@@ -2,11 +2,14 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
|
2 |
import torch
|
3 |
|
4 |
# Set the device for model inference (CPU is used by default)
|
|
|
5 |
device = torch.device("cpu")
|
6 |
|
7 |
# --- Grammar model ---
|
8 |
# Uses vennify/t5-base-grammar-correction for grammar correction tasks.
|
9 |
-
# This model
|
|
|
|
|
10 |
grammar_tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
|
11 |
grammar_model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction").to(device)
|
12 |
|
@@ -47,6 +50,7 @@ def run_grammar_correction(text: str) -> str:
|
|
47 |
def run_flan_prompt(prompt: str) -> str:
|
48 |
"""
|
49 |
Runs a given prompt through the FLAN-T5 model to generate a response.
|
|
|
50 |
|
51 |
Args:
|
52 |
prompt (str): The prompt string to be processed by FLAN-T5.
|
@@ -56,8 +60,21 @@ def run_flan_prompt(prompt: str) -> str:
|
|
56 |
"""
|
57 |
# Prepare the input for the FLAN-T5 model
|
58 |
inputs = flan_tokenizer(prompt, return_tensors="pt").to(device)
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
# Decode the generated tokens back into a readable string
|
62 |
return flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
63 |
|
@@ -67,7 +84,7 @@ def run_translation(text: str, target_lang: str) -> str:
|
|
67 |
|
68 |
Args:
|
69 |
text (str): The input text to be translated.
|
70 |
-
target_lang (str): The target language code (e.g., "fr" for French).
|
71 |
|
72 |
Returns:
|
73 |
str: The translated text.
|
@@ -87,7 +104,7 @@ def classify_tone(text: str) -> str:
|
|
87 |
text (str): The input text for tone classification.
|
88 |
|
89 |
Returns:
|
90 |
-
str: The detected emotional label (e.g., 'neutral', 'joy', 'sadness').
|
91 |
"""
|
92 |
# The tone_classifier returns a list of dictionaries, where each dictionary
|
93 |
# contains 'label' and 'score'. We extract the 'label' from the first (and only) result.
|
|
|
2 |
import torch
|
3 |
|
4 |
# Set the device for model inference (CPU is used by default)
|
5 |
+
# You can change to "cuda" if a compatible GPU is available for faster processing.
|
6 |
device = torch.device("cpu")
|
7 |
|
8 |
# --- Grammar model ---
|
9 |
# Uses vennify/t5-base-grammar-correction for grammar correction tasks.
|
10 |
+
# Note: This model might not catch all subtle spelling or advanced grammar errors
|
11 |
+
# as robustly as larger models or rule-based systems. Its performance depends on
|
12 |
+
# its training data.
|
13 |
grammar_tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
|
14 |
grammar_model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction").to(device)
|
15 |
|
|
|
50 |
def run_flan_prompt(prompt: str) -> str:
|
51 |
"""
|
52 |
Runs a given prompt through the FLAN-T5 model to generate a response.
|
53 |
+
Includes advanced generation parameters for better output quality.
|
54 |
|
55 |
Args:
|
56 |
prompt (str): The prompt string to be processed by FLAN-T5.
|
|
|
60 |
"""
|
61 |
# Prepare the input for the FLAN-T5 model
|
62 |
inputs = flan_tokenizer(prompt, return_tensors="pt").to(device)
|
63 |
+
|
64 |
+
# Generate the output with improved parameters:
|
65 |
+
# max_new_tokens: Limits the maximum length of the generated response.
|
66 |
+
# num_beams: Uses beam search for higher quality, less repetitive outputs.
|
67 |
+
# do_sample: Enables sampling, allowing for more diverse outputs.
|
68 |
+
# top_k, top_p: Control the sampling process, making it more focused and coherent.
|
69 |
+
outputs = flan_model.generate(
|
70 |
+
**inputs,
|
71 |
+
max_new_tokens=100, # Limit output length to prevent rambling
|
72 |
+
num_beams=5, # Use beam search for better quality
|
73 |
+
do_sample=True, # Enable sampling for diversity
|
74 |
+
top_k=50, # Sample from top 50 most probable tokens
|
75 |
+
top_p=0.95, # Sample from tokens that cumulatively exceed 95% probability
|
76 |
+
temperature=0.7 # Controls randomness; lower means more deterministic
|
77 |
+
)
|
78 |
# Decode the generated tokens back into a readable string
|
79 |
return flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
80 |
|
|
|
84 |
|
85 |
Args:
|
86 |
text (str): The input text to be translated.
|
87 |
+
target_lang (str): The target language code (e.g., "fr" for French, "es" for Spanish).
|
88 |
|
89 |
Returns:
|
90 |
str: The translated text.
|
|
|
104 |
text (str): The input text for tone classification.
|
105 |
|
106 |
Returns:
|
107 |
+
str: The detected emotional label (e.g., 'neutral', 'joy', 'sadness', 'anger', 'fear', 'disgust', 'surprise').
|
108 |
"""
|
109 |
# The tone_classifier returns a list of dictionaries, where each dictionary
|
110 |
# contains 'label' and 'score'. We extract the 'label' from the first (and only) result.
|
app/routers/analyze.py
CHANGED
@@ -51,6 +51,9 @@ def analyze_text(payload: AnalyzeInput):
|
|
51 |
|
52 |
# --- 1. Grammar Suggestions with Diffs ---
|
53 |
# Get the grammatically corrected version of the original text.
|
|
|
|
|
|
|
54 |
corrected_grammar = models.run_grammar_correction(text)
|
55 |
|
56 |
grammar_changes = []
|
@@ -97,13 +100,12 @@ def analyze_text(payload: AnalyzeInput):
|
|
97 |
tone_suggestion_text = ""
|
98 |
# Provide a simple tone suggestion based on the detected tone.
|
99 |
# This logic can be expanded for more sophisticated suggestions based on context or user goals.
|
100 |
-
if detected_tone in ["neutral", "joy"
|
101 |
-
#
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
tone_suggestion_text = models.run_flan_prompt(prompts.tone_prompt(text, "more uplifting"))
|
107 |
else:
|
108 |
# If no specific suggestion, indicate that the detected tone is generally fine.
|
109 |
tone_suggestion_text = f"The detected tone '{detected_tone}' seems appropriate for general communication."
|
|
|
51 |
|
52 |
# --- 1. Grammar Suggestions with Diffs ---
|
53 |
# Get the grammatically corrected version of the original text.
|
54 |
+
# Note: The 'vennify/t5-base-grammar-correction' model's performance
|
55 |
+
# can vary. For more robust corrections, especially for subtle spelling
|
56 |
+
# and grammar errors, consider a larger or fine-tuned model if needed.
|
57 |
corrected_grammar = models.run_grammar_correction(text)
|
58 |
|
59 |
grammar_changes = []
|
|
|
100 |
tone_suggestion_text = ""
|
101 |
# Provide a simple tone suggestion based on the detected tone.
|
102 |
# This logic can be expanded for more sophisticated suggestions based on context or user goals.
|
103 |
+
if detected_tone in ["neutral", "joy", "sadness", "anger", "fear", "disgust", "surprise"]:
|
104 |
+
# For simplicity, we'll try to make neutral/joy more formal, and other strong emotions more neutral/calm.
|
105 |
+
if detected_tone in ["neutral", "joy"]:
|
106 |
+
tone_suggestion_text = models.run_flan_prompt(prompts.tone_prompt(text, "formal"))
|
107 |
+
else: # For emotions like anger, sadness, fear, etc., suggest a more neutral/calm tone
|
108 |
+
tone_suggestion_text = models.run_flan_prompt(prompts.tone_prompt(text, "neutral and calm"))
|
|
|
109 |
else:
|
110 |
# If no specific suggestion, indicate that the detected tone is generally fine.
|
111 |
tone_suggestion_text = f"The detected tone '{detected_tone}' seems appropriate for general communication."
|