This commit is contained in:
OlivierDehaene 2023-06-02 18:17:18 +02:00
parent bfd6928c3e
commit 3fc87f93bd

View File

@ -3,8 +3,6 @@ import torch.distributed
import numpy as np
from torch.nn import functional as F
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
@ -206,6 +204,13 @@ class FlashCausalLMBatch(Batch):
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
@ -223,21 +228,15 @@ class FlashCausalLMBatch(Batch):
start_seq_prefill = start_seq
end_seq_prefill = end_seq
# Create tensors on device
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.int64)
if all_prefill_logprobs:
prefill_head_indices = None
prefill_next_token_indices = end_seq - 1
prefill_next_token_indices = end_seq_prefill - 1
elif no_prefill_logprobs:
prefill_head_indices = end_seq - 1
prefill_head_indices = end_seq_prefill - 1
prefill_next_token_indices = None
else:
prefill_head_indices = torch.tensor(
@ -392,7 +391,6 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping = {}
total_batch_size = sum([len(b) for b in batches])
total_tokens = sum(b.max_tokens for b in batches)
dtype = batches[0].past_key_values.dtype
device = batches[0].input_ids.device
@ -605,9 +603,8 @@ class FlashCausalLM(Model):
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None
prefill_logprobs = batch.prefill_next_token_indices is not None
single_request = len(batch) == 1
if prefill and single_request:
if prefill:
# Ask to pre-allocate kv to its max size
# == Sum over batch size (number of tokens + max_new_tokens) - batch size
pre_allocate_past_size = batch.max_tokens