File size: 5,472 Bytes
20cccb6
 
 
0f77dec
20cccb6
 
 
 
de305ed
 
40403f3
20cccb6
 
de305ed
 
20cccb6
 
 
40403f3
20cccb6
 
 
 
40403f3
20cccb6
 
 
 
 
 
de305ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20cccb6
de305ed
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
crud.py

This module defines the operations for the Expressive TTS Arena project's database.
Since vote records are never updated or deleted, only functions to create and read votes are provided.
"""

# Third-Party Library Imports
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession

# Local Application Imports
from src.config import logger
from src.custom_types import LeaderboardEntry, LeaderboardTableEntries, VotingResults
from src.database.models import VoteResult


async def create_vote(db: AsyncSession, vote_data: VotingResults) -> VoteResult:
    """
    Create a new vote record in the database based on the given VotingResults data.

    Args:
        db (AsyncSession): The SQLAlchemy async database session.
        vote_data (VotingResults): The vote data to persist.

    Returns:
        VoteResult: The newly created vote record.
    """
    try:

        # Create vote record
        vote = VoteResult(
            comparison_type=vote_data["comparison_type"],
            winning_provider=vote_data["winning_provider"],
            winning_option=vote_data["winning_option"],
            option_a_provider=vote_data["option_a_provider"],
            option_b_provider=vote_data["option_b_provider"],
            option_a_generation_id=vote_data["option_a_generation_id"],
            option_b_generation_id=vote_data["option_b_generation_id"],
            voice_description=vote_data["character_description"],
            text=vote_data["text"],
            is_custom_text=vote_data["is_custom_text"],
        )

        db.add(vote)

        try:
            await db.commit()
            await db.refresh(vote)
            logger.info(f"Vote record created successfully: ID={vote.id}")
            return vote
        except SQLAlchemyError as db_error:
            await db.rollback()
            logger.error(f"Database error while creating vote: {db_error}")
            raise
    except ValueError as val_error:
        logger.error(f"Invalid vote data: {val_error}")
        raise
    except Exception as e:
        if db:
            try:
                await db.rollback()
            except Exception as rollback_error:
                logger.error(f"Error during rollback operation: {rollback_error}")

        logger.error(f"Unexpected error creating vote record: {e}")
        raise


async def get_leaderboard_stats(db: AsyncSession) -> LeaderboardTableEntries:
    """
    Fetches voting statistics from the database to populate a leaderboard.

    This function calculates voting statistics for TTS providers, excluding Hume-to-Hume
    comparisons, and returns data structured for a leaderboard display.

    Args:
        db (AsyncSession): The SQLAlchemy async database session.

    Returns:
        LeaderboardTableEntries: A list of LeaderboardEntry objects containing rank,
                                provider name, model name, win rate, and total votes.
    """
    default_leaderboard = [
        LeaderboardEntry("1", "", "", "0%", "0"),
        LeaderboardEntry("2", "", "", "0%", "0")
    ]

    try:
        query = text(
            """
            WITH provider_stats AS (
                -- Get wins for Hume AI
                SELECT
                    'Hume AI' as provider,
                    COUNT(*) as total_comparisons,
                    SUM(CASE WHEN winning_provider = 'Hume AI' THEN 1 ELSE 0 END) as wins
                FROM vote_results
                WHERE comparison_type != 'Hume AI - Hume AI'

                UNION ALL

                -- Get wins for ElevenLabs
                SELECT
                    'ElevenLabs' as provider,
                    COUNT(*) as total_comparisons,
                    SUM(CASE WHEN winning_provider = 'ElevenLabs' THEN 1 ELSE 0 END) as wins
                FROM vote_results
                WHERE comparison_type != 'Hume AI - Hume AI'
            )
            SELECT
                provider,
                CASE
                    WHEN provider = 'Hume AI' THEN 'Octave'
                    WHEN provider = 'ElevenLabs' THEN 'Voice Design'
                END as model,
                CASE
                    WHEN total_comparisons > 0 THEN ROUND((wins * 100.0 / total_comparisons)::numeric, 2)
                    ELSE 0
                END as win_rate,
                wins as total_votes
            FROM provider_stats
            ORDER BY win_rate DESC;
            """
        )

        result = await db.execute(query)
        rows = result.fetchall()

        # Format the data for the leaderboard
        leaderboard_data = []
        for i, row in enumerate(rows, 1):
            provider, model, win_rate, total_votes = row
            leaderboard_entry = LeaderboardEntry(
                rank=f"{i}",
                provider=provider,
                model=model,
                win_rate=f"{win_rate}%",
                votes=f"{total_votes}"
            )
            leaderboard_data.append(leaderboard_entry)

        # If no data was found, return default entries
        if not leaderboard_data:
            return default_leaderboard

        return leaderboard_data

    except SQLAlchemyError as e:
        logger.error(f"Database error while fetching leaderboard stats: {e}")
        return default_leaderboard
    except Exception as e:
        logger.error(f"Unexpected error while fetching leaderboard stats: {e}")
        return default_leaderboard