move exllama buffer init to the top level

This commit is contained in:
Felix Marty 2023-07-12 16:09:26 +00:00
parent 4462854e1b
commit 67a46b7361
3 changed files with 33 additions and 26 deletions

View File

@ -467,30 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module):
self.lm_head = TensorParallelHead.load(
config, prefix="transformer.wte", weights=weights
)
# Buffers need to be persistent to avoid any bug.
self.buffers = {}
if config.quantize == "gptq":
max_dq_buffer_size = 0
for name, submodule in self.named_modules():
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
intermediate_size = config.n_inner
max_seq_len = 2048 # TODO: we should be able to set it
self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=weights.device)
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=weights.device)
prepare_buffers(weights.device, self.buffers["temp_state"], self.buffers["temp_dq"])
# TODO: ability to set them
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
torch.cuda.empty_cache()
self.config = config
def forward(
self,

View File

@ -8,8 +8,15 @@ from transformers import PreTrainedTokenizerBase
from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch)
from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear
from custom_kernels.exllama import prepare_buffers, set_tuning_params
from text_generation_server.utils.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear
)
B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(
@ -39,6 +46,30 @@ class Model(ABC):
is not None
)
if model.config.quantize == "gptq":
# Buffers need to be persistent to avoid any bug.
self.buffers = {}
max_dq_buffer_size = 0
for name, submodule in self.model.named_modules():
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
intermediate_size = model.config.n_inner
max_seq_len = 2048 # TODO: we should be able to set it
self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=device)
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
prepare_buffers(device, self.buffers["temp_state"], self.buffers["temp_dq"])
# TODO: ability to set them
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
torch.cuda.empty_cache()
self.check_initialized()
@property

View File

@ -2,7 +2,6 @@ from pathlib import Path
from typing import List, Dict, Optional
from safetensors import safe_open
import torch
from loguru import logger
class Weights:
def __init__(