mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
fix
This commit is contained in:
parent
bfd6928c3e
commit
3fc87f93bd
@ -3,8 +3,6 @@ import torch.distributed
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch.nn import functional as F
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
|
||||
@ -206,6 +204,13 @@ class FlashCausalLMBatch(Batch):
|
||||
for i, input_ids in enumerate(all_input_ids):
|
||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||
|
||||
# Create tensors on device
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
||||
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
||||
|
||||
if len(pb.requests) > 1:
|
||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||
position_ids = torch.cat(position_ids)
|
||||
@ -223,21 +228,15 @@ class FlashCausalLMBatch(Batch):
|
||||
start_seq_prefill = start_seq
|
||||
end_seq_prefill = end_seq
|
||||
|
||||
# Create tensors on device
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
||||
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
||||
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
||||
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.int64)
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
prefill_next_token_indices = end_seq - 1
|
||||
prefill_next_token_indices = end_seq_prefill - 1
|
||||
elif no_prefill_logprobs:
|
||||
prefill_head_indices = end_seq - 1
|
||||
prefill_head_indices = end_seq_prefill - 1
|
||||
prefill_next_token_indices = None
|
||||
else:
|
||||
prefill_head_indices = torch.tensor(
|
||||
@ -392,7 +391,6 @@ class FlashCausalLMBatch(Batch):
|
||||
requests_idx_mapping = {}
|
||||
|
||||
total_batch_size = sum([len(b) for b in batches])
|
||||
total_tokens = sum(b.max_tokens for b in batches)
|
||||
|
||||
dtype = batches[0].past_key_values.dtype
|
||||
device = batches[0].input_ids.device
|
||||
@ -605,9 +603,8 @@ class FlashCausalLM(Model):
|
||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||
prefill = batch.past_key_values is None
|
||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||
single_request = len(batch) == 1
|
||||
|
||||
if prefill and single_request:
|
||||
if prefill:
|
||||
# Ask to pre-allocate kv to its max size
|
||||
# == Sum over batch size (number of tokens + max_new_tokens) - batch size
|
||||
pre_allocate_past_size = batch.max_tokens
|
||||
|
Loading…
Reference in New Issue
Block a user