diff --git a/router/src/validation.rs b/router/src/validation.rs index 34b9190d..01b2de4b 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -2,6 +2,7 @@ use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput} /// Payload validation logic use crate::{GenerateParameters, GenerateRequest}; use rand::{thread_rng, Rng}; +use std::cmp::max; use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; use thiserror::Error; use tokenizers::tokenizer::Tokenizer; @@ -30,6 +31,10 @@ impl Validation { max_input_length: usize, max_total_tokens: usize, ) -> Self { + if max_input_length >= max_total_tokens { + panic!("`max_input_length` must be < `max_total_tokens`"); + } + // If we have a fast tokenizer let sender = if let Some(tokenizer) = tokenizer { // Create channel @@ -105,6 +110,18 @@ impl Validation { } // Return inputs without validation else { + // In this case, we don't know the real length in tokens of the inputs + // However, the inputs will be truncated by the python servers + // We make sure that truncate + max_new_tokens <= self.max_total_tokens + + // Validate MaxNewTokens + if (truncate + max_new_tokens) > self.max_total_tokens { + return Err(ValidationError::MaxNewTokens( + self.max_total_tokens - self.max_input_length, + max_new_tokens, + )); + } + Ok(inputs) } } @@ -183,7 +200,7 @@ impl Validation { .unwrap_or(Ok(0))?; if max_new_tokens == 0 { - return Err(ValidationError::MaxNewTokens); + return Err(ValidationError::NegativeMaxNewTokens); } if stop_sequences.len() > self.max_stop_sequences { @@ -345,7 +362,9 @@ pub enum ValidationError { #[error("`typical_p` must be > 0.0 and < 1.0")] TypicalP, #[error("`max_new_tokens` must be strictly positive")] - MaxNewTokens, + NegativeMaxNewTokens, + #[error("`max_new_tokens` must be <= {0}. Given: {1}")] + MaxNewTokens(usize, u32), #[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")] MaxTotalTokens(usize, usize, u32), #[error("`inputs` must have less than {0} tokens. Given: {1}")] diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 0eb260c9..e5c09cbe 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -44,7 +44,7 @@ class LlamaRMSNorm(nn.Module): self.variance_epsilon = eps def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 6144: + if hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual residual = hidden_states diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index efbfa70b..4ff17619 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -38,7 +38,7 @@ from flash_attn.layers.rotary import RotaryEmbedding class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 6144: + if hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual residual = hidden_states diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 799e7054..29c4a5c8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -11,7 +11,7 @@ import dropout_layer_norm class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 6144: + if hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual residual = hidden_states diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index b42605fb..5b82872c 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -31,6 +31,13 @@ class Model(ABC): token_offset: Optional[int] = None, ) -> Tuple[str, Optional[int], Optional[int]]: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" + if all_input_ids[-1] in self.all_special_ids: + return ( + self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False), + None, + None, + ) + if token_offset is None: token_offset = len(all_input_ids) - 3