mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
Change add_special_tokens
in order to have the correct tokens for chat
input and not (since it's super important with the prefixing now)
This commit is contained in:
parent
f1c0735453
commit
7f1816a4e1
@ -120,10 +120,11 @@ impl Infer {
|
||||
) -> Result<Option<tokenizers::Encoding>, InferError> {
|
||||
// Tokenize request
|
||||
let inputs = request.inputs;
|
||||
let add_special_tokens = request.add_special_tokens;
|
||||
let truncate = request.parameters.truncate;
|
||||
let encoding = self
|
||||
.validation
|
||||
.tokenize(inputs, truncate)
|
||||
.tokenize(inputs, add_special_tokens, truncate)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
tracing::error!("Tokenization {err}");
|
||||
|
@ -1082,6 +1082,16 @@ pub(crate) struct GenerateRequest {
|
||||
pub inputs: String,
|
||||
#[serde(default = "default_parameters")]
|
||||
pub parameters: GenerateParameters,
|
||||
|
||||
/// This is used internally because some requests
|
||||
/// already contain the templated input therefore
|
||||
/// we shouldn't add the special tokens.
|
||||
#[serde(default = "default_true")]
|
||||
pub add_special_tokens: bool,
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, ToSchema)]
|
||||
@ -1099,6 +1109,7 @@ impl From<CompatGenerateRequest> for GenerateRequest {
|
||||
fn from(req: CompatGenerateRequest) -> Self {
|
||||
Self {
|
||||
inputs: req.inputs,
|
||||
add_special_tokens: true,
|
||||
parameters: req.parameters,
|
||||
}
|
||||
}
|
||||
|
@ -158,6 +158,7 @@ async fn get_chat_tokenize(
|
||||
|
||||
let generate_request = GenerateRequest {
|
||||
inputs,
|
||||
add_special_tokens: false,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
@ -754,6 +755,7 @@ async fn completions(
|
||||
.iter()
|
||||
.map(|prompt| GenerateRequest {
|
||||
inputs: prompt.to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
@ -1180,6 +1182,7 @@ async fn chat_completions(
|
||||
// build the request passing some parameters
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: inputs.to_string(),
|
||||
add_special_tokens: false,
|
||||
parameters: GenerateParameters {
|
||||
best_of: None,
|
||||
temperature,
|
||||
@ -1386,6 +1389,7 @@ async fn vertex_compatibility(
|
||||
.map(|instance| {
|
||||
let generate_request = GenerateRequest {
|
||||
inputs: instance.inputs.clone(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
do_sample: true,
|
||||
max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens),
|
||||
|
@ -95,6 +95,7 @@ impl Validation {
|
||||
pub async fn tokenize(
|
||||
&self,
|
||||
inputs: String,
|
||||
add_special_tokens: bool,
|
||||
truncate: Option<usize>,
|
||||
) -> Result<Option<(tokenizers::Encoding, Vec<Chunk>)>, ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
@ -104,7 +105,11 @@ impl Validation {
|
||||
// Send request to the background validation task
|
||||
// Unwrap is safe here
|
||||
sender
|
||||
.send(((inputs, truncate), response_sender, Span::current()))
|
||||
.send((
|
||||
(inputs, add_special_tokens, truncate),
|
||||
response_sender,
|
||||
Span::current(),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
// Await on response channel
|
||||
@ -121,11 +126,15 @@ impl Validation {
|
||||
async fn validate_input(
|
||||
&self,
|
||||
inputs: String,
|
||||
add_special_tokens: bool,
|
||||
truncate: Option<usize>,
|
||||
max_new_tokens: Option<u32>,
|
||||
) -> Result<(Vec<Chunk>, Option<Vec<u32>>, usize, u32), ValidationError> {
|
||||
// If we have a fast tokenizer
|
||||
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
|
||||
if let Some((encoding, inputs)) = self
|
||||
.tokenize(inputs.clone(), add_special_tokens, truncate)
|
||||
.await?
|
||||
{
|
||||
// Create response channel
|
||||
let input_length = if let Some(truncate) = truncate {
|
||||
std::cmp::min(encoding.len(), truncate)
|
||||
@ -324,7 +333,12 @@ impl Validation {
|
||||
|
||||
// Validate inputs
|
||||
let (inputs, input_ids, input_length, max_new_tokens) = self
|
||||
.validate_input(request.inputs, truncate, max_new_tokens)
|
||||
.validate_input(
|
||||
request.inputs,
|
||||
request.add_special_tokens,
|
||||
truncate,
|
||||
max_new_tokens,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// TODO: we should build the FSM here and pass the compiled FSM instead of the grammar
|
||||
@ -449,12 +463,15 @@ fn tokenizer_worker(
|
||||
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
|
||||
) {
|
||||
// Loop over requests
|
||||
while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() {
|
||||
while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) =
|
||||
receiver.blocking_recv()
|
||||
{
|
||||
parent_span.in_scope(|| {
|
||||
response_tx
|
||||
.send(prepare_input(
|
||||
inputs,
|
||||
truncate,
|
||||
add_special_tokens,
|
||||
&tokenizer,
|
||||
config.as_ref(),
|
||||
preprocessor_config.as_ref(),
|
||||
@ -591,6 +608,7 @@ fn image_tokens_fixup(config: &Config, text: String) -> String {
|
||||
fn prepare_input(
|
||||
inputs: String,
|
||||
_truncate: Option<usize>,
|
||||
add_special_tokens: bool,
|
||||
tokenizer: &Tokenizer,
|
||||
config: Option<&Config>,
|
||||
preprocessor_config: Option<&HubPreprocessorConfig>,
|
||||
@ -628,14 +646,14 @@ fn prepare_input(
|
||||
|
||||
// Get the number of tokens in the input
|
||||
let encoding = tokenizer
|
||||
.encode(tokenizer_query, true)
|
||||
.encode(tokenizer_query, add_special_tokens)
|
||||
.map_err(|err| ValidationError::Tokenizer(err.to_string()))?;
|
||||
|
||||
Ok((encoding, input_chunks))
|
||||
}
|
||||
|
||||
type TokenizerRequest = (
|
||||
(String, Option<usize>),
|
||||
(String, bool, Option<usize>),
|
||||
oneshot::Sender<Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError>>,
|
||||
Span,
|
||||
);
|
||||
@ -826,7 +844,7 @@ mod tests {
|
||||
|
||||
let max_new_tokens = 10;
|
||||
match validation
|
||||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
||||
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
|
||||
.await
|
||||
{
|
||||
// Err(ValidationError::MaxNewTokens(1, 10)) => (),
|
||||
@ -861,7 +879,7 @@ mod tests {
|
||||
|
||||
let max_new_tokens = 10;
|
||||
match validation
|
||||
.validate_input("Hello".to_string(), None, Some(max_new_tokens))
|
||||
.validate_input("Hello".to_string(), true, None, Some(max_new_tokens))
|
||||
.await
|
||||
{
|
||||
Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (),
|
||||
@ -895,6 +913,7 @@ mod tests {
|
||||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
best_of: Some(2),
|
||||
do_sample: false,
|
||||
@ -934,6 +953,7 @@ mod tests {
|
||||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_p: Some(1.0),
|
||||
max_new_tokens: Some(5),
|
||||
@ -949,6 +969,7 @@ mod tests {
|
||||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_p: Some(0.99),
|
||||
max_new_tokens: Some(5),
|
||||
@ -964,6 +985,7 @@ mod tests {
|
||||
let valid_request = validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_p: None,
|
||||
max_new_tokens: Some(5),
|
||||
@ -1002,6 +1024,7 @@ mod tests {
|
||||
match validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(5),
|
||||
max_new_tokens: Some(5),
|
||||
@ -1017,6 +1040,7 @@ mod tests {
|
||||
validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(4),
|
||||
max_new_tokens: Some(5),
|
||||
@ -1029,6 +1053,7 @@ mod tests {
|
||||
validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: Some(0),
|
||||
max_new_tokens: Some(5),
|
||||
@ -1041,6 +1066,7 @@ mod tests {
|
||||
let valid_request = validation
|
||||
.validate(GenerateRequest {
|
||||
inputs: "Hello".to_string(),
|
||||
add_special_tokens: true,
|
||||
parameters: GenerateParameters {
|
||||
top_n_tokens: None,
|
||||
max_new_tokens: Some(5),
|
||||
@ -1089,6 +1115,7 @@ mod tests {
|
||||
let chunks = match validation
|
||||
.tokenize(
|
||||
format!("test", PIXEL_GIF),
|
||||
true,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
@ -1148,6 +1175,7 @@ mod tests {
|
||||
"test",
|
||||
PIXEL_GIF, PIXEL_GIF
|
||||
),
|
||||
true,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
|
@ -266,6 +266,7 @@ class FlashCausalLMBatch(Batch):
|
||||
orig_input_length = len(tokenized_input)
|
||||
|
||||
prefix_len = r.prefix_len
|
||||
assert prefix_len <= orig_input_length
|
||||
if prefix_len == orig_input_length:
|
||||
assert prefix_len > 0
|
||||
prefix_len -= 1
|
||||
@ -282,6 +283,7 @@ class FlashCausalLMBatch(Batch):
|
||||
all_input_ids.append(tokenized_input)
|
||||
|
||||
# Position ids
|
||||
print(f"Prefix {prefix_len} - Orig {orig_input_length}")
|
||||
request_position_ids = torch.arange(
|
||||
prefix_len, orig_input_length, dtype=torch.int32
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user