a1c00l commited on
Commit
4496556
·
verified ·
1 Parent(s): 8935d92

Update src/aibom_generator/api.py

Browse files
Files changed (1) hide show
  1. src/aibom_generator/api.py +360 -80
src/aibom_generator/api.py CHANGED
@@ -2,128 +2,408 @@
2
  FastAPI server for the AIBOM Generator with minimal UI.
3
  """
4
 
5
- import logging
6
- import os
7
- from typing import Dict, List, Optional, Any
8
-
9
- from fastapi import FastAPI, HTTPException, BackgroundTasks, Request, Form
10
  from fastapi.middleware.cors import CORSMiddleware
11
- from fastapi.responses import HTMLResponse, JSONResponse
12
- from fastapi.templating import Jinja2Templates
13
- from pydantic import BaseModel
 
 
 
 
 
 
 
14
 
15
- from aibom_generator.generator import AIBOMGenerator
16
- from aibom_generator.utils import setup_logging, calculate_completeness_score
17
 
18
- # Set up logging
19
- setup_logging()
20
- logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Create FastAPI app
23
  app = FastAPI(
24
- title="AIBOM Generator API",
25
- description="API for generating AI Bills of Materials (AIBOMs) in CycloneDX format for Hugging Face models.",
26
- version="0.1.0",
27
  )
28
 
29
  # Add CORS middleware
30
  app.add_middleware(
31
  CORSMiddleware,
32
- allow_origins=["*"],
33
  allow_credentials=True,
34
  allow_methods=["*"],
35
  allow_headers=["*"],
36
  )
37
 
38
- # Initialize templates
39
- templates = Jinja2Templates(directory="templates")
40
-
41
- # Create generator instance
42
- generator = AIBOMGenerator(
43
- hf_token=os.environ.get("HF_TOKEN"),
44
- inference_model_url=os.environ.get("AIBOM_INFERENCE_URL"),
45
- use_inference=os.environ.get("AIBOM_USE_INFERENCE", "true").lower() == "true",
46
- cache_dir=os.environ.get("AIBOM_CACHE_DIR"),
47
- )
48
 
 
 
49
 
50
- # Define request and response models
51
- class GenerateRequest(BaseModel):
52
- model_id: str
53
- include_inference: Optional[bool] = None
54
- completeness_threshold: Optional[int] = 0
 
55
 
 
 
 
 
 
 
56
 
57
- class GenerateResponse(BaseModel):
58
- aibom: Dict[str, Any]
59
- completeness_score: int
60
- model_id: str
 
 
61
 
 
 
62
 
63
  class StatusResponse(BaseModel):
64
- status: str
65
- version: str
 
66
 
 
 
 
 
 
67
 
68
- # Web UI endpoint
69
- @app.get("/", response_class=HTMLResponse)
70
- async def home(request: Request):
71
- return templates.TemplateResponse("index.html", {"request": request})
 
72
 
 
 
73
 
74
- @app.post("/generate", response_class=HTMLResponse)
75
- async def generate_from_ui(request: Request, model_id: str = Form(...)):
76
- try:
77
- aibom = generator.generate_aibom(model_id=model_id)
78
- completeness_score = calculate_completeness_score(aibom)
 
 
 
 
79
 
80
- return templates.TemplateResponse(
81
- "result.html",
82
- {"request": request, "aibom": aibom, "completeness_score": completeness_score, "model_id": model_id},
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
- except Exception as e:
85
- logger.error(f"Error generating AIBOM: {e}")
86
- return templates.TemplateResponse(
87
- "error.html",
88
- {"request": request, "error": str(e)},
 
 
 
 
 
 
 
 
 
 
89
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
-
92
- # Original JSON API endpoints (kept unchanged)
93
- @app.post("/generate/json", response_model=GenerateResponse)
94
- async def generate_aibom(request: GenerateRequest):
 
 
 
 
95
  try:
96
- aibom = generator.generate_aibom(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  model_id=request.model_id,
98
  include_inference=request.include_inference,
 
 
99
  )
100
- completeness_score = calculate_completeness_score(aibom)
101
-
102
- if completeness_score < request.completeness_threshold:
103
- raise HTTPException(
104
- status_code=400,
105
- detail=f"AIBOM completeness score ({completeness_score}) is below threshold ({request.completeness_threshold})",
106
- )
107
-
108
- return {
109
  "aibom": aibom,
110
- "completeness_score": completeness_score,
111
  "model_id": request.model_id,
 
 
 
 
112
  }
 
 
113
  except Exception as e:
114
- logger.error(f"Error generating AIBOM: {e}")
115
- raise HTTPException(
116
- status_code=500,
117
- detail=f"Error generating AIBOM: {str(e)}",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  )
 
 
 
 
 
 
 
 
 
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- @app.get("/health")
122
- async def health():
123
- return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- if __name__ == "__main__":
127
- import uvicorn
 
 
 
 
 
 
 
 
 
 
128
 
129
- uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  FastAPI server for the AIBOM Generator with minimal UI.
3
  """
