Modified fix.

This commit is contained in:
Nicolas Patry 2023-07-04 11:30:59 +02:00
parent 81f234ec61
commit 742199aa0d

View File

@ -48,7 +48,13 @@ class FlashRWSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(
filenames,
device,
dtype,
process_group=self.process_group,
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
)
config.quantize = quantize config.quantize = quantize