mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
add LLMM_Silu mistral
This commit is contained in:
parent
caf07decf0
commit
ca5ea45181
@ -25,7 +25,7 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -206,15 +206,14 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.act_func = config.hidden_act
|
self.hidden_act = config.hidden_act
|
||||||
act = config.hidden_act
|
|
||||||
self.act = (
|
self.act = (
|
||||||
ACT2FN[act]
|
ACT2FN[self.hidden_act]
|
||||||
if "gelu" not in act
|
if "gelu" not in self.hidden_act
|
||||||
else lambda x: torch.nn.functional.gelu(
|
else lambda x: torch.nn.functional.gelu(
|
||||||
x,
|
x,
|
||||||
approximate=(
|
approximate=(
|
||||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
"tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -245,7 +244,7 @@ class LlamaMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if IS_ROCM_SYSTEM and self.act_func == "silu" and hidden_states.shape[0] == 1:
|
if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
||||||
out = torch.empty(
|
out = torch.empty(
|
||||||
hidden_states.shape[0],
|
hidden_states.shape[0],
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
|
@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -38,6 +39,13 @@ from text_generation_server.utils.layers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if IS_ROCM_SYSTEM:
|
||||||
|
try:
|
||||||
|
from vllm import _custom_C
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
class MistralConfig(PretrainedConfig):
|
class MistralConfig(PretrainedConfig):
|
||||||
model_type = "mistral"
|
model_type = "mistral"
|
||||||
|
|
||||||
@ -249,14 +257,14 @@ class MistralAttention(torch.nn.Module):
|
|||||||
class MistralMLP(nn.Module):
|
class MistralMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_act
|
self.hidden_act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
ACT2FN[act]
|
ACT2FN[self.hidden_act]
|
||||||
if "gelu" not in act
|
if "gelu" not in self.hidden_act
|
||||||
else lambda x: torch.nn.functional.gelu(
|
else lambda x: torch.nn.functional.gelu(
|
||||||
x,
|
x,
|
||||||
approximate=(
|
approximate=(
|
||||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
"tanh" if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -279,9 +287,19 @@ class MistralMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
if IS_ROCM_SYSTEM and self.hidden_act == "silu" and hidden_states.shape[0] == 1:
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
out = torch.empty(
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
hidden_states.shape[0],
|
||||||
|
self.intermediate_size,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||||
|
return self.down_proj(out)
|
||||||
|
else:
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||||
|
|
||||||
|
|
||||||
class MistralLayer(nn.Module):
|
class MistralLayer(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user