diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f1a4854f9..b1474e0dd 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -46,6 +46,7 @@ class FlashCausalLMBatch(Batch): # tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill cu_seqlen_prefill: Optional[torch.Tensor] + cu_seqlen_speculative: Optional[torch.Tensor] # Paged Attention values @@ -121,6 +122,7 @@ class FlashCausalLMBatch(Batch): position_ids = [] cu_seqlen_prefill = [0] + cu_seqlen_speculative = [0] needed_blocks_slots = [] start_slots = [] slot_indices = [] @@ -160,9 +162,17 @@ class FlashCausalLMBatch(Batch): tokenized_input = tokenized_input[-r.truncate :] + # TODO remove this + # Scaffolding to speculate some ids + speculate_ids = [1, 2] + tokenized_input.extend([1, 2]) + + input_length = len(tokenized_input) input_lengths.append(input_length) + + prefix_offsets.append(input_length - 5) read_offsets.append(input_length) @@ -174,6 +184,7 @@ class FlashCausalLMBatch(Batch): # Add cumulative lengths of all previous inputs cu_seqlen_prefill.append(cumulative_length + input_length) + cu_seqlen_speculative.append(cumulative_length + input_length - len(speculate_ids)) next_token_chooser_parameters.append(r.parameters) @@ -255,6 +266,9 @@ class FlashCausalLMBatch(Batch): cu_seqlen_prefill = torch.tensor( cu_seqlen_prefill, device=device, dtype=torch.int32 ) + cu_seqlen_speculative = torch.tensor( + cu_seqlen_speculative, device=device, dtype=torch.int32 + ) position_ids = position_ids.to(device) slot_indices = slot_indices.to(device) @@ -287,6 +301,7 @@ class FlashCausalLMBatch(Batch): input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, + cu_seqlen_speculative=cu_seqlen_speculative, start_slots=start_slots, slot_indices=slot_indices, needed_blocks_slots=needed_blocks_slots, @@ -752,15 +767,29 @@ class FlashCausalLM(Model): del batch raise e + try: + out, speculative_logits = out.logits, out.speculative_logits + except Exception: + out = out + speculative_logits = None + + if prefill: next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out ) + if speculative_logits is not None: + speculative_logits = ( + speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits + ) else: next_token_logits = out - next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser( - batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits + + + import ipdb;ipdb.set_trace() + next_input_ids, next_token_logprobs, logprobs, speculative_ids = batch.next_token_chooser( + batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, speculative_logits ) batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( @@ -773,12 +802,20 @@ class FlashCausalLM(Model): # When batch == 1, we will just use the batch.input_ids values directly prefill_tokens_indices = batch.input_ids.new_zeros(len(out)) - next_position_ids = batch.position_ids.new_empty(len(batch)) + if speculative_ids is not None: + # TODO + # length = len(batch) * speculative_ids.shape[1] + length = len(batch) + else: + length = len(batch) + # import ipdb;ipdb.set_trace() + next_position_ids = batch.position_ids.new_empty(length) batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1] # We do not need cu_seqlen_prefill anymore batch.cu_seqlen_prefill = None else: prefill_logprobs = None + # import ipdb;ipdb.set_trace() next_position_ids = batch.position_ids # Cumulative length diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 42a82a1f2..3a84b1b6a 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -77,7 +77,8 @@ class FlashLlama(FlashCausalLM): medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt") medusa_sf = medusa_head[:-len(".pt")] + ".safetensors" weights = Weights([medusa_sf], device, dtype, process_group=self.process_group) - model.lm_head = MedusaModel(config, weights) + lm_head = model.lm_head + model.lm_head = MedusaModel(config, weights, lm_head) torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( diff --git a/server/text_generation_server/utils/medusa.py b/server/text_generation_server/utils/medusa.py index a62300165..ce9083333 100644 --- a/server/text_generation_server/utils/medusa.py +++ b/server/text_generation_server/utils/medusa.py @@ -1,6 +1,12 @@ import torch +from dataclasses import dataclass from text_generation_server.utils.layers import TensorParallelHead, FastLinear +@dataclass +class Output: + logits: torch.FloatTensor = None + speculative_logits: torch.FloatTensor = None + class ResBlock(torch.nn.Module): def __init__(self, config, prefix, weights): @@ -16,12 +22,19 @@ class MedusaModel(torch.nn.Module): def __init__( self, config, - weights + weights, + lm_head ): super().__init__() self.heads = torch.nn.ModuleList( [MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])] ) + self.lm_head = lm_head + + def forward(self, x): + logits = self.lm_head(x) + speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) + return Output(logits=logits, speculative_logits=speculative_logits) class MedusaHead(torch.nn.Module): diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 0ff074171..eee4f660d 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -215,7 +215,7 @@ class HeterogeneousNextTokenChooser: self.dtype = dtype self.device = device - def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor): + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculative_scores: Optional[torch.Tensor] = None): if self.watermark_processor is not None: scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: @@ -224,11 +224,27 @@ class HeterogeneousNextTokenChooser: for warper in self.warpers: scores = warper(input_ids, scores) + next_ids = self.choice(scores) logprobs = torch.log_softmax(scores, -1) next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) - return next_ids, next_logprobs, logprobs + if speculative_scores is not None: + # length, spec_length, vocab_size = speculative_scores.shape + # speculative_scores = speculative_scores.view((-1, vocab_size)) + # if self.watermark_processor is not None: + # speculative_scores = self.watermark_processor(input_ids, speculative_scores) + # if self.repetition_processor is not None: + # speculative_scores = self.repetition_processor(input_ids, speculative_scores) + + # speculative_scores = speculative_scores.view((length, spec_length, vocab_size)) + # for warper in self.warpers: + # speculative_scores = warper(input_ids, speculative_scores) + speculative_ids = Greedy()(speculative_scores) + else: + speculative_ids = None + + return next_ids, next_logprobs, logprobs, speculative_ids def filter(self, indices): if self.watermark_processor is not None: