fix santacoder sharded

This commit is contained in:
OlivierDehaene 2023-04-19 12:32:25 +02:00
parent b47edb28af
commit 0fc4f99379
2 changed files with 4 additions and 4 deletions

View File

@ -90,10 +90,10 @@ def get_model(
raise NotImplementedError( raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
) )
return FlashSantacoderSharded(model_id, revision=revision) return FlashSantacoderSharded(model_id, revision, quantize=quantize)
else: else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize) return santacoder_cls(model_id, revision, quantize=quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision) config = AutoConfig.from_pretrained(model_id, revision=revision)
model_type = config.model_type model_type = config.model_type

View File

@ -75,9 +75,9 @@ class FlashLlama(FlashCausalLM):
dtype: torch.dtype, dtype: torch.dtype,
): ):
for filename in filenames: for filename in filenames:
state_dict = torch.load(filename, "cpu") state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items(): for key, value in state_dict.items():
value = value.to(dtype).to(device if not quantize else "cpu") value = value.to(device if not quantize else "cpu").to(dtype)
layer_name = ".".join(key.split(".")[:4]) layer_name = ".".join(key.split(".")[:4])