mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
move exllama buffer init to the top level
This commit is contained in:
parent
4462854e1b
commit
67a46b7361
@ -467,30 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = TensorParallelHead.load(
|
||||||
config, prefix="transformer.wte", weights=weights
|
config, prefix="transformer.wte", weights=weights
|
||||||
)
|
)
|
||||||
|
self.config = config
|
||||||
# Buffers need to be persistent to avoid any bug.
|
|
||||||
self.buffers = {}
|
|
||||||
if config.quantize == "gptq":
|
|
||||||
max_dq_buffer_size = 0
|
|
||||||
for name, submodule in self.named_modules():
|
|
||||||
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
|
|
||||||
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
|
|
||||||
|
|
||||||
intermediate_size = config.n_inner
|
|
||||||
max_seq_len = 2048 # TODO: we should be able to set it
|
|
||||||
|
|
||||||
self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=weights.device)
|
|
||||||
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=weights.device)
|
|
||||||
|
|
||||||
prepare_buffers(weights.device, self.buffers["temp_state"], self.buffers["temp_dq"])
|
|
||||||
|
|
||||||
# TODO: ability to set them
|
|
||||||
matmul_recons_thd = 8
|
|
||||||
matmul_fused_remap = False
|
|
||||||
matmul_no_half2 = False
|
|
||||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -8,8 +8,15 @@ from transformers import PreTrainedTokenizerBase
|
|||||||
from text_generation_server.models.types import Batch, GeneratedText
|
from text_generation_server.models.types import Batch, GeneratedText
|
||||||
from text_generation_server.pb.generate_pb2 import InfoResponse
|
from text_generation_server.pb.generate_pb2 import InfoResponse
|
||||||
|
|
||||||
B = TypeVar("B", bound=Batch)
|
from text_generation_server.utils.gptq.quant_linear import Ex4bitLinear
|
||||||
|
from custom_kernels.exllama import prepare_buffers, set_tuning_params
|
||||||
|
|
||||||
|
from text_generation_server.utils.layers import (
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
TensorParallelColumnLinear
|
||||||
|
)
|
||||||
|
|
||||||
|
B = TypeVar("B", bound=Batch)
|
||||||
|
|
||||||
class Model(ABC):
|
class Model(ABC):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -39,6 +46,30 @@ class Model(ABC):
|
|||||||
is not None
|
is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model.config.quantize == "gptq":
|
||||||
|
# Buffers need to be persistent to avoid any bug.
|
||||||
|
self.buffers = {}
|
||||||
|
max_dq_buffer_size = 0
|
||||||
|
for name, submodule in self.model.named_modules():
|
||||||
|
if isinstance(submodule, (TensorParallelColumnLinear, TensorParallelRowLinear)) and isinstance(submodule.linear, Ex4bitLinear):
|
||||||
|
max_dq_buffer_size = max(max_dq_buffer_size, submodule.linear.qweight.numel() * 8)
|
||||||
|
|
||||||
|
intermediate_size = model.config.n_inner
|
||||||
|
max_seq_len = 2048 # TODO: we should be able to set it
|
||||||
|
|
||||||
|
self.buffers["temp_state"] = torch.zeros((max_seq_len, intermediate_size), dtype=torch.float16, device=device)
|
||||||
|
self.buffers["temp_dq"] = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device)
|
||||||
|
|
||||||
|
prepare_buffers(device, self.buffers["temp_state"], self.buffers["temp_dq"])
|
||||||
|
|
||||||
|
# TODO: ability to set them
|
||||||
|
matmul_recons_thd = 8
|
||||||
|
matmul_fused_remap = False
|
||||||
|
matmul_no_half2 = False
|
||||||
|
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
self.check_initialized()
|
self.check_initialized()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -2,7 +2,6 @@ from pathlib import Path
|
|||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
class Weights:
|
class Weights:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
Loading…
Reference in New Issue
Block a user