wjm55 commited on
Commit
4160d5b
·
1 Parent(s): aae7036

Refactor app.py to remove model caching and initialize model per request. Update root endpoint and modify predict function to handle uploaded images. Change test.py to use variable for local app URL.

Browse files
Files changed (3) hide show
  1. app.py +16 -31
  2. test.py +1 -1
  3. test2.py +21 -0
app.py CHANGED
@@ -1,22 +1,10 @@
1
  from fastapi import FastAPI, UploadFile
2
  from ultralytics import YOLO
3
- import io
4
  from PIL import Image
5
- import numpy as np
6
  import os
7
  from huggingface_hub import hf_hub_download
8
- from ultralytics import YOLO
9
- import requests
10
- import supervision as sv
11
-
12
- # Global variable to store model instances
13
- MODEL_CACHE = {}
14
 
15
  def init_model(model_id: str):
16
- # Return cached model if it exists
17
- if model_id in MODEL_CACHE:
18
- return MODEL_CACHE[model_id]
19
-
20
  # Define models
21
  MODEL_OPTIONS = {
22
  "YOLOv11-Nano": "medieval-yolov11n.pt",
@@ -28,34 +16,33 @@ def init_model(model_id: str):
28
 
29
  if model_id in MODEL_OPTIONS:
30
  os.makedirs("models", exist_ok=True)
31
- model_path = hf_hub_download(
32
  repo_id="biglam/medieval-manuscript-yolov11",
33
- filename=MODEL_OPTIONS[model_id],
34
- cache_dir="models" # Specify cache directory
35
  )
36
- model = YOLO(model_path)
37
- MODEL_CACHE[model_id] = model
 
38
  return model
39
  else:
40
  raise ValueError(f"Model {model_id} not found")
41
 
42
  app = FastAPI()
43
 
44
- # Initialize default model at startup
45
- @app.on_event("startup")
46
- async def startup_event():
47
- init_model("YOLOv11-XLarge") # Initialize default model
48
 
49
  @app.post("/predict")
50
  async def predict(image: UploadFile,
51
- model_id: str = "YOLOv11-XLarge",
52
- conf: float = 0.25,
53
- iou: float = 0.7
54
- ):
55
- # Get model from cache or initialize it
56
  model = init_model(model_id)
57
 
58
- # Download and open image from URL
59
  image = Image.open(image.file)
60
 
61
  # Run inference with the PIL Image
@@ -63,18 +50,16 @@ async def predict(image: UploadFile,
63
 
64
  # Extract detection results
65
  result = results[0]
66
- # print(result)
67
  detections = []
68
 
69
  for box in result.boxes:
70
  detection = {
71
  "class": result.names[int(box.cls[0])],
72
  "confidence": float(box.conf[0]),
73
- "bbox": box.xyxy[0].tolist() # Convert bbox tensor to list
74
  }
75
  detections.append(detection)
76
- print(detections)
77
-
78
  return {"detections": detections}
79
 
80
  if __name__ == "__main__":
 
1
  from fastapi import FastAPI, UploadFile
2
  from ultralytics import YOLO
 
3
  from PIL import Image
 
4
  import os
5
  from huggingface_hub import hf_hub_download
 
 
 
 
 
 
6
 
7
  def init_model(model_id: str):
 
 
 
 
8
  # Define models
9
  MODEL_OPTIONS = {
10
  "YOLOv11-Nano": "medieval-yolov11n.pt",
 
16
 
17
  if model_id in MODEL_OPTIONS:
18
  os.makedirs("models", exist_ok=True)
19
+ path = hf_hub_download(
20
  repo_id="biglam/medieval-manuscript-yolov11",
21
+ filename=MODEL_OPTIONS[model_id]
 
22
  )
23
+ local_path = os.path.join("models", path)
24
+ # Initialize and return model
25
+ model = YOLO(path)
26
  return model
27
  else:
28
  raise ValueError(f"Model {model_id} not found")
29
 
30
  app = FastAPI()
31
 
32
+ @app.get("/")
33
+ async def root():
34
+ return {"status": "ok"}
 
35
 
36
  @app.post("/predict")
37
  async def predict(image: UploadFile,
38
+ model_id: str = "YOLOv11-XLarge",
39
+ conf: float = 0.25,
40
+ iou: float = 0.7
41
+ ):
42
+ # Initialize model for each request
43
  model = init_model(model_id)
44
 
45
+ # Open image from uploaded file
46
  image = Image.open(image.file)
47
 
48
  # Run inference with the PIL Image
 
50
 
51
  # Extract detection results
52
  result = results[0]
 
53
  detections = []
54
 
55
  for box in result.boxes:
56
  detection = {
57
  "class": result.names[int(box.cls[0])],
58
  "confidence": float(box.conf[0]),
59
+ "bbox": box.xyxy[0].tolist()
60
  }
61
  detections.append(detection)
62
+
 
63
  return {"detections": detections}
64
 
65
  if __name__ == "__main__":
test.py CHANGED
@@ -18,7 +18,7 @@ with open(image_path, 'rb') as f:
18
  }
19
 
20
  # Send POST request to the endpoint
21
- response = requests.post(hf_app + ':7860/predict',
22
  files=files,
23
  params=params)
24
 
 
18
  }
19
 
20
  # Send POST request to the endpoint
21
+ response = requests.post(local_app + ':7860/predict',
22
  files=files,
23
  params=params)
24
 
test2.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+
3
+ def process_item(item):
4
+ # Simulate a time-consuming task
5
+ import time
6
+ time.sleep(0.1)
7
+ return item * 2
8
+
9
+ def generate_data():
10
+ for i in range(10):
11
+ yield i
12
+
13
+ def process_data_in_parallel(data_generator):
14
+ with ThreadPoolExecutor(max_workers=4) as executor:
15
+ results = executor.map(process_item, data_generator)
16
+ return list(results)
17
+
18
+ # Usage example
19
+ data_generator = generate_data()
20
+ processed_data = process_data_in_parallel(data_generator)
21
+ print(processed_data)