File size: 3,369 Bytes
054900e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import func, select, update
from bot.cache.redis import build_key, cached, clear_cache
from bot.database.models import UserModel
if TYPE_CHECKING:
from aiogram.types import User
from sqlalchemy.ext.asyncio import AsyncSession
async def add_user(
session: AsyncSession,
user: User,
referrer: str | None,
) -> None:
"""Add a new user to the database."""
user_id: int = user.id
first_name: str = user.first_name
last_name: str | None = user.last_name
username: str | None = user.username
language_code: str | None = user.language_code
is_premium: bool = user.is_premium or False
new_user = UserModel(
id=user_id,
first_name=first_name,
last_name=last_name,
username=username,
language_code=language_code,
is_premium=is_premium,
referrer=referrer,
)
session.add(new_user)
await session.commit()
await clear_cache(user_exists, user_id)
@cached(key_builder=lambda session, user_id: build_key(user_id))
async def user_exists(session: AsyncSession, user_id: int) -> bool:
"""Checks if the user is in the database."""
query = select(UserModel.id).filter_by(id=user_id).limit(1)
result = await session.execute(query)
user = result.scalar_one_or_none()
return bool(user)
@cached(key_builder=lambda session, user_id: build_key(user_id))
async def get_first_name(session: AsyncSession, user_id: int) -> str:
query = select(UserModel.first_name).filter_by(id=user_id)
result = await session.execute(query)
first_name = result.scalar_one_or_none()
return first_name or ""
@cached(key_builder=lambda session, user_id: build_key(user_id))
async def get_language_code(session: AsyncSession, user_id: int) -> str:
query = select(UserModel.language_code).filter_by(id=user_id)
result = await session.execute(query)
language_code = result.scalar_one_or_none()
return language_code or ""
async def set_language_code(
session: AsyncSession,
user_id: int,
language_code: str,
) -> None:
stmt = update(UserModel).where(UserModel.id == user_id).values(language_code=language_code)
await session.execute(stmt)
await session.commit()
@cached(key_builder=lambda session, user_id: build_key(user_id))
async def is_admin(session: AsyncSession, user_id: int) -> bool:
query = select(UserModel.is_admin).filter_by(id=user_id)
result = await session.execute(query)
is_admin = result.scalar_one_or_none()
return bool(is_admin)
async def set_is_admin(session: AsyncSession, user_id: int, is_admin: bool) -> None:
stmt = update(UserModel).where(UserModel.id == user_id).values(is_admin=is_admin)
await session.execute(stmt)
await session.commit()
@cached(key_builder=lambda session: build_key())
async def get_all_users(session: AsyncSession) -> list[UserModel]:
query = select(UserModel)
result = await session.execute(query)
users = result.scalars()
return list(users)
@cached(key_builder=lambda session: build_key())
async def get_user_count(session: AsyncSession) -> int:
query = select(func.count()).select_from(UserModel)
result = await session.execute(query)
count = result.scalar_one_or_none() or 0
return int(count)
|