From 549f0e9ca73de45e4fd9739d47591acd907c1d34 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 13 Aug 2024 11:58:55 +0200 Subject: [PATCH] Fixing medusa (still wrong outputs, but functional). --- flake.nix | 1 + router.nix | 20 ++++++++++ .../text_generation_server/layers/medusa.py | 2 + .../models/flash_causal_lm.py | 37 ++++++++++++------- .../text_generation_server/models/globals.py | 2 +- 5 files changed, 47 insertions(+), 15 deletions(-) create mode 100644 router.nix diff --git a/flake.nix b/flake.nix index cf05746a..b5eb0eed 100644 --- a/flake.nix +++ b/flake.nix @@ -81,6 +81,7 @@ grpcio-status grpcio-tools hf-transfer + ipdb loguru mamba-ssm marlin-kernels diff --git a/router.nix b/router.nix new file mode 100644 index 00000000..2147ccc9 --- /dev/null +++ b/router.nix @@ -0,0 +1,20 @@ +{ buildRustPackage, importCargoLock, pkg-config, protobuf, openssl }: + +buildRustPackage { + name = "text-generation-router"; + + src = ./.; + + sourceDir = ./backends/v3; + + cargoLock = { + lockFile = ./Cargo.lock; + }; + + nativeBuildInputs = [ pkg-config ]; + + buildInputs = [ openssl.dev protobuf ]; + + doCheck = false; + +} diff --git a/server/text_generation_server/layers/medusa.py b/server/text_generation_server/layers/medusa.py index 7579ccdb..139c4dc2 100644 --- a/server/text_generation_server/layers/medusa.py +++ b/server/text_generation_server/layers/medusa.py @@ -32,6 +32,8 @@ class MedusaModel(torch.nn.Module): ) def forward(self, x): + if not self.heads: + return None speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) return speculative_logits diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 7cab1c4b..387118c2 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -552,6 +552,7 @@ class FlashCausalLMBatch(Batch): prefix_ids = [] input_lengths = [] + prefix_lens = [] prefix_offsets = [] read_offsets = [] @@ -573,12 +574,14 @@ class FlashCausalLMBatch(Batch): # Get length request_input_length = self.input_lengths[idx] + prefix_len = self.prefix_lens[idx] max_seqlen = max(max_seqlen, request_input_length) all_input_ids.append(self.all_input_ids[idx]) prefix_ids.append(self.prefix_ids[idx]) input_lengths.append(request_input_length) + prefix_lens.append(prefix_len) prefix_offsets.append(self.prefix_offsets[idx]) read_offsets.append(self.read_offsets[idx]) @@ -623,6 +626,7 @@ class FlashCausalLMBatch(Batch): block_tables_tensor = self.block_tables_tensor[indices] input_lengths_tensor = self.input_lengths_tensor[indices] slots = self.slots[slot_filtering_indices] + prefix_lens_tensor = self.prefix_lens_tensor[indices] next_token_chooser = self.next_token_chooser.filter(indices) top_n_tokens_tensor = self.top_n_tokens_tensor[indices] speculative_ids = ( @@ -658,6 +662,8 @@ class FlashCausalLMBatch(Batch): prefill_cu_outlens=None, input_lengths=input_lengths, input_lengths_tensor=input_lengths_tensor, + prefix_lens=prefix_lens, + prefix_lens_tensor=prefix_lens_tensor, prefix_offsets=prefix_offsets, read_offsets=read_offsets, all_input_ids=all_input_ids, @@ -1033,12 +1039,11 @@ class FlashCausalLM(Model): device=device ) - if not CUDA_GRAPHS: - self.decode_state = create_decode_state( - device=device, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - ) + self.decode_state = create_decode_state( + device=device, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) super().__init__( model_id=model_id, @@ -1404,6 +1409,10 @@ class FlashCausalLM(Model): input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) + prefix_lens_tensor = ( + batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) + + arange_int + ).view(-1) # Add Copy the block tables for all members block_tables = ( @@ -1424,6 +1433,7 @@ class FlashCausalLM(Model): block_tables = batch.block_tables_tensor slots = batch.slots[batch.slot_indices] input_lengths = batch.input_lengths_tensor + prefix_lens_tensor = batch.prefix_lens_tensor max_s = batch.max_seqlen lm_head_indices = batch.prefill_head_indices @@ -1448,7 +1458,7 @@ class FlashCausalLM(Model): input_lengths=batch.input_lengths, input_lengths_tensor=input_lengths, prefix_lens=batch.prefix_lens, - prefix_lens_tensor=batch.prefix_lens_tensor, + prefix_lens_tensor=prefix_lens_tensor, ): input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( @@ -1487,18 +1497,17 @@ class FlashCausalLM(Model): cuda_graph["slots"][: slots.shape[0]] = slots cuda_graph["input_lengths"].zero_() cuda_graph["input_lengths"][: input_lengths.shape[0]] = ( - input_lengths + batch.prefix_lens_tensor + input_lengths + prefix_lens_tensor ) - state = cuda_graph.get("state") with self._forward_context( - block_tables=block_tables, + block_tables=cuda_graph["block_tables"], cu_seqlen_prefill=None, input_lengths=batch.input_lengths, - input_lengths_tensor=input_lengths, + input_lengths_tensor=cuda_graph["input_lengths"], prefix_lens=batch.prefix_lens, - prefix_lens_tensor=batch.prefix_lens_tensor, - state=state, + prefix_lens_tensor=prefix_lens_tensor, + state=cuda_graph["state"], ): # Replay the graph cuda_graph["graph"].replay() @@ -1929,7 +1938,7 @@ class FlashCausalLM(Model): assert input_lengths_tensor is not None return use_decode_state( state=state if state is not None else self.decode_state, - input_lengths=input_lengths_tensor + prefix_lens_tensor, + input_lengths=input_lengths_tensor, block_tables=block_tables, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index fbff1cec..d5133f5e 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,7 +5,7 @@ from typing import Dict, Optional from text_generation_server.utils.log import log_master -PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", False) +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") _expected = {"paged", "flashdecoding", "flashinfer"}