Initial deployment of MailPilot application
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .env +18 -0
- Dockerfile +62 -0
- README.md +14 -6
- Spacefile +2 -0
- app/__init__.py +0 -0
- app/__pycache__/__init__.cpython-312.pyc +0 -0
- app/__pycache__/main.cpython-312.pyc +0 -0
- app/__pycache__/router.cpython-312.pyc +0 -0
- app/api/__init__.py +0 -0
- app/api/__pycache__/__init__.cpython-312.pyc +0 -0
- app/api/endpoints/__init__.py +0 -0
- app/api/endpoints/__pycache__/__init__.cpython-312.pyc +0 -0
- app/api/endpoints/v1/__init__.py +0 -0
- app/api/endpoints/v1/__pycache__/__init__.cpython-312.pyc +0 -0
- app/api/endpoints/v1/firebaseauth/__init__.py +0 -0
- app/api/endpoints/v1/firebaseauth/__pycache__/__init__.cpython-312.pyc +0 -0
- app/api/endpoints/v1/firebaseauth/__pycache__/app.cpython-312.pyc +0 -0
- app/api/endpoints/v1/firebaseauth/app.py +361 -0
- app/api/endpoints/v1/login/__init__.py +0 -0
- app/api/endpoints/v1/login/__pycache__/__init__.cpython-312.pyc +0 -0
- app/api/endpoints/v1/login/__pycache__/api.cpython-312.pyc +0 -0
- app/api/endpoints/v1/login/api.py +10 -0
- app/core/__init__.py +0 -0
- app/core/__pycache__/__init__.cpython-312.pyc +0 -0
- app/core/__pycache__/config.cpython-312.pyc +0 -0
- app/core/__pycache__/logger.cpython-312.pyc +0 -0
- app/core/cache/__init__.py +0 -0
- app/core/cache/cache.py +112 -0
- app/core/config.py +97 -0
- app/core/database/__init__.py +0 -0
- app/core/database/__pycache__/__init__.cpython-312.pyc +0 -0
- app/core/database/__pycache__/session_manager.cpython-312.pyc +0 -0
- app/core/database/session_manager.py +64 -0
- app/core/logger.py +181 -0
- app/core/middlewares/__init__.py +20 -0
- app/core/middlewares/__pycache__/__init__.cpython-312.pyc +0 -0
- app/core/middlewares/__pycache__/execution_middleware.cpython-312.pyc +0 -0
- app/core/middlewares/execution_middleware.py +22 -0
- app/llm/llm_interface.py +14 -0
- app/llm/provider/bedrock_provider.py +63 -0
- app/llm/token/token_manager.py +37 -0
- app/main.py +34 -0
- app/migrations/__init__.py +0 -0
- app/migrations/__pycache__/env.cpython-312.pyc +0 -0
- app/migrations/alembic.ini +101 -0
- app/migrations/env.py +96 -0
- app/migrations/script.py.mako +24 -0
- app/migrations/utils.py +24 -0
- app/migrations/versions/2025041655_new_migration_0c372b179073.py +32 -0
- app/migrations/versions/__pycache__/2025041655_new_migration_0c372b179073.cpython-312.pyc +0 -0
.env
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ENVIRONMENT=PROD
|
2 |
+
DATABASE_HOSTNAME=ep-royal-meadow-a4zzp6z8-pooler.us-east-1.aws.neon.tech
|
3 |
+
DATABASE_USER=neondb_owner
|
4 |
+
DATABASE_PASSWORD=npg_Kuh24FTfEsrx
|
5 |
+
DATABASE_PORT=5432
|
6 |
+
DATABASE_DB=neondb
|
7 |
+
DATABASE_SSL_MODE=require
|
8 |
+
|
9 |
+
CACHE_HOST=localhost
|
10 |
+
CACHE_PORT=11211
|
11 |
+
CACHE_TTL=300
|
12 |
+
|
13 |
+
UVICORN_HOST=0.0.0.0
|
14 |
+
UVICORN_PORT=7860
|
15 |
+
|
16 |
+
LOG_LEVEL=INFO
|
17 |
+
LOG_JSON_FORMAT=False
|
18 |
+
ROOT_PATH=/
|
Dockerfile
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#
|
2 |
+
# Dependencies stage
|
3 |
+
#
|
4 |
+
FROM python:3.12-slim-bullseye AS deps
|
5 |
+
|
6 |
+
ENV POETRY_VERSION 1.5.1
|
7 |
+
|
8 |
+
RUN apt-get update && apt-get install --no-install-recommends -y \
|
9 |
+
gcc \
|
10 |
+
libc-dev \
|
11 |
+
libpq-dev \
|
12 |
+
libpq5 \
|
13 |
+
&& rm -rf /var/lib/apt/lists/*
|
14 |
+
|
15 |
+
WORKDIR /tmp
|
16 |
+
COPY ./pyproject.toml /tmp
|
17 |
+
|
18 |
+
RUN pip install --no-cache-dir email-validator==2.1.0
|
19 |
+
COPY requirements.txt .
|
20 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
21 |
+
RUN pip install -q --no-cache-dir poetry==$POETRY_VERSION \
|
22 |
+
&& poetry lock -q -n \
|
23 |
+
&& poetry export -f requirements.txt -o /tmp/requirements.txt --without-hashes \
|
24 |
+
&& pip uninstall -y poetry \
|
25 |
+
&& pip install --no-cache-dir -q -r /tmp/requirements.txt
|
26 |
+
|
27 |
+
#
|
28 |
+
# Base stage
|
29 |
+
#
|
30 |
+
FROM python:3.12-slim-bullseye AS base
|
31 |
+
|
32 |
+
ENV APP_NAME MailPilot_ai_agents
|
33 |
+
ENV PREFIX /opt/MailPilot
|
34 |
+
ENV PREFIX_APP ${PREFIX}/${APP_NAME}
|
35 |
+
|
36 |
+
ENV PYTHONUNBUFFERED 1
|
37 |
+
|
38 |
+
RUN groupadd -g 20001 MailPilot \
|
39 |
+
&& useradd -l -M -u 10001 -g MailPilot MailPilot
|
40 |
+
|
41 |
+
WORKDIR ${PREFIX_APP}
|
42 |
+
|
43 |
+
COPY ./docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh
|
44 |
+
RUN chmod +x /usr/local/bin/docker-entrypoint.sh
|
45 |
+
|
46 |
+
RUN apt-get update && apt-get install --no-install-recommends -y libpq5 postgresql-client \
|
47 |
+
&& rm -rf /var/lib/apt/lists/*
|
48 |
+
|
49 |
+
COPY --from=deps /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
|
50 |
+
COPY --from=deps /usr/local/bin /usr/local/bin
|
51 |
+
COPY . ${PREFIX_APP}
|
52 |
+
|
53 |
+
RUN chown -R MailPilot:MailPilot ${PREFIX_APP}
|
54 |
+
|
55 |
+
# Hugging Face specific configuration
|
56 |
+
EXPOSE 7860
|
57 |
+
ENV UVICORN_PORT=7860
|
58 |
+
ENV UVICORN_HOST=0.0.0.0
|
59 |
+
|
60 |
+
USER MailPilot
|
61 |
+
|
62 |
+
CMD ["uvicorn", "app.main:fastapi_app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
@@ -1,10 +1,18 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
7 |
-
|
8 |
---
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: MailPilot AI Agents
|
3 |
+
emoji: 📧
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
sdk: docker
|
7 |
+
app_port: 7860
|
8 |
---
|
9 |
|
10 |
+
# MailPilot AI Agents API
|
11 |
+
|
12 |
+
FastAPI-based AI agent application for email processing and analysis.
|
13 |
+
|
14 |
+
## API Documentation
|
15 |
+
|
16 |
+
Once deployed, API documentation will be available at:
|
17 |
+
- Swagger UI: `/docs`
|
18 |
+
- ReDoc: `/redoc`
|
Spacefile
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
dockerfile: Dockerfile.huggingface
|
2 |
+
port: 7860
|
app/__init__.py
ADDED
File without changes
|
app/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (147 Bytes). View file
|
|
app/__pycache__/main.cpython-312.pyc
ADDED
Binary file (1.33 kB). View file
|
|
app/__pycache__/router.cpython-312.pyc
ADDED
Binary file (605 Bytes). View file
|
|
app/api/__init__.py
ADDED
File without changes
|
app/api/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (151 Bytes). View file
|
|
app/api/endpoints/__init__.py
ADDED
File without changes
|
app/api/endpoints/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (161 Bytes). View file
|
|
app/api/endpoints/v1/__init__.py
ADDED
File without changes
|
app/api/endpoints/v1/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (164 Bytes). View file
|
|
app/api/endpoints/v1/firebaseauth/__init__.py
ADDED
File without changes
|
app/api/endpoints/v1/firebaseauth/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (177 Bytes). View file
|
|
app/api/endpoints/v1/firebaseauth/__pycache__/app.cpython-312.pyc
ADDED
Binary file (16.2 kB). View file
|
|
app/api/endpoints/v1/firebaseauth/app.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Depends, HTTPException, status, Request, APIRouter
|
2 |
+
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
3 |
+
import firebase_admin
|
4 |
+
from firebase_admin import credentials, auth
|
5 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
6 |
+
from sqlalchemy import select
|
7 |
+
from app.models.api.user import UserCreate, UserSignIn, PasswordReset, TokenVerify, UserResponse
|
8 |
+
from app.models.database.DBUser import DBUser
|
9 |
+
import datetime
|
10 |
+
import os
|
11 |
+
from app.core.database.session_manager import get_db_session as get_db
|
12 |
+
from pydantic import BaseModel, EmailStr
|
13 |
+
|
14 |
+
router = APIRouter(prefix="/FirebaseAuth", tags=["FirebaseAuth related APIs"])
|
15 |
+
|
16 |
+
# Initialize Firebase Admin SDK with better error handling
|
17 |
+
try:
|
18 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
19 |
+
# Try multiple possible paths for the service account file
|
20 |
+
service_account_paths = [
|
21 |
+
"/opt/MailPilot/MailPilot_ai_agents/app/serviceAccountKey/mailpoilt-firebase-adminsdk-fbsvc-26bb455f79.json",
|
22 |
+
os.path.join(current_dir, "../serviceAccountKey/mailpoilt-firebase-adminsdk-fbsvc-26bb455f79.json"),
|
23 |
+
os.path.join(current_dir, "../../serviceAccountKey/mailpoilt-firebase-adminsdk-fbsvc-26bb455f79.json")
|
24 |
+
]
|
25 |
+
|
26 |
+
cred = None
|
27 |
+
for path in service_account_paths:
|
28 |
+
if os.path.exists(path):
|
29 |
+
cred = credentials.Certificate(path)
|
30 |
+
break
|
31 |
+
|
32 |
+
if cred is None:
|
33 |
+
raise FileNotFoundError("Firebase service account key not found")
|
34 |
+
|
35 |
+
if not firebase_admin._apps:
|
36 |
+
firebase_admin.initialize_app(cred)
|
37 |
+
|
38 |
+
except Exception as e:
|
39 |
+
print(f"Firebase initialization error: {str(e)}")
|
40 |
+
# Continue without crashing, but auth functions will fail
|
41 |
+
|
42 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/FirebaseAuth/signin")
|
43 |
+
async def get_current_user(token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(get_db)):
|
44 |
+
try:
|
45 |
+
decoded_token = auth.verify_id_token(token)
|
46 |
+
user_id = decoded_token["uid"]
|
47 |
+
|
48 |
+
# Get the Firebase user
|
49 |
+
firebase_user = auth.get_user(user_id)
|
50 |
+
|
51 |
+
result = await db.execute(select(DBUser).filter(DBUser.firebase_uid == user_id))
|
52 |
+
db_user = result.scalar_one_or_none()
|
53 |
+
|
54 |
+
if db_user is None:
|
55 |
+
raise HTTPException(status_code=404, detail="User not found in database")
|
56 |
+
|
57 |
+
return UserResponse(
|
58 |
+
firebase_uid=db_user.firebase_uid,
|
59 |
+
email=db_user.email,
|
60 |
+
display_name=db_user.display_name,
|
61 |
+
is_active=db_user.is_active,
|
62 |
+
created_at=db_user.created_at,
|
63 |
+
last_login=db_user.last_login,
|
64 |
+
provider=db_user.provider,
|
65 |
+
email_verified=firebase_user.email_verified
|
66 |
+
)
|
67 |
+
except Exception as e:
|
68 |
+
raise HTTPException(
|
69 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
70 |
+
detail=f"Invalid authentication credentials: {str(e)}",
|
71 |
+
headers={"WWW-Authenticate": "Bearer"},
|
72 |
+
)
|
73 |
+
|
74 |
+
@router.post("/signup", response_model=dict)
|
75 |
+
async def create_user(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
|
76 |
+
"""Create a new user with email and password and store in database"""
|
77 |
+
try:
|
78 |
+
# Check if user already exists
|
79 |
+
try:
|
80 |
+
existing_user = auth.get_user_by_email(user_data.email)
|
81 |
+
raise HTTPException(
|
82 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
83 |
+
detail=f"User with email {user_data.email} already exists"
|
84 |
+
)
|
85 |
+
except auth.UserNotFoundError:
|
86 |
+
# This is what we want - user doesn't exist yet
|
87 |
+
pass
|
88 |
+
|
89 |
+
# Create Firebase user
|
90 |
+
firebase_user = auth.create_user(
|
91 |
+
email=user_data.email,
|
92 |
+
password=user_data.password,
|
93 |
+
display_name=user_data.display_name,
|
94 |
+
email_verified=False # Explicitly set to false
|
95 |
+
)
|
96 |
+
|
97 |
+
# Generate email verification link
|
98 |
+
action_code_settings = auth.ActionCodeSettings(
|
99 |
+
url=f"https://mailpoilt.web.app/verify-email?email={user_data.email}",
|
100 |
+
handle_code_in_app=True
|
101 |
+
)
|
102 |
+
verification_link = auth.generate_email_verification_link(
|
103 |
+
user_data.email,
|
104 |
+
action_code_settings
|
105 |
+
)
|
106 |
+
|
107 |
+
# Firebase will handle sending the verification email automatically
|
108 |
+
|
109 |
+
current_time = datetime.datetime.utcnow()
|
110 |
+
|
111 |
+
db_user = DBUser(
|
112 |
+
firebase_uid=firebase_user.uid,
|
113 |
+
email=user_data.email,
|
114 |
+
display_name=user_data.display_name,
|
115 |
+
is_active=True,
|
116 |
+
created_at=current_time,
|
117 |
+
last_login=current_time,
|
118 |
+
provider="email"
|
119 |
+
)
|
120 |
+
|
121 |
+
db.add(db_user)
|
122 |
+
await db.commit()
|
123 |
+
await db.refresh(db_user)
|
124 |
+
|
125 |
+
return {
|
126 |
+
"message": "User created successfully. Please check your email to verify your account.",
|
127 |
+
"verification_link": verification_link, # In production, you might not return this
|
128 |
+
"user": {
|
129 |
+
"firebase_uid": db_user.firebase_uid,
|
130 |
+
"email": db_user.email,
|
131 |
+
"display_name": db_user.display_name,
|
132 |
+
"is_active": db_user.is_active,
|
133 |
+
"created_at": db_user.created_at.isoformat() if db_user.created_at else None,
|
134 |
+
"last_login": db_user.last_login.isoformat() if db_user.last_login else None,
|
135 |
+
"provider": db_user.provider,
|
136 |
+
"email_verified": firebase_user.email_verified
|
137 |
+
}
|
138 |
+
}
|
139 |
+
except Exception as e:
|
140 |
+
await db.rollback()
|
141 |
+
try:
|
142 |
+
if 'firebase_user' in locals():
|
143 |
+
auth.delete_user(firebase_user.uid)
|
144 |
+
except:
|
145 |
+
pass
|
146 |
+
raise HTTPException(
|
147 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
148 |
+
detail=f"Error creating user: {str(e)}"
|
149 |
+
)
|
150 |
+
|
151 |
+
@router.post("/signin", response_model=dict)
|
152 |
+
async def signin_user(user_data: UserSignIn, db: AsyncSession = Depends(get_db)):
|
153 |
+
"""Sign in a user with email and password"""
|
154 |
+
try:
|
155 |
+
try:
|
156 |
+
firebase_user = auth.get_user_by_email(user_data.email)
|
157 |
+
except auth.UserNotFoundError:
|
158 |
+
raise HTTPException(
|
159 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
160 |
+
detail=f"No user found with email: {user_data.email}"
|
161 |
+
)
|
162 |
+
|
163 |
+
# Generate a custom token that can be exchanged for an ID token
|
164 |
+
custom_token = auth.create_custom_token(firebase_user.uid)
|
165 |
+
|
166 |
+
# Update last login time
|
167 |
+
result = await db.execute(select(DBUser).filter(DBUser.firebase_uid == firebase_user.uid))
|
168 |
+
db_user = result.scalar_one_or_none()
|
169 |
+
|
170 |
+
if not db_user:
|
171 |
+
# Create db user if not exists
|
172 |
+
db_user = DBUser(
|
173 |
+
firebase_uid=firebase_user.uid,
|
174 |
+
email=firebase_user.email,
|
175 |
+
display_name=firebase_user.display_name or user_data.email.split('@')[0],
|
176 |
+
is_active=True,
|
177 |
+
created_at=datetime.datetime.utcnow(),
|
178 |
+
last_login=datetime.datetime.utcnow(),
|
179 |
+
provider="email"
|
180 |
+
)
|
181 |
+
db.add(db_user)
|
182 |
+
else:
|
183 |
+
db_user.last_login = datetime.datetime.utcnow()
|
184 |
+
|
185 |
+
await db.commit()
|
186 |
+
await db.refresh(db_user)
|
187 |
+
|
188 |
+
user_info = {
|
189 |
+
"firebase_uid": db_user.firebase_uid,
|
190 |
+
"email": db_user.email,
|
191 |
+
"display_name": db_user.display_name,
|
192 |
+
"is_active": db_user.is_active,
|
193 |
+
"created_at": db_user.created_at.isoformat() if db_user.created_at else None,
|
194 |
+
"last_login": db_user.last_login.isoformat() if db_user.last_login else None,
|
195 |
+
"provider": db_user.provider,
|
196 |
+
"email_verified": firebase_user.email_verified,
|
197 |
+
"custom_token": custom_token.decode("utf-8") if isinstance(custom_token, bytes) else custom_token
|
198 |
+
}
|
199 |
+
|
200 |
+
return {
|
201 |
+
"message": "Login successful",
|
202 |
+
"user": user_info,
|
203 |
+
"custom_token": custom_token.decode("utf-8") if isinstance(custom_token, bytes) else custom_token,
|
204 |
+
"email_verified": firebase_user.email_verified
|
205 |
+
}
|
206 |
+
except Exception as e:
|
207 |
+
if isinstance(e, HTTPException):
|
208 |
+
raise e
|
209 |
+
await db.rollback()
|
210 |
+
raise HTTPException(
|
211 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
212 |
+
detail=f"Authentication failed: {str(e)}"
|
213 |
+
)
|
214 |
+
class EmailVerifyRequest(BaseModel):
|
215 |
+
email: EmailStr
|
216 |
+
|
217 |
+
@router.post("/resend-verification", status_code=status.HTTP_200_OK)
|
218 |
+
async def resend_verification_email(
|
219 |
+
email_data: EmailVerifyRequest = None,
|
220 |
+
current_user: UserResponse = Depends(get_current_user)
|
221 |
+
):
|
222 |
+
"""
|
223 |
+
Resend verification email to a user
|
224 |
+
|
225 |
+
If user is logged in, uses their email.
|
226 |
+
Otherwise, uses the email provided in the request body.
|
227 |
+
"""
|
228 |
+
try:
|
229 |
+
# If email is provided in request body, use that
|
230 |
+
# Otherwise use logged in user's email
|
231 |
+
email = email_data.email if email_data else current_user.email
|
232 |
+
|
233 |
+
# Check if user exists
|
234 |
+
try:
|
235 |
+
firebase_user = auth.get_user_by_email(email)
|
236 |
+
except auth.UserNotFoundError:
|
237 |
+
raise HTTPException(
|
238 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
239 |
+
detail=f"No user found with email: {email}"
|
240 |
+
)
|
241 |
+
|
242 |
+
# Check if email is already verified
|
243 |
+
if firebase_user.email_verified:
|
244 |
+
return {"message": "Email is already verified"}
|
245 |
+
|
246 |
+
# Generate a new verification link
|
247 |
+
action_code_settings = auth.ActionCodeSettings(
|
248 |
+
url=f"https://mailpoilt.web.app/verify-email?email={email}",
|
249 |
+
handle_code_in_app=True
|
250 |
+
)
|
251 |
+
verification_link = auth.generate_email_verification_link(
|
252 |
+
email,
|
253 |
+
action_code_settings
|
254 |
+
)
|
255 |
+
|
256 |
+
return {
|
257 |
+
"message": "Verification email sent successfully",
|
258 |
+
"verification_link": verification_link
|
259 |
+
}
|
260 |
+
except Exception as e:
|
261 |
+
if isinstance(e, HTTPException):
|
262 |
+
raise e
|
263 |
+
raise HTTPException(
|
264 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
265 |
+
detail=f"Failed to resend verification email: {str(e)}"
|
266 |
+
)
|
267 |
+
|
268 |
+
|
269 |
+
email: EmailStr
|
270 |
+
|
271 |
+
@router.post("/check-email-verified")
|
272 |
+
async def check_email_verified(email_data: EmailVerifyRequest):
|
273 |
+
"""Check if a user's email is verified"""
|
274 |
+
try:
|
275 |
+
# Check if user exists
|
276 |
+
try:
|
277 |
+
firebase_user = auth.get_user_by_email(email_data.email)
|
278 |
+
except auth.UserNotFoundError:
|
279 |
+
raise HTTPException(
|
280 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
281 |
+
detail=f"No user found with email: {email_data.email}"
|
282 |
+
)
|
283 |
+
|
284 |
+
return {
|
285 |
+
"email": email_data.email,
|
286 |
+
"email_verified": firebase_user.email_verified
|
287 |
+
}
|
288 |
+
except Exception as e:
|
289 |
+
if isinstance(e, HTTPException):
|
290 |
+
raise e
|
291 |
+
raise HTTPException(
|
292 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
293 |
+
detail=f"Failed to check email verification status: {str(e)}"
|
294 |
+
)
|
295 |
+
|
296 |
+
@router.post("/verify-token", response_model=UserResponse)
|
297 |
+
async def verify_token(token_data: TokenVerify, db: AsyncSession = Depends(get_db)):
|
298 |
+
"""Verify a Firebase ID token or UID and return user data"""
|
299 |
+
try:
|
300 |
+
# First try to verify as an ID token
|
301 |
+
try:
|
302 |
+
decoded_token = auth.verify_id_token(token_data.token)
|
303 |
+
user_id = decoded_token["uid"]
|
304 |
+
except:
|
305 |
+
# If that fails, treat it as a UID
|
306 |
+
user_id = token_data.token
|
307 |
+
|
308 |
+
try:
|
309 |
+
firebase_user = auth.get_user(user_id)
|
310 |
+
except auth.UserNotFoundError:
|
311 |
+
raise HTTPException(
|
312 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
313 |
+
detail="User not found"
|
314 |
+
)
|
315 |
+
|
316 |
+
result = await db.execute(select(DBUser).filter(DBUser.firebase_uid == user_id))
|
317 |
+
db_user = result.scalar_one_or_none()
|
318 |
+
|
319 |
+
if not db_user:
|
320 |
+
# Create DB user if it doesn't exist
|
321 |
+
db_user = DBUser(
|
322 |
+
firebase_uid=user_id,
|
323 |
+
email=firebase_user.email,
|
324 |
+
display_name=firebase_user.display_name or firebase_user.email.split('@')[0],
|
325 |
+
is_active=True,
|
326 |
+
created_at=datetime.datetime.utcnow(),
|
327 |
+
last_login=datetime.datetime.utcnow(),
|
328 |
+
provider="firebase"
|
329 |
+
)
|
330 |
+
db.add(db_user)
|
331 |
+
await db.commit()
|
332 |
+
await db.refresh(db_user)
|
333 |
+
else:
|
334 |
+
# Update last_login time
|
335 |
+
db_user.last_login = datetime.datetime.utcnow()
|
336 |
+
await db.commit()
|
337 |
+
await db.refresh(db_user)
|
338 |
+
|
339 |
+
return UserResponse(
|
340 |
+
firebase_uid=db_user.firebase_uid,
|
341 |
+
email=db_user.email,
|
342 |
+
display_name=db_user.display_name,
|
343 |
+
is_active=db_user.is_active,
|
344 |
+
created_at=db_user.created_at,
|
345 |
+
last_login=db_user.last_login,
|
346 |
+
provider=db_user.provider,
|
347 |
+
email_verified=firebase_user.email_verified
|
348 |
+
)
|
349 |
+
except Exception as e:
|
350 |
+
if isinstance(e, HTTPException):
|
351 |
+
raise e
|
352 |
+
raise HTTPException(
|
353 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
354 |
+
detail=f"Token verification failed: {str(e)}"
|
355 |
+
)
|
356 |
+
@router.post("/token")
|
357 |
+
async def get_token(form_data: OAuth2PasswordRequestForm = Depends(), db: AsyncSession = Depends(get_db)):
|
358 |
+
return await signin_user(
|
359 |
+
UserSignIn(email=form_data.username, password=form_data.password),
|
360 |
+
db
|
361 |
+
)
|
app/api/endpoints/v1/login/__init__.py
ADDED
File without changes
|
app/api/endpoints/v1/login/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (170 Bytes). View file
|
|
app/api/endpoints/v1/login/__pycache__/api.cpython-312.pyc
ADDED
Binary file (729 Bytes). View file
|
|
app/api/endpoints/v1/login/api.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Depends , HTTPException
|
2 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
3 |
+
from app.core.database.session_manager import get_db_session as db_session
|
4 |
+
|
5 |
+
router = APIRouter(prefix="/login", tags=["login related APIs"])
|
6 |
+
|
7 |
+
|
8 |
+
@router.post("/login")
|
9 |
+
async def home():
|
10 |
+
return {"message": "Welcome to the Simple Router!"}
|
app/core/__init__.py
ADDED
File without changes
|
app/core/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (152 Bytes). View file
|
|
app/core/__pycache__/config.cpython-312.pyc
ADDED
Binary file (5.03 kB). View file
|
|
app/core/__pycache__/logger.cpython-312.pyc
ADDED
Binary file (8.36 kB). View file
|
|
app/core/cache/__init__.py
ADDED
File without changes
|
app/core/cache/cache.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from logging import Logger
|
3 |
+
|
4 |
+
from pymemcache.client.base import Client
|
5 |
+
|
6 |
+
from app.core.config import settings
|
7 |
+
from app.core.exceptions.base_exception import (
|
8 |
+
ConnectionException,
|
9 |
+
CouldNotEditMemcache,
|
10 |
+
KeyNotFoundException,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class Cache:
|
15 |
+
"""
|
16 |
+
A generic cache class for interacting with Memcached.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, logger: Logger) -> None:
|
20 |
+
"""
|
21 |
+
Initialize the cache connection.
|
22 |
+
|
23 |
+
:param logger: Logger instance for logging operations
|
24 |
+
"""
|
25 |
+
|
26 |
+
# Load Memcache config from .env
|
27 |
+
self.host = settings.CACHE_HOST
|
28 |
+
self.port = settings.CACHE_HOST
|
29 |
+
self.default_ttl = settings.CACHE_TTL
|
30 |
+
|
31 |
+
self.client = self._initialize_connection()
|
32 |
+
self.logger = logger
|
33 |
+
|
34 |
+
def _initialize_connection(self):
|
35 |
+
"""
|
36 |
+
Establish a connection to the Memcached server.
|
37 |
+
|
38 |
+
:return: Client instance
|
39 |
+
:raises ConnectionException: If the connection cannot be established
|
40 |
+
"""
|
41 |
+
client = Client((self.host, self.port))
|
42 |
+
if client:
|
43 |
+
self.logger.info(f"Connected to Memcached at {self.host}: {self.port}")
|
44 |
+
return client
|
45 |
+
else:
|
46 |
+
raise ConnectionException("Could not connect to Memcached server.")
|
47 |
+
|
48 |
+
def add(self, key: str, value: dict):
|
49 |
+
"""
|
50 |
+
Add an item to the cache.
|
51 |
+
|
52 |
+
:param key: Cache key
|
53 |
+
:param value: Value to store (serialized using pickle)
|
54 |
+
:raises CouldNotEditMemcache: If the item could not be added
|
55 |
+
"""
|
56 |
+
serialized_value = pickle.dumps(value)
|
57 |
+
res = self.client.add(key, serialized_value, expire=self.default_ttl)
|
58 |
+
if not res:
|
59 |
+
raise CouldNotEditMemcache(f"Could not add key {key} to cache.")
|
60 |
+
self.logger.info(f"Added {key} to cache.")
|
61 |
+
|
62 |
+
def get(self, key: str):
|
63 |
+
"""
|
64 |
+
Retrieve an item from the cache.
|
65 |
+
|
66 |
+
:param key: Cache key
|
67 |
+
:return: Deserialized value
|
68 |
+
:raises KeyNotFoundException: If the key is not found in the cache
|
69 |
+
"""
|
70 |
+
byte_string = self.get_raw(key)
|
71 |
+
return pickle.loads(byte_string)
|
72 |
+
|
73 |
+
def get_raw(self, key: str):
|
74 |
+
"""
|
75 |
+
Retrieve the raw byte string from the cache.
|
76 |
+
|
77 |
+
:param key: Cache key
|
78 |
+
:return: Raw byte string
|
79 |
+
:raises KeyNotFoundException: If the key is not found in the cache
|
80 |
+
"""
|
81 |
+
byte_string = self.client.get(key)
|
82 |
+
if not byte_string:
|
83 |
+
raise KeyNotFoundException(f"Key {key} not found in cache.") # noqa: E713
|
84 |
+
return byte_string
|
85 |
+
|
86 |
+
def delete(self, key: str):
|
87 |
+
"""
|
88 |
+
Delete an item from the cache.
|
89 |
+
|
90 |
+
:param key: Cache key
|
91 |
+
:return: Result of the delete operation
|
92 |
+
:raises CouldNotEditMemcache: If the item could not be deleted
|
93 |
+
"""
|
94 |
+
res = self.client.delete(key)
|
95 |
+
if not res:
|
96 |
+
raise CouldNotEditMemcache(f"Could not delete key {key} from cache.")
|
97 |
+
self.logger.info(f"Deleted {key} from cache.")
|
98 |
+
return res
|
99 |
+
|
100 |
+
def update(self, key: str, value: dict):
|
101 |
+
"""
|
102 |
+
Update an item in the cache.
|
103 |
+
|
104 |
+
:param key: Cache key
|
105 |
+
:param value: New value to store (serialized using pickle)
|
106 |
+
:raises CouldNotEditMemcache: If the item could not be updated
|
107 |
+
"""
|
108 |
+
serialized_value = pickle.dumps(value)
|
109 |
+
res = self.client.set(key, serialized_value, expire=self.default_ttl)
|
110 |
+
if not res:
|
111 |
+
raise CouldNotEditMemcache(f"Could not update key {key} in cache.")
|
112 |
+
self.logger.info(f"Updated {key} in cache.")
|
app/core/config.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
File with environment variables and general configuration logic.
|
3 |
+
Environment variables are loaded from `.env`, with default values as fallback.
|
4 |
+
|
5 |
+
For project metadata, pyproject.toml is used.
|
6 |
+
Complex types like lists are read as JSON-encoded strings.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import tomllib
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Literal
|
12 |
+
from urllib.parse import quote_plus
|
13 |
+
|
14 |
+
from environs import Env
|
15 |
+
from pydantic import validator
|
16 |
+
from pydantic_settings import BaseSettings
|
17 |
+
from structlog.stdlib import BoundLogger
|
18 |
+
|
19 |
+
|
20 |
+
from app.core.logger import Logger
|
21 |
+
|
22 |
+
PROJECT_DIR = Path(__file__).parent.parent.parent
|
23 |
+
with open(f"{PROJECT_DIR}/pyproject.toml", "rb") as f:
|
24 |
+
PYPROJECT_CONTENT = tomllib.load(f)["tool"]["poetry"]
|
25 |
+
|
26 |
+
env = Env()
|
27 |
+
env.read_env()
|
28 |
+
|
29 |
+
CORS_ALLOWED_HEADERS = list(map(str.strip, env.list("CORS_ALLOWED_HEADERS", ["*"])))
|
30 |
+
CORS_ORIGINS = list(map(str.strip, env.list("CORS_ORIGINS", ["http://localhost:3000"])))
|
31 |
+
|
32 |
+
|
33 |
+
class Settings(BaseSettings):
|
34 |
+
# CORE SETTINGS
|
35 |
+
ENVIRONMENT: Literal["DEV", "STG", "PROD"] = env.str("ENVIRONMENT", "DEV").upper()
|
36 |
+
|
37 |
+
# CORS SETTINGS
|
38 |
+
# BACKEND_CORS_ORIGINS: list[str] = env.list("BACKEND_CORS_ORIGINS", ["http://localhost:3000"])
|
39 |
+
# BACKEND_CORS_HEADERS: list[str] = env.list("BACKEND_CORS_HEADERS", ["*"])
|
40 |
+
# ALLOWED_HOSTS: list[str] = env.list("ALLOWED_HOSTS", ["*"])
|
41 |
+
|
42 |
+
# LOG SETTINGS
|
43 |
+
LOG_LEVEL: Literal["INFO", "DEBUG", "WARN", "ERROR"] = env.str("LOG_LEVEL", "INFO")
|
44 |
+
LOG_JSON_FORMAT: bool = env.bool("LOG_JSON_FORMAT", False)
|
45 |
+
|
46 |
+
# PROJECT NAME, VERSION AND DESCRIPTION
|
47 |
+
PROJECT_NAME: str = PYPROJECT_CONTENT["name"]
|
48 |
+
VERSION: str = PYPROJECT_CONTENT["version"]
|
49 |
+
DESCRIPTION: str = PYPROJECT_CONTENT["description"]
|
50 |
+
|
51 |
+
ROOT_PATH: str = env.str("ROOT_PATH", "")
|
52 |
+
|
53 |
+
# DOCS SETTINGS
|
54 |
+
DOCS_URL: str = f"{ROOT_PATH}/docs"
|
55 |
+
OPENAPI_URL: str = f"{ROOT_PATH}/openapi.json"
|
56 |
+
|
57 |
+
# POSTGRESQL DATABASE SETTINGS
|
58 |
+
DATABASE_HOSTNAME: str = env.str("DATABASE_HOSTNAME")
|
59 |
+
DATABASE_USER: str = env.str("DATABASE_USER")
|
60 |
+
DATABASE_PASSWORD: str = env.str("DATABASE_PASSWORD")
|
61 |
+
DATABASE_PORT: str = env.str("DATABASE_PORT", "5432")
|
62 |
+
DATABASE_DB: str = env.str("DATABASE_DB")
|
63 |
+
SQLALCHEMY_DATABASE_URI: str = ""
|
64 |
+
|
65 |
+
@validator("SQLALCHEMY_DATABASE_URI")
|
66 |
+
def _assemble_db_connection(cls, v: str, values: dict[str, str]) -> str:
|
67 |
+
return "postgresql+asyncpg://{}:{}@{}:{}/{}".format(
|
68 |
+
values["DATABASE_USER"],
|
69 |
+
quote_plus(values["DATABASE_PASSWORD"]),
|
70 |
+
values["DATABASE_HOSTNAME"],
|
71 |
+
values["DATABASE_PORT"],
|
72 |
+
values["DATABASE_DB"],
|
73 |
+
)
|
74 |
+
|
75 |
+
# UVICORN SETTINGS
|
76 |
+
UVICORN_HOST: str = env.str("UVICORN_HOST", "0.0.0.0")
|
77 |
+
UVICORN_PORT: int = env.int("UVICORN_PORT", 5001)
|
78 |
+
|
79 |
+
CACHE_HOST: str = env.str("CACHE_HOST", "localhost")
|
80 |
+
CACHE_PORT: int = env.int("CACHE_PORT", 11211)
|
81 |
+
CACHE_TTL: int = env.int("CACHE_TTL", 300)
|
82 |
+
|
83 |
+
BEDROCK_MODEL_ID: str = env.str("BEDROCK_MODEL_ID", "anthropic.claude-v2")
|
84 |
+
BEDROCK_PROVIDER: str = env.str("BEDROCK_PROVIDER", "anthropic")
|
85 |
+
AWS_ACCESS_KEY: str = env.str("AWS_ACCESS_KEY", "")
|
86 |
+
AWS_SECRET_KEY: str = env.str("AWS_SECRET_KEY", "")
|
87 |
+
AWS_REGION: str = env.str("AWS_REGION", "us-east-1")
|
88 |
+
|
89 |
+
TOKENIZER_MODEL: str = env.str("TOKENIZER_MODEL")
|
90 |
+
TOKEN_LIMIT_PER_REQUEST: int = env.int("TOKEN_LIMIT_PER_REQUEST", 20000)
|
91 |
+
|
92 |
+
|
93 |
+
settings: Settings = Settings() # type: ignore
|
94 |
+
|
95 |
+
log: BoundLogger = Logger(
|
96 |
+
json_logs=settings.LOG_JSON_FORMAT, log_level=settings.LOG_LEVEL
|
97 |
+
).setup_logging()
|
app/core/database/__init__.py
ADDED
File without changes
|
app/core/database/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (161 Bytes). View file
|
|
app/core/database/__pycache__/session_manager.cpython-312.pyc
ADDED
Binary file (3.48 kB). View file
|
|
app/core/database/session_manager.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
from collections.abc import AsyncIterator
|
3 |
+
|
4 |
+
from sqlalchemy.ext.asyncio import (
|
5 |
+
AsyncConnection,
|
6 |
+
AsyncSession,
|
7 |
+
async_sessionmaker,
|
8 |
+
create_async_engine,
|
9 |
+
)
|
10 |
+
from sqlalchemy.orm import declarative_base
|
11 |
+
|
12 |
+
from app.core.config import settings
|
13 |
+
|
14 |
+
Base = declarative_base()
|
15 |
+
|
16 |
+
# Inspiration https://praciano.com.br/fastapi-and-async-sqlalchemy-20-with-pytest-done-right.html
|
17 |
+
|
18 |
+
|
19 |
+
class DatabaseSessionManager:
|
20 |
+
def __init__(self):
|
21 |
+
self._engine = create_async_engine(settings.SQLALCHEMY_DATABASE_URI)
|
22 |
+
self._sessionmaker = async_sessionmaker(autocommit=False, bind=self._engine)
|
23 |
+
|
24 |
+
async def close(self):
|
25 |
+
if self._engine is None:
|
26 |
+
raise Exception("DatabaseSessionManager is not initialized")
|
27 |
+
await self._engine.dispose()
|
28 |
+
|
29 |
+
self._engine = None
|
30 |
+
self._sessionmaker = None
|
31 |
+
|
32 |
+
@contextlib.asynccontextmanager
|
33 |
+
async def connect(self) -> AsyncIterator[AsyncConnection]:
|
34 |
+
if self._engine is None:
|
35 |
+
raise Exception("DatabaseSessionManager is not initialized")
|
36 |
+
|
37 |
+
async with self._engine.begin() as connection:
|
38 |
+
try:
|
39 |
+
yield connection
|
40 |
+
except Exception:
|
41 |
+
await connection.rollback()
|
42 |
+
raise
|
43 |
+
|
44 |
+
@contextlib.asynccontextmanager
|
45 |
+
async def session(self) -> AsyncIterator[AsyncSession]:
|
46 |
+
if self._sessionmaker is None:
|
47 |
+
raise Exception("DatabaseSessionManager is not initialized")
|
48 |
+
|
49 |
+
session = self._sessionmaker()
|
50 |
+
try:
|
51 |
+
yield session
|
52 |
+
except Exception:
|
53 |
+
await session.rollback()
|
54 |
+
raise
|
55 |
+
finally:
|
56 |
+
await session.close()
|
57 |
+
|
58 |
+
|
59 |
+
sessionmanager = DatabaseSessionManager()
|
60 |
+
|
61 |
+
|
62 |
+
async def get_db_session():
|
63 |
+
async with sessionmanager.session() as session:
|
64 |
+
yield session
|
app/core/logger.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from logging.handlers import TimedRotatingFileHandler
|
5 |
+
|
6 |
+
import structlog
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from structlog.processors import CallsiteParameter
|
9 |
+
from structlog.stdlib import BoundLogger
|
10 |
+
from structlog.typing import EventDict, Processor
|
11 |
+
|
12 |
+
# Load environment variables
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
|
16 |
+
class Logger:
|
17 |
+
"""
|
18 |
+
Configure and setup logging with Structlog.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
json_logs (bool, optional): Whether to log in JSON format. Defaults to False.
|
22 |
+
log_level (str, optional): Minimum log level to display. Defaults to "INFO".
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, json_logs: bool = False, log_level: str = "INFO"):
|
26 |
+
self.json_logs = json_logs
|
27 |
+
self.log_level = log_level.upper()
|
28 |
+
|
29 |
+
self.environment = os.getenv("ENVIRONMENT", "PROD").upper() # Default to PROD
|
30 |
+
self.log_file_path = os.getenv(
|
31 |
+
"LOG_FILE_PATH", self._get_default_log_file_path()
|
32 |
+
)
|
33 |
+
|
34 |
+
def _get_default_log_file_path(self) -> str | None:
|
35 |
+
"""
|
36 |
+
Provides a default log file path outside the project folder.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
str: The default log file path.
|
40 |
+
"""
|
41 |
+
return
|
42 |
+
# default_log_dir = os.path.expanduser("./logs")
|
43 |
+
# if not os.path.exists(default_log_dir):
|
44 |
+
# os.makedirs(default_log_dir)
|
45 |
+
# return os.path.join(default_log_dir, "app.log")
|
46 |
+
|
47 |
+
def _rename_event_key(self, _, __, event_dict: EventDict) -> EventDict:
|
48 |
+
"""
|
49 |
+
Renames the 'event' key to 'message' in log entries.
|
50 |
+
"""
|
51 |
+
event_dict["message"] = event_dict.pop("event", "")
|
52 |
+
return event_dict
|
53 |
+
|
54 |
+
def _drop_color_message_key(self, _, __, event_dict: EventDict) -> EventDict:
|
55 |
+
"""
|
56 |
+
Removes the 'color_message' key from log entries.
|
57 |
+
"""
|
58 |
+
event_dict.pop("color_message", None)
|
59 |
+
return event_dict
|
60 |
+
|
61 |
+
def _get_processors(self) -> list[Processor]:
|
62 |
+
"""
|
63 |
+
Returns a list of structlog processors based on the specified configuration.
|
64 |
+
"""
|
65 |
+
processors: list[Processor] = [
|
66 |
+
structlog.contextvars.merge_contextvars,
|
67 |
+
structlog.stdlib.add_logger_name,
|
68 |
+
structlog.stdlib.add_log_level,
|
69 |
+
structlog.stdlib.PositionalArgumentsFormatter(),
|
70 |
+
structlog.stdlib.ExtraAdder(),
|
71 |
+
self._drop_color_message_key,
|
72 |
+
structlog.processors.TimeStamper(fmt="iso"),
|
73 |
+
structlog.processors.StackInfoRenderer(),
|
74 |
+
structlog.processors.CallsiteParameterAdder(
|
75 |
+
[
|
76 |
+
CallsiteParameter.FILENAME,
|
77 |
+
CallsiteParameter.FUNC_NAME,
|
78 |
+
CallsiteParameter.LINENO,
|
79 |
+
],
|
80 |
+
),
|
81 |
+
]
|
82 |
+
|
83 |
+
if self.json_logs:
|
84 |
+
processors.append(self._rename_event_key)
|
85 |
+
processors.append(structlog.processors.format_exc_info)
|
86 |
+
|
87 |
+
return processors
|
88 |
+
|
89 |
+
def _clear_uvicorn_loggers(self):
|
90 |
+
"""
|
91 |
+
Clears the log handlers for uvicorn loggers.
|
92 |
+
"""
|
93 |
+
for _log in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
|
94 |
+
logging.getLogger(_log).handlers.clear()
|
95 |
+
logging.getLogger(_log).propagate = True
|
96 |
+
|
97 |
+
def _configure_structlog(self, processors: list[Processor]):
|
98 |
+
"""
|
99 |
+
Configures structlog with the specified processors.
|
100 |
+
"""
|
101 |
+
structlog.configure(
|
102 |
+
processors=processors
|
103 |
+
+ [
|
104 |
+
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
105 |
+
],
|
106 |
+
logger_factory=structlog.stdlib.LoggerFactory(),
|
107 |
+
cache_logger_on_first_use=True,
|
108 |
+
)
|
109 |
+
|
110 |
+
def _configure_logging(self, processors: list[Processor]) -> logging.Logger:
|
111 |
+
"""
|
112 |
+
Configures logging with the specified processors based on the environment.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
logging.Logger: The configured root logger.
|
116 |
+
"""
|
117 |
+
formatter = structlog.stdlib.ProcessorFormatter(
|
118 |
+
foreign_pre_chain=processors,
|
119 |
+
processors=[
|
120 |
+
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
121 |
+
structlog.processors.JSONRenderer()
|
122 |
+
if self.json_logs
|
123 |
+
else structlog.dev.ConsoleRenderer(colors=True),
|
124 |
+
],
|
125 |
+
)
|
126 |
+
|
127 |
+
root_logger = logging.getLogger()
|
128 |
+
root_logger.handlers.clear() # Clear existing handlers
|
129 |
+
|
130 |
+
if self.environment == "DEV":
|
131 |
+
# Console logging for development
|
132 |
+
stream_handler = logging.StreamHandler()
|
133 |
+
stream_handler.setFormatter(formatter)
|
134 |
+
root_logger.addHandler(stream_handler)
|
135 |
+
else:
|
136 |
+
# File logging for production
|
137 |
+
file_handler = TimedRotatingFileHandler(
|
138 |
+
filename=self.log_file_path,
|
139 |
+
when="midnight",
|
140 |
+
interval=1,
|
141 |
+
backupCount=7,
|
142 |
+
encoding="utf-8",
|
143 |
+
)
|
144 |
+
file_handler.setFormatter(formatter)
|
145 |
+
root_logger.addHandler(file_handler)
|
146 |
+
|
147 |
+
root_logger.setLevel(self.log_level.upper())
|
148 |
+
return root_logger
|
149 |
+
|
150 |
+
def _configure(self):
|
151 |
+
"""
|
152 |
+
Configures logging and structlog, and sets up exception handling.
|
153 |
+
"""
|
154 |
+
shared_processors: list[Processor] = self._get_processors()
|
155 |
+
self._configure_structlog(shared_processors)
|
156 |
+
root_logger = self._configure_logging(shared_processors)
|
157 |
+
self._clear_uvicorn_loggers()
|
158 |
+
|
159 |
+
def handle_exception(exc_type, exc_value, exc_traceback):
|
160 |
+
"""
|
161 |
+
Logs uncaught exceptions.
|
162 |
+
"""
|
163 |
+
if issubclass(exc_type, KeyboardInterrupt):
|
164 |
+
sys.__excepthook__(exc_type, exc_value, exc_traceback)
|
165 |
+
return
|
166 |
+
|
167 |
+
root_logger.error(
|
168 |
+
"Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)
|
169 |
+
)
|
170 |
+
|
171 |
+
sys.excepthook = handle_exception
|
172 |
+
|
173 |
+
def setup_logging(self) -> BoundLogger:
|
174 |
+
"""
|
175 |
+
Sets up logging configuration for the application.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
BoundLogger: The configured logger instance.
|
179 |
+
"""
|
180 |
+
self._configure()
|
181 |
+
return structlog.get_logger()
|
app/core/middlewares/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
4 |
+
|
5 |
+
from app.core.config import CORS_ALLOWED_HEADERS, CORS_ORIGINS
|
6 |
+
from app.core.middlewares.execution_middleware import measure_execution_time
|
7 |
+
|
8 |
+
|
9 |
+
def add_middlewares(app: FastAPI) -> None:
|
10 |
+
"""
|
11 |
+
Wrap FastAPI application, with various of middlewares
|
12 |
+
"""
|
13 |
+
app.add_middleware(
|
14 |
+
CORSMiddleware,
|
15 |
+
allow_origins=["*"], # For development only. In production, use specific origins
|
16 |
+
allow_credentials=True,
|
17 |
+
allow_methods=["*"],
|
18 |
+
allow_headers=["*"],
|
19 |
+
)
|
20 |
+
app.add_middleware(BaseHTTPMiddleware, dispatch=measure_execution_time)
|
app/core/middlewares/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (1.01 kB). View file
|
|
app/core/middlewares/__pycache__/execution_middleware.cpython-312.pyc
ADDED
Binary file (1.01 kB). View file
|
|
app/core/middlewares/execution_middleware.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import time
|
3 |
+
|
4 |
+
from fastapi import Request
|
5 |
+
|
6 |
+
from app.core.config import log
|
7 |
+
|
8 |
+
|
9 |
+
async def measure_execution_time(request: Request, call_next):
|
10 |
+
start_time = time.time()
|
11 |
+
response = await call_next(request)
|
12 |
+
process_time = time.time() - start_time
|
13 |
+
response.headers["X-Process-Time"] = f"{process_time:.2f} s" # noqa: E231
|
14 |
+
|
15 |
+
log_dict = {
|
16 |
+
"url": request.url.path,
|
17 |
+
"method": request.method,
|
18 |
+
"process_time": process_time,
|
19 |
+
}
|
20 |
+
log.info(log_dict, extra=log_dict)
|
21 |
+
|
22 |
+
return response
|
app/llm/llm_interface.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from langchain_core.messages import BaseMessage
|
3 |
+
|
4 |
+
|
5 |
+
class LLMInterface(ABC):
|
6 |
+
@abstractmethod
|
7 |
+
def query(self, messages: list[BaseMessage]) -> BaseMessage:
|
8 |
+
"""Query the LLM with a list of messages"""
|
9 |
+
pass
|
10 |
+
|
11 |
+
@abstractmethod
|
12 |
+
async def aquery(self, messages: list[BaseMessage]) -> BaseMessage:
|
13 |
+
"""Asynchronously query the LLM with a list of messages"""
|
14 |
+
pass
|
app/llm/provider/bedrock_provider.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import threading
|
2 |
+
from langchain_core.messages import BaseMessage
|
3 |
+
from langchain_aws import ChatBedrock
|
4 |
+
|
5 |
+
from app.llm.token.token_manager import TokenManager
|
6 |
+
from app.core.config import settings
|
7 |
+
from app.llm.llm_interface import LLMInterface
|
8 |
+
|
9 |
+
|
10 |
+
class BedrockProvider(LLMInterface):
|
11 |
+
_instance = None
|
12 |
+
_lock = threading.Lock()
|
13 |
+
token_manager = TokenManager(token_limit=50000, reset_interval=30)
|
14 |
+
|
15 |
+
def __new__(cls):
|
16 |
+
if cls._instance is None:
|
17 |
+
with cls._lock:
|
18 |
+
if cls._instance is None:
|
19 |
+
cls._instance = super().__new__(cls)
|
20 |
+
cls._instance._initialized = False
|
21 |
+
return cls._instance
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
if not self._initialized:
|
25 |
+
self.model_id = settings.BEDROCK_MODEL_ID
|
26 |
+
self.aws_access_key = settings.AWS_ACCESS_KEY
|
27 |
+
self.aws_secret_key = settings.AWS_SECRET_KEY
|
28 |
+
self.aws_region = settings.AWS_REGION
|
29 |
+
self.provider = settings.BEDROCK_PROVIDER
|
30 |
+
|
31 |
+
# Initialize BedrockChat
|
32 |
+
self.llm = ChatBedrock(
|
33 |
+
model_id=self.model_id,
|
34 |
+
region_name=self.aws_region,
|
35 |
+
aws_access_key_id=self.aws_access_key,
|
36 |
+
aws_secret_access_key=self.aws_secret_key,
|
37 |
+
provider=self.provider,
|
38 |
+
streaming=False,
|
39 |
+
model_kwargs={
|
40 |
+
"temperature": 0.7,
|
41 |
+
"max_tokens": 2000
|
42 |
+
}
|
43 |
+
)
|
44 |
+
|
45 |
+
self._initialized = True
|
46 |
+
|
47 |
+
def query(self, messages: list[BaseMessage]) -> BaseMessage:
|
48 |
+
"""Query AWS Bedrock with messages"""
|
49 |
+
response = self.llm.invoke(messages)
|
50 |
+
self._track_tokens(response)
|
51 |
+
return response
|
52 |
+
|
53 |
+
async def aquery(self, messages: list[BaseMessage]) -> BaseMessage:
|
54 |
+
"""Asynchronous query method"""
|
55 |
+
response = await self.llm.ainvoke(messages)
|
56 |
+
self._track_tokens(response)
|
57 |
+
return response
|
58 |
+
|
59 |
+
def _track_tokens(self, response: BaseMessage) -> None:
|
60 |
+
"""Helper to track token usage"""
|
61 |
+
token_usage = response.response_metadata.get("token_usage", {}) if hasattr(response, "response_metadata") else {}
|
62 |
+
total_tokens = token_usage.get("total_tokens", 0)
|
63 |
+
self.token_manager.track_tokens(total_tokens)
|
app/llm/token/token_manager.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from threading import Lock
|
3 |
+
|
4 |
+
|
5 |
+
class TokenManager:
|
6 |
+
def __init__(self, token_limit: int = 50000, reset_interval: int = 30):
|
7 |
+
self.token_limit = token_limit
|
8 |
+
self.reset_interval = reset_interval
|
9 |
+
self.token_count = 0
|
10 |
+
self.last_reset = time.time()
|
11 |
+
self.lock = Lock()
|
12 |
+
|
13 |
+
def track_tokens(self, tokens: int) -> None:
|
14 |
+
"""
|
15 |
+
Track token usage and reset if needed
|
16 |
+
"""
|
17 |
+
with self.lock:
|
18 |
+
current_time = time.time()
|
19 |
+
if current_time - self.last_reset > self.reset_interval:
|
20 |
+
self.token_count = 0
|
21 |
+
self.last_reset = current_time
|
22 |
+
|
23 |
+
self.token_count += tokens
|
24 |
+
if self.token_count > self.token_limit:
|
25 |
+
print(f"Warning: Token limit of {self.token_limit} exceeded!")
|
26 |
+
|
27 |
+
def get_token_usage(self) -> int:
|
28 |
+
"""
|
29 |
+
Get current token usage
|
30 |
+
"""
|
31 |
+
with self.lock:
|
32 |
+
current_time = time.time()
|
33 |
+
if current_time - self.last_reset > self.reset_interval:
|
34 |
+
self.token_count = 0
|
35 |
+
self.last_reset = current_time
|
36 |
+
|
37 |
+
return self.token_count
|
app/main.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Main FastAPI app instance declaration."""
|
2 |
+
import fastapi
|
3 |
+
import structlog
|
4 |
+
import uvicorn
|
5 |
+
|
6 |
+
from app.core.config import settings
|
7 |
+
from .core.middlewares import add_middlewares
|
8 |
+
from app.router import api_router
|
9 |
+
|
10 |
+
# Set up structlog for logging
|
11 |
+
logger = structlog.get_logger()
|
12 |
+
|
13 |
+
|
14 |
+
fastapi_app = fastapi.FastAPI(
|
15 |
+
title=settings.PROJECT_NAME,
|
16 |
+
version=settings.VERSION,
|
17 |
+
description=settings.DESCRIPTION,
|
18 |
+
openapi_url=settings.OPENAPI_URL,
|
19 |
+
docs_url=settings.DOCS_URL,
|
20 |
+
)
|
21 |
+
fastapi_app.include_router(api_router)
|
22 |
+
add_middlewares(fastapi_app)
|
23 |
+
# Log the app startup
|
24 |
+
logger.info(
|
25 |
+
"Application started", project=settings.PROJECT_NAME, version=settings.VERSION
|
26 |
+
)
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
uvicorn.run(
|
30 |
+
"main:fastapi_app",
|
31 |
+
host=settings.UVICORN_HOST,
|
32 |
+
port=settings.UVICORN_PORT,
|
33 |
+
reload=True,
|
34 |
+
)
|
app/migrations/__init__.py
ADDED
File without changes
|
app/migrations/__pycache__/env.cpython-312.pyc
ADDED
Binary file (3.41 kB). View file
|
|
app/migrations/alembic.ini
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# A generic, single database configuration.
|
2 |
+
|
3 |
+
[alembic]
|
4 |
+
# path to migration scripts
|
5 |
+
script_location = /opt/MailPilot/MailPilot_ai_agents/app/migrations
|
6 |
+
|
7 |
+
# template used to generate migration files
|
8 |
+
file_template = %%(year)d%%(month).2d%%(day).2d%%(minute).2d_%%(slug)s_%%(rev)s
|
9 |
+
|
10 |
+
# sys.path path, will be prepended to sys.path if present.
|
11 |
+
# defaults to the current working directory.
|
12 |
+
prepend_sys_path = .
|
13 |
+
|
14 |
+
# timezone to use when rendering the date within the migration file
|
15 |
+
# as well as the filename.
|
16 |
+
# If specified, requires the python-dateutil library that can be
|
17 |
+
# installed by adding `alembic[tz]` to the pip requirements
|
18 |
+
# string value is passed to dateutil.tz.gettz()
|
19 |
+
# leave blank for localtime
|
20 |
+
# timezone =
|
21 |
+
|
22 |
+
# max length of characters to apply to the
|
23 |
+
# "slug" field
|
24 |
+
truncate_slug_length = 40
|
25 |
+
|
26 |
+
# set to 'true' to run the environment during
|
27 |
+
# the 'revision' command, regardless of autogenerate
|
28 |
+
# revision_environment = false
|
29 |
+
|
30 |
+
# set to 'true' to allow .pyc and .pyo files without
|
31 |
+
# a source .py file to be detected as revisions in the
|
32 |
+
# versions/ directory
|
33 |
+
# sourceless = false
|
34 |
+
|
35 |
+
# version location specification; This defaults
|
36 |
+
# to migrations/versions. When using multiple version
|
37 |
+
# directories, initial revisions must be specified with --version-path.
|
38 |
+
# The path separator used here should be the separator specified by "version_path_separator"
|
39 |
+
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
|
40 |
+
|
41 |
+
# version path separator; As mentioned above, this is the character used to split
|
42 |
+
# version_locations. Valid values are:
|
43 |
+
#
|
44 |
+
# version_path_separator = :
|
45 |
+
# version_path_separator = ;
|
46 |
+
# version_path_separator = space
|
47 |
+
version_path_separator = os # default: use os.pathsep
|
48 |
+
|
49 |
+
# the output encoding used when revision files
|
50 |
+
# are written from script.py.mako
|
51 |
+
# output_encoding = utf-8
|
52 |
+
|
53 |
+
sqlalchemy.url = driver://user:pass@localhost/dbname
|
54 |
+
|
55 |
+
|
56 |
+
[post_write_hooks]
|
57 |
+
# post_write_hooks defines scripts or Python functions that are run
|
58 |
+
# on newly generated revision scripts. See the documentation for further
|
59 |
+
# detail and examples
|
60 |
+
|
61 |
+
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
62 |
+
hooks = black
|
63 |
+
|
64 |
+
black.type = console_scripts
|
65 |
+
black.entrypoint = black
|
66 |
+
black.options = REVISION_SCRIPT_FILENAME
|
67 |
+
|
68 |
+
# Logging configuration
|
69 |
+
[loggers]
|
70 |
+
keys = root,sqlalchemy,alembic
|
71 |
+
|
72 |
+
[handlers]
|
73 |
+
keys = console
|
74 |
+
|
75 |
+
[formatters]
|
76 |
+
keys = generic
|
77 |
+
|
78 |
+
[logger_root]
|
79 |
+
level = WARN
|
80 |
+
handlers = console
|
81 |
+
qualname =
|
82 |
+
|
83 |
+
[logger_sqlalchemy]
|
84 |
+
level = WARN
|
85 |
+
handlers =
|
86 |
+
qualname = sqlalchemy.engine
|
87 |
+
|
88 |
+
[logger_alembic]
|
89 |
+
level = INFO
|
90 |
+
handlers =
|
91 |
+
qualname = alembic
|
92 |
+
|
93 |
+
[handler_console]
|
94 |
+
class = StreamHandler
|
95 |
+
args = (sys.stderr,)
|
96 |
+
level = NOTSET
|
97 |
+
formatter = generic
|
98 |
+
|
99 |
+
[formatter_generic]
|
100 |
+
format = %(levelname)-5.5s [%(name)s] %(message)s
|
101 |
+
datefmt = %H:%M:%S
|
app/migrations/env.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from logging.config import fileConfig
|
3 |
+
|
4 |
+
from alembic import context
|
5 |
+
from sqlalchemy import engine_from_config, pool
|
6 |
+
from sqlalchemy.ext.asyncio import AsyncEngine
|
7 |
+
|
8 |
+
from app.core import config as app_config
|
9 |
+
|
10 |
+
# this is the Alembic Config object, which provides
|
11 |
+
# access to the values within the .ini file in use.
|
12 |
+
config = context.config
|
13 |
+
|
14 |
+
# Interpret the config file for Python logging.
|
15 |
+
# This line sets up loggers basically.
|
16 |
+
fileConfig(config.config_file_name) # type: ignore
|
17 |
+
|
18 |
+
# add your model's MetaData object here
|
19 |
+
# for 'autogenerate' support
|
20 |
+
# from myapp import mymodel
|
21 |
+
# target_metadata = mymodel.Base.metadata
|
22 |
+
from app.models.database.base import Base # noqa
|
23 |
+
|
24 |
+
target_metadata = Base.metadata
|
25 |
+
|
26 |
+
# other values from the config, defined by the needs of env.py,
|
27 |
+
# can be acquired:
|
28 |
+
# my_important_option = config.get_main_option("my_important_option")
|
29 |
+
# ... etc.
|
30 |
+
|
31 |
+
|
32 |
+
def get_database_uri():
|
33 |
+
return app_config.settings.SQLALCHEMY_DATABASE_URI
|
34 |
+
|
35 |
+
|
36 |
+
def run_migrations_offline():
|
37 |
+
"""Run migrations in 'offline' mode.
|
38 |
+
|
39 |
+
This configures the context with just a URL
|
40 |
+
and not an Engine, though an Engine is acceptable
|
41 |
+
here as well. By skipping the Engine creation
|
42 |
+
we don't even need a DBAPI to be available.
|
43 |
+
|
44 |
+
Calls to context.execute() here emit the given string to the
|
45 |
+
script output.
|
46 |
+
|
47 |
+
"""
|
48 |
+
url = get_database_uri()
|
49 |
+
context.configure(
|
50 |
+
url=url,
|
51 |
+
target_metadata=target_metadata,
|
52 |
+
literal_binds=True,
|
53 |
+
dialect_opts={"paramstyle": "named"},
|
54 |
+
compare_type=True,
|
55 |
+
compare_server_default=True,
|
56 |
+
)
|
57 |
+
|
58 |
+
with context.begin_transaction():
|
59 |
+
context.run_migrations()
|
60 |
+
|
61 |
+
|
62 |
+
def do_run_migrations(connection):
|
63 |
+
context.configure(
|
64 |
+
connection=connection, target_metadata=target_metadata, compare_type=True
|
65 |
+
)
|
66 |
+
|
67 |
+
with context.begin_transaction():
|
68 |
+
context.run_migrations()
|
69 |
+
|
70 |
+
|
71 |
+
async def run_migrations_online():
|
72 |
+
"""Run migrations in 'online' mode.
|
73 |
+
|
74 |
+
In this scenario we need to create an Engine
|
75 |
+
and associate a connection with the context.
|
76 |
+
|
77 |
+
"""
|
78 |
+
configuration = config.get_section(config.config_ini_section)
|
79 |
+
assert configuration
|
80 |
+
configuration["sqlalchemy.url"] = get_database_uri()
|
81 |
+
connectable = AsyncEngine(
|
82 |
+
engine_from_config(
|
83 |
+
configuration,
|
84 |
+
prefix="sqlalchemy.",
|
85 |
+
poolclass=pool.NullPool,
|
86 |
+
future=True,
|
87 |
+
) # type: ignore
|
88 |
+
)
|
89 |
+
async with connectable.connect() as connection:
|
90 |
+
await connection.run_sync(do_run_migrations)
|
91 |
+
|
92 |
+
|
93 |
+
if context.is_offline_mode():
|
94 |
+
run_migrations_offline()
|
95 |
+
else:
|
96 |
+
asyncio.run(run_migrations_online())
|
app/migrations/script.py.mako
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""${message}
|
2 |
+
|
3 |
+
Revision ID: ${up_revision}
|
4 |
+
Revises: ${down_revision | comma,n}
|
5 |
+
Create Date: ${create_date}
|
6 |
+
|
7 |
+
"""
|
8 |
+
from alembic import op
|
9 |
+
import sqlalchemy as sa
|
10 |
+
${imports if imports else ""}
|
11 |
+
|
12 |
+
# revision identifiers, used by Alembic.
|
13 |
+
revision = ${repr(up_revision)}
|
14 |
+
down_revision = ${repr(down_revision)}
|
15 |
+
branch_labels = ${repr(branch_labels)}
|
16 |
+
depends_on = ${repr(depends_on)}
|
17 |
+
|
18 |
+
|
19 |
+
def upgrade():
|
20 |
+
${upgrades if upgrades else "pass"}
|
21 |
+
|
22 |
+
|
23 |
+
def downgrade():
|
24 |
+
${downgrades if downgrades else "pass"}
|
app/migrations/utils.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from alembic import op
|
2 |
+
from sqlalchemy import text
|
3 |
+
from sqlalchemy.engine import reflection
|
4 |
+
|
5 |
+
|
6 |
+
def table_has_column(table: str, column: str):
|
7 |
+
if not hasattr(table_has_column, "inspection"):
|
8 |
+
conn = op.get_bind()
|
9 |
+
insp = table_has_column.inspection = reflection.Inspector.from_engine(conn)
|
10 |
+
else:
|
11 |
+
insp = table_has_column.inspection
|
12 |
+
has_column = False
|
13 |
+
for col in insp.get_columns(table):
|
14 |
+
if column not in col["name"]:
|
15 |
+
continue
|
16 |
+
has_column = True
|
17 |
+
return has_column
|
18 |
+
|
19 |
+
|
20 |
+
def table_exists(table):
|
21 |
+
conn = op.get_bind()
|
22 |
+
inspector = reflection.Inspector.from_engine(conn)
|
23 |
+
tables = inspector.get_table_names()
|
24 |
+
return table in tables
|
app/migrations/versions/2025041655_new_migration_0c372b179073.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""new migration
|
2 |
+
|
3 |
+
Revision ID: 0c372b179073
|
4 |
+
Revises:
|
5 |
+
Create Date: 2025-04-16 14:55:45.297069
|
6 |
+
|
7 |
+
"""
|
8 |
+
from alembic import op
|
9 |
+
import sqlalchemy as sa
|
10 |
+
|
11 |
+
# revision identifiers, used by Alembic.
|
12 |
+
revision = '0c372b179073'
|
13 |
+
down_revision = None
|
14 |
+
branch_labels = None
|
15 |
+
depends_on = None
|
16 |
+
|
17 |
+
def upgrade():
|
18 |
+
op.create_table(
|
19 |
+
"users",
|
20 |
+
sa.Column("firebase_uid", sa.String(), nullable=False),
|
21 |
+
sa.Column("email", sa.String(), nullable=False),
|
22 |
+
sa.Column("display_name", sa.String(), nullable=True),
|
23 |
+
sa.Column("is_active", sa.Boolean(), nullable=False, server_default='true'),
|
24 |
+
sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.text('CURRENT_TIMESTAMP')),
|
25 |
+
sa.Column("last_login", sa.DateTime(), nullable=True),
|
26 |
+
sa.Column("provider", sa.String(), nullable=False, server_default='email'),
|
27 |
+
sa.PrimaryKeyConstraint("firebase_uid"),
|
28 |
+
sa.UniqueConstraint("email"),
|
29 |
+
)
|
30 |
+
|
31 |
+
def downgrade():
|
32 |
+
op.drop_table("users")
|
app/migrations/versions/__pycache__/2025041655_new_migration_0c372b179073.cpython-312.pyc
ADDED
Binary file (1.96 kB). View file
|
|