From 07a3050b208b0f6566de7b73a40b32bb9910bee1 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Wed, 10 Apr 2024 16:47:41 +0200 Subject: [PATCH] fixed --- server/pyproject.toml | 2 +- .../custom_modeling/flash_cohere_modeling.py | 142 ++++++------------ server/text_generation_server/utils/layers.py | 7 - 3 files changed, 43 insertions(+), 108 deletions(-) diff --git a/server/pyproject.toml b/server/pyproject.toml index 7595752a..b5e9c3fd 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -26,7 +26,7 @@ hf-transfer = "^0.1.2" sentencepiece = "^0.1.97" tokenizers = "^0.15.0" huggingface-hub = "^0.19.3" -transformers = { git = "https://github.com/huggingface/transformers", rev = "517a3e670d8fc11374895e870dd0dd041467c7fe" } +transformers = "^4.38" einops = "^0.6.1" texttable = { version = "^1.6.7", optional = true } datasets = { version = "^2.14.0", optional = true } diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 468615a7..56d9a966 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -43,6 +43,43 @@ else: dropout_layer_norm = None +class CohereRotary(PositionRotaryEmbedding): + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + # Such controlflows may add some overhead. + if IS_CUDA_SYSTEM: + import rotary_emb + + q1 = query[..., ::2] + q2 = query[..., 1::2] + + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + + k1 = key[..., ::2] + k2 = key[..., 1::2] + + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + elif IS_ROCM_SYSTEM: + from vllm import pos_encoding_ops + + # NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems. + # Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773 + + head_size = query.shape[-1] + + # Inplace operation, updating query and key. + pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, False) + else: + raise ValueError( + "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." + ) + + class CohereLayerNorm(nn.Module): def __init__(self, prefix, weights, eps): super().__init__() @@ -146,93 +183,6 @@ def _load_gqa(config, prefix: str, weights): ) -class CohereRotaryEmbedding(nn.Module): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - ): - super().__init__() - self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / ( - self.base - ** ( - torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) - / self.dim - ) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - @torch.no_grad() - def forward(self, device_type, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - position_ids_expanded = position_ids[None, None, :].float() - - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = ( - device_type - if isinstance(device_type, str) and device_type != "mps" - else "cpu" - ) - with torch.autocast(device_type=device_type, enabled=False): - freqs = ( - inv_freq_expanded.float().to(position_ids.device) - @ position_ids_expanded.float() - ).transpose(1, 2) - emb = torch.repeat_interleave(freqs, 2, dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos[0], sin[0] - - -def rotate_half(x): - # Split and rotate - x1 = x[..., ::2] - x2 = x[..., 1::2] - rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) - return rot_x - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - dtype = q.dtype - q = q.float() - k = k.float() - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) - - class FlashCohereAttention(torch.nn.Module): def __init__( self, @@ -245,7 +195,7 @@ class FlashCohereAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.static( + self.rotary_emb = CohereRotary.static( config=config, dim=self.head_size, base=config.rope_theta, @@ -325,9 +275,7 @@ class FlashCohereAttention(torch.nn.Module): key = key.view(-1, self.num_key_value_heads, self.head_size) value = value.view(-1, self.num_key_value_heads, self.head_size) - query, key = apply_rotary_pos_emb(query, key, cos, sin) - - # self.rotary_emb(query, key, true_cos.reshape(*cos.shape), true_sin.reshape(*sin.shape)) + self.rotary_emb(query, key, cos, sin) paged_attention.reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) @@ -487,11 +435,6 @@ class FlashCohereModel(torch.nn.Module): self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads - self.rotary_true = CohereRotaryEmbedding( - self.head_size, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) def forward( self, @@ -508,10 +451,9 @@ class FlashCohereModel(torch.nn.Module): # Get rotary cos and sin for this forward # Avoid to index in each layer - # cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( - # position_ids, max_s, hidden_states.dtype - # ) - cos, sin = self.rotary_true(hidden_states.device.type, position_ids) + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) residual = None for i, layer in enumerate(self.layers): diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index ad70651f..f29e55c5 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -19,7 +19,6 @@ from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM -from text_generation_server.utils.log import log_once HAS_AWQ = True try: @@ -35,12 +34,6 @@ except Exception: HAS_EXLLAMA = False CAN_EXLLAMA = major >= 8 or IS_ROCM_SYSTEM V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" -# if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: -# V2 = False -# log_once( -# logger.warning, -# "Disabling exllama v2 and using v1 instead because there are issues when sharding", -# ) if os.getenv("DISABLE_EXLLAMA") == "True": HAS_EXLLAMA = False