This commit is contained in:
Nicolas Patry 2023-09-18 13:37:59 +00:00
parent 94a0bf1bbc
commit a2e9ccbb10
4 changed files with 74 additions and 7 deletions

View File

@ -46,6 +46,7 @@ class FlashCausalLMBatch(Batch):
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
cu_seqlen_prefill: Optional[torch.Tensor] cu_seqlen_prefill: Optional[torch.Tensor]
cu_seqlen_speculative: Optional[torch.Tensor]
# Paged Attention values # Paged Attention values
@ -121,6 +122,7 @@ class FlashCausalLMBatch(Batch):
position_ids = [] position_ids = []
cu_seqlen_prefill = [0] cu_seqlen_prefill = [0]
cu_seqlen_speculative = [0]
needed_blocks_slots = [] needed_blocks_slots = []
start_slots = [] start_slots = []
slot_indices = [] slot_indices = []
@ -160,9 +162,17 @@ class FlashCausalLMBatch(Batch):
tokenized_input = tokenized_input[-r.truncate :] tokenized_input = tokenized_input[-r.truncate :]
# TODO remove this
# Scaffolding to speculate some ids
speculate_ids = [1, 2]
tokenized_input.extend([1, 2])
input_length = len(tokenized_input) input_length = len(tokenized_input)
input_lengths.append(input_length) input_lengths.append(input_length)
prefix_offsets.append(input_length - 5) prefix_offsets.append(input_length - 5)
read_offsets.append(input_length) read_offsets.append(input_length)
@ -174,6 +184,7 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs # Add cumulative lengths of all previous inputs
cu_seqlen_prefill.append(cumulative_length + input_length) cu_seqlen_prefill.append(cumulative_length + input_length)
cu_seqlen_speculative.append(cumulative_length + input_length - len(speculate_ids))
next_token_chooser_parameters.append(r.parameters) next_token_chooser_parameters.append(r.parameters)
@ -255,6 +266,9 @@ class FlashCausalLMBatch(Batch):
cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill = torch.tensor(
cu_seqlen_prefill, device=device, dtype=torch.int32 cu_seqlen_prefill, device=device, dtype=torch.int32
) )
cu_seqlen_speculative = torch.tensor(
cu_seqlen_speculative, device=device, dtype=torch.int32
)
position_ids = position_ids.to(device) position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device) slot_indices = slot_indices.to(device)
@ -287,6 +301,7 @@ class FlashCausalLMBatch(Batch):
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
cu_seqlen_speculative=cu_seqlen_speculative,
start_slots=start_slots, start_slots=start_slots,
slot_indices=slot_indices, slot_indices=slot_indices,
needed_blocks_slots=needed_blocks_slots, needed_blocks_slots=needed_blocks_slots,
@ -752,15 +767,29 @@ class FlashCausalLM(Model):
del batch del batch
raise e raise e
try:
out, speculative_logits = out.logits, out.speculative_logits
except Exception:
out = out
speculative_logits = None
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out out[batch.prefill_next_token_indices] if prefill_logprobs else out
) )
if speculative_logits is not None:
speculative_logits = (
speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits
)
else: else:
next_token_logits = out next_token_logits = out
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
import ipdb;ipdb.set_trace()
next_input_ids, next_token_logprobs, logprobs, speculative_ids = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, speculative_logits
) )
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
@ -773,12 +802,20 @@ class FlashCausalLM(Model):
# When batch == 1, we will just use the batch.input_ids values directly # When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
next_position_ids = batch.position_ids.new_empty(len(batch)) if speculative_ids is not None:
# TODO
# length = len(batch) * speculative_ids.shape[1]
length = len(batch)
else:
length = len(batch)
# import ipdb;ipdb.set_trace()
next_position_ids = batch.position_ids.new_empty(length)
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
# We do not need cu_seqlen_prefill anymore # We do not need cu_seqlen_prefill anymore
batch.cu_seqlen_prefill = None batch.cu_seqlen_prefill = None
else: else:
prefill_logprobs = None prefill_logprobs = None
# import ipdb;ipdb.set_trace()
next_position_ids = batch.position_ids next_position_ids = batch.position_ids
# Cumulative length # Cumulative length

View File

@ -77,7 +77,8 @@ class FlashLlama(FlashCausalLM):
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt") medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt")
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors" medusa_sf = medusa_head[:-len(".pt")] + ".safetensors"
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group) weights = Weights([medusa_sf], device, dtype, process_group=self.process_group)
model.lm_head = MedusaModel(config, weights) lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__( super(FlashLlama, self).__init__(

View File

@ -1,6 +1,12 @@
import torch import torch
from dataclasses import dataclass
from text_generation_server.utils.layers import TensorParallelHead, FastLinear from text_generation_server.utils.layers import TensorParallelHead, FastLinear
@dataclass
class Output:
logits: torch.FloatTensor = None
speculative_logits: torch.FloatTensor = None
class ResBlock(torch.nn.Module): class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
@ -16,12 +22,19 @@ class MedusaModel(torch.nn.Module):
def __init__( def __init__(
self, self,
config, config,
weights weights,
lm_head
): ):
super().__init__() super().__init__()
self.heads = torch.nn.ModuleList( self.heads = torch.nn.ModuleList(
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])] [MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
) )
self.lm_head = lm_head
def forward(self, x):
logits = self.lm_head(x)
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return Output(logits=logits, speculative_logits=speculative_logits)
class MedusaHead(torch.nn.Module): class MedusaHead(torch.nn.Module):

View File

@ -215,7 +215,7 @@ class HeterogeneousNextTokenChooser:
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculative_scores: Optional[torch.Tensor] = None):
if self.watermark_processor is not None: if self.watermark_processor is not None:
scores = self.watermark_processor(input_ids, scores) scores = self.watermark_processor(input_ids, scores)
if self.repetition_processor is not None: if self.repetition_processor is not None:
@ -224,11 +224,27 @@ class HeterogeneousNextTokenChooser:
for warper in self.warpers: for warper in self.warpers:
scores = warper(input_ids, scores) scores = warper(input_ids, scores)
next_ids = self.choice(scores) next_ids = self.choice(scores)
logprobs = torch.log_softmax(scores, -1) logprobs = torch.log_softmax(scores, -1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
return next_ids, next_logprobs, logprobs if speculative_scores is not None:
# length, spec_length, vocab_size = speculative_scores.shape
# speculative_scores = speculative_scores.view((-1, vocab_size))
# if self.watermark_processor is not None:
# speculative_scores = self.watermark_processor(input_ids, speculative_scores)
# if self.repetition_processor is not None:
# speculative_scores = self.repetition_processor(input_ids, speculative_scores)
# speculative_scores = speculative_scores.view((length, spec_length, vocab_size))
# for warper in self.warpers:
# speculative_scores = warper(input_ids, speculative_scores)
speculative_ids = Greedy()(speculative_scores)
else:
speculative_ids = None
return next_ids, next_logprobs, logprobs, speculative_ids
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None: