mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
Directly load GPTBigCode to specified device
This PR directly load GPTBigCode to specified device, avoiding moving model between devices.
This commit is contained in:
parent
c58a0c185b
commit
3e5165c3ed
@ -51,14 +51,14 @@ class SantaCoder(CausalLM):
|
||||
"pad_token": EOD,
|
||||
}
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
).to(device)
|
||||
with device:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
model=model,
|
||||
|
Loading…
Reference in New Issue
Block a user