|
from fastapi import FastAPI, Request |
|
from fastapi.templating import Jinja2Templates |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import JSONResponse |
|
from optimum.neuron import utils |
|
import logging |
|
import sys |
|
import os |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[ |
|
logging.StreamHandler(sys.stdout) |
|
] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI() |
|
|
|
|
|
static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static") |
|
logger.info(f"Static directory path: {static_dir}") |
|
|
|
|
|
app.mount("/static", StaticFiles(directory=static_dir), name="static") |
|
templates = Jinja2Templates(directory="app/templates") |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
logger.info("Health check endpoint called") |
|
return {"status": "healthy"} |
|
|
|
@app.get("/") |
|
async def home(request: Request): |
|
logger.info("Home page requested") |
|
return templates.TemplateResponse("index.html", {"request": request}) |
|
|
|
@app.get("/api/models") |
|
async def get_model_list(): |
|
logger.info("Fetching model list") |
|
try: |
|
|
|
logger.info(f"HF_TOKEN present: {bool(os.getenv('HF_TOKEN'))}") |
|
|
|
model_list = utils.get_hub_cached_models(mode="inference") |
|
logger.info(f"Found {len(model_list)} models") |
|
|
|
models = [] |
|
seen_models = set() |
|
|
|
for model_tuple in model_list: |
|
architecture, org, model_id = model_tuple |
|
full_model_id = f"{org}/{model_id}" |
|
|
|
if full_model_id not in seen_models: |
|
models.append({ |
|
"id": full_model_id, |
|
"name": full_model_id, |
|
"type": architecture |
|
}) |
|
seen_models.add(full_model_id) |
|
|
|
logger.info(f"Returning {len(models)} unique models") |
|
return JSONResponse(content=models) |
|
except Exception as e: |
|
|
|
logger.error(f"Error fetching models: {str(e)}") |
|
logger.error("Full error details:", exc_info=True) |
|
return JSONResponse( |
|
status_code=500, |
|
content={"error": str(e), "type": str(type(e).__name__)} |
|
) |
|
|
|
@app.get("/api/models/{model_id:path}") |
|
async def get_model_info_endpoint(model_id: str): |
|
logger.info(f"Fetching configurations for model: {model_id}") |
|
try: |
|
configs = utils.get_hub_cached_entries(model_id=model_id, mode="inference") |
|
logger.info(f"Found {len(configs)} configurations for model {model_id}") |
|
|
|
if not configs: |
|
return JSONResponse(content={"configurations": []}) |
|
return JSONResponse(content={"configurations": configs}) |
|
except Exception as e: |
|
logger.error(f"Error fetching configurations for model {model_id}: {str(e)}", exc_info=True) |
|
return JSONResponse( |
|
status_code=500, |
|
content={"error": str(e)} |
|
) |