From 042180d88f91d4bc9acd42ae4de3c0236d272de4 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 8 Dec 2022 19:37:37 +0100 Subject: [PATCH] fix(server): Only pad to multiple of 8 on GPUs --- server/text_generation/models/causal_lm.py | 3 ++- server/text_generation/models/seq2seq_lm.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py index 2a88c781..1cd999f0 100644 --- a/server/text_generation/models/causal_lm.py +++ b/server/text_generation/models/causal_lm.py @@ -71,8 +71,9 @@ class CausalLMBatch: ) ) + pad_to_multiple_of = 8 if "gpu" in str(device) else None tokenized_inputs = tokenizer( - inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 + inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of ).to(device) all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1) diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py index 3302138f..cbfe7ccf 100644 --- a/server/text_generation/models/seq2seq_lm.py +++ b/server/text_generation/models/seq2seq_lm.py @@ -83,8 +83,9 @@ class Seq2SeqLMBatch: ) # Tokenize batch + pad_to_multiple_of = 8 if "gpu" in str(device) else None tokenized_inputs = tokenizer( - inputs, return_tensors="pt", padding=True, pad_to_multiple_of=8 + inputs, return_tensors="pt", padding=True, pad_to_multiple_of=pad_to_multiple_of ).to(device) # Convert decoder_input_ids to torch tensor of size [batch_size, 1] decoder_input_ids = torch.tensor(decoder_input_ids, device=device).unsqueeze(-1)