zach commited on
Commit
b50c10f
·
1 Parent(s): fc85b67

Fix types in utils.py

Browse files
Files changed (2) hide show
  1. src/custom_types.py +2 -2
  2. src/utils.py +17 -12
src/custom_types.py CHANGED
@@ -50,8 +50,8 @@ class VotingResults(TypedDict):
50
  winning_option: OptionKey
51
  option_a_provider: TTSProviderName
52
  option_b_provider: TTSProviderName
53
- option_a_generation_id: str
54
- option_b_generation_id: str
55
  character_description: str
56
  text: str
57
  is_custom_text: bool
 
50
  winning_option: OptionKey
51
  option_a_provider: TTSProviderName
52
  option_b_provider: TTSProviderName
53
+ option_a_generation_id: Optional[str]
54
+ option_b_generation_id: Optional[str]
55
  character_description: str
56
  text: str
57
  is_custom_text: bool
src/utils.py CHANGED
@@ -100,7 +100,7 @@ def validate_character_description_length(character_description: str) -> None:
100
  logger.debug(f"Character description length validation passed for character_description: {truncated_description}")
101
 
102
 
103
- def delete_files_older_than(directory: str, minutes: int = 30) -> None:
104
  """
105
  Delete all files in the specified directory that are older than a given number of minutes.
106
 
@@ -274,7 +274,7 @@ def determine_selected_option(
274
  return selected_option, other_option
275
 
276
 
277
- def determine_comparison_type(provider_a: TTSProviderName, provider_b: TTSProviderName) -> ComparisonType:
278
  """
279
  Determine the comparison type based on the given TTS provider names.
280
 
@@ -300,12 +300,17 @@ def determine_comparison_type(provider_a: TTSProviderName, provider_b: TTSProvid
300
  raise ValueError(f"Invalid provider combination: {provider_a}, {provider_b}")
301
 
302
 
303
- def log_voting_results(voting_results: VotingResults) -> None:
304
  """Log the full voting results."""
305
  logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4))
306
 
307
 
308
- def handle_vote_failure(e: Exception, voting_results: VotingResults, is_dummy_db_session: bool, config: Config) -> None:
 
 
 
 
 
309
  """
310
  Handles logging when creating a vote record fails.
311
 
@@ -318,12 +323,12 @@ def handle_vote_failure(e: Exception, voting_results: VotingResults, is_dummy_db
318
  """
319
  if config.app_env == "prod" or (config.app_env == "dev" and not is_dummy_db_session):
320
  logger.error("Failed to create vote record: %s", e, exc_info=(config.app_env == "prod"))
321
- log_voting_results(voting_results)
322
  if config.app_env == "prod":
323
  raise e
324
  else:
325
  # Dev mode with a dummy session: only log the voting results.
326
- log_voting_results(voting_results)
327
 
328
 
329
  def _persist_vote(db_session_maker: DBSessionMaker, voting_results: VotingResults, config: Config) -> None:
