mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Directly load GPTBigCode to specified device
This PR directly load GPTBigCode to specified device, avoiding moving model between devices.
This commit is contained in:
parent
c58a0c185b
commit
3e5165c3ed
@ -51,14 +51,14 @@ class SantaCoder(CausalLM):
|
|||||||
"pad_token": EOD,
|
"pad_token": EOD,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
with device:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
).to(device)
|
)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
model=model,
|
model=model,
|
||||||
|
Loading…
Reference in New Issue
Block a user