diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index dfc42a5e5..3e5340122 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -467,31 +467,8 @@ class FlashSantacoderForCausalLM(nn.Module): self.lm_head = TensorParallelHead.load( config, prefix="transformer.wte", weights=weights ) + self.config = config - # Buffers need to be persistent to avoid any bug. - self.buffers = {} - if config.quantize == "gptq": - max_dq_buffer_size = 0 - for name, submodule in self.named_modules(): - if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear): - max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8) - - intermediate_size = config.n_inner - max_seq_len = 2048 # TODO: we should be able to set it - - self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=weights.device) - self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=weights.device) - - prepare_buffers(weights.device, self.buffers["temp_state"], self.buffers["temp_dq"]) - - # TODO: ability to set them - matmul_recons_thd = 8 - matmul_fused_remap = False - matmul_no_half2 = False - set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) - - torch.cuda.empty_cache() - def forward( self, input_ids: torch.Tensor, diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index f8460fc2f..95dd54478 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -8,8 +8,15 @@ from transformers import PreTrainedTokenizerBase from text_generation_server.models.types import Batch, GeneratedText from text_generation_server.pb.generate_pb2 import InfoResponse -B = TypeVar("B", bound=Batch) +from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear +from custom_kernels.exllama import prepare_buffers, set_tuning_params +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear +) + +B = TypeVar("B", bound=Batch) class Model(ABC): def __init__( @@ -39,6 +46,30 @@ class Model(ABC): is not None ) + if model.config.quantize == "gptq": + # Buffers need to be persistent to avoid any bug. + self.buffers = {} + max_dq_buffer_size = 0 + for name, submodule in self.model.named_modules(): + if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear): + max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8) + + intermediate_size = model.config.n_inner + max_seq_len = 2048 # TODO: we should be able to set it + + self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=device) + self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) + + prepare_buffers(device, self.buffers["temp_state"], self.buffers["temp_dq"]) + + # TODO: ability to set them + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) + + torch.cuda.empty_cache() + self.check_initialized() @property diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 48a227853..6ad085dc6 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import List, Dict, Optional from safetensors import safe_open import torch -from loguru import logger class Weights: def __init__(