# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team from ..config_v2 import RaggedInferenceEngineConfig from ..inference_utils import NormTypeEnum from .module_registry import ConfigBundle from ..modules.configs import ( DSEmbeddingsConfig, DSLinearConfig, DSMoEConfig, DSNormConfig, DSSelfAttentionConfig, DSUnembedConfig, ) from ..modules.interfaces import ( DSEmbeddingBase, DSEmbeddingRegistry, DSLinearBase, DSLinearRegistry, DSMoEBase, DSMoERegistry, DSPostNormBase, DSPostNormRegistry, DSPreNormBase, DSPreNormRegistry, DSSelfAttentionBase, DSSelfAttentionRegistry, DSUnembedBase, DSUnembedRegistry, ) def instantiate_attention(attention_config: DSSelfAttentionConfig, engine_config: RaggedInferenceEngineConfig) -> DSSelfAttentionBase: """ Choose an appropriate attention implementation based on the given configurations. This method is currently a stub, but as more implementations may be developed we can centralize the logic for choosing between them here. Arguments: attention_config (DSSelfAttentionConfig): Configuration for the attention module. engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. Returns: An attention module implementing the given configuration. """ # Currently, we only have one implementation, so we just return it. config = ConfigBundle(name="dense_blocked_attention", config=attention_config) return DSSelfAttentionRegistry.instantiate_config(config) def instantiate_embed(embed_config: DSEmbeddingsConfig, engine_config: RaggedInferenceEngineConfig) -> DSEmbeddingBase: """ Choose an appropriate embedding implementation based on the given configurations. This method is currently a stub, but as more implementations may be developed we can centralize the logic for choosing between them here. Arguments: embed_config (DSEmbeddingsConfig): Configuration for the embedding module. engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. Returns: An embedding module implementing the given configuration. """ # Currently, we only have one implementation, so we just return it. config = ConfigBundle(name="ragged_embedding", config=embed_config) return DSEmbeddingRegistry.instantiate_config(config) def instantiate_linear(linear_config: DSLinearConfig, engine_config: RaggedInferenceEngineConfig) -> DSLinearBase: """ Choose an appropriate linear implementation based on the given configurations. This method is currently a stub, but as more implementations may be developed we can centralize the logic for choosing between them here. Arguments: linear_config (DSLinearConfig): Configuration for the linear module. engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. Returns: A linear module implementing the given configuration. """ quantization_mode = engine_config.quantization.quantization_mode if quantization_mode is None: config = ConfigBundle(name="blas_fp_linear", config=linear_config) else: # Currently, we only support ``quantized_wf6af16_linear`` on NVIDIA Ampere GPUs. if quantization_mode == "wf6af16": import torch if not torch.cuda.is_available(): #ignore-cuda raise ValueError("WF6AF16 quantization is only supported on CUDA") else: is_rocm_pytorch = hasattr(torch.version, 'hip') and torch.version.hip is not None if is_rocm_pytorch: raise ValueError("WF6AF16 quantization is only supported on NVIDIA GPUs") elif torch.cuda.get_device_properties(0).major != 8: #ignore-cuda raise ValueError("WF6AF16 quantization is only supported on Ampere architectures") config = ConfigBundle(name="quantized_wf6af16_linear", config=linear_config) else: raise ValueError(f"Unsupported quantization mode: {quantization_mode}") return DSLinearRegistry.instantiate_config(config) def instantiate_moe(moe_config: DSMoEConfig, engine_config: RaggedInferenceEngineConfig) -> DSMoEBase: """ Choose an appropriate MoE implementation based on the given configurations. This method is currently a stub, but as more implementations may be developed we can centralize the logic for choosing between them here. Arguments: moe_config (DSMoEConfig): Configuration for the MoE module. engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. Returns: A MoE module implementing the given configuration. """ moe_type = "cutlass_multi_gemm_moe" if moe_type == "cutlass_multi_gemm_moe": # TODO: Get this off an engine config implementation_config = { "weight_dtype": moe_config.input_dtype, } # Currently, we only have one implementation, so we just return it. config = ConfigBundle(name="cutlass_multi_gemm_moe", config=moe_config, implementation_config=implementation_config) return DSMoERegistry.instantiate_config(config) def instantiate_post_norm(norm_config: DSNormConfig, engine_config: RaggedInferenceEngineConfig) -> DSPostNormBase: """ Choose an appropriate post-norm implementation based on the given configurations. This method is currently a stub, but as more implementations may be developed we can centralize the logic for choosing between them here. Arguments: norm_config (DSNormConfig): Configuration for the post-norm module. engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. Returns: A post-norm module implementing the given configuration. """ # Currently, we only have one implementation, so we just return it. config = ConfigBundle(name="cuda_post_ln", config=norm_config) return DSPostNormRegistry.instantiate_config(config) def instantiate_pre_norm(norm_config: DSNormConfig, engine_config: RaggedInferenceEngineConfig) -> DSPreNormBase: """ Choose an appropriate pre-norm implementation based on the given configurations. Currently, this will select between two CUDA implementations, one for LayerNorm and one for RMSNorm. Arguments: norm_config (DSNormConfig): Configuration for the pre-norm module. engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. Returns: A pre-norm module implementing the given configuration. """ if NormTypeEnum(norm_config.type) == NormTypeEnum.LayerNorm: module_name = "cuda_pre_ln" elif NormTypeEnum(norm_config.type) == NormTypeEnum.RMSNorm: module_name = "cuda_pre_rms" config = ConfigBundle(name=module_name, config=norm_config) return DSPreNormRegistry.instantiate_config(config) def instantiate_unembed(unembed_config: DSUnembedConfig, engine_config: RaggedInferenceEngineConfig) -> DSUnembedBase: """ Choose an appropriate unembedding implementation based on the given configurations. This method is currently a stub, but as more implementations may be developed we can centralize the logic for choosing between them here. Arguments: unembed_config (DSUnembedConfig): Configuration for the unembed module. engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine. Returns: An unembed module implementing the given configuration. """ # Currently, we only have one implementation, so we just return it. config = ConfigBundle(name="ragged_unembed", config=unembed_config) return DSUnembedRegistry.instantiate_config(config)