diff --git a/server/custom_kernels/setup.py b/server/custom_kernels/setup.py index 62c720e1..fa4382e9 100644 --- a/server/custom_kernels/setup.py +++ b/server/custom_kernels/setup.py @@ -1,15 +1,14 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension + setup( - name='custom_kernels', + name="custom_kernels", ext_modules=[ CUDAExtension( - name="custom_kernels.fused_bloom_attention_cuda", - sources=['custom_kernels/fused_bloom_attention_cuda.cu'], - extra_compile_args=["-arch=compute_80", "-std=c++17"], + name="custom_kernels.fused_bloom_attention_cuda", + sources=["custom_kernels/fused_bloom_attention_cuda.cu"], + extra_compile_args=["-arch=compute_80", "-std=c++17"], ) ], - cmdclass={ - 'build_ext': BuildExtension - } + cmdclass={"build_ext": BuildExtension}, ) diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 554cab9f..e5e87645 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -37,7 +37,6 @@ from text_generation_server.utils.layers import ( TensorParallelEmbedding, TensorParallelRowLinear, TensorParallelHead, - FastLinear ) CUSTOM_KERNELS_ENABLED = False diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index c1273267..24004e8a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -60,12 +60,18 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): weight = weights.get_sharded(f"{prefix}.weight", dim=0) bias = weights.get_sharded(f"{prefix}.bias", dim=0) - weight = weight.view( - num_heads, 3, head_size, hidden_size, - ).permute(1, 0, 2, 3).reshape(-1, hidden_size) + weight = ( + weight.view( + num_heads, + 3, + head_size, + hidden_size, + ) + .permute(1, 0, 2, 3) + .reshape(-1, hidden_size) + ) bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) - linear = get_linear(weight, bias, config.quantize) if config.use_parallel_residual: return linear @@ -88,17 +94,23 @@ class FlashNeoxAttention(torch.nn.Module): rotary_ndims = int(self.head_size * rotary_pct) self.rotary_emb = PositionRotaryEmbedding(rotary_ndims, base=rotary_emb_base) - self.rotary_emb.inv_freq = nn.Parameter(weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")) + self.rotary_emb.inv_freq = nn.Parameter( + weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") + ) self.softmax_scale = self.head_size ** (-0.5) self.query_key_value = load_qkv( - config, prefix=f"{prefix}.query_key_value", weights=weights, - num_heads = self.num_heads, head_size = self.head_size, hidden_size = self.hidden_size + config, + prefix=f"{prefix}.query_key_value", + weights=weights, + num_heads=self.num_heads, + head_size=self.head_size, + hidden_size=self.hidden_size, ) self.dense = load_row( config, prefix=f"{prefix}.dense", weights=weights, bias=True ) - + def forward( self, hidden_states, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 21b3f039..888a6066 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -3,7 +3,7 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from typing import Optional, List +from typing import Optional # Flash attention imports import flash_attn_cuda @@ -17,8 +17,9 @@ from text_generation_server.utils.layers import ( ) - -def load_multi_mqa(config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size): +def load_multi_mqa( + config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size +): if any("c_attn" in k for k in weights.routing.keys()): slice_ = weights._get_slice(f"{prefix}.c_attn.weight") shape = slice_.get_shape() @@ -55,30 +56,35 @@ def load_multi_mqa(config, prefix: str, weights, bias: bool, head_size, num_head if config.transpose: w = [ weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T, - weights.get_tensor(f"{prefix}.kv_attn.weight").T + weights.get_tensor(f"{prefix}.kv_attn.weight").T, ] weight = torch.cat(w, dim=0) else: w = [ weights.get_sharded(f"{prefix}.q_attn.weight", dim=0), - weights.get_tensor(f"{prefix}.kv_attn.weight") + weights.get_tensor(f"{prefix}.kv_attn.weight"), ] weight = torch.cat(w, dim=1) if bias: b = [ weights.get_sharded(f"{prefix}.q_attn.bias", dim=0), - weights.get_tensor(f"{prefix}.kv_attn.bias") + weights.get_tensor(f"{prefix}.kv_attn.bias"), ] bias = torch.cat(b, dim=0) else: bias = None weight = weight.to(dtype=weights.dtype).to(device=weights.device) - assert list(weight.shape) == [(num_heads + 2) * head_size, hidden_size], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}" + assert list(weight.shape) == [ + (num_heads + 2) * head_size, + hidden_size, + ], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}" if bias is not None: bias = bias.to(dtype=weights.dtype).to(device=weights.device) - assert list(bias.shape) == [(num_heads + 2) * head_size], f"{weight.shape} != {[(num_heads + 2) * head_size]}" + assert list(bias.shape) == [ + (num_heads + 2) * head_size + ], f"{weight.shape} != {[(num_heads + 2) * head_size]}" return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) @@ -106,7 +112,9 @@ def load_row(config, prefix: str, weights, bias: bool): bias = weights.get_tensor(f"{prefix}.bias") else: bias = None - return TensorParallelRowLinear(get_linear(weight, bias, config.quantize), process_group=weights.process_group) + return TensorParallelRowLinear( + get_linear(weight, bias, config.quantize), process_group=weights.process_group + ) class FlashMQAttention(torch.nn.Module): @@ -131,7 +139,7 @@ class FlashMQAttention(torch.nn.Module): bias=True, head_size=self.head_size, hidden_size=hidden_size, - num_heads=self.num_heads + num_heads=self.num_heads, ) self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index 6fa09b09..c5ce9bfc 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -109,9 +109,21 @@ class T5DenseActDense(nn.Module): self.wi = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.wi", weights=weights, bias=False ) + + ### XXX: T5 models do not handle well both f16 and quantization. + ### Overidding specifically this layer for that reason. + ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 + ### https://github.com/huggingface/transformers/issues/20287 + _q = config.quantize + _dtype = weights.dtype + weights.dtype = torch.float32 + config.quantize = None + self.wo_cast = (torch.float32, _dtype) self.wo = TensorParallelRowLinear.load( config, prefix=f"{prefix}.wo", weights=weights, bias=False ) + weights.dtype = _dtype + config.quantize = _q self.dropout = nn.Dropout(config.dropout_rate) self.act = ( @@ -124,7 +136,10 @@ class T5DenseActDense(nn.Module): hidden_states = self.wi(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states.to(dtype=self.wo_cast[0]) hidden_states = self.wo(hidden_states) + hidden_states = hidden_states.to(dtype=self.wo_cast[1]) return hidden_states @@ -137,9 +152,20 @@ class T5DenseGatedActDense(nn.Module): self.wi_1 = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.wi_1", weights=weights, bias=False ) + ### XXX: T5 models do not handle well both f16 and quantization. + ### Overidding specifically this layer for that reason. + ### https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L316 + ### https://github.com/huggingface/transformers/issues/20287 + _q = config.quantize + _dtype = weights.dtype + weights.dtype = torch.float32 + config.quantize = None + self.wo_cast = (torch.float32, _dtype) self.wo = TensorParallelRowLinear.load( config, prefix=f"{prefix}.wo", weights=weights, bias=False ) + weights.dtype = _dtype + config.quantize = _q self.dropout = nn.Dropout(config.dropout_rate) self.act = ( @@ -154,18 +180,9 @@ class T5DenseGatedActDense(nn.Module): hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states) - # TODO Support this again mayber - # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. - # See https://github.com/huggingface/transformers/issues/20287 - # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` - # if ( - # isinstance(self.wo.weight, torch.Tensor) - # and hidden_states.dtype != self.wo.weight.dtype - # and self.wo.weight.dtype != torch.int8 - # ): - # hidden_states = hidden_states.to(self.wo.weight.dtype) - + hidden_states = hidden_states.to(dtype=self.wo_cast[0]) hidden_states = self.wo(hidden_states) + hidden_states = hidden_states.to(dtype=self.wo_cast[1]) return hidden_states diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 5ab8a624..4d0e4730 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -26,7 +26,7 @@ HAS_BITS_AND_BYTES = True try: import bitsandbytes as bnb from bitsandbytes.nn import Int8Params -except Exception as e: +except Exception: HAS_BITS_AND_BYTES = False diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 185937e6..16cb48b7 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -40,10 +40,11 @@ class OPTSharded(CausalLM): trust_remote_code=trust_remote_code, ) - ) - config = AutoConfig.from_pretrained(model_id, revision=revision, + config = AutoConfig.from_pretrained( + model_id, + revision=revision, trust_remote_code=trust_remote_code, - ) + ) config.quantize = quantize tokenizer.pad_token_id = config.pad_token_id diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 9be51f74..fe9c3b7b 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -3,6 +3,7 @@ import torch from datetime import timedelta + class FakeBarrier: def wait(self): pass @@ -17,7 +18,9 @@ class FakeGroup: return FakeBarrier() def allgather(self, inputs, local_tensor, **kwargs): - assert len(inputs[0]) == len(local_tensor) == 1, f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" + assert ( + len(inputs[0]) == len(local_tensor) == 1 + ), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors" for input_ in inputs: input_[0].data = local_tensor[0].data return FakeBarrier() diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 2ed7673c..965cae99 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -10,8 +10,7 @@ from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.utils import ( LocalEntryNotFoundError, - EntryNotFoundError, - RevisionNotFoundError, # Import here to ease try/except in other part of the lib + EntryNotFoundError, # Import here to ease try/except in other part of the lib ) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index ea9a1469..0146e5c3 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -2,14 +2,14 @@ import torch from torch import nn from torch.nn import functional as F -from typing import Optional, List +from typing import List HAS_BITS_AND_BYTES = True try: import bitsandbytes as bnb from bitsandbytes.nn import Int8Params -except ImportError as e: +except ImportError: HAS_BITS_AND_BYTES = False from accelerate import init_empty_weights @@ -27,14 +27,16 @@ def load_layer_norm(cls, prefix, weights, eps): ln.bias = nn.Parameter(bias) return ln + torch.nn.LayerNorm.load = load_layer_norm class FastLinear(nn.Module): def __init__( self, - weight, bias, - ) -> None: + weight, + bias, + ) -> None: super().__init__() self.weight = nn.Parameter(weight) if bias is not None: @@ -44,9 +46,9 @@ class FastLinear(nn.Module): @staticmethod def load(config, prefix: str, weights, bias: bool): - weight = weights.get_tensor(f"{prefix}.weight") + weight = weights.get_tensor(f"{prefix}.weight") if bias: - bias = weights.get_tensor(f"{prefix}.bias") + bias = weights.get_tensor(f"{prefix}.bias") else: bias = None return FastLinear(weight, bias) @@ -56,10 +58,19 @@ class FastLinear(nn.Module): class Linear8bitLt(nn.Module): - def __init__(self, weight, bias, has_fp16_weights=True, - memory_efficient_backward=False, threshold=0.0, index=None): + def __init__( + self, + weight, + bias, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + ): super().__init__() - assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" + assert ( + not memory_efficient_backward + ), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" self.state = bnb.MatmulLtState() self.index = index @@ -70,7 +81,11 @@ class Linear8bitLt(nn.Module): if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params(weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) + self.weight = Int8Params( + weight.data, + has_fp16_weights=has_fp16_weights, + requires_grad=has_fp16_weights, + ) self.weight.cuda(weight.device) self.bias = bias @@ -105,7 +120,8 @@ def get_linear(weight, bias, quantize): linear = FastLinear(weight, bias) elif quantize == "bitsandbytes": linear = Linear8bitLt( - weight, bias, + weight, + bias, has_fp16_weights=False, threshold=6.0, ) @@ -114,7 +130,9 @@ def get_linear(weight, bias, quantize): elif quantize == "gptq": raise NotImplementedError("Soon") else: - raise NotImplementedError(f"Quantization `{config.quantize}` is not implemented yet.") + raise NotImplementedError( + f"Quantization `{config.quantize}` is not implemented yet." + ) return linear @@ -126,6 +144,7 @@ class SuperLayer(nn.Module): def forward(self, x): return self.linear.forward(x) + class TensorParallelHead(SuperLayer): def __init__(self, linear, process_group): super().__init__(linear) @@ -133,13 +152,18 @@ class TensorParallelHead(SuperLayer): @staticmethod def load(config, prefix: str, weights): - weight = weights.get_sharded(f"{prefix}.weight", dim=0) - return TensorParallelHead(get_linear(weight, bias=None, quantize=config.quantize), process_group = weights.process_group) + weight = weights.get_sharded(f"{prefix}.weight", dim=0) + return TensorParallelHead( + get_linear(weight, bias=None, quantize=config.quantize), + process_group=weights.process_group, + ) def forward(self, input: torch.Tensor) -> torch.Tensor: output = super().forward(input) # Logits are sharded, so we need to gather them - world_output = [torch.empty_like(output) for _ in range(self.process_group.size())] + world_output = [ + torch.empty_like(output) for _ in range(self.process_group.size()) + ] torch.distributed.all_gather(world_output, output, group=self.process_group) world_output = torch.cat(world_output, dim=-1) return world_output @@ -148,9 +172,9 @@ class TensorParallelHead(SuperLayer): class TensorParallelColumnLinear(SuperLayer): @staticmethod def load(config, prefix: str, weights, bias: bool): - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + weight = weights.get_sharded(f"{prefix}.weight", dim=0) if bias: - bias = weights.get_sharded(f"{prefix}.bias", dim=0) + bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) @@ -175,23 +199,27 @@ class TensorParallelRowLinear(SuperLayer): @staticmethod def load(config, prefix: str, weights, bias: bool): - weight = weights.get_sharded(f"{prefix}.weight", dim=1) + weight = weights.get_sharded(f"{prefix}.weight", dim=1) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process - bias = weights.get_tensor(f"{prefix}.bias") + bias = weights.get_tensor(f"{prefix}.bias") else: bias = None - return TensorParallelRowLinear(get_linear(weight, bias, config.quantize), process_group=weights.process_group) + return TensorParallelRowLinear( + get_linear(weight, bias, config.quantize), + process_group=weights.process_group, + ) def forward(self, input: torch.Tensor) -> torch.Tensor: out = super().forward(input) torch.distributed.all_reduce(out, group=self.process_group) return out + class TensorParallelEmbedding(nn.Module): def __init__(self, prefix: str, weights, reduce=True): super().__init__() - weight = weights.get_sharded(f"{prefix}.weight", dim=0) + weight = weights.get_sharded(f"{prefix}.weight", dim=0) num_embeddings = weights.get_shape(f"{prefix}.weight")[0] process_group = weights.process_group @@ -222,6 +250,7 @@ class TensorParallelEmbedding(nn.Module): torch.distributed.all_reduce(out, group=self.process_group) return out + try: import dropout_layer_norm diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index fc01d937..2a410ca3 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,7 +1,8 @@ from pathlib import Path -from typing import Optional, List +from typing import List from safetensors import safe_open + class Weights: def __init__(self, filenames: List[Path], device, dtype, process_group): routing = {} @@ -26,8 +27,6 @@ class Weights: return self._handles[filename] - - def get_filename(self, tensor_name: str) -> str: filename = self.routing.get(tensor_name, None) if filename is None: @@ -63,7 +62,9 @@ class Weights: start = rank * block_size stop = (rank + 1) * block_size - assert size % world_size == 0, f"The choosen size {size} is not compatible with sharding on {world_size} shards" + assert ( + size % world_size == 0 + ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" if dim == 0: tensor = slice_[start:stop] @@ -74,5 +75,3 @@ class Weights: tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor - -