From cdc70f4c23244232b25100beebc5054d6a38696e Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 23 Mar 2023 12:35:38 +0100 Subject: [PATCH] faster --- .../models/flash_neox.py | 7 +- .../models/flash_neox_modeling.py | 350 +++++++++++------- 2 files changed, 228 insertions(+), 129 deletions(-) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index b5e12882..51270b55 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -13,7 +13,7 @@ from text_generation_server.models.flash_neox_modeling import ( FlashGPTNeoXForCausalLM, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelColumnLinear + TensorParallelColumnLinear, ) from text_generation_server.models.types import ( Batch, @@ -115,7 +115,6 @@ class FlashNeoXBatch(Batch): def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": raise NotImplementedError - def __len__(self): return len(self.requests) @@ -259,7 +258,9 @@ class FlashNeoX(Model): if stop: # Decode generated tokens - output_text = self.decode(all_input_ids[-stopping_criteria.current_tokens :]) + output_text = self.decode( + all_input_ids[-stopping_criteria.current_tokens :] + ) # Get seed if isinstance(next_token_chooser.choice, Sampling): seed = next_token_chooser.choice.seed diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/flash_neox_modeling.py index 4cd2a452..65a061b9 100644 --- a/server/text_generation_server/models/flash_neox_modeling.py +++ b/server/text_generation_server/models/flash_neox_modeling.py @@ -9,43 +9,35 @@ from transformers.models.gpt_neox import GPTNeoXConfig import rotary_emb import flash_attn_cuda +import dropout_layer_norm + +import fused_dense_lib as fused_dense_cuda -from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_qkvpacked_func, - flash_attn_unpadded_kvpacked_func, -) -# from flash_attn.ops.fused_dense import ( -# FusedDense, -# ColumnParallelLinear, -# RowParallelLinear, -# fused_mlp_func, -# ) from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_qkv_ -# from flash_attn.ops.layer_norm import dropout_add_layer_norm - - class TensorParallelColumnLinear(nn.Linear): def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, ): self.process_group = process_group self.tp_world_size = process_group.size() assert out_features % self.tp_world_size == 0 out_features = out_features // self.tp_world_size - super().__init__(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype) + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) @staticmethod def linear(input, weight, bias): @@ -57,24 +49,26 @@ class TensorParallelColumnLinear(nn.Linear): class TensorParallelRowLinear(nn.Linear): def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, ): self.process_group = process_group self.tp_world_size = process_group.size() assert in_features % self.tp_world_size == 0 in_features = in_features // self.tp_world_size - super().__init__(in_features=in_features, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype) + super().__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) @staticmethod def linear(input, weight, bias): @@ -89,18 +83,18 @@ class TensorParallelRowLinear(nn.Linear): class TensorParallelEmbedding(nn.Embedding): def __init__( - self, - num_embeddings, - embedding_dim, - process_group: torch.distributed.ProcessGroup, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - _weight=None, - device=None, - dtype=None + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, ): self.process_group = process_group self.tp_rank = process_group.rank() @@ -115,15 +109,27 @@ class TensorParallelEmbedding(nn.Embedding): self.min_id = self.tp_rank * block_size self.max_id = (self.tp_rank + 1) * block_size - super().__init__(block_size, embedding_dim, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, _weight=_weight, device=device, - dtype=dtype) + super().__init__( + block_size, + embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=_weight, + device=device, + dtype=dtype, + ) def forward(self, input: torch.Tensor) -> torch.Tensor: # Sanity check - if torch.any(torch.logical_or(0 > input, input >= self.original_num_embeddings)): + if torch.any( + torch.logical_or(0 > input, input >= self.original_num_embeddings) + ): raise IndexError( - f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}") + f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}" + ) # `0` if input is in the correct interval, else `1` input_mask = torch.logical_or(self.min_id > input, input >= self.max_id) @@ -141,8 +147,11 @@ class PositionRotaryEmbedding(RotaryEmbedding): def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) - if (seqlen > self._seq_len_cached or self._cos_cached.device != device - or self._cos_cached.dtype != dtype): + if ( + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + ): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) # Don't do einsum, it converts fp32 to fp16 @@ -152,8 +161,12 @@ class PositionRotaryEmbedding(RotaryEmbedding): self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) else: - power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - - seqlen // 2) / self.scale_base) + power = ( + torch.arange( + seqlen, dtype=self.scale.dtype, device=self.scale.device + ) + - seqlen // 2 + ) / self.scale_base scale = self.scale.to(device=power.device) ** power.unsqueeze(1) # We want the multiplication by scale to happen in fp32 self._cos_cached = (torch.cos(freqs) * scale).to(dtype) @@ -164,29 +177,33 @@ class PositionRotaryEmbedding(RotaryEmbedding): def forward(self, qkv: torch.Tensor, position_ids: torch.Tensor, max_s: int): self._update_cos_sin_cache(qkv.dtype, qkv.device, max_s) - q1, q2, k1, k2, cos, sin = _prepare_rotary(qkv, self._cos_cached, self._sin_cached, position_ids) + q1, q2, k1, k2, cos, sin = _prepare_rotary( + qkv, self._cos_cached, self._sin_cached, position_ids + ) rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) return qkv @torch.jit.script -def _prepare_rotary(qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor): +def _prepare_rotary( + qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor +): cos = torch.index_select(cos, 0, position_ids) sin = torch.index_select(sin, 0, position_ids) rotary_dim = cos.shape[-1] q1 = qkv[:, 0, :, :rotary_dim] - q2 = qkv[:, 0, :, rotary_dim:2*rotary_dim] + q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] k1 = qkv[:, 1, :, :rotary_dim] - k2 = qkv[:, 1, :, rotary_dim: 2*rotary_dim] + k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] return q1, q2, k1, k2, cos.unsqueeze(1), sin.unsqueeze(1) class FlashNeoxAttention(torch.nn.Module): def __init__( - self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None + self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None ): super().__init__() self.num_heads = num_heads @@ -216,17 +233,21 @@ class FlashNeoxAttention(torch.nn.Module): def _swap_dims(self): self.query_key_value.weight = torch.nn.Parameter( - self.query_key_value.weight.view(self.num_heads, 3, self.head_size, self.hidden_size) - .permute(1, 0, 2, 3).reshape(-1, self.hidden_size) + self.query_key_value.weight.view( + self.num_heads, 3, self.head_size, self.hidden_size + ) + .permute(1, 0, 2, 3) + .reshape(-1, self.hidden_size) ) self.query_key_value.bias = torch.nn.Parameter( self.query_key_value.bias.view(self.num_heads, 3, self.head_size) - .permute(1, 0, 2).reshape(-1) + .permute(1, 0, 2) + .reshape(-1) ) self.swap_dims = True def forward( - self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill + self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill ): if not self.swap_dims: self._swap_dims() @@ -240,9 +261,21 @@ class FlashNeoxAttention(torch.nn.Module): attn_output = torch.empty_like(qkv[:, 0]) flash_attn_cuda.fwd( - qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, cu_seqlens, cu_seqlens, max_s, max_s, 0.0, + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + attn_output, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, self.softmax_scale, - False, True, False, 0, None + False, + True, + False, + 0, + None, ) else: query = qkv_rot[:, 0] @@ -250,12 +283,21 @@ class FlashNeoxAttention(torch.nn.Module): attn_output = torch.empty_like(query) flash_attn_cuda.fwd( - query, layer_past[:, 0], layer_past[:, 1], attn_output, - torch.arange(len(cu_seqlens), dtype=torch.int32).to( - query.device - ), cu_seqlens, torch.tensor(1, dtype=torch.int32).to(query.device), max_s, 0.0, + query, + layer_past[:, 0], + layer_past[:, 1], + attn_output, + torch.arange(len(cu_seqlens), dtype=torch.int32).to(query.device), + cu_seqlens, + torch.tensor(1, dtype=torch.int32).to(query.device), + max_s, + 0.0, self.softmax_scale, - False, False, False, 0, None + False, + False, + False, + 0, + None, ) return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) @@ -264,11 +306,11 @@ class FlashNeoxAttention(torch.nn.Module): class FlashMLP(nn.Module): def __init__(self, act, hidden_size, intermediate_size, process_group=None): super().__init__() - assert "gelu" in act - # if "gelu" in act: - # act = "gelu_approx" - # assert act in ["gelu_approx", "relu"] - self.act = lambda x: F.gelu(x, approximate="tanh") + if "gelu" in act: + act = "gelu_approx" + assert act in ["gelu_approx", "relu"] + self.is_gelu = act == "gelu_approx" + # self.act = lambda x: F.gelu(x, approximate="tanh") if process_group is None: self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size) @@ -288,24 +330,34 @@ class FlashMLP(nn.Module): self.process_group = process_group def forward(self, hidden_states): - hidden_states = self.dense_h_to_4h(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.dense_4h_to_h(hidden_states) - return hidden_states + hidden_states, *rest = fused_dense_cuda.linear_act_forward( + hidden_states, + self.dense_h_to_4h.weight, + self.dense_h_to_4h.bias, + self.is_gelu, + False, + 0, + ) + return self.dense_4h_to_h(hidden_states) + # + # hidden_states = self.dense_h_to_4h(hidden_states) + # hidden_states = self.act(hidden_states) + # hidden_states = self.dense_4h_to_h(hidden_states) + # return hidden_states class FlashNeoXLayer(nn.Module): def __init__( - self, - num_heads, - act, - hidden_size, - intermediate_size, - rotary_pct, - rotary_emb_base, - layer_norm_eps, - use_parallel_residual, - process_group=None, + self, + num_heads, + act, + hidden_size, + intermediate_size, + rotary_pct, + rotary_emb_base, + layer_norm_eps, + use_parallel_residual, + process_group=None, ): super().__init__() self.use_parallel_residual = use_parallel_residual @@ -317,51 +369,97 @@ class FlashNeoXLayer(nn.Module): self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group) def forward( - self, - hidden_states, - residual, - position_ids, - cu_seqlens, - max_s, - layer_past, - prefill, + self, + hidden_states, + residual, + position_ids, + cu_seqlens, + max_s, + layer_past, + prefill, ): if self.use_parallel_residual: - attn_output = self.attention( - self.input_layernorm(hidden_states), position_ids, cu_seqlens, max_s, layer_past, prefill + ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + None, + self.input_layernorm.weight, + self.input_layernorm.bias, + None, + None, + None, + None, + 0.0, + self.input_layernorm.eps, + 1.0, + 0, + None, + False, + False, ) - mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) - return mlp_output + attn_output + hidden_states, None + attn_output = self.attention( + ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill + ) + ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + None, + self.post_attention_layernorm.weight, + self.post_attention_layernorm.bias, + None, + None, + None, + None, + 0.0, + self.post_attention_layernorm.eps, + 1.0, + 0, + None, + False, + False, + ) + + mlp_output = self.mlp(ln2_hidden_states) + return mlp_output + attn_output + hidden_states, None else: - raise NotImplementedError - hidden_states, residual = dropout_add_layer_norm( + hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( hidden_states, residual, self.input_layernorm.weight, self.input_layernorm.bias, + None, + None, + None, + None, 0.0, self.input_layernorm.eps, - rowscale=None, - prenorm=True, - residual_in_fp32=True, + 1.0, + 0, + None, + False, + False, ) hidden_states = self.attention( hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill ) - hidden_states, residual = dropout_add_layer_norm( + hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( hidden_states, residual, self.post_attention_layernorm.weight, self.post_attention_layernorm.bias, + None, + None, + None, + None, 0.0, self.post_attention_layernorm.eps, - rowscale=None, - prenorm=True, - residual_in_fp32=True, + 1.0, + 0, + None, + False, + False, ) mlp_output = self.mlp(hidden_states) @@ -421,12 +519,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self.num_heads = self.layers[0].attention.num_heads def forward( - self, - input_ids, - position_ids, - cu_seqlens, - max_s, - past_key_values=None, + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, ): hidden_states = self.embed_in(input_ids) @@ -483,12 +581,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ) def forward( - self, - input_ids, - position_ids, - cu_seqlens, - max_s, - past_key_values=None, + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, ): hidden_states, present = self.gpt_neox( input_ids, position_ids, cu_seqlens, max_s, past_key_values