mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Modify the default for max_new_tokens
.
This commit is contained in:
parent
8ec1b87f16
commit
c05cabc730
@ -107,8 +107,8 @@ pub(crate) struct GenerateParameters {
|
|||||||
#[schema(default = "false", example = true)]
|
#[schema(default = "false", example = true)]
|
||||||
pub do_sample: bool,
|
pub do_sample: bool,
|
||||||
#[serde(default = "default_max_new_tokens")]
|
#[serde(default = "default_max_new_tokens")]
|
||||||
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
|
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "null")]
|
||||||
pub max_new_tokens: u32,
|
pub max_new_tokens: Option<u32>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
#[schema(nullable = true, default = "null", example = false)]
|
#[schema(nullable = true, default = "null", example = false)]
|
||||||
pub return_full_text: Option<bool>,
|
pub return_full_text: Option<bool>,
|
||||||
@ -140,8 +140,8 @@ pub(crate) struct GenerateParameters {
|
|||||||
pub top_n_tokens: Option<u32>,
|
pub top_n_tokens: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_new_tokens() -> u32 {
|
fn default_max_new_tokens() -> Option<u32> {
|
||||||
20
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_parameters() -> GenerateParameters {
|
fn default_parameters() -> GenerateParameters {
|
||||||
|
@ -67,8 +67,8 @@ impl Validation {
|
|||||||
&self,
|
&self,
|
||||||
inputs: String,
|
inputs: String,
|
||||||
truncate: Option<usize>,
|
truncate: Option<usize>,
|
||||||
max_new_tokens: u32,
|
max_new_tokens: Option<u32>,
|
||||||
) -> Result<(String, usize), ValidationError> {
|
) -> Result<(String, usize, u32), ValidationError> {
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
if let Some(sender) = &self.sender {
|
if let Some(sender) = &self.sender {
|
||||||
// Create response channel
|
// Create response channel
|
||||||
@ -84,6 +84,11 @@ impl Validation {
|
|||||||
let (inputs, input_length) = response_receiver.await.unwrap()?;
|
let (inputs, input_length) = response_receiver.await.unwrap()?;
|
||||||
|
|
||||||
// Get total tokens
|
// Get total tokens
|
||||||
|
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens{
|
||||||
|
max_new_tokens
|
||||||
|
}else{
|
||||||
|
self.max_total_tokens.saturating_sub(input_length) as u32
|
||||||
|
};
|
||||||
let total_tokens = input_length + max_new_tokens as usize;
|
let total_tokens = input_length + max_new_tokens as usize;
|
||||||
|
|
||||||
// Validate MaxTotalTokens
|
// Validate MaxTotalTokens
|
||||||
@ -104,7 +109,7 @@ impl Validation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
metrics::histogram!("tgi_request_input_length", input_length as f64);
|
||||||
Ok((inputs, input_length))
|
Ok((inputs, input_length, max_new_tokens))
|
||||||
}
|
}
|
||||||
// Return inputs without validation
|
// Return inputs without validation
|
||||||
else {
|
else {
|
||||||
@ -112,6 +117,11 @@ impl Validation {
|
|||||||
// However, the inputs will be truncated by the python servers
|
// However, the inputs will be truncated by the python servers
|
||||||
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
||||||
let input_length = truncate.unwrap_or(self.max_input_length);
|
let input_length = truncate.unwrap_or(self.max_input_length);
|
||||||
|
let max_new_tokens: u32 = if let Some(max_new_tokens) = max_new_tokens{
|
||||||
|
max_new_tokens
|
||||||
|
}else{
|
||||||
|
self.max_total_tokens.saturating_sub(input_length) as u32
|
||||||
|
};
|
||||||
|
|
||||||
// Validate MaxNewTokens
|
// Validate MaxNewTokens
|
||||||
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 {
|
||||||
@ -121,7 +131,7 @@ impl Validation {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((inputs, input_length))
|
Ok((inputs, input_length, max_new_tokens))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -200,7 +210,7 @@ impl Validation {
|
|||||||
})
|
})
|
||||||
.unwrap_or(Ok(0))?;
|
.unwrap_or(Ok(0))?;
|
||||||
|
|
||||||
if max_new_tokens == 0 {
|
if max_new_tokens == Some(0) {
|
||||||
return Err(ValidationError::NegativeMaxNewTokens);
|
return Err(ValidationError::NegativeMaxNewTokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -247,7 +257,7 @@ impl Validation {
|
|||||||
.unwrap_or(Ok(None))?;
|
.unwrap_or(Ok(None))?;
|
||||||
|
|
||||||
// Validate inputs
|
// Validate inputs
|
||||||
let (inputs, input_length) = self
|
let (inputs, input_length, max_new_tokens) = self
|
||||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
@ -426,7 +436,7 @@ mod tests {
|
|||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation
|
match validation
|
||||||
.validate_input("Hello".to_string(), None, max_new_tokens)
|
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||||
@ -455,7 +465,7 @@ mod tests {
|
|||||||
|
|
||||||
let max_new_tokens = 10;
|
let max_new_tokens = 10;
|
||||||
match validation
|
match validation
|
||||||
.validate_input("Hello".to_string(), None, max_new_tokens)
|
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||||
@ -534,7 +544,6 @@ mod tests {
|
|||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: Some(0.99),
|
top_p: Some(0.99),
|
||||||
max_new_tokens: 1,
|
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@ -549,7 +558,6 @@ mod tests {
|
|||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: None,
|
top_p: None,
|
||||||
max_new_tokens: 1,
|
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@ -596,7 +604,6 @@ mod tests {
|
|||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: Some(4),
|
top_n_tokens: Some(4),
|
||||||
max_new_tokens: 1,
|
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@ -608,7 +615,6 @@ mod tests {
|
|||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: Some(0),
|
top_n_tokens: Some(0),
|
||||||
max_new_tokens: 1,
|
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@ -620,7 +626,6 @@ mod tests {
|
|||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_n_tokens: None,
|
top_n_tokens: None,
|
||||||
max_new_tokens: 1,
|
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
Loading…
Reference in New Issue
Block a user