fix santacoder sharded

This commit is contained in:
OlivierDehaene 2023-04-19 12:32:25 +02:00
parent b47edb28af
commit 0fc4f99379
2 changed files with 4 additions and 4 deletions

View File

@ -90,10 +90,10 @@ def get_model(
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
)
return FlashSantacoderSharded(model_id, revision=revision)
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize)
return santacoder_cls(model_id, revision, quantize=quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision)
model_type = config.model_type

View File

@ -75,9 +75,9 @@ class FlashLlama(FlashCausalLM):
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, "cpu")
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(dtype).to(device if not quantize else "cpu")
value = value.to(device if not quantize else "cpu").to(dtype)
layer_name = ".".join(key.split(".")[:4])