mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 12:02:08 +00:00
Speedup 2x.
- Still wrong when batched - Incorrect returned payload. (No multiple ids/logprobs)
This commit is contained in:
parent
8897b89606
commit
866af9b9fd
@ -301,7 +301,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
import ipdb;ipdb.set_trace()
|
|
||||||
paged_attention.attention(
|
paged_attention.attention(
|
||||||
attn_output,
|
attn_output,
|
||||||
query,
|
query,
|
||||||
@ -454,9 +453,20 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
speculative_ids: Optional[torch.Tensor]
|
speculative_ids: Optional[torch.Tensor]
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if speculative_ids is not None:
|
if speculative_ids is not None:
|
||||||
print(speculative_ids.shape, input_ids.shape)
|
speculative_length = speculative_ids.shape[1]
|
||||||
|
new_length = speculative_length + 1
|
||||||
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0)
|
new_input_ids = torch.cat([input_ids.unsqueeze(-1), speculative_ids], dim=1).squeeze(0)
|
||||||
new_position_ids = (position_ids.view((1, -1)).expand(speculative_ids.shape[1] + 1, 1) + torch.arange(speculative_ids.shape[1] + 1).unsqueeze(1).to(device="cuda:0")).squeeze(0).squeeze(-1)
|
new_position_ids = (position_ids.view((1, -1)).expand(new_length, 1) + torch.arange(new_length).unsqueeze(1).to(device=position_ids.device)).squeeze(0).squeeze(-1)
|
||||||
|
|
||||||
|
# Add an extra block just in case
|
||||||
|
block_tables = torch.cat([block_tables, block_tables[:, -1:] + 1], dim=1)
|
||||||
|
# Add Copy the block tables for all members
|
||||||
|
block_tables = block_tables.expand(new_length, -1).contiguous()
|
||||||
|
slots = slots.expand(new_length) + torch.arange(new_length, dtype=slots.dtype).to(device=slots.device)
|
||||||
|
input_lengths = input_lengths.expand(new_length) + torch.arange(new_length, dtype=input_lengths.dtype).to(device=input_lengths.device)
|
||||||
|
max_s = max_s + speculative_length
|
||||||
|
|
||||||
|
|
||||||
input_ids = new_input_ids
|
input_ids = new_input_ids
|
||||||
position_ids = new_position_ids
|
position_ids = new_position_ids
|
||||||
|
|
||||||
|
@ -793,7 +793,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# if next_token_logits.shape[0] == 3:
|
# if next_token_logits.shape[0] == 3:
|
||||||
# import ipdb;ipdb.set_trace()
|
# import ipdb;ipdb.set_trace()
|
||||||
next_input_ids, next_token_logprobs, logprobs, speculative_ids = batch.next_token_chooser(
|
next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser(
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
|
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, batch.speculative_ids, speculative_logits
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -835,6 +835,7 @@ class FlashCausalLM(Model):
|
|||||||
iterator = zip(
|
iterator = zip(
|
||||||
batch.input_lengths,
|
batch.input_lengths,
|
||||||
batch.all_input_ids,
|
batch.all_input_ids,
|
||||||
|
accepted_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
|
||||||
@ -842,10 +843,11 @@ class FlashCausalLM(Model):
|
|||||||
# It is faster if we delay this sync for the maximum amount of time
|
# It is faster if we delay this sync for the maximum amount of time
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
step = 1 + speculative_length
|
index = 0
|
||||||
for i, (
|
for i, (
|
||||||
input_length,
|
input_length,
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
|
n_accepted_ids
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Indexing metadata
|
# Indexing metadata
|
||||||
start_index = cumulative_length
|
start_index = cumulative_length
|
||||||
@ -859,8 +861,6 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Initialize position_ids
|
# Initialize position_ids
|
||||||
# In decode, we do not need this as we can just increment position ids
|
# In decode, we do not need this as we can just increment position ids
|
||||||
# for j in range(1 + speculative_length):
|
|
||||||
# next_position_ids[i * step + j] = batch.position_ids[end_index - 1] + j
|
|
||||||
next_position_ids[i] = batch.position_ids[end_index - 1]
|
next_position_ids[i] = batch.position_ids[end_index - 1]
|
||||||
|
|
||||||
# Used to gather prefill logprobs
|
# Used to gather prefill logprobs
|
||||||
@ -876,17 +876,29 @@ class FlashCausalLM(Model):
|
|||||||
start_index + 1 : start_index + out_length
|
start_index + 1 : start_index + out_length
|
||||||
]
|
]
|
||||||
|
|
||||||
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
|
for j in range(n_accepted_ids):
|
||||||
|
batch.all_input_ids_tensor[i, input_length + j] = next_input_ids[index]
|
||||||
|
index += 1
|
||||||
|
|
||||||
cumulative_length += input_length
|
cumulative_length += input_length
|
||||||
|
|
||||||
|
|
||||||
|
# if accepted_ids[0] > 1:
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
|
||||||
|
if len(accepted_ids) > 1:
|
||||||
|
raise Exception("Implemtent the batched behavior")
|
||||||
|
|
||||||
# Set values in batch
|
# Set values in batch
|
||||||
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
|
# batch.input_ids = torch.cat([next_input_ids.unsqueeze(-1), speculative_ids], dim=1).view(-1)
|
||||||
batch.input_ids = next_input_ids
|
|
||||||
batch.speculative_ids = speculative_ids
|
for n_accepted_ids in accepted_ids:
|
||||||
batch.position_ids = next_position_ids + 1
|
# TODO Make this batched
|
||||||
batch.input_lengths_tensor += 1
|
batch.input_ids = next_input_ids[-1:]
|
||||||
batch.slot_indices += 1
|
batch.speculative_ids = speculative_ids
|
||||||
|
batch.position_ids = next_position_ids + n_accepted_ids
|
||||||
|
batch.input_lengths_tensor += n_accepted_ids
|
||||||
|
batch.slot_indices += n_accepted_ids
|
||||||
|
|
||||||
if prefill and prefill_logprobs:
|
if prefill and prefill_logprobs:
|
||||||
# Get prefill logprobs
|
# Get prefill logprobs
|
||||||
@ -899,7 +911,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# GPU <-> CPU sync
|
# GPU <-> CPU sync
|
||||||
next_token_logprobs = next_token_logprobs.tolist()
|
next_token_logprobs = next_token_logprobs.tolist()
|
||||||
next_token_ids = batch.input_ids.tolist()
|
next_token_ids = next_input_ids.tolist()
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
@ -912,13 +924,15 @@ class FlashCausalLM(Model):
|
|||||||
batch.next_token_chooser.do_sample,
|
batch.next_token_chooser.do_sample,
|
||||||
batch.next_token_chooser.seeds,
|
batch.next_token_chooser.seeds,
|
||||||
batch.top_n_tokens,
|
batch.top_n_tokens,
|
||||||
next_token_ids,
|
# next_token_ids,
|
||||||
next_token_logprobs,
|
# next_token_logprobs,
|
||||||
|
accepted_ids,
|
||||||
batch_top_token_ids,
|
batch_top_token_ids,
|
||||||
batch_top_token_logprobs,
|
batch_top_token_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
|
index = 0
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
@ -929,13 +943,16 @@ class FlashCausalLM(Model):
|
|||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens,
|
top_n_tokens,
|
||||||
next_token_id,
|
# next_token_id,
|
||||||
next_token_logprob,
|
# next_token_logprob,
|
||||||
|
n_accepted_ids,
|
||||||
top_token_ids,
|
top_token_ids,
|
||||||
top_token_logprobs,
|
top_token_logprobs,
|
||||||
) in enumerate(iterator):
|
) in enumerate(iterator):
|
||||||
# Append next token to all tokens
|
# Append next token to all tokens
|
||||||
all_input_ids.append(next_token_id)
|
_next_token_ids = next_token_ids[index: index+n_accepted_ids]
|
||||||
|
_next_token_logprobs = next_token_logprobs[index: index+n_accepted_ids]
|
||||||
|
all_input_ids.extend(_next_token_ids)
|
||||||
|
|
||||||
# Generated token
|
# Generated token
|
||||||
next_token_text, prefix_offset, read_offset = self.decode_token(
|
next_token_text, prefix_offset, read_offset = self.decode_token(
|
||||||
@ -945,13 +962,18 @@ class FlashCausalLM(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate stopping criteria
|
# Evaluate stopping criteria
|
||||||
stop, reason = stopping_criteria(
|
|
||||||
next_token_id,
|
|
||||||
next_token_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not stop:
|
for next_token_id in _next_token_ids:
|
||||||
stopped = False
|
stop, reason = stopping_criteria(
|
||||||
|
next_token_id,
|
||||||
|
next_token_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stop:
|
||||||
|
stopped = True
|
||||||
|
break
|
||||||
|
if not stop:
|
||||||
|
stopped = False
|
||||||
|
|
||||||
# Shard generations
|
# Shard generations
|
||||||
# All generations will be appended in the rust sharded client
|
# All generations will be appended in the rust sharded client
|
||||||
@ -1015,6 +1037,9 @@ class FlashCausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
top_tokens = None
|
top_tokens = None
|
||||||
|
|
||||||
|
next_token_ids = _next_token_ids[0]
|
||||||
|
next_token_logprob = _next_token_logprobs[0]
|
||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
prefill_tokens,
|
prefill_tokens,
|
||||||
|
@ -225,19 +225,21 @@ class HeterogeneousNextTokenChooser:
|
|||||||
scores = warper(input_ids, scores)
|
scores = warper(input_ids, scores)
|
||||||
|
|
||||||
|
|
||||||
|
accepted_ids = []
|
||||||
next_ids = self.choice(scores)
|
next_ids = self.choice(scores)
|
||||||
if speculated_ids is not None:
|
if speculated_ids is not None:
|
||||||
validate_speculative = next_ids[1:] == speculated_ids[0]
|
validate_speculative = next_ids[:-1] == speculated_ids[0]
|
||||||
index = 1
|
index = 1
|
||||||
for valid in validate_speculative.tolist():
|
for valid in validate_speculative.tolist():
|
||||||
if valid:
|
if valid:
|
||||||
index += 1
|
index += 1
|
||||||
print(f"Validated {index - 1}")
|
# print(f"Validated {index - 1}")
|
||||||
next_ids = next_ids[:index]
|
next_ids = next_ids[:index]
|
||||||
scores = scores[:index]
|
scores = scores[:index]
|
||||||
speculative_scores = speculative_scores[index - 1:index]
|
speculative_scores = speculative_scores[index - 1:index]
|
||||||
if index > 1:
|
accepted_ids.append(index)
|
||||||
import ipdb;ipdb.set_trace()
|
else:
|
||||||
|
accepted_ids.append(1)
|
||||||
|
|
||||||
|
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
@ -255,10 +257,12 @@ class HeterogeneousNextTokenChooser:
|
|||||||
# for warper in self.warpers:
|
# for warper in self.warpers:
|
||||||
# speculative_scores = warper(input_ids, speculative_scores)
|
# speculative_scores = warper(input_ids, speculative_scores)
|
||||||
speculative_ids = Greedy()(speculative_scores)
|
speculative_ids = Greedy()(speculative_scores)
|
||||||
|
# # Ignore first head, it seems to be a regular head.
|
||||||
|
# speculative_ids = speculative_ids[:, 1:]
|
||||||
else:
|
else:
|
||||||
speculative_ids = None
|
speculative_ids = None
|
||||||
|
|
||||||
return next_ids, next_logprobs, logprobs, speculative_ids
|
return next_ids, next_logprobs, logprobs, accepted_ids, speculative_ids
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user