diff --git a/server/text_generation_server/layers/layernorm.py b/server/text_generation_server/layers/layernorm.py index 15d24e80..c4aa6c7d 100644 --- a/server/text_generation_server/layers/layernorm.py +++ b/server/text_generation_server/layers/layernorm.py @@ -72,7 +72,7 @@ if SYSTEM == "cuda": return normed_hidden_states, residual elif SYSTEM == "rocm": - from vllm import layernorm_ops + from vllm._C import ops class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): @@ -172,7 +172,7 @@ class FastRMSNorm(nn.Module): residual = hidden_states out = torch.empty_like(hidden_states) - layernorm_ops.rms_norm( + ops.rms_norm( out, hidden_states, self.weight.data, diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index d137a500..90978d6a 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -2,6 +2,11 @@ import torch from torch.nn import functional as F from text_generation_server.utils.import_utils import SYSTEM +if SYSTEM == "rocm": + try: + from vllm import _custom_C + except Exception as e: + raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") class FastLinear(torch.nn.Module): def __init__( @@ -29,9 +34,63 @@ class FastLinear(torch.nn.Module): return F.linear(input, self.weight, self.bias) +class FastLinearROCm(nn.Module): + def __init__( + self, + weight, + bias, + ) -> None: + super().__init__() + self.weight = nn.Parameter(weight) + if bias is not None: + self.bias = nn.Parameter(bias) + else: + self.bias = None + + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): + weight = weights.get_tensor(f"{prefix}.weight") + if bias: + bias = weights.get_tensor(f"{prefix}.bias") + else: + bias = None + return cls(weight, bias) + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + weight = self.weight + bias = self.bias + + if IS_ROCM_SYSTEM and inp.numel() // inp.size(-1) == 1: + batched = False + + if inp.dim() == 3: + inp = inp.view(-1, inp.size(-1)) + batched = True + + m, k = weight.shape[0], inp.shape[1] + out = torch.empty( + inp.shape[0], weight.shape[0], dtype=inp.dtype, device="cuda" + ) + if (k == 8192 and (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): + _custom_C.LLMM1(weight, inp, out, 8) + elif k <= 8192 and k % 8 == 0 and m % 4 == 0: + _custom_C.LLMM1(weight, inp, out, 4) + else: + out = F.linear(inp, weight) + if batched: + out = out.view(inp.shape[0], inp.shape[1], weight.shape[0]) + if bias is not None: + out = out + bias + return out + return F.linear(inp, self.weight, self.bias) + + def get_linear(weight, bias, quantize): if quantize is None: - linear = FastLinear(weight, bias) + if SYSTEM == "rocm": + linear = FastLinearROCm(weight, bias) + else: + linear = FastLinear(weight, bias) elif quantize == "eetq": try: from text_generation_server.layers.eetq import EETQLinear diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 503dd554..198e5d8d 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -8,7 +8,7 @@ if SYSTEM == "cuda": from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb elif SYSTEM == "rocm": - from vllm import pos_encoding_ops + from vllm._C import ops def _create_inv_freq(dim, base, device): @@ -66,7 +66,7 @@ class PositionRotaryEmbedding(nn.Module): head_size = query.shape[-1] # Inplace operation, updating query and key. - pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) + ops.rotary_embedding(query, key, head_size, cos, sin, True) elif SYSTEM == "xpu": ipex.llm.functional.rotary_embedding( query, key, sin, cos, query.size(-1), True