mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
add validation + decode of special tokens
This commit is contained in:
parent
273f0ae42c
commit
146e0e27ce
@ -2,6 +2,7 @@ use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}
|
||||
/// Payload validation logic
|
||||
use crate::{GenerateParameters, GenerateRequest};
|
||||
use rand::{thread_rng, Rng};
|
||||
use std::cmp::max;
|
||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||
use thiserror::Error;
|
||||
use tokenizers::tokenizer::Tokenizer;
|
||||
@ -30,6 +31,10 @@ impl Validation {
|
||||
max_input_length: usize,
|
||||
max_total_tokens: usize,
|
||||
) -> Self {
|
||||
if max_input_length >= max_total_tokens {
|
||||
panic!("`max_input_length` must be < `max_total_tokens`");
|
||||
}
|
||||
|
||||
// If we have a fast tokenizer
|
||||
let sender = if let Some(tokenizer) = tokenizer {
|
||||
// Create channel
|
||||
@ -105,6 +110,18 @@ impl Validation {
|
||||
}
|
||||
// Return inputs without validation
|
||||
else {
|
||||
// In this case, we don't know the real length in tokens of the inputs
|
||||
// However, the inputs will be truncated by the python servers
|
||||
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
||||
|
||||
// Validate MaxNewTokens
|
||||
if (truncate + max_new_tokens) > self.max_total_tokens {
|
||||
return Err(ValidationError::MaxNewTokens(
|
||||
self.max_total_tokens - self.max_input_length,
|
||||
max_new_tokens,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(inputs)
|
||||
}
|
||||
}
|
||||
@ -183,7 +200,7 @@ impl Validation {
|
||||
.unwrap_or(Ok(0))?;
|
||||
|
||||
if max_new_tokens == 0 {
|
||||
return Err(ValidationError::MaxNewTokens);
|
||||
return Err(ValidationError::NegativeMaxNewTokens);
|
||||
}
|
||||
|
||||
if stop_sequences.len() > self.max_stop_sequences {
|
||||
@ -345,7 +362,9 @@ pub enum ValidationError {
|
||||
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
||||
TypicalP,
|
||||
#[error("`max_new_tokens` must be strictly positive")]
|
||||
MaxNewTokens,
|
||||
NegativeMaxNewTokens,
|
||||
#[error("`max_new_tokens` must be <= {0}. Given: {1}")]
|
||||
MaxNewTokens(usize, u32),
|
||||
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
|
||||
MaxTotalTokens(usize, usize, u32),
|
||||
#[error("`inputs` must have less than {0} tokens. Given: {1}")]
|
||||
|
@ -44,7 +44,7 @@ class LlamaRMSNorm(nn.Module):
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 6144:
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
@ -38,7 +38,7 @@ from flash_attn.layers.rotary import RotaryEmbedding
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 6144:
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
@ -11,7 +11,7 @@ import dropout_layer_norm
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 6144:
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
@ -31,6 +31,13 @@ class Model(ABC):
|
||||
token_offset: Optional[int] = None,
|
||||
) -> Tuple[str, Optional[int], Optional[int]]:
|
||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||
if all_input_ids[-1] in self.all_special_ids:
|
||||
return (
|
||||
self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
if token_offset is None:
|
||||
token_offset = len(all_input_ids) - 3
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user