From a7ed31cf6c9d5835ef8ca9979a47def445d74b8e Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sat, 25 Nov 2023 11:30:20 +0000 Subject: [PATCH] Deactivating v2 for sharded. It fails with illegal access on cuda when using sharding. Took a long while to try and fix it: All tensors are correct (same as v1). Scratch size doesn't help Error only occurs for sequence lengths > 50 (so during warmup most of the time) Couldn't figure out why this 50 particular number, nor change anything to fix the behavior. --- .../utils/gptq/exllamav2.py | 14 +++++++++++++- server/text_generation_server/utils/layers.py | 4 ++++ server/text_generation_server/utils/weights.py | 18 ++++-------------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/utils/gptq/exllamav2.py index 25d90e97..1945338b 100644 --- a/server/text_generation_server/utils/gptq/exllamav2.py +++ b/server/text_generation_server/utils/gptq/exllamav2.py @@ -51,7 +51,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None): w["scales"] = w["scales"].half() # GPTQ with g_idx (act_order) - if "g_idx" in w and not (w["g_idx"] == 0).all().item(): + if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item(): w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device) w["q_invperm"] = torch.empty_like(w["q_perm"]) # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. @@ -113,12 +113,24 @@ class QuantLinear(nn.Module): self.maxq = 2 ** self.bits - 1 self.infeatures = qweight.shape[0] // self.bits * 32 self.outfeatures = qweight.shape[1] + self.padding = - self.outfeatures % 32 + self.outfeatures = self.outfeatures + self.padding + self.device = qweight.device self.qweight = qweight self.qzeros = qzeros self.scales = scales self.g_idx = g_idx self.bias = bias if bias is not None else None + self.group_size = groupsize + + infeatures = self.infeatures + outfeatures = self.outfeatures + assert qweight.shape == (infeatures // 32 * self.bits, outfeatures) + assert infeatures % self.group_size == 0 + assert qzeros.shape == (infeatures // self.group_size, outfeatures // 32 * self.bits) + assert scales.shape == (infeatures // self.group_size, outfeatures) + assert g_idx.shape == (infeatures, ), f"{g_idx.shape}, {infeatures}" global FIXED_BYTES, LAYERS FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 13bd422a..e6a90116 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -35,6 +35,10 @@ except Exception: HAS_EXLLAMA = False CAN_EXLLAMA = major >= 8 V2 = os.getenv("EXLLAMA_VERSION", "2") == "2" +if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1: + logger.warning("Disabling exllama v2 and using v1 instead because there are issues when sharding") + V2 = False + if os.getenv("DISABLE_EXLLAMA") == "True": HAS_EXLLAMA = False elif CAN_EXLLAMA: diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index f03892ba..f3344988 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -281,20 +281,10 @@ class Weights: logger.info(f"Using exllama kernels v{HAS_EXLLAMA}") if use_exllama: - if groupsize >= 0: - # Exllama reorders the weights in advance and the activations on the fly, thus - # the scales and zero-points do not need to be reordered. - qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) - scales = self.get_sharded(f"{prefix}.scales", dim=0) - else: - qzeros = self.get_tensor(f"{prefix}.qzeros") - scales = self.get_tensor(f"{prefix}.scales") - - # For tp > 1, at this point we know we do not use act-order - if self.process_group.size() == 1: - g_idx = self.get_tensor(f"{prefix}.g_idx") - else: - g_idx = None + qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) + scales = self.get_sharded(f"{prefix}.scales", dim=0) + g_idx = self.get_sharded(f"{prefix}.g_idx", dim= 0) + g_idx = g_idx - g_idx[0] else: # The triton kernel reorders the scales/zero points instead of the weight/activation. # Thus, each rank needs the full qzeros/scales.