mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Modified fix.
This commit is contained in:
parent
81f234ec61
commit
742199aa0d
@ -48,7 +48,13 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(
|
||||||
|
filenames,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
process_group=self.process_group,
|
||||||
|
aliases={"transformer.word_embeddings.weight": ["lm_head.weight"]},
|
||||||
|
)
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user