diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 2a88c781..1cd999f0 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -71,8 +71,9 @@ class CausalLMBatch: ) ) + pad_to_multiple_of = 8 if "gpu" in str(device) else None tokenized_inputs = tokenizer( - inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 + inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of ).to(device) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 3302138f..cbfe7ccf 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -83,8 +83,9 @@ class Seq2SeqLMBatch: ) # Tokenize batch + pad_to_multiple_of = 8 if "gpu" in str(device) else None tokenized_inputs = tokenizer( - inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 + inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of ).to(device) # Convert decoder_input_ids to torch tensor of size [batch_size, 1] decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)