From ead19abb0ec13cba236df16886a0d637b8d6fdb5 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 23 Mar 2023 12:00:52 +0100 Subject: [PATCH] faster --- .../models/flash_neox.py | 24 +- .../models/flash_neox_modeling.py | 537 ++++++++---------- 2 files changed, 232 insertions(+), 329 deletions(-) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 3adf8062..b5e12882 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -3,7 +3,6 @@ import torch.distributed from accelerate import init_empty_weights from dataclasses import dataclass -from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear from opentelemetry import trace from safetensors import safe_open from transformers import AutoTokenizer, PreTrainedTokenizerBase, AutoConfig @@ -13,6 +12,8 @@ from text_generation_server.models import Model from text_generation_server.models.flash_neox_modeling import ( FlashGPTNeoXForCausalLM, TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelColumnLinear ) from text_generation_server.models.types import ( Batch, @@ -42,7 +43,7 @@ class FlashNeoXBatch(Batch): position_ids: torch.Tensor # cumulative sequence lengths cu_seqlens: torch.Tensor - max_seqlen: torch.Tensor + max_seqlen: int past_key_values: Optional[torch.Tensor] # All tokens @@ -95,7 +96,6 @@ class FlashNeoXBatch(Batch): input_ids = torch.concat(input_ids).unsqueeze(1) position_ids = torch.concat(position_ids) cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) - max_seqlen = torch.tensor(max_seqlen, dtype=torch.int32, device=device) return cls( batch_id=pb.id, @@ -168,7 +168,7 @@ class FlashNeoX(Model): input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlens: torch.Tensor, - max_s: torch.Tensor, + max_s: int, past_key_values: Optional = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward @@ -184,10 +184,6 @@ class FlashNeoX(Model): def generate_token( self, batch: FlashNeoXBatch ) -> Tuple[List[Generation], Optional[FlashNeoXBatch]]: - print("pos", batch.position_ids) - print("cu", batch.cu_seqlens) - print("max", batch.max_seqlen) - out, present = self.forward( batch.input_ids.squeeze(1), batch.position_ids, @@ -228,7 +224,7 @@ class FlashNeoX(Model): # Indexing metadata start_index = batch.cu_seqlens[i] end_index = batch.cu_seqlens[i + 1] - seq_length = end_index - start_index + seq_length = (end_index - start_index).item() if batch.past_key_values is None: # Prefill mode @@ -263,7 +259,7 @@ class FlashNeoX(Model): if stop: # Decode generated tokens - output_text = self.decode(all_input_ids) + output_text = self.decode(all_input_ids[-stopping_criteria.current_tokens :]) # Get seed if isinstance(next_token_chooser.choice, Sampling): seed = next_token_chooser.choice.seed @@ -282,7 +278,7 @@ class FlashNeoX(Model): generated_text = None next_batch_keep_indices.append(i) next_batch_input_ids.append(next_token_id) - next_batch_position_ids.append(new_input_length) + next_batch_position_ids.append(seq_length) next_batch_cu_seqlens.append( next_batch_cu_seqlens[i] + new_input_length ) @@ -435,13 +431,13 @@ class FlashNeoXSharded(FlashNeoX): slice_ = f.get_slice(name) - if isinstance(module, ColumnParallelLinear): + if isinstance(module, TensorParallelColumnLinear): size = slice_.get_shape()[0] block_size = size // world_size start = rank * block_size stop = (rank + 1) * block_size tensor = slice_[start:stop] - elif isinstance(module, RowParallelLinear): + elif isinstance(module, TensorParallelRowLinear): if param_name == "weight": size = slice_.get_shape()[1] block_size = size // world_size @@ -491,7 +487,7 @@ class FlashNeoXSharded(FlashNeoX): input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlens: torch.Tensor, - max_s: torch.Tensor, + max_s: int, past_key_values: Optional = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if self.model.gpt_neox.tp_embeddings: diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/flash_neox_modeling.py index 06fab145..4cd2a452 100644 --- a/server/text_generation_server/models/flash_neox_modeling.py +++ b/server/text_generation_server/models/flash_neox_modeling.py @@ -1,38 +1,106 @@ import torch +import torch.distributed + +import torch.nn.functional as F from torch import nn - from transformers.modeling_utils import PreTrainedModel from transformers.models.gpt_neox import GPTNeoXConfig -from einops import rearrange + +import rotary_emb +import flash_attn_cuda + from flash_attn.flash_attn_interface import ( flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func, ) -from flash_attn.ops.fused_dense import ( - FusedDense, - ColumnParallelLinear, - RowParallelLinear, - fused_mlp_func, -) +# from flash_attn.ops.fused_dense import ( +# FusedDense, +# ColumnParallelLinear, +# RowParallelLinear, +# fused_mlp_func, +# ) from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_qkv_ -from flash_attn.ops.layer_norm import dropout_add_layer_norm + + +# from flash_attn.ops.layer_norm import dropout_add_layer_norm + + +class TensorParallelColumnLinear(nn.Linear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + assert out_features % self.tp_world_size == 0 + out_features = out_features // self.tp_world_size + + super().__init__(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype) + + @staticmethod + def linear(input, weight, bias): + return F.linear(input, weight, bias) + + def forward(self, input): + return self.linear(input, self.weight, self.bias) + + +class TensorParallelRowLinear(nn.Linear): + def __init__( + self, + in_features, + out_features, + process_group: torch.distributed.ProcessGroup, + bias=True, + device=None, + dtype=None, + ): + self.process_group = process_group + self.tp_world_size = process_group.size() + assert in_features % self.tp_world_size == 0 + in_features = in_features // self.tp_world_size + + super().__init__(in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype) + + @staticmethod + def linear(input, weight, bias): + return F.linear(input, weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = self.linear(input, self.weight, self.bias) + torch.distributed.all_reduce(out, group=self.process_group) + + return out class TensorParallelEmbedding(nn.Embedding): def __init__( - self, - num_embeddings, - embedding_dim, - process_group: torch.distributed.ProcessGroup, - padding_idx=None, - max_norm=None, - norm_type=2.0, - scale_grad_by_freq=False, - sparse=False, - _weight=None, - device=None, - dtype=None, + self, + num_embeddings, + embedding_dim, + process_group: torch.distributed.ProcessGroup, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + _weight=None, + device=None, + dtype=None ): self.process_group = process_group self.tp_rank = process_group.rank() @@ -40,33 +108,22 @@ class TensorParallelEmbedding(nn.Embedding): self.original_num_embeddings = num_embeddings + # TODO @thomasw21 fix and remove that constraint assert num_embeddings % self.tp_world_size == 0 block_size = num_embeddings // self.tp_world_size # inputs in `[min_id, max_id[` are handled by `self` to get embeddings self.min_id = self.tp_rank * block_size self.max_id = (self.tp_rank + 1) * block_size - super().__init__( - block_size, - embedding_dim, - padding_idx=padding_idx, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse, - _weight=_weight, - device=device, - dtype=dtype, - ) + super().__init__(block_size, embedding_dim, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, sparse=sparse, _weight=_weight, device=device, + dtype=dtype) def forward(self, input: torch.Tensor) -> torch.Tensor: # Sanity check - if torch.any( - torch.logical_or(0 > input, input >= self.original_num_embeddings) - ): + if torch.any(torch.logical_or(0 > input, input >= self.original_num_embeddings)): raise IndexError( - f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}" - ) + f"Input is required to be in [0, {self.original_num_embeddings}[, got min: {torch.min(input)} and max: {torch.max(input)}") # `0` if input is in the correct interval, else `1` input_mask = torch.logical_or(self.min_id > input, input >= self.max_id) @@ -81,20 +138,55 @@ class TensorParallelEmbedding(nn.Embedding): class PositionRotaryEmbedding(RotaryEmbedding): - def forward(self, qkv: torch.Tensor, position_ids: torch.Tensor): - assert self.scale is None + def _update_cos_sin_cache(self, dtype, device, seqlen): + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if (seqlen > self._seq_len_cached or self._cos_cached.device != device + or self._cos_cached.dtype != dtype): + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + # Don't do einsum, it converts fp32 to fp16 + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) + - seqlen // 2) / self.scale_base) + scale = self.scale.to(device=power.device) ** power.unsqueeze(1) + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) - self._update_cos_sin_cache(qkv, position_ids.max() + 1) + def forward(self, qkv: torch.Tensor, position_ids: torch.Tensor, max_s: int): + self._update_cos_sin_cache(qkv.dtype, qkv.device, max_s) - cos = self._cos_cached[position_ids] - sin = self._sin_cached[position_ids] + q1, q2, k1, k2, cos, sin = _prepare_rotary(qkv, self._cos_cached, self._sin_cached, position_ids) + rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) + rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) + return qkv - return apply_rotary_emb_qkv_(qkv, cos, sin, None, None) + +@torch.jit.script +def _prepare_rotary(qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor): + cos = torch.index_select(cos, 0, position_ids) + sin = torch.index_select(sin, 0, position_ids) + + rotary_dim = cos.shape[-1] + q1 = qkv[:, 0, :, :rotary_dim] + q2 = qkv[:, 0, :, rotary_dim:2*rotary_dim] + k1 = qkv[:, 1, :, :rotary_dim] + k2 = qkv[:, 1, :, rotary_dim: 2*rotary_dim] + + return q1, q2, k1, k2, cos.unsqueeze(1), sin.unsqueeze(1) class FlashNeoxAttention(torch.nn.Module): def __init__( - self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None + self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None ): super().__init__() self.num_heads = num_heads @@ -106,129 +198,114 @@ class FlashNeoxAttention(torch.nn.Module): self.softmax_scale = self.head_size ** (-0.5) if process_group is None: - self.query_key_value = FusedDense(hidden_size, 3 * hidden_size) - self.dense = FusedDense(hidden_size, hidden_size) + self.query_key_value = nn.Linear(hidden_size, 3 * hidden_size) + self.dense = nn.Linear(hidden_size, hidden_size) else: self.num_heads = self.num_heads // process_group.size() - self.query_key_value = ColumnParallelLinear( + self.query_key_value = TensorParallelColumnLinear( hidden_size, 3 * hidden_size, process_group=process_group, - sequence_parallel=False, ) - self.dense = RowParallelLinear( + self.dense = TensorParallelRowLinear( hidden_size, hidden_size, process_group=process_group, - sequence_parallel=False, ) + self.swap_dims = False + + def _swap_dims(self): + self.query_key_value.weight = torch.nn.Parameter( + self.query_key_value.weight.view(self.num_heads, 3, self.head_size, self.hidden_size) + .permute(1, 0, 2, 3).reshape(-1, self.hidden_size) + ) + self.query_key_value.bias = torch.nn.Parameter( + self.query_key_value.bias.view(self.num_heads, 3, self.head_size) + .permute(1, 0, 2).reshape(-1) + ) + self.swap_dims = True def forward( - self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill + self, hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill ): + if not self.swap_dims: + self._swap_dims() + qkv = self.query_key_value(hidden_states) - qkv = rearrange( - qkv, "... (h three d) -> ... h three d", three=3, d=self.head_size - ).permute(0, 2, 1, 3) - qkv_rot = self.rotary_emb(qkv.unsqueeze(0), position_ids).squeeze(0) + qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + qkv_rot = self.rotary_emb(qkv, position_ids, max_s) if prefill: layer_past[...] = qkv_rot[:, 1:] - # test flash_attn_unpadded_qkvpacked_split_func - attn_output = flash_attn_unpadded_qkvpacked_func( - qkv_rot, cu_seqlens, max_s, 0.0, self.softmax_scale, causal=True + attn_output = torch.empty_like(qkv[:, 0]) + flash_attn_cuda.fwd( + qkv[:, 0], qkv[:, 1], qkv[:, 2], attn_output, cu_seqlens, cu_seqlens, max_s, max_s, 0.0, + self.softmax_scale, + False, True, False, 0, None ) else: query = qkv_rot[:, 0] layer_past[cu_seqlens[1:] - 1] = qkv_rot[:, 1:] - attn_output = flash_attn_unpadded_kvpacked_func( - query, - layer_past, - cu_seqlens_q=torch.arange(len(cu_seqlens), dtype=torch.int32).to( + attn_output = torch.empty_like(query) + flash_attn_cuda.fwd( + query, layer_past[:, 0], layer_past[:, 1], attn_output, + torch.arange(len(cu_seqlens), dtype=torch.int32).to( query.device - ), - max_seqlen_q=torch.tensor(1, dtype=torch.int32).to(query.device), - cu_seqlens_k=cu_seqlens, - max_seqlen_k=max_s, - dropout_p=0.0, - softmax_scale=self.softmax_scale, - causal=False, + ), cu_seqlens, torch.tensor(1, dtype=torch.int32).to(query.device), max_s, 0.0, + self.softmax_scale, + False, False, False, 0, None ) - return self.dense(rearrange(attn_output, "... h d -> ... (h d)")) + return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) class FlashMLP(nn.Module): def __init__(self, act, hidden_size, intermediate_size, process_group=None): super().__init__() - if "gelu" in act: - act = "gelu_approx" - assert act in ["gelu_approx", "relu"] - self.act = act + assert "gelu" in act + # if "gelu" in act: + # act = "gelu_approx" + # assert act in ["gelu_approx", "relu"] + self.act = lambda x: F.gelu(x, approximate="tanh") if process_group is None: - self.dense_h_to_4h = FusedDense(hidden_size, intermediate_size) - self.dense_4h_to_h = FusedDense(intermediate_size, hidden_size) + self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size) + self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size) else: - self.dense_h_to_4h = ColumnParallelLinear( + self.dense_h_to_4h = TensorParallelColumnLinear( hidden_size, intermediate_size, process_group=process_group, - sequence_parallel=False, ) - self.dense_4h_to_h = RowParallelLinear( + self.dense_4h_to_h = TensorParallelRowLinear( intermediate_size, hidden_size, process_group=process_group, - sequence_parallel=False, ) self.heuristic = "auto" self.process_group = process_group - def forward(self, x): - if self.heuristic == "auto": - if self.act == "gelu_approx": - cuda_ver = tuple(map(int, torch.version.cuda.split("."))) - self.heuristic = ( - 0 - if cuda_ver >= (11, 8) - else (1 if x.dtype == torch.float16 else -1) - ) - else: - self.heuristic = 0 - - out = fused_mlp_func( - x, - self.dense_h_to_4h.weight, - self.dense_4h_to_h.weight, - self.dense_h_to_4h.bias, - self.dense_4h_to_h.bias, - activation=self.act, - save_pre_act=self.training, - checkpoint_lvl=0, - heuristic=self.heuristic, - process_group=self.process_group, - sequence_parallel=False, - ) - if self.process_group is not None: - torch.distributed.all_reduce(out, group=self.process_group) - return out + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states class FlashNeoXLayer(nn.Module): def __init__( - self, - num_heads, - act, - hidden_size, - intermediate_size, - rotary_pct, - rotary_emb_base, - layer_norm_eps, - use_parallel_residual, - process_group=None, + self, + num_heads, + act, + hidden_size, + intermediate_size, + rotary_pct, + rotary_emb_base, + layer_norm_eps, + use_parallel_residual, + process_group=None, ): super().__init__() self.use_parallel_residual = use_parallel_residual @@ -240,46 +317,25 @@ class FlashNeoXLayer(nn.Module): self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group) def forward( - self, - hidden_states, - residual, - position_ids, - cu_seqlens, - max_s, - layer_past, - prefill, + self, + hidden_states, + residual, + position_ids, + cu_seqlens, + max_s, + layer_past, + prefill, ): if self.use_parallel_residual: - ln1_hidden_states = dropout_add_layer_norm( - hidden_states, - residual, - self.input_layernorm.weight, - self.input_layernorm.bias, - 0.0, - self.input_layernorm.eps, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - ) attn_output = self.attention( - ln1_hidden_states, position_ids, cu_seqlens, max_s, layer_past, prefill + self.input_layernorm(hidden_states), position_ids, cu_seqlens, max_s, layer_past, prefill ) - ln2_hidden_states = dropout_add_layer_norm( - hidden_states, - residual, - self.post_attention_layernorm.weight, - self.post_attention_layernorm.bias, - 0.0, - self.post_attention_layernorm.eps, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - ) - mlp_output = self.mlp(ln2_hidden_states) + mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) return mlp_output + attn_output + hidden_states, None else: + raise NotImplementedError hidden_states, residual = dropout_add_layer_norm( hidden_states, residual, @@ -365,12 +421,12 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self.num_heads = self.layers[0].attention.num_heads def forward( - self, - input_ids, - position_ids, - cu_seqlens, - max_s, - past_key_values=None, + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, ): hidden_states = self.embed_in(input_ids) @@ -399,17 +455,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): prefill, ) - hidden_states = dropout_add_layer_norm( - hidden_states, - residual, - self.final_layer_norm.weight, - self.final_layer_norm.bias, - 0.0, - self.final_layer_norm.eps, - rowscale=None, - prenorm=False, - residual_in_fp32=False, - ) + hidden_states = self.final_layer_norm(hidden_states) return hidden_states, past_key_values @@ -426,164 +472,25 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): self.gpt_neox = FlashGPTNeoXModel(config, process_group) if self.gpt_neox.tp_embeddings: - self.embed_out = FusedDense( + self.embed_out = nn.Linear( config.hidden_size, config.vocab_size // process_group.size(), bias=False, ) else: - self.embed_out = FusedDense( + self.embed_out = nn.Linear( config.hidden_size, config.vocab_size, bias=False ) def forward( - self, - input_ids, - position_ids, - cu_seqlens, - max_s, - past_key_values=None, + self, + input_ids, + position_ids, + cu_seqlens, + max_s, + past_key_values=None, ): hidden_states, present = self.gpt_neox( input_ids, position_ids, cu_seqlens, max_s, past_key_values ) return self.embed_out(hidden_states), present - - -if __name__ == "__main__": - from transformers import AutoTokenizer - from flash_attn.bert_padding import unpad_input - - model = ( - FlashGPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-160m") - .cuda() - .to(torch.half) - ) - - tokenizer = AutoTokenizer.from_pretrained( - "EleutherAI/pythia-160m", padding_side="left" - ) - tokenizer.pad_token = tokenizer.eos_token - - tokenized_inputs = tokenizer( - ["What is this?\n\nA:\n\nThe answer to the problem?", "hello!"], - padding=True, - return_tensors="pt", - ).to("cuda") - - input_ids, indices, cu_seqlens, max_seqlen = unpad_input( - tokenized_inputs["input_ids"].unsqueeze(-1), tokenized_inputs["attention_mask"] - ) - - position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 - position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 0) - - unpad_position_ids = torch.gather(position_ids.view(-1).cuda(), 0, indices) - - gen_input_ids = input_ids.squeeze(1).cuda().clone() - gen_position_ids = unpad_position_ids.clone() - gen_indices = indices.clone() - gen_cu_seqlens = cu_seqlens.clone() - gen_max_seqlen = max_seqlen - - past_key_values = None - - results = [] - with torch.no_grad(): - out, present, _ = model( - gen_input_ids, - gen_position_ids, - gen_cu_seqlens, - gen_max_seqlen, - past_key_values=past_key_values, - ) - - futures = [] - new_gen_cu_seqlens = [0] - new_position_ids = [] - next_token_ids = [] - - for i in range(len(gen_cu_seqlens) - 1): - start_index = gen_cu_seqlens[i] - end_index = gen_cu_seqlens[i + 1] - - seq_logits = out[start_index:end_index] - next_token_id = torch.argmax(seq_logits[-1:], dim=1) - next_token_ids.append(next_token_id) - - sequence_length = end_index - start_index - new_gen_cu_seqlens.append(new_gen_cu_seqlens[i] + sequence_length + 1) - - seq_position_ids = gen_position_ids[start_index:end_index] - new_position_ids.append( - torch.concat([seq_position_ids, seq_position_ids[-1:] + 1]) - ) - - seq_present = present[:, start_index:end_index] - future = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1)) - - futures.append(future) - - past_key_values = torch.concat(futures, dim=1) - new_position_ids = torch.concat(new_position_ids) - new_gen_cu_seqlens = torch.tensor( - new_gen_cu_seqlens, device=past_key_values.device, dtype=torch.int32 - ) - next_token_ids = torch.concat(next_token_ids) - - gen_max_seqlen += 1 - - gen_input_ids = next_token_ids - gen_position_ids = new_position_ids - gen_cu_seqlens = new_gen_cu_seqlens - - print(tokenizer.batch_decode(gen_input_ids)) - - for _ in range(40): - out, present, _ = model( - gen_input_ids, - gen_position_ids, - gen_cu_seqlens, - gen_max_seqlen, - past_key_values=past_key_values, - ) - - futures = [] - new_gen_cu_seqlens = [0] - new_position_ids = [] - next_token_ids = [] - for i in range(len(gen_cu_seqlens) - 1): - start_index = gen_cu_seqlens[i] - end_index = gen_cu_seqlens[i + 1] - - seq_logits = out[i] - next_token_id = torch.argmax(seq_logits.view(1, -1)[-1:], dim=1) - next_token_ids.append(next_token_id) - - sequence_length = end_index - start_index - new_gen_cu_seqlens.append(new_gen_cu_seqlens[i] + sequence_length + 1) - - seq_position_ids = gen_position_ids[start_index:end_index] - new_position_ids.append( - torch.concat([seq_position_ids, seq_position_ids[-1:] + 1]) - ) - - seq_present = present[:, start_index:end_index] - future = torch.nn.functional.pad(seq_present, (0, 0, 0, 0, 0, 0, 0, 1)) - - futures.append(future) - - past_key_values = torch.concat(futures, dim=1) - new_position_ids = torch.concat(new_position_ids) - new_gen_cu_seqlens = torch.tensor( - new_gen_cu_seqlens, device=past_key_values.device, dtype=torch.int32 - ) - next_token_ids = torch.concat(next_token_ids) - - gen_max_seqlen += 1 - - gen_input_ids = next_token_ids - gen_position_ids = new_position_ids - gen_cu_seqlens = new_gen_cu_seqlens - - print(tokenizer.batch_decode(gen_input_ids))