mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 04:14:52 +00:00
Deactivating v2 for sharded.
It fails with illegal access on cuda when using sharding. Took a long while to try and fix it: All tensors are correct (same as v1). Scratch size doesn't help Error only occurs for sequence lengths > 50 (so during warmup most of the time) Couldn't figure out why this 50 particular number, nor change anything to fix the behavior.
This commit is contained in:
parent
b041bf15ae
commit
a7ed31cf6c
@ -51,7 +51,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
|||||||
w["scales"] = w["scales"].half()
|
w["scales"] = w["scales"].half()
|
||||||
|
|
||||||
# GPTQ with g_idx (act_order)
|
# GPTQ with g_idx (act_order)
|
||||||
if "g_idx" in w and not (w["g_idx"] == 0).all().item():
|
if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item():
|
||||||
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device)
|
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device)
|
||||||
w["q_invperm"] = torch.empty_like(w["q_perm"])
|
w["q_invperm"] = torch.empty_like(w["q_perm"])
|
||||||
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
||||||
@ -113,12 +113,24 @@ class QuantLinear(nn.Module):
|
|||||||
self.maxq = 2 ** self.bits - 1
|
self.maxq = 2 ** self.bits - 1
|
||||||
self.infeatures = qweight.shape[0] // self.bits * 32
|
self.infeatures = qweight.shape[0] // self.bits * 32
|
||||||
self.outfeatures = qweight.shape[1]
|
self.outfeatures = qweight.shape[1]
|
||||||
|
self.padding = - self.outfeatures % 32
|
||||||
|
self.outfeatures = self.outfeatures + self.padding
|
||||||
|
|
||||||
self.device = qweight.device
|
self.device = qweight.device
|
||||||
self.qweight = qweight
|
self.qweight = qweight
|
||||||
self.qzeros = qzeros
|
self.qzeros = qzeros
|
||||||
self.scales = scales
|
self.scales = scales
|
||||||
self.g_idx = g_idx
|
self.g_idx = g_idx
|
||||||
self.bias = bias if bias is not None else None
|
self.bias = bias if bias is not None else None
|
||||||
|
self.group_size = groupsize
|
||||||
|
|
||||||
|
infeatures = self.infeatures
|
||||||
|
outfeatures = self.outfeatures
|
||||||
|
assert qweight.shape == (infeatures // 32 * self.bits, outfeatures)
|
||||||
|
assert infeatures % self.group_size == 0
|
||||||
|
assert qzeros.shape == (infeatures // self.group_size, outfeatures // 32 * self.bits)
|
||||||
|
assert scales.shape == (infeatures // self.group_size, outfeatures)
|
||||||
|
assert g_idx.shape == (infeatures, ), f"{g_idx.shape}, {infeatures}"
|
||||||
|
|
||||||
global FIXED_BYTES, LAYERS
|
global FIXED_BYTES, LAYERS
|
||||||
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
|
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
|
||||||
|
@ -35,6 +35,10 @@ except Exception:
|
|||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
CAN_EXLLAMA = major >= 8
|
CAN_EXLLAMA = major >= 8
|
||||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||||
|
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
|
||||||
|
logger.warning("Disabling exllama v2 and using v1 instead because there are issues when sharding")
|
||||||
|
V2 = False
|
||||||
|
|
||||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
elif CAN_EXLLAMA:
|
elif CAN_EXLLAMA:
|
||||||
|
@ -281,20 +281,10 @@ class Weights:
|
|||||||
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
|
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||||
|
|
||||||
if use_exllama:
|
if use_exllama:
|
||||||
if groupsize >= 0:
|
|
||||||
# Exllama reorders the weights in advance and the activations on the fly, thus
|
|
||||||
# the scales and zero-points do not need to be reordered.
|
|
||||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
else:
|
g_idx = self.get_sharded(f"{prefix}.g_idx", dim= 0)
|
||||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
g_idx = g_idx - g_idx[0]
|
||||||
scales = self.get_tensor(f"{prefix}.scales")
|
|
||||||
|
|
||||||
# For tp > 1, at this point we know we do not use act-order
|
|
||||||
if self.process_group.size() == 1:
|
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
|
||||||
else:
|
|
||||||
g_idx = None
|
|
||||||
else:
|
else:
|
||||||
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||||
# Thus, each rank needs the full qzeros/scales.
|
# Thus, each rank needs the full qzeros/scales.
|
||||||
|
Loading…
Reference in New Issue
Block a user