Black + ruff + T5 w0 quant.

This commit is contained in:
Ubuntu 2023-05-24 09:35:29 +00:00 committed by Nicolas Patry
parent 15bf3d4944
commit 2362a80a4f
11 changed files with 137 additions and 71 deletions

View File

@ -1,15 +1,14 @@
from setuptools import setup from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup( setup(
name='custom_kernels', name="custom_kernels",
ext_modules=[ ext_modules=[
CUDAExtension( CUDAExtension(
name="custom_kernels.fused_bloom_attention_cuda", name="custom_kernels.fused_bloom_attention_cuda",
sources=['custom_kernels/fused_bloom_attention_cuda.cu'], sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"], extra_compile_args=["-arch=compute_80", "-std=c++17"],
) )
], ],
cmdclass={ cmdclass={"build_ext": BuildExtension},
'build_ext': BuildExtension
}
) )

View File

@ -37,7 +37,6 @@ from text_generation_server.utils.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, TensorParallelHead,
FastLinear
) )
CUSTOM_KERNELS_ENABLED = False CUSTOM_KERNELS_ENABLED = False

View File

@ -60,12 +60,18 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_sharded(f"{prefix}.weight", dim=0)
bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0)
weight = weight.view( weight = (
num_heads, 3, head_size, hidden_size, weight.view(
).permute(1, 0, 2, 3).reshape(-1, hidden_size) num_heads,
3,
head_size,
hidden_size,
)
.permute(1, 0, 2, 3)
.reshape(-1, hidden_size)
)
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
linear = get_linear(weight, bias, config.quantize) linear = get_linear(weight, bias, config.quantize)
if config.use_parallel_residual: if config.use_parallel_residual:
return linear return linear
@ -88,12 +94,18 @@ class FlashNeoxAttention(torch.nn.Module):
rotary_ndims = int(self.head_size * rotary_pct) rotary_ndims = int(self.head_size * rotary_pct)
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
self.rotary_emb.inv_freq = nn.Parameter(weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")) self.rotary_emb.inv_freq = nn.Parameter(
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
)
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
self.query_key_value = load_qkv( self.query_key_value = load_qkv(
config, prefix=f"{prefix}.query_key_value", weights=weights, config,
num_heads = self.num_heads, head_size = self.head_size, hidden_size = self.hidden_size prefix=f"{prefix}.query_key_value",
weights=weights,
num_heads=self.num_heads,
head_size=self.head_size,
hidden_size=self.hidden_size,
) )
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True config, prefix=f"{prefix}.dense", weights=weights, bias=True

View File

