File size: 3,330 Bytes
deb3471
 
 
 
 
d352fe2
 
 
 
 
 
 
 
 
 
 
 
 
deb3471
 
 
d352fe2
 
 
 
deb3471
d352fe2
deb3471
 
 
 
d352fe2
deb3471
 
 
 
d352fe2
deb3471
 
 
 
d352fe2
deb3471
e0174a0
 
 
deb3471
 
d352fe2
deb3471
e0174a0
 
 
 
deb3471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d352fe2
deb3471
 
d352fe2
deb3471
 
 
 
 
 
 
d352fe2
deb3471
 
d352fe2
deb3471
 
 
 
 
d352fe2
deb3471
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from optimum.neuron import utils
import logging
import sys
import os

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

app = FastAPI()

# Get the absolute path to the static directory
static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static")
logger.info(f"Static directory path: {static_dir}")

# Mount static files and templates
app.mount("/static", StaticFiles(directory=static_dir), name="static")
templates = Jinja2Templates(directory="app/templates")

@app.get("/health")
async def health_check():
    logger.info("Health check endpoint called")
    return {"status": "healthy"}

@app.get("/")
async def home(request: Request):
    logger.info("Home page requested")
    return templates.TemplateResponse("index.html", {"request": request})

@app.get("/api/models")
async def get_model_list():
    logger.info("Fetching model list")
    try:
        # Ensure cache directory exists
        os.makedirs("/cache", exist_ok=True)
        
        # Get actual model configurations
        model_list = utils.get_hub_cached_models(mode="inference")
        logger.info(f"Found {len(model_list)} models")
        
        if not model_list:
            logger.warning("No models found")
            return JSONResponse(content=[])
            
        # Transform the data into the expected format
        models = []
        seen_models = set()
        
        for model_tuple in model_list:
            architecture, org, model_id = model_tuple
            full_model_id = f"{org}/{model_id}"
            
            if full_model_id not in seen_models:
                models.append({
                    "id": full_model_id,
                    "name": full_model_id,  # This will be used as the title
                    "type": architecture     # This will be used as the subtitle
                })
                seen_models.add(full_model_id)
                
        logger.info(f"Returning {len(models)} unique models")
        return JSONResponse(content=models)
    except Exception as e:
        logger.error(f"Error fetching models: {str(e)}", exc_info=True)
        return JSONResponse(
            status_code=500,
            content={"error": str(e)}
        )

@app.get("/api/models/{model_id}")
async def get_model_info_endpoint(model_id: str):
    logger.info(f"Fetching configurations for model: {model_id}")
    try:
        configs = utils.get_hub_cached_entries(model_id=model_id, mode="inference")
        logger.info(f"Found {len(configs)} configurations for model {model_id}")
        # Return empty list if no configurations found
        if not configs:
            return JSONResponse(content={"configurations": []})
        return JSONResponse(content={"configurations": configs})
    except Exception as e:
        logger.error(f"Error fetching configurations for model {model_id}: {str(e)}", exc_info=True)
        return JSONResponse(
            status_code=500,
            content={"error": str(e)}
        )