mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Style.
This commit is contained in:
parent
66195d832c
commit
a352563ee0
@ -181,6 +181,7 @@ class EETQLinear(nn.Module):
|
|||||||
output = output + self.bias if self.bias is not None else output
|
output = output + self.bias if self.bias is not None else output
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
||||||
device = weight.device
|
device = weight.device
|
||||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||||
@ -197,6 +198,7 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
|
|||||||
scale = scale.float().reciprocal()
|
scale = scale.float().reciprocal()
|
||||||
return qweight, scale
|
return qweight, scale
|
||||||
|
|
||||||
|
|
||||||
class Fp8Linear(nn.Module):
|
class Fp8Linear(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -206,14 +208,22 @@ class Fp8Linear(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = weight.dtype
|
self.dtype = weight.dtype
|
||||||
self.qweight, self.scale = fp8_quantize(weight)
|
self.qweight, self.scale = fp8_quantize(weight)
|
||||||
|
|
||||||
self.bias = bias.cuda(device) if bias is not None else None
|
self.bias = bias.cuda(device) if bias is not None else None
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
qinput, scale = fp8_quantize(input)
|
qinput, scale = fp8_quantize(input)
|
||||||
output, _ = torch._scaled_mm(qinput, self.qweight.t(), out_dtype=self.dtype,
|
output, _ = torch._scaled_mm(
|
||||||
scale_a=scale , scale_b=self.scale, bias=self.bias)
|
qinput,
|
||||||
|
self.qweight.t(),
|
||||||
|
out_dtype=self.dtype,
|
||||||
|
scale_a=scale,
|
||||||
|
scale_b=self.scale,
|
||||||
|
bias=self.bias,
|
||||||
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class Linear8bitLt(nn.Module):
|
class Linear8bitLt(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user