ntsc207 commited on
Commit
a1225a7
·
verified ·
1 Parent(s): 74a077f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -31,18 +31,19 @@ def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm
31
  img.save(img_path)
32
  input_path = img_path
33
  print(input_path)
34
- output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', hide_conf= True)
35
  elif vid_path is not None:
36
  #_, vid_extension = os.path.splitext(vid_path)
37
  #if vid_extension.lower() in vid_extensions:
38
  input_path = vid_path
39
  print(input_path)
40
  if tracking_algorithm == 'deep_sort':
41
- output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', draw_trails=True)
42
  elif tracking_algorithm == 'strong_sort':
43
- output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
 
44
  else:
45
- output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='cpu', hide_conf= True)
46
  # Assuming output_path is the path to the output file
47
  _, output_extension = os.path.splitext(output_path)
48
  if output_extension.lower() in img_extensions:
 
31
  img.save(img_path)
32
  input_path = img_path
33
  print(input_path)
34
+ output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
35
  elif vid_path is not None:
36
  #_, vid_extension = os.path.splitext(vid_path)
37
  #if vid_extension.lower() in vid_extensions:
38
  input_path = vid_path
39
  print(input_path)
40
  if tracking_algorithm == 'deep_sort':
41
+ output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
42
  elif tracking_algorithm == 'strong_sort':
43
+ device_strongsort = torch.device('cuda:0')
44
+ output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
45
  else:
46
+ output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
47
  # Assuming output_path is the path to the output file
48
  _, output_extension = os.path.splitext(output_path)
49
  if output_extension.lower() in img_extensions: