mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
fix: remove unnecessary code, avoid copies and make deser safer
This commit is contained in:
parent
be7835475b
commit
f0cdd9c8ea
@ -146,7 +146,6 @@ class ResponseComparator(JSONSnapshotExtension):
|
||||
response.details, other.details
|
||||
)
|
||||
|
||||
# print(serialized_data)
|
||||
serialized_data = convert_data(serialized_data)
|
||||
snapshot_data = convert_data(snapshot_data)
|
||||
|
||||
|
@ -44,15 +44,13 @@ impl HubTokenizerConfig {
|
||||
serde_json::from_str(&content).unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
mod json_object_or_string_to_string {
|
||||
// This custom deserializer is used to handle the fact that the grammar field can be either a
|
||||
// string or an object. In both cases we handle it as a string, but also provide this convience
|
||||
// to the user to be flexible with the input.
|
||||
use super::*;
|
||||
use serde::de;
|
||||
use serde::Deserializer;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use serde_json::Value;
|
||||
|
||||
// A custom deserializer that treats both strings and objects as strings.
|
||||
// This provides flexibility with input formats for the 'grammar' field.
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
@ -61,8 +59,13 @@ mod json_object_or_string_to_string {
|
||||
|
||||
match value {
|
||||
Value::String(s) => Ok(s),
|
||||
Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()),
|
||||
_ => Err(de::Error::custom("expected string or object for grammar")),
|
||||
// Safely handle serialization and return an error if it fails
|
||||
Value::Object(o) => {
|
||||
serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string()))
|
||||
}
|
||||
_ => Err(serde::de::Error::custom(
|
||||
"expected string or object for grammar",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1050,13 +1050,11 @@ class FlashCausalLM(Model):
|
||||
prefill_logprobs = prefill_logprobs.view(-1).tolist()
|
||||
|
||||
# GPU <-> CPU sync
|
||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
|
||||
next_input_ids.tolist(),
|
||||
)
|
||||
next_token_logprobs = next_token_logprobs.tolist()
|
||||
next_token_ids = next_input_ids.tolist()
|
||||
accepted_ids = accepted_ids.tolist()
|
||||
start_decode = time.time_ns()
|
||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(next_token_ids)
|
||||
|
||||
# Zipped iterator
|
||||
iterator = zip(
|
||||
|
@ -575,16 +575,14 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
fsm_grammar_states: List[int],
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
mask = torch.full_like(logits, -math.inf)
|
||||
for i in range(logits.shape[0]):
|
||||
fsm = self.fsms[i]
|
||||
if fsm_grammar_states[i] == -1 or fsm is None:
|
||||
continue
|
||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||
mask[allowed_tokens] = 0
|
||||
biased_scores = logits[i] + mask
|
||||
mask.fill_(-math.inf)
|
||||
logits[i] = biased_scores
|
||||
|
||||
mask[i, allowed_tokens] = 0
|
||||
logits += mask
|
||||
return logits
|
||||
|
||||
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):
|
||||
|
Loading…
Reference in New Issue
Block a user