From fbc5a6a120b01116d9965e2c9b543ef8e27521b9 Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 23 Apr 2024 15:02:53 +0000 Subject: [PATCH] add LLMM_Silu --- .../custom_modeling/flash_llama_modeling.py | 24 ++++++++++++++++--- server/text_generation_server/utils/layers.py | 16 ++++++------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 4cf0fcf2..313db503 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -26,6 +26,7 @@ from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple +from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( TensorParallelRowLinear, @@ -37,6 +38,12 @@ from text_generation_server.utils.layers import ( FastRMSNorm, ) +if IS_ROCM_SYSTEM: + try: + from vllm import _custom_C + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + class LlamaConfig(PretrainedConfig): def __init__( @@ -245,6 +252,7 @@ class FlashLlamaAttention(torch.nn.Module): class LlamaMLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() + self.act_func = config.hidden_act act = config.hidden_act self.act = ( ACT2FN[act] @@ -275,9 +283,19 @@ class LlamaMLP(nn.Module): ) def forward(self, hidden_states): - gate_up_states = self.gate_up_proj(hidden_states) - gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + if IS_ROCM_SYSTEM and self.act_func == "silu" and hidden_states.shape[0] == 1: + out = torch.empty( + hidden_states.shape[0], + self.intermediate_size, + dtype=hidden_states.dtype, + device="cuda", + ) + _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + return self.down_proj(out) + else: + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) class FlashLlamaLayer(nn.Module): diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 8e36f654..38dfe0f0 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -73,6 +73,7 @@ if IS_ROCM_SYSTEM: except Exception as e: raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") + # Monkey patching @classmethod def load_layer_norm(cls, prefix, weights, eps): @@ -329,6 +330,7 @@ def warn_deprecate_bnb(): "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" ) + class FastLinearROCm(nn.Module): def __init__( self, @@ -361,14 +363,12 @@ class FastLinearROCm(nn.Module): if inp.dim() == 3: inp = inp.view(-1, inp.size(-1)) batched = True - + m, k = weight.shape[0], inp.shape[1] - out = torch.empty(inp.shape[0], - weight.shape[0], - dtype=inp.dtype, - device='cuda') - if (k == 8192 and - (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): + out = torch.empty( + inp.shape[0], weight.shape[0], dtype=inp.dtype, device="cuda" + ) + if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): _custom_C.LLMM1(weight, inp, out, 8) elif k <= 8192 and k % 8 == 0 and m % 4 == 0: _custom_C.LLMM1(weight, inp, out, 4) @@ -1293,4 +1293,4 @@ try: self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) except ImportError: - pass \ No newline at end of file + pass