Speculative medusa (illegal address Paged).

This commit is contained in:
Nicolas Patry 2023-11-28 22:23:03 +00:00
parent a2e9ccbb10
commit 8897b89606
3 changed files with 48 additions and 10 deletions

View File

@ -301,6 +301,7 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
import ipdb;ipdb.set_trace()
paged_attention.attention(
attn_output,
query,
@ -450,7 +451,15 @@ class FlashLlamaModel(torch.nn.Module):
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
speculative_ids: Optional[torch.Tensor]
) -> torch.Tensor:
if speculative_ids is not None:
print(speculative_ids.shape, input_ids.shape)
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0)
new_position_ids = (position_ids.view((1, -1)).expand(speculative_ids.shape[1] + 1, 1) + torch.arange(speculative_ids.shape[1] + 1).unsqueeze(1).to(device="cuda:0")).squeeze(0).squeeze(-1)
input_ids = new_input_ids
position_ids = new_position_ids
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
@ -501,6 +510,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
speculative_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
@ -511,6 +521,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
slots,
input_lengths,
max_s,
speculative_ids,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]

View File

@ -41,6 +41,7 @@ class FlashCausalLMBatch(Batch):
# Decoder values
input_ids: torch.Tensor
position_ids: torch.Tensor
speculative_ids: torch.Tensor
# Flash Attention values
@ -121,6 +122,7 @@ class FlashCausalLMBatch(Batch):
)["input_ids"]
position_ids = []
speculative_ids = []
cu_seqlen_prefill = [0]
cu_seqlen_speculative = [0]
needed_blocks_slots = []
@ -162,10 +164,11 @@ class FlashCausalLMBatch(Batch):
tokenized_input = tokenized_input[-r.truncate :]
# TODO remove this
# Scaffolding to speculate some ids
speculate_ids = [1, 2]
tokenized_input.extend([1, 2])
# # TODO remove this
# # Scaffolding to speculate some ids
# speculate_ids = [1, 2]
# tokenized_input.extend([1, 2])
speculate_ids = []
input_length = len(tokenized_input)
@ -324,6 +327,7 @@ class FlashCausalLMBatch(Batch):
top_n_tokens_tensor=top_n_tokens_tensor,
blocks=blocks,
max_blocks=max_blocks,
speculative_ids=None,
)
@tracer.start_as_current_span("filter")
@ -739,6 +743,7 @@ class FlashCausalLM(Model):
input_lengths=batch.input_lengths_tensor,
max_s=batch.max_seqlen,
lm_head_indices=batch.prefill_head_indices,
speculative_ids =batch.speculative_ids
)
@tracer.start_as_current_span("generate_token")
@ -786,16 +791,17 @@ class FlashCausalLM(Model):
next_token_logits = out
import ipdb;ipdb.set_trace()
# if next_token_logits.shape[0] == 3:
# import ipdb;ipdb.set_trace()
next_input_ids, next_token_logprobs, logprobs, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, speculative_logits
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
)
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
)
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
if prefill:
if len(batch) > 1 and prefill_logprobs:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
@ -803,13 +809,13 @@ class FlashCausalLM(Model):
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
if speculative_ids is not None:
# TODO
# length = len(batch) * speculative_ids.shape[1]
# length = len(batch) * (1 + speculative_length)
length = len(batch)
else:
length = len(batch)
# import ipdb;ipdb.set_trace()
next_position_ids = batch.position_ids.new_empty(length)
# Keep only 1 slot index, TODO make sure we recover the speculated ids slots later
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
# We do not need cu_seqlen_prefill anymore
batch.cu_seqlen_prefill = None
@ -836,6 +842,7 @@ class FlashCausalLM(Model):
# It is faster if we delay this sync for the maximum amount of time
# For each member of the batch
step = 1 + speculative_length
for i, (
input_length,
all_input_ids,
@ -852,6 +859,8 @@ class FlashCausalLM(Model):
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
# for j in range(1 + speculative_length):
# next_position_ids[i * step + j] = batch.position_ids[end_index - 1] + j
next_position_ids[i] = batch.position_ids[end_index - 1]
# Used to gather prefill logprobs
@ -872,7 +881,9 @@ class FlashCausalLM(Model):
cumulative_length += input_length
# Set values in batch
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
batch.input_ids = next_input_ids
batch.speculative_ids = speculative_ids
batch.position_ids = next_position_ids + 1
batch.input_lengths_tensor += 1
batch.slot_indices += 1
@ -1031,6 +1042,8 @@ class FlashCausalLM(Model):
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
if prefill:
batch.max_seqlen += speculative_length
batch.max_seqlen = batch.max_seqlen + 1
return generations, batch

View File

@ -215,7 +215,7 @@ class HeterogeneousNextTokenChooser:
self.dtype = dtype
self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculative_scores: Optional[torch.Tensor] = None):
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None):
if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None:
@ -226,6 +226,20 @@ class HeterogeneousNextTokenChooser:
next_ids = self.choice(scores)
if speculated_ids is not None:
validate_speculative = next_ids[1:] == speculated_ids[0]
index = 1
for valid in validate_speculative.tolist():
if valid:
index += 1
print(f"Validated {index - 1}")
next_ids = next_ids[:index]
scores = scores[:index]
speculative_scores = speculative_scores[index - 1:index]
if index > 1:
import ipdb;ipdb.set_trace()
logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)