zach commited on
Commit
ec0cbf8
·
1 Parent(s): fcb34bb

Add additional validation logic for prompt and generated text to prevent calling APIs unless valid and displaying the error for user feedback.

Browse files
Files changed (2) hide show
  1. src/app.py +37 -21
  2. src/utils.py +2 -2
src/app.py CHANGED
@@ -37,35 +37,46 @@ from src.theme import CustomTheme
37
  from src.utils import truncate_text, validate_prompt_length
38
 
39
 
40
- def generate_text(prompt: str) -> Union[str, gr.update]:
41
  """
42
- Generates text using the Claude API.
 
 
43
 
44
  Args:
45
- prompt (str): User-provided text prompt.
46
 
47
  Returns:
48
- Union[str, gr.update]: The generated text wrapped in `gr.update` for Gradio UI,
49
- or an error message as a string if validation fails.
 
50
  """
51
- logger.info(f'Generating text with prompt: {truncate_text(prompt, max_length=100)}')
52
  try:
53
- # Validate prompt length
54
- validate_prompt_length(prompt, PROMPT_MAX_LENGTH, PROMPT_MIN_LENGTH)
 
 
55
 
56
- # Generate text
57
  generated_text = generate_text_with_claude(prompt)
 
 
 
58
  logger.info(f'Generated text ({len(generated_text)} characters).')
 
 
 
 
 
 
59
 
60
- return gr.update(value=generated_text)
61
-
62
- except ValueError as ve:
63
- logger.warning(f'Validation error: {ve}')
64
- return str(ve)
65
 
66
  def text_to_speech(prompt: str, generated_text: str) -> tuple[gr.update, gr.update, dict, str | None]:
67
  """
68
  Converts generated text to speech using Hume AI and ElevenLabs APIs.
 
 
 
69
 
70
  Args:
71
  prompt (str): The original user-provided prompt.
@@ -78,6 +89,10 @@ def text_to_speech(prompt: str, generated_text: str) -> tuple[gr.update, gr.upda
78
  - `options_map`: A dictionary mapping OPTION_ONE and OPTION_TWO to their providers.
79
  - `option_2_audio`: The second audio file path or `None` if an error occurs.
80
  """
 
 
 
 
81
  try:
82
  # Generate TTS output in parallel
83
  with ThreadPoolExecutor(max_workers=2) as executor:
@@ -111,8 +126,8 @@ def text_to_speech(prompt: str, generated_text: str) -> tuple[gr.update, gr.upda
111
  )
112
 
113
  except Exception as e:
114
- logger.error(f'Unexpected error: {e}')
115
- return None, None, {}
116
 
117
 
118
  def vote(
@@ -202,6 +217,7 @@ def build_gradio_interface() -> gr.Blocks:
202
  autoscroll=False,
203
  lines=5,
204
  max_lines=5,
 
205
  show_copy_button=True,
206
  )
207
 
@@ -256,14 +272,14 @@ def build_gradio_interface() -> gr.Blocks:
256
  vote_submitted_state
257
  ]
