mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-25 03:52:08 +00:00
feat: remove uncompile grammar and improve logit processor logic
This commit is contained in:
parent
c52a0f679e
commit
1f7be736d2
@ -250,5 +250,7 @@ ENTRYPOINT ["./entrypoint.sh"]
|
|||||||
# Final image
|
# Final image
|
||||||
FROM base
|
FROM base
|
||||||
|
|
||||||
|
ENV LD_LIBRARY_PATH=/opt/conda/lib/:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
@ -8,7 +8,7 @@ use crate::app::App;
|
|||||||
use crate::event::Event;
|
use crate::event::Event;
|
||||||
use crossterm::ExecutableCommand;
|
use crossterm::ExecutableCommand;
|
||||||
use std::io;
|
use std::io;
|
||||||
use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient};
|
use text_generation_client::{NextTokenChooserParameters, ShardedClient};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use tokio::sync::{broadcast, mpsc};
|
use tokio::sync::{broadcast, mpsc};
|
||||||
use tui::backend::CrosstermBackend;
|
use tui::backend::CrosstermBackend;
|
||||||
@ -45,8 +45,6 @@ pub async fn run(
|
|||||||
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
repetition_penalty: repetition_penalty.unwrap_or(1.0),
|
||||||
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
frequency_penalty: frequency_penalty.unwrap_or(0.0),
|
||||||
watermark,
|
watermark,
|
||||||
grammar: String::new(),
|
|
||||||
grammar_type: GrammarType::None as i32,
|
|
||||||
states_to_token_maps: None,
|
states_to_token_maps: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -51,12 +51,6 @@ message ClearCacheRequest {
|
|||||||
/// Empty response
|
/// Empty response
|
||||||
message ClearCacheResponse {}
|
message ClearCacheResponse {}
|
||||||
|
|
||||||
enum GrammarType {
|
|
||||||
GRAMMAR_TYPE_NONE = 0;
|
|
||||||
GRAMMAR_TYPE_JSON = 1;
|
|
||||||
GRAMMAR_TYPE_REGEX = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message NextTokenChooserParameters {
|
message NextTokenChooserParameters {
|
||||||
/// exponential scaling output probability distribution
|
/// exponential scaling output probability distribution
|
||||||
float temperature = 1;
|
float temperature = 1;
|
||||||
@ -76,10 +70,6 @@ message NextTokenChooserParameters {
|
|||||||
float frequency_penalty = 9;
|
float frequency_penalty = 9;
|
||||||
/// token watermarking using "A Watermark for Large Language Models"
|
/// token watermarking using "A Watermark for Large Language Models"
|
||||||
bool watermark = 8;
|
bool watermark = 8;
|
||||||
/// grammar (applied if not empty)
|
|
||||||
string grammar = 10;
|
|
||||||
/// grammar type
|
|
||||||
GrammarType grammar_type = 11;
|
|
||||||
/// states to token maps
|
/// states to token maps
|
||||||
StatesToTokenMaps states_to_token_maps = 12;
|
StatesToTokenMaps states_to_token_maps = 12;
|
||||||
}
|
}
|
||||||
|
@ -135,8 +135,6 @@ impl Client {
|
|||||||
repetition_penalty: 1.2,
|
repetition_penalty: 1.2,
|
||||||
frequency_penalty: 0.1,
|
frequency_penalty: 0.1,
|
||||||
watermark: true,
|
watermark: true,
|
||||||
grammar: String::new(),
|
|
||||||
grammar_type: GrammarType::None as i32,
|
|
||||||
states_to_token_maps: None,
|
states_to_token_maps: None,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
@ -9,8 +9,8 @@ pub use client::Client;
|
|||||||
pub use pb::generate::v2::HealthResponse;
|
pub use pb::generate::v2::HealthResponse;
|
||||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||||
pub use pb::generate::v2::{
|
pub use pb::generate::v2::{
|
||||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
||||||
NextTokenChooserParameters, Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens,
|
Request, StatesToTokenMaps, StoppingCriteriaParameters, Tokens,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use text_generation_client::GrammarType as ProtoGrammarType;
|
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
@ -46,8 +45,6 @@ impl Health {
|
|||||||
repetition_penalty: 1.0,
|
repetition_penalty: 1.0,
|
||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
|
||||||
grammar_type: ProtoGrammarType::None as i32,
|
|
||||||
states_to_token_maps: None,
|
states_to_token_maps: None,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
|
@ -343,9 +343,7 @@ enum QueueCommand {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use text_generation_client::{
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
|
||||||
};
|
|
||||||
use tracing::info_span;
|
use tracing::info_span;
|
||||||
|
|
||||||
fn default_entry() -> (
|
fn default_entry() -> (
|
||||||
@ -370,8 +368,6 @@ mod tests {
|
|||||||
repetition_penalty: 0.0,
|
repetition_penalty: 0.0,
|
||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
|
||||||
grammar_type: ProtoGrammarType::None as i32,
|
|
||||||
states_to_token_maps: None,
|
states_to_token_maps: None,
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
|
@ -7,8 +7,7 @@ use jsonschema::{Draft, JSONSchema};
|
|||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StatesToTokenMaps,
|
NextTokenChooserParameters, StatesToTokenMaps, StoppingCriteriaParameters,
|
||||||
StoppingCriteriaParameters,
|
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
@ -368,7 +367,7 @@ impl Validation {
|
|||||||
// compiler and use that to build the FSM here.
|
// compiler and use that to build the FSM here.
|
||||||
|
|
||||||
// Validate grammar and unpack the grammar and type for the proto message
|
// Validate grammar and unpack the grammar and type for the proto message
|
||||||
let (grammar, grammar_type, states_to_token_maps) = match grammar {
|
let states_to_token_maps = match grammar {
|
||||||
Some(grammar) => {
|
Some(grammar) => {
|
||||||
// Ensure that grammar is not set if it's not supported
|
// Ensure that grammar is not set if it's not supported
|
||||||
if self.disable_grammar_support {
|
if self.disable_grammar_support {
|
||||||
@ -392,7 +391,7 @@ impl Validation {
|
|||||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||||
|
|
||||||
// NOTE: this is the first step to compile the grammar
|
// NOTE: this is the first step to compile the grammar
|
||||||
let (regex_compiled_grammar, _states_to_token_maps) = self
|
let (_regex_compiled_grammar, _states_to_token_maps) = self
|
||||||
.compile_grammar(serde_json::to_string(&json).unwrap())
|
.compile_grammar(serde_json::to_string(&json).unwrap())
|
||||||
.await
|
.await
|
||||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||||
@ -416,16 +415,12 @@ impl Validation {
|
|||||||
end_states,
|
end_states,
|
||||||
};
|
};
|
||||||
|
|
||||||
(
|
Some(stm)
|
||||||
regex_compiled_grammar,
|
|
||||||
ProtoGrammarType::Regex.into(),
|
|
||||||
Some(stm),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into(), None),
|
GrammarType::Regex(_regex) => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => (String::new(), ProtoGrammarType::None.into(), None),
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
@ -438,8 +433,6 @@ impl Validation {
|
|||||||
do_sample,
|
do_sample,
|
||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
grammar,
|
|
||||||
grammar_type,
|
|
||||||
states_to_token_maps,
|
states_to_token_maps,
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
@ -567,7 +560,6 @@ fn compile_grammar(
|
|||||||
r#"
|
r#"
|
||||||
from outlines.fsm.fsm import RegexFSM
|
from outlines.fsm.fsm import RegexFSM
|
||||||
from outlines.fsm.json_schema import build_regex_from_schema
|
from outlines.fsm.json_schema import build_regex_from_schema
|
||||||
import time
|
|
||||||
from transformers.file_utils import SPIECE_UNDERLINE
|
from transformers.file_utils import SPIECE_UNDERLINE
|
||||||
|
|
||||||
class Tokenizer:
|
class Tokenizer:
|
||||||
@ -589,13 +581,10 @@ class Tokenizer:
|
|||||||
return " ".join(tokens)
|
return " ".join(tokens)
|
||||||
|
|
||||||
def adapt_tokenizer(vocab, special_tokens):
|
def adapt_tokenizer(vocab, special_tokens):
|
||||||
start_time = time.time()
|
|
||||||
tokenizer = Tokenizer(vocab, special_tokens)
|
tokenizer = Tokenizer(vocab, special_tokens)
|
||||||
|
|
||||||
def convert_token_to_string(token: str) -> str:
|
def convert_token_to_string(token: str) -> str:
|
||||||
|
|
||||||
string = tokenizer.convert_tokens_to_string([token])
|
string = tokenizer.convert_tokens_to_string([token])
|
||||||
|
|
||||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||||
return " " + string
|
return " " + string
|
||||||
@ -603,26 +592,16 @@ def adapt_tokenizer(vocab, special_tokens):
|
|||||||
return string
|
return string
|
||||||
|
|
||||||
tokenizer.convert_token_to_string = convert_token_to_string
|
tokenizer.convert_token_to_string = convert_token_to_string
|
||||||
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
def compile_regex_grammar(inputs, vocab, special_tokens):
|
def compile_regex_grammar(inputs, vocab, special_tokens):
|
||||||
start_time = time.time()
|
|
||||||
print("🔥 starting compile_regex_grammar", inputs)
|
|
||||||
schema = build_regex_from_schema(inputs)
|
schema = build_regex_from_schema(inputs)
|
||||||
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
|
|
||||||
tokenizer = adapt_tokenizer(vocab, special_tokens)
|
tokenizer = adapt_tokenizer(vocab, special_tokens)
|
||||||
print(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
|
|
||||||
fsm = RegexFSM(schema, tokenizer)
|
fsm = RegexFSM(schema, tokenizer)
|
||||||
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
|
|
||||||
return fsm
|
return fsm
|
||||||
|
|
||||||
def convert_grammar_to_regex(inputs):
|
def convert_grammar_to_regex(inputs):
|
||||||
start_time = time.time()
|
return build_regex_from_schema(inputs)
|
||||||
print("🔥 starting convert_grammar_to_regex", inputs)
|
|
||||||
schema = build_regex_from_schema(inputs)
|
|
||||||
print(f"Compiled grammar in {time.time() - start_time:.2f}s")
|
|
||||||
return schema
|
|
||||||
"#,
|
"#,
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
@ -645,19 +624,7 @@ def convert_grammar_to_regex(inputs):
|
|||||||
.getattr("states_to_token_maps")?
|
.getattr("states_to_token_maps")?
|
||||||
.extract::<StateTokenMaps>()?;
|
.extract::<StateTokenMaps>()?;
|
||||||
|
|
||||||
println!("🔥 elapsed: {:?}", start_time.elapsed());
|
|
||||||
|
|
||||||
// size of serialized states_to_token_maps
|
|
||||||
let states_to_token_maps_json = serde_json::to_string(&states_to_token_maps).unwrap();
|
|
||||||
println!(
|
|
||||||
"🔥 states_to_token_maps size: {:.2}MB",
|
|
||||||
states_to_token_maps_json.len() as f64 / 1024.0 / 1024.0
|
|
||||||
);
|
|
||||||
|
|
||||||
let result = regex_fsm.into_ref(py).extract().unwrap();
|
let result = regex_fsm.into_ref(py).extract().unwrap();
|
||||||
|
|
||||||
println!("result: {:?}", result);
|
|
||||||
|
|
||||||
Ok((result, states_to_token_maps))
|
Ok((result, states_to_token_maps))
|
||||||
})
|
})
|
||||||
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
.map_err(|e| ValidationError::InvalidGrammar(e.to_string()))?;
|
||||||
|
@ -475,22 +475,26 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
fsm_state: DefaultDict[int, int]
|
fsm_state: DefaultDict[int, int]
|
||||||
fsm: RegexFSM
|
fsm: RegexFSM
|
||||||
|
|
||||||
def __init__(self, tokenizer, device, grammar, grammar_type, states_to_token_maps):
|
def __init__(self, tokenizer, device, states_to_token_maps):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
|
||||||
|
|
||||||
# TODO: use the precompiled grammar here
|
start_states = states_to_token_maps.start_states
|
||||||
self.states_to_token_maps = states_to_token_maps
|
tokens = states_to_token_maps.tokens
|
||||||
precompiled_grammar = RegexFSM.precompiled(
|
end_states = states_to_token_maps.end_states
|
||||||
states_to_token_maps=states_to_token_maps,
|
|
||||||
empty_token_ids=None,
|
|
||||||
vocabulary=None,
|
|
||||||
eos_token_id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
_states_to_token_maps = {}
|
||||||
grammar_type, grammar, self.tokenizer
|
for i in range(len(start_states)):
|
||||||
)
|
if start_states[i] in _states_to_token_maps:
|
||||||
|
_states_to_token_maps[start_states[i]][tokens[i]] = end_states[i]
|
||||||
|
else:
|
||||||
|
_states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]}
|
||||||
|
|
||||||
|
fsm = object.__new__(RegexFSM)
|
||||||
|
fsm.states_to_token_maps = _states_to_token_maps
|
||||||
|
fsm.empty_token_ids = None
|
||||||
|
fsm.vocabulary = list(tokenizer.get_vocab().values())
|
||||||
|
fsm.eos_token_id = tokenizer.eos_token_id
|
||||||
|
self.fsm = fsm
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -560,17 +564,14 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
|
|
||||||
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
def __init__(
|
def __init__(self, tokenizer, device, states_to_token_maps):
|
||||||
self, tokenizer, device, grammars, grammar_types, states_to_token_maps
|
|
||||||
):
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
|
||||||
self.fsms = []
|
self.fsms = []
|
||||||
|
|
||||||
for grammar, grammar_type in zip(grammars, grammar_types):
|
for states_to_token_map in states_to_token_maps:
|
||||||
start_states = states_to_token_maps[0].start_states
|
start_states = states_to_token_map.start_states
|
||||||
tokens = states_to_token_maps[0].tokens
|
tokens = states_to_token_map.tokens
|
||||||
end_states = states_to_token_maps[0].end_states
|
end_states = states_to_token_map.end_states
|
||||||
|
|
||||||
_states_to_token_maps = {}
|
_states_to_token_maps = {}
|
||||||
for i in range(len(start_states)):
|
for i in range(len(start_states)):
|
||||||
@ -579,18 +580,11 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
|||||||
else:
|
else:
|
||||||
_states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]}
|
_states_to_token_maps[start_states[i]] = {tokens[i]: end_states[i]}
|
||||||
|
|
||||||
# TODO: cleanup how precompiled grammars are handled
|
fsm = object.__new__(RegexFSM)
|
||||||
precompiled_grammar = RegexFSM.precompiled(
|
fsm.states_to_token_maps = _states_to_token_maps
|
||||||
states_to_token_maps=_states_to_token_maps,
|
fsm.empty_token_ids = None
|
||||||
empty_token_ids=None,
|
fsm.vocabulary = list(tokenizer.get_vocab().values())
|
||||||
vocabulary=list(tokenizer.get_vocab().values()),
|
fsm.eos_token_id = tokenizer.eos_token_id
|
||||||
eos_token_id=self.tokenizer.eos_token_id,
|
|
||||||
)
|
|
||||||
# fsm = GrammarLogitProcessor._cached_compile_fsm(
|
|
||||||
# grammar_type, grammar, self.tokenizer
|
|
||||||
# )
|
|
||||||
|
|
||||||
fsm = precompiled_grammar
|
|
||||||
self.fsms.append(fsm)
|
self.fsms.append(fsm)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
@ -36,8 +36,6 @@ class NextTokenChooser:
|
|||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||||
grammar: str = "",
|
|
||||||
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
|
|
||||||
fsm_grammar_state: int = 0,
|
fsm_grammar_state: int = 0,
|
||||||
states_to_token_maps: List[List[int]] = None,
|
states_to_token_maps: List[List[int]] = None,
|
||||||
):
|
):
|
||||||
@ -55,10 +53,8 @@ class NextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.grammar_processor = (
|
self.grammar_processor = (
|
||||||
GrammarLogitProcessor(
|
GrammarLogitProcessor(tokenizer, device, states_to_token_maps)
|
||||||
tokenizer, device, grammar, grammar_type, states_to_token_maps
|
if states_to_token_maps
|
||||||
)
|
|
||||||
if grammar != ""
|
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@ -80,7 +76,6 @@ class NextTokenChooser:
|
|||||||
|
|
||||||
self.choice = Sampling(seed, device) if sampling else Greedy()
|
self.choice = Sampling(seed, device) if sampling else Greedy()
|
||||||
self.fsm_grammar_state = fsm_grammar_state
|
self.fsm_grammar_state = fsm_grammar_state
|
||||||
self.grammar = grammar
|
|
||||||
self.states_to_token_maps = states_to_token_maps
|
self.states_to_token_maps = states_to_token_maps
|
||||||
|
|
||||||
def __call__(self, input_ids, scores):
|
def __call__(self, input_ids, scores):
|
||||||
@ -128,8 +123,6 @@ class NextTokenChooser:
|
|||||||
seed=pb.seed,
|
seed=pb.seed,
|
||||||
device=device,
|
device=device,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
grammar=pb.grammar,
|
|
||||||
grammar_type=pb.grammar_type,
|
|
||||||
states_to_token_maps=pb.states_to_token_maps,
|
states_to_token_maps=pb.states_to_token_maps,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -236,8 +229,6 @@ class HeterogeneousNextTokenChooser:
|
|||||||
do_sample: List[bool],
|
do_sample: List[bool],
|
||||||
seeds: List[int],
|
seeds: List[int],
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
grammars: List[str],
|
|
||||||
grammar_types: List[int],
|
|
||||||
fsm_grammar_states: List[int],
|
fsm_grammar_states: List[int],
|
||||||
states_to_token_maps: List[List[List[int]]],
|
states_to_token_maps: List[List[List[int]]],
|
||||||
):
|
):
|
||||||
@ -272,10 +263,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.grammar_processor = (
|
self.grammar_processor = (
|
||||||
HeterogeneousGrammarLogitProcessor(
|
HeterogeneousGrammarLogitProcessor(tokenizer, device, states_to_token_maps)
|
||||||
tokenizer, device, grammars, grammar_types, states_to_token_maps
|
if any(states_to_token_maps)
|
||||||
)
|
|
||||||
if any([grammar != "" for grammar in grammars])
|
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -312,8 +301,6 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.fsm_grammar_states = fsm_grammar_states
|
self.fsm_grammar_states = fsm_grammar_states
|
||||||
self.grammars = grammars
|
|
||||||
self.grammar_types = grammar_types
|
|
||||||
self.states_to_token_maps = states_to_token_maps
|
self.states_to_token_maps = states_to_token_maps
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -447,17 +434,11 @@ class HeterogeneousNextTokenChooser:
|
|||||||
self.seeds = [self.seeds[i] for i in indices]
|
self.seeds = [self.seeds[i] for i in indices]
|
||||||
self.do_sample = [self.do_sample[i] for i in indices]
|
self.do_sample = [self.do_sample[i] for i in indices]
|
||||||
|
|
||||||
new_grammars = []
|
|
||||||
new_fsm_grammar_states = []
|
new_fsm_grammar_states = []
|
||||||
new_grammar_types = []
|
|
||||||
for i in indices:
|
for i in indices:
|
||||||
new_grammars.append(self.grammars[i])
|
|
||||||
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
|
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
|
||||||
new_grammar_types.append(self.grammar_types[i])
|
|
||||||
|
|
||||||
self.grammars = new_grammars
|
|
||||||
self.fsm_grammar_states = new_fsm_grammar_states
|
self.fsm_grammar_states = new_fsm_grammar_states
|
||||||
self.grammar_types = new_grammar_types
|
|
||||||
|
|
||||||
if any(self.do_sample):
|
if any(self.do_sample):
|
||||||
self.choice.filter(indices)
|
self.choice.filter(indices)
|
||||||
@ -488,8 +469,6 @@ class HeterogeneousNextTokenChooser:
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
grammars=[pb_.grammar for pb_ in pb],
|
|
||||||
grammar_types=[pb_.grammar_type for pb_ in pb],
|
|
||||||
fsm_grammar_states=(
|
fsm_grammar_states=(
|
||||||
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
||||||
),
|
),
|
||||||
|
Loading…
Reference in New Issue
Block a user