Black + ruff + T5 w0 quant.

This commit is contained in:
Ubuntu 2023-05-24 09:35:29 +00:00 committed by Nicolas Patry
parent 15bf3d4944
commit 2362a80a4f
11 changed files with 137 additions and 71 deletions

View File

@ -1,15 +1,14 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='custom_kernels',
name="custom_kernels",
ext_modules=[
CUDAExtension(
name="custom_kernels.fused_bloom_attention_cuda",
sources=['custom_kernels/fused_bloom_attention_cuda.cu'],
extra_compile_args=["-arch=compute_80", "-std=c++17"],
name="custom_kernels.fused_bloom_attention_cuda",
sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
extra_compile_args=["-arch=compute_80", "-std=c++17"],
)
],
cmdclass={
'build_ext': BuildExtension
}
cmdclass={"build_ext": BuildExtension},
)

View File

@ -37,7 +37,6 @@ from text_generation_server.utils.layers import (
TensorParallelEmbedding,
TensorParallelRowLinear,
TensorParallelHead,
FastLinear
)
CUSTOM_KERNELS_ENABLED = False

View File

@ -60,12 +60,18 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
bias = weights.get_sharded(f"{prefix}.bias", dim=0)
weight = weight.view(
num_heads, 3, head_size, hidden_size,
).permute(1, 0, 2, 3).reshape(-1, hidden_size)
weight = (
weight.view(
num_heads,
3,
head_size,
hidden_size,
)
.permute(1, 0, 2, 3)
.reshape(-1, hidden_size)
)
bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1)
linear = get_linear(weight, bias, config.quantize)
if config.use_parallel_residual:
return linear
@ -88,12 +94,18 @@ class FlashNeoxAttention(torch.nn.Module):
rotary_ndims = int(self.head_size * rotary_pct)
self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base)
self.rotary_emb.inv_freq = nn.Parameter(weights.get_tensor(f"{prefix}.rotary_emb.inv_freq"))
self.rotary_emb.inv_freq = nn.Parameter(
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
)
self.softmax_scale = self.head_size ** (-0.5)
self.query_key_value = load_qkv(
config, prefix=f"{prefix}.query_key_value", weights=weights,
num_heads = self.num_heads, head_size = self.head_size, hidden_size = self.hidden_size
config,
prefix=f"{prefix}.query_key_value",
weights=weights,
num_heads=self.num_heads,
head_size=self.head_size,
hidden_size=self.hidden_size,
)
self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True

View File

@ -3,7 +3,7 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List
from typing import Optional
# Flash attention imports
import flash_attn_cuda
@ -17,8 +17,9 @@ from text_generation_server.utils.layers import (
)
def load_multi_mqa(config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size):
def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
if any("c_attn" in k for k in weights.routing.keys()):
slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
shape = slice_.get_shape()
@ -55,30 +56,35 @@ def load_multi_mqa(config, prefix: str, weights, bias: bool, head_size, num_head
if config.transpose:
w = [
weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T,
weights.get_tensor(f"{prefix}.kv_attn.weight").T
weights.get_tensor(f"{prefix}.kv_attn.weight").T,
]
weight = torch.cat(w, dim=0)
else:
w = [
weights.get_sharded(f"{prefix}.q_attn.weight", dim=0),
weights.get_tensor(f"{prefix}.kv_attn.weight")
weights.get_tensor(f"{prefix}.kv_attn.weight"),
]
weight = torch.cat(w, dim=1)
if bias:
b = [
weights.get_sharded(f"{prefix}.q_attn.bias", dim=0),
weights.get_tensor(f"{prefix}.kv_attn.bias")
weights.get_tensor(f"{prefix}.kv_attn.bias"),
]
bias = torch.cat(b, dim=0)
else:
bias = None
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
assert list(weight.shape) == [(num_heads + 2) * head_size, hidden_size], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}"
assert list(weight.shape) == [
(num_heads + 2) * head_size,
hidden_size,
], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}"
if bias is not None:
bias = bias.to(dtype=weights.dtype).to(device=weights.device)
assert list(bias.shape) == [(num_heads + 2) * head_size], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
assert list(bias.shape) == [
(num_heads + 2) * head_size
], f"{weight.shape} != {[(num_heads + 2) * head_size]}"
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
@ -106,7 +112,9 @@ def load_row(config, prefix: str, weights, bias: bool):
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return TensorParallelRowLinear(get_linear(weight, bias, config.quantize), process_group=weights.process_group)
return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize), process_group=weights.process_group
)
class FlashMQAttention(torch.nn.Module):
@ -131,7 +139,7 @@ class FlashMQAttention(torch.nn.Module):
bias=True,
head_size=self.head_size,
hidden_size=hidden_size,
num_heads=self.num_heads
num_heads=self.num_heads,
)
self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True

