mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Fixing medusa (still wrong outputs, but functional).
This commit is contained in:
parent
b31ec3bc8c
commit
549f0e9ca7
@ -81,6 +81,7 @@
|
|||||||
grpcio-status
|
grpcio-status
|
||||||
grpcio-tools
|
grpcio-tools
|
||||||
hf-transfer
|
hf-transfer
|
||||||
|
ipdb
|
||||||
loguru
|
loguru
|
||||||
mamba-ssm
|
mamba-ssm
|
||||||
marlin-kernels
|
marlin-kernels
|
||||||
|
20
router.nix
Normal file
20
router.nix
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
{ buildRustPackage, importCargoLock, pkg-config, protobuf, openssl }:
|
||||||
|
|
||||||
|
buildRustPackage {
|
||||||
|
name = "text-generation-router";
|
||||||
|
|
||||||
|
src = ./.;
|
||||||
|
|
||||||
|
sourceDir = ./backends/v3;
|
||||||
|
|
||||||
|
cargoLock = {
|
||||||
|
lockFile = ./Cargo.lock;
|
||||||
|
};
|
||||||
|
|
||||||
|
nativeBuildInputs = [ pkg-config ];
|
||||||
|
|
||||||
|
buildInputs = [ openssl.dev protobuf ];
|
||||||
|
|
||||||
|
doCheck = false;
|
||||||
|
|
||||||
|
}
|
@ -32,6 +32,8 @@ class MedusaModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if not self.heads:
|
||||||
|
return None
|
||||||
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
|
||||||
return speculative_logits
|
return speculative_logits
|
||||||
|
|
||||||
|
@ -552,6 +552,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefix_ids = []
|
prefix_ids = []
|
||||||
|
|
||||||
input_lengths = []
|
input_lengths = []
|
||||||
|
prefix_lens = []
|
||||||
prefix_offsets = []
|
prefix_offsets = []
|
||||||
read_offsets = []
|
read_offsets = []
|
||||||
|
|
||||||
@ -573,12 +574,14 @@ class FlashCausalLMBatch(Batch):
|
|||||||
|
|
||||||
# Get length
|
# Get length
|
||||||
request_input_length = self.input_lengths[idx]
|
request_input_length = self.input_lengths[idx]
|
||||||
|
prefix_len = self.prefix_lens[idx]
|
||||||
max_seqlen = max(max_seqlen, request_input_length)
|
max_seqlen = max(max_seqlen, request_input_length)
|
||||||
|
|
||||||
all_input_ids.append(self.all_input_ids[idx])
|
all_input_ids.append(self.all_input_ids[idx])
|
||||||
prefix_ids.append(self.prefix_ids[idx])
|
prefix_ids.append(self.prefix_ids[idx])
|
||||||
|
|
||||||
input_lengths.append(request_input_length)
|
input_lengths.append(request_input_length)
|
||||||
|
prefix_lens.append(prefix_len)
|
||||||
prefix_offsets.append(self.prefix_offsets[idx])
|
prefix_offsets.append(self.prefix_offsets[idx])
|
||||||
read_offsets.append(self.read_offsets[idx])
|
read_offsets.append(self.read_offsets[idx])
|
||||||
|
|
||||||
@ -623,6 +626,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
block_tables_tensor = self.block_tables_tensor[indices]
|
block_tables_tensor = self.block_tables_tensor[indices]
|
||||||
input_lengths_tensor = self.input_lengths_tensor[indices]
|
input_lengths_tensor = self.input_lengths_tensor[indices]
|
||||||
slots = self.slots[slot_filtering_indices]
|
slots = self.slots[slot_filtering_indices]
|
||||||
|
prefix_lens_tensor = self.prefix_lens_tensor[indices]
|
||||||
next_token_chooser = self.next_token_chooser.filter(indices)
|
next_token_chooser = self.next_token_chooser.filter(indices)
|
||||||
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
top_n_tokens_tensor = self.top_n_tokens_tensor[indices]
|
||||||
speculative_ids = (
|
speculative_ids = (
|
||||||
@ -658,6 +662,8 @@ class FlashCausalLMBatch(Batch):
|
|||||||
prefill_cu_outlens=None,
|
prefill_cu_outlens=None,
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
|
prefix_lens=prefix_lens,
|
||||||
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
prefix_offsets=prefix_offsets,
|
prefix_offsets=prefix_offsets,
|
||||||
read_offsets=read_offsets,
|
read_offsets=read_offsets,
|
||||||
all_input_ids=all_input_ids,
|
all_input_ids=all_input_ids,
|
||||||
@ -1033,12 +1039,11 @@ class FlashCausalLM(Model):
|
|||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
if not CUDA_GRAPHS:
|
self.decode_state = create_decode_state(
|
||||||
self.decode_state = create_decode_state(
|
device=device,
|
||||||
device=device,
|
num_heads=self.num_heads,
|
||||||
num_heads=self.num_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
@ -1404,6 +1409,10 @@ class FlashCausalLM(Model):
|
|||||||
input_lengths = (
|
input_lengths = (
|
||||||
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
|
||||||
).view(-1)
|
).view(-1)
|
||||||
|
prefix_lens_tensor = (
|
||||||
|
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length)
|
||||||
|
+ arange_int
|
||||||
|
).view(-1)
|
||||||
|
|
||||||
# Add Copy the block tables for all members
|
# Add Copy the block tables for all members
|
||||||
block_tables = (
|
block_tables = (
|
||||||
@ -1424,6 +1433,7 @@ class FlashCausalLM(Model):
|
|||||||
block_tables = batch.block_tables_tensor
|
block_tables = batch.block_tables_tensor
|
||||||
slots = batch.slots[batch.slot_indices]
|
slots = batch.slots[batch.slot_indices]
|
||||||
input_lengths = batch.input_lengths_tensor
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
prefix_lens_tensor = batch.prefix_lens_tensor
|
||||||
max_s = batch.max_seqlen
|
max_s = batch.max_seqlen
|
||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
@ -1448,7 +1458,7 @@ class FlashCausalLM(Model):
|
|||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
input_lengths_tensor=input_lengths,
|
input_lengths_tensor=input_lengths,
|
||||||
prefix_lens=batch.prefix_lens,
|
prefix_lens=batch.prefix_lens,
|
||||||
prefix_lens_tensor=batch.prefix_lens_tensor,
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
):
|
):
|
||||||
input_lengths = Seqlen(input_lengths=input_lengths)
|
input_lengths = Seqlen(input_lengths=input_lengths)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
@ -1487,18 +1497,17 @@ class FlashCausalLM(Model):
|
|||||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
cuda_graph["input_lengths"].zero_()
|
cuda_graph["input_lengths"].zero_()
|
||||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||||
input_lengths + batch.prefix_lens_tensor
|
input_lengths + prefix_lens_tensor
|
||||||
)
|
)
|
||||||
|
|
||||||
state = cuda_graph.get("state")
|
|
||||||
with self._forward_context(
|
with self._forward_context(
|
||||||
block_tables=block_tables,
|
block_tables=cuda_graph["block_tables"],
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
input_lengths=batch.input_lengths,
|
input_lengths=batch.input_lengths,
|
||||||
input_lengths_tensor=input_lengths,
|
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||||
prefix_lens=batch.prefix_lens,
|
prefix_lens=batch.prefix_lens,
|
||||||
prefix_lens_tensor=batch.prefix_lens_tensor,
|
prefix_lens_tensor=prefix_lens_tensor,
|
||||||
state=state,
|
state=cuda_graph["state"],
|
||||||
):
|
):
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
@ -1929,7 +1938,7 @@ class FlashCausalLM(Model):
|
|||||||
assert input_lengths_tensor is not None
|
assert input_lengths_tensor is not None
|
||||||
return use_decode_state(
|
return use_decode_state(
|
||||||
state=state if state is not None else self.decode_state,
|
state=state if state is not None else self.decode_state,
|
||||||
input_lengths=input_lengths_tensor + prefix_lens_tensor,
|
input_lengths=input_lengths_tensor,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
@ -5,7 +5,7 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
|
|
||||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False)
|
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"}
|
||||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||||
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
|
ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged")
|
||||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||||
|
Loading…
Reference in New Issue
Block a user