Fixing RW code (it's remote code so the Arch checking doesn't work to

see which weights to keep).
This commit is contained in:
Nicolas Patry 2023-07-10 18:40:09 +00:00
parent b4024edd45
commit d9ed7b9274

View File

@ -49,7 +49,13 @@ class FlashRWSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
)
config.quantize = quantize