mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
wip
This commit is contained in:
parent
812de7ee50
commit
5d5a2de96c
@ -1,6 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -33,12 +35,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
requests_idx_mapping: Dict[int, int]
|
requests_idx_mapping: Dict[int, int]
|
||||||
|
|
||||||
# Decoder values
|
# Decoder values
|
||||||
input_ids: List[torch.Tensor]
|
input_ids: torch.Tensor
|
||||||
position_ids: List[torch.Tensor]
|
position_ids: torch.Tensor
|
||||||
# cumulative sequence lengths
|
# cumulative sequence lengths
|
||||||
cu_seqlens: List[int]
|
cu_seqlens: torch.Tensor
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
past_key_values: Optional[Union[torch.Tensor, List[torch.Tensor]]]
|
past_key_values: Optional[torch.Tensor]
|
||||||
|
|
||||||
# All tokens
|
# All tokens
|
||||||
all_input_ids: List[List[int]]
|
all_input_ids: List[List[int]]
|
||||||
@ -53,9 +55,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
next_token_choosers: List[NextTokenChooser]
|
next_token_choosers: List[NextTokenChooser]
|
||||||
stopping_criterias: List[StoppingCriteria]
|
stopping_criterias: List[StoppingCriteria]
|
||||||
|
|
||||||
# Constant shared tensor, ref here just so that it's accessible in concatentate()
|
|
||||||
past_pad: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
# Maximum number of tokens this batch will grow to
|
# Maximum number of tokens this batch will grow to
|
||||||
max_tokens: int
|
max_tokens: int
|
||||||
|
|
||||||
@ -69,12 +68,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "FlashCausalLMBatch":
|
) -> "FlashCausalLMBatch":
|
||||||
input_ids = []
|
|
||||||
position_ids = []
|
position_ids = []
|
||||||
cu_seqlens = [0]
|
cu_seqlens = [0]
|
||||||
max_seqlen = 0
|
max_seqlen = 0
|
||||||
@ -83,7 +81,6 @@ class FlashCausalLMBatch(Batch):
|
|||||||
offsets = []
|
offsets = []
|
||||||
token_offsets = []
|
token_offsets = []
|
||||||
all_input_ids = []
|
all_input_ids = []
|
||||||
all_input_ids_tensor = []
|
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
@ -109,15 +106,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
offsets.append(None)
|
offsets.append(None)
|
||||||
token_offsets.append(None)
|
token_offsets.append(None)
|
||||||
|
|
||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
tokenized_input = torch.tensor(tokenized_input, device=device)
|
|
||||||
input_ids.append(tokenized_input)
|
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
position_ids.append(
|
position_ids.append(np.arange(0, input_length))
|
||||||
torch.arange(0, input_length, dtype=torch.int32, device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
cu_seqlens.append(cumulative_length + input_length)
|
cu_seqlens.append(cumulative_length + input_length)
|
||||||
@ -130,14 +123,16 @@ class FlashCausalLMBatch(Batch):
|
|||||||
max_new_tokens = stopping_criteria.max_new_tokens
|
max_new_tokens = stopping_criteria.max_new_tokens
|
||||||
stopping_criterias.append(stopping_criteria)
|
stopping_criterias.append(stopping_criteria)
|
||||||
|
|
||||||
all_input_ids_tensor.append(
|
|
||||||
F.pad(tokenized_input, (0, stopping_criteria.max_new_tokens))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update
|
# Update
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
max_tokens += input_length + max_new_tokens
|
max_tokens += input_length + max_new_tokens
|
||||||
|
|
||||||
|
input_ids = torch.tensor(np.concatenate(all_input_ids), dtype=torch.int32, device=device)
|
||||||
|
position_ids = torch.tensor(np.concatenate(position_ids), dtype=torch.int32, device=device)
|
||||||
|
cu_seqlens = torch.tensor(
|
||||||
|
cu_seqlens, device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
batch_id=pb.id,
|
batch_id=pb.id,
|
||||||
requests=pb.requests,
|
requests=pb.requests,
|
||||||
@ -151,10 +146,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
offsets=offsets,
|
offsets=offsets,
|
||||||
token_offsets=token_offsets,
|
token_offsets=token_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
all_input_ids_tensor=all_input_ids_tensor,
|
all_input_ids_tensor=[],
|
||||||
next_token_choosers=next_token_choosers,
|
next_token_choosers=next_token_choosers,
|
||||||
stopping_criterias=stopping_criterias,
|
stopping_criterias=stopping_criterias,
|
||||||
past_pad=None,
|
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -224,7 +218,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
cumulative_length += request_input_length
|
cumulative_length += request_input_length
|
||||||
max_tokens += request_input_length + (
|
max_tokens += request_input_length + (
|
||||||
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
if single_request:
|
if single_request:
|
||||||
@ -360,14 +354,13 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
class FlashCausalLM(Model):
|
class FlashCausalLM(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_cls: Type[PreTrainedModel],
|
model_cls: Type[PreTrainedModel],
|
||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
decode_buffer: int = 3,
|
decode_buffer: int = 3,
|
||||||
):
|
):
|
||||||
self.past_pad = None
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||||
@ -406,13 +399,13 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
position_ids: torch.Tensor,
|
position_ids: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
@ -426,42 +419,24 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
|
||||||
# Shortcut when batch_size == 1
|
# Shortcut when batch_size == 1
|
||||||
if len(batch) == 1:
|
|
||||||
input_ids = batch.input_ids[0].view(-1)
|
|
||||||
else:
|
|
||||||
# Concatenate tensors
|
|
||||||
if not isinstance(batch.input_ids, torch.Tensor):
|
|
||||||
input_ids = torch.cat(batch.input_ids).view(-1)
|
|
||||||
else:
|
|
||||||
input_ids = batch.input_ids.view(-1)
|
|
||||||
|
|
||||||
# if prefill and bs == 1
|
# if prefill and bs == 1
|
||||||
if batch.past_key_values is None and len(batch) == 1:
|
if batch.past_key_values is None and len(batch) == 1:
|
||||||
# Ask to pre-allocate kv to its max size
|
# Ask to pre-allocate kv to its max size
|
||||||
# == number of tokens + max_new_tokens
|
# == number of tokens + max_new_tokens
|
||||||
pre_allocate_past_size = (
|
pre_allocate_past_size = (
|
||||||
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
|
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pre_allocate_past_size = None
|
pre_allocate_past_size = None
|
||||||
|
|
||||||
# Concatenate when prefill, torch.tensor when decode
|
|
||||||
if batch.past_key_values is None:
|
|
||||||
position_ids = torch.cat(batch.position_ids)
|
|
||||||
else:
|
|
||||||
position_ids = batch.position_ids
|
|
||||||
|
|
||||||
cu_seqlens = torch.tensor(
|
|
||||||
batch.cu_seqlens, device=self.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
out, present = self.forward(
|
out, present = self.forward(
|
||||||
input_ids,
|
batch.input_ids,
|
||||||
position_ids,
|
batch.position_ids,
|
||||||
cu_seqlens,
|
batch.cu_seqlens,
|
||||||
batch.max_seqlen,
|
batch.max_seqlen,
|
||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
pre_allocate_past_size,
|
pre_allocate_past_size,
|
||||||
@ -483,61 +458,72 @@ class FlashCausalLM(Model):
|
|||||||
batch.next_token_choosers,
|
batch.next_token_choosers,
|
||||||
batch.stopping_criterias,
|
batch.stopping_criterias,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
batch.all_input_ids_tensor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
next_input_ids = input_ids.new_empty(len(batch.requests))
|
|
||||||
past_indices = []
|
past_indices = []
|
||||||
|
|
||||||
|
prefill = batch.past_key_values is None
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
offset,
|
offset,
|
||||||
token_offset,
|
token_offset,
|
||||||
next_token_chooser,
|
next_token_chooser,
|
||||||
stopping_criteria,
|
stopping_criteria,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
all_input_ids_tensor,
|
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
start_index = cumulative_length
|
start_index = cumulative_length
|
||||||
end_index = cumulative_length + input_length
|
end_index = cumulative_length + input_length
|
||||||
|
|
||||||
prefill = stopping_criteria.current_tokens == 0
|
|
||||||
if prefill:
|
if prefill:
|
||||||
# Prefill mode
|
# Prefill mode
|
||||||
# out is of shape [cumulative_sequence_lengths, vocab_size]
|
# out is of shape [cumulative_sequence_lengths, vocab_size]
|
||||||
logits = out[start_index:end_index]
|
logits = out[start_index:end_index]
|
||||||
|
batch.all_input_ids_tensor.append(
|
||||||
|
F.pad(batch.input_ids[start_index:end_index], (0, stopping_criteria.max_new_tokens))
|
||||||
|
)
|
||||||
|
batch.position_ids[i] = input_length
|
||||||
else:
|
else:
|
||||||
# Decode mode
|
# Decode mode
|
||||||
# out is of shape [batch_size, vocab_size]
|
# out is of shape [batch_size, vocab_size]
|
||||||
logits = out[i].unsqueeze(0)
|
logits = out[i].unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
all_input_ids_tensor = batch.all_input_ids_tensor[i]
|
||||||
|
|
||||||
# Select next token
|
# Select next token
|
||||||
next_token_id, logprobs = next_token_chooser(
|
next_token_id, logprobs = next_token_chooser(
|
||||||
all_input_ids_tensor[None, :input_length], logits
|
all_input_ids_tensor[None, :input_length], logits
|
||||||
)
|
)
|
||||||
next_token_id_squeezed = next_token_id.squeeze()
|
next_token_id_squeezed = next_token_id.squeeze()
|
||||||
all_input_ids_tensor[input_length] = next_token_id_squeezed
|
all_input_ids_tensor[input_length] = next_token_id_squeezed
|
||||||
next_input_ids[i] = next_token_id_squeezed
|
|
||||||
past_indices.extend([j for j in range(start_index + i, end_index + i)])
|
past_indices.extend([j for j in range(start_index + i, end_index + i)])
|
||||||
|
|
||||||
|
batch.input_ids[i] = next_token_id_squeezed
|
||||||
|
|
||||||
|
|
||||||
|
if prefill:
|
||||||
|
batch.input_ids = batch.input_ids[:len(batch)]
|
||||||
|
batch.position_ids = batch.position_ids[:len(batch)]
|
||||||
|
else:
|
||||||
|
batch.position_ids += 1
|
||||||
|
|
||||||
# Initialize past_key_values in prefill
|
# Initialize past_key_values in prefill
|
||||||
if batch.past_key_values is None and len(batch) == 1:
|
if batch.past_key_values is None and len(batch) == 1:
|
||||||
# present is already pre-padded
|
# present is already pre-padded
|
||||||
batch.past_key_values = present
|
batch.past_key_values = present
|
||||||
|
|
||||||
if len(batch) > 1:
|
if len(batch) > 1:
|
||||||
batch.past_key_values = present.new_empty((present.shape[0], present.shape[1] + len(batch.requests), *present.shape[2:]))
|
batch.past_key_values = present.new_empty(
|
||||||
|
(present.shape[0], present.shape[1] + len(batch.requests), *present.shape[2:]))
|
||||||
batch.past_key_values[:, past_indices] = present
|
batch.past_key_values[:, past_indices] = present
|
||||||
|
|
||||||
if prefill:
|
batch.cu_seqlens = batch.cu_seqlens + torch.arange(0, len(batch) + 1, device=self.device, dtype=torch.int32)
|
||||||
batch.position_ids = torch.tensor(batch.input_lengths, device=self.device)
|
|
||||||
else:
|
|
||||||
batch.position_ids = batch.position_ids + 1
|
|
||||||
|
|
||||||
next_token_ids = next_input_ids.tolist()
|
next_token_ids = batch.input_ids.to("cpu").detach()
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
@ -584,7 +570,7 @@ class FlashCausalLM(Model):
|
|||||||
if stop:
|
if stop:
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text = self.decode(
|
output_text = self.decode(
|
||||||
all_input_ids[-stopping_criteria.current_tokens :]
|
all_input_ids[-stopping_criteria.current_tokens:]
|
||||||
)
|
)
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
@ -599,7 +585,6 @@ class FlashCausalLM(Model):
|
|||||||
stopped = False
|
stopped = False
|
||||||
generated_text = None
|
generated_text = None
|
||||||
|
|
||||||
prefill = stopping_criteria.current_tokens == 0
|
|
||||||
# # Prefill
|
# # Prefill
|
||||||
# if prefill:
|
# if prefill:
|
||||||
# # Remove generated token to only have prefill and add nan for first prompt token
|
# # Remove generated token to only have prefill and add nan for first prompt token
|
||||||
@ -638,11 +623,6 @@ class FlashCausalLM(Model):
|
|||||||
batch.token_offsets[i] = token_offset
|
batch.token_offsets[i] = token_offset
|
||||||
batch.all_input_ids[i] = all_input_ids
|
batch.all_input_ids[i] = all_input_ids
|
||||||
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
|
batch.max_seqlen = max(batch.max_seqlen, new_input_length)
|
||||||
# Cumulative sum
|
|
||||||
batch.cu_seqlens[(i + 1)] = batch.cu_seqlens[i] + new_input_length
|
|
||||||
|
|
||||||
|
|
||||||
batch.input_ids = next_input_ids
|
|
||||||
|
|
||||||
# No need to return a batch if we know that all requests stopped
|
# No need to return a batch if we know that all requests stopped
|
||||||
return generations, batch if not stopped else None
|
return generations, batch if not stopped else None
|
||||||
|
Loading…
Reference in New Issue
Block a user