diff --git a/proto/generate.proto b/proto/generate.proto index 02f3b2e84..1f30df383 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -213,6 +213,9 @@ message DecodeResponse { message WarmupRequest { /// Batch to warmup on Batch batch = 1; + uint32 max_input_length = 2; + uint32 max_prefill_tokens = 3; + uint32 max_total_tokens = 4; } /// Empty response diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 898e2b117..4723d6641 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -145,7 +145,13 @@ impl Client { max_tokens: 0, }; - let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context(); + let request = tonic::Request::new(WarmupRequest { + batch: Some(batch), + max_input_length, + max_prefill_tokens, + max_total_tokens, + }) + .inject_context(); let response = self.stub.warmup(request).await?.into_inner(); Ok(response.max_supported_total_tokens) } diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index a65138c91..d5adbd32a 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -19,9 +19,16 @@ from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): - def __init__(self, model: Model, cache: Cache, server_urls: List[str]): + def __init__( + self, + model: Model, + cache: Cache, + quantize: Optional[str], + server_urls: List[str], + ): self.cache = cache self.model = model + self.quantize = quantize self.server_urls = server_urls # For some reason, inference_mode does not work well with GLOO which we use on CPU if model.device.type == "cuda": @@ -56,6 +63,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): + if self.quantize == "gptq": + try: + # When using GPTQ, Exllama kernels need some global kernels + # For which we have the finale shapes only after the model has loaded + # This will allocate those buffers. + from text_generation_server.utils.layers import ( + create_exllama_buffers, + set_device, + ) + + set_device(self.model.device) + create_exllama_buffers(request.max_prefill_tokens) + except ImportError: + pass + if ( self.model.batch_type == IdeficsCausalLMBatch ): # Hack, i would rather use kwargs in the `from_pb` call @@ -184,21 +206,6 @@ def serve( logger.exception("Error when initializing model") raise - if quantize == "gptq": - try: - # When using GPTQ, Exllama kernels need some global kernels - # For which we have the finale shapes only after the model has loaded - # This will allocate those buffers. - from text_generation_server.utils.layers import ( - create_exllama_buffers, - set_device, - ) - - set_device(model.device) - create_exllama_buffers() - except ImportError: - pass - server = aio.server( interceptors=[ ExceptionInterceptor(), @@ -206,7 +213,7 @@ def serve( ] ) generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( - TextGenerationService(model, Cache(), server_urls), server + TextGenerationService(model, Cache(), quantize, server_urls), server ) SERVICE_NAMES = ( generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name, diff --git a/server/text_generation_server/utils/gptq/exllama.py b/server/text_generation_server/utils/gptq/exllama.py index 7353afb57..32f817dba 100644 --- a/server/text_generation_server/utils/gptq/exllama.py +++ b/server/text_generation_server/utils/gptq/exllama.py @@ -37,19 +37,12 @@ def set_device(device): DEVICE = device -def create_exllama_buffers(): +def create_exllama_buffers(max_total_tokens: int): global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ assert DEVICE is not None, "call set_device first" - if ACT_ORDER: - # TODO: this should be set to rust side `max_total_tokens`, but TGI - # does not offer an API to expose this variable to python, as this variable - # is handled by the client but it appears the model is initialized by the server. - # An alternative could be to initialize the buffers during warmup. - # Dummy - max_total_tokens = 2048 - else: + if not ACT_ORDER: max_total_tokens = 1 # This temp_state buffer is required to reorder X in the act-order case. diff --git a/server/text_generation_server/utils/gptq/exllamav2.py b/server/text_generation_server/utils/gptq/exllamav2.py index f820f0d9b..dd41b2697 100644 --- a/server/text_generation_server/utils/gptq/exllamav2.py +++ b/server/text_generation_server/utils/gptq/exllamav2.py @@ -101,7 +101,7 @@ def set_device(device): DEVICE = device -def create_exllama_buffers(): +def create_exllama_buffers(max_total_tokens: int): global FIXED_BYTES, LAYERS, DEVICE temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES) @@ -138,17 +138,6 @@ class QuantLinear(nn.Module): self.bias = bias if bias is not None else None self.group_size = groupsize - infeatures = self.infeatures - outfeatures = self.outfeatures - assert qweight.shape == (infeatures // 32 * self.bits, outfeatures) - assert infeatures % self.group_size == 0 - assert qzeros.shape == ( - infeatures // self.group_size, - outfeatures // 32 * self.bits, - ) - assert scales.shape == (infeatures // self.group_size, outfeatures) - assert g_idx.shape == (infeatures,), f"{g_idx.shape}, {infeatures}" - global FIXED_BYTES, LAYERS FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed()) LAYERS.append(self) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 67fda5114..a2cca2ea3 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -281,17 +281,17 @@ class Weights: else: logger.info(f"Using exllama kernels v{HAS_EXLLAMA}") - if use_exllama: + if use_exllama and groupsize != -1: qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0) scales = self.get_sharded(f"{prefix}.scales", dim=0) - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) - g_idx = g_idx - g_idx[0] else: - # The triton kernel reorders the scales/zero points instead of the weight/activation. - # Thus, each rank needs the full qzeros/scales. qzeros = self.get_tensor(f"{prefix}.qzeros") scales = self.get_tensor(f"{prefix}.scales") - g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + + g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0) + + if use_exllama: + g_idx = g_idx - g_idx[0] weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) elif quantize == "awq":