From e9669a4085a0f031df10585090fe656d4675d492 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 23 May 2023 19:12:12 +0200 Subject: [PATCH] feat(server): do not use device_map auto on single GPU (#362) --- server/text_generation_server/models/causal_lm.py | 5 ++++- server/text_generation_server/models/seq2seq_lm.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 90b1e5ee..ab92feed 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -468,9 +468,12 @@ class CausalLM(Model): model_id, revision=revision, torch_dtype=dtype, - device_map="auto" if torch.cuda.is_available() else None, + device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, load_in_8bit=quantize == "bitsandbytes", ) + if torch.cuda.is_available() and torch.cuda.device_count() == 1: + model = model.cuda() + tokenizer.pad_token_id = ( model.config.pad_token_id if model.config.pad_token_id is not None diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 4f55b22f..f8b404a9 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -518,9 +518,12 @@ class Seq2SeqLM(Model): model_id, revision=revision, torch_dtype=dtype, - device_map="auto" if torch.cuda.is_available() else None, + device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, load_in_8bit=quantize == "bitsandbytes", ) + if torch.cuda.is_available() and torch.cuda.device_count() == 1: + model = model.cuda() + tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" )