From 63258dfab8ed9f297b20e66c71e02e1f458847d5 Mon Sep 17 00:00:00 2001
From: Nick Hill <nickhill@us.ibm.com>
Date: Tue, 27 Dec 2022 16:08:59 -0800
Subject: [PATCH] fix: Check for device type correctly when determining initial
 padding

AFAIK there is no torch device type called "gpu".
---
 server/text_generation/models/causal_lm.py  | 2 +-
 server/text_generation/models/seq2seq_lm.py | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/server/text_generation/models/causal_lm.py b/server/text_generation/models/causal_lm.py
index aeecf127..3ad36219 100644
--- a/server/text_generation/models/causal_lm.py
+++ b/server/text_generation/models/causal_lm.py
@@ -65,7 +65,7 @@ class CausalLMBatch:
             )
             all_logprobs.append(None)
 
-        pad_to_multiple_of = 8 if "gpu" in str(device) else None
+        pad_to_multiple_of = 8 if device.type == "cuda" else None
         tokenized_inputs = tokenizer(
             inputs,
             return_tensors="pt",
diff --git a/server/text_generation/models/seq2seq_lm.py b/server/text_generation/models/seq2seq_lm.py
index fc80c60c..4095db92 100644
--- a/server/text_generation/models/seq2seq_lm.py
+++ b/server/text_generation/models/seq2seq_lm.py
@@ -77,7 +77,7 @@ class Seq2SeqLMBatch:
             decoder_logprobs.append(None)
 
         # Tokenize batch
-        pad_to_multiple_of = 8 if "gpu" in str(device) else None
+        pad_to_multiple_of = 8 if device.type == "cuda" else None
         tokenized_inputs = tokenizer(
             inputs,
             return_tensors="pt",