258
  ).then(
259
- # Generate text from user prompt
260
- fn=generate_text,
261
  inputs=[prompt_input],
262
- outputs=[generated_text]
263
  ).then(
264
- # Synthesize text to speech and trigger playback of generated audio
265
  fn=text_to_speech,
266
- inputs=[prompt_input, generated_text],
267
  outputs=[
268
  option1_audio_player,
269
  option2_audio_player,
 
37
  from src.utils import truncate_text, validate_prompt_length
38
 
39
 
40
+ def validate_and_generate_text(prompt: str) -> tuple[Union[str, gr.update], gr.update]:
41
  """
42
+ Validates the prompt before generating text.
43
+ - If valid, returns the generated text and keeps the button disabled.
44
+ - If invalid, returns an error message and re-enables the button.
45
 
46
  Args:
47
+ prompt (str): The user-provided text prompt.
48
 
49
  Returns:
50
+ tuple[Union[str, gr.update], gr.update]:
51
+ - The generated text or an error message.
52
+ - The updated state of the "Generate" button.
53
  """
 
54
  try:
55
+ validate_prompt_length(prompt, PROMPT_MAX_LENGTH, PROMPT_MIN_LENGTH) # Raises error if invalid
56
+ except ValueError as ve:
57
+ logger.warning(f'Validation error: {ve}')
58
+ return str(ve), gr.update(interactive=True) # Show error, re-enable button
59
 
60
+ try:
61
  generated_text = generate_text_with_claude(prompt)
62
+ if not generated_text:
63
+ raise ValueError("Claude API returned an empty response.")
64
+
65
  logger.info(f'Generated text ({len(generated_text)} characters).')
66
+ return gr.update(value=generated_text), gr.update(interactive=False) # Keep button disabled
67
+
68
+ except Exception as e:
69
+ logger.error(f'Error while generating text with Claude API: {e}')
70
+ return "Error: Failed to generate text. Please try again.", gr.update(interactive=True) # Re-enable button
71
+
72
 
 
 
 
 
 
73
 
74
  def text_to_speech(prompt: str, generated_text: str) -> tuple[gr.update, gr.update, dict, str | None]:
75
  """
76
  Converts generated text to speech using Hume AI and ElevenLabs APIs.
77
+
78
+ If the generated text is invalid (empty or an error message), this function
79
+ does nothing and returns `None` values to prevent TTS from running.
80
 
81
  Args:
82
  prompt (str): The original user-provided prompt.
 
89
  - `options_map`: A dictionary mapping OPTION_ONE and OPTION_TWO to their providers.
90
  - `option_2_audio`: The second audio file path or `None` if an error occurs.
91
  """
92
+ if not generated_text or generated_text.startswith("Error:"):
93
+ logger.warning("Skipping TTS generation due to invalid text.")
94
+ return gr.update(value=None), gr.update(value=None), {}, None # Return empty updates
95
+
96
  try:
97
  # Generate TTS output in parallel
98
  with ThreadPoolExecutor(max_workers=2) as executor:
 
126
  )
127
 
128
  except Exception as e:
129
+ logger.error(f'Unexpected error during TTS generation: {e}')
130
+ return gr.update(), gr.update(), {}, None
131
 
132
 
133
  def vote(
 
217
  autoscroll=False,
218
  lines=5,
219
  max_lines=5,
220
+ max_length=PROMPT_MAX_LENGTH,
221
  show_copy_button=True,
222
  )
223
 
 
272
  vote_submitted_state
273
  ]
274
  ).then(
275
+ # Validate and prompt and generate text
276
+ fn=validate_and_generate_text,
277
  inputs=[prompt_input],
278
+ outputs=[generated_text, generate_button] # Ensure button gets re-enabled on failure
279
  ).then(
280
+ # Validate generated text and synthesize text-to-speech
281
  fn=text_to_speech,
282
+ inputs=[prompt_input, generated_text], # Pass prompt & generated text
283
  outputs=[
284
  option1_audio_player,
285
  option2_audio_player,
src/utils.py CHANGED
@@ -107,14 +107,14 @@ def validate_prompt_length(prompt: str, max_length: int, min_length: int) -> Non
107
  # Check if prompt is too short
108
  if prompt_length < min_length:
109
  raise ValueError(
110
- f'Prompt must be at least {min_length} character(s) long. '
111
  f'Received only {prompt_length}.'
112
  )
113
 
114
  # Check if prompt exceeds max length
115
  if prompt_length > max_length:
116
  raise ValueError(
117
- f'The prompt exceeds the maximum allowed length of {max_length} characters. '
118
  f'Your prompt contains {prompt_length} characters.'
119
  )
120
 
 
107
  # Check if prompt is too short
108
  if prompt_length < min_length:
109
  raise ValueError(
110
+ f'Error: prompt must be at least {min_length} character(s) long. '
111
  f'Received only {prompt_length}.'
112
  )
113
 
114
  # Check if prompt exceeds max length
115
  if prompt_length > max_length:
116
  raise ValueError(
117
+ f'Error: the prompt exceeds the maximum allowed length of {max_length} characters. '
118
  f'Your prompt contains {prompt_length} characters.'
119
  )
120