iamspruce commited on
Commit
6eff95d
·
1 Parent(s): 869988f

updated models

Browse files
Files changed (2) hide show
  1. app/models.py +22 -5
  2. app/routers/analyze.py +9 -7
app/models.py CHANGED
@@ -2,11 +2,14 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
2
  import torch
3
 
4
  # Set the device for model inference (CPU is used by default)
 
5
  device = torch.device("cpu")
6
 
7
  # --- Grammar model ---
8
  # Uses vennify/t5-base-grammar-correction for grammar correction tasks.
9
- # This model takes text and returns a grammatically corrected version.
 
 
10
  grammar_tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
11
  grammar_model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction").to(device)
12
 
@@ -47,6 +50,7 @@ def run_grammar_correction(text: str) -> str:
47
  def run_flan_prompt(prompt: str) -> str:
48
  """
49
  Runs a given prompt through the FLAN-T5 model to generate a response.
 
50
 
51
  Args:
52
  prompt (str): The prompt string to be processed by FLAN-T5.
@@ -56,8 +60,21 @@ def run_flan_prompt(prompt: str) -> str:
56
  """
57
  # Prepare the input for the FLAN-T5 model
58
  inputs = flan_tokenizer(prompt, return_tensors="pt").to(device)
59
- # Generate the output based on the prompt
60
- outputs = flan_model.generate(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # Decode the generated tokens back into a readable string
62
  return flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
63
 
@@ -67,7 +84,7 @@ def run_translation(text: str, target_lang: str) -> str:
67
 
68
  Args:
69
  text (str): The input text to be translated.
70
- target_lang (str): The target language code (e.g., "fr" for French).
71
 
72
  Returns:
73
  str: The translated text.
@@ -87,7 +104,7 @@ def classify_tone(text: str) -> str:
87
  text (str): The input text for tone classification.
88
 
89
  Returns:
90
- str: The detected emotional label (e.g., 'neutral', 'joy', 'sadness').
91
  """
92
  # The tone_classifier returns a list of dictionaries, where each dictionary
93
  # contains 'label' and 'score'. We extract the 'label' from the first (and only) result.
 
2
  import torch
3
 
4
  # Set the device for model inference (CPU is used by default)
5
+ # You can change to "cuda" if a compatible GPU is available for faster processing.
6
  device = torch.device("cpu")
7
 
8
  # --- Grammar model ---
9
  # Uses vennify/t5-base-grammar-correction for grammar correction tasks.
10
+ # Note: This model might not catch all subtle spelling or advanced grammar errors
11
+ # as robustly as larger models or rule-based systems. Its performance depends on
12
+ # its training data.
13
  grammar_tokenizer = AutoTokenizer.from_pretrained("vennify/t5-base-grammar-correction")
14
  grammar_model = AutoModelForSeq2SeqLM.from_pretrained("vennify/t5-base-grammar-correction").to(device)
15
 
 
50
  def run_flan_prompt(prompt: str) -> str:
51
  """
52
  Runs a given prompt through the FLAN-T5 model to generate a response.
53
+ Includes advanced generation parameters for better output quality.
54
 
55
  Args:
56
  prompt (str): The prompt string to be processed by FLAN-T5.
 
60
  """
61
  # Prepare the input for the FLAN-T5 model
62
  inputs = flan_tokenizer(prompt, return_tensors="pt").to(device)
63
+
64
+ # Generate the output with improved parameters:
65
+ # max_new_tokens: Limits the maximum length of the generated response.
66
+ # num_beams: Uses beam search for higher quality, less repetitive outputs.
67
+ # do_sample: Enables sampling, allowing for more diverse outputs.
68
+ # top_k, top_p: Control the sampling process, making it more focused and coherent.
69
+ outputs = flan_model.generate(
70
+ **inputs,
71
+ max_new_tokens=100, # Limit output length to prevent rambling
72
+ num_beams=5, # Use beam search for better quality
73
+ do_sample=True, # Enable sampling for diversity
74
+ top_k=50, # Sample from top 50 most probable tokens
75
+ top_p=0.95, # Sample from tokens that cumulatively exceed 95% probability
76
+ temperature=0.7 # Controls randomness; lower means more deterministic
77
+ )
78
  # Decode the generated tokens back into a readable string
79
  return flan_tokenizer.decode(outputs[0], skip_special_tokens=True)
80
 
 
84
 
85
  Args:
86
  text (str): The input text to be translated.
87
+ target_lang (str): The target language code (e.g., "fr" for French, "es" for Spanish).
88
 
89
  Returns:
90
  str: The translated text.
 
104
  text (str): The input text for tone classification.
105
 
106
  Returns:
107
+ str: The detected emotional label (e.g., 'neutral', 'joy', 'sadness', 'anger', 'fear', 'disgust', 'surprise').
108
  """
109
  # The tone_classifier returns a list of dictionaries, where each dictionary
110
  # contains 'label' and 'score'. We extract the 'label' from the first (and only) result.
app/routers/analyze.py CHANGED
@@ -51,6 +51,9 @@ def analyze_text(payload: AnalyzeInput):
51
 
52
  # --- 1. Grammar Suggestions with Diffs ---
53
  # Get the grammatically corrected version of the original text.
 
 
 
54
  corrected_grammar = models.run_grammar_correction(text)
55
 
56
  grammar_changes = []
@@ -97,13 +100,12 @@ def analyze_text(payload: AnalyzeInput):
97
  tone_suggestion_text = ""
98
  # Provide a simple tone suggestion based on the detected tone.
99
  # This logic can be expanded for more sophisticated suggestions based on context or user goals.
100
- if detected_tone in ["neutral", "joy"]: # Example: if text is neutral or joyful, suggest a formal alternative
101
- # Generate a formal tone version using FLAN-T5.
102
- tone_suggestion_text = models.run_flan_prompt(prompts.tone_prompt(text, "formal"))
103
- elif detected_tone == "anger":
104
- tone_suggestion_text = models.run_flan_prompt(prompts.tone_prompt(text, "calm and professional"))
105
- elif detected_tone == "sadness":
106
- tone_suggestion_text = models.run_flan_prompt(prompts.tone_prompt(text, "more uplifting"))
107
  else:
108
  # If no specific suggestion, indicate that the detected tone is generally fine.
109
  tone_suggestion_text = f"The detected tone '{detected_tone}' seems appropriate for general communication."
 
51
 
52
  # --- 1. Grammar Suggestions with Diffs ---
53
  # Get the grammatically corrected version of the original text.
54
+ # Note: The 'vennify/t5-base-grammar-correction' model's performance
55
+ # can vary. For more robust corrections, especially for subtle spelling
56
+ # and grammar errors, consider a larger or fine-tuned model if needed.
57
  corrected_grammar = models.run_grammar_correction(text)
58
 
59
  grammar_changes = []
 
100
  tone_suggestion_text = ""
101
  # Provide a simple tone suggestion based on the detected tone.
102
  # This logic can be expanded for more sophisticated suggestions based on context or user goals.
103
+ if detected_tone in ["neutral", "joy", "sadness", "anger", "fear", "disgust", "surprise"]:
104
+ # For simplicity, we'll try to make neutral/joy more formal, and other strong emotions more neutral/calm.
105
+ if detected_tone in ["neutral", "joy"]:
106
+ tone_suggestion_text = models.run_flan_prompt(prompts.tone_prompt(text, "formal"))
107
+ else: # For emotions like anger, sadness, fear, etc., suggest a more neutral/calm tone
108
+ tone_suggestion_text = models.run_flan_prompt(prompts.tone_prompt(text, "neutral and calm"))
 
109
  else:
110
  # If no specific suggestion, indicate that the detected tone is generally fine.
111
  tone_suggestion_text = f"The detected tone '{detected_tone}' seems appropriate for general communication."