add validation + decode of special tokens

This commit is contained in:
OlivierDehaene 2023-04-07 11:12:16 +02:00
parent 273f0ae42c
commit 146e0e27ce
5 changed files with 31 additions and 5 deletions

View File

@ -2,6 +2,7 @@ use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}
/// Payload validation logic
use crate::{GenerateParameters, GenerateRequest};
use rand::{thread_rng, Rng};
use std::cmp::max;
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
@ -30,6 +31,10 @@ impl Validation {
max_input_length: usize,
max_total_tokens: usize,
) -> Self {
if max_input_length >= max_total_tokens {
panic!("`max_input_length` must be < `max_total_tokens`");
}
// If we have a fast tokenizer
let sender = if let Some(tokenizer) = tokenizer {
// Create channel
@ -105,6 +110,18 @@ impl Validation {
}
// Return inputs without validation
else {
// In this case, we don't know the real length in tokens of the inputs
// However, the inputs will be truncated by the python servers
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
// Validate MaxNewTokens
if (truncate + max_new_tokens) > self.max_total_tokens {
return Err(ValidationError::MaxNewTokens(
self.max_total_tokens - self.max_input_length,
max_new_tokens,
));
}
Ok(inputs)
}
}
@ -183,7 +200,7 @@ impl Validation {
.unwrap_or(Ok(0))?;
if max_new_tokens == 0 {
return Err(ValidationError::MaxNewTokens);
return Err(ValidationError::NegativeMaxNewTokens);
}
if stop_sequences.len() > self.max_stop_sequences {
@ -345,7 +362,9 @@ pub enum ValidationError {
#[error("`typical_p` must be > 0.0 and < 1.0")]
TypicalP,
#[error("`max_new_tokens` must be strictly positive")]
MaxNewTokens,
NegativeMaxNewTokens,
#[error("`max_new_tokens` must be <= {0}. Given: {1}")]
MaxNewTokens(usize, u32),
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
MaxTotalTokens(usize, usize, u32),
#[error("`inputs` must have less than {0} tokens. Given: {1}")]

View File

@ -44,7 +44,7 @@ class LlamaRMSNorm(nn.Module):
self.variance_epsilon = eps
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 6144:
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states

View File

@ -38,7 +38,7 @@ from flash_attn.layers.rotary import RotaryEmbedding
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 6144:
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states

View File

@ -11,7 +11,7 @@ import dropout_layer_norm
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 6144:
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states

View File

@ -31,6 +31,13 @@ class Model(ABC):
token_offset: Optional[int] = None,
) -> Tuple[str, Optional[int], Optional[int]]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
if all_input_ids[-1] in self.all_special_ids:
return (
self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False),
None,
None,
)
if token_offset is None:
token_offset = len(all_input_ids) - 3