post process exllama model

This commit is contained in:
IlyasMoutawwakil 2024-02-01 12:48:17 +01:00
parent 12c1f54525
commit 5766c55b7a

View File

@ -63,20 +63,27 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
if self.quantize == "gptq": if self.quantize in ["gptq", "awq"]:
try: has_exllama_layers = False
# When using GPTQ, Exllama kernels need some global kernels for _, module in self.model.model.named_modules():
# For which we have the finale shapes only after the model has loaded if hasattr(module, "QUANT_TYPE"):
# This will allocate those buffers. has_exllama_layers = True
from text_generation_server.utils.layers import ( break
create_exllama_buffers,
set_device,
)
set_device(self.model.device) if has_exllama_layers:
create_exllama_buffers(request.max_prefill_tokens) try:
except ImportError: # When using GPTQ or AWQ, Exllama kernels need some global kernels
pass # For which we have the finale shapes only after the model has loaded
# This will allocate those buffers.
from text_generation_server.utils.layers import (
create_exllama_buffers,
set_device,
)
set_device(self.model.device)
create_exllama_buffers(request.max_prefill_tokens)
except ImportError:
pass
if ( if (
self.model.batch_type == IdeficsCausalLMBatch self.model.batch_type == IdeficsCausalLMBatch