add LLMM_Silu mistral

This commit is contained in:
Mohit Sharma 2024-05-03 03:37:48 +00:00
parent caf07decf0
commit ca5ea45181
2 changed files with 31 additions and 14 deletions

View File

@ -25,7 +25,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -206,15 +206,14 @@ class FlashLlamaAttention(torch.nn.Module):
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.act_func = config.hidden_act self.hidden_act = config.hidden_act
act = config.hidden_act
self.act = ( self.act = (
ACT2FN[act] ACT2FN[self.hidden_act]
if "gelu" not in act if "gelu" not in self.hidden_act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate=( approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" "tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
), ),
) )
) )
@ -245,7 +244,7 @@ class LlamaMLP(nn.Module):
) )
def forward(self, hidden_states): def forward(self, hidden_states):
if IS_ROCM_SYSTEM and self.act_func == "silu" and hidden_states.shape[0] == 1: if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
out = torch.empty( out = torch.empty(
hidden_states.shape[0], hidden_states.shape[0],
self.intermediate_size, self.intermediate_size,

View File

@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -38,6 +39,13 @@ from text_generation_server.utils.layers import (
) )
if IS_ROCM_SYSTEM:
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class MistralConfig(PretrainedConfig): class MistralConfig(PretrainedConfig):
model_type = "mistral" model_type = "mistral"
@ -249,14 +257,14 @@ class MistralAttention(torch.nn.Module):
class MistralMLP(nn.Module): class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
act = config.hidden_act self.hidden_act = config.hidden_act
self.act = ( self.act = (
ACT2FN[act] ACT2FN[self.hidden_act]
if "gelu" not in act if "gelu" not in self.hidden_act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate=( approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" "tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
), ),
) )
) )
@ -279,9 +287,19 @@ class MistralMLP(nn.Module):
) )
def forward(self, hidden_states): def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states) if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) out = torch.empty(
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
class MistralLayer(nn.Module): class MistralLayer(nn.Module):