mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-19 22:02:06 +00:00
pingpong optimization
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
29703dbd27
commit
cd900c3b72
@ -253,6 +253,9 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
hpu_attn_meta: Optional[HPUPagedAttentionMetadata]
|
||||
|
||||
next_token_logits: Optional[torch.Tensor]
|
||||
speculative_logits: Optional[torch.Tensor]
|
||||
|
||||
def to_pb(self) -> generate_pb2.CachedBatch:
|
||||
return generate_pb2.CachedBatch(
|
||||
id=self.batch_id,
|
||||
@ -490,6 +493,8 @@ class FlashCausalLMBatch(Batch):
|
||||
input_lengths_tensor=None,
|
||||
adapter_meta=None,
|
||||
hpu_attn_meta=None,
|
||||
next_token_logits=None,
|
||||
speculative_logits=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -698,6 +703,8 @@ class FlashCausalLMBatch(Batch):
|
||||
speculative_ids=speculative_ids,
|
||||
adapter_meta=adapter_meta,
|
||||
hpu_attn_meta=None,
|
||||
next_token_logits=None,
|
||||
speculative_logits=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -959,6 +966,8 @@ class FlashCausalLMBatch(Batch):
|
||||
speculative_ids=speculative_ids,
|
||||
adapter_meta=adapter_meta,
|
||||
hpu_attn_meta=None,
|
||||
next_token_logits=None,
|
||||
speculative_logits=None,
|
||||
)
|
||||
|
||||
def prepare_for_decode(self, dtype, use_contiguous_pa, bucketing_ctx):
|
||||
@ -1484,7 +1493,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
log_master(logger.info, f"KV-cache blocks: {num_blocks}, size: {BLOCK_SIZE}")
|
||||
if max_total_tokens is None:
|
||||
max_total_tokens = sum(batch.cache_lengths)
|
||||
max_total_tokens = sum(batch.input_lengths)
|
||||
|
||||
if max_input_tokens is None:
|
||||
max_input_tokens = max_total_tokens - 1
|
||||
@ -1531,6 +1540,8 @@ class FlashCausalLM(Model):
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
):
|
||||
if batch_size > block_num:
|
||||
continue
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
)
|
||||
@ -1803,6 +1814,144 @@ class FlashCausalLM(Model):
|
||||
def generate_token(
|
||||
self, batches: List[FlashCausalLMBatch]
|
||||
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch], Tuple[int, int]]:
|
||||
|
||||
# In order to pipeline any actions on CPU we perform the operation in 3 main stages:
|
||||
# Stage 1. Collect next token ids of any previously started generations
|
||||
prev_batches = []
|
||||
requests_to_generate = []
|
||||
for batch_id, batch in enumerate(batches):
|
||||
if batch.next_token_logits is not None:
|
||||
prefill = batch.prefilling
|
||||
if batch.prefilling:
|
||||
batch.prefilling = False
|
||||
batch.prefilling_mask = [False] * len(batch)
|
||||
|
||||
speculate = get_speculate()
|
||||
(
|
||||
next_input_ids,
|
||||
next_token_logprobs,
|
||||
logprobs,
|
||||
accepted_ids,
|
||||
speculative_ids,
|
||||
) = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_current_length],
|
||||
batch.next_token_logits,
|
||||
speculate,
|
||||
batch.speculative_ids,
|
||||
batch.speculative_logits,
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens,
|
||||
batch.top_n_tokens_tensor,
|
||||
logprobs,
|
||||
accepted_ids,
|
||||
)
|
||||
|
||||
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
|
||||
# instantly become of shape [BATCH_SIZE]
|
||||
if prefill:
|
||||
indices = batch.cu_seqlen_prefill[1:] - 1
|
||||
# pad in left
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.position_ids = batch.position_ids[
|
||||
batch.prefill_cache_indices
|
||||
][indices]
|
||||
else:
|
||||
batch.position_ids = batch.position_ids[indices]
|
||||
|
||||
batch.slot_indices = batch.slot_indices[indices]
|
||||
batch.adapter_meta.adapter_indices = (
|
||||
batch.adapter_meta.adapter_indices[indices]
|
||||
)
|
||||
# For each member of the batch
|
||||
# Cumulative length
|
||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||
if batch.speculative_logits is not None:
|
||||
for i in range(len(batch)):
|
||||
batch.all_input_ids_tensor[
|
||||
i,
|
||||
batch.cache_lengths_tensor[i]
|
||||
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
|
||||
+ batch.input_lengths[i]
|
||||
+ accepted_ids[i],
|
||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||
else:
|
||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||
batch_idx = torch.arange(
|
||||
0,
|
||||
batch.all_input_ids_tensor.shape[0],
|
||||
dtype=torch.long,
|
||||
device=batch.input_lengths_tensor.device,
|
||||
)
|
||||
batch.all_input_ids_tensor.index_put_(
|
||||
(batch_idx, index.long()), next_input_ids
|
||||
)
|
||||
|
||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
if batch.position_ids.dim() == 2:
|
||||
# Qwen2_vl case:
|
||||
batch.position_ids += accepted_ids.unsqueeze(-1)
|
||||
else:
|
||||
batch.position_ids += accepted_ids
|
||||
batch.cache_lengths_tensor += (
|
||||
batch.input_lengths_tensor + accepted_ids - 1
|
||||
)
|
||||
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||
batch.slot_indices += accepted_ids
|
||||
|
||||
# Does a HPU <-> CPU sync internally
|
||||
if prefill:
|
||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||
adapter_segments, _ = find_segments(
|
||||
batch.adapter_meta.adapter_indices
|
||||
)
|
||||
batch.adapter_meta.adapter_segments = torch.tensor(
|
||||
adapter_segments,
|
||||
dtype=torch.int32,
|
||||
device=batch.adapter_meta.adapter_segments.device,
|
||||
)
|
||||
prev_batches.append(
|
||||
{
|
||||
"next_token_ids": next_input_ids,
|
||||
"next_token_logprobs": next_token_logprobs,
|
||||
"accepted_ids": accepted_ids,
|
||||
}
|
||||
)
|
||||
idx = len(prev_batches) - 1
|
||||
|
||||
for req_idx, req in enumerate(batch.requests):
|
||||
requests_to_generate.append(
|
||||
{
|
||||
"idx": idx,
|
||||
"request_id": req.id,
|
||||
"cache_length": batch.cache_lengths[req_idx],
|
||||
"input_length": batch.input_lengths[req_idx],
|
||||
"prefix_offset": batch.prefix_offsets[req_idx],
|
||||
"read_offset": batch.read_offsets[req_idx],
|
||||
"stopping_criteria": batch.stopping_criterias[req_idx],
|
||||
"all_input_ids": batch.all_input_ids[req_idx],
|
||||
"do_sample": batch.next_token_chooser.do_sample[req_idx],
|
||||
"seed": batch.next_token_chooser.seeds[req_idx],
|
||||
"top_n_tokens": batch.top_n_tokens[req_idx],
|
||||
"top_token_ids": batch_top_token_ids[req_idx],
|
||||
"top_token_logprobs": batch_top_token_logprobs[req_idx],
|
||||
}
|
||||
)
|
||||
if prefill:
|
||||
# We do not need prefill tensors anymore
|
||||
batch.cu_seqlen_prefill = None
|
||||
batch.prefill_cache_indices = None
|
||||
batch.prefill_cu_outlens = None
|
||||
batch.prefill_head_indices = None
|
||||
batch.prefill_next_token_indices = None
|
||||
batch.next_token_logits = None
|
||||
batch.speculative_ids = None
|
||||
|
||||
htorch.core.mark_step()
|
||||
# Stage 2. Prepare new batch for speculative scheduling
|
||||
if len(batches) > 1:
|
||||
batch = self.batch_type.concatenate(batches)
|
||||
else:
|
||||
@ -1851,7 +2000,7 @@ class FlashCausalLM(Model):
|
||||
out, speculative_logits = self.forward(batch, adapter_data)
|
||||
|
||||
if prefill:
|
||||
next_token_logits = (
|
||||
batch.next_token_logits = (
|
||||
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||
)
|
||||
if speculative_logits is not None:
|
||||
@ -1862,263 +2011,43 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
else:
|
||||
prefill_logprobs = None
|
||||
next_token_logits = out
|
||||
batch.next_token_logits = out
|
||||
batch.speculative_logits = speculative_logits
|
||||
|
||||
finished_prefilling = True
|
||||
next_chunk_lengths = []
|
||||
current_prefilling_mask = batch.prefilling_mask
|
||||
if prefill:
|
||||
finished_prefilling = True
|
||||
next_prefilling_mask = [False] * len(batch)
|
||||
|
||||
batch.prefilling = not finished_prefilling
|
||||
batch.prefilling_mask = next_prefilling_mask
|
||||
|
||||
speculate = get_speculate()
|
||||
(
|
||||
next_input_ids,
|
||||
next_token_logprobs,
|
||||
logprobs,
|
||||
accepted_ids,
|
||||
speculative_ids,
|
||||
) = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_current_length],
|
||||
next_token_logits,
|
||||
speculate,
|
||||
batch.speculative_ids,
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
||||
)
|
||||
|
||||
# Since we are done prefilling, all the tensors that were concatenating values for all the requests
|
||||
# instantly become of shape [BATCH_SIZE]
|
||||
if prefill and finished_prefilling:
|
||||
indices = batch.cu_seqlen_prefill[1:] - 1
|
||||
# pad in left
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.position_ids = batch.position_ids[batch.prefill_cache_indices][
|
||||
indices
|
||||
]
|
||||
else:
|
||||
batch.position_ids = batch.position_ids[indices]
|
||||
|
||||
batch.slot_indices = batch.slot_indices[indices]
|
||||
batch.adapter_meta.adapter_indices = batch.adapter_meta.adapter_indices[
|
||||
indices
|
||||
]
|
||||
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
||||
# one, we need to first do a HPU <-> CPU sync
|
||||
# It is faster if we delay this sync for the maximum amount of time
|
||||
|
||||
# For each member of the batch
|
||||
# Cumulative length
|
||||
cu_accepted_ids = accepted_ids.new_zeros(accepted_ids.shape[0] + 1)
|
||||
torch.cumsum(accepted_ids, dim=0, out=cu_accepted_ids[1:])
|
||||
if speculative_logits is not None:
|
||||
for i in range(len(batch)):
|
||||
batch.all_input_ids_tensor[
|
||||
i,
|
||||
batch.cache_lengths_tensor[i]
|
||||
+ batch.input_lengths[i] : batch.cache_lengths_tensor[i]
|
||||
+ batch.input_lengths[i]
|
||||
+ accepted_ids[i],
|
||||
] = next_input_ids[cu_accepted_ids[i] : cu_accepted_ids[i + 1]]
|
||||
else:
|
||||
index = batch.cache_lengths_tensor + batch.input_lengths_tensor
|
||||
batch_idx = torch.arange(
|
||||
0,
|
||||
batch.all_input_ids_tensor.shape[0],
|
||||
dtype=torch.long,
|
||||
device=batch.input_lengths_tensor.device,
|
||||
)
|
||||
batch.all_input_ids_tensor.index_put_(
|
||||
(batch_idx, index.long()), next_input_ids
|
||||
)
|
||||
|
||||
# Update values
|
||||
# These values can be updated without a HPU -> CPU sync
|
||||
if not prefill or (prefill and finished_prefilling):
|
||||
batch.input_ids = next_input_ids[cu_accepted_ids[1:] - 1]
|
||||
batch.speculative_ids = speculative_ids
|
||||
if batch.position_ids.dim() == 2:
|
||||
# Qwen2_vl case:
|
||||
batch.position_ids += accepted_ids.unsqueeze(-1)
|
||||
else:
|
||||
batch.position_ids += accepted_ids
|
||||
batch.cache_lengths_tensor += batch.input_lengths_tensor + accepted_ids - 1
|
||||
batch.input_lengths_tensor = torch.ones_like(batch.input_lengths_tensor)
|
||||
batch.slot_indices += accepted_ids
|
||||
|
||||
# Does a HPU <-> CPU sync internally
|
||||
if prefill and finished_prefilling:
|
||||
# adjust segment lengths to account for all request lengths being 1 during decoding
|
||||
adapter_segments, _ = find_segments(batch.adapter_meta.adapter_indices)
|
||||
batch.adapter_meta.adapter_segments = torch.tensor(
|
||||
adapter_segments,
|
||||
dtype=torch.int32,
|
||||
device=batch.adapter_meta.adapter_segments.device,
|
||||
)
|
||||
|
||||
# HPU <-> CPU sync
|
||||
next_token_logprobs = next_token_logprobs.tolist()
|
||||
next_token_ids = next_input_ids.tolist()
|
||||
accepted_ids = accepted_ids.tolist()
|
||||
|
||||
# Update values if we need to continue prefilling
|
||||
# This represents the `else` case of the `Update values` if above
|
||||
# but since this require the `next_token_ids` to be on CPU, it is better to do it here
|
||||
if prefill and not finished_prefilling:
|
||||
# Speculation must be ignored while we prefill even with chunking
|
||||
# it simplifies everything
|
||||
assert batch.speculative_ids is None
|
||||
|
||||
all_postfix_ids = []
|
||||
for i, (
|
||||
request_prefilling,
|
||||
next_token_id,
|
||||
all_input_ids,
|
||||
cache_length,
|
||||
input_length,
|
||||
next_chunk_length,
|
||||
) in enumerate(
|
||||
zip(
|
||||
batch.prefilling_mask,
|
||||
next_token_ids,
|
||||
batch.all_input_ids,
|
||||
batch.cache_lengths,
|
||||
batch.input_lengths,
|
||||
next_chunk_lengths,
|
||||
)
|
||||
):
|
||||
if request_prefilling:
|
||||
next_cache_length = cache_length + input_length
|
||||
# Get new prompt IDs to prefill
|
||||
postfix_ids = all_input_ids[
|
||||
next_cache_length : next_cache_length + next_chunk_length
|
||||
]
|
||||
else:
|
||||
# This request is done prefilling, the new id is the one selected the sampling method
|
||||
postfix_ids = [next_token_id]
|
||||
|
||||
all_postfix_ids.append(postfix_ids)
|
||||
|
||||
batch.input_ids = all_postfix_ids
|
||||
# HPU->CPU sync
|
||||
for prev_batch in prev_batches:
|
||||
prev_batch["next_token_logprobs"] = prev_batch[
|
||||
"next_token_logprobs"
|
||||
].tolist()
|
||||
prev_batch["next_token_ids"] = prev_batch["next_token_ids"].tolist()
|
||||
prev_batch["accepted_ids"] = prev_batch["accepted_ids"].tolist()
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
||||
# Stage 3. Finish and return previous generations
|
||||
# Results
|
||||
generations: List[Generation] = []
|
||||
stopped = True
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
batch.requests,
|
||||
batch.prompt_lengths,
|
||||
batch.cache_lengths,
|
||||
batch.input_lengths,
|
||||
batch.prefix_offsets,
|
||||
batch.read_offsets,
|
||||
batch.stopping_criterias,
|
||||
batch.all_input_ids,
|
||||
batch.next_token_chooser.do_sample,
|
||||
batch.next_token_chooser.seeds,
|
||||
batch.top_n_tokens,
|
||||
current_prefilling_mask,
|
||||
batch.prefilling_mask,
|
||||
accepted_ids,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
stopped = len(requests_to_generate) > 0
|
||||
# Reset max_input_length
|
||||
batch.max_input_length = 0
|
||||
# For each member of the batch
|
||||
index = 0
|
||||
for i, (
|
||||
request,
|
||||
prompt_length,
|
||||
cache_length,
|
||||
input_length,
|
||||
prefix_offset,
|
||||
read_offset,
|
||||
stopping_criteria,
|
||||
all_input_ids,
|
||||
do_sample,
|
||||
seed,
|
||||
top_n_tokens,
|
||||
request_was_prefilling,
|
||||
request_is_prefilling,
|
||||
n_accepted_ids,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Compute logprobs first as, even though we might skip the token,
|
||||
# it can still be required to compute the logprobs
|
||||
# modulo on request.id as it is robust to batch.filter whereas the index in the batch is not and we need
|
||||
# this state to be stable
|
||||
if request.id % self.world_size == self.rank:
|
||||
# Prefill
|
||||
if request_was_prefilling and request.prefill_logprobs:
|
||||
out_start_index = batch.prefill_cu_outlens[i]
|
||||
out_end_index = batch.prefill_cu_outlens[i + 1]
|
||||
if not request_is_prefilling:
|
||||
# The request is dones prefilling, meaning that we started generating new tokens
|
||||
# The last logprob is a logprob for a generated token that was not part of the prompt
|
||||
# We need to remove it
|
||||
out_end_index -= 1
|
||||
indexs = [0] * len(prev_batches)
|
||||
idx_accept_ids = [0] * len(prev_batches)
|
||||
for i, req_data in enumerate(requests_to_generate):
|
||||
idx = req_data["idx"]
|
||||
request_id = req_data["request_id"]
|
||||
cache_length = req_data["cache_length"]
|
||||
input_length = req_data["input_length"]
|
||||
prefix_offset = req_data["prefix_offset"]
|
||||
read_offset = req_data["read_offset"]
|
||||
stopping_criteria = req_data["stopping_criteria"]
|
||||
all_input_ids = req_data["all_input_ids"]
|
||||
do_sample = req_data["do_sample"]
|
||||
seed = req_data["seed"]
|
||||
top_n_tokens = req_data["top_n_tokens"]
|
||||
n_accepted_ids = prev_batches[idx]["accepted_ids"][idx_accept_ids[idx]]
|
||||
top_token_ids = req_data["top_token_ids"]
|
||||
top_token_logprobs = req_data["top_token_logprobs"]
|
||||
|
||||
request_prefill_logprobs = prefill_logprobs[
|
||||
out_start_index:out_end_index
|
||||
]
|
||||
# Logprobs generated by the model are for the next token
|
||||
# So we need to translate the id tensor by 1
|
||||
prefill_token_ids = all_input_ids[
|
||||
cache_length + 1 : cache_length + input_length + 1
|
||||
]
|
||||
|
||||
past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]
|
||||
|
||||
if past_prefill_logprob_tokens is None:
|
||||
# add nan for cached prompt tokens/first token
|
||||
request_prefill_logprobs = [float("nan")] * (
|
||||
cache_length + 1
|
||||
) + request_prefill_logprobs
|
||||
prefill_token_ids = (
|
||||
all_input_ids[: cache_length + 1] + prefill_token_ids
|
||||
)
|
||||
|
||||
prefill_texts = self.tokenizer.batch_decode(
|
||||
prefill_token_ids,
|
||||
clean_up_tokenization_spaces=False,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
prefill_logprob_tokens = Tokens(
|
||||
prefill_token_ids,
|
||||
request_prefill_logprobs,
|
||||
prefill_texts,
|
||||
is_special=[],
|
||||
)
|
||||
if past_prefill_logprob_tokens is not None:
|
||||
prefill_logprob_tokens = (
|
||||
past_prefill_logprob_tokens + prefill_logprob_tokens
|
||||
)
|
||||
|
||||
batch.prefill_logprob_tokens[i] = prefill_logprob_tokens
|
||||
else:
|
||||
batch.prefill_logprob_tokens[i] = None
|
||||
|
||||
# If it is, the tokens we decoded should be ignored
|
||||
if request_is_prefilling:
|
||||
# Make sure that we do not stop as even though this request did not create a token, it is still
|
||||
# processing
|
||||
stopped = False
|
||||
new_input_length = next_chunk_lengths[i]
|
||||
new_cache_length = cache_length + input_length
|
||||
else:
|
||||
new_input_length = 1
|
||||
new_cache_length = cache_length + input_length + n_accepted_ids - 1
|
||||
# Append next token to all tokens
|
||||
@ -2129,9 +2058,10 @@ class FlashCausalLM(Model):
|
||||
log_master(logger.debug, f"speculated ids {n_accepted_ids - 1}")
|
||||
|
||||
current_stopped = False
|
||||
index = indexs[idx]
|
||||
for j in range(index, index + n_accepted_ids):
|
||||
# Generated token
|
||||
next_token_id = next_token_ids[j]
|
||||
next_token_id = prev_batches[idx]["next_token_ids"][j]
|
||||
all_input_ids.append(next_token_id)
|
||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||
all_input_ids,
|
||||
@ -2153,14 +2083,16 @@ class FlashCausalLM(Model):
|
||||
current_stopped = False
|
||||
stopped = stopped and current_stopped
|
||||
|
||||
_next_token_ids = next_token_ids[index : index + n_accepted_ids - left]
|
||||
_next_token_logprobs = next_token_logprobs[
|
||||
_next_token_ids = prev_batches[idx]["next_token_ids"][
|
||||
index : index + n_accepted_ids - left
|
||||
]
|
||||
_next_token_logprobs = prev_batches[idx]["next_token_logprobs"][
|
||||
index : index + n_accepted_ids - left
|
||||
]
|
||||
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
if request.id % self.world_size == self.rank:
|
||||
if request_id % self.world_size == self.rank:
|
||||
if stop:
|
||||
# Decode generated tokens
|
||||
output_text, _, _ = self.decode_token(
|
||||
@ -2207,8 +2139,8 @@ class FlashCausalLM(Model):
|
||||
top_tokens = None
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
batch.prefill_logprob_tokens[i],
|
||||
request_id,
|
||||
None,
|
||||
Tokens(
|
||||
_next_token_ids,
|
||||
_next_token_logprobs,
|
||||
@ -2231,7 +2163,8 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
# Update values
|
||||
index += n_accepted_ids
|
||||
indexs[idx] += n_accepted_ids
|
||||
idx_accept_ids[idx] += 1
|
||||
batch.cache_lengths[i] = new_cache_length
|
||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||
batch.input_lengths[i] = new_input_length
|
||||
@ -2248,14 +2181,6 @@ class FlashCausalLM(Model):
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
if prefill and finished_prefilling:
|
||||
# We do not need prefill tensors anymore
|
||||
batch.cu_seqlen_prefill = None
|
||||
batch.prefill_cache_indices = None
|
||||
batch.prefill_cu_outlens = None
|
||||
batch.prefill_head_indices = None
|
||||
batch.prefill_next_token_indices = None
|
||||
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
@ -456,6 +456,8 @@ class FlashVlmCausalLM(FlashCausalLM):
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
):
|
||||
if batch_size > block_num:
|
||||
continue
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
)
|
||||
|
@ -368,6 +368,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
for i, (batch_size, block_num) in enumerate(
|
||||
reversed(self.bucketing_ctx.decode_buckets)
|
||||
):
|
||||
if batch_size > block_num:
|
||||
continue
|
||||
log_master(
|
||||
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user