Spaces:
Running
Running
| import base64 | |
| import random | |
| import warnings | |
| from collections.abc import Coroutine | |
| from datetime import datetime, timedelta, timezone | |
| from typing import TYPE_CHECKING, Annotated | |
| from uuid import UUID | |
| from cryptography.fernet import Fernet | |
| from fastapi import Depends, HTTPException, Security, status | |
| from fastapi.security import APIKeyHeader, APIKeyQuery, OAuth2PasswordBearer | |
| from jose import JWTError, jwt | |
| from loguru import logger | |
| from sqlmodel.ext.asyncio.session import AsyncSession | |
| from starlette.websockets import WebSocket | |
| from langflow.services.database.models.api_key.crud import check_key | |
| from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at | |
| from langflow.services.database.models.user.model import User, UserRead | |
| from langflow.services.deps import get_db_service, get_session, get_settings_service | |
| from langflow.services.settings.service import SettingsService | |
| if TYPE_CHECKING: | |
| from langflow.services.database.models.api_key.model import ApiKey | |
| oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False) | |
| API_KEY_NAME = "x-api-key" | |
| api_key_query = APIKeyQuery(name=API_KEY_NAME, scheme_name="API key query", auto_error=False) | |
| api_key_header = APIKeyHeader(name=API_KEY_NAME, scheme_name="API key header", auto_error=False) | |
| MINIMUM_KEY_LENGTH = 32 | |
| # Source: https://github.com/mrtolkien/fastapi_simple_security/blob/master/fastapi_simple_security/security_api_key.py | |
| async def api_key_security( | |
| query_param: Annotated[str, Security(api_key_query)], | |
| header_param: Annotated[str, Security(api_key_header)], | |
| ) -> UserRead | None: | |
| settings_service = get_settings_service() | |
| result: ApiKey | User | None | |
| async with get_db_service().with_async_session() as db: | |
| if settings_service.auth_settings.AUTO_LOGIN: | |
| # Get the first user | |
| if not settings_service.auth_settings.SUPERUSER: | |
| raise HTTPException( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| detail="Missing first superuser credentials", | |
| ) | |
| result = await get_user_by_username(db, settings_service.auth_settings.SUPERUSER) | |
| elif not query_param and not header_param: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="An API key must be passed as query or header", | |
| ) | |
| elif query_param: | |
| result = await check_key(db, query_param) | |
| else: | |
| result = await check_key(db, header_param) | |
| if not result: | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Invalid or missing API key", | |
| ) | |
| if isinstance(result, User): | |
| return UserRead.model_validate(result, from_attributes=True) | |
| msg = "Invalid result type" | |
| raise ValueError(msg) | |
| async def get_current_user( | |
| token: Annotated[str, Security(oauth2_login)], | |
| query_param: Annotated[str, Security(api_key_query)], | |
| header_param: Annotated[str, Security(api_key_header)], | |
| db: Annotated[AsyncSession, Depends(get_session)], | |
| ) -> User: | |
| if token: | |
| return await get_current_user_by_jwt(token, db) | |
| user = await api_key_security(query_param, header_param) | |
| if user: | |
| return user | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail="Invalid or missing API key", | |
| ) | |
| async def get_current_user_by_jwt( | |
| token: str, | |
| db: AsyncSession, | |
| ) -> User: | |
| settings_service = get_settings_service() | |
| if isinstance(token, Coroutine): | |
| token = await token | |
| secret_key = settings_service.auth_settings.SECRET_KEY.get_secret_value() | |
| if secret_key is None: | |
| logger.error("Secret key is not set in settings.") | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| # Careful not to leak sensitive information | |
| detail="Authentication failure: Verify authentication settings.", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| payload = jwt.decode(token, secret_key, algorithms=[settings_service.auth_settings.ALGORITHM]) | |
| user_id: UUID = payload.get("sub") # type: ignore[assignment] | |
| token_type: str = payload.get("type") # type: ignore[assignment] | |
| if expires := payload.get("exp", None): | |
| expires_datetime = datetime.fromtimestamp(expires, timezone.utc) | |
| if datetime.now(timezone.utc) > expires_datetime: | |
| logger.info("Token expired for user") | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Token has expired.", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| if user_id is None or token_type is None: | |
| logger.info(f"Invalid token payload. Token type: {token_type}") | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid token details.", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| except JWTError as e: | |
| logger.exception("JWT decoding error") | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) from e | |
| user = await get_user_by_id(db, user_id) | |
| if user is None or not user.is_active: | |
| logger.info("User not found or inactive.") | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="User not found or is inactive.", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| return user | |
| async def get_current_user_for_websocket( | |
| websocket: WebSocket, | |
| db: Annotated[AsyncSession, Depends(get_session)], | |
| query_param: Annotated[str, Security(api_key_query)], | |
| ) -> User | None: | |
| token = websocket.query_params.get("token") | |
| api_key = websocket.query_params.get("x-api-key") | |
| if token: | |
| return await get_current_user_by_jwt(token, db) | |
| if api_key: | |
| return await api_key_security(api_key, query_param) | |
| return None | |
| async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]): | |
| if not current_user.is_active: | |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user") | |
| return current_user | |
| async def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User: | |
| if not current_user.is_active: | |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user") | |
| if not current_user.is_superuser: | |
| raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="The user doesn't have enough privileges") | |
| return current_user | |
| def verify_password(plain_password, hashed_password): | |
| settings_service = get_settings_service() | |
| return settings_service.auth_settings.pwd_context.verify(plain_password, hashed_password) | |
| def get_password_hash(password): | |
| settings_service = get_settings_service() | |
| return settings_service.auth_settings.pwd_context.hash(password) | |
| def create_token(data: dict, expires_delta: timedelta): | |
| settings_service = get_settings_service() | |
| to_encode = data.copy() | |
| expire = datetime.now(timezone.utc) + expires_delta | |
| to_encode["exp"] = expire | |
| return jwt.encode( | |
| to_encode, | |
| settings_service.auth_settings.SECRET_KEY.get_secret_value(), | |
| algorithm=settings_service.auth_settings.ALGORITHM, | |
| ) | |
| async def create_super_user( | |
| username: str, | |
| password: str, | |
| db: AsyncSession, | |
| ) -> User: | |
| super_user = await get_user_by_username(db, username) | |
| if not super_user: | |
| super_user = User( | |
| username=username, | |
| password=get_password_hash(password), | |
| is_superuser=True, | |
| is_active=True, | |
| last_login_at=None, | |
| ) | |
| db.add(super_user) | |
| await db.commit() | |
| await db.refresh(super_user) | |
| return super_user | |
| async def create_user_longterm_token(db: AsyncSession) -> tuple[UUID, dict]: | |
| settings_service = get_settings_service() | |
| username = settings_service.auth_settings.SUPERUSER | |
| super_user = await get_user_by_username(db, username) | |
| if not super_user: | |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Super user hasn't been created") | |
| access_token_expires_longterm = timedelta(days=365) | |
| access_token = create_token( | |
| data={"sub": str(super_user.id), "type": "access"}, | |
| expires_delta=access_token_expires_longterm, | |
| ) | |
| # Update: last_login_at | |
| await update_user_last_login_at(super_user.id, db) | |
| return super_user.id, { | |
| "access_token": access_token, | |
| "refresh_token": None, | |
| "token_type": "bearer", | |
| } | |
| def create_user_api_key(user_id: UUID) -> dict: | |
| access_token = create_token( | |
| data={"sub": str(user_id), "type": "api_key"}, | |
| expires_delta=timedelta(days=365 * 2), | |
| ) | |
| return {"api_key": access_token} | |
| def get_user_id_from_token(token: str) -> UUID: | |
| try: | |
| user_id = jwt.get_unverified_claims(token)["sub"] | |
| return UUID(user_id) | |
| except (KeyError, JWTError, ValueError): | |
| return UUID(int=0) | |
| async def create_user_tokens(user_id: UUID, db: AsyncSession, *, update_last_login: bool = False) -> dict: | |
| settings_service = get_settings_service() | |
| access_token_expires = timedelta(seconds=settings_service.auth_settings.ACCESS_TOKEN_EXPIRE_SECONDS) | |
| access_token = create_token( | |
| data={"sub": str(user_id), "type": "access"}, | |
| expires_delta=access_token_expires, | |
| ) | |
| refresh_token_expires = timedelta(seconds=settings_service.auth_settings.REFRESH_TOKEN_EXPIRE_SECONDS) | |
| refresh_token = create_token( | |
| data={"sub": str(user_id), "type": "refresh"}, | |
| expires_delta=refresh_token_expires, | |
| ) | |
| # Update: last_login_at | |
| if update_last_login: | |
| await update_user_last_login_at(user_id, db) | |
| return { | |
| "access_token": access_token, | |
| "refresh_token": refresh_token, | |
| "token_type": "bearer", | |
| } | |
| async def create_refresh_token(refresh_token: str, db: AsyncSession): | |
| settings_service = get_settings_service() | |
| try: | |
| # Ignore warning about datetime.utcnow | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| payload = jwt.decode( | |
| refresh_token, | |
| settings_service.auth_settings.SECRET_KEY.get_secret_value(), | |
| algorithms=[settings_service.auth_settings.ALGORITHM], | |
| ) | |
| user_id: UUID = payload.get("sub") # type: ignore[assignment] | |
| token_type: str = payload.get("type") # type: ignore[assignment] | |
| if user_id is None or token_type == "": | |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") | |
| user_exists = await get_user_by_id(db, user_id) | |
| if user_exists is None: | |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token") | |
| return await create_user_tokens(user_id, db) | |
| except JWTError as e: | |
| logger.exception("JWT decoding error") | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Invalid refresh token", | |
| ) from e | |
| async def authenticate_user(username: str, password: str, db: AsyncSession) -> User | None: | |
| user = await get_user_by_username(db, username) | |
| if not user: | |
| return None | |
| if not user.is_active: | |
| if not user.last_login_at: | |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Waiting for approval") | |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user") | |
| return user if verify_password(password, user.password) else None | |
| def add_padding(s): | |
| # Calculate the number of padding characters needed | |
| padding_needed = 4 - len(s) % 4 | |
| return s + "=" * padding_needed | |
| def ensure_valid_key(s: str) -> bytes: | |
| # If the key is too short, we'll use it as a seed to generate a valid key | |
| if len(s) < MINIMUM_KEY_LENGTH: | |
| # Use the input as a seed for the random number generator | |
| random.seed(s) | |
| # Generate 32 random bytes | |
| key = bytes(random.getrandbits(8) for _ in range(32)) | |
| key = base64.urlsafe_b64encode(key) | |
| else: | |
| key = add_padding(s).encode() | |
| return key | |
| def get_fernet(settings_service: SettingsService): | |
| secret_key: str = settings_service.auth_settings.SECRET_KEY.get_secret_value() | |
| valid_key = ensure_valid_key(secret_key) | |
| return Fernet(valid_key) | |
| def encrypt_api_key(api_key: str, settings_service: SettingsService): | |
| fernet = get_fernet(settings_service) | |
| # Two-way encryption | |
| encrypted_key = fernet.encrypt(api_key.encode()) | |
| return encrypted_key.decode() | |
| def decrypt_api_key(encrypted_api_key: str, settings_service: SettingsService): | |
| fernet = get_fernet(settings_service) | |
| decrypted_key = "" | |
| # Two-way decryption | |
| if isinstance(encrypted_api_key, str): | |
| try: | |
| decrypted_key = fernet.decrypt(encrypted_api_key.encode()).decode() | |
| except Exception: # noqa: BLE001 | |
| logger.debug("Failed to decrypt API key") | |
| decrypted_key = fernet.decrypt(encrypted_api_key).decode() | |
| return decrypted_key | |