mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
add validation + decode of special tokens
This commit is contained in:
parent
273f0ae42c
commit
146e0e27ce
@ -2,6 +2,7 @@ use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}
|
|||||||
/// Payload validation logic
|
/// Payload validation logic
|
||||||
use crate::{GenerateParameters, GenerateRequest};
|
use crate::{GenerateParameters, GenerateRequest};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
|
use std::cmp::max;
|
||||||
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tokenizers::tokenizer::Tokenizer;
|
use tokenizers::tokenizer::Tokenizer;
|
||||||
@ -30,6 +31,10 @@ impl Validation {
|
|||||||
max_input_length: usize,
|
max_input_length: usize,
|
||||||
max_total_tokens: usize,
|
max_total_tokens: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
if max_input_length >= max_total_tokens {
|
||||||
|
panic!("`max_input_length` must be < `max_total_tokens`");
|
||||||
|
}
|
||||||
|
|
||||||
// If we have a fast tokenizer
|
// If we have a fast tokenizer
|
||||||
let sender = if let Some(tokenizer) = tokenizer {
|
let sender = if let Some(tokenizer) = tokenizer {
|
||||||
// Create channel
|
// Create channel
|
||||||
@ -105,6 +110,18 @@ impl Validation {
|
|||||||
}
|
}
|
||||||
// Return inputs without validation
|
// Return inputs without validation
|
||||||
else {
|
else {
|
||||||
|
// In this case, we don't know the real length in tokens of the inputs
|
||||||
|
// However, the inputs will be truncated by the python servers
|
||||||
|
// We make sure that truncate + max_new_tokens <= self.max_total_tokens
|
||||||
|
|
||||||
|
// Validate MaxNewTokens
|
||||||
|
if (truncate + max_new_tokens) > self.max_total_tokens {
|
||||||
|
return Err(ValidationError::MaxNewTokens(
|
||||||
|
self.max_total_tokens - self.max_input_length,
|
||||||
|
max_new_tokens,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(inputs)
|
Ok(inputs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -183,7 +200,7 @@ impl Validation {
|
|||||||
.unwrap_or(Ok(0))?;
|
.unwrap_or(Ok(0))?;
|
||||||
|
|
||||||
if max_new_tokens == 0 {
|
if max_new_tokens == 0 {
|
||||||
return Err(ValidationError::MaxNewTokens);
|
return Err(ValidationError::NegativeMaxNewTokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
if stop_sequences.len() > self.max_stop_sequences {
|
if stop_sequences.len() > self.max_stop_sequences {
|
||||||
@ -345,7 +362,9 @@ pub enum ValidationError {
|
|||||||
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
#[error("`typical_p` must be > 0.0 and < 1.0")]
|
||||||
TypicalP,
|
TypicalP,
|
||||||
#[error("`max_new_tokens` must be strictly positive")]
|
#[error("`max_new_tokens` must be strictly positive")]
|
||||||
MaxNewTokens,
|
NegativeMaxNewTokens,
|
||||||
|
#[error("`max_new_tokens` must be <= {0}. Given: {1}")]
|
||||||
|
MaxNewTokens(usize, u32),
|
||||||
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
|
#[error("`inputs` tokens + `max_new_tokens` must be <= {0}. Given: {1} `inputs` tokens and {2} `max_new_tokens`")]
|
||||||
MaxTotalTokens(usize, usize, u32),
|
MaxTotalTokens(usize, usize, u32),
|
||||||
#[error("`inputs` must have less than {0} tokens. Given: {1}")]
|
#[error("`inputs` must have less than {0} tokens. Given: {1}")]
|
||||||
|
@ -44,7 +44,7 @@ class LlamaRMSNorm(nn.Module):
|
|||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 6144:
|
if hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
@ -38,7 +38,7 @@ from flash_attn.layers.rotary import RotaryEmbedding
|
|||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 6144:
|
if hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
@ -11,7 +11,7 @@ import dropout_layer_norm
|
|||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if hidden_states.shape[-1] > 6144:
|
if hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
@ -31,6 +31,13 @@ class Model(ABC):
|
|||||||
token_offset: Optional[int] = None,
|
token_offset: Optional[int] = None,
|
||||||
) -> Tuple[str, Optional[int], Optional[int]]:
|
) -> Tuple[str, Optional[int], Optional[int]]:
|
||||||
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
|
||||||
|
if all_input_ids[-1] in self.all_special_ids:
|
||||||
|
return (
|
||||||
|
self.tokenizer.decode(all_input_ids[-1], skip_special_tokens=False),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
if token_offset is None:
|
if token_offset is None:
|
||||||
token_offset = len(all_input_ids) - 3
|
token_offset = len(all_input_ids) - 3
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user