mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
Speculative medusa (illegal address Paged).
This commit is contained in:
parent
a2e9ccbb10
commit
8897b89606
@ -301,6 +301,7 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
import ipdb;ipdb.set_trace()
|
||||
paged_attention.attention(
|
||||
attn_output,
|
||||
query,
|
||||
@ -450,7 +451,15 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
speculative_ids: Optional[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
if speculative_ids is not None:
|
||||
print(speculative_ids.shape, input_ids.shape)
|
||||
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0)
|
||||
new_position_ids = (position_ids.view((1, -1)).expand(speculative_ids.shape[1] + 1, 1) + torch.arange(speculative_ids.shape[1] + 1).unsqueeze(1).to(device="cuda:0")).squeeze(0).squeeze(-1)
|
||||
input_ids = new_input_ids
|
||||
position_ids = new_position_ids
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
@ -501,6 +510,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
speculative_ids: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
@ -511,6 +521,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
speculative_ids,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
@ -41,6 +41,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Decoder values
|
||||
input_ids: torch.Tensor
|
||||
position_ids: torch.Tensor
|
||||
speculative_ids: torch.Tensor
|
||||
|
||||
# Flash Attention values
|
||||
|
||||
@ -121,6 +122,7 @@ class FlashCausalLMBatch(Batch):
|
||||
)["input_ids"]
|
||||
|
||||
position_ids = []
|
||||
speculative_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
cu_seqlen_speculative = [0]
|
||||
needed_blocks_slots = []
|
||||
@ -162,10 +164,11 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
|
||||
# TODO remove this
|
||||
# Scaffolding to speculate some ids
|
||||
speculate_ids = [1, 2]
|
||||
tokenized_input.extend([1, 2])
|
||||
# # TODO remove this
|
||||
# # Scaffolding to speculate some ids
|
||||
# speculate_ids = [1, 2]
|
||||
# tokenized_input.extend([1, 2])
|
||||
speculate_ids = []
|
||||
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
@ -324,6 +327,7 @@ class FlashCausalLMBatch(Batch):
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
speculative_ids=None,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
@ -739,6 +743,7 @@ class FlashCausalLM(Model):
|
||||
input_lengths=batch.input_lengths_tensor,
|
||||
max_s=batch.max_seqlen,
|
||||
lm_head_indices=batch.prefill_head_indices,
|
||||
speculative_ids =batch.speculative_ids
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
@ -786,16 +791,17 @@ class FlashCausalLM(Model):
|
||||
next_token_logits = out
|
||||
|
||||
|
||||
|
||||
import ipdb;ipdb.set_trace()
|
||||
# if next_token_logits.shape[0] == 3:
|
||||
# import ipdb;ipdb.set_trace()
|
||||
next_input_ids, next_token_logprobs, logprobs, speculative_ids = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, speculative_logits
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs
|
||||
)
|
||||
|
||||
speculative_length = 0 if speculative_ids is None else speculative_ids.shape[1]
|
||||
if prefill:
|
||||
if len(batch) > 1 and prefill_logprobs:
|
||||
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
|
||||
@ -803,13 +809,13 @@ class FlashCausalLM(Model):
|
||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||
|
||||
if speculative_ids is not None:
|
||||
# TODO
|
||||
# length = len(batch) * speculative_ids.shape[1]
|
||||
# length = len(batch) * (1 + speculative_length)
|
||||
length = len(batch)
|
||||
else:
|
||||
length = len(batch)
|
||||
# import ipdb;ipdb.set_trace()
|
||||
next_position_ids = batch.position_ids.new_empty(length)
|
||||
# Keep only 1 slot index, TODO make sure we recover the speculated ids slots later
|
||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||
# We do not need cu_seqlen_prefill anymore
|
||||
batch.cu_seqlen_prefill = None
|
||||
@ -836,6 +842,7 @@ class FlashCausalLM(Model):
|
||||
# It is faster if we delay this sync for the maximum amount of time
|
||||
|
||||
# For each member of the batch
|
||||
step = 1 + speculative_length
|
||||
for i, (
|
||||
input_length,
|
||||
all_input_ids,
|
||||
@ -852,6 +859,8 @@ class FlashCausalLM(Model):
|
||||
|
||||
# Initialize position_ids
|
||||
# In decode, we do not need this as we can just increment position ids
|
||||
# for j in range(1 + speculative_length):
|
||||
# next_position_ids[i * step + j] = batch.position_ids[end_index - 1] + j
|
||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||
|
||||
# Used to gather prefill logprobs
|
||||
@ -872,7 +881,9 @@ class FlashCausalLM(Model):
|
||||
cumulative_length += input_length
|
||||
|
||||
# Set values in batch
|
||||
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
|
||||
batch.input_ids = next_input_ids
|
||||
batch.speculative_ids = speculative_ids
|
||||
batch.position_ids = next_position_ids + 1
|
||||
batch.input_lengths_tensor += 1
|
||||
batch.slot_indices += 1
|
||||
@ -1031,6 +1042,8 @@ class FlashCausalLM(Model):
|
||||
batch.prefill_cu_outlens = None
|
||||
batch.prefill_head_indices = None
|
||||
batch.prefill_next_token_indices = None
|
||||
if prefill:
|
||||
batch.max_seqlen += speculative_length
|
||||
batch.max_seqlen = batch.max_seqlen + 1
|
||||
|
||||
return generations, batch
|
||||
|
@ -215,7 +215,7 @@ class HeterogeneousNextTokenChooser:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculative_scores: Optional[torch.Tensor] = None):
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculated_ids: Optional[torch.Tensor] = None, speculative_scores: Optional[torch.Tensor] = None):
|
||||
if self.watermark_processor is not None:
|
||||
scores = self.watermark_processor(input_ids, scores)
|
||||
if self.repetition_processor is not None:
|
||||
@ -226,6 +226,20 @@ class HeterogeneousNextTokenChooser:
|
||||
|
||||
|
||||
next_ids = self.choice(scores)
|
||||
if speculated_ids is not None:
|
||||
validate_speculative = next_ids[1:] == speculated_ids[0]
|
||||
index = 1
|
||||
for valid in validate_speculative.tolist():
|
||||
if valid:
|
||||
index += 1
|
||||
print(f"Validated {index - 1}")
|
||||
next_ids = next_ids[:index]
|
||||
scores = scores[:index]
|
||||
speculative_scores = speculative_scores[index - 1:index]
|
||||
if index > 1:
|
||||
import ipdb;ipdb.set_trace()
|
||||
|
||||
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user