4
 
5
+ from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks, Query, File, UploadFile, Form
6
+ from fastapi.responses import JSONResponse, FileResponse
 
 
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.staticfiles import StaticFiles
9
+ from pydantic import BaseModel, Field
10
+ from typing import Optional, Dict, Any, List
11
+ import uvicorn
12
+ import json
13
+ import os
14
+ import sys
15
+ import uuid
16
+ import shutil
17
+ from datetime import datetime
18
 
19
+ # Add parent directory to path to import generator module
20
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
 
22
+ # Import the AIBOM generator
23
+ try:
24
+ from aibom_fix.final_generator import AIBOMGenerator
25
+ except ImportError:
26
+ # If not found, try the mapping directory
27
+ try:
28
+ from aibom_mapping.final_generator import AIBOMGenerator
29
+ except ImportError:
30
+ # If still not found, use the original generator
31
+ try:
32
+ from aibom_fix.generator import AIBOMGenerator
33
+ except ImportError:
34
+ from generator import AIBOMGenerator
35
 
36
  # Create FastAPI app
37
  app = FastAPI(
38
+ title="Aetheris AI SBOM Generator API",
39
+ description="API for generating CycloneDX JSON AI SBOMs for machine learning models",
40
+ version="1.0.0",
41
  )
42
 
43
  # Add CORS middleware
44
  app.add_middleware(
45
  CORSMiddleware,
46
+ allow_origins=["*"], # Allow all origins in development
47
  allow_credentials=True,
48
  allow_methods=["*"],
49
  allow_headers=["*"],
50
  )
51
 
