mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
pingpong optimization
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
29703dbd27
commit
cd900c3b72
@ -253,6 +253,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
hpu_attn_meta: Optional[HPUPagedAttentionMetadata]
|
hpu_attn_meta: Optional[HPUPagedAttentionMetadata]
|
||||||
|
|
||||||
|
next_token_logits: Optional[torch.Tensor]
|
||||||
|
speculative_logits: Optional[torch.Tensor]
|
||||||
|
|
||||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||||
return generate_pb2.CachedBatch(
|
return generate_pb2.CachedBatch(
|
||||||
id=self.batch_id,
|
id=self.batch_id,
|
||||||
@ -490,6 +493,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_lengths_tensor=None,
|
input_lengths_tensor=None,
|
||||||
adapter_meta=None,
|
adapter_meta=None,
|
||||||
hpu_attn_meta=None,
|
hpu_attn_meta=None,
|
||||||
|
next_token_logits=None,
|
||||||
|
speculative_logits=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -698,6 +703,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
speculative_ids=speculative_ids,
|
speculative_ids=speculative_ids,
|
||||||
adapter_meta=adapter_meta,
|
adapter_meta=adapter_meta,
|
||||||
hpu_attn_meta=None,
|
hpu_attn_meta=None,
|
||||||
|
next_token_logits=None,
|
||||||
|
speculative_logits=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -959,6 +966,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
speculative_ids=speculative_ids,
|
speculative_ids=speculative_ids,
|
||||||
adapter_meta=adapter_meta,
|
adapter_meta=adapter_meta,
|
||||||
hpu_attn_meta=None,
|
hpu_attn_meta=None,
|
||||||
|
next_token_logits=None,
|
||||||
|
speculative_logits=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
|
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
|
||||||
@ -1484,7 +1493,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
||||||
if max_total_tokens is None:
|
if max_total_tokens is None:
|
||||||
max_total_tokens = sum(batch.cache_lengths)
|
max_total_tokens = sum(batch.input_lengths)
|
||||||
|
|
||||||
if max_input_tokens is None:
|
if max_input_tokens is None:
|
||||||
max_input_tokens = max_total_tokens - 1
|
max_input_tokens = max_total_tokens - 1
|
||||||
@ -1531,6 +1540,8 @@ class FlashCausalLM(Model):
|
|||||||
for i, (batch_size, block_num) in enumerate(
|
for i, (batch_size, block_num) in enumerate(
|
||||||
reversed(self.bucketing_ctx.decode_buckets)
|
reversed(self.bucketing_ctx.decode_buckets)
|
||||||
):
|
):
|
||||||
|
if batch_size > block_num:
|
||||||
|
continue
|
||||||
log_master(
|
log_master(
|
||||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||||
)
|
)
|
||||||
@ -1803,6 +1814,144 @@ class FlashCausalLM(Model):
|
|||||||
def generate_token(
|
def generate_token(
|
||||||
self, batches: List[FlashCausalLMBatch]
|
self, batches: List[FlashCausalLMBatch]
|
||||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
|
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
|
||||||
|
|
||||||
|
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
|
||||||
|
# Stage 1. Collect next token ids of any previously started generations
|
||||||
|
prev_batches = []
|
||||||
|
requests_to_generate = []
|
||||||
|
for batch_id, batch in enumerate(batches):
|
||||||
|
if batch.next_token_logits is not None:
|
||||||
|
prefill = batch.prefilling
|
||||||
|
if batch.prefilling:
|
||||||
|
batch.prefilling = False
|
||||||
|
batch.prefilling_mask = [False] * len(batch)
|
||||||
|
|
||||||
|
speculate = get_speculate()
|
||||||
|
(
|
||||||
|
next_input_ids,
|
||||||
|
next_token_logprobs,
|
||||||
|
logprobs,
|
||||||
|
accepted_ids,
|
||||||
|
speculative_ids,
|
||||||
|
) = batch.next_token_chooser(
|
||||||
|
batch.all_input_ids_tensor[:, : batch.max_current_length],
|
||||||
|
batch.next_token_logits,
|
||||||
|
speculate,
|
||||||
|
batch.speculative_ids,
|
||||||
|
batch.speculative_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
|
batch.top_n_tokens,
|
||||||
|
batch.top_n_tokens_tensor,
|
||||||
|
logprobs,
|
||||||
|
accepted_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
|
||||||
|
# instantly become of shape [BATCH_SIZE]
|
||||||
|
if prefill:
|
||||||
|
indices = batch.cu_seqlen_prefill[1:] - 1
|
||||||
|
# pad in left
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
batch.position_ids = batch.position_ids[
|
||||||
|
batch.prefill_cache_indices
|
||||||
|
][indices]
|
||||||
|
else:
|
||||||
|
batch.position_ids = batch.position_ids[indices]
|
||||||
|
|
||||||
|
batch.slot_indices = batch.slot_indices[indices]
|
||||||
|
batch.adapter_meta.adapter_indices = (
|
||||||
|
batch.adapter_meta.adapter_indices[indices]
|
||||||
|
)
|
||||||
|
# For each member of the batch
|
||||||
|
# Cumulative length
|
||||||
|
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||||
|
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||||
|
if batch.speculative_logits is not None:
|
||||||
|
for i in range(len(batch)):
|
||||||
|
batch.all_input_ids_tensor[
|
||||||
|
i,
|
||||||
|
batch.cache_lengths_tensor[i]
|
||||||
|
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
|
||||||
|
+ batch.input_lengths[i]
|
||||||
|
+ accepted_ids[i],
|
||||||
|
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||||
|
else:
|
||||||
|
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||||
|
batch_idx = torch.arange(
|
||||||
|
0,
|
||||||
|
batch.all_input_ids_tensor.shape[0],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=batch.input_lengths_tensor.device,
|
||||||
|
)
|
||||||
|
batch.all_input_ids_tensor.index_put_(
|
||||||
|
(batch_idx, index.long()), next_input_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||||
|
batch.speculative_ids = speculative_ids
|
||||||
|
if batch.position_ids.dim() == 2:
|
||||||
|
# Qwen2_vl case:
|
||||||
|
batch.position_ids += accepted_ids.unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
batch.position_ids += accepted_ids
|
||||||
|
batch.cache_lengths_tensor += (
|
||||||
|
batch.input_lengths_tensor + accepted_ids - 1
|
||||||
|
)
|
||||||
|
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||||
|
batch.slot_indices += accepted_ids
|
||||||
|
|
||||||
|
# Does a HPU <-> CPU sync internally
|
||||||
|
if prefill:
|
||||||
|
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||||
|
adapter_segments, _ = find_segments(
|
||||||
|
batch.adapter_meta.adapter_indices
|
||||||
|
)
|
||||||
|
batch.adapter_meta.adapter_segments = torch.tensor(
|
||||||
|
adapter_segments,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=batch.adapter_meta.adapter_segments.device,
|
||||||
|
)
|
||||||
|
prev_batches.append(
|
||||||
|
{
|
||||||
|
"next_token_ids": next_input_ids,
|
||||||
|
"next_token_logprobs": next_token_logprobs,
|
||||||
|
"accepted_ids": accepted_ids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
idx = len(prev_batches) - 1
|
||||||
|
|
||||||
|
for req_idx, req in enumerate(batch.requests):
|
||||||
|
requests_to_generate.append(
|
||||||
|
{
|
||||||
|
"idx": idx,
|
||||||
|
"request_id": req.id,
|
||||||
|
"cache_length": batch.cache_lengths[req_idx],
|
||||||
|
"input_length": batch.input_lengths[req_idx],
|
||||||
|
"prefix_offset": batch.prefix_offsets[req_idx],
|
||||||
|
"read_offset": batch.read_offsets[req_idx],
|
||||||
|
"stopping_criteria": batch.stopping_criterias[req_idx],
|
||||||
|
"all_input_ids": batch.all_input_ids[req_idx],
|
||||||
|
"do_sample": batch.next_token_chooser.do_sample[req_idx],
|
||||||
|
"seed": batch.next_token_chooser.seeds[req_idx],
|
||||||
|
"top_n_tokens": batch.top_n_tokens[req_idx],
|
||||||
|
"top_token_ids": batch_top_token_ids[req_idx],
|
||||||
|
"top_token_logprobs": batch_top_token_logprobs[req_idx],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if prefill:
|
||||||
|
# We do not need prefill tensors anymore
|
||||||
|
batch.cu_seqlen_prefill = None
|
||||||
|
batch.prefill_cache_indices = None
|
||||||
|
batch.prefill_cu_outlens = None
|
||||||
|
batch.prefill_head_indices = None
|
||||||
|
batch.prefill_next_token_indices = None
|
||||||
|
batch.next_token_logits = None
|
||||||
|
batch.speculative_ids = None
|
||||||
|
|
||||||
|
htorch.core.mark_step()
|
||||||
|
# Stage 2. Prepare new batch for speculative scheduling
|
||||||
if len(batches) > 1:
|
if len(batches) > 1:
|
||||||
batch = self.batch_type.concatenate(batches)
|
batch = self.batch_type.concatenate(batches)
|
||||||
else:
|
else:
|
||||||
@ -1851,7 +2000,7 @@ class FlashCausalLM(Model):
|
|||||||
out, speculative_logits = self.forward(batch, adapter_data)
|
out, speculative_logits = self.forward(batch, adapter_data)
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
next_token_logits = (
|
batch.next_token_logits = (
|
||||||
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||||
)
|
)
|
||||||
if speculative_logits is not None:
|
if speculative_logits is not None:
|
||||||
@ -1862,364 +2011,147 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_logprobs = None
|
prefill_logprobs = None
|
||||||
next_token_logits = out
|
batch.next_token_logits = out
|
||||||
|
batch.speculative_logits = speculative_logits
|
||||||
|
|
||||||
finished_prefilling = True
|
# HPU->CPU sync
|
||||||
next_chunk_lengths = []
|
for prev_batch in prev_batches:
|
||||||
current_prefilling_mask = batch.prefilling_mask
|
prev_batch["next_token_logprobs"] = prev_batch[
|
||||||
if prefill:
|
"next_token_logprobs"
|
||||||
finished_prefilling = True
|
].tolist()
|
||||||
next_prefilling_mask = [False] * len(batch)
|
prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist()
|
||||||
|
prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist()
|
||||||
batch.prefilling = not finished_prefilling
|
|
||||||
batch.prefilling_mask = next_prefilling_mask
|
|
||||||
|
|
||||||
speculate = get_speculate()
|
|
||||||
(
|
|
||||||
next_input_ids,
|
|
||||||
next_token_logprobs,
|
|
||||||
logprobs,
|
|
||||||
accepted_ids,
|
|
||||||
speculative_ids,
|
|
||||||
) = batch.next_token_chooser(
|
|
||||||
batch.all_input_ids_tensor[:, : batch.max_current_length],
|
|
||||||
next_token_logits,
|
|
||||||
speculate,
|
|
||||||
batch.speculative_ids,
|
|
||||||
speculative_logits,
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
|
||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
|
|
||||||
# instantly become of shape [BATCH_SIZE]
|
|
||||||
if prefill and finished_prefilling:
|
|
||||||
indices = batch.cu_seqlen_prefill[1:] - 1
|
|
||||||
# pad in left
|
|
||||||
if batch.prefill_cache_indices is not None:
|
|
||||||
batch.position_ids = batch.position_ids[batch.prefill_cache_indices][
|
|
||||||
indices
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
batch.position_ids = batch.position_ids[indices]
|
|
||||||
|
|
||||||
batch.slot_indices = batch.slot_indices[indices]
|
|
||||||
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
|
||||||
indices
|
|
||||||
]
|
|
||||||
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
|
||||||
# one, we need to first do a HPU <-> CPU sync
|
|
||||||
# It is faster if we delay this sync for the maximum amount of time
|
|
||||||
|
|
||||||
# For each member of the batch
|
|
||||||
# Cumulative length
|
|
||||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
|
||||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
|
||||||
if speculative_logits is not None:
|
|
||||||
for i in range(len(batch)):
|
|
||||||
batch.all_input_ids_tensor[
|
|
||||||
i,
|
|
||||||
batch.cache_lengths_tensor[i]
|
|
||||||
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
|
|
||||||
+ batch.input_lengths[i]
|
|
||||||
+ accepted_ids[i],
|
|
||||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
|
||||||
else:
|
|
||||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
|
||||||
batch_idx = torch.arange(
|
|
||||||
0,
|
|
||||||
batch.all_input_ids_tensor.shape[0],
|
|
||||||
dtype=torch.long,
|
|
||||||
device=batch.input_lengths_tensor.device,
|
|
||||||
)
|
|
||||||
batch.all_input_ids_tensor.index_put_(
|
|
||||||
(batch_idx, index.long()), next_input_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update values
|
|
||||||
# These values can be updated without a HPU -> CPU sync
|
|
||||||
if not prefill or (prefill and finished_prefilling):
|
|
||||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
|
||||||
batch.speculative_ids = speculative_ids
|
|
||||||
if batch.position_ids.dim() == 2:
|
|
||||||
# Qwen2_vl case:
|
|
||||||
batch.position_ids += accepted_ids.unsqueeze(-1)
|
|
||||||
else:
|
|
||||||
batch.position_ids += accepted_ids
|
|
||||||
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
|
|
||||||
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
|
||||||
batch.slot_indices += accepted_ids
|
|
||||||
|
|
||||||
# Does a HPU <-> CPU sync internally
|
|
||||||
if prefill and finished_prefilling:
|
|
||||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
|
||||||
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
|
||||||
batch.adapter_meta.adapter_segments = torch.tensor(
|
|
||||||
adapter_segments,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=batch.adapter_meta.adapter_segments.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# HPU <-> CPU sync
|
|
||||||
next_token_logprobs = next_token_logprobs.tolist()
|
|
||||||
next_token_ids = next_input_ids.tolist()
|
|
||||||
accepted_ids = accepted_ids.tolist()
|
|
||||||
|
|
||||||
# Update values if we need to continue prefilling
|
|
||||||
# This represents the `else` case of the `Update values` if above
|
|
||||||
# but since this require the `next_token_ids` to be on CPU, it is better to do it here
|
|
||||||
if prefill and not finished_prefilling:
|
|
||||||
# Speculation must be ignored while we prefill even with chunking
|
|
||||||
# it simplifies everything
|
|
||||||
assert batch.speculative_ids is None
|
|
||||||
|
|
||||||
all_postfix_ids = []
|
|
||||||
for i, (
|
|
||||||
request_prefilling,
|
|
||||||
next_token_id,
|
|
||||||
all_input_ids,
|
|
||||||
cache_length,
|
|
||||||
input_length,
|
|
||||||
next_chunk_length,
|
|
||||||
) in enumerate(
|
|
||||||
zip(
|
|
||||||
batch.prefilling_mask,
|
|
||||||
next_token_ids,
|
|
||||||
batch.all_input_ids,
|
|
||||||
batch.cache_lengths,
|
|
||||||
batch.input_lengths,
|
|
||||||
next_chunk_lengths,
|
|
||||||
)
|
|
||||||
):
|
|
||||||
if request_prefilling:
|
|
||||||
next_cache_length = cache_length + input_length
|
|
||||||
# Get new prompt IDs to prefill
|
|
||||||
postfix_ids = all_input_ids[
|
|
||||||
next_cache_length : next_cache_length + next_chunk_length
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# This request is done prefilling, the new id is the one selected the sampling method
|
|
||||||
postfix_ids = [next_token_id]
|
|
||||||
|
|
||||||
all_postfix_ids.append(postfix_ids)
|
|
||||||
|
|
||||||
batch.input_ids = all_postfix_ids
|
|
||||||
|
|
||||||
start_decode = time.time_ns()
|
start_decode = time.time_ns()
|
||||||
|
# Stage 3. Finish and return previous generations
|
||||||
# Results
|
# Results
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
stopped = True
|
stopped = len(requests_to_generate) > 0
|
||||||
|
|
||||||
# Zipped iterator
|
|
||||||
iterator = zip(
|
|
||||||
batch.requests,
|
|
||||||
batch.prompt_lengths,
|
|
||||||
batch.cache_lengths,
|
|
||||||
batch.input_lengths,
|
|
||||||
batch.prefix_offsets,
|
|
||||||
batch.read_offsets,
|
|
||||||
batch.stopping_criterias,
|
|
||||||
batch.all_input_ids,
|
|
||||||
batch.next_token_chooser.do_sample,
|
|
||||||
batch.next_token_chooser.seeds,
|
|
||||||
batch.top_n_tokens,
|
|
||||||
current_prefilling_mask,
|
|
||||||
batch.prefilling_mask,
|
|
||||||
accepted_ids,
|
|
||||||
batch_top_token_ids,
|
|
||||||
batch_top_token_logprobs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reset max_input_length
|
# Reset max_input_length
|
||||||
batch.max_input_length = 0
|
batch.max_input_length = 0
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
index = 0
|
indexs = [0] * len(prev_batches)
|
||||||
for i, (
|
idx_accept_ids = [0] * len(prev_batches)
|
||||||
request,
|
for i, req_data in enumerate(requests_to_generate):
|
||||||
prompt_length,
|
idx = req_data["idx"]
|
||||||
cache_length,
|
request_id = req_data["request_id"]
|
||||||
input_length,
|
cache_length = req_data["cache_length"]
|
||||||
prefix_offset,
|
input_length = req_data["input_length"]
|
||||||
read_offset,
|
prefix_offset = req_data["prefix_offset"]
|
||||||
stopping_criteria,
|
read_offset = req_data["read_offset"]
|
||||||
all_input_ids,
|
stopping_criteria = req_data["stopping_criteria"]
|
||||||
do_sample,
|
all_input_ids = req_data["all_input_ids"]
|
||||||
seed,
|
do_sample = req_data["do_sample"]
|
||||||
top_n_tokens,
|
seed = req_data["seed"]
|
||||||
request_was_prefilling,
|
top_n_tokens = req_data["top_n_tokens"]
|
||||||
request_is_prefilling,
|
n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]]
|
||||||
n_accepted_ids,
|
top_token_ids = req_data["top_token_ids"]
|
||||||
top_token_ids,
|
top_token_logprobs = req_data["top_token_logprobs"]
|
||||||
top_token_logprobs,
|
|
||||||
) in enumerate(iterator):
|
|
||||||
# Compute logprobs first as, even though we might skip the token,
|
|
||||||
# it can still be required to compute the logprobs
|
|
||||||
# modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
|
|
||||||
# this state to be stable
|
|
||||||
if request.id % self.world_size == self.rank:
|
|
||||||
# Prefill
|
|
||||||
if request_was_prefilling and request.prefill_logprobs:
|
|
||||||
out_start_index = batch.prefill_cu_outlens[i]
|
|
||||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
|
||||||
if not request_is_prefilling:
|
|
||||||
# The request is dones prefilling, meaning that we started generating new tokens
|
|
||||||
# The last logprob is a logprob for a generated token that was not part of the prompt
|
|
||||||
# We need to remove it
|
|
||||||
out_end_index -= 1
|
|
||||||
|
|
||||||
request_prefill_logprobs = prefill_logprobs[
|
new_input_length = 1
|
||||||
out_start_index:out_end_index
|
new_cache_length = cache_length + input_length + n_accepted_ids - 1
|
||||||
]
|
# Append next token to all tokens
|
||||||
# Logprobs generated by the model are for the next token
|
next_token_texts = []
|
||||||
# So we need to translate the id tensor by 1
|
left = 0
|
||||||
prefill_token_ids = all_input_ids[
|
|
||||||
cache_length + 1 : cache_length + input_length + 1
|
|
||||||
]
|
|
||||||
|
|
||||||
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
|
if n_accepted_ids > 1:
|
||||||
|
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
|
||||||
|
|
||||||
if past_prefill_logprob_tokens is None:
|
current_stopped = False
|
||||||
# add nan for cached prompt tokens/first token
|
index = indexs[idx]
|
||||||
request_prefill_logprobs = [float("nan")] * (
|
for j in range(index, index + n_accepted_ids):
|
||||||
cache_length + 1
|
# Generated token
|
||||||
) + request_prefill_logprobs
|
next_token_id = prev_batches[idx]["next_token_ids"][j]
|
||||||
prefill_token_ids = (
|
all_input_ids.append(next_token_id)
|
||||||
all_input_ids[: cache_length + 1] + prefill_token_ids
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
)
|
all_input_ids,
|
||||||
|
prefix_offset,
|
||||||
|
read_offset,
|
||||||
|
)
|
||||||
|
next_token_texts.append(next_token_text)
|
||||||
|
|
||||||
prefill_texts = self.tokenizer.batch_decode(
|
stop, reason = stopping_criteria(
|
||||||
prefill_token_ids,
|
next_token_id,
|
||||||
clean_up_tokenization_spaces=False,
|
next_token_text,
|
||||||
skip_special_tokens=False,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
prefill_logprob_tokens = Tokens(
|
if stop:
|
||||||
prefill_token_ids,
|
left = index + n_accepted_ids - j - 1
|
||||||
request_prefill_logprobs,
|
current_stopped = True
|
||||||
prefill_texts,
|
break
|
||||||
is_special=[],
|
|
||||||
)
|
|
||||||
if past_prefill_logprob_tokens is not None:
|
|
||||||
prefill_logprob_tokens = (
|
|
||||||
past_prefill_logprob_tokens + prefill_logprob_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
|
|
||||||
else:
|
else:
|
||||||
batch.prefill_logprob_tokens[i] = None
|
current_stopped = False
|
||||||
|
stopped = stopped and current_stopped
|
||||||
|
|
||||||
# If it is, the tokens we decoded should be ignored
|
_next_token_ids = prev_batches[idx]["next_token_ids"][
|
||||||
if request_is_prefilling:
|
index : index + n_accepted_ids - left
|
||||||
# Make sure that we do not stop as even though this request did not create a token, it is still
|
]
|
||||||
# processing
|
_next_token_logprobs = prev_batches[idx]["next_token_logprobs"][
|
||||||
stopped = False
|
index : index + n_accepted_ids - left
|
||||||
new_input_length = next_chunk_lengths[i]
|
]
|
||||||
new_cache_length = cache_length + input_length
|
|
||||||
else:
|
|
||||||
new_input_length = 1
|
|
||||||
new_cache_length = cache_length + input_length + n_accepted_ids - 1
|
|
||||||
# Append next token to all tokens
|
|
||||||
next_token_texts = []
|
|
||||||
left = 0
|
|
||||||
|
|
||||||
if n_accepted_ids > 1:
|
# Shard generations
|
||||||
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
|
# All generations will be appended in the rust sharded client
|
||||||
|
if request_id % self.world_size == self.rank:
|
||||||
current_stopped = False
|
if stop:
|
||||||
for j in range(index, index + n_accepted_ids):
|
# Decode generated tokens
|
||||||
# Generated token
|
output_text, _, _ = self.decode_token(
|
||||||
next_token_id = next_token_ids[j]
|
|
||||||
all_input_ids.append(next_token_id)
|
|
||||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
prefix_offset,
|
prefix_offset=len(all_input_ids)
|
||||||
read_offset,
|
- stopping_criteria.current_tokens
|
||||||
|
- 1,
|
||||||
|
read_offset=len(all_input_ids)
|
||||||
|
- stopping_criteria.current_tokens,
|
||||||
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
next_token_texts.append(next_token_text)
|
generated_text = GeneratedText(
|
||||||
|
output_text,
|
||||||
stop, reason = stopping_criteria(
|
stopping_criteria.current_tokens,
|
||||||
next_token_id,
|
reason,
|
||||||
next_token_text,
|
seed if do_sample else None,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
generated_text = None
|
||||||
|
|
||||||
if stop:
|
if top_n_tokens > 0:
|
||||||
left = index + n_accepted_ids - j - 1
|
all_top_tokens = []
|
||||||
current_stopped = True
|
for top_token_ids, top_token_logprobs in zip(
|
||||||
break
|
top_token_ids, top_token_logprobs
|
||||||
else:
|
):
|
||||||
current_stopped = False
|
toptoken_texts = self.tokenizer.batch_decode(
|
||||||
stopped = stopped and current_stopped
|
top_token_ids,
|
||||||
|
clean_up_tokenization_spaces=False,
|
||||||
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
|
skip_special_tokens=False,
|
||||||
_next_token_logprobs = next_token_logprobs[
|
|
||||||
index : index + n_accepted_ids - left
|
|
||||||
]
|
|
||||||
|
|
||||||
# Shard generations
|
|
||||||
# All generations will be appended in the rust sharded client
|
|
||||||
if request.id % self.world_size == self.rank:
|
|
||||||
if stop:
|
|
||||||
# Decode generated tokens
|
|
||||||
output_text, _, _ = self.decode_token(
|
|
||||||
all_input_ids,
|
|
||||||
prefix_offset=len(all_input_ids)
|
|
||||||
- stopping_criteria.current_tokens
|
|
||||||
- 1,
|
|
||||||
read_offset=len(all_input_ids)
|
|
||||||
- stopping_criteria.current_tokens,
|
|
||||||
skip_special_tokens=True,
|
|
||||||
)
|
)
|
||||||
generated_text = GeneratedText(
|
special_toptokens = [
|
||||||
output_text,
|
token_id in self.all_special_ids
|
||||||
stopping_criteria.current_tokens,
|
for token_id in top_token_ids
|
||||||
reason,
|
]
|
||||||
seed if do_sample else None,
|
top_tokens = Tokens(
|
||||||
|
top_token_ids,
|
||||||
|
top_token_logprobs,
|
||||||
|
toptoken_texts,
|
||||||
|
special_toptokens,
|
||||||
)
|
)
|
||||||
else:
|
all_top_tokens.append(top_tokens)
|
||||||
generated_text = None
|
top_tokens = all_top_tokens
|
||||||
|
else:
|
||||||
|
top_tokens = None
|
||||||
|
|
||||||
if top_n_tokens > 0:
|
generation = Generation(
|
||||||
all_top_tokens = []
|
request_id,
|
||||||
for top_token_ids, top_token_logprobs in zip(
|
None,
|
||||||
top_token_ids, top_token_logprobs
|
Tokens(
|
||||||
):
|
_next_token_ids,
|
||||||
toptoken_texts = self.tokenizer.batch_decode(
|
_next_token_logprobs,
|
||||||
top_token_ids,
|
next_token_texts,
|
||||||
clean_up_tokenization_spaces=False,
|
[nid in self.all_special_ids for nid in _next_token_ids],
|
||||||
skip_special_tokens=False,
|
),
|
||||||
)
|
generated_text,
|
||||||
special_toptokens = [
|
top_tokens,
|
||||||
token_id in self.all_special_ids
|
)
|
||||||
for token_id in top_token_ids
|
|
||||||
]
|
|
||||||
top_tokens = Tokens(
|
|
||||||
top_token_ids,
|
|
||||||
top_token_logprobs,
|
|
||||||
toptoken_texts,
|
|
||||||
special_toptokens,
|
|
||||||
)
|
|
||||||
all_top_tokens.append(top_tokens)
|
|
||||||
top_tokens = all_top_tokens
|
|
||||||
else:
|
|
||||||
top_tokens = None
|
|
||||||
|
|
||||||
generation = Generation(
|
generations.append(generation)
|
||||||
request.id,
|
|
||||||
batch.prefill_logprob_tokens[i],
|
|
||||||
Tokens(
|
|
||||||
_next_token_ids,
|
|
||||||
_next_token_logprobs,
|
|
||||||
next_token_texts,
|
|
||||||
[nid in self.all_special_ids for nid in _next_token_ids],
|
|
||||||
),
|
|
||||||
generated_text,
|
|
||||||
top_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
generations.append(generation)
|
|
||||||
|
|
||||||
# accept each new token for this specific request since we may
|
# accept each new token for this specific request since we may
|
||||||
# have more than one new token per request with speculative decoding
|
# have more than one new token per request with speculative decoding
|
||||||
@ -2231,7 +2163,8 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update values
|
# Update values
|
||||||
index += n_accepted_ids
|
indexs[idx] += n_accepted_ids
|
||||||
|
idx_accept_ids[idx] += 1
|
||||||
batch.cache_lengths[i] = new_cache_length
|
batch.cache_lengths[i] = new_cache_length
|
||||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||||
batch.input_lengths[i] = new_input_length
|
batch.input_lengths[i] = new_input_length
|
||||||
@ -2248,14 +2181,6 @@ class FlashCausalLM(Model):
|
|||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, None, (forward_ns, decode_ns)
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
if prefill and finished_prefilling:
|
|
||||||
# We do not need prefill tensors anymore
|
|
||||||
batch.cu_seqlen_prefill = None
|
|
||||||
batch.prefill_cache_indices = None
|
|
||||||
batch.prefill_cu_outlens = None
|
|
||||||
batch.prefill_head_indices = None
|
|
||||||
batch.prefill_next_token_indices = None
|
|
||||||
|
|
||||||
forward_ns = start_decode - start
|
forward_ns = start_decode - start
|
||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, batch, (forward_ns, decode_ns)
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
@ -456,6 +456,8 @@ class FlashVlmCausalLM(FlashCausalLM):
|
|||||||
for i, (batch_size, block_num) in enumerate(
|
for i, (batch_size, block_num) in enumerate(
|
||||||
reversed(self.bucketing_ctx.decode_buckets)
|
reversed(self.bucketing_ctx.decode_buckets)
|
||||||
):
|
):
|
||||||
|
if batch_size > block_num:
|
||||||
|
continue
|
||||||
log_master(
|
log_master(
|
||||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||||
)
|
)
|
||||||
|
@ -368,6 +368,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
for i, (batch_size, block_num) in enumerate(
|
for i, (batch_size, block_num) in enumerate(
|
||||||
reversed(self.bucketing_ctx.decode_buckets)
|
reversed(self.bucketing_ctx.decode_buckets)
|
||||||
):
|
):
|
||||||
|
if batch_size > block_num:
|
||||||
|
continue
|
||||||
log_master(
|
log_master(
|
||||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user