From f764bc1b52b7a8c537fc1b11eeb2ec2a5aff9b44 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 12 Jul 2023 12:46:02 +0000 Subject: [PATCH] Fixing OOM on non sharded. --- server/text_generation_server/utils/layers.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index caa7d62d..4f65446e 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -181,12 +181,16 @@ class TensorParallelHead(SuperLayer): @staticmethod def load(config, prefix: str, weights): - try: - weight = weights.get_sharded(f"{prefix}.weight", dim=0) - should_gather = True - except AssertionError: - # If the vocab size is not divisible by number of shards - # just load the entire thing. + if weights.process_group.size() > 1: + try: + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + should_gather = True + except AssertionError: + # If the vocab size is not divisible by number of shards + # just load the entire thing. + weight = weights.get_tensor(f"{prefix}.weight") + should_gather = False + else: weight = weights.get_tensor(f"{prefix}.weight") should_gather = False