fix: fix gpt-q with groupsize = -1 (#1358)

This commit is contained in:
OlivierDehaene 2023-12-18 16:07:05 +01:00 committed by Karol Damaszke
parent 5ff9e81952
commit b7299e1b7f
6 changed files with 30 additions and 30 deletions

View File

@ -213,6 +213,9 @@ message DecodeResponse {
message WarmupRequest { message WarmupRequest {
/// Batch to warmup on /// Batch to warmup on
repeated Batch batches = 1; repeated Batch batches = 1;
uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3;
uint32 max_total_tokens = 4;
} }
/// Empty response /// Empty response

View File

@ -167,7 +167,12 @@ impl Client {
); );
num_batches num_batches
]; ];
let request = tonic::Request::new(WarmupRequest { batches }).inject_context(); let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
}).inject_context();
let _response = self.stub.warmup(request).await?.into_inner(); let _response = self.stub.warmup(request).await?.into_inner();
} }
@ -188,7 +193,12 @@ impl Client {
); );
num_batches num_batches
]; ];
let request = tonic::Request::new(WarmupRequest { batches }).inject_context(); let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
max_prefill_tokens,
max_total_tokens,
}).inject_context();
let _response = self.stub.warmup(request).await?.into_inner(); let _response = self.stub.warmup(request).await?.into_inner();
} }
Ok(None) // No support for maximum total tokens Ok(None) // No support for maximum total tokens

View File

@ -21,7 +21,12 @@ from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]): def __init__(
self,
model: Model,
cache: Cache,
server_urls: List[str],
):
self.cache = cache self.cache = cache
self.model = model self.model = model
self.server_urls = server_urls self.server_urls = server_urls

View File

@ -37,19 +37,12 @@ def set_device(device):
DEVICE = device DEVICE = device
def create_exllama_buffers(): def create_exllama_buffers(max_total_tokens: int):
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
assert DEVICE is not None, "call set_device first" assert DEVICE is not None, "call set_device first"
if ACT_ORDER: if not ACT_ORDER:
# TODO: this should be set to rust side `max_total_tokens`, but TGI
# does not offer an API to expose this variable to python, as this variable
# is handled by the client but it appears the model is initialized by the server.
# An alternative could be to initialize the buffers during warmup.
# Dummy
max_total_tokens = 2048
else:
max_total_tokens = 1 max_total_tokens = 1
# This temp_state buffer is required to reorder X in the act-order case. # This temp_state buffer is required to reorder X in the act-order case.

View File

@ -101,7 +101,7 @@ def set_device(device):
DEVICE = device DEVICE = device
def create_exllama_buffers(): def create_exllama_buffers(max_total_tokens: int):
global FIXED_BYTES, LAYERS, DEVICE global FIXED_BYTES, LAYERS, DEVICE
temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES) temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES)
@ -138,17 +138,6 @@ class QuantLinear(nn.Module):
self.bias = bias if bias is not None else None self.bias = bias if bias is not None else None
self.group_size = groupsize self.group_size = groupsize
infeatures = self.infeatures
outfeatures = self.outfeatures
assert qweight.shape == (infeatures // 32 * self.bits, outfeatures)
assert infeatures % self.group_size == 0
assert qzeros.shape == (
infeatures // self.group_size,
outfeatures // 32 * self.bits,
)
assert scales.shape == (infeatures // self.group_size, outfeatures)
assert g_idx.shape == (infeatures,), f"{g_idx.shape}, {infeatures}"
global FIXED_BYTES, LAYERS global FIXED_BYTES, LAYERS
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
LAYERS.append(self) LAYERS.append(self)

View File

@ -281,17 +281,17 @@ class Weights:
else: else:
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}") logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama: if use_exllama and groupsize != -1:
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
g_idx = g_idx - g_idx[0]
else: else:
# The triton kernel reorders the scales/zero points instead of the weight/activation.
# Thus, each rank needs the full qzeros/scales.
qzeros = self.get_tensor(f"{prefix}.qzeros") qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales") scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
if use_exllama:
g_idx = g_idx - g_idx[0]
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
elif quantize == "awq": elif quantize == "awq":