Kernels

Update torch-ext/triton_kernels/target_info.py num_sms() to support pytorch-rocm cuda and gpt-oss models

#6
torch-ext/triton_kernels/target_info.py CHANGED
@@ -96,3 +96,5 @@ def num_sms():
96
  return torch.cuda.get_device_properties(0).multi_processor_count
97
  if is_xpu():
98
  return torch.xpu.get_device_properties(0).max_compute_units
 
 
 
96
  return torch.cuda.get_device_properties(0).multi_processor_count
97
  if is_xpu():
98
  return torch.xpu.get_device_properties(0).max_compute_units
99
+ if is_hip() and torch.cuda.is_available():
100
+ return torch.cuda.get_device_properties(0).multi_processor_count