diff --git a/docs/openapi.json b/docs/openapi.json index 903f7426..22b06720 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -101,6 +101,47 @@ } } }, + "/chat_tokenize": { + "post": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Template and tokenize ChatRequest", + "operationId": "get_chat_tokenize", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Templated and tokenized ChatRequest", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ChatTokenizeResponse" + } + } + } + }, + "404": { + "description": "Failed to tokenize ChatRequest", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + } + } + } + }, "/generate": { "post": { "tags": [ @@ -1092,6 +1133,21 @@ } } }, + "ChatTokenizeResponse": { + "type": "object", + "required": [ + "tokenize_response", + "templated_text" + ], + "properties": { + "templated_text": { + "type": "string" + }, + "tokenize_response": { + "$ref": "#/components/schemas/TokenizeResponse" + } + } + }, "Chunk": { "type": "object", "required": [ diff --git a/flake.lock b/flake.lock index 69ce6cd5..5246f424 100644 --- a/flake.lock +++ b/flake.lock @@ -978,16 +978,15 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1729761651, - "narHash": "sha256-GYykQ9Fxji2EuXCGcPn0dx8Qx8VQBJTkRdcCytp4A/k=", + "lastModified": 1730724647, + "narHash": "sha256-SVv+50CGaCoU4zZwsg6ZAaOi/D5QJBL1P2SIB+3CEf4=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "f7e3c4fa67d70590ed9ee47feeab645bd9ba81b1", + "rev": "1512898a1e5ad9eff025205fa9c4d33a44506cf3", "type": "github" }, "original": { "owner": "huggingface", - "ref": "marlin-kernels-0.3.1", "repo": "text-generation-inference-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index 45441cae..f26a983e 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.1"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 19a79115..64f4f515 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1687,13 +1687,6 @@ fn main() -> Result<(), LauncherError> { let max_position_embeddings = if let Some(config) = &config { if let Some(max_position_embeddings) = config.max_position_embeddings { if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } max_default } else { max_position_embeddings diff --git a/router/src/server.rs b/router/src/server.rs index 863607b1..7d8d518c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -181,12 +181,16 @@ async fn openai_get_model_info(info: Extension) -> Json { }) } +/// Template and tokenize ChatRequest #[utoipa::path( post, tag = "Text Generation Inference", path = "/chat_tokenize", request_body = ChatRequest, - responses((status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse)) + responses( + (status = 200, description = "Templated and tokenized ChatRequest", body = ChatTokenizeResponse), + (status = 404, description = "Failed to tokenize ChatRequest", body = ErrorResponse), + ) )] async fn get_chat_tokenize( Extension(infer): Extension, @@ -1501,6 +1505,7 @@ tokenize, metrics, openai_get_model_info, sagemaker_compatibility, +get_chat_tokenize, ), components( schemas( @@ -1558,6 +1563,7 @@ Function, FunctionDefinition, ToolChoice, ModelInfo, +ChatTokenizeResponse, ) ), tags( diff --git a/server/text_generation_server/layers/awq/quantize/ipex.py b/server/text_generation_server/layers/awq/quantize/ipex.py index 84cd7a21..842e9623 100644 --- a/server/text_generation_server/layers/awq/quantize/ipex.py +++ b/server/text_generation_server/layers/awq/quantize/ipex.py @@ -44,5 +44,4 @@ class WQLinear(nn.Module): def forward(self, x): out_shape = x.shape[:-1] + (self.out_features,) out = self.woq_linear(x.reshape(-1, x.shape[-1])) - out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) diff --git a/server/text_generation_server/layers/gptq/ipex.py b/server/text_generation_server/layers/gptq/ipex.py index ab9c9e24..48584e90 100644 --- a/server/text_generation_server/layers/gptq/ipex.py +++ b/server/text_generation_server/layers/gptq/ipex.py @@ -122,5 +122,4 @@ class QuantLinear(nn.Module): def forward(self, x): out_shape = x.shape[:-1] + (self.outfeatures,) out = self.woq_linear(x.reshape(-1, x.shape[-1])) - out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 52ab5d6a..bb908fd0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -887,11 +887,12 @@ class FlashCausalLMBatch(Batch): fsm_grammar_states=fsm_grammar_states, ) - speculative_ids = ( - torch.cat([b.speculative_ids for b in batches], dim=0) - if batches[0].speculative_ids is not None - else None - ) + # We skip computing the speculative_ids when the batch size is too large, so + # we must check that all batches have them, otherwise they must be discarded + if get_speculate() > 0 and all(b.speculative_ids is not None for b in batches): + speculative_ids = torch.cat([b.speculative_ids for b in batches], dim=0) + else: + speculative_ids = None if adapter_segment_builder is not None: adapter_segments, adapter_segment_indices = adapter_segment_builder.build() @@ -1532,8 +1533,6 @@ class FlashCausalLM(Model): self.kv_cache_dtype, self.device, ) - max_bt = batch.max_blocks - max_s = max_bt * BLOCK_SIZE batch_num_blocks = batch.num_blocks if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False): @@ -1651,7 +1650,7 @@ class FlashCausalLM(Model): # Warmup cuda graphs for bs in CUDA_GRAPHS: if self.speculate is None or self.speculate + 1 <= bs: - self.cuda_graph_warmup(bs, max_s, max_bt) + self.cuda_graph_warmup(bs, max_total_tokens, max_total_tokens) except torch.cuda.OutOfMemoryError: logger.exception("Decode cuda graph warmup failed") else: @@ -1726,7 +1725,15 @@ class FlashCausalLM(Model): new_position_ids = ( position_ids.unsqueeze(-1).expand(B, new_length) + arange ).view(-1) - slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1) + + # Slots can be discontiguous when prefix caching is enabled, so we need to expand the slot_indices, + # then update the slots with the additional indices to ensure we're grabbing the ones that have been + # allocated + slot_indices = ( + batch.slot_indices.unsqueeze(-1).expand(B, new_length) + arange_int + ).view(-1) + slots = batch.slots[slot_indices] + input_lengths = ( input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int ).view(-1) diff --git a/server/text_generation_server/models/metadata_kernels.py b/server/text_generation_server/models/metadata_kernels.py index b3e2160d..783aab80 100644 --- a/server/text_generation_server/models/metadata_kernels.py +++ b/server/text_generation_server/models/metadata_kernels.py @@ -55,7 +55,7 @@ def block_tables_to_ragged( cache_lengths: List[int], input_lengths_tensor: torch.Tensor, cache_lengths_tensor: torch.Tensor, - max_current_length: int + max_current_length: int, ) -> torch.Tensor: """Convert block table to ragged format compatible with FlashInfer.""" assert len(input_lengths) == len(cache_lengths)