diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/flash_neox_modeling.py index 65a061b9..f21b2771 100644 --- a/server/text_generation_server/models/flash_neox_modeling.py +++ b/server/text_generation_server/models/flash_neox_modeling.py @@ -18,13 +18,13 @@ from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_qkv_ class TensorParallelColumnLinear(nn.Linear): def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, ): self.process_group = process_group self.tp_world_size = process_group.size() @@ -49,13 +49,13 @@ class TensorParallelColumnLinear(nn.Linear): class TensorParallelRowLinear(nn.Linear): def __init__( - self, - in_features, - out_features, - process_group: torch.distributed.ProcessGroup, - bias=True, - device=None, - dtype=None, + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, ): self.process_group = process_group self.tp_world_size = process_group.size() @@ -83,18 +83,18 @@ class TensorParallelRowLinear(nn.Linear): class TensorParallelEmbedding(nn.Embedding): def __init__( - self, - num_embeddings, - embedding_dim, - process_group: torch.distributed.ProcessGroup, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - _weight=None, - device=None, - dtype=None, + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None, ): self.process_group = process_group self.tp_rank = process_group.rank() @@ -125,7 +125,7 @@ class TensorParallelEmbedding(nn.Embedding): def forward(self, input: torch.Tensor) -> torch.Tensor: # Sanity check if torch.any( - torch.logical_or(0 > input, input >= self.original_num_embeddings) + torch.logical_or(0 > input, input >= self.original_num_embeddings) ): raise IndexError( f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}" @@ -148,9 +148,9 @@ class PositionRotaryEmbedding(RotaryEmbedding): # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if ( - seqlen > self._seq_len_cached - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype + seqlen > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype ): self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) @@ -162,11 +162,11 @@ class PositionRotaryEmbedding(RotaryEmbedding): self._sin_cached = torch.sin(freqs).to(dtype) else: power = ( - torch.arange( - seqlen, dtype=self.scale.dtype, device=self.scale.device - ) - - seqlen // 2 - ) / self.scale_base + torch.arange( + seqlen, dtype=self.scale.dtype, device=self.scale.device + ) + - seqlen // 2 + ) / self.scale_base scale = self.scale.to(device=power.device) ** power.unsqueeze(1) # We want the multiplication by scale to happen in fp32 self._cos_cached = (torch.cos(freqs) * scale).to(dtype) @@ -187,23 +187,23 @@ class PositionRotaryEmbedding(RotaryEmbedding): @torch.jit.script def _prepare_rotary( - qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor + qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor ): cos = torch.index_select(cos, 0, position_ids) sin = torch.index_select(sin, 0, position_ids) rotary_dim = cos.shape[-1] q1 = qkv[:, 0, :, :rotary_dim] - q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] + q2 = qkv[:, 0, :, rotary_dim: 2 * rotary_dim] k1 = qkv[:, 1, :, :rotary_dim] - k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] + k2 = qkv[:, 1, :, rotary_dim: 2 * rotary_dim] return q1, q2, k1, k2, cos.unsqueeze(1), sin.unsqueeze(1) class FlashNeoxAttention(torch.nn.Module): def __init__( - self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None + self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None ): super().__init__() self.num_heads = num_heads @@ -247,7 +247,7 @@ class FlashNeoxAttention(torch.nn.Module): self.swap_dims = True def forward( - self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill + self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q ): if not self.swap_dims: self._swap_dims() @@ -256,7 +256,7 @@ class FlashNeoxAttention(torch.nn.Module): qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv_rot = self.rotary_emb(qkv, position_ids, max_s) - if prefill: + if layer_past_present_indices is None: layer_past[...] = qkv_rot[:, 1:] attn_output = torch.empty_like(qkv[:, 0]) @@ -279,7 +279,7 @@ class FlashNeoxAttention(torch.nn.Module): ) else: query = qkv_rot[:, 0] - layer_past[cu_seqlens[1:] - 1] = qkv_rot[:, 1:] + layer_past[layer_past_present_indices] = qkv_rot[:, 1:] attn_output = torch.empty_like(query) flash_attn_cuda.fwd( @@ -287,9 +287,9 @@ class FlashNeoxAttention(torch.nn.Module): layer_past[:, 0], layer_past[:, 1], attn_output, - torch.arange(len(cu_seqlens), dtype=torch.int32).to(query.device), + cu_seqlens_q, cu_seqlens, - torch.tensor(1, dtype=torch.int32).to(query.device), + 1, max_s, 0.0, self.softmax_scale, @@ -348,16 +348,16 @@ class FlashMLP(nn.Module): class FlashNeoXLayer(nn.Module): def __init__( - self, - num_heads, - act, - hidden_size, - intermediate_size, - rotary_pct, - rotary_emb_base, - layer_norm_eps, - use_parallel_residual, - process_group=None, + self, + num_heads, + act, + hidden_size, + intermediate_size, + rotary_pct, + rotary_emb_base, + layer_norm_eps, + use_parallel_residual, + process_group=None, ): super().__init__() self.use_parallel_residual = use_parallel_residual @@ -369,14 +369,15 @@ class FlashNeoXLayer(nn.Module): self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group) def forward( - self, - hidden_states, - residual, - position_ids, - cu_seqlens, - max_s, - layer_past, - prefill, + self, + hidden_states, + residual, + position_ids, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, ): if self.use_parallel_residual: ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( @@ -398,7 +399,7 @@ class FlashNeoXLayer(nn.Module): ) attn_output = self.attention( - ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill + ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q ) ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( @@ -441,7 +442,7 @@ class FlashNeoXLayer(nn.Module): ) hidden_states = self.attention( - hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill + hidden_states, position_ids, cu_seqlens, max_s, layer_past, layer_past_present_indices, cu_seqlens_q ) hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( @@ -519,16 +520,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self.num_heads = self.layers[0].attention.num_heads def forward( - self, - input_ids, - position_ids, - cu_seqlens, - max_s, - past_key_values=None, + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, ): hidden_states = self.embed_in(input_ids) - prefill = False if past_key_values is None: past_key_values = hidden_states.new_empty( ( @@ -539,7 +539,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self.head_size, ) ) - prefill = True + layer_past_present_indices = None + cu_seqlens_q = None + else: + layer_past_present_indices = cu_seqlens[1:] - 1 + cu_seqlens_q = torch.arange(len(cu_seqlens), dtype=torch.int32, device=hidden_states.device) residual = None for i, layer in enumerate(self.layers): @@ -550,7 +554,8 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): cu_seqlens, max_s, past_key_values[i], - prefill, + layer_past_present_indices, + cu_seqlens_q ) hidden_states = self.final_layer_norm(hidden_states) @@ -581,12 +586,12 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ) def forward( - self, - input_ids, - position_ids, - cu_seqlens, - max_s, - past_key_values=None, + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, ): hidden_states, present = self.gpt_neox( input_ids, position_ids, cu_seqlens, max_s, past_key_values