diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 9c1ea3b06..13c74c916 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -26,7 +26,9 @@ try: FLASH_ATTENTION = torch.cuda.is_available() except ImportError: - logger.opt(exception=True).warning("Could not import Flash Attention enabled models") + logger.opt(exception=True).warning( + "Could not import Flash Attention enabled models" + ) FLASH_ATTENTION = False __all__ = [ @@ -88,10 +90,10 @@ def get_model( raise NotImplementedError( FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") ) - return FlashSantacoderSharded(model_id, revision=revision) + return FlashSantacoderSharded(model_id, revision, quantize=quantize) else: santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder - return santacoder_cls(model_id, revision, quantize) + return santacoder_cls(model_id, revision, quantize=quantize) config = AutoConfig.from_pretrained(model_id, revision=revision) model_type = config.model_type diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 508b7746c..718438860 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -33,6 +33,12 @@ import dropout_layer_norm from flash_attn.layers.rotary import RotaryEmbedding +HAS_BITS_AND_BYTES = True +try: + from bitsandbytes.nn import Linear8bitLt +except ImportError as e: + HAS_BITS_AND_BYTES = False + class LlamaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -94,14 +100,44 @@ class FastLinear(nn.Linear): dtype=None, ) -> None: super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) + self.quantized = False + self.bnb_linear = None - def transpose_weight(self): - self.weight = nn.Parameter(self.weight.T) + def prepare_weights(self, quantize: bool = False): + if quantize: + if not HAS_BITS_AND_BYTES: + raise ImportError( + "bitsandbytes is not available on your machine either because it is not installed " + "or you don't have a GPU.\n" + "You can install it with `pip install bitsandbytes`." + ) + + self.quantized = True + self.bnb_linear = Linear8bitLt( + self.in_features, + self.out_features, + has_fp16_weights=False, + threshold=6.0, + bias=False, + ) + # Copy data to bnb_linear + self.bnb_linear.weight.data = self.weight.data + if self.bias is not None: + self.bnb_linear.bias = nn.Parameter(self.bias) + + # Delete reference to data + self.weight = None + self.bias = None + else: + self.weight = nn.Parameter(self.weight.T) def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.bias is not None: - return torch.addmm(self.bias, input, self.weight) - return torch.matmul(input, self.weight) + if self.quantized: + return self.bnb_linear(input) + else: + if self.bias is not None: + return torch.addmm(self.bias, input, self.weight) + return torch.matmul(input, self.weight) class TensorParallelColumnLinear(FastLinear): @@ -502,15 +538,15 @@ class FlashLlamaModel(torch.nn.Module): self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads - def post_load_weights(self): + def post_load_weights(self, load_in_8bit: bool = False): if isinstance(self.embed_tokens, TensorParallelEmbedding): self.embed_tokens.add_null_idx() for layer in self.layers: layer: FlashLlamaLayer - layer.self_attn.query_key_value.transpose_weight() - layer.self_attn.o_proj.transpose_weight() - layer.mlp.gate_up_proj.transpose_weight() - layer.mlp.down_proj.transpose_weight() + layer.self_attn.query_key_value.prepare_weights(load_in_8bit) + layer.self_attn.o_proj.prepare_weights(load_in_8bit) + layer.mlp.gate_up_proj.prepare_weights(load_in_8bit) + layer.mlp.down_proj.prepare_weights(load_in_8bit) def forward( self, @@ -592,9 +628,9 @@ class FlashLlamaForCausalLM(torch.nn.Module): else: self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - def post_load_weights(self): - self.model.post_load_weights() - self.lm_head.transpose_weight() + def post_load_weights(self, load_in_8bit: bool = False): + self.model.post_load_weights(load_in_8bit) + self.lm_head.prepare_weights() def forward( self, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 16fd40917..8de582e3a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -35,6 +35,12 @@ import dropout_layer_norm from flash_attn.layers.rotary import RotaryEmbedding +HAS_BITS_AND_BYTES = True +try: + from bitsandbytes.nn import Linear8bitLt +except ImportError as e: + HAS_BITS_AND_BYTES = False + class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): @@ -82,14 +88,44 @@ class FastLinear(nn.Linear): dtype=None, ) -> None: super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) + self.quantized = False + self.bnb_linear = None - def transpose_weight(self): - self.weight = nn.Parameter(self.weight.T) + def prepare_weights(self, quantize: bool = False): + if quantize: + if not HAS_BITS_AND_BYTES: + raise ImportError( + "bitsandbytes is not available on your machine either because it is not installed " + "or you don't have a GPU.\n" + "You can install it with `pip install bitsandbytes`." + ) + + self.quantized = True + self.bnb_linear = Linear8bitLt( + self.in_features, + self.out_features, + has_fp16_weights=False, + threshold=6.0, + bias=False, + ) + # Copy data to bnb_linear + self.bnb_linear.weight.data = self.weight.data + if self.bias is not None: + self.bnb_linear.bias = nn.Parameter(self.bias) + + # Delete reference to data + self.weight = None + self.bias = None + else: + self.weight = nn.Parameter(self.weight.T) def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.bias is not None: - return torch.addmm(self.bias, input, self.weight) - return torch.matmul(input, self.weight) + if self.quantized: + return self.bnb_linear(input) + else: + if self.bias is not None: + return torch.addmm(self.bias, input, self.weight) + return torch.matmul(input, self.weight) class TensorParallelColumnLinear(FastLinear): @@ -552,23 +588,27 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): self.head_size = self.layers[0].attention.head_size self.num_heads = self.layers[0].attention.num_heads - def post_load_weights(self): + def post_load_weights(self, load_in_8bit=False): if isinstance(self.embed_in, TensorParallelEmbedding): self.embed_in.add_null_idx() for layer in self.layers: layer: FlashNeoXLayer layer.attention.shuffle_qkv_dims() - layer.attention.query_key_value.transpose_weight() - layer.attention.dense.transpose_weight() - layer.mlp.dense_h_to_4h.transpose_weight() - layer.mlp.dense_4h_to_h.transpose_weight() + layer.attention.query_key_value.prepare_weights(load_in_8bit) + layer.attention.dense.prepare_weights(load_in_8bit) + layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit) + layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + # Pop here as we will replace the layer in our own logic and don't want from_pretrained + # to do it for us + load_in_8bit = kwargs.pop("load_in_8bit", False) model = super(FlashGPTNeoXModel, cls).from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs ) - model.post_load_weights() + + model.post_load_weights(load_in_8bit) return model def forward( @@ -653,16 +693,19 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): config.hidden_size, config.vocab_size, bias=False ) - def post_load_weights(self): - self.gpt_neox.post_load_weights() - self.embed_out.transpose_weight() + def post_load_weights(self, load_in_8bit=False): + self.gpt_neox.post_load_weights(load_in_8bit) + self.embed_out.prepare_weights() @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + # Pop here as we will replace the layer in our own logic and don't want from_pretrained + # to do it for us + load_in_8bit = kwargs.pop("load_in_8bit", False) model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained( - pretrained_model_name_or_path, *model_args, **kwargs + pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs ) - model.post_load_weights() + model.post_load_weights(load_in_8bit) return model def forward( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 8679826b6..793b3d11d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -10,6 +10,12 @@ from transformers.activations import ACT2FN import flash_attn_cuda import dropout_layer_norm +HAS_BITS_AND_BYTES = True +try: + from bitsandbytes.nn import Linear8bitLt +except ImportError as e: + HAS_BITS_AND_BYTES = False + class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): @@ -57,14 +63,44 @@ class FastLinear(nn.Linear): dtype=None, ) -> None: super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) + self.quantized = False + self.bnb_linear = None - def transpose_weight(self): - self.weight = nn.Parameter(self.weight.T) + def prepare_weights(self, quantize: bool = False): + if quantize: + if not HAS_BITS_AND_BYTES: + raise ImportError( + "bitsandbytes is not available on your machine either because it is not installed " + "or you don't have a GPU.\n" + "You can install it with `pip install bitsandbytes`." + ) + + self.quantized = True + self.bnb_linear = Linear8bitLt( + self.in_features, + self.out_features, + has_fp16_weights=False, + threshold=6.0, + bias=False, + ) + # Copy data to bnb_linear + self.bnb_linear.weight.data = self.weight.data + if self.bias is not None: + self.bnb_linear.bias = nn.Parameter(self.bias) + + # Delete reference to data + self.weight = None + self.bias = None + else: + self.weight = nn.Parameter(self.weight.T) def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.bias is not None: - return torch.addmm(self.bias, input, self.weight) - return torch.matmul(input, self.weight) + if self.quantized: + return self.bnb_linear(input) + else: + if self.bias is not None: + return torch.addmm(self.bias, input, self.weight) + return torch.matmul(input, self.weight) class TensorParallelColumnLinear(FastLinear): @@ -431,16 +467,16 @@ class FlashSantacoderModel(nn.Module): self.head_size = self.h[0].attn.head_size self.num_heads = self.h[0].attn.num_heads - def post_load_weights(self): + def post_load_weights(self, load_in_8bit: bool = False): if self.tp_embeddings: self.wte.add_null_idx() self.wpe.add_null_idx() for layer in self.h: layer: Block - layer.attn.c_attn.transpose_weight() - layer.attn.c_proj.transpose_weight() - layer.mlp.c_fc.transpose_weight() - layer.mlp.c_proj.transpose_weight() + layer.attn.c_attn.prepare_weights(load_in_8bit) + layer.attn.c_proj.prepare_weights(load_in_8bit) + layer.mlp.c_fc.prepare_weights(load_in_8bit) + layer.mlp.c_proj.prepare_weights(load_in_8bit) def forward( self, @@ -508,9 +544,9 @@ class FlashSantacoderForCausalLM(nn.Module): else: self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) - def post_load_weights(self): - self.transformer.post_load_weights() - self.lm_head.transpose_weight() + def post_load_weights(self, load_in_8bit: bool = False): + self.transformer.post_load_weights(load_in_8bit) + self.lm_head.prepare_weights() def forward( self, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 507fec0fe..3d4273cc9 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -221,9 +221,6 @@ class FlashCausalLM(Model): else: raise NotImplementedError("FlashCausalLM is only available on GPU") - if quantize: - raise NotImplementedError("FlashCausalLM does not support quantization") - tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" ) @@ -232,9 +229,10 @@ class FlashCausalLM(Model): model_id, revision=revision, torch_dtype=dtype, + load_in_8bit=quantize, ) .eval() - .cuda() + .to(device) ) super(FlashCausalLM, self).__init__( diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 063910f47..9cbf1b57e 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -35,9 +35,6 @@ class FlashLlama(FlashCausalLM): else: raise NotImplementedError("FlashLlama is only available on GPU") - if quantize: - raise NotImplementedError("FlashLlama does not support quantization") - tokenizer = LlamaTokenizer.from_pretrained( model_id, revision=revision, @@ -61,8 +58,8 @@ class FlashLlama(FlashCausalLM): with init_empty_weights(): model = FlashLlamaForCausalLM(config) - self.load_weights(model, filenames, device, dtype) - self.model = model.eval() + self.load_weights(model, filenames, quantize, device, dtype) + self.model = model.eval().to(device) super(FlashCausalLM, self).__init__( tokenizer=tokenizer, @@ -73,13 +70,14 @@ class FlashLlama(FlashCausalLM): def load_weights( model, filenames: List[Path], + quantize: bool, device: torch.device, dtype: torch.dtype, ): for filename in filenames: state_dict = torch.load(filename, map_location="cpu") for key, value in state_dict.items(): - value = value.to(device).to(dtype) + value = value.to(device if not quantize else "cpu").to(dtype) layer_name = ".".join(key.split(".")[:4]) @@ -139,7 +137,7 @@ class FlashLlama(FlashCausalLM): del value torch.cuda.empty_cache() - model.post_load_weights() + model.post_load_weights(quantize) class FlashLlamaSharded(FlashLlama): @@ -154,9 +152,6 @@ class FlashLlamaSharded(FlashLlama): else: raise NotImplementedError("FlashLlama is only available on GPU") - if quantize: - raise NotImplementedError("FlashLlama does not support quantization") - tokenizer = LlamaTokenizer.from_pretrained( model_id, revision=revision, @@ -185,7 +180,7 @@ class FlashLlamaSharded(FlashLlama): rank=self.rank, world_size=self.world_size, ) - self.model = model.eval() + self.model = model.eval().to(device) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( tokenizer=tokenizer, @@ -300,4 +295,4 @@ class FlashLlamaSharded(FlashLlama): else: module._buffers[param_name] = tensor torch.cuda.empty_cache() - model.post_load_weights() + model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index a8b384654..0cda728a2 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -41,9 +41,6 @@ class FlashNeoXSharded(FlashNeoX): else: raise NotImplementedError("FlashNeoX is only available on GPU") - if quantize: - raise NotImplementedError("FlashNeoX does not support quantization") - tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" ) @@ -63,13 +60,13 @@ class FlashNeoXSharded(FlashNeoX): self.load_weights( model, filenames, + quantize=quantize, device=device, dtype=dtype, rank=self.rank, world_size=self.world_size, ) - model.post_load_weights() - self.model = model.eval() + self.model = model.eval().to(device) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( tokenizer=tokenizer, @@ -80,6 +77,7 @@ class FlashNeoXSharded(FlashNeoX): def load_weights( model, filenames: List[str], + quantize: bool, device: torch.device, dtype: torch.dtype, rank: int, @@ -87,7 +85,9 @@ class FlashNeoXSharded(FlashNeoX): ): parameters = dict(model.named_parameters()) for file in filenames: - with safe_open(file, framework="pt", device=str(device)) as f: + with safe_open( + file, framework="pt", device=str(device) if not quantize else "cpu" + ) as f: for name in f.keys(): module_name, param_name = name.rsplit(".", 1) module = model.get_submodule(module_name) @@ -146,3 +146,4 @@ class FlashNeoXSharded(FlashNeoX): module._parameters[param_name] = tensor else: module._buffers[param_name] = tensor + model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 39381e929..e3066c982 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -34,9 +34,6 @@ class FlashSantacoder(FlashCausalLM): else: raise NotImplementedError("FlashSantacoder is only available on GPU") - if quantize: - raise NotImplementedError("FlashSantacoder does not support quantization") - tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" ) @@ -58,9 +55,14 @@ class FlashSantacoder(FlashCausalLM): model = FlashSantacoderForCausalLM(config) self.load_weights( - model, filenames, device, dtype, config.architectures[0].startswith("GPT2") + model, + filenames, + quantize, + device, + dtype, + config.architectures[0].startswith("GPT2"), ) - self.model = model.eval() + self.model = model.eval().to(device) super(FlashCausalLM, self).__init__( tokenizer=tokenizer, device=device, decode_buffer=1 @@ -70,6 +72,7 @@ class FlashSantacoder(FlashCausalLM): def load_weights( model: FlashSantacoderForCausalLM, filenames: List[Path], + quantize: bool, device: torch.device, dtype: torch.dtype, transpose: bool, @@ -77,7 +80,7 @@ class FlashSantacoder(FlashCausalLM): for filename in filenames: state_dict = torch.load(filename, map_location="cpu") for key, value in state_dict.items(): - value = value.to(device).to(dtype) + value = value.to(device if not quantize else "cpu").to(dtype) layer_name = ".".join(key.split(".")[:4]) @@ -152,7 +155,7 @@ class FlashSantacoder(FlashCausalLM): del value torch.cuda.empty_cache() - model.post_load_weights() + model.post_load_weights(quantize) def decode(self, generated_ids: List[int]) -> str: # Do not skip special tokens as they are used for custom parsing rules of the generated text @@ -173,11 +176,6 @@ class FlashSantacoderSharded(FlashSantacoder): else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") - if quantize: - raise NotImplementedError( - "FlashSantacoderSharded does not support quantization" - ) - tokenizer = AutoTokenizer.from_pretrained( model_id, revision=revision, padding_side="left", truncation_side="left" ) @@ -197,13 +195,14 @@ class FlashSantacoderSharded(FlashSantacoder): self.load_weights( model, filenames, + quantize=quantize, device=device, dtype=dtype, rank=self.rank, world_size=self.world_size, transpose=config.architectures[0].startswith("GPT2"), ) - self.model = model.eval() + self.model = model.eval().to(device) torch.distributed.barrier(group=self.process_group) super(FlashCausalLM, self).__init__( tokenizer=tokenizer, @@ -214,6 +213,7 @@ class FlashSantacoderSharded(FlashSantacoder): def load_weights( model, filenames: List[str], + quantize: bool, device: torch.device, dtype: torch.dtype, rank: int, @@ -221,7 +221,9 @@ class FlashSantacoderSharded(FlashSantacoder): transpose: bool, ): for file in filenames: - with safe_open(file, framework="pt", device=str(device)) as f: + with safe_open( + file, framework="pt", device=str(device) if not quantize else "cpu" + ) as f: for key in f.keys(): slice_ = f.get_slice(key) @@ -363,4 +365,4 @@ class FlashSantacoderSharded(FlashSantacoder): else: module._buffers[param_name] = tensor torch.cuda.empty_cache() - model.post_load_weights() + model.post_load_weights(quantize)