mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-07-23 08:10:18 +00:00
Dummy but working version.
This commit is contained in:
parent
50d5a3c11e
commit
e1e9a18433
@ -181,6 +181,22 @@ class EETQLinear(nn.Module):
|
||||
output = output + self.bias if self.bias is not None else output
|
||||
return output
|
||||
|
||||
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
||||
device = weight.device
|
||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||
finfo = torch.finfo(qdtype)
|
||||
# Calculate the scale as dtype max divided by absmax
|
||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
# Return both float8 data and the inverse scale (as float),
|
||||
# as both required as inputs to torch._scaled_mm
|
||||
qweight = qweight.to(qdtype)
|
||||
scale = scale.float().reciprocal()
|
||||
return qweight, scale
|
||||
|
||||
class Fp8Linear(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -188,34 +204,21 @@ class Fp8Linear(nn.Module):
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
device = weight.device
|
||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||
finfo = torch.finfo(weight.dtype)
|
||||
qdtype = torch.float8_e4m3fn
|
||||
# Calculate the scale as dtype max divided by absmax
|
||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
x_scl_sat = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
# Return both float8 data and the inverse scale (as float),
|
||||
# as both required as inputs to torch._scaled_mm
|
||||
self.dtype = weight.dtype
|
||||
self.qweight = x_scl_sat.to(qdtype).to(device=device)
|
||||
self.scale = scale.float().reciprocal().to(device=device)
|
||||
self.qweight, self.scale = fp8_quantize(weight)
|
||||
self.bias = bias.cuda(device) if bias is not None else None
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
finfo = torch.finfo(input.dtype)
|
||||
scale = finfo.max / input.abs().max().clamp(min=1e-12)
|
||||
qinput = (input * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
|
||||
output, _ = torch._scaled_mm(qinput, self.qweight, out_dtype=torch.float16,
|
||||
scale_a=scale , scale_b=self.scale)
|
||||
output = output + self.bias if self.bias is not None else output
|
||||
qinput, scale = fp8_quantize(input)
|
||||
seqlen = qinput.shape[0]
|
||||
if seqlen % 16 != 0:
|
||||
missing = 16 - seqlen % 16
|
||||
qinput = F.pad(qinput, (0, 0, 0, missing), "constant", value=0)
|
||||
output, _ = torch._scaled_mm(qinput, self.qweight.t(), out_dtype=self.dtype,
|
||||
scale_a=scale , scale_b=self.scale, bias=self.bias)
|
||||
output = output[:seqlen]
|
||||
return output
|
||||
|
||||
|
||||
class Linear8bitLt(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user