Spaces:
Running
Running
import os | |
from datetime import datetime | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.middleware.trustedhost import TrustedHostMiddleware | |
import asyncio | |
from contextlib import asynccontextmanager | |
from RequestModel import PredictRequest | |
# 全局变量,用于跟踪初始化状态 | |
is_initialized = False | |
initialization_lock = asyncio.Lock() | |
async def lifespan(app: FastAPI): | |
# 启动时运行 | |
print("===== Application Startup at", datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "=====") | |
global is_initialized | |
async with initialization_lock: | |
if not is_initialized: | |
# 非阻塞初始化 - 在后台任务中执行 | |
asyncio.create_task(initialize_application()) | |
is_initialized = True | |
print("===== FastAPI Application Ready =====") | |
yield | |
# 关闭时运行 | |
print("===== Application Shutdown =====") | |
# cleanup_code_here() | |
async def initialize_application(): | |
# 在这里进行所有需要的初始化 | |
print("===== Starting application initialization =====") | |
try: | |
from us_stock import fetch_symbols | |
print("Importing us_stock module...") | |
print("Calling fetch_symbols...") | |
await fetch_symbols() | |
print("fetch_symbols completed") | |
print("===== Application initialization completed =====") | |
except Exception as e: | |
print(f"Error during initialization: {e}") | |
print("===== Application initialization failed =====") | |
raise | |
app = FastAPI(lifespan=lifespan) | |
# 添加 CORS 中间件和限流配置 | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# 添加信任主机中间件 | |
app.add_middleware( | |
TrustedHostMiddleware, | |
allowed_hosts=["*"] | |
) | |
# 定义请求模型 | |
class TextRequest(BaseModel): | |
text: str | |
# 定义两个 API 路由处理函数 | |
async def api_aaa_post(request: TextRequest): | |
result = request.text + 'aaa' | |
return {"result": result} | |
# 定义两个 API 路由处理函数 | |
async def aaa(request: TextRequest): | |
result = request.text + 'aaa' | |
return {"result": result} | |
# 定义两个 API 路由处理函数 | |
async def api_aaa_get(request: TextRequest): | |
result = request.text + 'aaa' | |
return {"result": result} | |
async def api_bbb(request: TextRequest): | |
result = request.text + 'bbb' | |
return {"result": result} | |
# 优化预测路由 | |
async def predict(request: PredictRequest): | |
from blkeras import predict | |
try: | |
result = await asyncio.to_thread(predict, request.text, request.stock_codes) | |
return result | |
except Exception as e: | |
return [] | |
async def root(): | |
return {"message": "Welcome to the API. Use /api/aaa or /api/bbb for processing."} | |
async def health_check(): | |
return { | |
"status": "healthy", | |
"initialized": is_initialized, | |
"timestamp": datetime.now().isoformat() | |
} | |