File size: 2,055 Bytes
17f3a9b
 
 
 
f56051d
 
d4b1508
 
 
62f31c8
 
 
17f3a9b
2ae9fb3
f56051d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17f3a9b
 
 
2ae9fb3
17f3a9b
 
d4b1508
 
 
 
 
 
 
 
 
 
 
 
 
 
17f3a9b
 
2ae9fb3
17f3a9b
 
 
 
8d84024
d4b1508
62f31c8
 
 
 
d4b1508
f56051d
62f31c8
 
 
d4b1508
62f31c8
f56051d
 
 
 
 
 
 
 
62f31c8
 
d4b1508
efad2c7
 
 
a65e7e5
62f31c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware

from transformers import pipeline

from RequestModel import PredictRequest
from us_stock import fetch_symbols

app = FastAPI()  # 创建 FastAPI 应用

# 添加 CORS 中间件和限流配置
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 添加信任主机中间件
app.add_middleware(
    TrustedHostMiddleware,
    allowed_hosts=["*"]
)

# 定义请求模型
class TextRequest(BaseModel):
    text: str

# 定义两个 API 路由处理函数
@app.post("/api/aaa")
async def api_aaa_post(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}

# 定义两个 API 路由处理函数
@app.post("/aaa")
async def aaa(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}


# 定义两个 API 路由处理函数
@app.get("/aaa")
async def api_aaa_get(request: TextRequest):
    result = request.text + 'aaa'
    return {"result": result}

@app.post("/api/bbb")
async def api_bbb(request: TextRequest):
    result = request.text + 'bbb'
    return {"result": result}


@app.on_event("startup")
async def initialize_symbols():
    # 在 FastAPI 启动时初始化变量
    await fetch_symbols()

# 优化预测路由
@app.post("/api/predict")
async def predict(request: PredictRequest):
    from blkeras import predict

    try:
        # 使用 asyncio.to_thread 将同步操作转换为异步
        import asyncio
        result = await asyncio.to_thread(
            predict, 
            request.text, 
            request.stock_codes
        )
        return result
    except Exception as e:
        return {"error": str(e)}

@app.get("/")
async def root():
    return {"message": "Welcome to the API. Use /api/aaa or /api/bbb for processing."}