mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
Tmp.
This commit is contained in:
parent
94a0bf1bbc
commit
a2e9ccbb10
@ -46,6 +46,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
# tensor of length b containing the cumulative sequence lengths of the sequences in the batch, only used in prefill
|
||||||
cu_seqlen_prefill: Optional[torch.Tensor]
|
cu_seqlen_prefill: Optional[torch.Tensor]
|
||||||
|
cu_seqlen_speculative: Optional[torch.Tensor]
|
||||||
|
|
||||||
# Paged Attention values
|
# Paged Attention values
|
||||||
|
|
||||||
@ -121,6 +122,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
position_ids = []
|
position_ids = []
|
||||||
cu_seqlen_prefill = [0]
|
cu_seqlen_prefill = [0]
|
||||||
|
cu_seqlen_speculative = [0]
|
||||||
needed_blocks_slots = []
|
needed_blocks_slots = []
|
||||||
start_slots = []
|
start_slots = []
|
||||||
slot_indices = []
|
slot_indices = []
|
||||||
@ -160,9 +162,17 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
tokenized_input = tokenized_input[-r.truncate :]
|
tokenized_input = tokenized_input[-r.truncate :]
|
||||||
|
|
||||||
|
# TODO remove this
|
||||||
|
# Scaffolding to speculate some ids
|
||||||
|
speculate_ids = [1, 2]
|
||||||
|
tokenized_input.extend([1, 2])
|
||||||
|
|
||||||
|
|
||||||
input_length = len(tokenized_input)
|
input_length = len(tokenized_input)
|
||||||
input_lengths.append(input_length)
|
input_lengths.append(input_length)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
prefix_offsets.append(input_length - 5)
|
prefix_offsets.append(input_length - 5)
|
||||||
read_offsets.append(input_length)
|
read_offsets.append(input_length)
|
||||||
|
|
||||||
@ -174,6 +184,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Add cumulative lengths of all previous inputs
|
# Add cumulative lengths of all previous inputs
|
||||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||||
|
cu_seqlen_speculative.append(cumulative_length + input_length - len(speculate_ids))
|
||||||
|
|
||||||
next_token_chooser_parameters.append(r.parameters)
|
next_token_chooser_parameters.append(r.parameters)
|
||||||
|
|
||||||
@ -255,6 +266,9 @@ class FlashCausalLMBatch(Batch):
|
|||||||
cu_seqlen_prefill = torch.tensor(
|
cu_seqlen_prefill = torch.tensor(
|
||||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
cu_seqlen_speculative = torch.tensor(
|
||||||
|
cu_seqlen_speculative, device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
position_ids = position_ids.to(device)
|
position_ids = position_ids.to(device)
|
||||||
slot_indices = slot_indices.to(device)
|
slot_indices = slot_indices.to(device)
|
||||||
@ -287,6 +301,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
cu_seqlen_speculative=cu_seqlen_speculative,
|
||||||
start_slots=start_slots,
|
start_slots=start_slots,
|
||||||
slot_indices=slot_indices,
|
slot_indices=slot_indices,
|
||||||
needed_blocks_slots=needed_blocks_slots,
|
needed_blocks_slots=needed_blocks_slots,
|
||||||
@ -752,15 +767,29 @@ class FlashCausalLM(Model):
|
|||||||
del batch
|
del batch
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
try:
|
||||||
|
out, speculative_logits = out.logits, out.speculative_logits
|
||||||
|
except Exception:
|
||||||
|
out = out
|
||||||
|
speculative_logits = None
|
||||||
|
|
||||||
|
|
||||||
if prefill:
|
if prefill:
|
||||||
next_token_logits = (
|
next_token_logits = (
|
||||||
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
out[batch.prefill_next_token_indices] if prefill_logprobs else out
|
||||||
)
|
)
|
||||||
|
if speculative_logits is not None:
|
||||||
|
speculative_logits = (
|
||||||
|
speculative_logits[batch.prefill_next_token_indices] if prefill_logprobs else speculative_logits
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
next_token_logits = out
|
next_token_logits = out
|
||||||
|
|
||||||
next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
|
|
||||||
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
|
|
||||||
|
import ipdb;ipdb.set_trace()
|
||||||
|
next_input_ids, next_token_logprobs, logprobs, speculative_ids = batch.next_token_chooser(
|
||||||
|
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, speculative_logits
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
@ -773,12 +802,20 @@ class FlashCausalLM(Model):
|
|||||||
# When batch == 1, we will just use the batch.input_ids values directly
|
# When batch == 1, we will just use the batch.input_ids values directly
|
||||||
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
|
||||||
|
|
||||||
next_position_ids = batch.position_ids.new_empty(len(batch))
|
if speculative_ids is not None:
|
||||||
|
# TODO
|
||||||
|
# length = len(batch) * speculative_ids.shape[1]
|
||||||
|
length = len(batch)
|
||||||
|
else:
|
||||||
|
length = len(batch)
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
|
next_position_ids = batch.position_ids.new_empty(length)
|
||||||
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
batch.slot_indices = batch.slot_indices[batch.cu_seqlen_prefill[1:] - 1]
|
||||||
# We do not need cu_seqlen_prefill anymore
|
# We do not need cu_seqlen_prefill anymore
|
||||||
batch.cu_seqlen_prefill = None
|
batch.cu_seqlen_prefill = None
|
||||||
else:
|
else:
|
||||||
prefill_logprobs = None
|
prefill_logprobs = None
|
||||||
|
# import ipdb;ipdb.set_trace()
|
||||||
next_position_ids = batch.position_ids
|
next_position_ids = batch.position_ids
|
||||||
|
|
||||||
# Cumulative length
|
# Cumulative length
|
||||||
|
@ -77,7 +77,8 @@ class FlashLlama(FlashCausalLM):
|
|||||||
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt")
|
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt")
|
||||||
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors"
|
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors"
|
||||||
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group)
|
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group)
|
||||||
model.lm_head = MedusaModel(config, weights)
|
lm_head = model.lm_head
|
||||||
|
model.lm_head = MedusaModel(config, weights, lm_head)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashLlama, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
|
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Output:
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
speculative_logits: torch.FloatTensor = None
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(torch.nn.Module):
|
class ResBlock(torch.nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
@ -16,12 +22,19 @@ class MedusaModel(torch.nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
weights
|
weights,
|
||||||
|
lm_head
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.heads = torch.nn.ModuleList(
|
self.heads = torch.nn.ModuleList(
|
||||||
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
|
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
|
||||||
)
|
)
|
||||||
|
self.lm_head = lm_head
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
logits = self.lm_head(x)
|
||||||
|
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||||
|
return Output(logits=logits, speculative_logits=speculative_logits)
|
||||||
|
|
||||||
|
|
||||||
class MedusaHead(torch.nn.Module):
|
class MedusaHead(torch.nn.Module):
|
||||||
|
@ -215,7 +215,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
|
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, speculative_scores: Optional[torch.Tensor] = None):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
scores = self.watermark_processor(input_ids, scores)
|
scores = self.watermark_processor(input_ids, scores)
|
||||||
if self.repetition_processor is not None:
|
if self.repetition_processor is not None:
|
||||||
@ -224,11 +224,27 @@ class HeterogeneousNextTokenChooser:
|
|||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
scores = warper(input_ids, scores)
|
scores = warper(input_ids, scores)
|
||||||
|
|
||||||
|
|
||||||
next_ids = self.choice(scores)
|
next_ids = self.choice(scores)
|
||||||
logprobs = torch.log_softmax(scores, -1)
|
logprobs = torch.log_softmax(scores, -1)
|
||||||
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1)
|
||||||
|
|
||||||
return next_ids, next_logprobs, logprobs
|
if speculative_scores is not None:
|
||||||
|
# length, spec_length, vocab_size = speculative_scores.shape
|
||||||
|
# speculative_scores = speculative_scores.view((-1, vocab_size))
|
||||||
|
# if self.watermark_processor is not None:
|
||||||
|
# speculative_scores = self.watermark_processor(input_ids, speculative_scores)
|
||||||
|
# if self.repetition_processor is not None:
|
||||||
|
# speculative_scores = self.repetition_processor(input_ids, speculative_scores)
|
||||||
|
|
||||||
|
# speculative_scores = speculative_scores.view((length, spec_length, vocab_size))
|
||||||
|
# for warper in self.warpers:
|
||||||
|
# speculative_scores = warper(input_ids, speculative_scores)
|
||||||
|
speculative_ids = Greedy()(speculative_scores)
|
||||||
|
else:
|
||||||
|
speculative_ids = None
|
||||||
|
|
||||||
|
return next_ids, next_logprobs, logprobs, speculative_ids
|
||||||
|
|
||||||
def filter(self, indices):
|
def filter(self, indices):
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user