mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 12:02:08 +00:00
Speedup 2x.
- Still wrong when batched - Incorrect returned payload. (No multiple ids/logprobs)
This commit is contained in:
parent
8897b89606
commit
866af9b9fd
@ -301,7 +301,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
import ipdb;ipdb.set_trace()
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
@ -454,9 +453,20 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
speculative_ids: Optional[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
if speculative_ids is not None:
|
||||
print(speculative_ids.shape, input_ids.shape)
|
||||
speculative_length = speculative_ids.shape[1]
|
||||
new_length = speculative_length + 1
|
||||
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0)
|
||||
new_position_ids = (position_ids.view((1, -1)).expand(speculative_ids.shape[1] + 1, 1) + torch.arange(speculative_ids.shape[1] + 1).unsqueeze(1).to(device="cuda:0")).squeeze(0).squeeze(-1)
|
||||
new_position_ids = (position_ids.view((1, -1)).expand(new_length, 1) + torch.arange(new_length).unsqueeze(1).to(device=position_ids.device)).squeeze(0).squeeze(-1)
|
||||
|
||||
# Add an extra block just in case
|
||||
block_tables = torch.cat([block_tables, block_tables[:, -1:] + 1], dim=1)
|
||||
# Add Copy the block tables for all members
|
||||
block_tables = block_tables.expand(new_length, -1).contiguous()
|
||||
slots = slots.expand(new_length) + torch.arange(new_length, dtype=slots.dtype).to(device=slots.device)
|
||||
input_lengths = input_lengths.expand(new_length) + torch.arange(new_length, dtype=input_lengths.dtype).to(device=input_lengths.device)
|
||||
max_s = max_s + speculative_length
|
||||
|
||||
|
||||
input_ids = new_input_ids
|
||||
position_ids = new_position_ids
|
||||
|
||||
|
@ -793,7 +793,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
# if next_token_logits.shape[0] == 3:
|
||||
# import ipdb;ipdb.set_trace()
|
||||
next_input_ids, next_token_logprobs, logprobs, speculative_ids = batch.next_token_chooser(
|
||||
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
|
||||
)
|
||||
|
||||
@ -835,6 +835,7 @@ class FlashCausalLM(Model):
|
||||
iterator = zip(
|
||||
batch.input_lengths,
|
||||
batch.all_input_ids,
|
||||
accepted_ids
|
||||
)
|
||||
|
||||
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
||||
@ -842,10 +843,11 @@ class FlashCausalLM(Model):
|
||||
# It is faster if we delay this sync for the maximum amount of time
|
||||
|
||||
# For each member of the batch
|
||||
step = 1 + speculative_length
|
||||
index = 0
|
||||
for i, (
|
||||
input_length,
|
||||
all_input_ids,
|
||||
n_accepted_ids
|
||||
) in enumerate(iterator):
|
||||
# Indexing metadata
|
||||
start_index = cumulative_length
|
||||
@ -859,8 +861,6 @@ class FlashCausalLM(Model):
|
||||
|
||||
# Initialize position_ids
|
||||
# In decode, we do not need this as we can just increment position ids
|
||||
# for j in range(1 + speculative_length):
|
||||
# next_position_ids[i * step + j] = batch.position_ids[end_index - 1] + j
|
||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||
|
||||
# Used to gather prefill logprobs
|
||||
@ -876,17 +876,29 @@ class FlashCausalLM(Model):
|
||||
start_index + 1 : start_index + out_length
|
||||
]
|
||||
|
||||
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
|
||||
for j in range(n_accepted_ids):
|
||||
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
|
||||
index += 1
|
||||
|
||||
cumulative_length += input_length
|
||||
|
||||
|
||||
# if accepted_ids[0] > 1:
|
||||
# import ipdb;ipdb.set_trace()
|
||||
|
||||
if len(accepted_ids) > 1:
|
||||
raise Exception("Implemtent the batched behavior")
|
||||
|
||||
# Set values in batch
|
||||
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
|
||||
batch.input_ids = next_input_ids
|
||||
batch.speculative_ids = speculative_ids
|
||||
batch.position_ids = next_position_ids + 1
|
||||
batch.input_lengths_tensor += 1
|
||||
batch.slot_indices += 1
|
||||
|
||||
for n_accepted_ids in accepted_ids:
|
||||
# TODO Make this batched
|
||||
batch.input_ids = next_input_ids[-1:]
|
||||
batch.speculative_ids = speculative_ids
|
||||
batch.position_ids = next_position_ids + n_accepted_ids
|
||||
batch.input_lengths_tensor += n_accepted_ids
|
||||
batch.slot_indices += n_accepted_ids
|
||||
|
||||
if prefill and prefill_logprobs:
|
||||
# Get prefill logprobs
|
||||
@ -899,7 +911,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
# GPU <-> CPU sync
|
||||
next_token_logprobs = next_token_logprobs.tolist()
|
||||
next_token_ids = batch.input_ids.tolist()
|
||||
next_token_ids = next_input_ids.tolist()
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
@ -912,13 +924,15 @@ class FlashCausalLM(Model):
|
||||
batch.next_token_chooser.do_sample,
|
||||
batch.next_token_chooser.seeds,
|
||||
batch.top_n_tokens,
|
||||
next_token_ids,
|
||||
next_token_logprobs,
|
||||
# next_token_ids,
|
||||
# next_token_logprobs,
|
||||
accepted_ids,
|
||||
batch_top_token_ids,
|
||||
batch_top_token_logprobs,
|
||||
)
|
||||
|
||||
# For each member of the batch
|
||||
index = 0
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
@ -929,13 +943,16 @@ class FlashCausalLM(Model):
|
||||
do_sample,
|
||||
seed,
|
||||
top_n_tokens,
|
||||
next_token_id,
|
||||
next_token_logprob,
|
||||
# next_token_id,
|
||||
# next_token_logprob,
|
||||
n_accepted_ids,
|
||||
top_token_ids,
|
||||
top_token_logprobs,
|
||||
) in enumerate(iterator):
|
||||
# Append next token to all tokens
|
||||
all_input_ids.append(next_token_id)
|
||||
_next_token_ids = next_token_ids[index: index+n_accepted_ids]
|
||||
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
|
||||
all_input_ids.extend(_next_token_ids)
|
||||
|
||||
# Generated token
|
||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||
@ -945,13 +962,18 @@ class FlashCausalLM(Model):
|
||||
)
|
||||
|
||||
# Evaluate stopping criteria
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id,
|
||||
next_token_text,
|
||||
)
|
||||
|
||||
if not stop:
|
||||
stopped = False
|
||||
for next_token_id in _next_token_ids:
|
||||
stop, reason = stopping_criteria(
|
||||
next_token_id,
|
||||
next_token_text,
|
||||
)
|
||||
|
||||
if stop:
|
||||
stopped = True
|
||||
break
|
||||
if not stop:
|
||||
stopped = False
|
||||
|
||||
# Shard generations
|
||||
# All generations will be appended in the rust sharded client
|
||||
@ -1015,6 +1037,9 @@ class FlashCausalLM(Model):
|
||||
else:
|
||||
top_tokens = None
|
||||
|
||||
next_token_ids = _next_token_ids[0]
|
||||
next_token_logprob = _next_token_logprobs[0]
|
||||
|
||||
generation = Generation(
|
||||
request.id,
|
||||
prefill_tokens,
|
||||
|
@ -225,19 +225,21 @@ class HeterogeneousNextTokenChooser:
|
||||
scores = warper(input_ids, scores)
|
||||
|
||||
|
||||
accepted_ids = []
|
||||
next_ids = self.choice(scores)
|
||||
if speculated_ids is not None:
|
||||
validate_speculative = next_ids[1:] == speculated_ids[0]
|
||||
validate_speculative = next_ids[:-1] == speculated_ids[0]
|
||||
index = 1
|
||||
for valid in validate_speculative.tolist():
|
||||
if valid:
|
||||
index += 1
|
||||
print(f"Validated {index - 1}")
|
||||
# print(f"Validated {index - 1}")
|
||||
next_ids = next_ids[:index]
|
||||
scores = scores[:index]
|
||||
speculative_scores = speculative_scores[index - 1:index]
|
||||
if index > 1:
|
||||
import ipdb;ipdb.set_trace()
|
||||
accepted_ids.append(index)
|
||||
else:
|
||||
accepted_ids.append(1)
|
||||
|
||||
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
@ -255,10 +257,12 @@ class HeterogeneousNextTokenChooser:
|
||||
# for warper in self.warpers:
|
||||
# speculative_scores = warper(input_ids, speculative_scores)
|
||||
speculative_ids = Greedy()(speculative_scores)
|
||||
# # Ignore first head, it seems to be a regular head.
|
||||
# speculative_ids = speculative_ids[:, 1:]
|
||||
else:
|
||||
speculative_ids = None
|
||||
|
||||
return next_ids, next_logprobs, logprobs, speculative_ids
|
||||
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
|
||||
|
||||
def filter(self, indices):
|
||||
if self.watermark_processor is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user