Spaces:
Runtime error
Runtime error
from fastapi import Depends, HTTPException, Request | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from jose import jwt, JWTError, ExpiredSignatureError | |
import os | |
from dotenv import load_dotenv | |
# from db.mongoDB import user_collection, blacklist_collection | |
from db.mongoDB import mongo_db | |
import torch | |
import rag_components | |
from schemas.chat import AppState | |
from pydantic import ValidationError | |
from config import SECRET_KEY, ALGORITHM,EMBEDDING_MODEL_NAME,WEAVIATE_COLLECTION_NAME, WEAVIATE_URL | |
from utils.utils import load_legal_dictionary | |
import config | |
from langchain_groq import ChatGroq | |
from typing import Annotated, Optional | |
from schemas.user import UserOut, UserRole | |
from fastapi import status | |
from datetime import datetime, timezone | |
from db.redis import get_redis_client # Giả sử bạn đã định nghĩa hàm này trong db/redis.py | |
from utils.AdvancedLawRetriever import AdvancedLawRetriever | |
from services.reranker_service import get_reranker_compressor | |
from db.weaviateDB import connect_to_weaviate | |
import logging | |
logger = logging.getLogger(__name__) | |
handler = logging.StreamHandler() # Gửi log đến stdout | |
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) | |
logger.addHandler(handler) | |
# Bearer token security scheme | |
bearer_scheme = HTTPBearer(auto_error=False) | |
def get_app_state(request: Request): | |
if not hasattr(request.app.state, 'app_state'): | |
print("Error in get_app_state: request.app.state.app_state is not set!") | |
raise RuntimeError("Application state ('app_state') not found. Initialization failed?") | |
return request.app.state.app_state | |
async def initialize_api_components(app_state: AppState): | |
"""Khởi tạo các thành phần cần thiết cho API """ | |
logger.info("🔸Bắt đầu Khởi tạo API Components") | |
load_dotenv() | |
# --- Kiểm tra kết nối tới Redis --- | |
app_state.process_input_llm = ChatGroq(model=config.GROQ_MODEL_NAME,temperature=0.2) | |
try: | |
app_state.redis = await get_redis_client() # Gọi hàm khởi tạo redis | |
except Exception as e: | |
logger.error(f"☠️ LỖI NGHIÊM TRỌNG khi khởi tạo Redis trong initialize_api_components: {e}") | |
raise | |
app_state.dict = load_legal_dictionary(config.LEGAL_DIC_FOLDER+ "/legal_terms.json") | |
app_state.weaviateDB = connect_to_weaviate(run_diagnostics=False) | |
# --- Kiểm tra kết nối tới MongoDB --- | |
if mongo_db.users is None or app_state.weaviateDB is None: | |
logger.error("🔸Lỗi kết nối tới MongoDB hoặc Weaviate.", mongo_db.users) | |
raise HTTPException(status_code=500, detail="Lỗi kết nối tới database.") | |
# app_state.google_api_key = os.environ.get("GOOGLE_API_KEY") | |
# if not app_state.google_api_key: | |
# logger.error("🔸GG API Key không được cung cấp.") | |
# raise HTTPException(status_code=500, detail="Missing GG API Key") | |
app_state.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
logger.info(f"🔸Sử dụng thiết bị: {app_state.device}") | |
# 1. Tải Embedding Model (giữ nguyên) | |
print(f"Đang tải Embedding Model...") | |
app_state.embeddings = rag_components.get_huggingface_embeddings( | |
EMBEDDING_MODEL_NAME, app_state.device | |
) | |
if not app_state.embeddings: | |
raise HTTPException(status_code=500, detail="Failed to load embedding model") | |
# 2. Tải Vector Store | |
print(f"Đang tải Vector Store...") | |
app_state.vectorstore = rag_components.create_or_load_vectorstore( | |
embeddings=app_state.embeddings, | |
weaviate_url=WEAVIATE_URL, | |
collection_name=WEAVIATE_COLLECTION_NAME, | |
weaviate_client=app_state.weaviateDB, | |
chunks=None, | |
) | |
if not app_state.vectorstore: | |
raise HTTPException(status_code=500, detail="Failed to load or create Vectorstore") | |
# 3. Tải LLM | |
logger.info(f"🔸Đang tải LLM...") | |
# llm = rag_components.get_google_llm(app_state.google_api_key) | |
# app_state.llm = llm | |
if not config.GOOGLE_API_KEYS: | |
logger.error("🚨 BIẾN MÔI TRƯỜNG 'GOOGLE_API_KEYS' CHƯA ĐƯỢC THIẾT LẬP TRÊN HUGGING FACE SPACES!") | |
# Tại đây, bạn có thể dừng chương trình hoặc xử lý lỗi | |
google_api_keys_list = [] | |
else: | |
google_api_keys_list = [key.strip() for key in config.GOOGLE_API_KEYS.split(',') if key.strip()] | |
logger.info(f"✅ Đã tìm thấy và tải {len(google_api_keys_list)} API key từ secrets.") | |
if not google_api_keys_list: | |
logger.error("🚨 KHÔNG CÓ GOOGLE API KEYS NÀO ĐƯỢC CẤP PHÁT!") | |
raise HTTPException(status_code=500, detail="No Google API keys found") | |
# app_state.llm = rag_components.create_llm_from_google_key_list(google_api_keys=google_api_keys_list) | |
app_state.llm = rag_components.get_google_llm(config.GOOGLE_API_KEYS) | |
if not app_state.llm: | |
raise HTTPException(status_code=500, detail="Failed to load LLM") | |
# 4. Tạo retriever (giữ nguyên) | |
logger.info(f"🔸Đang tạo retriever...") | |
app_state.reranker = get_reranker_compressor() # Singleton re-ranker | |
app_state.retriever = AdvancedLawRetriever( | |
client=app_state.weaviateDB, | |
collection_name=WEAVIATE_COLLECTION_NAME, | |
llm=app_state.llm, | |
reranker=app_state.reranker, # Singleton re-ranker | |
embeddings_model=app_state.embeddings | |
) | |
if app_state.retriever is None: | |
raise HTTPException(status_code=500, detail="Failed to create retriever") | |
logger.info(f"🔸Đã tạo retriever thành công.") | |
# 5. Tạo QA Chain (giữ nguyên) | |
logger.info(f"🔸Đang tạo QA Chain...") | |
app_state.qa_chain = rag_components.create_qa_chain( | |
llm=app_state.llm, | |
retriever=app_state.retriever, | |
process_input_llm=app_state.process_input_llm | |
) | |
if app_state.qa_chain is None: | |
raise HTTPException(status_code=500, detail="Failed to create QA Chain") | |
logger.info(f"🔸Khởi tạo API Components hoàn tất ") | |
async def get_access_token_from_cookie(request: Request) -> Optional[str]: | |
""" | |
Lấy access token từ cookie 'access_token_cookie'. | |
""" | |
token = request.cookies.get("access_token_cookie") | |
logger.info(f"Token {token}") | |
logger.debug(f"GET_ACCESS_TOKEN_FROM_COOKIE: Cookies nhận được: {request.cookies}") | |
logger.info(f"GET_ACCESS_TOKEN_FROM_COOKIE: Token trích xuất từ 'access_token_cookie': {'PRESENT' if token else 'MISSING'}") | |
return token | |
async def get_current_user( | |
request: Request, | |
token_from_cookie: Optional[str] = Depends(get_access_token_from_cookie), | |
auth_header: Optional[HTTPAuthorizationCredentials] = Depends(bearer_scheme) # Sẽ raise 403 nếu header Auth sai format | |
) -> UserOut: | |
logger.warning("GET_CURRENT_USER: *** BẮT ĐẦU XÁC THỰC ***") # Sẽ không thấy log này nếu bearer_scheme raise 403 | |
token_to_verify: Optional[str] = None | |
source_of_token: str = "NONE" | |
if token_from_cookie: | |
token_to_verify = token_from_cookie | |
source_of_token = "COOKIE" | |
logger.info("GET_CURRENT_USER: Sử dụng token từ cookie.") | |
elif auth_header: # Chỉ dùng nếu không có token từ cookie | |
token_to_verify = auth_header.credentials | |
source_of_token = "AUTHORIZATION_HEADER" | |
logger.info("GET_CURRENT_USER: Không có token từ cookie, sử dụng token từ Authorization header.") | |
# Không cần else, 'if not token_to_verify' ở dưới sẽ xử lý | |
credentials_exception = HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Không thể xác thực người dùng. Vui lòng đăng nhập lại.", | |
headers={"WWW-Authenticate": "Bearer"}, # Thêm header này là good practice | |
) | |
if not token_to_verify: | |
logger.error(f"GET_CURRENT_USER: *** KHÔNG TÌM THẤY TOKEN (Nguồn: {source_of_token}) - RAISING 401 ***") | |
raise credentials_exception | |
# logger.info(f"GET_CURRENT_USER: Token để verify (nguồn: {source_of_token}): {token_to_verify[:20]}...") | |
# 1. Kiểm tra token trong blacklist | |
try: | |
logger.info("GET_CURRENT_USER: Đang kiểm tra blacklist...") | |
is_blacklisted =await mongo_db.token_blacklist.find_one({"token": token_to_verify}) | |
if is_blacklisted: | |
logger.error(f"GET_CURRENT_USER: *** TOKEN TRONG BLACKLIST - RAISING 401 ***") | |
raise HTTPException( # Sử dụng credentials_exception hoặc một cái cụ thể hơn | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Token đã bị thu hồi hoặc không hợp lệ.", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
logger.info("GET_CURRENT_USER: Token không trong blacklist - OK") | |
except HTTPException: | |
raise | |
except Exception as db_error: | |
logger.error(f"GET_CURRENT_USER: *** LỖI DATABASE BLACKLIST: {db_error} ***") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail="Lỗi máy chủ khi kiểm tra trạng thái token." | |
) | |
# 2. Giải mã và xác thực JWT | |
payload: Optional[dict] = None | |
email: Optional[str] = None | |
try: | |
logger.info("GET_CURRENT_USER: Đang decode JWT...") | |
if not SECRET_KEY: # Kiểm tra này quan trọng | |
logger.error("GET_CURRENT_USER: *** SECRET_KEY CHƯA ĐƯỢC CẤU HÌNH ***") | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Lỗi cấu hình máy chủ.") | |
payload = jwt.decode(token_to_verify, SECRET_KEY, algorithms=[ALGORITHM]) | |
logger.info(f'CHECK: {payload}') | |
email = payload.get("sub") | |
exp = payload.get("exp") | |
# logger.info(f"GET_CURRENT_USER: JWT decode thành công - email: {email}, exp: {exp}") | |
if not isinstance(email, str) or not email: | |
logger.error("GET_CURRENT_USER: *** EMAIL KHÔNG HỢP LỆ TRONG TOKEN ***") | |
raise credentials_exception # Sử dụng lại credentials_exception đã định nghĩa | |
if not isinstance(exp, int): # Thường 'exp' là int (timestamp) | |
logger.error("GET_CURRENT_USER: *** EXP KHÔNG HỢP LỆ TRONG TOKEN ***") | |
raise credentials_exception | |
expiration_datetime = datetime.fromtimestamp(exp, tz=timezone.utc) | |
current_datetime_utc = datetime.now(tz=timezone.utc) | |
if expiration_datetime < current_datetime_utc: | |
logger.error(f"GET_CURRENT_USER: *** TOKEN HẾT HẠN - RAISING 401 ***") | |
raise HTTPException( # Sử dụng credentials_exception hoặc một cái cụ thể hơn | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Token đã hết hạn. Vui lòng đăng nhập lại.", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
logger.info(f"GET_CURRENT_USER: Token còn hạn - OK") | |
except ExpiredSignatureError: # Bắt lỗi cụ thể này từ PyJWT | |
logger.error(f"GET_CURRENT_USER: *** TOKEN HẾT HẠN (ExpiredSignatureError) ***") | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Token đã hết hạn (JWT validation). Vui lòng đăng nhập lại.", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
except JWTError as e: # Bắt lỗi chung từ PyJWT | |
logger.error(f"GET_CURRENT_USER: *** LỖI JWT: {e} ***") | |
raise HTTPException( # Có thể dùng credentials_exception hoặc thông báo cụ thể hơn | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail=f"Token không hợp lệ hoặc có vấn đề khi giải mã.", # Thông báo chung chung hơn | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
except HTTPException: # Re-raise nếu là HTTPException đã được raise từ bên trong try | |
raise | |
except Exception as e_decode: # Bắt các lỗi không mong muốn khác | |
logger.error(f"GET_CURRENT_USER: *** LỖI KHÔNG XÁC ĐỊNH KHI DECODE JWT: {e_decode} ***") | |
raise credentials_exception # Trả về lỗi chung | |
# 3. Lấy thông tin người dùng từ database | |
user_data: Optional[dict] = None # Khởi tạo để tránh UnboundLocalError | |
try: | |
# logger.info(f"GET_CURRENT_USER: Đang tìm user trong DB: {email.lower()}") # email đã được validate là str | |
user_data = await mongo_db.users.find_one({"email": email.lower()}, {"password": 0, "_id": 0}) | |
# print(user_data) # Bỏ print trong production | |
if user_data is None: | |
logger.error(f"GET_CURRENT_USER: *** KHÔNG TÌM THẤY USER TRONG DB ({email.lower()}) - RAISING 401 ***") | |
raise credentials_exception | |
# logger.info(f"GET_CURRENT_USER: Tìm thấy user - data: {user_data}") | |
except HTTPException: | |
raise | |
except Exception as db_user_error: | |
logger.error(f"GET_CURRENT_USER: *** LỖI DATABASE USER: {db_user_error} ***") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail="Lỗi máy chủ khi truy xuất thông tin người dùng." | |
) | |
# 4. Tạo đối tượng UserOut và kiểm tra is_active | |
try: | |
if user_data and ('username' not in user_data or not user_data.get('username')): | |
user_data['username'] = email.lower().split('@')[0] | |
logger.info(f"GET_CURRENT_USER: Set default username: {user_data['username']}") | |
user = UserOut(**user_data) | |
if not user.is_active: | |
logger.error(f"GET_CURRENT_USER: *** TÀI KHOẢN BỊ KHÓA ({user.email}) - RAISING 403 ***") | |
raise HTTPException( | |
status_code=status.HTTP_403_FORBIDDEN, # 403 là phù hợp ở đây | |
detail="Tài khoản của bạn đã bị khóa hoặc không hoạt động. Vui lòng liên hệ quản trị viên.", | |
) | |
logger.info(f"GET_CURRENT_USER: *** XÁC THỰC THÀNH CÔNG *** - User: {user.email}, Active: {user.is_active}, Role: {user.role}") | |
return user | |
except ValidationError as ve: | |
logger.error(f"GET_CURRENT_USER: *** LỖI PYDANTIC VALIDATION: {ve.errors()} ***") | |
logger.error(f"GET_CURRENT_USER: Dữ liệu gây lỗi: {user_data}") | |
raise HTTPException( | |
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, # 422 khi dữ liệu không thể xử lý | |
detail=f"Dữ liệu người dùng không hợp lệ từ DB: {ve.errors()}" # Có thể trả về lỗi cụ thể nếu an toàn | |
) | |
except HTTPException: # Re-raise | |
raise | |
except Exception as e_userout: | |
logger.error(f"GET_CURRENT_USER: *** LỖI TẠO USEROUT HOẶC KIỂM TRA IS_ACTIVE: {e_userout} ***") | |
raise credentials_exception # Lỗi chung nếu không rõ nguyên nhân | |
async def admin_required( | |
current_user: Annotated[UserOut, Depends(get_current_user)] | |
) -> UserOut: | |
""" | |
Dependency kiểm tra người dùng hiện tại có quyền admin hay không. | |
Trả về thông tin người dùng nếu có quyền admin, nếu không raise HTTPException. | |
Usage: | |
@router.get("/admin-only") | |
async def admin_route(user: UserOut = Depends(admin_required)): | |
return {"message": "You have admin access"} | |
""" | |
if not current_user.role or current_user.role not in [UserRole.ADMIN]: | |
raise HTTPException( | |
status_code=status.HTTP_403_FORBIDDEN, | |
detail="Bạn không có quyền truy cập chức năng này", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
return current_user | |