ntsc207 commited on
Commit
d2108ed
·
verified ·
1 Parent(s): 8ddd3c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -33,7 +33,7 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
33
  img.save(img_path)
34
  input_path = img_path
35
 
36
- output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
37
  elif vid_path is not None:
38
  vid_name = 'output.mp4'
39
 
@@ -73,9 +73,9 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
73
  output_path, df, frame_counts_df = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
74
  elif tracking_algorithm == 'strong_sort':
75
  device_strongsort = torch.device('cuda:0')
76
- output_path, df, frame_counts_df = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
77
  else:
78
- output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
79
  # Assuming output_path is the path to the output file
80
  _, output_extension = os.path.splitext(output_path)
81
  palette = {"Bus": "red", "Bike": "blue", "Car": "green", "Pedestrian": "yellow", "Truck": "purple"}
 
33
  img.save(img_path)
34
  input_path = img_path
35
 
36
+ output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True, hide_labels = True)
37
  elif vid_path is not None:
38
  vid_name = 'output.mp4'
39
 
 
73
  output_path, df, frame_counts_df = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
74
  elif tracking_algorithm == 'strong_sort':
75
  device_strongsort = torch.device('cuda:0')
76
+ output_path, df, frame_counts_df = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True, hide_labels = True)
77
  else:
78
+ output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True, hide_labels = True)
79
  # Assuming output_path is the path to the output file
80
  _, output_extension = os.path.splitext(output_path)
81
  palette = {"Bus": "red", "Bike": "blue", "Car": "green", "Pedestrian": "yellow", "Truck": "purple"}