| 
									
										
										
										
											2024-07-24 14:33:26 +00:00
										 |  |  | from typing import Optional | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import torch.nn as nn | 
					
						
							|  |  |  | from loguru import logger | 
					
						
							|  |  |  | from text_generation_server.layers.fp8 import fp8_quantize | 
					
						
							|  |  |  | from text_generation_server.layers.marlin.gptq import _check_valid_shape | 
					
						
							|  |  |  | from text_generation_server.layers.marlin.util import ( | 
					
						
							|  |  |  |     _check_marlin_kernels, | 
					
						
							|  |  |  |     permute_scales, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | from text_generation_server.utils.log import log_once | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | try: | 
					
						
							|  |  |  |     import marlin_kernels | 
					
						
							|  |  |  | except ImportError: | 
					
						
							|  |  |  |     marlin_kernels = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | MARLIN_TILE_SIZE = 16 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class GPTQMarlinFP8Linear(nn.Module): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     FP8 GPTQ-Marlin linear layer. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         qweight: torch.Tensor, | 
					
						
							|  |  |  |         scales: torch.Tensor, | 
					
						
							|  |  |  |         bias: Optional[torch.Tensor], | 
					
						
							|  |  |  |     ) -> None: | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         _check_marlin_kernels() | 
					
						
							|  |  |  |         assert marlin_kernels is not None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         scales = scales.unsqueeze(0) | 
					
						
							|  |  |  |         if scales.shape[1] == 1: | 
					
						
							|  |  |  |             out_features, in_features = qweight.shape | 
					
						
							|  |  |  |             scales = scales.repeat(1, out_features) | 
					
						
							|  |  |  |         qweight, scales = repack_fp8_for_marlin(qweight, scales) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         in_features = qweight.shape[0] * MARLIN_TILE_SIZE | 
					
						
							|  |  |  |         out_features = scales.shape[1] | 
					
						
							|  |  |  |         _check_valid_shape(in_features=in_features, out_features=out_features) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.qweight = qweight | 
					
						
							|  |  |  |         self.scales = scales | 
					
						
							|  |  |  |         self.bias = bias if bias is not None else None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.workspace = torch.zeros( | 
					
						
							|  |  |  |             out_features // 64 * 16, dtype=torch.int, device=qweight.device | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							|  |  |  |     def from_unquant(cls, weight, bias, dtype): | 
					
						
							|  |  |  |         qweight, scales = fp8_quantize(weight) | 
					
						
							|  |  |  |         return cls(qweight=qweight, scales=scales.to(dtype), bias=bias) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @classmethod | 
					
						
							| 
									
										
										
										
											2024-10-16 07:54:50 +00:00
										 |  |  |     def from_fp8( | 
					
						
							|  |  |  |         cls, | 
					
						
							|  |  |  |         weight: torch.Tensor, | 
					
						
							|  |  |  |         scale: torch.Tensor, | 
					
						
							|  |  |  |         bias: torch.Tensor, | 
					
						
							|  |  |  |         dtype: torch.dtype, | 
					
						
							|  |  |  |         **kwargs, | 
					
						
							|  |  |  |     ): | 
					
						
							| 
									
										
										
										
											2024-07-24 14:33:26 +00:00
										 |  |  |         return cls(qweight=weight, scales=scale.to(dtype), bias=bias) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, A: torch.Tensor) -> torch.Tensor: | 
					
						
							|  |  |  |         assert marlin_kernels is not None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         A_flat = A.view(-1, A.shape[-1]) | 
					
						
							|  |  |  |         C = marlin_kernels.fp8_marlin_gemm( | 
					
						
							|  |  |  |             A_flat, | 
					
						
							|  |  |  |             self.qweight, | 
					
						
							|  |  |  |             self.scales, | 
					
						
							|  |  |  |             self.workspace, | 
					
						
							|  |  |  |             8, | 
					
						
							|  |  |  |             A_flat.shape[0], | 
					
						
							|  |  |  |             self.scales.shape[1], | 
					
						
							|  |  |  |             A_flat.shape[1], | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         C = C.reshape(A.shape[:-1] + (self.scales.shape[1],)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if self.bias is not None: | 
					
						
							|  |  |  |             C += self.bias | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return C | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Repack FP8 weights to gptq format (packed int32 elements). | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     assert fp8_tensor.dtype == torch.float8_e4m3fn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if fp8_tensor.shape[0] % 4 != 0: | 
					
						
							|  |  |  |         raise ValueError( | 
					
						
							|  |  |  |             f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}" | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Reshape to prepare for packing | 
					
						
							|  |  |  |     reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Convert fp8 to uint8 (byte) representation | 
					
						
							|  |  |  |     byte_tensor = reshaped.view(torch.uint8) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Pack 4 uint8 values into one int32 | 
					
						
							|  |  |  |     packed = torch.zeros( | 
					
						
							|  |  |  |         fp8_tensor.shape[0] // 4, | 
					
						
							|  |  |  |         fp8_tensor.shape[1], | 
					
						
							|  |  |  |         dtype=torch.int32, | 
					
						
							|  |  |  |         device=fp8_tensor.device, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for i in range(4): | 
					
						
							|  |  |  |         packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return packed | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Repack FP8 tensor for GPTQ-Marlin. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     out_features, in_features = weight.shape | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Torch linear layers weights with shape [out_features, in_features], | 
					
						
							|  |  |  |     # GPTQ-quantized weights use [in_feateres/pack_factor, in_features], | 
					
						
							|  |  |  |     # so transpose before packing. | 
					
						
							|  |  |  |     qweight = pack_fp8_as_int32(weight.t()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     perm = torch.empty(0, dtype=torch.int, device=qweight.device) | 
					
						
							|  |  |  |     repacked = marlin_kernels.gptq_marlin_repack( | 
					
						
							|  |  |  |         qweight, perm, in_features, out_features, 8 | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     scales = permute_scales(scales) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return repacked, scales |