mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 03:44:54 +00:00
fmt
This commit is contained in:
parent
20e0117e7c
commit
6d8d5b6d1d
@ -285,11 +285,15 @@ mod tests{
|
|||||||
|
|
||||||
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
pub(crate) async fn get_tokenizer() -> Tokenizer {
|
||||||
if !std::path::Path::new("tokenizer.json").exists() {
|
if !std::path::Path::new("tokenizer.json").exists() {
|
||||||
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json").await.unwrap().bytes().await.unwrap();
|
let content = reqwest::get("https://huggingface.co/gpt2/raw/main/tokenizer.json")
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.bytes()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
let mut file = std::fs::File::create("tokenizer.json").unwrap();
|
||||||
file.write_all(&content).unwrap();
|
file.write_all(&content).unwrap();
|
||||||
}
|
}
|
||||||
Tokenizer::from_file("tokenizer.json").unwrap()
|
Tokenizer::from_file("tokenizer.json").unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,7 +141,6 @@ impl State {
|
|||||||
|
|
||||||
// Get the next batch
|
// Get the next batch
|
||||||
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
|
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
|
||||||
|
|
||||||
if self.entries.is_empty() {
|
if self.entries.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
@ -741,4 +741,3 @@ impl From<InferError> for Event {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -447,19 +447,28 @@ mod tests {
|
|||||||
let max_input_length = 4;
|
let max_input_length = 4;
|
||||||
let max_total_tokens = 5;
|
let max_total_tokens = 5;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
let validation = Validation::new(
|
||||||
match validation.validate(GenerateRequest{
|
workers,
|
||||||
|
tokenizer,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequence,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
|
match validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
best_of: Some(2),
|
best_of: Some(2),
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
}
|
},
|
||||||
}).await{
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
Err(ValidationError::BestOfSampling) => (),
|
Err(ValidationError::BestOfSampling) => (),
|
||||||
_ => panic!("Unexpected not best of sampling")
|
_ => panic!("Unexpected not best of sampling"),
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -470,41 +479,55 @@ mod tests {
|
|||||||
let max_input_length = 4;
|
let max_input_length = 4;
|
||||||
let max_total_tokens = 5;
|
let max_total_tokens = 5;
|
||||||
let workers = 1;
|
let workers = 1;
|
||||||
let validation = Validation::new(workers, tokenizer, max_best_of, max_stop_sequence, max_input_length, max_total_tokens);
|
let validation = Validation::new(
|
||||||
match validation.validate(GenerateRequest{
|
workers,
|
||||||
|
tokenizer,
|
||||||
|
max_best_of,
|
||||||
|
max_stop_sequence,
|
||||||
|
max_input_length,
|
||||||
|
max_total_tokens,
|
||||||
|
);
|
||||||
|
match validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: Some(1.0),
|
top_p: Some(1.0),
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
}
|
},
|
||||||
}).await{
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
Err(ValidationError::TopP) => (),
|
Err(ValidationError::TopP) => (),
|
||||||
_ => panic!("Unexpected top_p")
|
_ => panic!("Unexpected top_p"),
|
||||||
}
|
}
|
||||||
|
|
||||||
match validation.validate(GenerateRequest{
|
match validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: Some(0.99),
|
top_p: Some(0.99),
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
}
|
},
|
||||||
}).await{
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(_) => (),
|
Ok(_) => (),
|
||||||
_ => panic!("Unexpected top_p error")
|
_ => panic!("Unexpected top_p error"),
|
||||||
}
|
}
|
||||||
|
|
||||||
let valid_request = validation.validate(GenerateRequest{
|
let valid_request = validation
|
||||||
|
.validate(GenerateRequest {
|
||||||
inputs: "Hello".to_string(),
|
inputs: "Hello".to_string(),
|
||||||
parameters: GenerateParameters {
|
parameters: GenerateParameters {
|
||||||
top_p: None,
|
top_p: None,
|
||||||
max_new_tokens: 1,
|
max_new_tokens: 1,
|
||||||
..default_parameters()
|
..default_parameters()
|
||||||
}
|
},
|
||||||
}).await.unwrap();
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
// top_p == 1.0 is invalid for users to ask for but it's the default resolved value.
|
||||||
assert_eq!(valid_request.parameters.top_p, 1.0);
|
assert_eq!(valid_request.parameters.top_p, 1.0);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user