mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 11:24:53 +00:00
fix santacoder sharded
This commit is contained in:
parent
b47edb28af
commit
0fc4f99379
@ -90,10 +90,10 @@ def get_model(
|
||||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
||||
)
|
||||
return FlashSantacoderSharded(model_id, revision=revision)
|
||||
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
|
||||
else:
|
||||
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||
return santacoder_cls(model_id, revision, quantize)
|
||||
return santacoder_cls(model_id, revision, quantize=quantize)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||
model_type = config.model_type
|
||||
|
@ -75,9 +75,9 @@ class FlashLlama(FlashCausalLM):
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
for filename in filenames:
|
||||
state_dict = torch.load(filename, "cpu")
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
for key, value in state_dict.items():
|
||||
value = value.to(dtype).to(device if not quantize else "cpu")
|
||||
value = value.to(device if not quantize else "cpu").to(dtype)
|
||||
|
||||
layer_name = ".".join(key.split(".")[:4])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user