mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
add LLMM_Silu
This commit is contained in:
parent
aef931ea5d
commit
fbc5a6a120
@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
@ -37,6 +38,12 @@ from text_generation_server.utils.layers import (
|
|||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if IS_ROCM_SYSTEM:
|
||||||
|
try:
|
||||||
|
from vllm import _custom_C
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
class LlamaConfig(PretrainedConfig):
|
class LlamaConfig(PretrainedConfig):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -245,6 +252,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.act_func = config.hidden_act
|
||||||
act = config.hidden_act
|
act = config.hidden_act
|
||||||
self.act = (
|
self.act = (
|
||||||
ACT2FN[act]
|
ACT2FN[act]
|
||||||
@ -275,9 +283,19 @@ class LlamaMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
if IS_ROCM_SYSTEM and self.act_func == "silu" and hidden_states.shape[0] == 1:
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
out = torch.empty(
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
hidden_states.shape[0],
|
||||||
|
self.intermediate_size,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||||
|
return self.down_proj(out)
|
||||||
|
else:
|
||||||
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
|
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaLayer(nn.Module):
|
class FlashLlamaLayer(nn.Module):
|
||||||
|
@ -73,6 +73,7 @@ if IS_ROCM_SYSTEM:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
# Monkey patching
|
# Monkey patching
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_layer_norm(cls, prefix, weights, eps):
|
def load_layer_norm(cls, prefix, weights, eps):
|
||||||
@ -329,6 +330,7 @@ def warn_deprecate_bnb():
|
|||||||
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
|
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FastLinearROCm(nn.Module):
|
class FastLinearROCm(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -361,14 +363,12 @@ class FastLinearROCm(nn.Module):
|
|||||||
if inp.dim() == 3:
|
if inp.dim() == 3:
|
||||||
inp = inp.view(-1, inp.size(-1))
|
inp = inp.view(-1, inp.size(-1))
|
||||||
batched = True
|
batched = True
|
||||||
|
|
||||||
m, k = weight.shape[0], inp.shape[1]
|
m, k = weight.shape[0], inp.shape[1]
|
||||||
out = torch.empty(inp.shape[0],
|
out = torch.empty(
|
||||||
weight.shape[0],
|
inp.shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
|
||||||
dtype=inp.dtype,
|
)
|
||||||
device='cuda')
|
if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
||||||
if (k == 8192 and
|
|
||||||
(m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
|
|
||||||
_custom_C.LLMM1(weight, inp, out, 8)
|
_custom_C.LLMM1(weight, inp, out, 8)
|
||||||
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
|
elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
|
||||||
_custom_C.LLMM1(weight, inp, out, 4)
|
_custom_C.LLMM1(weight, inp, out, 4)
|
||||||
@ -1293,4 +1293,4 @@ try:
|
|||||||
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user