Update torch-ext/triton_kernels/target_info.py num_sms() to support pytorch-rocm cuda and gpt-oss models
Browse fileswhen running transformers on AMD gfx1151 matmul_ogs kernel fails with following error
```
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/transformers/utils/generic.py", line 936, in wrapper
    outputs = func(self, *args, **kwargs_without_recordable)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 507, in forward
    hidden_states = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 386, in forward
    hidden_states, _ = self.mlp(hidden_states)  # diff with llama: router scores
                       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 310, in mlp_forward
    routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 213, in forward
    intermediate_cache3 = matmul_ogs(
                          ^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 583, in matmul_ogs
    out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 252, in apply_postprocessing_features
    grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0]
                                          ^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 223, in compute_grid
    num_pid = target_info.num_sms() * (warps_per_sm // num_warps)
              ~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
TypeError: unsupported operand type(s) for *: 'NoneType' and 'int'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/root/run_model_gpt-oss-20b.py", line 20, in <module>
    outputs = model.generate(**inputs, max_new_tokens=40)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 121, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/generation/utils.py", line 2542, in generate
    result = decoding_method(
             ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/generation/utils.py", line 2762, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/utils/generic.py", line 783, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 668, in forward
    outputs: MoeModelOutputWithPast = self.model(
                                      ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/utils/generic.py", line 938, in wrapper
    raise original_exception
  File "/usr/local/lib/python3.12/site-packages/transformers/utils/generic.py", line 929, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 507, in forward
    hidden_states = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 386, in forward
    hidden_states, _ = self.mlp(hidden_states)  # diff with llama: router scores
                       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 310, in mlp_forward
    routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 213, in forward
    intermediate_cache3 = matmul_ogs(
                          ^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 583, in matmul_ogs
    out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 252, in apply_postprocessing_features
    grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0]
                                          ^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 223, in compute_grid
    num_pid = target_info.num_sms() * (warps_per_sm // num_warps)
```
the cause is num_sms functions return None in this case. proposed change allowed me to get gpt-oss running in transformers on AMD gfx1151 gpu, however I did not made full regression tests if this eventually is breaking something only thing I tested is running also with this change gemma-3-270m-it
| @@ -96,3 +96,5 @@ def num_sms(): | |
| 96 | 
             
                    return torch.cuda.get_device_properties(0).multi_processor_count
         | 
| 97 | 
             
                if is_xpu():
         | 
| 98 | 
             
                    return torch.xpu.get_device_properties(0).max_compute_units
         | 
|  | |
|  | 
|  | |
| 96 | 
             
                    return torch.cuda.get_device_properties(0).multi_processor_count
         | 
| 97 | 
             
                if is_xpu():
         | 
| 98 | 
             
                    return torch.xpu.get_device_properties(0).max_compute_units
         | 
| 99 | 
            +
                if is_hip() and torch.cuda.is_available():
         | 
| 100 | 
            +
                    return torch.cuda.get_device_properties(0).multi_processor_count
         | 
