Update torch-ext/triton_kernels/target_info.py num_sms() to support pytorch-rocm cuda and gpt-oss models
#6
by
janantos
- opened
when running transformers on AMD gfx1151 matmul_ogs kernel fails with following error
Traceback (most recent call last):
File "/usr/local/lib/python3.12/site-packages/transformers/utils/generic.py", line 936, in wrapper
outputs = func(self, *args, **kwargs_without_recordable)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 507, in forward
hidden_states = decoder_layer(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 386, in forward
hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 310, in mlp_forward
routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 213, in forward
intermediate_cache3 = matmul_ogs(
^^^^^^^^^^^
File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 583, in matmul_ogs
out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 252, in apply_postprocessing_features
grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0]
^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 223, in compute_grid
num_pid = target_info.num_sms() * (warps_per_sm // num_warps)
~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
TypeError: unsupported operand type(s) for *: 'NoneType' and 'int'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/root/run_model_gpt-oss-20b.py", line 20, in <module>
outputs = model.generate(**inputs, max_new_tokens=40)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 121, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/generation/utils.py", line 2542, in generate
result = decoding_method(
^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/generation/utils.py", line 2762, in _sample
outputs = self(**model_inputs, return_dict=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/utils/generic.py", line 783, in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 668, in forward
outputs: MoeModelOutputWithPast = self.model(
^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/utils/generic.py", line 938, in wrapper
raise original_exception
File "/usr/local/lib/python3.12/site-packages/transformers/utils/generic.py", line 929, in wrapper
outputs = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 507, in forward
hidden_states = decoder_layer(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/models/gpt_oss/modeling_gpt_oss.py", line 386, in forward
hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 310, in mlp_forward
routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1788, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/site-packages/transformers/integrations/mxfp4.py", line 213, in forward
intermediate_cache3 = matmul_ogs(
^^^^^^^^^^^
File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 583, in matmul_ogs
out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 252, in apply_postprocessing_features
grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0]
^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/hub/models--kernels-community--triton_kernels/snapshots/9c3f05aab46799e6c50dd1d3e760affc7096fe22/build/torch-universal/triton_kernels/matmul_ogs.py", line 223, in compute_grid
num_pid = target_info.num_sms() * (warps_per_sm // num_warps)
the cause is num_sms functions return None in this case. proposed change allowed me to get gpt-oss running in transformers on AMD gfx1151 gpu, however I did not made full regression tests if this eventually is breaking something only thing I tested is running also with this change gemma-3-270m-it
I have strong believe that previous commit https://huggingface.co/kernels-community/triton_kernels/discussions/5 caused regression, from quick look on code, AMD pytorch-rocm shall work with version before that commit
janantos
changed pull request status to
closed
janantos
changed pull request status to
open
Can we please merge this change?