File size: 2,396 Bytes
20cccb6
 
 
 
 
 
8e2cadd
 
20cccb6
 
8e2cadd
1ed6720
8e2cadd
20cccb6
1ed6720
20cccb6
 
 
1ed6720
8e2cadd
 
1ed6720
 
8e2cadd
1ed6720
 
8e2cadd
1ed6720
 
8e2cadd
1ed6720
 
 
8e2cadd
1ed6720
 
 
8e2cadd
1ed6720
 
 
20cccb6
1ed6720
 
 
20cccb6
1ed6720
 
 
20cccb6
 
 
1ed6720
 
a4afe51
1ed6720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
database.py

This module sets up the SQLAlchemy database connection for the Expressive TTS Arena project.
It initializes the PostgreSQL engine, creates a session factory for handling database transactions,
and defines a declarative base class for ORM models.

If no DATABASE_URL environment variable is set, then create a dummy database to fail gracefully
"""

# Standard Library Imports
from typing import Callable, Optional

# Third-Party Library Imports
from sqlalchemy import Engine, create_engine
from sqlalchemy.orm import declarative_base, sessionmaker

# Local Application Imports
from src.config import Config


class DummySession:
    is_dummy = True  # Flag to indicate this is a dummy session.

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        pass

    def add(self, _instance, _warn=True):
        # No-op: simply ignore adding the instance.
        pass

    def commit(self):
        # Raise an exception to simulate failure when attempting a write.
        raise RuntimeError("DummySession does not support commit operations.")

    def refresh(self, _instance):
        # Raise an exception to simulate failure when attempting to refresh.
        raise RuntimeError("DummySession does not support refresh operations.")

    def rollback(self):
        # No-op: there's nothing to roll back.
        pass

    def close(self):
        # No-op: nothing to close.
        pass


Base = declarative_base()
engine: Optional[Engine] = None


DBSessionMaker = sessionmaker | Callable[[], DummySession]


def init_db(config: Config) -> DBSessionMaker:
    # ruff doesn't like setting global variables, but this is practical here
    global engine # noqa 

    if config.app_env == "prod":
        # In production, a valid DATABASE_URL is required.
        if not config.database_url:
            raise ValueError("DATABASE_URL must be set in production!")

        engine = create_engine(config.database_url)
        return sessionmaker(bind=engine)
    # In development, if a DATABASE_URL is provided, use it.
    if config.database_url:
        engine = create_engine(config.database_url)
        return sessionmaker(bind=engine)
    # No DATABASE_URL is provided; use a DummySession that does nothing.
    engine = None

    def dummy_session_factory():
        return DummySession()

    return dummy_session_factory