add LLMM_Silu

This commit is contained in:
Mohit Sharma 2024-04-23 15:02:53 +00:00
parent aef931ea5d
commit fbc5a6a120
2 changed files with 29 additions and 11 deletions

View File

@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -37,6 +38,12 @@ from text_generation_server.utils.layers import (
FastRMSNorm, FastRMSNorm,
) )
if IS_ROCM_SYSTEM:
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class LlamaConfig(PretrainedConfig): class LlamaConfig(PretrainedConfig):
def __init__( def __init__(
@ -245,6 +252,7 @@ class FlashLlamaAttention(torch.nn.Module):
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.act_func = config.hidden_act
act = config.hidden_act act = config.hidden_act
self.act = ( self.act = (
ACT2FN[act] ACT2FN[act]
@ -275,9 +283,19 @@ class LlamaMLP(nn.Module):
) )
def forward(self, hidden_states): def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states) if IS_ROCM_SYSTEM and self.act_func == "silu" and hidden_states.shape[0] == 1:
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) out = torch.empty(
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):

View File

@ -73,6 +73,7 @@ if IS_ROCM_SYSTEM:
except Exception as e: except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
# Monkey patching # Monkey patching
@classmethod @classmethod
def load_layer_norm(cls, prefix, weights, eps): def load_layer_norm(cls, prefix, weights, eps):
@ -329,6 +330,7 @@ def warn_deprecate_bnb():
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
) )
class FastLinearROCm(nn.Module): class FastLinearROCm(nn.Module):
def __init__( def __init__(
self, self,
@ -361,14 +363,12 @@ class FastLinearROCm(nn.Module):
if inp.dim() == 3: if inp.dim() == 3:
inp = inp.view(-1, inp.size(-1)) inp = inp.view(-1, inp.size(-1))
batched = True batched = True
m, k = weight.shape[0], inp.shape[1] m, k = weight.shape[0], inp.shape[1]
out = torch.empty(inp.shape[0], out = torch.empty(
weight.shape[0], inp.shape[0], weight.shape[0], dtype=inp.dtype, device="cuda"
dtype=inp.dtype, )
device='cuda') if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
if (k == 8192 and
(m == 1280 or m == 7168)) or (k == 3584 and m == 8192):
_custom_C.LLMM1(weight, inp, out, 8) _custom_C.LLMM1(weight, inp, out, 8)
elif k <= 8192 and k % 8 == 0 and m % 4 == 0: elif k <= 8192 and k % 8 == 0 and m % 4 == 0:
_custom_C.LLMM1(weight, inp, out, 4) _custom_C.LLMM1(weight, inp, out, 4)
@ -1293,4 +1293,4 @@ try:
self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype)
except ImportError: except ImportError:
pass pass