fix: remove unnecessary code, avoid copies and make deser safer

This commit is contained in:
drbh 2024-02-14 15:54:32 +00:00
parent be7835475b
commit f0cdd9c8ea
4 changed files with 15 additions and 17 deletions

View File

@ -146,7 +146,6 @@ class ResponseComparator(JSONSnapshotExtension):
response.details, other.details
)
# print(serialized_data)
serialized_data = convert_data(serialized_data)
snapshot_data = convert_data(snapshot_data)

View File

@ -44,15 +44,13 @@ impl HubTokenizerConfig {
serde_json::from_str(&content).unwrap_or_default()
}
}
mod json_object_or_string_to_string {
// This custom deserializer is used to handle the fact that the grammar field can be either a
// string or an object. In both cases we handle it as a string, but also provide this convience
// to the user to be flexible with the input.
use super::*;
use serde::de;
use serde::Deserializer;
use serde::{Deserialize, Deserializer};
use serde_json::Value;
// A custom deserializer that treats both strings and objects as strings.
// This provides flexibility with input formats for the 'grammar' field.
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
where
D: Deserializer<'de>,
@ -61,8 +59,13 @@ mod json_object_or_string_to_string {
match value {
Value::String(s) => Ok(s),
Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()),
_ => Err(de::Error::custom("expected string or object for grammar")),
// Safely handle serialization and return an error if it fails
Value::Object(o) => {
serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string()))
}
_ => Err(serde::de::Error::custom(
"expected string or object for grammar",
)),
}
}
}

View File

@ -1050,13 +1050,11 @@ class FlashCausalLM(Model):
prefill_logprobs = prefill_logprobs.view(-1).tolist()
# GPU <-> CPU sync
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
next_input_ids.tolist(),
)
next_token_logprobs = next_token_logprobs.tolist()
next_token_ids = next_input_ids.tolist()
accepted_ids = accepted_ids.tolist()
start_decode = time.time_ns()
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(next_token_ids)
# Zipped iterator
iterator = zip(

View File

@ -575,16 +575,14 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
fsm_grammar_states: List[int],
mask: torch.Tensor,
):
mask = torch.full_like(logits, -math.inf)
for i in range(logits.shape[0]):
fsm = self.fsms[i]
if fsm_grammar_states[i] == -1 or fsm is None:
continue
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask[allowed_tokens] = 0
biased_scores = logits[i] + mask
mask.fill_(-math.inf)
logits[i] = biased_scores
mask[i, allowed_tokens] = 0
logits += mask
return logits
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):