diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index d499fee9..c579fff5 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -146,7 +146,6 @@ class ResponseComparator(JSONSnapshotExtension): response.details, other.details ) - # print(serialized_data) serialized_data = convert_data(serialized_data) snapshot_data = convert_data(snapshot_data) diff --git a/router/src/lib.rs b/router/src/lib.rs index d3602e24..87873821 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -44,15 +44,13 @@ impl HubTokenizerConfig { serde_json::from_str(&content).unwrap_or_default() } } + mod json_object_or_string_to_string { - // This custom deserializer is used to handle the fact that the grammar field can be either a - // string or an object. In both cases we handle it as a string, but also provide this convience - // to the user to be flexible with the input. - use super::*; - use serde::de; - use serde::Deserializer; + use serde::{Deserialize, Deserializer}; use serde_json::Value; + // A custom deserializer that treats both strings and objects as strings. + // This provides flexibility with input formats for the 'grammar' field. pub fn deserialize<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, @@ -61,8 +59,13 @@ mod json_object_or_string_to_string { match value { Value::String(s) => Ok(s), - Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()), - _ => Err(de::Error::custom("expected string or object for grammar")), + // Safely handle serialization and return an error if it fails + Value::Object(o) => { + serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string())) + } + _ => Err(serde::de::Error::custom( + "expected string or object for grammar", + )), } } } diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 0b455fec..25eacf64 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1050,13 +1050,11 @@ class FlashCausalLM(Model): prefill_logprobs = prefill_logprobs.view(-1).tolist() # GPU <-> CPU sync - batch.next_token_chooser = batch.next_token_chooser.advance_grammar( - next_input_ids.tolist(), - ) next_token_logprobs = next_token_logprobs.tolist() next_token_ids = next_input_ids.tolist() accepted_ids = accepted_ids.tolist() start_decode = time.time_ns() + batch.next_token_chooser = batch.next_token_chooser.advance_grammar(next_token_ids) # Zipped iterator iterator = zip( diff --git a/server/text_generation_server/utils/logits_process.py b/server/text_generation_server/utils/logits_process.py index 718fc12e..ddc151da 100644 --- a/server/text_generation_server/utils/logits_process.py +++ b/server/text_generation_server/utils/logits_process.py @@ -575,16 +575,14 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): fsm_grammar_states: List[int], mask: torch.Tensor, ): + mask = torch.full_like(logits, -math.inf) for i in range(logits.shape[0]): fsm = self.fsms[i] if fsm_grammar_states[i] == -1 or fsm is None: continue allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) - mask[allowed_tokens] = 0 - biased_scores = logits[i] + mask - mask.fill_(-math.inf) - logits[i] = biased_scores - + mask[i, allowed_tokens] = 0 + logits += mask return logits def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):