Speedup 2x.

- Still wrong when batched
- Incorrect returned payload. (No multiple ids/logprobs)
This commit is contained in:
Nicolas Patry 2023-11-29 14:36:17 +00:00
parent 8897b89606
commit 866af9b9fd
3 changed files with 69 additions and 30 deletions

View File

@ -301,7 +301,6 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
import ipdb;ipdb.set_trace()
paged_attention.attention(
attn_output,
query,
@ -454,9 +453,20 @@ class FlashLlamaModel(torch.nn.Module):
speculative_ids: Optional[torch.Tensor]
) -> torch.Tensor:
if speculative_ids is not None:
print(speculative_ids.shape, input_ids.shape)
speculative_length = speculative_ids.shape[1]
new_length = speculative_length + 1
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0)
new_position_ids = (position_ids.view((1, -1)).expand(speculative_ids.shape[1] + 1, 1) + torch.arange(speculative_ids.shape[1] + 1).unsqueeze(1).to(device="cuda:0")).squeeze(0).squeeze(-1)
new_position_ids = (position_ids.view((1, -1)).expand(new_length, 1) + torch.arange(new_length).unsqueeze(1).to(device=position_ids.device)).squeeze(0).squeeze(-1)
# Add an extra block just in case
block_tables = torch.cat([block_tables, block_tables[:, -1:] + 1], dim=1)
# Add Copy the block tables for all members
block_tables = block_tables.expand(new_length, -1).contiguous()
slots = slots.expand(new_length) + torch.arange(new_length, dtype=slots.dtype).to(device=slots.device)
input_lengths = input_lengths.expand(new_length) + torch.arange(new_length, dtype=input_lengths.dtype).to(device=input_lengths.device)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids

View File

@ -793,7 +793,7 @@ class FlashCausalLM(Model):
# if next_token_logits.shape[0] == 3:
# import ipdb;ipdb.set_trace()
next_input_ids, next_token_logprobs, logprobs, speculative_ids = batch.next_token_chooser(
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
)
@ -835,6 +835,7 @@ class FlashCausalLM(Model):
iterator = zip(
batch.input_lengths,
batch.all_input_ids,
accepted_ids
)
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
@ -842,10 +843,11 @@ class FlashCausalLM(Model):
# It is faster if we delay this sync for the maximum amount of time
# For each member of the batch
step = 1 + speculative_length
index = 0
for i, (
input_length,
all_input_ids,
n_accepted_ids
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
@ -859,8 +861,6 @@ class FlashCausalLM(Model):
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
# for j in range(1 + speculative_length):
# next_position_ids[i * step + j] = batch.position_ids[end_index - 1] + j
next_position_ids[i] = batch.position_ids[end_index - 1]
# Used to gather prefill logprobs
@ -876,17 +876,29 @@ class FlashCausalLM(Model):
start_index + 1 : start_index + out_length
]
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
index += 1
cumulative_length += input_length
# if accepted_ids[0] > 1:
# import ipdb;ipdb.set_trace()
if len(accepted_ids) > 1:
raise Exception("Implemtent the batched behavior")
# Set values in batch
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
batch.input_ids = next_input_ids
for n_accepted_ids in accepted_ids:
# TODO Make this batched
batch.input_ids = next_input_ids[-1:]
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + 1
batch.input_lengths_tensor += 1
batch.slot_indices += 1
batch.position_ids = next_position_ids + n_accepted_ids
batch.input_lengths_tensor += n_accepted_ids
batch.slot_indices += n_accepted_ids
if prefill and prefill_logprobs:
# Get prefill logprobs
@ -899,7 +911,7 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = batch.input_ids.tolist()
next_token_ids = next_input_ids.tolist()
# Zipped iterator
iterator = zip(
@ -912,13 +924,15 @@ class FlashCausalLM(Model):
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
next_token_ids,
next_token_logprobs,
# next_token_ids,
# next_token_logprobs,
accepted_ids,
batch_top_token_ids,
batch_top_token_logprobs,
)
# For each member of the batch
index = 0
for i, (
request,
input_length,
@ -929,13 +943,16 @@ class FlashCausalLM(Model):
do_sample,
seed,
top_n_tokens,
next_token_id,
next_token_logprob,
# next_token_id,
# next_token_logprob,
n_accepted_ids,
top_token_ids,
top_token_logprobs,
) in enumerate(iterator):
# Append next token to all tokens
all_input_ids.append(next_token_id)
_next_token_ids = next_token_ids[index: index+n_accepted_ids]
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
all_input_ids.extend(_next_token_ids)
# Generated token
next_token_text, prefix_offset, read_offset = self.decode_token(
@ -945,11 +962,16 @@ class FlashCausalLM(Model):
)
# Evaluate stopping criteria
for next_token_id in _next_token_ids:
stop, reason = stopping_criteria(
next_token_id,
next_token_text,
)
if stop:
stopped = True
break
if not stop:
stopped = False
@ -1015,6 +1037,9 @@ class FlashCausalLM(Model):
else:
top_tokens = None
next_token_ids = _next_token_ids[0]
next_token_logprob = _next_token_logprobs[0]
generation = Generation(
request.id,
prefill_tokens,

View File

@ -225,19 +225,21 @@ class HeterogeneousNextTokenChooser:
scores = warper(input_ids, scores)
accepted_ids = []
next_ids = self.choice(scores)
if speculated_ids is not None:
validate_speculative = next_ids[1:] == speculated_ids[0]
validate_speculative = next_ids[:-1] == speculated_ids[0]
index = 1
for valid in validate_speculative.tolist():
if valid:
index += 1
print(f"Validated {index - 1}")
# print(f"Validated {index - 1}")
next_ids = next_ids[:index]
scores = scores[:index]
speculative_scores = speculative_scores[index - 1:index]
if index > 1:
import ipdb;ipdb.set_trace()
accepted_ids.append(index)
else:
accepted_ids.append(1)
logprobs = torch.log_softmax(scores, -1)
@ -255,10 +257,12 @@ class HeterogeneousNextTokenChooser:
# for warper in self.warpers:
# speculative_scores = warper(input_ids, speculative_scores)
speculative_ids = Greedy()(speculative_scores)
# # Ignore first head, it seems to be a regular head.
# speculative_ids = speculative_ids[:, 1:]
else:
speculative_ids = None
return next_ids, next_logprobs, logprobs, speculative_ids
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
def filter(self, indices):
if self.watermark_processor is not None: