Fixing medusa (still wrong outputs, but functional).

This commit is contained in:
Nicolas Patry 2024-08-13 11:58:55 +02:00
parent b31ec3bc8c
commit 549f0e9ca7
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
5 changed files with 47 additions and 15 deletions

View File

@ -81,6 +81,7 @@
grpcio-status
grpcio-tools
hf-transfer
ipdb
loguru
mamba-ssm
marlin-kernels

20
router.nix Normal file
View File

@ -0,0 +1,20 @@
{ buildRustPackage, importCargoLock, pkg-config, protobuf, openssl }:
buildRustPackage {
name = "text-generation-router";
src = ./.;
sourceDir = ./backends/v3;
cargoLock = {
lockFile = ./Cargo.lock;
};
nativeBuildInputs = [ pkg-config ];
buildInputs = [ openssl.dev protobuf ];
doCheck = false;
}

View File

@ -32,6 +32,8 @@ class MedusaModel(torch.nn.Module):
)
def forward(self, x):
if not self.heads:
return None
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return speculative_logits

View File

@ -552,6 +552,7 @@ class FlashCausalLMBatch(Batch):
prefix_ids = []
input_lengths = []
prefix_lens = []
prefix_offsets = []
read_offsets = []
@ -573,12 +574,14 @@ class FlashCausalLMBatch(Batch):
# Get length
request_input_length = self.input_lengths[idx]
prefix_len = self.prefix_lens[idx]
max_seqlen = max(max_seqlen, request_input_length)
all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])
input_lengths.append(request_input_length)
prefix_lens.append(prefix_len)
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
@ -623,6 +626,7 @@ class FlashCausalLMBatch(Batch):
block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
prefix_lens_tensor = self.prefix_lens_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices)
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
speculative_ids = (
@ -658,6 +662,8 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens=None,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_lens=prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
@ -1033,12 +1039,11 @@ class FlashCausalLM(Model):
device=device
)
if not CUDA_GRAPHS:
self.decode_state = create_decode_state(
device=device,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
self.decode_state = create_decode_state(
device=device,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
super().__init__(
model_id=model_id,
@ -1404,6 +1409,10 @@ class FlashCausalLM(Model):
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
prefix_lens_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
+ arange_int
).view(-1)
# Add Copy the block tables for all members
block_tables = (
@ -1424,6 +1433,7 @@ class FlashCausalLM(Model):
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor
max_s = batch.max_seqlen
lm_head_indices = batch.prefill_head_indices
@ -1448,7 +1458,7 @@ class FlashCausalLM(Model):
input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=batch.prefix_lens_tensor,
prefix_lens_tensor=prefix_lens_tensor,
):
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
@ -1487,18 +1497,17 @@ class FlashCausalLM(Model):
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + batch.prefix_lens_tensor
input_lengths + prefix_lens_tensor
)
state = cuda_graph.get("state")
with self._forward_context(
block_tables=block_tables,
block_tables=cuda_graph["block_tables"],
cu_seqlen_prefill=None,
input_lengths=batch.input_lengths,
input_lengths_tensor=input_lengths,
input_lengths_tensor=cuda_graph["input_lengths"],
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=batch.prefix_lens_tensor,
state=state,
prefix_lens_tensor=prefix_lens_tensor,
state=cuda_graph["state"],
):
# Replay the graph
cuda_graph["graph"].replay()
@ -1929,7 +1938,7 @@ class FlashCausalLM(Model):
assert input_lengths_tensor is not None
return use_decode_state(
state=state if state is not None else self.decode_state,
input_lengths=input_lengths_tensor + prefix_lens_tensor,
input_lengths=input_lengths_tensor,
block_tables=block_tables,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,

View File

@ -5,7 +5,7 @@ from typing import Dict, Optional
from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False)
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
_expected = {"paged", "flashdecoding", "flashinfer"}