fix:error use loadin8bit model.cuda

This commit is contained in:
mmnga 2023-10-16 11:03:32 +09:00
parent 3af1a11401
commit ac531c8d0a

View File

@ -511,7 +511,7 @@ class CausalLM(Model):
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
if torch.cuda.is_available() and torch.cuda.device_count() == 1 and quantize != "bitsandbytes":
model = model.cuda()
if tokenizer.pad_token_id is None: