Directly load GPTBigCode to specified device

This PR directly load GPTBigCode to specified device, avoiding moving model between devices.
This commit is contained in:
Yang, Bo 2023-07-15 00:32:46 -07:00 committed by GitHub
parent c58a0c185b
commit 3e5165c3ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -51,14 +51,14 @@ class SantaCoder(CausalLM):
"pad_token": EOD, "pad_token": EOD,
} }
) )
with device:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
).to(device) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model, model=model,