File size: 3,190 Bytes
17f3a9b
298d9ab
17f3a9b
 
f56051d
 
068fdbc
 
d4b1508
62f31c8
 
068fdbc
 
 
 
 
 
 
298d9ab
068fdbc
 
 
298d9ab
 
068fdbc
298d9ab
068fdbc
 
298d9ab
068fdbc
 
 
 
298d9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
068fdbc
 
2ae9fb3
f56051d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17f3a9b
 
 
2ae9fb3
17f3a9b
 
d4b1508
 
 
 
 
 
 
 
 
 
 
 
 
 
17f3a9b
 
2ae9fb3
17f3a9b
 
 
 
8d84024
f56051d
62f31c8
 
 
 
068fdbc
f56051d
62f31c8
558076d
d4b1508
efad2c7
 
 
a65e7e5
298d9ab
 
 
 
 
 
 
 
62f31c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
from datetime import datetime
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
import asyncio
from contextlib import asynccontextmanager

from RequestModel import PredictRequest

# 全局变量,用于跟踪初始化状态
is_initialized = False
initialization_lock = asyncio.Lock()

@asynccontextmanager
async def lifespan(app: FastAPI):
    # 启动时运行
    print("===== Application Startup at", datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "=====")
    global is_initialized
    async with initialization_lock:
        if not is_initialized:
            # 非阻塞初始化 - 在后台任务中执行
            asyncio.create_task(initialize_application())
            is_initialized = True
    print("===== FastAPI Application Ready =====")
    yield
    # 关闭时运行
    print("===== Application Shutdown =====")
    # cleanup_code_here()

async def initialize_application():
    # 在这里进行所有需要的初始化
    print("===== Starting application initialization =====")
    try:
        from us_stock import fetch_symbols
        print("Importing us_stock module...")
        
        print("Calling fetch_symbols...")
        await fetch_symbols()
        print("fetch_symbols completed")
        
        print("===== Application initialization completed =====")
    except Exception as e:
        print(f"Error during initialization: {e}")
        print("===== Application initialization failed =====")
        raise

app = FastAPI(lifespan=lifespan)

# 添加 CORS 中间件和限流配置
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 添加信任主机中间件
app.add_middleware(
    TrustedHostMiddleware,
    allowed_hosts=["*"]
)

# 定义请求模型
class TextRequest(BaseModel):
    text: str

# 定义两个 API 路由处理函数
@app.post("/api/aaa")
async def api_aaa_post(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}

# 定义两个 API 路由处理函数
@app.post("/aaa")
async def aaa(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}


# 定义两个 API 路由处理函数
@app.get("/aaa")
async def api_aaa_get(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}

@app.post("/api/bbb")
async def api_bbb(request: TextRequest):
    result = request.text + 'bbb'
    return {"result": result}

# 优化预测路由
@app.post("/api/predict")
async def predict(request: PredictRequest):
    from blkeras import predict
    try:
        result = await asyncio.to_thread(predict, request.text, request.stock_codes)
        return result
    except Exception as e:
        return []

@app.get("/")
async def root():
    return {"message": "Welcome to the API. Use /api/aaa or /api/bbb for processing."}

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "initialized": is_initialized,
        "timestamp": datetime.now().isoformat()
    }