(fix) fp8 scaling for cuda

This commit is contained in:
Mohit Sharma 2024-12-06 14:16:39 +00:00
parent e2454dba40
commit e22cb47fe3
3 changed files with 46 additions and 32 deletions

View File

@ -116,17 +116,17 @@ def paged_attention(
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0 assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty( tmp_output = torch.zeros(
size=(num_seqs, num_heads, max_num_partitions, head_size), size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype, dtype=out.dtype,
device=out.device, device=out.device,
) )
exp_sums = torch.empty( exp_sums = torch.zeros(
size=(num_seqs, num_heads, max_num_partitions), size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32, dtype=torch.float32,
device=out.device, device=out.device,
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.zeros_like(exp_sums)
if not use_custom: if not use_custom:
ops.paged_attention_v2( ops.paged_attention_v2(

View File

@ -3,8 +3,13 @@ from typing import List, Optional, Union
import torch import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationType from compressed_tensors.quantization import QuantizationArgs, QuantizationType
from text_generation_server.layers.fp8 import Fp8Weight, _load_scalar_or_matrix_scale from text_generation_server.layers.fp8 import (
Fp8Weight,
_load_scalar_or_matrix_scale,
requantize_with_max_scale,
)
from text_generation_server.utils.weights import Weights, WeightsLoader from text_generation_server.utils.weights import Weights, WeightsLoader
from text_generation_server.utils.import_utils import SYSTEM
class W8ANFpLoader(WeightsLoader): class W8ANFpLoader(WeightsLoader):
@ -47,11 +52,10 @@ class W8ANFpLoader(WeightsLoader):
weight_scale = None weight_scale = None
if self.load_weight_scale: if self.load_weight_scale:
weight_scale = ( weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1) if SYSTEM == "cuda":
.expand(w.shape[0]) weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
)
input_scale = None input_scale = None
if self.load_input_scale: if self.load_input_scale:
@ -87,7 +91,8 @@ class W8ANFpLoader(WeightsLoader):
block_sizes=block_sizes, block_sizes=block_sizes,
to_dtype=False, to_dtype=False,
) )
weight_scale = weight_scale.reshape(-1).expand(w.shape[0]) if SYSTEM == "cuda":
weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
input_scale = None input_scale = None
if self.load_input_scale: if self.load_input_scale:
@ -127,6 +132,12 @@ class W8ANFpLoader(WeightsLoader):
] ]
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1) weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
if weight_scale.numel() == len(prefixes):
logical_widths = [x[0] for x in shapes]
w, weight_scale = requantize_with_max_scale(
w, weight_scale.to(weights.device), logical_widths
)
input_scale = None input_scale = None
if self.load_input_scale: if self.load_input_scale:
input_scale = [ input_scale = [
@ -153,11 +164,10 @@ class W8ANFpLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1) w = weights.get_sharded(f"{prefix}.weight", dim=1)
weight_scale = None weight_scale = None
if self.load_weight_scale: if self.load_weight_scale:
weight_scale = ( weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1) if SYSTEM == "cuda":
.expand(w.shape[0]) weight_scale = weight_scale.reshape(-1).expand(w.shape[0])
)
input_scale = None input_scale = None
if self.load_input_scale: if self.load_input_scale:

View File

@ -167,11 +167,10 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
# FP8 branch # FP8 branch
scale = ( scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1) if SYSTEM == "cuda":
.expand(w.shape[0]) scale.reshape(-1).expand(w.shape[0])
).max()
input_scale = None input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"): if weights.has_tensor(f"{prefix}.input_scale"):
@ -206,6 +205,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
# FP8 branch # FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1: if scale.numel() > 1:
scale = weights.get_packed_sharded( scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", f"{prefix}.weight_scale",
@ -213,7 +213,8 @@ class HybridFP8UnquantLoader(WeightsLoader):
block_sizes=block_sizes, block_sizes=block_sizes,
to_dtype=False, to_dtype=False,
) )
scale = scale.reshape(-1).expand(w.shape[0]).max() if SYSTEM == "cuda":
scale = scale.reshape(-1).expand(w.shape[0])
input_scale = None input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"): if weights.has_tensor(f"{prefix}.input_scale"):
@ -255,15 +256,15 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
scale = [ scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
.max()
.unsqueeze(0)
for p, shape in zip(prefixes, shapes) for p, shape in zip(prefixes, shapes)
] ]
scale = torch.cat(scale).to(weights.device) scale = torch.cat(scale, dim=0).reshape(-1)
logical_widths = [x[0] for x in shapes] if scale.numel() == len(prefixes):
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(w, scale, logical_widths) w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths
)
input_scale = [ input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape) _load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
@ -293,11 +294,11 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1) w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
scale = ( scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1) if SYSTEM == "cuda":
.expand(w.shape[0]) scale = scale.reshape(-1).expand(w.shape[0])
).max()
input_scale = None input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"): if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = ( input_scale = (
@ -479,6 +480,9 @@ class Fp8Linear(torch.nn.Module):
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False) scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1: if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False) scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
elif SYSTEM == "rocm":
return scale.reshape(-1)
return scale.reshape(-1).expand(shape[0]) return scale.reshape(-1).expand(shape[0])