diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6e072d4c..29cc9848 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1,6 +1,8 @@ import torch import torch.distributed +import numpy as np + from torch.nn import functional as F from dataclasses import dataclass @@ -33,12 +35,12 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping: Dict[int, int] # Decoder values - input_ids: List[torch.Tensor] - position_ids: List[torch.Tensor] + input_ids: torch.Tensor + position_ids: torch.Tensor # cumulative sequence lengths - cu_seqlens: List[int] + cu_seqlens: torch.Tensor max_seqlen: int - past_key_values: Optional[Union[torch.Tensor, List[torch.Tensor]]] + past_key_values: Optional[torch.Tensor] # All tokens all_input_ids: List[List[int]] @@ -53,9 +55,6 @@ class FlashCausalLMBatch(Batch): next_token_choosers: List[NextTokenChooser] stopping_criterias: List[StoppingCriteria] - # Constant shared tensor, ref here just so that it's accessible in concatentate() - past_pad: Optional[torch.Tensor] - # Maximum number of tokens this batch will grow to max_tokens: int @@ -69,12 +68,11 @@ class FlashCausalLMBatch(Batch): @classmethod def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - device: torch.device, + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + device: torch.device, ) -> "FlashCausalLMBatch": - input_ids = [] position_ids = [] cu_seqlens = [0] max_seqlen = 0 @@ -83,7 +81,6 @@ class FlashCausalLMBatch(Batch): offsets = [] token_offsets = [] all_input_ids = [] - all_input_ids_tensor = [] requests_idx_mapping = {} next_token_choosers = [] @@ -109,15 +106,11 @@ class FlashCausalLMBatch(Batch): offsets.append(None) token_offsets.append(None) + all_input_ids.append(tokenized_input) - tokenized_input = torch.tensor(tokenized_input, device=device) - input_ids.append(tokenized_input) - # Position ids - position_ids.append( - torch.arange(0, input_length, dtype=torch.int32, device=device) - ) + position_ids.append(np.arange(0, input_length)) # Add cumulative lengths of all previous inputs cu_seqlens.append(cumulative_length + input_length) @@ -130,14 +123,16 @@ class FlashCausalLMBatch(Batch): max_new_tokens = stopping_criteria.max_new_tokens stopping_criterias.append(stopping_criteria) - all_input_ids_tensor.append( - F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens)) - ) - # Update cumulative_length += input_length max_tokens += input_length + max_new_tokens + input_ids = torch.tensor(np.concatenate(all_input_ids), dtype=torch.int32, device=device) + position_ids = torch.tensor(np.concatenate(position_ids), dtype=torch.int32, device=device) + cu_seqlens = torch.tensor( + cu_seqlens, device=device, dtype=torch.int32 + ) + return cls( batch_id=pb.id, requests=pb.requests, @@ -151,10 +146,9 @@ class FlashCausalLMBatch(Batch): offsets=offsets, token_offsets=token_offsets, all_input_ids=all_input_ids, - all_input_ids_tensor=all_input_ids_tensor, + all_input_ids_tensor=[], next_token_choosers=next_token_choosers, stopping_criterias=stopping_criterias, - past_pad=None, max_tokens=max_tokens, ) @@ -224,7 +218,7 @@ class FlashCausalLMBatch(Batch): cumulative_length += request_input_length max_tokens += request_input_length + ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) if single_request: @@ -360,14 +354,13 @@ class FlashCausalLMBatch(Batch): class FlashCausalLM(Model): def __init__( - self, - model_cls: Type[PreTrainedModel], - model_id: str, - revision: Optional[str] = None, - quantize: bool = False, - decode_buffer: int = 3, + self, + model_cls: Type[PreTrainedModel], + model_id: str, + revision: Optional[str] = None, + quantize: bool = False, + decode_buffer: int = 3, ): - self.past_pad = None if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 @@ -406,13 +399,13 @@ class FlashCausalLM(Model): ) def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - cu_seqlens: torch.Tensor, - max_s: int, - past_key_values: Optional = None, - pre_allocate_past_size: Optional[int] = None, + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + max_s: int, + past_key_values: Optional = None, + pre_allocate_past_size: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Model Forward return self.model.forward( @@ -426,42 +419,24 @@ class FlashCausalLM(Model): @tracer.start_as_current_span("generate_token") def generate_token( - self, batch: FlashCausalLMBatch + self, batch: FlashCausalLMBatch ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: # Shortcut when batch_size == 1 - if len(batch) == 1: - input_ids = batch.input_ids[0].view(-1) - else: - # Concatenate tensors - if not isinstance(batch.input_ids, torch.Tensor): - input_ids = torch.cat(batch.input_ids).view(-1) - else: - input_ids = batch.input_ids.view(-1) # if prefill and bs == 1 if batch.past_key_values is None and len(batch) == 1: # Ask to pre-allocate kv to its max size # == number of tokens + max_new_tokens pre_allocate_past_size = ( - batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens + batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens ) else: pre_allocate_past_size = None - # Concatenate when prefill, torch.tensor when decode - if batch.past_key_values is None: - position_ids = torch.cat(batch.position_ids) - else: - position_ids = batch.position_ids - - cu_seqlens = torch.tensor( - batch.cu_seqlens, device=self.device, dtype=torch.int32 - ) - out, present = self.forward( - input_ids, - position_ids, - cu_seqlens, + batch.input_ids, + batch.position_ids, + batch.cu_seqlens, batch.max_seqlen, batch.past_key_values, pre_allocate_past_size, @@ -483,61 +458,72 @@ class FlashCausalLM(Model): batch.next_token_choosers, batch.stopping_criterias, batch.all_input_ids, - batch.all_input_ids_tensor, ) - next_input_ids = input_ids.new_empty(len(batch.requests)) past_indices = [] + prefill = batch.past_key_values is None + # For each member of the batch for i, ( - request, - input_length, - offset, - token_offset, - next_token_chooser, - stopping_criteria, - all_input_ids, - all_input_ids_tensor, + request, + input_length, + offset, + token_offset, + next_token_chooser, + stopping_criteria, + all_input_ids, ) in enumerate(iterator): # Indexing metadata start_index = cumulative_length end_index = cumulative_length + input_length - prefill = stopping_criteria.current_tokens == 0 if prefill: # Prefill mode # out is of shape [cumulative_sequence_lengths, vocab_size] logits = out[start_index:end_index] + batch.all_input_ids_tensor.append( + F.pad(batch.input_ids[start_index:end_index], (0, stopping_criteria.max_new_tokens)) + ) + batch.position_ids[i] = input_length else: # Decode mode # out is of shape [batch_size, vocab_size] logits = out[i].unsqueeze(0) + + all_input_ids_tensor = batch.all_input_ids_tensor[i] + # Select next token next_token_id, logprobs = next_token_chooser( all_input_ids_tensor[None, :input_length], logits ) next_token_id_squeezed = next_token_id.squeeze() all_input_ids_tensor[input_length] = next_token_id_squeezed - next_input_ids[i] = next_token_id_squeezed past_indices.extend([j for j in range(start_index + i, end_index + i)]) + batch.input_ids[i] = next_token_id_squeezed + + + if prefill: + batch.input_ids = batch.input_ids[:len(batch)] + batch.position_ids = batch.position_ids[:len(batch)] + else: + batch.position_ids += 1 + # Initialize past_key_values in prefill if batch.past_key_values is None and len(batch) == 1: # present is already pre-padded batch.past_key_values = present if len(batch) > 1: - batch.past_key_values = present.new_empty((present.shape[0], present.shape[1] + len(batch.requests), *present.shape[2:])) + batch.past_key_values = present.new_empty( + (present.shape[0], present.shape[1] + len(batch.requests), *present.shape[2:])) batch.past_key_values[:, past_indices] = present - if prefill: - batch.position_ids = torch.tensor(batch.input_lengths, device=self.device) - else: - batch.position_ids = batch.position_ids + 1 + batch.cu_seqlens = batch.cu_seqlens + torch.arange(0, len(batch) + 1, device=self.device, dtype=torch.int32) - next_token_ids = next_input_ids.tolist() + next_token_ids = batch.input_ids.to("cpu").detach() # Zipped iterator iterator = zip( @@ -584,7 +570,7 @@ class FlashCausalLM(Model): if stop: # Decode generated tokens output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :] + all_input_ids[-stopping_criteria.current_tokens:] ) # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -599,7 +585,6 @@ class FlashCausalLM(Model): stopped = False generated_text = None - prefill = stopping_criteria.current_tokens == 0 # # Prefill # if prefill: # # Remove generated token to only have prefill and add nan for first prompt token @@ -638,11 +623,6 @@ class FlashCausalLM(Model): batch.token_offsets[i] = token_offset batch.all_input_ids[i] = all_input_ids batch.max_seqlen = max(batch.max_seqlen, new_input_length) - # Cumulative sum - batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length - - - batch.input_ids = next_input_ids # No need to return a batch if we know that all requests stopped return generations, batch if not stopped else None