mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
wip
This commit is contained in:
parent
e36dfaa8de
commit
7169cbae6d
@ -149,11 +149,26 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
max_seqlen: int
|
max_seqlen: int
|
||||||
|
|
||||||
# Prefill metadata tensors to efficiently compute logprobs
|
# Prefill metadata tensors
|
||||||
prefill_head_indices: Optional[torch.Tensor]
|
prefill_head_indices: Optional[torch.Tensor]
|
||||||
prefill_next_token_indices: Optional[torch.tensor]
|
prefill_next_token_indices: Optional[torch.Tensor]
|
||||||
prefill_cu_outlens: Optional[List[int]]
|
prefill_cu_outlens: Optional[List[int]]
|
||||||
|
|
||||||
|
# Whether at least one request is prefilling/chunking
|
||||||
|
# == any(prefilling_mask)
|
||||||
|
prefilling: bool
|
||||||
|
# For each request, whether they are still prefilling/chunking
|
||||||
|
prefilling_mask: List[bool]
|
||||||
|
# For each request, whether the model output should be used or discarded
|
||||||
|
# If we are chunking, we don't care about the output as it might be different
|
||||||
|
# from the token in the prompt
|
||||||
|
use_output_token: List[bool]
|
||||||
|
|
||||||
|
# If the request is decoding, `next_chunk_length = 1`
|
||||||
|
# `None if not batch.prefilling`
|
||||||
|
next_chunk_lengths: Optional[List[int]]
|
||||||
|
next_chunk_lengths_tensor: Optional[torch.Tensor]
|
||||||
|
|
||||||
# Prefixes
|
# Prefixes
|
||||||
prefix_ids: List[List[int]]
|
prefix_ids: List[List[int]]
|
||||||
|
|
||||||
@ -232,11 +247,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
requests_idx_mapping = {}
|
requests_idx_mapping = {}
|
||||||
|
|
||||||
|
chunking = False
|
||||||
all_prefill_logprobs = True
|
all_prefill_logprobs = True
|
||||||
no_prefill_logprobs = True
|
no_prefill_logprobs = True
|
||||||
prefill_head_indices = []
|
prefill_head_indices = []
|
||||||
prefill_next_token_indices = []
|
prefill_next_token_indices = []
|
||||||
prefill_cu_outlens = [0]
|
prefill_cu_outlens = [0]
|
||||||
|
next_chunk_lengths = []
|
||||||
|
use_output_token = []
|
||||||
|
|
||||||
next_token_chooser_parameters = []
|
next_token_chooser_parameters = []
|
||||||
stopping_criterias = []
|
stopping_criterias = []
|
||||||
@ -276,6 +294,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
assert prefix_len > 0
|
assert prefix_len > 0
|
||||||
prefix_len -= 1
|
prefix_len -= 1
|
||||||
|
|
||||||
|
|
||||||
# Commented as it's costly.
|
# Commented as it's costly.
|
||||||
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
# log_master(logger.debug, "Tokenized input ids {tokenized_input}")
|
||||||
prefix_ids.append(tokenized_input[:prefix_len])
|
prefix_ids.append(tokenized_input[:prefix_len])
|
||||||
@ -284,9 +303,18 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
|
if True:
|
||||||
|
# This request only requires one prefill and no chunking
|
||||||
|
use_output_token.append(True)
|
||||||
|
next_chunk_lengths.append(1)
|
||||||
|
else:
|
||||||
|
chunking = True
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
prefix_offsets.append(input_length - 5)
|
prefix_offsets.append(input_length - 5)
|
||||||
read_offsets.append(input_length)
|
read_offsets.append(input_length)
|
||||||
|
|
||||||
|
# FIXME: use all input tokens not just postfix ones
|
||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
@ -357,6 +385,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Create tensor to slice into the kv tensor in prefill
|
# Create tensor to slice into the kv tensor in prefill
|
||||||
if sliding_window is not None:
|
if sliding_window is not None:
|
||||||
|
raise NotImplementedError
|
||||||
request_prefill_cache_indices = torch.arange(
|
request_prefill_cache_indices = torch.arange(
|
||||||
cumulative_length + max(0, input_length - sliding_window),
|
cumulative_length + max(0, input_length - sliding_window),
|
||||||
cumulative_length + input_length,
|
cumulative_length + input_length,
|
||||||
@ -368,6 +397,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||||
|
|
||||||
if r.prefill_logprobs:
|
if r.prefill_logprobs:
|
||||||
|
raise NotImplementedError
|
||||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||||
prefill_next_token_indices.append(
|
prefill_next_token_indices.append(
|
||||||
prefill_out_cumulative_length + input_length - 1
|
prefill_out_cumulative_length + input_length - 1
|
||||||
@ -445,6 +475,12 @@ class FlashCausalLMBatch(Batch):
|
|||||||
adapter_segments, dtype=torch.int32, device=device
|
adapter_segments, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if chunking:
|
||||||
|
next_chunk_lengths_tensor = torch.tensor(next_chunk_lengths, dtype=torch.int64, device=device)
|
||||||
|
else:
|
||||||
|
next_chunk_lengths = None
|
||||||
|
next_chunk_lengths_tensor = None
|
||||||
|
|
||||||
if all_prefill_logprobs:
|
if all_prefill_logprobs:
|
||||||
prefill_head_indices = None
|
prefill_head_indices = None
|
||||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||||
@ -491,6 +527,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_head_indices=prefill_head_indices,
|
prefill_head_indices=prefill_head_indices,
|
||||||
prefill_next_token_indices=prefill_next_token_indices,
|
prefill_next_token_indices=prefill_next_token_indices,
|
||||||
prefill_cu_outlens=prefill_cu_outlens,
|
prefill_cu_outlens=prefill_cu_outlens,
|
||||||
|
prefilling=True,
|
||||||
|
prefilling_mask=[True] * pb.requests.len(),
|
||||||
|
use_output_token=use_output_token,
|
||||||
|
next_chunk_lengths=next_chunk_lengths,
|
||||||
|
next_chunk_lengths_tensor=next_chunk_lengths_tensor,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
@ -1426,7 +1467,7 @@ class FlashCausalLM(Model):
|
|||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
if not batch.prefilling and self.max_past() is not None:
|
||||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
# in a circular buffer mode.
|
# in a circular buffer mode.
|
||||||
# This makes sure the max_s for the decode pass is correct.
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
@ -1440,7 +1481,7 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if batch.prefilling or cuda_graph is None:
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = block_tables_to_ragged(
|
block_tables = block_tables_to_ragged(
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
@ -1475,6 +1516,7 @@ class FlashCausalLM(Model):
|
|||||||
adapter_data=adapter_data,
|
adapter_data=adapter_data,
|
||||||
)
|
)
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
|
raise NotImplementedError
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
@ -1528,7 +1570,6 @@ class FlashCausalLM(Model):
|
|||||||
self, batch: FlashCausalLMBatch
|
self, batch: FlashCausalLMBatch
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
prefill = batch.cu_seqlen_prefill is not None
|
|
||||||
prefill_logprobs = batch.prefill_next_token_indices is not None
|
prefill_logprobs = batch.prefill_next_token_indices is not None
|
||||||
|
|
||||||
# Update adapter indices for speculative tokens (if present)
|
# Update adapter indices for speculative tokens (if present)
|
||||||
@ -1554,13 +1595,13 @@ class FlashCausalLM(Model):
|
|||||||
adapter_data = AdapterBatchData.from_meta(
|
adapter_data = AdapterBatchData.from_meta(
|
||||||
adapter_meta,
|
adapter_meta,
|
||||||
self.layer_to_adapter_weights,
|
self.layer_to_adapter_weights,
|
||||||
prefill,
|
batch.prefilling,
|
||||||
batch.prefill_head_indices,
|
batch.prefill_head_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
out, speculative_logits = self.forward(batch, adapter_data)
|
out, speculative_logits = self.forward(batch, adapter_data)
|
||||||
|
|
||||||
if prefill:
|
if batch.prefilling:
|
||||||
next_token_logits = (
|
next_token_logits = (
|
||||||
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||||
)
|
)
|
||||||
@ -1597,22 +1638,31 @@ class FlashCausalLM(Model):
|
|||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill:
|
if batch.prefilling:
|
||||||
if len(batch) > 1 and prefill_logprobs:
|
if len(batch) > 1 and prefill_logprobs:
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||||
# When batch == 1, we will just use the batch.input_ids values directly
|
# When batch == 1, we will just use the batch.input_ids values directly
|
||||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||||
|
|
||||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
if batch.next_chunk_lengths is None:
|
||||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
# We are done prefilling after this forward
|
||||||
# We do not need cu_seqlen_prefill anymore
|
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||||
batch.cu_seqlen_prefill = None
|
# [BATCH_SIZE]
|
||||||
|
# Last slot for each request, will be incremented later
|
||||||
|
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||||
|
else:
|
||||||
|
# We still have prefill chunks to go through
|
||||||
|
next_forward_size = sum(batch.next_chunk_lengths)
|
||||||
|
next_position_ids = batch.position_ids.new_empty(next_forward_size)
|
||||||
|
batch.slot_indices = batch.slot_indices.new_empty(next_forward_size)
|
||||||
|
batch.cu_seqlen_prefill[1:] = torch.cumsum(batch.next_chunk_lengths_tensor, dim=0)
|
||||||
else:
|
else:
|
||||||
prefill_logprobs = None
|
prefill_logprobs = None
|
||||||
next_position_ids = batch.position_ids
|
next_position_ids = batch.position_ids
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
cumulative_length = 0
|
cumulative_length = 0
|
||||||
|
cumulative_chunk_lengths = 0
|
||||||
|
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@ -1625,21 +1675,32 @@ class FlashCausalLM(Model):
|
|||||||
# one, we need to first do a GPU <-> CPU sync
|
# one, we need to first do a GPU <-> CPU sync
|
||||||
# It is faster if we delay this sync for the maximum amount of time
|
# It is faster if we delay this sync for the maximum amount of time
|
||||||
|
|
||||||
# For each member of the batch
|
|
||||||
index = 0
|
index = 0
|
||||||
|
# For each member of the batch
|
||||||
for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
|
for i, (input_length, all_input_ids, n_accepted_ids) in enumerate(iterator):
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
start_index = cumulative_length
|
start_index = cumulative_length
|
||||||
end_index = cumulative_length + input_length
|
end_index = cumulative_length + input_length
|
||||||
|
|
||||||
if prefill:
|
if batch.prefilling:
|
||||||
|
if batch.next_chunk_lengths is not None:
|
||||||
|
next_chunk_length = batch.next_chunk_lengths[i]
|
||||||
|
else:
|
||||||
|
next_chunk_length = 1
|
||||||
|
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
out_start_index = batch.prefill_cu_outlens[i]
|
out_start_index = batch.prefill_cu_outlens[i]
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
out_length = out_end_index - out_start_index
|
out_length = out_end_index - out_start_index
|
||||||
|
|
||||||
|
position_start_index =
|
||||||
|
|
||||||
# Initialize position_ids
|
# Initialize position_ids
|
||||||
# In decode, we do not need this as we can just increment position ids
|
# In decode, we do not need this as we can just increment position ids
|
||||||
|
|
||||||
|
|
||||||
|
next_position_ids
|
||||||
|
|
||||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||||
|
|
||||||
# Initialize adapter indices
|
# Initialize adapter indices
|
||||||
@ -1651,6 +1712,7 @@ class FlashCausalLM(Model):
|
|||||||
# Used to gather prefill logprobs
|
# Used to gather prefill logprobs
|
||||||
# Copy batch.input_ids to prefill_token_indices
|
# Copy batch.input_ids to prefill_token_indices
|
||||||
if prefill_logprobs:
|
if prefill_logprobs:
|
||||||
|
raise NotImplementedError
|
||||||
if len(batch) > 1:
|
if len(batch) > 1:
|
||||||
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
|
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
|
||||||
batch.input_ids[start_index + 1 : start_index + out_length]
|
batch.input_ids[start_index + 1 : start_index + out_length]
|
||||||
@ -1668,14 +1730,23 @@ class FlashCausalLM(Model):
|
|||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
if batch.next_prefilling_chunk_lengths is None:
|
||||||
batch.speculative_ids = speculative_ids
|
# We are done prefilling
|
||||||
batch.position_ids = next_position_ids + accepted_ids
|
batch.prefilling = False
|
||||||
batch.input_lengths_tensor += accepted_ids
|
batch.next_prefilling_chunk_lengths = None
|
||||||
batch.slot_indices += accepted_ids
|
batch.next_prefilling_chunk_lengths_tensor = None
|
||||||
batch.adapter_meta.adapter_indices = next_adapter_indices
|
# We do not need cu_seqlen_prefill anymore
|
||||||
|
batch.cu_seqlen_prefill = None
|
||||||
|
|
||||||
if prefill:
|
if not batch.prefilling:
|
||||||
|
batch.input_ids = next_input_ids[accepted_ids.cumsum(dim=-1) - 1]
|
||||||
|
batch.speculative_ids = speculative_ids
|
||||||
|
batch.position_ids = next_position_ids + accepted_ids
|
||||||
|
batch.input_lengths_tensor += accepted_ids
|
||||||
|
batch.slot_indices += accepted_ids
|
||||||
|
batch.adapter_meta.adapter_indices = next_adapter_indices
|
||||||
|
|
||||||
|
if batch.prefilling:
|
||||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||||
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
||||||
batch.adapter_meta.adapter_segments = torch.tensor(
|
batch.adapter_meta.adapter_segments = torch.tensor(
|
||||||
@ -1684,7 +1755,8 @@ class FlashCausalLM(Model):
|
|||||||
device=batch.adapter_meta.adapter_segments.device,
|
device=batch.adapter_meta.adapter_segments.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if batch.prefilling and prefill_logprobs:
|
||||||
|
raise NotImplementedError
|
||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
prefill_logprobs_tensor = torch.log_softmax(out, -1)
|
||||||
prefill_logprobs = torch.gather(
|
prefill_logprobs = torch.gather(
|
||||||
@ -1795,7 +1867,8 @@ class FlashCausalLM(Model):
|
|||||||
generated_text = None
|
generated_text = None
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if prefill and request.prefill_logprobs:
|
if batch.prefilling and request.prefill_logprobs:
|
||||||
|
raise NotImplementedError
|
||||||
out_start_index = batch.prefill_cu_outlens[i]
|
out_start_index = batch.prefill_cu_outlens[i]
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user