From e5e22993e7ee732d17c131f9f7f5ba2f66074b6a Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 23 Mar 2023 13:33:32 +0100 Subject: [PATCH] faster --- .../models/flash_neox.py | 3 +- .../models/flash_neox_modeling.py | 218 ++++++++++-------- 2 files changed, 122 insertions(+), 99 deletions(-) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 51270b55..206e39e7 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -338,7 +338,6 @@ class FlashNeoX(Model): # Create final next batch tensors device = out.device - next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0) next_batch_position_ids = torch.tensor( next_batch_position_ids, dtype=torch.int32, device=device ) @@ -346,8 +345,10 @@ class FlashNeoX(Model): next_batch_cu_seqlens, dtype=torch.int32, device=device ) if len(next_batch_keep_indices) > 1: + next_batch_input_ids = torch.concat(next_batch_input_ids, dim=0) next_batch_past_key_values = torch.concat(next_batch_past_key_values) else: + next_batch_input_ids = next_batch_input_ids[0] next_batch_past_key_values = next_batch_past_key_values[0] next_batch = FlashNeoXBatch( diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/flash_neox_modeling.py index f21b2771..1f8c6c4d 100644 --- a/server/text_generation_server/models/flash_neox_modeling.py +++ b/server/text_generation_server/models/flash_neox_modeling.py @@ -18,13 +18,13 @@ from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_qkv_ class TensorParallelColumnLinear(nn.Linear): def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, ): self.process_group = process_group self.tp_world_size = process_group.size() @@ -49,13 +49,13 @@ class TensorParallelColumnLinear(nn.Linear): class TensorParallelRowLinear(nn.Linear): def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, ): self.process_group = process_group self.tp_world_size = process_group.size() @@ -83,18 +83,18 @@ class TensorParallelRowLinear(nn.Linear): class TensorParallelEmbedding(nn.Embedding): def __init__( - self, - num_embeddings, - embedding_dim, - process_group: torch.distributed.ProcessGroup, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - _weight=None, - device=None, - dtype=None, + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, ): self.process_group = process_group self.tp_rank = process_group.rank() @@ -125,7 +125,7 @@ class TensorParallelEmbedding(nn.Embedding): def forward(self, input: torch.Tensor) -> torch.Tensor: # Sanity check if torch.any( - torch.logical_or(0 > input, input >= self.original_num_embeddings) + torch.logical_or(0 > input, input >= self.original_num_embeddings) ): raise IndexError( f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}" @@ -148,9 +148,9 @@ class PositionRotaryEmbedding(RotaryEmbedding): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) @@ -162,11 +162,11 @@ class PositionRotaryEmbedding(RotaryEmbedding): self._sin_cached = torch.sin(freqs).to(dtype) else: power = ( - torch.arange( - seqlen, dtype=self.scale.dtype, device=self.scale.device - ) - - seqlen // 2 - ) / self.scale_base + torch.arange( + seqlen, dtype=self.scale.dtype, device=self.scale.device + ) + - seqlen // 2 + ) / self.scale_base scale = self.scale.to(device=power.device) ** power.unsqueeze(1) # We want the multiplication by scale to happen in fp32 self._cos_cached = (torch.cos(freqs) * scale).to(dtype) @@ -174,36 +174,28 @@ class PositionRotaryEmbedding(RotaryEmbedding): self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) - def forward(self, qkv: torch.Tensor, position_ids: torch.Tensor, max_s: int): - self._update_cos_sin_cache(qkv.dtype, qkv.device, max_s) + def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): + self._update_cos_sin_cache(dtype, position_ids.device, max_s) + + cos = torch.index_select(self._cos_cached, 0, position_ids) + sin = torch.index_select(self._sin_cached, 0, position_ids) + return cos.unsqueeze(1), sin.unsqueeze(1) + + def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + rotary_dim = cos.shape[-1] + q1 = qkv[:, 0, :, :rotary_dim] + q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] + k1 = qkv[:, 1, :, :rotary_dim] + k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] - q1, q2, k1, k2, cos, sin = _prepare_rotary( - qkv, self._cos_cached, self._sin_cached, position_ids - ) rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) return qkv -@torch.jit.script -def _prepare_rotary( - qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor -): - cos = torch.index_select(cos, 0, position_ids) - sin = torch.index_select(sin, 0, position_ids) - - rotary_dim = cos.shape[-1] - q1 = qkv[:, 0, :, :rotary_dim] - q2 = qkv[:, 0, :, rotary_dim: 2 * rotary_dim] - k1 = qkv[:, 1, :, :rotary_dim] - k2 = qkv[:, 1, :, rotary_dim: 2 * rotary_dim] - - return q1, q2, k1, k2, cos.unsqueeze(1), sin.unsqueeze(1) - - class FlashNeoxAttention(torch.nn.Module): def __init__( - self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None + self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None ): super().__init__() self.num_heads = num_heads @@ -229,7 +221,7 @@ class FlashNeoxAttention(torch.nn.Module): hidden_size, process_group=process_group, ) - self.swap_dims = False + self.swap_dims = True def _swap_dims(self): self.query_key_value.weight = torch.nn.Parameter( @@ -244,17 +236,25 @@ class FlashNeoxAttention(torch.nn.Module): .permute(1, 0, 2) .reshape(-1) ) - self.swap_dims = True + self.swap_dims = False def forward( - self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q + self, + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, ): - if not self.swap_dims: + if self.swap_dims: self._swap_dims() qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) - qkv_rot = self.rotary_emb(qkv, position_ids, max_s) + qkv_rot = self.rotary_emb(qkv, cos, sin) if layer_past_present_indices is None: layer_past[...] = qkv_rot[:, 1:] @@ -348,16 +348,16 @@ class FlashMLP(nn.Module): class FlashNeoXLayer(nn.Module): def __init__( - self, - num_heads, - act, - hidden_size, - intermediate_size, - rotary_pct, - rotary_emb_base, - layer_norm_eps, - use_parallel_residual, - process_group=None, + self, + num_heads, + act, + hidden_size, + intermediate_size, + rotary_pct, + rotary_emb_base, + layer_norm_eps, + use_parallel_residual, + process_group=None, ): super().__init__() self.use_parallel_residual = use_parallel_residual @@ -369,15 +369,16 @@ class FlashNeoXLayer(nn.Module): self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group) def forward( - self, - hidden_states, - residual, - position_ids, - cu_seqlens, - max_s, - layer_past, - layer_past_present_indices, - cu_seqlens_q, + self, + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, ): if self.use_parallel_residual: ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( @@ -399,7 +400,14 @@ class FlashNeoXLayer(nn.Module): ) attn_output = self.attention( - ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q + ln1_hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, ) ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( @@ -442,7 +450,14 @@ class FlashNeoXLayer(nn.Module): ) hidden_states = self.attention( - hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, ) hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( @@ -520,12 +535,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self.num_heads = self.layers[0].attention.num_heads def forward( - self, - input_ids, - position_ids, - cu_seqlens, - max_s, - past_key_values=None, + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, ): hidden_states = self.embed_in(input_ids) @@ -543,19 +558,26 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): cu_seqlens_q = None else: layer_past_present_indices = cu_seqlens[1:] - 1 - cu_seqlens_q = torch.arange(len(cu_seqlens), dtype=torch.int32, device=hidden_states.device) + cu_seqlens_q = torch.arange( + len(cu_seqlens), dtype=torch.int32, device=hidden_states.device + ) + + cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) residual = None for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, residual, - position_ids, + cos, + sin, cu_seqlens, max_s, past_key_values[i], layer_past_present_indices, - cu_seqlens_q + cu_seqlens_q, ) hidden_states = self.final_layer_norm(hidden_states) @@ -586,12 +608,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ) def forward( - self, - input_ids, - position_ids, - cu_seqlens, - max_s, - past_key_values=None, + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, ): hidden_states, present = self.gpt_neox( input_ids, position_ids, cu_seqlens, max_s, past_key_values