mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
Fixing OOM on non sharded.
This commit is contained in:
parent
6193512c4b
commit
f764bc1b52
@ -181,6 +181,7 @@ class TensorParallelHead(SuperLayer):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(config, prefix: str, weights):
|
def load(config, prefix: str, weights):
|
||||||
|
if weights.process_group.size() > 1:
|
||||||
try:
|
try:
|
||||||
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
|
||||||
should_gather = True
|
should_gather = True
|
||||||
@ -189,6 +190,9 @@ class TensorParallelHead(SuperLayer):
|
|||||||
# just load the entire thing.
|
# just load the entire thing.
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
should_gather = False
|
should_gather = False
|
||||||
|
else:
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
|
should_gather = False
|
||||||
|
|
||||||
# GPTQ doesn't quantize heads (nor embeddings)
|
# GPTQ doesn't quantize heads (nor embeddings)
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
|
Loading…
Reference in New Issue
Block a user