mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
Merge branch 'main' into fix_replaying_requests
This commit is contained in:
commit
bed5634ead
@ -131,7 +131,7 @@ fn validation_worker(
|
||||
}
|
||||
|
||||
// Get the number of tokens in the input
|
||||
match tokenizer.encode(request.inputs.clone(), false) {
|
||||
match tokenizer.encode(request.inputs.clone(), true) {
|
||||
Ok(inputs) => {
|
||||
let input_length = inputs.len();
|
||||
|
||||
|
@ -65,7 +65,7 @@ class CausalLMBatch:
|
||||
)
|
||||
all_logprobs.append(None)
|
||||
|
||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
|
@ -77,7 +77,7 @@ class Seq2SeqLMBatch:
|
||||
decoder_logprobs.append(None)
|
||||
|
||||
# Tokenize batch
|
||||
pad_to_multiple_of = 8 if "gpu" in str(device) else None
|
||||
pad_to_multiple_of = 8 if device.type == "cuda" else None
|
||||
tokenized_inputs = tokenizer(
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
|
Loading…
Reference in New Issue
Block a user