From 946bf4424288b79851ee317640086fa6c53611f6 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 5 Apr 2024 18:42:33 +0200 Subject: [PATCH] fix cohere --- Dockerfile | 2 +- .../custom_modeling/flash_cohere_modeling.py | 73 +++++++++++++++++-- .../models/flash_cohere.py | 2 +- 3 files changed, 68 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index 793e4e99..e8c3096e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -211,7 +211,7 @@ COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /op COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages # Install flash-attention dependencies -RUN pip install einops --no-cache-dir +RUN pip install einops prometheus_client --no-cache-dir # Install server COPY proto proto diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index a652c1ca..c9d87972 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -27,6 +27,7 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils.import_utils import IS_ROCM_SYSTEM, IS_CUDA_SYSTEM from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -34,9 +35,65 @@ from text_generation_server.utils.layers import ( PositionRotaryEmbedding, SpeculativeHead, get_linear, - FastRMSNorm, + FastLayerNorm, ) +if IS_CUDA_SYSTEM: + import dropout_layer_norm +else: + dropout_layer_norm = None + + +class CohereLayerNorm(nn.Module): + def __init__(self, prefix, weights, eps): + super().__init__() + weight = weights.get_tensor(f"{prefix}.weight") + self.weight = nn.Parameter(weight) + # Fake weights + self.ones = weight.new_ones(weight.shape[1]) + self.eps = eps + + def forward(self, hidden_states): + if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + hidden_states_minus_mean = hidden_states - mean + variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps) + hidden_states = self.weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + + ( + hidden_states, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + None, + self.ones, + None, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + + # Required to apply one weight matrix per head + hidden_states = hidden_states.view( + -1, self.weight.shape[0], self.weight.shape[1] + ) + hidden_states = self.weight * hidden_states + hidden_states = hidden_states.view(-1, self.weight.shape[1]) + + return hidden_states + class CohereConfig(PretrainedConfig): def __init__( @@ -180,7 +237,7 @@ class FlashCohereAttention(torch.nn.Module): self.use_qk_norm = config.use_qk_norm if self.use_qk_norm: rank = weights.process_group.rank() - self.q_norm = FastRMSNorm.load( + self.q_norm = CohereLayerNorm( prefix=f"{prefix}.q_norm", weights=weights, eps=config.layer_norm_eps, @@ -188,7 +245,7 @@ class FlashCohereAttention(torch.nn.Module): self.q_norm.weight.data = self.q_norm.weight[ self.num_heads * rank : self.num_heads * (rank + 1) ] - self.k_norm = FastRMSNorm.load( + self.k_norm = CohereLayerNorm( prefix=f"{prefix}.k_norm", weights=weights, eps=config.layer_norm_eps, @@ -230,8 +287,10 @@ class FlashCohereAttention(torch.nn.Module): ) if self.use_qk_norm: - query, _ = self.q_norm(query.contiguous()) - key, _ = self.k_norm(key.contiguous()) + query = query.reshape(-1, self.head_size) + key = key.reshape(-1, self.head_size) + query = self.q_norm(query) + key = self.k_norm(key) query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_key_value_heads, self.head_size) @@ -324,7 +383,7 @@ class FlashCohereLayer(nn.Module): ) self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - self.input_layernorm = FastRMSNorm.load( + self.input_layernorm = FastLayerNorm.load_no_bias( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.layer_norm_eps, @@ -388,7 +447,7 @@ class FlashCohereModel(torch.nn.Module): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = FastRMSNorm.load( + self.norm = FastLayerNorm.load_no_bias( prefix="model.norm", weights=weights, eps=config.layer_norm_eps ) diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index 181a93b1..0c64a036 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -32,7 +32,7 @@ class FlashCohere(FlashCausalLM): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashCohere is only available on GPU")