From ca5ea4518170b5449db5c7c75f386a56e240774f Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Fri, 3 May 2024 03:37:48 +0000 Subject: [PATCH] add LLMM_Silu mistral --- .../custom_modeling/flash_llama_modeling.py | 13 ++++---- .../custom_modeling/flash_mistral_modeling.py | 32 +++++++++++++++---- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 72ccc1cc..9a21d043 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -25,7 +25,7 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -206,15 +206,14 @@ class FlashLlamaAttention(torch.nn.Module): class LlamaMLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - self.act_func = config.hidden_act - act = config.hidden_act + self.hidden_act = config.hidden_act self.act = ( - ACT2FN[act] - if "gelu" not in act + ACT2FN[self.hidden_act] + if "gelu" not in self.hidden_act else lambda x: torch.nn.functional.gelu( x, approximate=( - "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + "tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" ), ) ) @@ -245,7 +244,7 @@ class LlamaMLP(nn.Module): ) def forward(self, hidden_states): - if IS_ROCM_SYSTEM and self.act_func == "silu" and hidden_states.shape[0] == 1: + if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1: out = torch.empty( hidden_states.shape[0], self.intermediate_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index c2445cda..65b4e7ca 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -26,6 +26,7 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple +from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -38,6 +39,13 @@ from text_generation_server.utils.layers import ( ) +if IS_ROCM_SYSTEM: + try: + from vllm import _custom_C + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + + class MistralConfig(PretrainedConfig): model_type = "mistral" @@ -249,14 +257,14 @@ class MistralAttention(torch.nn.Module): class MistralMLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() - act = config.hidden_act + self.hidden_act = config.hidden_act self.act = ( - ACT2FN[act] - if "gelu" not in act + ACT2FN[self.hidden_act] + if "gelu" not in self.hidden_act else lambda x: torch.nn.functional.gelu( x, approximate=( - "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + "tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" ), ) ) @@ -279,9 +287,19 @@ class MistralMLP(nn.Module): ) def forward(self, hidden_states): - gate_up_states = self.gate_up_proj(hidden_states) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1: + out = torch.empty( + hidden_states.shape[0], + self.intermediate_size, + dtype=hidden_states.dtype, + device="cuda", + ) + _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + return self.down_proj(out) + else: + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) class MistralLayer(nn.Module):