Merge branch 'main' into fix_replaying_requests

This commit is contained in:
Nicolas Patry 2023-01-02 10:55:21 +01:00 committed by GitHub
commit bed5634ead
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 3 deletions

View File

@ -131,7 +131,7 @@ fn validation_worker(
}
// Get the number of tokens in the input
match tokenizer.encode(request.inputs.clone(), false) {
match tokenizer.encode(request.inputs.clone(), true) {
Ok(inputs) => {
let input_length = inputs.len();

View File

@ -65,7 +65,7 @@ class CausalLMBatch:
)
all_logprobs.append(None)
pad_to_multiple_of = 8 if "gpu" in str(device) else None
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",

View File

@ -77,7 +77,7 @@ class Seq2SeqLMBatch:
decoder_logprobs.append(None)
# Tokenize batch
pad_to_multiple_of = 8 if "gpu" in str(device) else None
pad_to_multiple_of = 8 if device.type == "cuda" else None
tokenized_inputs = tokenizer(
inputs,
return_tensors="pt",