Update torch-ext/triton_kernels/target_info.py num_sms() to support pytorch-rocm cuda and gpt-oss models
#6
by
janantos
- opened
torch-ext/triton_kernels/target_info.py
CHANGED
|
@@ -96,3 +96,5 @@ def num_sms():
|
|
| 96 |
return torch.cuda.get_device_properties(0).multi_processor_count
|
| 97 |
if is_xpu():
|
| 98 |
return torch.xpu.get_device_properties(0).max_compute_units
|
|
|
|
|
|
|
|
|
| 96 |
return torch.cuda.get_device_properties(0).multi_processor_count
|
| 97 |
if is_xpu():
|
| 98 |
return torch.xpu.get_device_properties(0).max_compute_units
|
| 99 |
+
if is_hip() and torch.cuda.is_available():
|
| 100 |
+
return torch.cuda.get_device_properties(0).multi_processor_count
|