mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
Mostly straightforward, changes to existing code: * Wrap quantizer parameters in a small wrapper to avoid passing around untyped tuples and needing to repack them as a dict. * Move scratch space computation to warmup, because we need the maximum input sequence length to avoid allocating huge scratch buffers that OOM.
24 lines
469 B
Python
24 lines
469 B
Python
import torch
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class Exl2Weight:
|
|
"""
|
|
Exllama2 exl2 quantized weights.
|
|
"""
|
|
|
|
q_weight: torch.Tensor
|
|
q_scale: torch.Tensor
|
|
q_invperm: torch.Tensor
|
|
q_scale_max: torch.Tensor
|
|
q_groups: torch.Tensor
|
|
|
|
def __post_init__(self):
|
|
self.q_scale_max /= 256
|
|
self.q_invperm = self.q_invperm.short()
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.q_weight.device
|