This commit is contained in:
Nicolas Patry 2024-02-05 14:39:45 +01:00
parent e1dc168188
commit 29a8d5a3a1

View File

@ -19,12 +19,13 @@ pub struct Validation {
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
batched_dimension: bool, batch_dimension: bool,
/// Channel to communicate with the background tokenization task /// Channel to communicate with the background tokenization task
sender: Option<mpsc::UnboundedSender<TokenizerRequest>>, sender: Option<mpsc::UnboundedSender<TokenizerRequest>>,
} }
impl Validation { impl Validation {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new( pub(crate) fn new(
workers: usize, workers: usize,
tokenizer: Option<Tokenizer>, tokenizer: Option<Tokenizer>,
@ -33,7 +34,7 @@ impl Validation {
max_top_n_tokens: u32, max_top_n_tokens: u32,
max_input_length: usize, max_input_length: usize,
max_total_tokens: usize, max_total_tokens: usize,
batched_dimension: bool, batch_dimension: bool,
) -> Self { ) -> Self {
// If we have a fast tokenizer // If we have a fast tokenizer
let sender = if let Some(tokenizer) = tokenizer { let sender = if let Some(tokenizer) = tokenizer {
@ -68,7 +69,7 @@ impl Validation {
max_top_n_tokens, max_top_n_tokens,
max_input_length, max_input_length,
max_total_tokens, max_total_tokens,
batched_dimension batch_dimension
} }
} }
@ -106,7 +107,7 @@ impl Validation {
) -> Result<(String, usize, u32), ValidationError> { ) -> Result<(String, usize, u32), ValidationError> {
// If we have a fast tokenizer // If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
if self.batched_dimension{ if self.batch_dimension{
let input_length = encoding.len(); let input_length = encoding.len();
// Get total tokens // Get total tokens
@ -139,7 +140,7 @@ impl Validation {
return Ok((inputs, input_length, max_new_tokens)); return Ok((inputs, input_length, max_new_tokens));
} }
} }
// Either we don't have a tokenizer or batched_dimension purposefully // Either we don't have a tokenizer or batch_dimension purposefully
// will ignore the actual length in order to schedule the job correctly. // will ignore the actual length in order to schedule the job correctly.
// In this case, we don't know the real length in tokens of the inputs // In this case, we don't know the real length in tokens of the inputs
// However, the inputs will be truncated by the python servers // However, the inputs will be truncated by the python servers