mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
move exllama buffer init to the top level
This commit is contained in:
parent
4462854e1b
commit
67a46b7361
@ -467,30 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config, prefix="transformer.wte", weights=weights
|
||||
)
|
||||
|
||||
# Buffers need to be persistent to avoid any bug.
|
||||
self.buffers = {}
|
||||
if config.quantize == "gptq":
|
||||
max_dq_buffer_size = 0
|
||||
for name, submodule in self.named_modules():
|
||||
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
|
||||
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
|
||||
|
||||
intermediate_size = config.n_inner
|
||||
max_seq_len = 2048 # TODO: we should be able to set it
|
||||
|
||||
self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=weights.device)
|
||||
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=weights.device)
|
||||
|
||||
prepare_buffers(weights.device, self.buffers["temp_state"], self.buffers["temp_dq"])
|
||||
|
||||
# TODO: ability to set them
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -8,8 +8,15 @@ from transformers import PreTrainedTokenizerBase
|
||||
from text_generation_server.models.types import Batch, GeneratedText
|
||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear
|
||||
from custom_kernels.exllama import prepare_buffers, set_tuning_params
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear
|
||||
)
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
||||
class Model(ABC):
|
||||
def __init__(
|
||||
@ -39,6 +46,30 @@ class Model(ABC):
|
||||
is not None
|
||||
)
|
||||
|
||||
if model.config.quantize == "gptq":
|
||||
# Buffers need to be persistent to avoid any bug.
|
||||
self.buffers = {}
|
||||
max_dq_buffer_size = 0
|
||||
for name, submodule in self.model.named_modules():
|
||||
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
|
||||
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
|
||||
|
||||
intermediate_size = model.config.n_inner
|
||||
max_seq_len = 2048 # TODO: we should be able to set it
|
||||
|
||||
self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=device)
|
||||
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
|
||||
|
||||
prepare_buffers(device, self.buffers["temp_state"], self.buffers["temp_dq"])
|
||||
|
||||
# TODO: ability to set them
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.check_initialized()
|
||||
|
||||
@property
|
||||
|
@ -2,7 +2,6 @@ from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from safetensors import safe_open
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
class Weights:
|
||||
def __init__(
|
||||
|
Loading…
Reference in New Issue
Block a user