View File

@ -109,9 +109,21 @@ class T5DenseActDense(nn.Module):
self.wi = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.wi", weights=weights, bias=False
)
### XXX: T5 models do not handle well both f16 and quantization.
### Overidding specifically this layer for that reason.
### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316
### https://github.com/huggingface/transformers/issues/20287
_q = config.quantize
_dtype = weights.dtype
weights.dtype = torch.float32
config.quantize = None
self.wo_cast = (torch.float32, _dtype)
self.wo = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.wo", weights=weights, bias=False
)
weights.dtype = _dtype
config.quantize = _q
self.dropout = nn.Dropout(config.dropout_rate)
self.act = (
@ -124,7 +136,10 @@ class T5DenseActDense(nn.Module):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states.to(dtype=self.wo_cast[0])
hidden_states = self.wo(hidden_states)
hidden_states = hidden_states.to(dtype=self.wo_cast[1])
return hidden_states
@ -137,9 +152,20 @@ class T5DenseGatedActDense(nn.Module):
self.wi_1 = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.wi_1", weights=weights, bias=False
)
### XXX: T5 models do not handle well both f16 and quantization.
### Overidding specifically this layer for that reason.
### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316
### https://github.com/huggingface/transformers/issues/20287
_q = config.quantize
_dtype = weights.dtype
weights.dtype = torch.float32
config.quantize = None
self.wo_cast = (torch.float32, _dtype)
self.wo = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.wo", weights=weights, bias=False
)
weights.dtype = _dtype
config.quantize = _q
self.dropout = nn.Dropout(config.dropout_rate)
self.act = (
@ -154,18 +180,9 @@ class T5DenseGatedActDense(nn.Module):
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
# TODO Support this again mayber
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
# if (
# isinstance(self.wo.weight, torch.Tensor)
# and hidden_states.dtype != self.wo.weight.dtype
# and self.wo.weight.dtype != torch.int8
# ):
# hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = hidden_states.to(dtype=self.wo_cast[0])
hidden_states = self.wo(hidden_states)
hidden_states = hidden_states.to(dtype=self.wo_cast[1])
return hidden_states

View File

@ -26,7 +26,7 @@ HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except Exception as e:
except Exception:
HAS_BITS_AND_BYTES = False

View File

@ -40,10 +40,11 @@ class OPTSharded(CausalLM):
trust_remote_code=trust_remote_code,
)
)
config = AutoConfig.from_pretrained(model_id, revision=revision,
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
)
)
config.quantize = quantize
tokenizer.pad_token_id = config.pad_token_id

View File

@ -3,6 +3,7 @@ import torch
from datetime import timedelta
class FakeBarrier:
def wait(self):
pass
@ -17,7 +18,9 @@ class FakeGroup:
return FakeBarrier()
def allgather(self, inputs, local_tensor, **kwargs):
assert len(inputs[0]) == len(local_tensor) == 1, f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
assert (
len(inputs[0]) == len(local_tensor) == 1
), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
for input_ in inputs:
input_[0].data = local_tensor[0].data
return FakeBarrier()

View File

@ -10,8 +10,7 @@ from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.utils import (
LocalEntryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError, # Import here to ease try/except in other part of the lib
EntryNotFoundError, # Import here to ease try/except in other part of the lib
)
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)

View File

