From 8617d4795a64873732e76b03d24fa683d091a72b Mon Sep 17 00:00:00 2001 From: Felix Marty Date: Thu, 9 Nov 2023 09:38:32 +0000 Subject: [PATCH] move controlflow in forward --- .../custom_modeling/flash_llama_modeling.py | 2 +- .../custom_modeling/flash_mistral_modeling.py | 2 + server/text_generation_server/utils/layers.py | 61 +++++++++---------- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index f46c9192..ad0b20b5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -161,7 +161,7 @@ class LlamaRMSNorm(nn.Module): ) return out, residual else: - raise RuntimeError("system not supported") + raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") def load_attention(config, prefix, weights): diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 7d91722e..ec05bc35 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -166,6 +166,8 @@ class MistralRMSNorm(nn.Module): self.variance_epsilon, ) return out, residual + else: + raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") def load_attention(config, prefix, weights): diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 23e313ef..9c8b2ade 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -556,32 +556,6 @@ try: from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb - def rope_forward_cuda(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): - rotary_dim = cos.shape[-1] - x1 = x[..., :rotary_dim] - x2 = x[..., rotary_dim : 2 * rotary_dim] - - rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) - return x - elif IS_ROCM_SYSTEM: - # For RoCm, we fall back on a manual implementation given that Flash Attention's ROPE kernel can not be compiled for RoCm. - # We could use VLLM ROPE kernel here (compatible with RoCm), but the API is different and would require position_ids: https://github.com/vllm-project/vllm/blob/1a2bbc930135cd3b94fbff2aafbdf5c568acc8bd/csrc/pos_encoding.cpp#L3 - def rope_forward_rocm(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): - rotary_dim = cos.shape[-1] - - dtype = x.dtype - x_upcast = x.to(torch.float32) - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) - - x1 = x_upcast[..., :rotary_dim] - x2 = x_upcast[..., rotary_dim : 2 * rotary_dim] - - # Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well. - x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype) - x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype) - return x - def _create_inv_freq(dim, base, device): inv_freq = 1.0 / ( base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) @@ -609,6 +583,36 @@ try: self.scaling_factor = scaling_factor self.dynamic_args = None + def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + # Such controlflows may add some overhead. + if IS_CUDA_SYSTEM: + rotary_dim = cos.shape[-1] + x1 = x[..., :rotary_dim] + x2 = x[..., rotary_dim : 2 * rotary_dim] + + rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) + return x + elif IS_ROCM_SYSTEM: + # For RoCm, we fall back on a manual implementation given that Flash Attention's ROPE kernel can not be compiled for RoCm. + # We could use VLLM ROPE kernel here (compatible with RoCm), but the API is different and would require position_ids: https://github.com/vllm-project/vllm/blob/1a2bbc930135cd3b94fbff2aafbdf5c568acc8bd/csrc/pos_encoding.cpp#L3 + def rope_forward_rocm(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + + dtype = x.dtype + x_upcast = x.to(torch.float32) + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + x1 = x_upcast[..., :rotary_dim] + x2 = x_upcast[..., rotary_dim : 2 * rotary_dim] + + # Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well. + x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype) + x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype) + return x + else: + raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.") + @classmethod def static(cls, config, dim, base, device): inv_freq = _create_inv_freq(dim, base, device) @@ -718,11 +722,6 @@ try: sin = torch.index_select(self._sin_cached, 0, position_ids) return cos.unsqueeze(1), sin.unsqueeze(1) - if IS_CUDA_SYSTEM: - PositionRotaryEmbedding.forward = rope_forward_cuda - elif IS_ROCM_SYSTEM: - PositionRotaryEmbedding.forward = rope_forward_rocm - class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): inv_freq = _create_inv_freq(dim, base, device)