mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-27 21:12:07 +00:00
Speculative medusa (illegal address Paged).
This commit is contained in:
parent
a2e9ccbb10
commit
8897b89606
@ -301,6 +301,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
import ipdb;ipdb.set_trace()
|
||||||
paged_attention.attention(
|
paged_attention.attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
@ -450,7 +451,15 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
speculative_ids: Optional[torch.Tensor]
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if speculative_ids is not None:
|
||||||
|
print(speculative_ids.shape, input_ids.shape)
|
||||||
|
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0)
|
||||||
|
new_position_ids = (position_ids.view((1, -1)).expand(speculative_ids.shape[1] + 1, 1) + torch.arange(speculative_ids.shape[1] + 1).unsqueeze(1).to(device="cuda:0")).squeeze(0).squeeze(-1)
|
||||||
|
input_ids = new_input_ids
|
||||||
|
position_ids = new_position_ids
|
||||||
|
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
# Get rotary cos and sin for this forward
|
# Get rotary cos and sin for this forward
|
||||||
@ -501,6 +510,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
speculative_ids: Optional[torch.Tensor] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -511,6 +521,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
slots,
|
slots,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
|
speculative_ids,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
@ -41,6 +41,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
# Decoder values
|
# Decoder values
|
||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
position_ids: torch.Tensor
|
position_ids: torch.Tensor
|
||||||
|
speculative_ids: torch.Tensor
|
||||||
|
|
||||||
# Flash Attention values
|
# Flash Attention values
|
||||||
|
|
||||||
@ -121,6 +122,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
)["input_ids"]
|
)["input_ids"]
|
||||||
|
|
||||||
position_ids = []
|
position_ids = []
|
||||||
|
speculative_ids = []
|
||||||
cu_seqlen_prefill = [0]
|
cu_seqlen_prefill = [0]
|
||||||
cu_seqlen_speculative = [0]
|
cu_seqlen_speculative = [0]
|
||||||
needed_blocks_slots = []
|
needed_blocks_slots = []
|
||||||
@ -162,10 +164,11 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
tokenized_input = tokenized_input[-r.truncate :]
|
tokenized_input = tokenized_input[-r.truncate :]
|
||||||
|
|
||||||
# TODO remove this
|
# # TODO remove this
|
||||||
# Scaffolding to speculate some ids
|
# # Scaffolding to speculate some ids
|
||||||
speculate_ids = [1, 2]
|
# speculate_ids = [1, 2]
|
||||||
tokenized_input.extend([1, 2])
|
# tokenized_input.extend([1, 2])
|
||||||
|
speculate_ids = []
|
||||||
|
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
@ -324,6 +327,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
max_blocks=max_blocks,
|
max_blocks=max_blocks,
|
||||||
|
speculative_ids=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
@ -739,6 +743,7 @@ class FlashCausalLM(Model):
|
|||||||
input_lengths=batch.input_lengths_tensor,
|
input_lengths=batch.input_lengths_tensor,
|
||||||
max_s=batch.max_seqlen,
|
max_s=batch.max_seqlen,
|
||||||
lm_head_indices=batch.prefill_head_indices,
|
lm_head_indices=batch.prefill_head_indices,
|
||||||
|
speculative_ids =batch.speculative_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
@ -786,16 +791,17 @@ class FlashCausalLM(Model):
|
|||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
|
|
||||||
|
|
||||||
|
# if next_token_logits.shape[0] == 3:
|
||||||
import ipdb;ipdb.set_trace()
|
# import ipdb;ipdb.set_trace()
|
||||||
next_input_ids, next_token_logprobs, logprobs, speculative_ids = batch.next_token_chooser(
|
next_input_ids, next_token_logprobs, logprobs, speculative_ids = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, speculative_logits
|
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
|
||||||
if prefill:
|
if prefill:
|
||||||
if len(batch) > 1 and prefill_logprobs:
|
if len(batch) > 1 and prefill_logprobs:
|
||||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||||
@ -803,13 +809,13 @@ class FlashCausalLM(Model):
|
|||||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||||
|
|
||||||
if speculative_ids is not None:
|
if speculative_ids is not None:
|
||||||
# TODO
|
# length = len(batch) * (1 + speculative_length)
|
||||||
# length = len(batch) * speculative_ids.shape[1]
|
|
||||||
length = len(batch)
|
length = len(batch)
|
||||||
else:
|
else:
|
||||||
length = len(batch)
|
length = len(batch)
|
||||||
# import ipdb;ipdb.set_trace()
|
# import ipdb;ipdb.set_trace()
|
||||||
next_position_ids = batch.position_ids.new_empty(length)
|
next_position_ids = batch.position_ids.new_empty(length)
|
||||||
|
# Keep only 1 slot index, TODO make sure we recover the speculated ids slots later
|
||||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||||
# We do not need cu_seqlen_prefill anymore
|
# We do not need cu_seqlen_prefill anymore
|
||||||
batch.cu_seqlen_prefill = None
|
batch.cu_seqlen_prefill = None
|
||||||
@ -836,6 +842,7 @@ class FlashCausalLM(Model):
|
|||||||
# It is faster if we delay this sync for the maximum amount of time
|
# It is faster if we delay this sync for the maximum amount of time
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
|
step = 1 + speculative_length
|
||||||
for i, (
|
for i, (
|
||||||
input_length,
|
input_length,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
@ -852,6 +859,8 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Initialize position_ids
|
# Initialize position_ids
|
||||||
# In decode, we do not need this as we can just increment position ids
|
# In decode, we do not need this as we can just increment position ids
|
||||||
|
# for j in range(1 + speculative_length):
|
||||||
|
# next_position_ids[i * step + j] = batch.position_ids[end_index - 1] + j
|
||||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||||
|
|
||||||
# Used to gather prefill logprobs
|
# Used to gather prefill logprobs
|
||||||
@ -872,7 +881,9 @@ class FlashCausalLM(Model):
|
|||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
# Set values in batch
|
# Set values in batch
|
||||||
|
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
|
||||||
batch.input_ids = next_input_ids
|
batch.input_ids = next_input_ids
|
||||||
|
batch.speculative_ids = speculative_ids
|
||||||
batch.position_ids = next_position_ids + 1
|
batch.position_ids = next_position_ids + 1
|
||||||
batch.input_lengths_tensor += 1
|
batch.input_lengths_tensor += 1
|
||||||
batch.slot_indices += 1
|
batch.slot_indices += 1
|
||||||
@ -1031,6 +1042,8 @@ class FlashCausalLM(Model):
|
|||||||
batch.prefill_cu_outlens = None
|
batch.prefill_cu_outlens = None
|
||||||
batch.prefill_head_indices = None
|
batch.prefill_head_indices = None
|
||||||
batch.prefill_next_token_indices = None
|
batch.prefill_next_token_indices = None
|
||||||
|
if prefill:
|
||||||
|
batch.max_seqlen += speculative_length
|
||||||
batch.max_seqlen = batch.max_seqlen + 1
|
batch.max_seqlen = batch.max_seqlen + 1
|
||||||
|
|
||||||
return generations, batch
|
return generations, batch
|
||||||
|
@ -215,7 +215,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculative_scores: Optional[torch.Tensor] = None):
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
scores = self.watermark_processor(input_ids, scores)
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
if self.repetition_processor is not None:
|
if self.repetition_processor is not None:
|
||||||
@ -226,6 +226,20 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
|
|
||||||
next_ids = self.choice(scores)
|
next_ids = self.choice(scores)
|
||||||
|
if speculated_ids is not None:
|
||||||
|
validate_speculative = next_ids[1:] == speculated_ids[0]
|
||||||
|
index = 1
|
||||||
|
for valid in validate_speculative.tolist():
|
||||||
|
if valid:
|
||||||
|
index += 1
|
||||||
|
print(f"Validated {index - 1}")
|
||||||
|
next_ids = next_ids[:index]
|
||||||
|
scores = scores[:index]
|
||||||
|
speculative_scores = speculative_scores[index - 1:index]
|
||||||
|
if index > 1:
|
||||||
|
import ipdb;ipdb.set_trace()
|
||||||
|
|
||||||
|
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user