From da7e10424113cbfe3a5a4b8892ddeea2aae2fc61 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 30 Jun 2023 07:47:06 +0000 Subject: [PATCH] Add the option to force another dtype than `f16`. Adds a new flag propagated everywhere. Disjoint from `--quantize` which also changes the actual dtype of layers. Fixes #490 --- launcher/src/main.rs | 24 ++ server/text_generation_server/cli.py | 15 +- .../text_generation_server/models/__init__.py | 55 ++- server/text_generation_server/models/bloom.py | 3 +- .../models/causal_lm.py | 3 +- .../custom_modeling/flash_mpt_modeling.py | 361 ++++++++++++++++++ .../models/flash_llama.py | 3 +- .../models/flash_mpt.py | 73 ++++ .../models/flash_neox.py | 3 +- .../text_generation_server/models/flash_rw.py | 3 +- .../models/flash_santacoder.py | 10 +- .../models/galactica.py | 3 +- .../text_generation_server/models/gpt_neox.py | 3 +- server/text_generation_server/models/opt.py | 3 +- server/text_generation_server/models/rw.py | 3 +- .../models/santacoder.py | 3 +- .../models/seq2seq_lm.py | 3 +- server/text_generation_server/models/t5.py | 3 +- server/text_generation_server/server.py | 6 +- 19 files changed, 558 insertions(+), 22 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/flash_mpt_modeling.py create mode 100644 server/text_generation_server/models/flash_mpt.py diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 2e2bc7a5..6a194bf3 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -36,6 +36,26 @@ impl std::fmt::Display for Quantization { } } +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Dtype { + Float16, + BFloat16, +} + +impl std::fmt::Display for Dtype { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // To keep in track with `server`. + match self { + Dtype::Float16 => { + write!(f, "float16") + } + Dtype::BFloat16 => { + write!(f, "bfloat16") + } + } + } +} + /// App Configuration #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] @@ -71,6 +91,10 @@ struct Args { #[clap(long, env, value_enum)] quantize: Option, + /// The dtype to be forced upon the model. This option cannot be used with `--quantize`. + #[clap(long, env, value_enum)] + quantize: Option, + /// Whether you want to execute hub modelling code. Explicitly passing a `revision` is /// encouraged when loading a model with custom code to ensure no malicious code has been /// contributed in a newer revision. diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index aeb1f13b..3463049a 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -16,12 +16,18 @@ class Quantization(str, Enum): gptq = "gptq" +class Dtype(str, Enum): + float16 = "float16" + bloat16 = "bfloat16" + + @app.command() def serve( model_id: str, revision: Optional[str] = None, sharded: bool = False, quantize: Optional[Quantization] = None, + dtype: Optional[Dtype] = None, trust_remote_code: bool = False, uds_path: Path = "/tmp/text-generation-server", logger_level: str = "INFO", @@ -64,7 +70,14 @@ def serve( # Downgrade enum into str for easier management later on quantize = None if quantize is None else quantize.value - server.serve(model_id, revision, sharded, quantize, trust_remote_code, uds_path) + dtype = None if dtype is None else dtype.value + if dtype is not None and quantize is not None: + raise RuntimeError( + "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." + ) + server.serve( + model_id, revision, sharded, quantize, dtype, trust_remote_code, uds_path + ) @app.command() diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 2abde685..e45e198a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -100,11 +100,25 @@ def get_model( revision: Optional[str], sharded: bool, quantize: Optional[str], + dtype: Optional[str], trust_remote_code: bool, ) -> Model: + if dtype is None: + dtype = torch.float16 + elif dtype == "float16": + dtype = torch.float16 + elif dtype == "bfloat16": + dtype = torch.bfloat16 + else: + raise RuntimeError(f"Unknown dtype {dtype}") + if "facebook/galactica" in model_id: return GalacticaSharded( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + dtypetrust_remote_code=trust_remote_code, ) if model_id.startswith("bigcode/"): @@ -113,6 +127,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: @@ -124,6 +139,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -138,6 +154,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: @@ -149,12 +166,17 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) if model_type == "bloom": return BLOOMSharded( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, ) elif model_type == "gpt_neox": @@ -163,6 +185,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: @@ -170,6 +193,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) else: @@ -177,6 +201,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -186,6 +211,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif sharded: @@ -195,6 +221,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -210,6 +237,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) raise NotImplementedError( @@ -221,6 +249,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) else: @@ -228,12 +257,17 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) elif model_type == "opt": return OPTSharded( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, ) elif model_type == "t5": @@ -241,6 +275,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -253,11 +288,19 @@ def get_model( if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, ) if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: return Seq2SeqLM( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, ) auto_map = config_dict.get("auto_map", None) @@ -267,6 +310,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) if "AutoModelForSeq2SeqLM" in auto_map.keys(): @@ -274,6 +318,7 @@ def get_model( model_id, revision, quantize=quantize, + dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 50b3b76a..101da207 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -42,12 +42,13 @@ class BLOOMSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ba0853f5..1ff4c514 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -454,11 +454,12 @@ class CausalLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/custom_modeling/flash_mpt_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mpt_modeling.py new file mode 100644 index 00000000..59c5ced0 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_mpt_modeling.py @@ -0,0 +1,361 @@ +"""A simple, flexible implementation of a GPT model. + +Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py +""" +# import math +# import warnings +# from typing import List, Optional, Tuple, Union +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast +# from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +# from .attention import attn_bias_shape, build_attn_bias +# from .blocks import MPTBlock +# from .custom_embedding import SharedEmbedding +# from .norm import NORM_CLASS_REGISTRY +# from .configuration_mpt import MPTConfig +# from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising +# from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm +# from .meta_init_context import init_empty_weights +# from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_ +# try: +# from .flash_attn_triton import flash_attn_func +# except: +# pass + +"""GPT Blocks used for the GPT Model.""" +from typing import Dict, Optional, Tuple +import torch +import torch.nn as nn +import math + +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + TensorParallelHead, + FastLayerNorm, +) + +EPS = 1e-5 + +def _gen_slopes(n_heads, alibi_bias_max=8, device=None): + _n_heads = 2 ** math.ceil(math.log2(n_heads)) + m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) + m = m.mul(alibi_bias_max / _n_heads) + slopes = 1.0 / torch.pow(2, m) + if _n_heads != n_heads: + slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] + return slopes.view(1, n_heads, 1, 1) + +def _build_alibi_bias(n_heads, seq_len, device, dtype, alibi_bias_max): + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len) + slopes = _gen_slopes(n_heads, alibi_bias_max, device=device) + alibi_bias = alibi_bias * slopes + return alibi_bias.to(dtype=dtype) + +ALIBI = None + +def build_alibi_bias(n_heads, seq_len, device, dtype, alibi_bias_max=8): + global ALIBI + if ALIBI is None or seq_len > ALIBI.shape[-1]: + ALIBI = _build_alibi_bias(n_heads, seq_len, device, dtype, alibi_bias_max=alibi_bias_max) + return ALIBI[:, :, :, :seq_len] + + +class MPTAttention(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + + self.num_heads = config.n_heads + self.hidden_size = config.d_model + self.head_size = self.hidden_size // self.num_heads + self.Wqkv = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.Wqkv", + weights=weights, + bias=False, + ) + self.out_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.out_proj", + weights=weights, + bias=False, + ) + + def forward(self, + hidden_states, + alibi, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + past_key_values, + past_present_indices, + prefill, + ): + qkv = self.Wqkv(hidden_states) + qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + + # Todo + raise Exception("Apply alibi ?"); + + # Prefill + if prefill: + # Copy to layer past + layer_past[...] = qkv[:, 1:] + + # output + attn_output = torch.empty_like(qkv[:, 0]) + # flash attention + flash_attn_cuda.fwd( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + attn_output, + start_seq, + end_seq, + start_seq, + end_seq, + max_s, + max_s, + 0.0, + self.softmax_scale, + False, + True, + False, + 0, + None, + ) + # Decode + else: + query = qkv[:, 0] + # Add present to the layer_past tensor at the correct indices + layer_past[past_present_indices] = qkv[:, 1:] + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + layer_past[:, 0], + layer_past[:, 1], + attn_output, + start_seq_q, + end_seq_q, + start_seq, + end_seq, + 1, + max_s, + 0.0, + self.softmax_scale, + False, + False, + False, + 0, + None, + ) + + return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + +class MPTMLP(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + + self.up_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.up_proj", + weights=weights, + bias=False, + ) + self.act = nn.GELU(approximate='none') + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + def forward(self, x): + return self.down_proj(self.act(self.up_proj(x))) + +class MPTBlock(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.norm_1 = FastLayerNorm.load_no_bias(prefix=f"{prefix}.norm_1", weights=weights, eps=EPS) + self.attn = MPTAttention(config, prefix=f"{prefix}.attn", weights=weights) + self.norm_2 = FastLayerNorm.load_no_bias(prefix=f"{prefix}.norm_2", weights=weights, eps=EPS) + self.ffn = MPTMLP(config, prefix=f"{prefix}.ffn", weights=weights) + + def forward(self, + hidden_states, + residual, + alibi, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + past_key_values, + past_present_indices, + prefill, + ): + residual = hidden_states + hidden_states, _ = self.norm_1(hidden_states) + # (hidden_states, attn_weights) = self.attn( + hidden_states = self.attn( + hidden_states, + alibi, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + past_key_values, + past_present_indices, + prefill, + ) + hidden_states += residual + residual = hidden_states + hidden_states, _ = self.norm_2(hidden_states) + hidden_states = self.ffn(hidden_states) + hidden_states += residual + return (x, attn_weights) + +class MPTModel(nn.Module): + def __init__(self, config, weights): + super().__init__() + self.wte = TensorParallelEmbedding( + prefix="transformer.wte", weights=weights + ) + self.num_heads = config.n_heads + self.hidden_size = config.d_model + self.head_size = self.hidden_size // self.num_heads + self.blocks = nn.ModuleList([MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) for i in range(config.n_layers)]) + self.norm_f = FastLayerNorm.load_no_bias( + prefix="transformer.norm_f", weights=weights, eps=EPS + ) + + # Create a default sizeable global alibi + build_alibi_bias(n_heads=self.num_heads, seq_len=1024,device=weights.device, dtype = weights.dtype) + + def forward( + self, + input_ids, + position_ids, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + past_present_indices, + past_key_values=None, + pre_allocate_past_size: Optional[int] = None, + ): + hidden_states = self.wte(input_ids) + + + + # Prefill + if past_key_values is None: + assert pre_allocate_past_size is not None + + prefill = True + + # Create past tensor + # We create a tensor of the same size as input_ids as we don't want to slice at every layer + past_key_values = hidden_states.new_empty( + ( + len(input_ids), + len(self.blocks), + 2, + self.num_heads, + self.head_size, + ) + ) + # Decode + else: + prefill = False + + alibi = build_alibi_bias(n_heads=self.num_heads, seq_len=max_s,device=hidden_states.device, dtype = hidden_states.dtype) + # Cast alibi into correct shape + alibi = alibi[:, :, :, position_ids] + + residual = None + for i, layer in enumerate(self.blocks): + hidden_states, residual = layer( + hidden_states, + residual, + alibi, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + past_key_values[:, i], + past_present_indices, + prefill, + ) + + if prefill: + present = past_key_values + # Create padded past tensor + past_key_values = hidden_states.new_empty( + ( + pre_allocate_past_size, + len(self.blocks), + 2, + self.num_heads, + self.head_size, + ) + ) + # We slice only once instead of at every layer + past_key_values[past_present_indices] = present + + hidden_states, _ = self.norm_f(hidden_states, residual) + + return hidden_states, past_key_values + +class MPTForCausalLM(nn.Module): + def __init__(self, config, weights): + super().__init__() + self.transformer = MPTModel(config, weights) + self.lm_head = TensorParallelHead.load( + config, + prefix="transformer.wte", + weights=weights, + ) + + def forward( + self, + input_ids, + position_ids, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + past_present_indices, + past_key_values: Optional[torch.Tensor] = None, + pre_allocate_past_size: Optional[int] = None, + lm_head_indices: Optional[torch.Tensor] = None, + ): + hidden_states, present = self.transformer( + input_ids, + position_ids, + start_seq, + end_seq, + start_seq_q, + end_seq_q, + max_s, + past_present_indices, + past_key_values, + pre_allocate_past_size, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits = self.lm_head(hidden_states) + return logits, present diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index a80d58cb..a2cc0a82 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -25,12 +25,13 @@ class FlashLlama(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_mpt.py b/server/text_generation_server/models/flash_mpt.py new file mode 100644 index 00000000..61c362fa --- /dev/null +++ b/server/text_generation_server/models/flash_mpt.py @@ -0,0 +1,73 @@ +import torch +import torch.distributed + +from opentelemetry import trace +from transformers import AutoConfig, AutoTokenizer, PretrainedConfig +from typing import Optional +from huggingface_hub import hf_hub_download +import json + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_mpt_modeling import ( + MPTForCausalLM, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + +tracer = trace.get_tracer(__name__) + + +class MPTSharded(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 + else: + raise NotImplementedError("FlashMPT is only available on GPU") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + filename = hf_hub_download(model_id, revision=revision, filename="config.json") + with open(filename, "r") as f: + config = json.load(f) + config = PretrainedConfig(**config) + config.quantize = quantize + # config = AutoConfig.from_pretrained( + # # model_id, revision=revision, trust_remote_code=trust_remote_code + # model_id, revision=revision, trust_remote_code=False + # ) + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + + config.quantize = quantize + model = MPTForCausalLM(config, weights) + + torch.distributed.barrier(group=self.process_group) + super(FlashCausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 4847571d..c19bd1da 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -24,12 +24,13 @@ class FlashNeoXSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 5f963bfb..200d7fb8 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -25,12 +25,13 @@ class FlashRWSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index a71c0061..654e9946 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -24,12 +24,13 @@ class FlashSantacoderSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") @@ -52,8 +53,11 @@ class FlashSantacoderSharded(FlashCausalLM): torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group, - aliases = {"transformer.wte.weight": ["lm_head.weight"]} + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases={"transformer.wte.weight": ["lm_head.weight"]}, ) model = FlashSantacoderForCausalLM(config, weights) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 01e1c773..01e58bad 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -158,12 +158,13 @@ class GalacticaSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 0abf0239..91877fa0 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -24,12 +24,13 @@ class GPTNeoxSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 16cb48b7..d407b44a 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -22,12 +22,13 @@ class OPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 2b1e4959..92bb135b 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -12,11 +12,12 @@ class RW(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index d0fd3070..a2b38737 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -19,11 +19,12 @@ class SantaCoder(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 3ad5698c..5c436d30 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -504,11 +504,12 @@ class Seq2SeqLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: if quantize: raise ValueError("quantization is not available on CPU") diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index c89462fc..1b7073af 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -25,12 +25,13 @@ class T5Sharded(Seq2SeqLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 + dtype = torch.float16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index e1bd8412..5d2702d0 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -99,6 +99,7 @@ def serve( revision: Optional[str], sharded: bool, quantize: Optional[str], + dtype: Optional[str], trust_remote_code: bool, uds_path: Path, ): @@ -107,6 +108,7 @@ def serve( revision: Optional[str], sharded: bool = False, quantize: Optional[str] = None, + dtype: Optional[str] = None, trust_remote_code: bool = False, ): unix_socket_template = "unix://{}-{}" @@ -121,7 +123,9 @@ def serve( server_urls = [local_url] try: - model = get_model(model_id, revision, sharded, quantize, trust_remote_code) + model = get_model( + model_id, revision, sharded, quantize, dtype, trust_remote_code + ) except Exception: logger.exception("Error when initializing model") raise