@@ -331,22 +336,22 @@ def _persist_vote(db_session_maker: DBSessionMaker, voting_results: VotingResult
331
  is_dummy_db_session = getattr(db, "is_dummy", False)
332
  if is_dummy_db_session:
333
  logger.info("Vote record created successfully.")
334
- log_voting_results(voting_results)
335
  try:
336
  crud.create_vote(cast(Session, db), voting_results)
337
  except Exception as e:
338
- handle_vote_failure(e, voting_results, is_dummy_db_session, config)
339
  else:
340
  logger.info("Vote record created successfully.")
341
  if config.app_env == "dev":
342
- log_voting_results(voting_results)
343
  finally:
344
  db.close()
345
 
346
 
347
  def submit_voting_results(
348
  option_map: OptionMap,
349
- selected_option: str,
350
  text_modified: bool,
351
  character_description: str,
352
  text: str,
@@ -366,7 +371,7 @@ def submit_voting_results(
366
  """
367
  provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"]
368
  provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"]
369
- comparison_type: ComparisonType = determine_comparison_type(provider_a, provider_b)
370
 
371
  voting_results: VotingResults = {
372
  "comparison_type": comparison_type,
@@ -376,7 +381,7 @@ def submit_voting_results(
376
  "option_b_provider": provider_b,
377
  "option_a_generation_id": option_map[constants.OPTION_A_KEY]["generation_id"],
378
  "option_b_generation_id": option_map[constants.OPTION_B_KEY]["generation_id"],
379
- "voice_description": character_description,
380
  "text": text,
381
  "is_custom_text": text_modified,
382
  }
 
100
  logger.debug(f"Character description length validation passed for character_description: {truncated_description}")
101
 
102
 
103
+ def delete_files_older_than(directory: Path, minutes: int = 30) -> None:
104
  """
105
  Delete all files in the specified directory that are older than a given number of minutes.
106
 
 
274
  return selected_option, other_option
275
 
276
 
277
+ def _determine_comparison_type(provider_a: TTSProviderName, provider_b: TTSProviderName) -> ComparisonType:
278
  """
279
  Determine the comparison type based on the given TTS provider names.
280
 
 
300
  raise ValueError(f"Invalid provider combination: {provider_a}, {provider_b}")
301
 
302
 
303
+ def _log_voting_results(voting_results: VotingResults) -> None:
304
  """Log the full voting results."""
305
  logger.info("Voting results:\n%s", json.dumps(voting_results, indent=4))
306
 
307
 
308
+ def _handle_vote_failure(
309
+ e: Exception,
310
+ voting_results: VotingResults,
311
+ is_dummy_db_session: bool,
312
+ config: Config,
313
+ ) -> None:
314
  """
315
  Handles logging when creating a vote record fails.
316
 
 
323
  """
324
  if config.app_env == "prod" or (config.app_env == "dev" and not is_dummy_db_session):
325
  logger.error("Failed to create vote record: %s", e, exc_info=(config.app_env == "prod"))
326
+ _log_voting_results(voting_results)
327
  if config.app_env == "prod":
328
  raise e
329
  else:
330
  # Dev mode with a dummy session: only log the voting results.
331
+ _log_voting_results(voting_results)
332
 
333
 
334
  def _persist_vote(db_session_maker: DBSessionMaker, voting_results: VotingResults, config: Config) -> None:
 
336
  is_dummy_db_session = getattr(db, "is_dummy", False)
337
  if is_dummy_db_session:
338
  logger.info("Vote record created successfully.")
339
+ _log_voting_results(voting_results)
340
  try:
341
  crud.create_vote(cast(Session, db), voting_results)
342
  except Exception as e:
343
+ _handle_vote_failure(e, voting_results, is_dummy_db_session, config)
344
  else:
345
  logger.info("Vote record created successfully.")
346
  if config.app_env == "dev":
347
+ _log_voting_results(voting_results)
348
  finally:
349
  db.close()
350
 
351
 
352
  def submit_voting_results(
353
  option_map: OptionMap,
354
+ selected_option: OptionKey,
355
  text_modified: bool,
356
  character_description: str,
357
  text: str,
 
371
  """
372
  provider_a: TTSProviderName = option_map[constants.OPTION_A_KEY]["provider"]
373
  provider_b: TTSProviderName = option_map[constants.OPTION_B_KEY]["provider"]
374
+ comparison_type: ComparisonType = _determine_comparison_type(provider_a, provider_b)
375
 
376
  voting_results: VotingResults = {
377
  "comparison_type": comparison_type,
 
381
  "option_b_provider": provider_b,
382
  "option_a_generation_id": option_map[constants.OPTION_A_KEY]["generation_id"],
383
  "option_b_generation_id": option_map[constants.OPTION_B_KEY]["generation_id"],
384
+ "character_description": character_description,
385
  "text": text,
386
  "is_custom_text": text_modified,
387
  }