Update torch-ext/triton_kernels/target_info.py num_sms() to support pytorch-rocm cuda and gpt-oss models

#6
by janantos - opened

when running transformers on AMD gfx1151 matmul_ogs kernel fails with following error

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/transformers/utils/generic.py", line 936, in wrapper
    outputs = func(self, *args, **kwargs_without_recordable)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 507, in forward
    hidden_states = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 386, in forward
    hidden_states, _ = self.mlp(hidden_states)  # diff with llama: router scores
                       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 310, in mlp_forward
    routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 213, in forward
    intermediate_cache3 = matmul_ogs(
                          ^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 583, in matmul_ogs
    out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 252, in apply_postprocessing_features
    grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0]
                                          ^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 223, in compute_grid
    num_pid = target_info.num_sms() * (warps_per_sm // num_warps)
              ~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
TypeError: unsupported operand type(s) for *: 'NoneType' and 'int'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/root/run_model_gpt-oss-20b.py", line 20, in <module>
    outputs = model.generate(**inputs, max_new_tokens=40)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 121, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/generation/utils.py", line 2542, in generate
    result = decoding_method(
             ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/generation/utils.py", line 2762, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/utils/generic.py", line 783, in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 668, in forward
    outputs: MoeModelOutputWithPast = self.model(
                                      ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/utils/generic.py", line 938, in wrapper
    raise original_exception
  File "/usr/local/lib/python3.12/site-packages/transformers/utils/generic.py", line 929, in wrapper
    outputs = func(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 507, in forward
    hidden_states = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 386, in forward
    hidden_states, _ = self.mlp(hidden_states)  # diff with llama: router scores
                       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 310, in mlp_forward
    routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 213, in forward
    intermediate_cache3 = matmul_ogs(
                          ^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 583, in matmul_ogs
    out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 252, in apply_postprocessing_features
    grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0]
                                          ^^^^^^^^^^^^^^^^
  File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 223, in compute_grid
    num_pid = target_info.num_sms() * (warps_per_sm // num_warps)

the cause is num_sms functions return None in this case. proposed change allowed me to get gpt-oss running in transformers on AMD gfx1151 gpu, however I did not made full regression tests if this eventually is breaking something only thing I tested is running also with this change gemma-3-270m-it

I have strong believe that previous commit https://huggingface.co/kernels-community/triton_kernels/discussions/5 caused regression, from quick look on code, AMD pytorch-rocm shall work with version before that commit

janantos changed pull request status to closed
janantos changed pull request status to open

Can we please merge this change?

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment