From bada345055b66f11afb7631df5db6e8b7a8c6cd5 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 20 Feb 2024 14:55:30 +0100 Subject: [PATCH] ??? --- integration-tests/models/test_flash_awq.py | 3 -- .../models/test_flash_awq_sharded.py | 2 - integration-tests/models/test_flash_medusa.py | 3 -- .../models/test_flash_mistral.py | 3 -- integration-tests/models/test_flash_phi.py | 3 -- .../models/test_flash_starcoder_gptq.py | 3 -- .../models/test_grammar_llama.py | 5 --- integration-tests/models/test_mamba.py | 3 -- router/src/lib.rs | 38 +------------------ router/src/validation.rs | 15 +++++++- server/text_generation_server/utils/tokens.py | 5 +-- 11 files changed, 18 insertions(+), 65 deletions(-) diff --git a/integration-tests/models/test_flash_awq.py b/integration-tests/models/test_flash_awq.py index 62a95f48..ead918c3 100644 --- a/integration-tests/models/test_flash_awq.py +++ b/integration-tests/models/test_flash_awq.py @@ -18,7 +18,6 @@ async def flash_llama_awq(flash_llama_awq_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq(flash_llama_awq, response_snapshot): response = await flash_llama_awq.generate( "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True @@ -33,7 +32,6 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): response = await flash_llama_awq.generate( "What is Deep Learning?", @@ -55,7 +53,6 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot): responses = await generate_load( flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4 diff --git a/integration-tests/models/test_flash_awq_sharded.py b/integration-tests/models/test_flash_awq_sharded.py index 1c687fc9..a83614ac 100644 --- a/integration-tests/models/test_flash_awq_sharded.py +++ b/integration-tests/models/test_flash_awq_sharded.py @@ -18,7 +18,6 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): response = await flash_llama_awq_sharded.generate( "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True @@ -33,7 +32,6 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_awq_load_sharded( flash_llama_awq_sharded, generate_load, response_snapshot ): diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py index a0ce0570..e0cc1039 100644 --- a/integration-tests/models/test_flash_medusa.py +++ b/integration-tests/models/test_flash_medusa.py @@ -14,7 +14,6 @@ async def flash_medusa(flash_medusa_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_medusa_simple(flash_medusa, response_snapshot): response = await flash_medusa.generate( "What is Deep Learning?", max_new_tokens=10, decoder_input_details=True @@ -25,7 +24,6 @@ async def test_flash_medusa_simple(flash_medusa, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_medusa_all_params(flash_medusa, response_snapshot): response = await flash_medusa.generate( "What is Deep Learning?", @@ -48,7 +46,6 @@ async def test_flash_medusa_all_params(flash_medusa, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_medusa_load(flash_medusa, generate_load, response_snapshot): responses = await generate_load( flash_medusa, "What is Deep Learning?", max_new_tokens=10, n=4 diff --git a/integration-tests/models/test_flash_mistral.py b/integration-tests/models/test_flash_mistral.py index ace3328b..52b51928 100644 --- a/integration-tests/models/test_flash_mistral.py +++ b/integration-tests/models/test_flash_mistral.py @@ -14,7 +14,6 @@ async def flash_mistral(flash_mistral_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_mistral(flash_mistral, response_snapshot): response = await flash_mistral.generate( "Test request", max_new_tokens=10, decoder_input_details=True @@ -26,7 +25,6 @@ async def test_flash_mistral(flash_mistral, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_mistral_all_params(flash_mistral, response_snapshot): response = await flash_mistral.generate( "Test request", @@ -49,7 +47,6 @@ async def test_flash_mistral_all_params(flash_mistral, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_mistral_load(flash_mistral, generate_load, response_snapshot): responses = await generate_load( flash_mistral, "Test request", max_new_tokens=10, n=4 diff --git a/integration-tests/models/test_flash_phi.py b/integration-tests/models/test_flash_phi.py index 0987b3a1..9d6ca566 100644 --- a/integration-tests/models/test_flash_phi.py +++ b/integration-tests/models/test_flash_phi.py @@ -14,7 +14,6 @@ async def flash_phi(flash_phi_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_phi(flash_phi, response_snapshot): response = await flash_phi.generate( "Test request", max_new_tokens=10, decoder_input_details=True @@ -26,7 +25,6 @@ async def test_flash_phi(flash_phi, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_phi_all_params(flash_phi, response_snapshot): response = await flash_phi.generate( "Test request", @@ -50,7 +48,6 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_phi_load(flash_phi, generate_load, response_snapshot): responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4) diff --git a/integration-tests/models/test_flash_starcoder_gptq.py b/integration-tests/models/test_flash_starcoder_gptq.py index 5e448d55..329158b7 100644 --- a/integration-tests/models/test_flash_starcoder_gptq.py +++ b/integration-tests/models/test_flash_starcoder_gptq.py @@ -14,7 +14,6 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snapshot): response = await flash_starcoder_gptq.generate( "def geometric_mean(L: List[float]):", @@ -26,7 +25,6 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap @pytest.mark.asyncio -@pytest.mark.private async def test_flash_starcoder_gptq_default_params( flash_starcoder_gptq, generous_response_snapshot ): @@ -43,7 +41,6 @@ async def test_flash_starcoder_gptq_default_params( @pytest.mark.asyncio -@pytest.mark.private async def test_flash_starcoder_gptq_load( flash_starcoder_gptq, generate_load, generous_response_snapshot ): diff --git a/integration-tests/models/test_grammar_llama.py b/integration-tests/models/test_grammar_llama.py index f068496c..ba123999 100644 --- a/integration-tests/models/test_grammar_llama.py +++ b/integration-tests/models/test_grammar_llama.py @@ -19,7 +19,6 @@ async def flash_llama_grammar(flash_llama_grammar_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): response = await flash_llama_grammar.generate( "Test request", max_new_tokens=10, decoder_input_details=True @@ -30,7 +29,6 @@ async def test_flash_llama_grammar(flash_llama_grammar, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot): response = await flash_llama_grammar.generate( "Whats Googles DNS", @@ -49,7 +47,6 @@ async def test_flash_llama_grammar_regex(flash_llama_grammar, response_snapshot) @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): response = await flash_llama_grammar.generate( "info: david holtz like trees and has two cats. ", @@ -98,7 +95,6 @@ async def test_flash_llama_grammar_json(flash_llama_grammar, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_load( flash_llama_grammar, generate_load, response_snapshot ): @@ -130,7 +126,6 @@ async def test_flash_llama_grammar_load( # this is the same as the above test, but only fires off a single request # this is only to ensure that the parallel and single inference produce the same result @pytest.mark.asyncio -@pytest.mark.private async def test_flash_llama_grammar_single_load_instance( flash_llama_grammar, generate_load, response_snapshot ): diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py index 5ec2ec31..bf3701b4 100644 --- a/integration-tests/models/test_mamba.py +++ b/integration-tests/models/test_mamba.py @@ -14,7 +14,6 @@ async def fused_kernel_mamba(fused_kernel_mamba_handle): @pytest.mark.asyncio -@pytest.mark.private async def test_mamba(fused_kernel_mamba, response_snapshot): response = await fused_kernel_mamba.generate( "What is Deep Learning?", max_new_tokens=10 @@ -26,7 +25,6 @@ async def test_mamba(fused_kernel_mamba, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): response = await fused_kernel_mamba.generate( "blue, red, yellow, ", @@ -53,7 +51,6 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): @pytest.mark.asyncio -@pytest.mark.private async def test_mamba_load( fused_kernel_mamba, generate_load, generous_response_snapshot ): diff --git a/router/src/lib.rs b/router/src/lib.rs index c1da6572..7bf51f5d 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -45,37 +45,6 @@ impl HubTokenizerConfig { } } -mod json_object_or_string_to_string { - use jsonschema::{Draft, JSONSchema}; - use serde::{Deserialize, Deserializer}; - use serde_json::Value; - - // A custom deserializer that treats both strings and objects as strings. - // This provides flexibility with input formats for the 'grammar' field. - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let value = Value::deserialize(deserializer)?; - - JSONSchema::options() - .with_draft(Draft::Draft202012) - .compile(&value) - .map_err(|e| serde::de::Error::custom(format!("invalid JSONSchema: {e}")))?; - - match value { - Value::String(s) => Ok(s), - // Safely handle serialization and return an error if it fails - Value::Object(o) => { - serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string())) - } - _ => Err(serde::de::Error::custom( - "expected string or object for grammar", - )), - } - } -} - #[derive(Clone, Debug, Deserialize, ToSchema)] #[serde(tag = "type", content = "value")] pub(crate) enum GrammarType { @@ -83,12 +52,9 @@ pub(crate) enum GrammarType { /// /// JSON Schema is a declarative language that allows to annotate JSON documents /// with types and descriptions. - #[serde( - rename = "json", - deserialize_with = "json_object_or_string_to_string::deserialize" - )] + #[serde(rename = "json")] #[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))] - Json(String), + Json(serde_json::Value), #[serde(rename = "regex")] Regex(String), } diff --git a/router/src/validation.rs b/router/src/validation.rs index bf85b12f..f350d15e 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -314,7 +314,18 @@ impl Validation { } match grammar { // currently both are handled the same way since compilation is done in Python - GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()), + GrammarType::Json(json) => { + // JSONSchema::options() + // .with_draft(Draft::Draft202012) + // .compile(&json) + // .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?; + + ( + serde_json::to_string(&json) + .map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?, + ProtoGrammarType::Json.into(), + ) + } GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()), } } @@ -486,6 +497,8 @@ pub enum ValidationError { Tokenizer(String), #[error("grammar is not supported")] Grammar, + #[error("grammar is not a valid JSONSchema: {0}")] + InvalidGrammar(String), } #[cfg(test)] diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 72c6c21c..32789850 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -328,7 +328,6 @@ class HeterogeneousNextTokenChooser: scores = scores.view(B, S, -1) next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long) - mask = torch.full((scores.shape[-1],), -math.inf, device=self.device) for j in range(S): _scores = scores[:, j] @@ -338,10 +337,10 @@ class HeterogeneousNextTokenChooser: _scores = self.repetition_processor(input_ids, _scores) if self.frequency_processor is not None: _scores = self.frequency_processor(input_ids, _scores) - for warper in self.warpers: - _scores = warper(input_ids, _scores) if self.grammar_processor is not None: _scores = self.grammar_processor(_scores, self.fsm_grammar_states) + for warper in self.warpers: + _scores = warper(input_ids, _scores) _next_ids = self.choice(_scores) scores[:, j] = _scores next_ids[:, j] = _next_ids