mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix: remove unnecessary code, avoid copies and make deser safer
This commit is contained in:
parent
be7835475b
commit
f0cdd9c8ea
@ -146,7 +146,6 @@ class ResponseComparator(JSONSnapshotExtension):
|
|||||||
response.details, other.details
|
response.details, other.details
|
||||||
)
|
)
|
||||||
|
|
||||||
# print(serialized_data)
|
|
||||||
serialized_data = convert_data(serialized_data)
|
serialized_data = convert_data(serialized_data)
|
||||||
snapshot_data = convert_data(snapshot_data)
|
snapshot_data = convert_data(snapshot_data)
|
||||||
|
|
||||||
|
@ -44,15 +44,13 @@ impl HubTokenizerConfig {
|
|||||||
serde_json::from_str(&content).unwrap_or_default()
|
serde_json::from_str(&content).unwrap_or_default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mod json_object_or_string_to_string {
|
mod json_object_or_string_to_string {
|
||||||
// This custom deserializer is used to handle the fact that the grammar field can be either a
|
use serde::{Deserialize, Deserializer};
|
||||||
// string or an object. In both cases we handle it as a string, but also provide this convience
|
|
||||||
// to the user to be flexible with the input.
|
|
||||||
use super::*;
|
|
||||||
use serde::de;
|
|
||||||
use serde::Deserializer;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
|
// A custom deserializer that treats both strings and objects as strings.
|
||||||
|
// This provides flexibility with input formats for the 'grammar' field.
|
||||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||||
where
|
where
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
@ -61,8 +59,13 @@ mod json_object_or_string_to_string {
|
|||||||
|
|
||||||
match value {
|
match value {
|
||||||
Value::String(s) => Ok(s),
|
Value::String(s) => Ok(s),
|
||||||
Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()),
|
// Safely handle serialization and return an error if it fails
|
||||||
_ => Err(de::Error::custom("expected string or object for grammar")),
|
Value::Object(o) => {
|
||||||
|
serde_json::to_string(&o).map_err(|e| serde::de::Error::custom(e.to_string()))
|
||||||
|
}
|
||||||
|
_ => Err(serde::de::Error::custom(
|
||||||
|
"expected string or object for grammar",
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1050,13 +1050,11 @@ class FlashCausalLM(Model):
|
|||||||
prefill_logprobs = prefill_logprobs.view(-1).tolist()
|
prefill_logprobs = prefill_logprobs.view(-1).tolist()
|
||||||
|
|
||||||
# GPU <-> CPU sync
|
# GPU <-> CPU sync
|
||||||
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(
|
|
||||||
next_input_ids.tolist(),
|
|
||||||
)
|
|
||||||
next_token_logprobs = next_token_logprobs.tolist()
|
next_token_logprobs = next_token_logprobs.tolist()
|
||||||
next_token_ids = next_input_ids.tolist()
|
next_token_ids = next_input_ids.tolist()
|
||||||
accepted_ids = accepted_ids.tolist()
|
accepted_ids = accepted_ids.tolist()
|
||||||
start_decode = time.time_ns()
|
start_decode = time.time_ns()
|
||||||
|
batch.next_token_chooser = batch.next_token_chooser.advance_grammar(next_token_ids)
|
||||||
|
|
||||||
# Zipped iterator
|
# Zipped iterator
|
||||||
iterator = zip(
|
iterator = zip(
|
||||||
|
@ -575,16 +575,14 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||||||
fsm_grammar_states: List[int],
|
fsm_grammar_states: List[int],
|
||||||
mask: torch.Tensor,
|
mask: torch.Tensor,
|
||||||
):
|
):
|
||||||
|
mask = torch.full_like(logits, -math.inf)
|
||||||
for i in range(logits.shape[0]):
|
for i in range(logits.shape[0]):
|
||||||
fsm = self.fsms[i]
|
fsm = self.fsms[i]
|
||||||
if fsm_grammar_states[i] == -1 or fsm is None:
|
if fsm_grammar_states[i] == -1 or fsm is None:
|
||||||
continue
|
continue
|
||||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||||
mask[allowed_tokens] = 0
|
mask[i, allowed_tokens] = 0
|
||||||
biased_scores = logits[i] + mask
|
logits += mask
|
||||||
mask.fill_(-math.inf)
|
|
||||||
logits[i] = biased_scores
|
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):
|
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):
|
||||||
|
Loading…
Reference in New Issue
Block a user