From 3fc87f93bd1c7795edd222ebc08f92fa706e71bb Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 2 Jun 2023 18:17:18 +0200 Subject: [PATCH] fix --- .../models/flash_causal_lm.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1a8bf6fc..bdfa5051 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -3,8 +3,6 @@ import torch.distributed import numpy as np -from torch.nn import functional as F - from dataclasses import dataclass from opentelemetry import trace from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel @@ -206,6 +204,13 @@ class FlashCausalLMBatch(Batch): for i, input_ids in enumerate(all_input_ids): all_input_ids_tensor[i, : len(input_ids)] = input_ids + # Create tensors on device + all_input_ids_tensor = torch.tensor( + all_input_ids_tensor, dtype=torch.int64, device=device + ) + start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32) + end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32) + if len(pb.requests) > 1: input_ids = np.concatenate(all_input_ids, dtype=np.int64) position_ids = torch.cat(position_ids) @@ -223,21 +228,15 @@ class FlashCausalLMBatch(Batch): start_seq_prefill = start_seq end_seq_prefill = end_seq - # Create tensors on device input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) - all_input_ids_tensor = torch.tensor( - all_input_ids_tensor, dtype=torch.int64, device=device - ) position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device) - start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32) - end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32) past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.int64) if all_prefill_logprobs: prefill_head_indices = None - prefill_next_token_indices = end_seq - 1 + prefill_next_token_indices = end_seq_prefill - 1 elif no_prefill_logprobs: - prefill_head_indices = end_seq - 1 + prefill_head_indices = end_seq_prefill - 1 prefill_next_token_indices = None else: prefill_head_indices = torch.tensor( @@ -392,7 +391,6 @@ class FlashCausalLMBatch(Batch): requests_idx_mapping = {} total_batch_size = sum([len(b) for b in batches]) - total_tokens = sum(b.max_tokens for b in batches) dtype = batches[0].past_key_values.dtype device = batches[0].input_ids.device @@ -605,9 +603,8 @@ class FlashCausalLM(Model): ) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: prefill = batch.past_key_values is None prefill_logprobs = batch.prefill_next_token_indices is not None - single_request = len(batch) == 1 - if prefill and single_request: + if prefill: # Ask to pre-allocate kv to its max size # == Sum over batch size (number of tokens + max_new_tokens) - batch size pre_allocate_past_size = batch.max_tokens