Spaces:
Running
Running
File size: 5,472 Bytes
20cccb6 0f77dec 20cccb6 de305ed 40403f3 20cccb6 de305ed 20cccb6 40403f3 20cccb6 40403f3 20cccb6 de305ed 20cccb6 de305ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""
crud.py
This module defines the operations for the Expressive TTS Arena project's database.
Since vote records are never updated or deleted, only functions to create and read votes are provided.
"""
# Third-Party Library Imports
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
# Local Application Imports
from src.config import logger
from src.custom_types import LeaderboardEntry, LeaderboardTableEntries, VotingResults
from src.database.models import VoteResult
async def create_vote(db: AsyncSession, vote_data: VotingResults) -> VoteResult:
"""
Create a new vote record in the database based on the given VotingResults data.
Args:
db (AsyncSession): The SQLAlchemy async database session.
vote_data (VotingResults): The vote data to persist.
Returns:
VoteResult: The newly created vote record.
"""
try:
# Create vote record
vote = VoteResult(
comparison_type=vote_data["comparison_type"],
winning_provider=vote_data["winning_provider"],
winning_option=vote_data["winning_option"],
option_a_provider=vote_data["option_a_provider"],
option_b_provider=vote_data["option_b_provider"],
option_a_generation_id=vote_data["option_a_generation_id"],
option_b_generation_id=vote_data["option_b_generation_id"],
voice_description=vote_data["character_description"],
text=vote_data["text"],
is_custom_text=vote_data["is_custom_text"],
)
db.add(vote)
try:
await db.commit()
await db.refresh(vote)
logger.info(f"Vote record created successfully: ID={vote.id}")
return vote
except SQLAlchemyError as db_error:
await db.rollback()
logger.error(f"Database error while creating vote: {db_error}")
raise
except ValueError as val_error:
logger.error(f"Invalid vote data: {val_error}")
raise
except Exception as e:
if db:
try:
await db.rollback()
except Exception as rollback_error:
logger.error(f"Error during rollback operation: {rollback_error}")
logger.error(f"Unexpected error creating vote record: {e}")
raise
async def get_leaderboard_stats(db: AsyncSession) -> LeaderboardTableEntries:
"""
Fetches voting statistics from the database to populate a leaderboard.
This function calculates voting statistics for TTS providers, excluding Hume-to-Hume
comparisons, and returns data structured for a leaderboard display.
Args:
db (AsyncSession): The SQLAlchemy async database session.
Returns:
LeaderboardTableEntries: A list of LeaderboardEntry objects containing rank,
provider name, model name, win rate, and total votes.
"""
default_leaderboard = [
LeaderboardEntry("1", "", "", "0%", "0"),
LeaderboardEntry("2", "", "", "0%", "0")
]
try:
query = text(
"""
WITH provider_stats AS (
-- Get wins for Hume AI
SELECT
'Hume AI' as provider,
COUNT(*) as total_comparisons,
SUM(CASE WHEN winning_provider = 'Hume AI' THEN 1 ELSE 0 END) as wins
FROM vote_results
WHERE comparison_type != 'Hume AI - Hume AI'
UNION ALL
-- Get wins for ElevenLabs
SELECT
'ElevenLabs' as provider,
COUNT(*) as total_comparisons,
SUM(CASE WHEN winning_provider = 'ElevenLabs' THEN 1 ELSE 0 END) as wins
FROM vote_results
WHERE comparison_type != 'Hume AI - Hume AI'
)
SELECT
provider,
CASE
WHEN provider = 'Hume AI' THEN 'Octave'
WHEN provider = 'ElevenLabs' THEN 'Voice Design'
END as model,
CASE
WHEN total_comparisons > 0 THEN ROUND((wins * 100.0 / total_comparisons)::numeric, 2)
ELSE 0
END as win_rate,
wins as total_votes
FROM provider_stats
ORDER BY win_rate DESC;
"""
)
result = await db.execute(query)
rows = result.fetchall()
# Format the data for the leaderboard
leaderboard_data = []
for i, row in enumerate(rows, 1):
provider, model, win_rate, total_votes = row
leaderboard_entry = LeaderboardEntry(
rank=f"{i}",
provider=provider,
model=model,
win_rate=f"{win_rate}%",
votes=f"{total_votes}"
)
leaderboard_data.append(leaderboard_entry)
# If no data was found, return default entries
if not leaderboard_data:
return default_leaderboard
return leaderboard_data
except SQLAlchemyError as e:
logger.error(f"Database error while fetching leaderboard stats: {e}")
return default_leaderboard
except Exception as e:
logger.error(f"Unexpected error while fetching leaderboard stats: {e}")
return default_leaderboard
|