mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
fix
This commit is contained in:
parent
bfd6928c3e
commit
3fc87f93bd
@ -3,8 +3,6 @@ import torch.distributed
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
|
||||||
@ -206,6 +204,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
for i, input_ids in enumerate(all_input_ids):
|
for i, input_ids in enumerate(all_input_ids):
|
||||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||||
|
|
||||||
|
# Create tensors on device
|
||||||
|
all_input_ids_tensor = torch.tensor(
|
||||||
|
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||||
|
)
|
||||||
|
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
||||||
|
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
||||||
|
|
||||||
if len(pb.requests) > 1:
|
if len(pb.requests) > 1:
|
||||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||||
position_ids = torch.cat(position_ids)
|
position_ids = torch.cat(position_ids)
|
||||||
@ -223,21 +228,15 @@ class FlashCausalLMBatch(Batch):
|
|||||||
start_seq_prefill = start_seq
|
start_seq_prefill = start_seq
|
||||||
end_seq_prefill = end_seq
|
end_seq_prefill = end_seq
|
||||||
|
|
||||||
# Create tensors on device
|
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||||
all_input_ids_tensor = torch.tensor(
|
|
||||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
|
||||||
)
|
|
||||||
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
|
||||||
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
|
|
||||||
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
|
|
||||||
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.int64)
|
past_present_indices = torch.tensor(past_present_indices, device=device, dtype=torch.int64)
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
prefill_head_indices = None
|
prefill_head_indices = None
|
||||||
prefill_next_token_indices = end_seq - 1
|
prefill_next_token_indices = end_seq_prefill - 1
|
||||||
elif no_prefill_logprobs:
|
elif no_prefill_logprobs:
|
||||||
prefill_head_indices = end_seq - 1
|
prefill_head_indices = end_seq_prefill - 1
|
||||||
prefill_next_token_indices = None
|
prefill_next_token_indices = None
|
||||||
else:
|
else:
|
||||||
prefill_head_indices = torch.tensor(
|
prefill_head_indices = torch.tensor(
|
||||||
@ -392,7 +391,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
total_batch_size = sum([len(b) for b in batches])
|
total_batch_size = sum([len(b) for b in batches])
|
||||||
total_tokens = sum(b.max_tokens for b in batches)
|
|
||||||
|
|
||||||
dtype = batches[0].past_key_values.dtype
|
dtype = batches[0].past_key_values.dtype
|
||||||
device = batches[0].input_ids.device
|
device = batches[0].input_ids.device
|
||||||
@ -605,9 +603,8 @@ class FlashCausalLM(Model):
|
|||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||||
prefill = batch.past_key_values is None
|
prefill = batch.past_key_values is None
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
single_request = len(batch) == 1
|
|
||||||
|
|
||||||
if prefill and single_request:
|
if prefill:
|
||||||
# Ask to pre-allocate kv to its max size
|
# Ask to pre-allocate kv to its max size
|
||||||
# == Sum over batch size (number of tokens + max_new_tokens) - batch size
|
# == Sum over batch size (number of tokens + max_new_tokens) - batch size
|
||||||
pre_allocate_past_size = batch.max_tokens
|
pre_allocate_past_size = batch.max_tokens
|
||||||
|
Loading…
Reference in New Issue
Block a user