mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: include grammar type in request, avoid alloc and improve proto types
This commit is contained in:
parent
8b9430fb68
commit
d849641b28
@ -51,6 +51,12 @@ message ClearCacheRequest {
|
|||||||
/// Empty response
|
/// Empty response
|
||||||
message ClearCacheResponse {}
|
message ClearCacheResponse {}
|
||||||
|
|
||||||
|
enum GrammarType {
|
||||||
|
GRAMMAR_TYPE_NONE = 0;
|
||||||
|
GRAMMAR_TYPE_JSON = 1;
|
||||||
|
GRAMMAR_TYPE_REGEX = 2;
|
||||||
|
}
|
||||||
|
|
||||||
message NextTokenChooserParameters {
|
message NextTokenChooserParameters {
|
||||||
/// exponential scaling output probability distribution
|
/// exponential scaling output probability distribution
|
||||||
float temperature = 1;
|
float temperature = 1;
|
||||||
@ -72,6 +78,8 @@ message NextTokenChooserParameters {
|
|||||||
bool watermark = 8;
|
bool watermark = 8;
|
||||||
/// grammar (applied if not empty)
|
/// grammar (applied if not empty)
|
||||||
string grammar = 10;
|
string grammar = 10;
|
||||||
|
/// grammar type
|
||||||
|
GrammarType grammar_type = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StoppingCriteriaParameters {
|
message StoppingCriteriaParameters {
|
||||||
|
@ -129,6 +129,7 @@ impl Client {
|
|||||||
frequency_penalty: 0.1,
|
frequency_penalty: 0.1,
|
||||||
watermark: true,
|
watermark: true,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
|
grammar_type: GrammarType::None as i32,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: max_total_tokens - truncate,
|
max_new_tokens: max_total_tokens - truncate,
|
||||||
|
@ -9,8 +9,8 @@ pub use client::Client;
|
|||||||
pub use pb::generate::v2::HealthResponse;
|
pub use pb::generate::v2::HealthResponse;
|
||||||
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
pub use pb::generate::v2::InfoResponse as ShardInfo;
|
||||||
pub use pb::generate::v2::{
|
pub use pb::generate::v2::{
|
||||||
Batch, CachedBatch, FinishReason, GeneratedText, Generation, NextTokenChooserParameters,
|
Batch, CachedBatch, FinishReason, GeneratedText, Generation, GrammarType,
|
||||||
Request, StoppingCriteriaParameters, Tokens,
|
NextTokenChooserParameters, Request, StoppingCriteriaParameters, Tokens,
|
||||||
};
|
};
|
||||||
pub use sharded_client::ShardedClient;
|
pub use sharded_client::ShardedClient;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use text_generation_client::GrammarType as ProtoGrammarType;
|
||||||
use text_generation_client::{
|
use text_generation_client::{
|
||||||
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
Batch, NextTokenChooserParameters, Request, ShardedClient, StoppingCriteriaParameters,
|
||||||
};
|
};
|
||||||
@ -46,6 +47,7 @@ impl Health {
|
|||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
|
grammar_type: ProtoGrammarType::None as i32,
|
||||||
}),
|
}),
|
||||||
stopping_parameters: Some(StoppingCriteriaParameters {
|
stopping_parameters: Some(StoppingCriteriaParameters {
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
|
@ -44,6 +44,40 @@ impl HubTokenizerConfig {
|
|||||||
serde_json::from_str(&content).unwrap_or_default()
|
serde_json::from_str(&content).unwrap_or_default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
mod json_object_or_string_to_string {
|
||||||
|
// This custom deserializer is used to handle the fact that the grammar field can be either a
|
||||||
|
// string or an object. In both cases we handle it as a string, but also provide this convience
|
||||||
|
// to the user to be flexible with the input.
|
||||||
|
use super::*;
|
||||||
|
use serde::de;
|
||||||
|
use serde::Deserializer;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let value = Value::deserialize(deserializer)?;
|
||||||
|
|
||||||
|
match value {
|
||||||
|
Value::String(s) => Ok(s),
|
||||||
|
Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()),
|
||||||
|
_ => Err(de::Error::custom("expected string or object for grammar")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
#[serde(tag = "type", content = "value")]
|
||||||
|
pub(crate) enum GrammarType {
|
||||||
|
#[serde(
|
||||||
|
rename = "json",
|
||||||
|
deserialize_with = "json_object_or_string_to_string::deserialize"
|
||||||
|
)]
|
||||||
|
Json(String),
|
||||||
|
#[serde(rename = "regex")]
|
||||||
|
Regex(String),
|
||||||
|
}
|
||||||
|
|
||||||
mod token_serde {
|
mod token_serde {
|
||||||
use super::*;
|
use super::*;
|
||||||
@ -201,31 +235,8 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 5)]
|
||||||
pub top_n_tokens: Option<u32>,
|
pub top_n_tokens: Option<u32>,
|
||||||
#[serde(
|
#[serde(default)]
|
||||||
default,
|
pub grammar: Option<GrammarType>,
|
||||||
deserialize_with = "json_object_or_string_to_string::deserialize"
|
|
||||||
)]
|
|
||||||
pub grammar: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
mod json_object_or_string_to_string {
|
|
||||||
use super::*;
|
|
||||||
use serde::de;
|
|
||||||
use serde::Deserializer;
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<String, D::Error>
|
|
||||||
where
|
|
||||||
D: Deserializer<'de>,
|
|
||||||
{
|
|
||||||
let value = Value::deserialize(deserializer)?;
|
|
||||||
|
|
||||||
match value {
|
|
||||||
Value::String(s) => Ok(s),
|
|
||||||
Value::Object(o) => Ok(serde_json::to_string(&o).unwrap()),
|
|
||||||
_ => Err(de::Error::custom("expected string or object for grammar")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> Option<u32> {
|
fn default_max_new_tokens() -> Option<u32> {
|
||||||
@ -251,7 +262,7 @@ fn default_parameters() -> GenerateParameters {
|
|||||||
decoder_input_details: false,
|
decoder_input_details: false,
|
||||||
seed: None,
|
seed: None,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: String::new(),
|
grammar: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,7 +343,9 @@ enum QueueCommand {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{
|
||||||
|
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
use tracing::info_span;
|
use tracing::info_span;
|
||||||
|
|
||||||
fn default_entry() -> (
|
fn default_entry() -> (
|
||||||
@ -369,6 +371,7 @@ mod tests {
|
|||||||
frequency_penalty: 0.0,
|
frequency_penalty: 0.0,
|
||||||
watermark: false,
|
watermark: false,
|
||||||
grammar: String::new(),
|
grammar: String::new(),
|
||||||
|
grammar_type: ProtoGrammarType::None as i32,
|
||||||
},
|
},
|
||||||
stopping_parameters: StoppingCriteriaParameters {
|
stopping_parameters: StoppingCriteriaParameters {
|
||||||
ignore_eos_token: false,
|
ignore_eos_token: false,
|
||||||
|
@ -614,7 +614,7 @@ async fn chat_completions(
|
|||||||
decoder_input_details: !stream,
|
decoder_input_details: !stream,
|
||||||
seed,
|
seed,
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
grammar: String::new(),
|
grammar: None,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest, GrammarType};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{
|
||||||
|
GrammarType as ProtoGrammarType, NextTokenChooserParameters, StoppingCriteriaParameters,
|
||||||
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
use tokenizers::TruncationDirection;
|
use tokenizers::TruncationDirection;
|
||||||
@ -296,10 +298,27 @@ impl Validation {
|
|||||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Ensure that grammar is not set if it's not supported
|
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
|
||||||
if !grammar.is_empty() && !self.grammar_support {
|
// NOTE: this is currently difficult because we need the tokenizer in Python to build
|
||||||
return Err(ValidationError::Grammar);
|
// the FSM and we'd have to load a copy of the tokenizer into our Pyo3 instance which
|
||||||
}
|
// may be slow and memory intensive. Best case is to have a Rust implementation of the FSM
|
||||||
|
// compiler and use that to build the FSM here.
|
||||||
|
|
||||||
|
// Validate grammar and unpack the grammar and type for the proto message
|
||||||
|
let (grammar, grammar_type) = match grammar {
|
||||||
|
Some(grammar) => {
|
||||||
|
// Ensure that grammar is not set if it's not supported
|
||||||
|
if !self.grammar_support {
|
||||||
|
return Err(ValidationError::Grammar);
|
||||||
|
}
|
||||||
|
match grammar {
|
||||||
|
// currently both are handled the same way since compilation is done in Python
|
||||||
|
GrammarType::Json(json) => (json, ProtoGrammarType::Json.into()),
|
||||||
|
GrammarType::Regex(regex) => (regex, ProtoGrammarType::Regex.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => (String::new(), ProtoGrammarType::None.into()),
|
||||||
|
};
|
||||||
|
|
||||||
let parameters = NextTokenChooserParameters {
|
let parameters = NextTokenChooserParameters {
|
||||||
temperature,
|
temperature,
|
||||||
@ -312,6 +331,7 @@ impl Validation {
|
|||||||
seed,
|
seed,
|
||||||
watermark,
|
watermark,
|
||||||
grammar,
|
grammar,
|
||||||
|
grammar_type,
|
||||||
};
|
};
|
||||||
let stopping_parameters = StoppingCriteriaParameters {
|
let stopping_parameters = StoppingCriteriaParameters {
|
||||||
max_new_tokens,
|
max_new_tokens,
|
||||||
|
@ -5,6 +5,7 @@ import json
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, List, Dict, Union
|
from typing import Optional, List, Dict, Union
|
||||||
|
from text_generation_server.pb.generate_pb2 import GrammarType
|
||||||
|
|
||||||
from outlines.fsm.fsm import RegexFSM
|
from outlines.fsm.fsm import RegexFSM
|
||||||
from outlines.fsm.json_schema import build_regex_from_object
|
from outlines.fsm.json_schema import build_regex_from_object
|
||||||
@ -476,10 +477,12 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
fsm_state: DefaultDict[int, int]
|
fsm_state: DefaultDict[int, int]
|
||||||
fsm: RegexFSM
|
fsm: RegexFSM
|
||||||
|
|
||||||
def __init__(self, tokenizer, device, grammar):
|
def __init__(self, tokenizer, device, grammar, grammar_type):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||||
self.fsm = GrammarLogitProcessor._cached_compile_fsm(grammar, self.tokenizer)
|
self.fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||||
|
grammar_type, grammar, self.tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -508,17 +511,12 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
# TODO: move grammar compilation into the router
|
# TODO: move grammar compilation into the router
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@lru_cache(maxsize=32, typed=True)
|
@lru_cache(maxsize=32, typed=True)
|
||||||
def _cached_compile_fsm(schema, tokenizer):
|
def _cached_compile_fsm(grammar_type, schema, tokenizer):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
# Detect if schema is a json object before converting it to regex.
|
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
||||||
# We need to check if it's a valid json object before converting it to regex
|
schema = build_regex_from_object(schema)
|
||||||
# and cannot simply test if it starts with '{' and ends with '}' because there
|
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
||||||
# are valid regexes that start and end with curly braces.
|
pass # schema is already a regex just here for clarity
|
||||||
try:
|
|
||||||
json.loads(schema) # check if schema is a valid json
|
|
||||||
schema = build_regex_from_object(schema) # convert schema to regex
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
fsm = RegexFSM(schema, tokenizer)
|
fsm = RegexFSM(schema, tokenizer)
|
||||||
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
logger.debug(f"Compiled FSM in {time.time() - start_time:.2f}s")
|
||||||
return fsm
|
return fsm
|
||||||
@ -561,32 +559,32 @@ class GrammarLogitProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
|
|
||||||
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||||
def __init__(self, tokenizer, device, grammars):
|
def __init__(self, tokenizer, device, grammars, grammar_type):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
|
||||||
self.fsms = [
|
self.fsms = []
|
||||||
(
|
for i in range(len(grammars)):
|
||||||
GrammarLogitProcessor._cached_compile_fsm(g, self.tokenizer)
|
fsm = GrammarLogitProcessor._cached_compile_fsm(
|
||||||
if g
|
grammar_type[i], grammars[i], self.tokenizer
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
for g in grammars
|
self.fsms.append(fsm)
|
||||||
]
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
fsm_grammar_states: List[int],
|
fsm_grammar_states: List[int],
|
||||||
|
mask: torch.Tensor,
|
||||||
):
|
):
|
||||||
for i in range(logits.shape[0]):
|
for i in range(logits.shape[0]):
|
||||||
fsm = self.fsms[i]
|
fsm = self.fsms[i]
|
||||||
if fsm_grammar_states[i] == -1 or fsm is None:
|
if fsm_grammar_states[i] == -1 or fsm is None:
|
||||||
continue
|
continue
|
||||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||||
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device)
|
|
||||||
mask[allowed_tokens] = 0
|
mask[allowed_tokens] = 0
|
||||||
biased_scores = logits[i] + mask
|
biased_scores = logits[i] + mask
|
||||||
|
mask.fill_(-math.inf)
|
||||||
logits[i] = biased_scores
|
logits[i] = biased_scores
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):
|
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars):
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import re
|
import re
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from text_generation_server.pb import generate_pb2
|
from text_generation_server.pb import generate_pb2
|
||||||
from text_generation_server.pb.generate_pb2 import FinishReason
|
from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
|
||||||
from text_generation_server.utils.logits_process import (
|
from text_generation_server.utils.logits_process import (
|
||||||
FrequencyPenaltyLogitsProcessor,
|
FrequencyPenaltyLogitsProcessor,
|
||||||
GrammarLogitProcessor,
|
GrammarLogitProcessor,
|
||||||
@ -36,6 +37,7 @@ class NextTokenChooser:
|
|||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||||
grammar: str = "",
|
grammar: str = "",
|
||||||
|
grammar_type: GrammarType = GrammarType.GRAMMAR_TYPE_NONE,
|
||||||
fsm_grammar_state: int = 0,
|
fsm_grammar_state: int = 0,
|
||||||
):
|
):
|
||||||
self.watermark_processor = (
|
self.watermark_processor = (
|
||||||
@ -52,7 +54,9 @@ class NextTokenChooser:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.grammar_processor = (
|
self.grammar_processor = (
|
||||||
GrammarLogitProcessor(tokenizer, device, grammar) if grammar != "" else None
|
GrammarLogitProcessor(tokenizer, device, grammar, grammar_type)
|
||||||
|
if grammar != ""
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
@ -121,6 +125,7 @@ class NextTokenChooser:
|
|||||||
device=device,
|
device=device,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
grammar=pb.grammar,
|
grammar=pb.grammar,
|
||||||
|
grammar_type=pb.grammar_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -227,6 +232,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
seeds: List[int],
|
seeds: List[int],
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
grammars: List[str],
|
grammars: List[str],
|
||||||
|
grammar_types: List[GrammarType],
|
||||||
fsm_grammar_states=List[int],
|
fsm_grammar_states=List[int],
|
||||||
):
|
):
|
||||||
warpers = []
|
warpers = []
|
||||||
@ -260,7 +266,9 @@ class HeterogeneousNextTokenChooser:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.grammar_processor = (
|
self.grammar_processor = (
|
||||||
HeterogeneousGrammarLogitProcessor(tokenizer, device, grammars)
|
HeterogeneousGrammarLogitProcessor(
|
||||||
|
tokenizer, device, grammars, grammar_types
|
||||||
|
)
|
||||||
if any([grammar != "" for grammar in grammars])
|
if any([grammar != "" for grammar in grammars])
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
@ -319,6 +327,8 @@ class HeterogeneousNextTokenChooser:
|
|||||||
scores = scores.view(B, S, -1)
|
scores = scores.view(B, S, -1)
|
||||||
|
|
||||||
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
next_ids = torch.zeros((B, S), device=scores.device, dtype=torch.long)
|
||||||
|
mask = torch.full((scores.shape[-1],), -math.inf, device=self.device)
|
||||||
|
|
||||||
for j in range(S):
|
for j in range(S):
|
||||||
_scores = scores[:, j]
|
_scores = scores[:, j]
|
||||||
if self.watermark_processor is not None:
|
if self.watermark_processor is not None:
|
||||||
@ -330,7 +340,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
for warper in self.warpers:
|
for warper in self.warpers:
|
||||||
_scores = warper(input_ids, _scores)
|
_scores = warper(input_ids, _scores)
|
||||||
if self.grammar_processor is not None:
|
if self.grammar_processor is not None:
|
||||||
_scores = self.grammar_processor(_scores, self.fsm_grammar_states)
|
_scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask)
|
||||||
_next_ids = self.choice(_scores)
|
_next_ids = self.choice(_scores)
|
||||||
scores[:, j] = _scores
|
scores[:, j] = _scores
|
||||||
next_ids[:, j] = _next_ids
|
next_ids[:, j] = _next_ids
|
||||||
@ -421,12 +431,15 @@ class HeterogeneousNextTokenChooser:
|
|||||||
|
|
||||||
new_grammars = []
|
new_grammars = []
|
||||||
new_fsm_grammar_states = []
|
new_fsm_grammar_states = []
|
||||||
|
new_grammar_types = []
|
||||||
for i in indices:
|
for i in indices:
|
||||||
new_grammars.append(self.grammars[i])
|
new_grammars.append(self.grammars[i])
|
||||||
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
|
new_fsm_grammar_states.append(self.fsm_grammar_states[i])
|
||||||
|
new_grammar_types.append(self.grammar_types[i])
|
||||||
|
|
||||||
self.grammars = new_grammars
|
self.grammars = new_grammars
|
||||||
self.fsm_grammar_states = new_fsm_grammar_states
|
self.fsm_grammar_states = new_fsm_grammar_states
|
||||||
|
self.grammar_types = new_grammar_types
|
||||||
|
|
||||||
if any(self.do_sample):
|
if any(self.do_sample):
|
||||||
self.choice.filter(indices)
|
self.choice.filter(indices)
|
||||||
@ -457,6 +470,7 @@ class HeterogeneousNextTokenChooser:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
grammars=[pb_.grammar for pb_ in pb],
|
grammars=[pb_.grammar for pb_ in pb],
|
||||||
|
grammar_types=[pb_.grammar_type for pb_ in pb],
|
||||||
fsm_grammar_states=[0] * len(pb),
|
fsm_grammar_states=[0] * len(pb),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user