mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 06:42:10 +00:00
Tmp.
This commit is contained in:
parent
94a0bf1bbc
commit
a2e9ccbb10
@ -46,6 +46,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
||||
cu_seqlen_prefill: Optional[torch.Tensor]
|
||||
cu_seqlen_speculative: Optional[torch.Tensor]
|
||||
|
||||
# Paged Attention values
|
||||
|
||||
@ -121,6 +122,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
position_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
cu_seqlen_speculative = [0]
|
||||
needed_blocks_slots = []
|
||||
start_slots = []
|
||||
slot_indices = []
|
||||
@ -160,9 +162,17 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
|
||||
# TODO remove this
|
||||
# Scaffolding to speculate some ids
|
||||
speculate_ids = [1, 2]
|
||||
tokenized_input.extend([1, 2])
|
||||
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
|
||||
|
||||
|
||||
prefix_offsets.append(input_length - 5)
|
||||
read_offsets.append(input_length)
|
||||
|
||||
@ -174,6 +184,7 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||
cu_seqlen_speculative.append(cumulative_length + input_length - len(speculate_ids))
|
||||
|
||||
next_token_chooser_parameters.append(r.parameters)
|
||||
|
||||
@ -255,6 +266,9 @@ class FlashCausalLMBatch(Batch):
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||
)
|
||||
cu_seqlen_speculative = torch.tensor(
|
||||
cu_seqlen_speculative, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
position_ids = position_ids.to(device)
|
||||
slot_indices = slot_indices.to(device)
|
||||
@ -287,6 +301,7 @@ class FlashCausalLMBatch(Batch):
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
cu_seqlen_speculative=cu_seqlen_speculative,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=needed_blocks_slots,
|
||||
@ -752,15 +767,29 @@ class FlashCausalLM(Model):
|
||||
del batch
|
||||
raise e
|
||||
|
||||
try:
|
||||
out, speculative_logits = out.logits, out.speculative_logits
|
||||
except Exception:
|
||||
out = out
|
||||
speculative_logits = None
|
||||
|
||||
|
||||
if prefill:
|
||||
next_token_logits = (
|
||||
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||
)
|
||||
if speculative_logits is not None:
|
||||
speculative_logits = (
|
||||
speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits
|
||||
)
|
||||
else:
|
||||
next_token_logits = out
|
||||
|
||||
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
|
||||
|
||||
|
||||
import ipdb;ipdb.set_trace()
|
||||
next_input_ids, next_token_logprobs, logprobs, speculative_ids = batch.next_token_chooser(
|
||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, speculative_logits
|
||||
)
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
@ -773,12 +802,20 @@ class FlashCausalLM(Model):
|
||||
# When batch == 1, we will just use the batch.input_ids values directly
|
||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||
|
||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
||||
if speculative_ids is not None:
|
||||
# TODO
|
||||
# length = len(batch) * speculative_ids.shape[1]
|
||||
length = len(batch)
|
||||
else:
|
||||
length = len(batch)
|
||||
# import ipdb;ipdb.set_trace()
|
||||
next_position_ids = batch.position_ids.new_empty(length)
|
||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||
# We do not need cu_seqlen_prefill anymore
|
||||
batch.cu_seqlen_prefill = None
|
||||
else:
|
||||
prefill_logprobs = None
|
||||
# import ipdb;ipdb.set_trace()
|
||||
next_position_ids = batch.position_ids
|
||||
|
||||
# Cumulative length
|
||||
|
@ -77,7 +77,8 @@ class FlashLlama(FlashCausalLM):
|
||||
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt")
|
||||
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors"
|
||||
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group)
|
||||
model.lm_head = MedusaModel(config, weights)
|
||||
lm_head = model.lm_head
|
||||
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
|
@ -1,6 +1,12 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
|
||||
|
||||
@dataclass
|
||||
class Output:
|
||||
logits: torch.FloatTensor = None
|
||||
speculative_logits: torch.FloatTensor = None
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
@ -16,12 +22,19 @@ class MedusaModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
weights
|
||||
weights,
|
||||
lm_head
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = torch.nn.ModuleList(
|
||||
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
|
||||
)
|
||||
self.lm_head = lm_head
|
||||
|
||||
def forward(self, x):
|
||||
logits = self.lm_head(x)
|
||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||
return Output(logits=logits, speculative_logits=speculative_logits)
|
||||
|
||||
|
||||
class MedusaHead(torch.nn.Module):
|
||||
|
@ -215,7 +215,7 @@ class HeterogeneousNextTokenChooser:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculative_scores: Optional[torch.Tensor] = None):
|
||||
if self.watermark_processor is not None:
|
||||
scores = self.watermark_processor(input_ids, scores)
|
||||
if self.repetition_processor is not None:
|
||||
@ -224,11 +224,27 @@ class HeterogeneousNextTokenChooser:
|
||||
for warper in self.warpers:
|
||||
scores = warper(input_ids, scores)
|
||||
|
||||
|
||||
next_ids = self.choice(scores)
|
||||
logprobs = torch.log_softmax(scores, -1)
|
||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||
|
||||
return next_ids, next_logprobs, logprobs
|
||||
if speculative_scores is not None:
|
||||
# length, spec_length, vocab_size = speculative_scores.shape
|
||||
# speculative_scores = speculative_scores.view((-1, vocab_size))
|
||||
# if self.watermark_processor is not None:
|
||||
# speculative_scores = self.watermark_processor(input_ids, speculative_scores)
|
||||
# if self.repetition_processor is not None:
|
||||
# speculative_scores = self.repetition_processor(input_ids, speculative_scores)
|
||||
|
||||
# speculative_scores = speculative_scores.view((length, spec_length, vocab_size))
|
||||
# for warper in self.warpers:
|
||||
# speculative_scores = warper(input_ids, speculative_scores)
|
||||
speculative_ids = Greedy()(speculative_scores)
|
||||
else:
|
||||
speculative_ids = None
|
||||
|
||||
return next_ids, next_logprobs, logprobs, speculative_ids
|
||||
|
||||
def filter(self, indices):
|
||||
if self.watermark_processor is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user