@ -2,14 +2,14 @@ import torch
from torch import nn
from torch.nn import functional as F
from typing import Optional, List
from typing import List
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params
except ImportError as e:
except ImportError:
HAS_BITS_AND_BYTES = False
from accelerate import init_empty_weights
@ -27,14 +27,16 @@ def load_layer_norm(cls, prefix, weights, eps):
ln.bias = nn.Parameter(bias)
return ln
torch.nn.LayerNorm.load = load_layer_norm
class FastLinear(nn.Module):
def __init__(
self,
weight, bias,
) -> None:
weight,
bias,
) -> None:
super().__init__()
self.weight = nn.Parameter(weight)
if bias is not None:
@ -56,10 +58,19 @@ class FastLinear(nn.Module):
class Linear8bitLt(nn.Module):
def __init__(self, weight, bias, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None):
def __init__(
self,
weight,
bias,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
super().__init__()
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
assert (
not memory_efficient_backward
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
@ -70,7 +81,11 @@ class Linear8bitLt(nn.Module):
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
self.weight = Int8Params(
weight.data,
has_fp16_weights=has_fp16_weights,
requires_grad=has_fp16_weights,
)
self.weight.cuda(weight.device)
self.bias = bias
@ -105,7 +120,8 @@ def get_linear(weight, bias, quantize):
linear = FastLinear(weight, bias)
elif quantize == "bitsandbytes":
linear = Linear8bitLt(
weight, bias,
weight,
bias,
has_fp16_weights=False,
threshold=6.0,
)
@ -114,7 +130,9 @@ def get_linear(weight, bias, quantize):
elif quantize == "gptq":
raise NotImplementedError("Soon")
else:
raise NotImplementedError(f"Quantization `{config.quantize}` is not implemented yet.")
raise NotImplementedError(
f"Quantization `{config.quantize}` is not implemented yet."
)
return linear
@ -126,6 +144,7 @@ class SuperLayer(nn.Module):
def forward(self, x):
return self.linear.forward(x)
class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group):
super().__init__(linear)
@ -134,12 +153,17 @@ class TensorParallelHead(SuperLayer):
@staticmethod
def load(config, prefix: str, weights):
weight = weights.get_sharded(f"{prefix}.weight", dim=0)
return TensorParallelHead(get_linear(weight, bias=None, quantize=config.quantize), process_group = weights.process_group)
return TensorParallelHead(
get_linear(weight, bias=None, quantize=config.quantize),
process_group=weights.process_group,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = super().forward(input)
# Logits are sharded, so we need to gather them
world_output = [torch.empty_like(output) for _ in range(self.process_group.size())]
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
@ -181,13 +205,17 @@ class TensorParallelRowLinear(SuperLayer):
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return TensorParallelRowLinear(get_linear(weight, bias, config.quantize), process_group=weights.process_group)
return TensorParallelRowLinear(
get_linear(weight, bias, config.quantize),
process_group=weights.process_group,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super().forward(input)
torch.distributed.all_reduce(out, group=self.process_group)
return out
class TensorParallelEmbedding(nn.Module):
def __init__(self, prefix: str, weights, reduce=True):
super().__init__()
@ -222,6 +250,7 @@ class TensorParallelEmbedding(nn.Module):
torch.distributed.all_reduce(out, group=self.process_group)
return out
try:
import dropout_layer_norm

View File

@ -1,7 +1,8 @@
from pathlib import Path
from typing import Optional, List
from typing import List
from safetensors import safe_open
class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group):
routing = {}
@ -26,8 +27,6 @@ class Weights:
return self._handles[filename]
def get_filename(self, tensor_name: str) -> str:
filename = self.routing.get(tensor_name, None)
if filename is None:
@ -63,7 +62,9 @@ class Weights:
start = rank * block_size
stop = (rank + 1) * block_size
assert size % world_size == 0, f"The choosen size {size} is not compatible with sharding on {world_size} shards"
assert (
size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
if dim == 0:
tensor = slice_[start:stop]
@ -74,5 +75,3 @@ class Weights:
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor