mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Change add_special_tokens
in order to have the correct tokens for chat
input and not (since it's super important with the prefixing now)
This commit is contained in:
parent
f1c0735453
commit
7f1816a4e1
@ -120,10 +120,11 @@ impl Infer {
|
|||||||
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
||||||
// Tokenize request
|
// Tokenize request
|
||||||
let inputs = request.inputs;
|
let inputs = request.inputs;
|
||||||
|
let add_special_tokens = request.add_special_tokens;
|
||||||
let truncate = request.parameters.truncate;
|
let truncate = request.parameters.truncate;
|
||||||
let encoding = self
|
let encoding = self
|
||||||
.validation
|
.validation
|
||||||
.tokenize(inputs, truncate)
|
.tokenize(inputs, add_special_tokens, truncate)
|
||||||
.await
|
.await
|
||||||
.map_err(|err| {
|
.map_err(|err| {
|
||||||
tracing::error!("Tokenization {err}");
|
tracing::error!("Tokenization {err}");
|
||||||
|
@ -1082,6 +1082,16 @@ pub(crate) struct GenerateRequest {
|
|||||||
pub inputs: String,
|
pub inputs: String,
|
||||||
#[serde(default = "default_parameters")]
|
#[serde(default = "default_parameters")]
|
||||||
pub parameters: GenerateParameters,
|
pub parameters: GenerateParameters,
|
||||||
|
|
||||||
|
/// This is used internally because some requests
|
||||||
|
/// already contain the templated input therefore
|
||||||
|
/// we shouldn't add the special tokens.
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub add_special_tokens: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_true() -> bool {
|
||||||
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||||
@ -1099,6 +1109,7 @@ impl From<CompatGenerateRequest> for GenerateRequest {
|
|||||||
fn from(req: CompatGenerateRequest) -> Self {
|
fn from(req: CompatGenerateRequest) -> Self {
|
||||||
Self {
|
Self {
|
||||||
inputs: req.inputs,
|
inputs: req.inputs,
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: req.parameters,
|
parameters: req.parameters,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -158,6 +158,7 @@ async fn get_chat_tokenize(
|
|||||||
|
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
inputs,
|
inputs,
|
||||||
|
add_special_tokens: false,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature,
|
temperature,
|
||||||
@ -754,6 +755,7 @@ async fn completions(
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|prompt| GenerateRequest {
|
.map(|prompt| GenerateRequest {
|
||||||
inputs: prompt.to_string(),
|
inputs: prompt.to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature,
|
temperature,
|
||||||
@ -1180,6 +1182,7 @@ async fn chat_completions(
|
|||||||
// build the request passing some parameters
|
// build the request passing some parameters
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
inputs: inputs.to_string(),
|
inputs: inputs.to_string(),
|
||||||
|
add_special_tokens: false,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: None,
|
best_of: None,
|
||||||
temperature,
|
temperature,
|
||||||
@ -1386,6 +1389,7 @@ async fn vertex_compatibility(
|
|||||||
.map(|instance| {
|
.map(|instance| {
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
inputs: instance.inputs.clone(),
|
inputs: instance.inputs.clone(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
do_sample: true,
|
do_sample: true,
|
||||||
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||||
|
@ -95,6 +95,7 @@ impl Validation {
|
|||||||
pub async fn tokenize(
|
pub async fn tokenize(
|
||||||
&self,
|
&self,
|
||||||
inputs: String,
|
inputs: String,
|
||||||
|
add_special_tokens: bool,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
|
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
@ -104,7 +105,11 @@ impl Validation {
|
|||||||
// Send request to the background validation task
|
// Send request to the background validation task
|
||||||
// Unwrap is safe here
|
// Unwrap is safe here
|
||||||
sender
|
sender
|
||||||
.send(((inputs, truncate), response_sender, Span::current()))
|
.send((
|
||||||
|
(inputs, add_special_tokens, truncate),
|
||||||
|
response_sender,
|
||||||
|
Span::current(),
|
||||||
|
))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Await on response channel
|
// Await on response channel
|
||||||
@ -121,11 +126,15 @@ impl Validation {
|
|||||||
async fn validate_input(
|
async fn validate_input(
|
||||||
&self,
|
&self,
|
||||||
inputs: String,
|
inputs: String,
|
||||||
|
add_special_tokens: bool,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
max_new_tokens: Option<u32>,
|
max_new_tokens: Option<u32>,
|
||||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
if let Some((encoding, inputs)) = self
|
||||||
|
.tokenize(inputs.clone(), add_special_tokens, truncate)
|
||||||
|
.await?
|
||||||
|
{
|
||||||
// Create response channel
|
// Create response channel
|
||||||
let input_length = if let Some(truncate) = truncate {
|
let input_length = if let Some(truncate) = truncate {
|
||||||
std::cmp::min(encoding.len(), truncate)
|
std::cmp::min(encoding.len(), truncate)
|
||||||
@ -324,7 +333,12 @@ impl Validation {
|
|||||||
|
|
||||||
// Validate inputs
|
// Validate inputs
|
||||||
let (inputs, input_ids, input_length, max_new_tokens) = self
|
let (inputs, input_ids, input_length, max_new_tokens) = self
|
||||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
.validate_input(
|
||||||
|
request.inputs,
|
||||||
|
request.add_special_tokens,
|
||||||
|
truncate,
|
||||||
|
max_new_tokens,
|
||||||
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
|
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
|
||||||
@ -449,12 +463,15 @@ fn tokenizer_worker(
|
|||||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||||
) {
|
) {
|
||||||
// Loop over requests
|
// Loop over requests
|
||||||
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
|
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
|
||||||
|
receiver.blocking_recv()
|
||||||
|
{
|
||||||
parent_span.in_scope(|| {
|
parent_span.in_scope(|| {
|
||||||
response_tx
|
response_tx
|
||||||
.send(prepare_input(
|
.send(prepare_input(
|
||||||
inputs,
|
inputs,
|
||||||
truncate,
|
truncate,
|
||||||
|
add_special_tokens,
|
||||||
&tokenizer,
|
&tokenizer,
|
||||||
config.as_ref(),
|
config.as_ref(),
|
||||||
preprocessor_config.as_ref(),
|
preprocessor_config.as_ref(),
|
||||||
@ -591,6 +608,7 @@ fn image_tokens_fixup(config: &Config, text: String) -> String {
|
|||||||
fn prepare_input(
|
fn prepare_input(
|
||||||
inputs: String,
|
inputs: String,
|
||||||
_truncate: Option<usize>,
|
_truncate: Option<usize>,
|
||||||
|
add_special_tokens: bool,
|
||||||
tokenizer: &Tokenizer,
|
tokenizer: &Tokenizer,
|
||||||
config: Option<&Config>,
|
config: Option<&Config>,
|
||||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||||
@ -628,14 +646,14 @@ fn prepare_input(
|
|||||||
|
|
||||||
// Get the number of tokens in the input
|
// Get the number of tokens in the input
|
||||||
let encoding = tokenizer
|
let encoding = tokenizer
|
||||||
.encode(tokenizer_query, true)
|
.encode(tokenizer_query, add_special_tokens)
|
||||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||||
|
|
||||||
Ok((encoding, input_chunks))
|
Ok((encoding, input_chunks))
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizerRequest = (
|
type TokenizerRequest = (
|
||||||
(String, Option<usize>),
|
(String, bool, Option<usize>),
|
||||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
||||||
Span,
|
Span,
|
||||||
);
|
);
|
||||||
@ -826,7 +844,7 @@ mod tests {
|
|||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation
|
match validation
|
||||||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||||
@ -861,7 +879,7 @@ mod tests {
|
|||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation
|
match validation
|
||||||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||||
@ -895,6 +913,7 @@ mod tests {
|
|||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: Some(2),
|
best_of: Some(2),
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
@ -934,6 +953,7 @@ mod tests {
|
|||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: Some(1.0),
|
top_p: Some(1.0),
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
@ -949,6 +969,7 @@ mod tests {
|
|||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: Some(0.99),
|
top_p: Some(0.99),
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
@ -964,6 +985,7 @@ mod tests {
|
|||||||
let valid_request = validation
|
let valid_request = validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: None,
|
top_p: None,
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
@ -1002,6 +1024,7 @@ mod tests {
|
|||||||
match validation
|
match validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: Some(5),
|
top_n_tokens: Some(5),
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
@ -1017,6 +1040,7 @@ mod tests {
|
|||||||
validation
|
validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: Some(4),
|
top_n_tokens: Some(4),
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
@ -1029,6 +1053,7 @@ mod tests {
|
|||||||
validation
|
validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: Some(0),
|
top_n_tokens: Some(0),
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
@ -1041,6 +1066,7 @@ mod tests {
|
|||||||
let valid_request = validation
|
let valid_request = validation
|
||||||
.validate(GenerateRequest {
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
|
add_special_tokens: true,
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
max_new_tokens: Some(5),
|
max_new_tokens: Some(5),
|
||||||
@ -1089,6 +1115,7 @@ mod tests {
|
|||||||
let chunks = match validation
|
let chunks = match validation
|
||||||
.tokenize(
|
.tokenize(
|
||||||
format!("test", PIXEL_GIF),
|
format!("test", PIXEL_GIF),
|
||||||
|
true,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@ -1148,6 +1175,7 @@ mod tests {
|
|||||||
"test",
|
"test",
|
||||||
PIXEL_GIF, PIXEL_GIF
|
PIXEL_GIF, PIXEL_GIF
|
||||||
),
|
),
|
||||||
|
true,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
|
@ -266,6 +266,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
orig_input_length = len(tokenized_input)
|
orig_input_length = len(tokenized_input)
|
||||||
|
|
||||||
prefix_len = r.prefix_len
|
prefix_len = r.prefix_len
|
||||||
|
assert prefix_len <= orig_input_length
|
||||||
if prefix_len == orig_input_length:
|
if prefix_len == orig_input_length:
|
||||||
assert prefix_len > 0
|
assert prefix_len > 0
|
||||||
prefix_len -= 1
|
prefix_len -= 1
|
||||||
@ -282,6 +283,7 @@ class FlashCausalLMBatch(Batch):
|
|||||||
all_input_ids.append(tokenized_input)
|
all_input_ids.append(tokenized_input)
|
||||||
|
|
||||||
# Position ids
|
# Position ids
|
||||||
|
print(f"Prefix {prefix_len} - Orig {orig_input_length}")
|
||||||
request_position_ids = torch.arange(
|
request_position_ids = torch.arange(
|
||||||
prefix_len, orig_input_length, dtype=torch.int32
|
prefix_len, orig_input_length, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user