Directly load GPTBigCode to specified device

This PR directly load GPTBigCode to specified device, avoiding moving model between devices.
This commit is contained in:
Yang, Bo 2023-07-15 00:32:46 -07:00 committed by GitHub
parent c58a0c185b
commit 3e5165c3ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -51,14 +51,14 @@ class SantaCoder(CausalLM):
"pad_token": EOD,
}
)
with device:
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
).to(device)
)
super(CausalLM, self).__init__(
model=model,