Spaces:
Paused
Paused
| import os | |
| from typing import Optional | |
| from fastapi import HTTPException, Depends | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from datetime import datetime, timedelta, timezone | |
| import jwt | |
| from logger import log_info, log_warning | |
| security = HTTPBearer() | |
| # ===================== Rate Limiting ===================== | |
| class RateLimiter: | |
| """Simple in-memory rate limiter""" | |
| def __init__(self): | |
| self.requests = {} # {key: [(timestamp, count)]} | |
| self.lock = threading.Lock() | |
| def is_allowed(self, key: str, max_requests: int, window_seconds: int) -> bool: | |
| """Check if request is allowed""" | |
| with self.lock: | |
| now = datetime.now(timezone.utc) | |
| if key not in self.requests: | |
| self.requests[key] = [] | |
| # Remove old entries | |
| cutoff = now.timestamp() - window_seconds | |
| self.requests[key] = [ | |
| (ts, count) for ts, count in self.requests[key] | |
| if ts > cutoff | |
| ] | |
| # Count requests in window | |
| total = sum(count for _, count in self.requests[key]) | |
| if total >= max_requests: | |
| return False | |
| # Add this request | |
| self.requests[key].append((now.timestamp(), 1)) | |
| return True | |
| def reset(self, key: str): | |
| """Reset rate limit for key""" | |
| with self.lock: | |
| if key in self.requests: | |
| del self.requests[key] | |
| # Create global rate limiter instance | |
| import threading | |
| rate_limiter = RateLimiter() | |
| # ===================== JWT Config ===================== | |
| def get_jwt_config(): | |
| """Get JWT configuration based on environment""" | |
| # Check if we're in HuggingFace Space | |
| if os.getenv("SPACE_ID"): | |
| # Cloud mode - use secrets from environment | |
| jwt_secret = os.getenv("JWT_SECRET") | |
| if not jwt_secret: | |
| log_warning("⚠️ WARNING: JWT_SECRET not found in environment, using fallback") | |
| jwt_secret = "flare-admin-secret-key-change-in-production" # Fallback | |
| else: | |
| # On-premise mode - use .env file | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| jwt_secret = os.getenv("JWT_SECRET", "flare-admin-secret-key-change-in-production") | |
| return { | |
| "secret": jwt_secret, | |
| "algorithm": os.getenv("JWT_ALGORITHM", "HS256"), | |
| "expiration_hours": int(os.getenv("JWT_EXPIRATION_HOURS", "24")) | |
| } | |
| # ===================== Auth Helpers ===================== | |
| def create_token(username: str) -> str: | |
| """Create JWT token for user""" | |
| config = get_jwt_config() | |
| expiry = datetime.now(timezone.utc) + timedelta(hours=config["expiration_hours"]) | |
| payload = { | |
| "sub": username, | |
| "exp": expiry, | |
| "iat": datetime.now(timezone.utc) | |
| } | |
| return jwt.encode(payload, config["secret"], algorithm=config["algorithm"]) | |
| def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str: | |
| """Verify JWT token and return username""" | |
| token = credentials.credentials | |
| config = get_jwt_config() | |
| try: | |
| payload = jwt.decode(token, config["secret"], algorithms=[config["algorithm"]]) | |
| return payload["sub"] | |
| except jwt.ExpiredSignatureError: | |
| raise HTTPException(status_code=401, detail="Token expired") | |
| except jwt.InvalidTokenError: | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| # ===================== Utility Functions ===================== | |
| def truncate_string(text: str, max_length: int = 100, suffix: str = "...") -> str: | |
| """Truncate string to max length""" | |
| if len(text) <= max_length: | |
| return text | |
| return text[:max_length - len(suffix)] + suffix | |
| def format_file_size(size_bytes: int) -> str: | |
| """Format file size in human readable format""" | |
| for unit in ['B', 'KB', 'MB', 'GB', 'TB']: | |
| if size_bytes < 1024.0: | |
| return f"{size_bytes:.2f} {unit}" | |
| size_bytes /= 1024.0 | |
| return f"{size_bytes:.2f} PB" | |
| def is_safe_path(path: str, base_path: str) -> bool: | |
| """Check if path is safe (no directory traversal)""" | |
| import os | |
| # Resolve to absolute paths | |
| base = os.path.abspath(base_path) | |
| target = os.path.abspath(os.path.join(base, path)) | |
| # Check if target is under base | |
| return target.startswith(base) | |
| def get_current_timestamp() -> str: | |
| """ | |
| Get current UTC timestamp in ISO format with Z suffix | |
| Returns: "2025-01-10T12:00:00.123Z" | |
| """ | |
| return datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z') | |
| def normalize_timestamp(timestamp: Optional[str]) -> str: | |
| """ | |
| Normalize timestamp string for consistent comparison | |
| Handles various formats: | |
| - "2025-01-10T12:00:00Z" | |
| - "2025-01-10T12:00:00.000Z" | |
| - "2025-01-10T12:00:00+00:00" | |
| - "2025-01-10 12:00:00+00:00" | |
| """ | |
| if not timestamp: | |
| return "" | |
| # Normalize various formats | |
| normalized = timestamp.replace(' ', 'T') # Space to T | |
| normalized = normalized.replace('+00:00', 'Z') # UTC timezone | |
| # Remove milliseconds if present for comparison | |
| if '.' in normalized and normalized.endswith('Z'): | |
| normalized = normalized.split('.')[0] + 'Z' | |
| return normalized | |
| def timestamps_equal(ts1: Optional[str], ts2: Optional[str]) -> bool: | |
| """ | |
| Compare two timestamps regardless of format differences | |
| """ | |
| return normalize_timestamp(ts1) == normalize_timestamp(ts2) |