52
+ # Create output directory for AIBOMs
53
+ os.makedirs(os.path.join(os.path.dirname(os.path.abspath(__file__)), "output"), exist_ok=True)
54
+ app.mount("/output", StaticFiles(directory=os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")), name="output")
 
 
 
 
 
 
 
55
 
56
+ # Create a global generator instance
57
+ generator = AIBOMGenerator(use_best_practices=True)
58
 
59
+ # Define request models
60
+ class GenerateAIBOMRequest(BaseModel):
61
+ model_id: str = Field(..., description="The Hugging Face model ID (e.g., 'meta-llama/Llama-4-Scout-17B-16E-Instruct')")
62
+ hf_token: Optional[str] = Field(None, description="Optional Hugging Face API token for accessing private models")
63
+ include_inference: Optional[bool] = Field(True, description="Whether to use AI inference to enhance the AIBOM")
64
+ use_best_practices: Optional[bool] = Field(True, description="Whether to use industry best practices for scoring")
65
 
66
+ class AIBOMResponse(BaseModel):
67
+ aibom: Dict[str, Any] = Field(..., description="The generated AIBOM in CycloneDX JSON format")
68
+ model_id: str = Field(..., description="The model ID for which the AIBOM was generated")
69
+ generated_at: str = Field(..., description="Timestamp when the AIBOM was generated")
70
+ request_id: str = Field(..., description="Unique ID for this request")
71
+ download_url: Optional[str] = Field(None, description="URL to download the AIBOM JSON file")
72
 
73
+ class EnhancementReport(BaseModel):
74
+ ai_enhanced: bool = Field(..., description="Whether AI enhancement was applied")
75
+ ai_model: Optional[str] = Field(None, description="The AI model used for enhancement, if any")
76
+ original_score: Dict[str, Any] = Field(..., description="Original completeness score before enhancement")
77
+ final_score: Dict[str, Any] = Field(..., description="Final completeness score after enhancement")
78
+ improvement: float = Field(..., description="Score improvement from enhancement")
79
 
80
+ class AIBOMWithReportResponse(AIBOMResponse):
81
+ enhancement_report: Optional[EnhancementReport] = Field(None, description="Report on AI enhancement results")
82
 
83
  class StatusResponse(BaseModel):
84
+ status: str = Field(..., description="API status")
85
+ version: str = Field(..., description="API version")
86
+ generator_version: str = Field(..., description="AIBOM generator version")
87
 
88
+ class BatchGenerateRequest(BaseModel):
89
+ model_ids: List[str] = Field(..., description="List of Hugging Face model IDs to generate AIBOMs for")
90
+ hf_token: Optional[str] = Field(None, description="Optional Hugging Face API token for accessing private models")
91
+ include_inference: Optional[bool] = Field(True, description="Whether to use AI inference to enhance the AIBOM")
92
+ use_best_practices: Optional[bool] = Field(True, description="Whether to use industry best practices for scoring")
93
 
94
+ class BatchJobResponse(BaseModel):
95
+ job_id: str = Field(..., description="Unique ID for the batch job")
96
+ status: str = Field(..., description="Job status (e.g., 'queued', 'processing', 'completed')")
97
+ model_ids: List[str] = Field(..., description="List of model IDs in the batch")
98
+ created_at: str = Field(..., description="Timestamp when the job was created")
99
 
100
+ # In-memory storage for batch jobs
101
+ batch_jobs = {}
102
 
103
+ # Define API endpoints
104
+ @app.get("/", response_model=StatusResponse)
105
+ async def get_status():
106
+ """Get the API status and version information."""
107
+ return {
108
+ "status": "operational",
109
+ "version": "1.0.0",
110
+ "generator_version": "0.1.0",
111
+ }
112
 
113
+ @app.post("/generate", response_model=AIBOMResponse)
114
+ async def generate_aibom(request: GenerateAIBOMRequest):
115
+ """
116
+ Generate a CycloneDX JSON AI SBOM for a Hugging Face model.
117
+
118
+ This endpoint takes a model ID and optional parameters to generate
119
+ a comprehensive AI SBOM in CycloneDX format.
120
+ """
121
+ try:
122
+ # Create a new generator instance with the provided token if available
123
+ gen = AIBOMGenerator(
124
+ hf_token=request.hf_token,
125
+ use_inference=request.include_inference,
126
+ use_best_practices=request.use_best_practices
127
  )
128
+
129
+ # Generate a request ID
130
+ request_id = str(uuid.uuid4())
131
+
132
+ # Create output file path
133
+ safe_model_id = request.model_id.replace("/", "_")
134
+ output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
135
+ output_file = os.path.join(output_dir, f"{safe_model_id}_{request_id}.json")
136
+
137
+ # Generate the AIBOM
138
+ aibom = gen.generate_aibom(
139
+ model_id=request.model_id,
140
+ include_inference=request.include_inference,
141
+ use_best_practices=request.use_best_practices,
142
+ output_file=output_file
143
  )
144
+
145
+ # Create download URL
146
+ download_url = f"/output/{os.path.basename(output_file)}"
147
+
148
+ # Create response
149
+ response = {
150
+ "aibom": aibom,
151
+ "model_id": request.model_id,
152
+ "generated_at": datetime.utcnow().isoformat() + "Z",
153
+ "request_id": request_id,
154
+ "download_url": download_url
155
+ }
156
+
157
+ return response
158
+ except Exception as e:
159
+ raise HTTPException(status_code=500, detail=f"Error generating AIBOM: {str(e)}")
160
 
161
+ @app.post("/generate-with-report", response_model=AIBOMWithReportResponse)
162
+ async def generate_aibom_with_report(request: GenerateAIBOMRequest):
163
+ """
164
+ Generate a CycloneDX JSON AI SBOM with an enhancement report.
165
+
166
+ This endpoint is similar to /generate but also includes a report
167
+ on the AI enhancement results, including before/after scores.
168
+ """
169
  try:
170
+ # Create a new generator instance with the provided token if available
171
+ gen = AIBOMGenerator(
172
+ hf_token=request.hf_token,
173
+ use_inference=request.include_inference,
174
+ use_best_practices=request.use_best_practices
175
+ )
176
+
177
+ # Generate a request ID
178
+ request_id = str(uuid.uuid4())
179
+
180
+ # Create output file path
181
+ safe_model_id = request.model_id.replace("/", "_")
182
+ output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
183
+ output_file = os.path.join(output_dir, f"{safe_model_id}_{request_id}.json")
184
+
185
+ # Generate the AIBOM
186
+ aibom = gen.generate_aibom(
187
  model_id=request.model_id,
188
  include_inference=request.include_inference,
189
+ use_best_practices=request.use_best_practices,
190
+ output_file=output_file
191
  )
192
+
193
+ # Get the enhancement report
194
+ enhancement_report = gen.get_enhancement_report()
195
+
196
+ # Create download URL
197
+ download_url = f"/output/{os.path.basename(output_file)}"
198
+
199
+ # Create response
200
+ response = {
201
  "aibom": aibom,
 
202
  "model_id": request.model_id,
203
+ "generated_at": datetime.utcnow().isoformat() + "Z",
204
+ "request_id": request_id,
205
+ "download_url": download_url,
206
+ "enhancement_report": enhancement_report
207
  }
208
+
209
+ return response
210
  except Exception as e:
211
+ raise HTTPException(status_code=500, detail=f"Error generating AIBOM: {str(e)}")
212
+
213
+ @app.get("/models/{model_id}/score", response_model=Dict[str, Any])
214
+ async def get_model_score(
215
+ model_id: str,
216
+ hf_token: Optional[str] = Query(None, description="Optional Hugging Face API token for accessing private models"),
217
+ use_best_practices: bool = Query(True, description="Whether to use industry best practices for scoring")
218
+ ):
219
+ """
220
+ Get the completeness score for a model without generating a full AIBOM.
221
+
222
+ This is a lightweight endpoint that only returns the scoring information.
223
+ """
224
+ try:
225
+ # Create a new generator instance with the provided token if available
226
+ gen = AIBOMGenerator(
227
+ hf_token=hf_token,
228
+ use_inference=False, # Don't use inference for scoring only
229
+ use_best_practices=use_best_practices
230
+ )
231
+
232
+ # Generate the AIBOM (needed to calculate score)
233
+ aibom = gen.generate_aibom(
234
+ model_id=model_id,
235
+ include_inference=False, # Don't use inference for scoring only
236
+ use_best_practices=use_best_practices
237
  )
238
+
239
+ # Get the enhancement report for the score
240
+ enhancement_report = gen.get_enhancement_report()
241
+
242
+ if enhancement_report and "final_score" in enhancement_report:
243
+ return enhancement_report["final_score"]
244
+ else:
245
+ raise HTTPException(status_code=500, detail="Failed to calculate score")
246
+ except Exception as e:
247
+ raise HTTPException(status_code=500, detail=f"Error calculating score: {str(e)}")
248
 
249
+ @app.post("/batch", response_model=BatchJobResponse)
250
+ async def batch_generate(request: BatchGenerateRequest, background_tasks: BackgroundTasks):
251
+ """
252
+ Start a batch job to generate AIBOMs for multiple models.
253
+
254
+ This endpoint queues a background task to generate AIBOMs for all the
255
+ specified model IDs and returns a job ID that can be used to check status.
256
+ """
257
+ try:
258
+ # Generate a job ID
259
+ job_id = str(uuid.uuid4())
260
+
261
+ # Create job record
262
+ job = {
263
+ "job_id": job_id,
264
+ "status": "queued",
265
+ "model_ids": request.model_ids,
266
+ "created_at": datetime.utcnow().isoformat() + "Z",
267
+ "completed": 0,
268
+ "total": len(request.model_ids),
269
+ "results": {}
270
+ }
271
+
272
+ # Store job in memory
273
+ batch_jobs[job_id] = job
274
+
275
+ # Add background task to process the batch
276
+ background_tasks.add_task(
277
+ process_batch_job,
278
+ job_id=job_id,
279
+ model_ids=request.model_ids,
280
+ hf_token=request.hf_token,
281
+ include_inference=request.include_inference,
282
+ use_best_practices=request.use_best_practices
283
+ )
284
+
285
+ # Return job info
286
+ return {
287
+ "job_id": job_id,
288
+ "status": "queued",
289
+ "model_ids": request.model_ids,
290
+ "created_at": job["created_at"]
291
+ }
292
+ except Exception as e:
293
+ raise HTTPException(status_code=500, detail=f"Error starting batch job: {str(e)}")
294
 
295
+ @app.get("/batch/{job_id}", response_model=Dict[str, Any])
296
+ async def get_batch_status(job_id: str):
297
+ """
298
+ Get the status of a batch job.
299
+
300
+ This endpoint returns the current status of a batch job, including
301
+ progress information and results for completed models.
302
+ """
303
+ if job_id not in batch_jobs:
304
+ raise HTTPException(status_code=404, detail="Batch job not found")
305
+
306
+ return batch_jobs[job_id]
307
 
308
+ @app.post("/upload-model-card")
309
+ async def upload_model_card(
310
+ model_id: str = Form(...),
311
+ model_card: UploadFile = File(...),
312
+ include_inference: bool = Form(True),
313
+ use_best_practices: bool = Form(True)
314
+ ):
315
+ """
316
+ Generate an AIBOM from an uploaded model card file.
317
+
318
+ This endpoint allows users to upload a model card file directly
319
+ instead of requiring the model to be on Hugging Face.
320
+ """
321
+ try:
322
+ # Create a temporary directory to store the uploaded file
323
+ temp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "temp")
324
+ os.makedirs(temp_dir, exist_ok=True)
325
+
326
+ # Save the uploaded file
327
+ file_path = os.path.join(temp_dir, model_card.filename)
328
+ with open(file_path, "wb") as f:
329
+ shutil.copyfileobj(model_card.file, f)
330
+
331
+ # TODO: Implement custom model card processing
332
+ # This would require modifying the AIBOMGenerator to accept a file path
333
+ # instead of a model ID, which is beyond the scope of this example
334
+
335
+ # For now, return a placeholder response
336
+ return {
337
+ "status": "not_implemented",
338
+ "message": "Custom model card processing is not yet implemented"
339
+ }
340
+ except Exception as e:
341
+ raise HTTPException(status_code=500, detail=f"Error processing uploaded model card: {str(e)}")
342
 
343
+ @app.get("/download/{filename}")
344
+ async def download_aibom(filename: str):
345
+ """
346
+ Download a previously generated AIBOM file.
347
+
348
+ This endpoint allows downloading AIBOM files by filename.
349
+ """
350
+ file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output", filename)
351
+ if not os.path.exists(file_path):
352
+ raise HTTPException(status_code=404, detail="File not found")
353
+
354
+ return FileResponse(file_path, media_type="application/json", filename=filename)
355
 
356
+ # Background task function for batch processing
357
+ async def process_batch_job(job_id: str, model_ids: List[str], hf_token: Optional[str], include_inference: bool, use_best_practices: bool):
358
+ """Process a batch job in the background."""
359
+ # Update job status
360
+ batch_jobs[job_id]["status"] = "processing"
361
+
362
+ # Create output directory
363
+ output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output", job_id)
364
+ os.makedirs(output_dir, exist_ok=True)
365
+
366
+ # Process each model
367
+ for model_id in model_ids:
368
+ try:
369
+ # Create a new generator instance
370
+ gen = AIBOMGenerator(
371
+ hf_token=hf_token,
372
+ use_inference=include_inference,
373
+ use_best_practices=use_best_practices
374
+ )
375
+
376
+ # Create output file path
377
+ safe_model_id = model_id.replace("/", "_")
378
+ output_file = os.path.join(output_dir, f"{safe_model_id}.json")
379
+
380
+ # Generate the AIBOM
381
+ aibom = gen.generate_aibom(
382
+ model_id=model_id,
383
+ include_inference=include_inference,
384
+ use_best_practices=use_best_practices,
385
+ output_file=output_file
386
+ )
387
+
388
+ # Get the enhancement report
389
+ enhancement_report = gen.get_enhancement_report()
390
+
391
+ # Create download URL
392
+ download_url = f"/output/{job_id}/{safe_model_id}.json"
393
+
394
+ # Store result
395
+ batch_jobs[job_id]["results"][model_id] = {
396
+ "status": "completed",
397
+ "download_url": download_url,
398
+ "enhancement_report": enhancement_report
399
+ }
400
+ except Exception as e:
401
+ # Store error
402
+ batch_jobs[job_id]["results"][model_id] = {
403
+ "status": "error",
404
+ "error": str(e)
405
+ }
406
+
407
+ # Update progress
408
+ batch_
409
+ (Content truncated due to size limit. Use line ranges to read in chunks)