@ -3,7 +3,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List from typing import Optional
# Flash attention imports # Flash attention imports
import flash_attn_cuda import flash_attn_cuda
@ -17,8 +17,9 @@ from text_generation_server.utils.layers import (
) )
def load_multi_mqa(
def load_multi_mqa(config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size): config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if any("c_attn" in k for k in weights.routing.keys()): if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight") slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape() shape = slice_.get_shape()
@ -55,30 +56,35 @@ def load_multi_mqa(config, prefix: str, weights, bias: bool, head_size, num_head
if config.transpose: if config.transpose:
w = [ w = [
weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T, weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T,
weights.get_tensor(f"{prefix}.kv_attn.weight").T weights.get_tensor(f"{prefix}.kv_attn.weight").T,
] ]
weight = torch.cat(w, dim=0) weight = torch.cat(w, dim=0)
else: else:
w = [ w = [
weights.get_sharded(f"{prefix}.q_attn.weight", dim=0), weights.get_sharded(f"{prefix}.q_attn.weight", dim=0),
weights.get_tensor(f"{prefix}.kv_attn.weight") weights.get_tensor(f"{prefix}.kv_attn.weight"),
] ]
weight = torch.cat(w, dim=1) weight = torch.cat(w, dim=1)
if bias: if bias:
b = [ b = [
weights.get_sharded(f"{prefix}.q_attn.bias", dim=0), weights.get_sharded(f"{prefix}.q_attn.bias", dim=0),
weights.get_tensor(f"{prefix}.kv_attn.bias") weights.get_tensor(f"{prefix}.kv_attn.bias"),
] ]
bias = torch.cat(b, dim=0) bias = torch.cat(b, dim=0)
else: else:
bias = None bias = None
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight = weight.to(dtype=weights.dtype).to(device=weights.device)
assert list(weight.shape) == [(num_heads + 2) * head_size, hidden_size], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}" assert list(weight.shape) == [
(num_heads + 2) * head_size,
hidden_size,
], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}"
if bias is not None: if bias is not None:
bias = bias.to(dtype=weights.dtype).to(device=weights.device) bias = bias.to(dtype=weights.dtype).to(device=weights.device)
assert list(bias.shape) == [(num_heads + 2) * head_size], f"{weight.shape} != {[(num_heads + 2) * head_size]}" assert list(bias.shape) == [
(num_heads + 2) * head_size
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
@ -106,7 +112,9 @@ def load_row(config, prefix: str, weights, bias: bool):
bias = weights.get_tensor(f"{prefix}.bias") bias = weights.get_tensor(f"{prefix}.bias")
else: else:
bias = None bias = None
return TensorParallelRowLinear(get_linear(weight, bias, config.quantize), process_group=weights.process_group) return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group
)
class FlashMQAttention(torch.nn.Module): class FlashMQAttention(torch.nn.Module):
@ -131,7 +139,7 @@ class FlashMQAttention(torch.nn.Module):
bias=True, bias=True,
head_size=self.head_size, head_size=self.head_size,
hidden_size=hidden_size, hidden_size=hidden_size,
num_heads=self.num_heads num_heads=self.num_heads,
) )
self.c_proj = load_row( self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True config, prefix=f"{prefix}.c_proj", weights=weights, bias=True

View File

@ -109,9 +109,21 @@ class T5DenseActDense(nn.Module):
self.wi = TensorParallelColumnLinear.load( self.wi = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.wi", weights=weights, bias=False config, prefix=f"{prefix}.wi", weights=weights, bias=False
) )
### XXX: T5 models do not handle well both f16 and quantization.
### Overidding specifically this layer for that reason.
### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316
### https://github.com/huggingface/transformers/issues/20287
_q = config.quantize
_dtype = weights.dtype
weights.dtype = torch.float32
config.quantize = None
self.wo_cast = (torch.float32, _dtype)
self.wo = TensorParallelRowLinear.load( self.wo = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.wo", weights=weights, bias=False config, prefix=f"{prefix}.wo", weights=weights, bias=False
) )
weights.dtype = _dtype
config.quantize = _q
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
self.act = ( self.act = (
@ -124,7 +136,10 @@ class T5DenseActDense(nn.Module):
hidden_states = self.wi(hidden_states) hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states.to(dtype=self.wo_cast[0])
hidden_states = self.wo(hidden_states) hidden_states = self.wo(hidden_states)
hidden_states = hidden_states.to(dtype=self.wo_cast[1])
return hidden_states return hidden_states
@ -137,9 +152,20 @@ class T5DenseGatedActDense(nn.Module):
self.wi_1 = TensorParallelColumnLinear.load( self.wi_1 = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.wi_1", weights=weights, bias=False config, prefix=f"{prefix}.wi_1", weights=weights, bias=False
) )
### XXX: T5 models do not handle well both f16 and quantization.
### Overidding specifically this layer for that reason.
### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316
### https://github.com/huggingface/transformers/issues/20287
_q = config.quantize
_dtype = weights.dtype
weights.dtype = torch.float32
config.quantize = None
self.wo_cast = (torch.float32, _dtype)
self.wo = TensorParallelRowLinear.load( self.wo = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.wo", weights=weights, bias=False config, prefix=f"{prefix}.wo", weights=weights, bias=False
) )
weights.dtype = _dtype
config.quantize = _q
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
self.act = ( self.act = (
@ -154,18 +180,9 @@ class T5DenseGatedActDense(nn.Module):
hidden_states = hidden_gelu * hidden_linear hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
# TODO Support this again mayber hidden_states = hidden_states.to(dtype=self.wo_cast[0])
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
# if (
# isinstance(self.wo.weight, torch.Tensor)
# and hidden_states.dtype != self.wo.weight.dtype
# and self.wo.weight.dtype != torch.int8
# ):
# hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states) hidden_states = self.wo(hidden_states)
hidden_states = hidden_states.to(dtype=self.wo_cast[1])
return hidden_states return hidden_states

View File

@ -26,7 +26,7 @@ HAS_BITS_AND_BYTES = True
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params from bitsandbytes.nn import Int8Params
except Exception as e: except Exception:
HAS_BITS_AND_BYTES = False HAS_BITS_AND_BYTES = False

View File

@ -40,10 +40,11 @@ class OPTSharded(CausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
) config = AutoConfig.from_pretrained(
config = AutoConfig.from_pretrained(model_id, revision=revision, model_id,
revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id

View File

@ -3,6 +3,7 @@ import torch
from datetime import timedelta from datetime import timedelta
class FakeBarrier: class FakeBarrier:
def wait(self): def wait(self):
pass pass
@ -17,7 +18,9 @@ class FakeGroup:
return FakeBarrier() return FakeBarrier()
def allgather(self, inputs, local_tensor, **kwargs): def allgather(self, inputs, local_tensor, **kwargs):
assert len(inputs[0]) == len(local_tensor) == 1, f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" assert (
len(inputs[0]) == len(local_tensor) == 1
), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
for input_ in inputs: for input_ in inputs:
input_[0].data = local_tensor[0].data input_[0].data = local_tensor[0].data
return FakeBarrier() return FakeBarrier()

View File

@ -10,8 +10,7 @@ from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import ( from huggingface_hub.utils import (
LocalEntryNotFoundError, LocalEntryNotFoundError,
EntryNotFoundError, EntryNotFoundError, # Import here to ease try/except in other part of the lib
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
) )
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)

View File

@ -2,14 +2,14 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import Optional, List from typing import List
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params from bitsandbytes.nn import Int8Params
except ImportError as e: except ImportError:
HAS_BITS_AND_BYTES = False HAS_BITS_AND_BYTES = False
from accelerate import init_empty_weights from accelerate import init_empty_weights
@ -27,14 +27,16 @@ def load_layer_norm(cls, prefix, weights, eps):
ln.bias = nn.Parameter(bias) ln.bias = nn.Parameter(bias)
return ln return ln
torch.nn.LayerNorm.load = load_layer_norm torch.nn.LayerNorm.load = load_layer_norm
class FastLinear(nn.Module): class FastLinear(nn.Module):
def __init__( def __init__(
self, self,
weight, bias, weight,
) -> None: bias,
) -> None:
super().__init__() super().__init__()
self.weight = nn.Parameter(weight) self.weight = nn.Parameter(weight)
if bias is not None: if bias is not None:
@ -56,10 +58,19 @@ class FastLinear(nn.Module):
class Linear8bitLt(nn.Module): class Linear8bitLt(nn.Module):
def __init__(self, weight, bias, has_fp16_weights=True, def __init__(
memory_efficient_backward=False, threshold=0.0, index=None): self,
weight,
bias,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
super().__init__() super().__init__()
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" assert (
not memory_efficient_backward
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState() self.state = bnb.MatmulLtState()
self.index = index self.index = index
@ -70,7 +81,11 @@ class Linear8bitLt(nn.Module):
if threshold > 0.0 and not has_fp16_weights: if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True self.state.use_pool = True
self.weight = Int8Params(weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) self.weight = Int8Params(
weight.data,
has_fp16_weights=has_fp16_weights,
requires_grad=has_fp16_weights,
)
self.weight.cuda(weight.device) self.weight.cuda(weight.device)
self.bias = bias self.bias = bias
@ -105,7 +120,8 @@ def get_linear(weight, bias, quantize):
linear = FastLinear(weight, bias) linear = FastLinear(weight, bias)
elif quantize == "bitsandbytes": elif quantize == "bitsandbytes":
linear = Linear8bitLt( linear = Linear8bitLt(
weight, bias, weight,
bias,
has_fp16_weights=False, has_fp16_weights=False,
threshold=6.0, threshold=6.0,
) )
@ -114,7 +130,9 @@ def get_linear(weight, bias, quantize):
elif quantize == "gptq": elif quantize == "gptq":
raise NotImplementedError("Soon") raise NotImplementedError("Soon")
else: else:
raise NotImplementedError(f"Quantization `{config.quantize}` is not implemented yet.") raise NotImplementedError(
f"Quantization `{config.quantize}` is not implemented yet."
)
return linear return linear
@ -126,6 +144,7 @@ class SuperLayer(nn.Module):
def forward(self, x): def forward(self, x):
return self.linear.forward(x) return self.linear.forward(x)
class TensorParallelHead(SuperLayer): class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group): def __init__(self, linear, process_group):
super().__init__(linear) super().__init__(linear)
@ -134,12 +153,17 @@ class TensorParallelHead(SuperLayer):
@staticmethod @staticmethod
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
weight = weights.get_sharded(f"{prefix}.weight", dim=0) weight = weights.get_sharded(f"{prefix}.weight", dim=0)
return TensorParallelHead(get_linear(weight, bias=None, quantize=config.quantize), process_group = weights.process_group) return TensorParallelHead(
get_linear(weight, bias=None, quantize=config.quantize),
process_group=weights.process_group,
)
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
output = super().forward(input) output = super().forward(input)
# Logits are sharded, so we need to gather them # Logits are sharded, so we need to gather them
world_output = [torch.empty_like(output) for _ in range(self.process_group.size())] world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group) torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1) world_output = torch.cat(world_output, dim=-1)
return world_output return world_output
@ -181,13 +205,17 @@ class TensorParallelRowLinear(SuperLayer):
bias = weights.get_tensor(f"{prefix}.bias") bias = weights.get_tensor(f"{prefix}.bias")
else: else:
bias = None bias = None
return TensorParallelRowLinear(get_linear(weight, bias, config.quantize), process_group=weights.process_group) return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize),
process_group=weights.process_group,
)
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super().forward(input) out = super().forward(input)
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
return out return out
class TensorParallelEmbedding(nn.Module): class TensorParallelEmbedding(nn.Module):
def __init__(self, prefix: str, weights, reduce=True): def __init__(self, prefix: str, weights, reduce=True):
super().__init__() super().__init__()
@ -222,6 +250,7 @@ class TensorParallelEmbedding(nn.Module):
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
return out return out
try: try:
import dropout_layer_norm import dropout_layer_norm

View File

@ -1,7 +1,8 @@
from pathlib import Path from pathlib import Path
from typing import Optional, List from typing import List
from safetensors import safe_open from safetensors import safe_open
class Weights: class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group): def __init__(self, filenames: List[Path], device, dtype, process_group):
routing = {} routing = {}
@ -26,8 +27,6 @@ class Weights:
return self._handles[filename] return self._handles[filename]
def get_filename(self, tensor_name: str) -> str: def get_filename(self, tensor_name: str) -> str:
filename = self.routing.get(tensor_name, None) filename = self.routing.get(tensor_name, None)
if filename is None: if filename is None:
@ -63,7 +62,9 @@ class Weights:
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
assert size % world_size == 0, f"The choosen size {size} is not compatible with sharding on {world_size} shards" assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
if dim == 0: if dim == 0:
tensor = slice_[start:stop] tensor = slice_[start:stop]
@ -74,5 +75,3 @@ class Weights:
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor