From 63a18c141402117b633e34bd12ac566aae7de9ed Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Mon, 29 May 2023 11:56:19 +0200 Subject: [PATCH] feat(server): support RefinedWeb models --- .../text_generation_server/models/__init__.py | 31 ++ .../custom_modeling/flash_llama_modeling.py | 19 +- .../custom_modeling/flash_neox_modeling.py | 19 +- .../custom_modeling/flash_rw_modeling.py | 507 ++++++++++++++++++ .../text_generation_server/models/flash_rw.py | 246 +++++++++ server/text_generation_server/models/rw.py | 80 +++ server/text_generation_server/utils/layers.py | 13 +- 7 files changed, 891 insertions(+), 24 deletions(-) create mode 100644 server/text_generation_server/models/custom_modeling/flash_rw_modeling.py create mode 100644 server/text_generation_server/models/flash_rw.py create mode 100644 server/text_generation_server/models/rw.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index bf7a2849..50b5a83d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -10,6 +10,7 @@ from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.bloom import BLOOM, BLOOMSharded from text_generation_server.models.seq2seq_lm import Seq2SeqLM +from text_generation_server.models.rw import RW from text_generation_server.models.opt import OPT, OPTSharded from text_generation_server.models.galactica import Galactica, GalacticaSharded from text_generation_server.models.santacoder import SantaCoder @@ -30,6 +31,7 @@ try: ) from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded + from text_generation_server.models.flash_rw import FlashRW from text_generation_server.models.flash_llama import ( FlashLlama, FlashLlamaSharded, @@ -68,6 +70,7 @@ __all__ = [ if FLASH_ATTENTION: __all__.append(FlashNeoX) __all__.append(FlashNeoXSharded) + __all__.append(FlashRW) __all__.append(FlashSantacoder) __all__.append(FlashSantacoderSharded) __all__.append(FlashLlama) @@ -194,6 +197,34 @@ def get_model( trust_remote_code=trust_remote_code, ) + if model_type in ["RefinedWeb", "RefinedWebModel"]: + if sharded: + if FLASH_ATTENTION: + if config.alibi: + raise NotImplementedError("sharded is not supported for this model") + # return FlashRWSharded( + # model_id, + # revision, + # quantize=quantize, + # trust_remote_code=trust_remote_code, + # ) + raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")) + else: + if FLASH_ATTENTION and not config.alibi: + return FlashRW( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) + else: + return RW( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) + if model_type == "llama": if sharded: if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 54670b79..2dcb6ed8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -134,20 +134,23 @@ class FlashLlamaAttention(torch.nn.Module): ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) - qkv_rot = self.rotary_emb(qkv, cos, sin) + + # Inplace rotary + self.rotary_emb(qkv[:, 0], cos, sin) + self.rotary_emb(qkv[:, 1], cos, sin) # Prefill if layer_past_present_indices is None: # Copy to layer past - layer_past[...] = qkv_rot[:, 1:] + layer_past[...] = qkv[:, 1:] # output - attn_output = torch.empty_like(qkv_rot[:, 0]) + attn_output = torch.empty_like(qkv[:, 0]) # flash attention flash_attn_cuda.fwd( - qkv_rot[:, 0], - qkv_rot[:, 1], - qkv_rot[:, 2], + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], attn_output, cu_seqlens, cu_seqlens, @@ -163,9 +166,9 @@ class FlashLlamaAttention(torch.nn.Module): ) # Decode else: - query = qkv_rot[:, 0] + query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = qkv_rot[:, 1:] + layer_past[layer_past_present_indices] = qkv[:, 1:] # output attn_output = torch.empty_like(query) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b7834157..26e21753 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -101,20 +101,23 @@ class FlashNeoxAttention(torch.nn.Module): ): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) - qkv_rot = self.rotary_emb(qkv, cos, sin) + + # Inplace rotary + self.rotary_emb(qkv[:, 0], cos, sin) + self.rotary_emb(qkv[:, 1], cos, sin) # Prefill if layer_past_present_indices is None: # Copy to layer past - layer_past[...] = qkv_rot[:, 1:] + layer_past[...] = qkv[:, 1:] # output - attn_output = torch.empty_like(qkv_rot[:, 0]) + attn_output = torch.empty_like(qkv[:, 0]) # flash attention flash_attn_cuda.fwd( - qkv_rot[:, 0], - qkv_rot[:, 1], - qkv_rot[:, 2], + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], attn_output, cu_seqlens, cu_seqlens, @@ -130,9 +133,9 @@ class FlashNeoxAttention(torch.nn.Module): ) # Decode else: - query = qkv_rot[:, 0] + query = qkv[:, 0] # Add present to the layer_past tensor at the correct indices - layer_past[layer_past_present_indices] = qkv_rot[:, 1:] + layer_past[layer_past_present_indices] = qkv[:, 1:] # output attn_output = torch.empty_like(query) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py new file mode 100644 index 00000000..2cacc518 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -0,0 +1,507 @@ +import torch +import torch.distributed + +from loguru import logger +from torch import nn +from transformers.modeling_utils import PreTrainedModel +from transformers.configuration_utils import PretrainedConfig +from typing import Optional + +# Flash attention imports +import flash_attn_cuda + +from text_generation_server.utils.layers import ( + FastLinear, + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + FastLayerNorm, + PositionRotaryEmbedding, +) + + +class RWConfig(PretrainedConfig): + attribute_map = { + "num_hidden_layers": "n_layer", + "num_attention_heads": "n_head", + } + + def __init__( + self, + model_type="RefinedWeb", + vocab_size=250880, + hidden_size=64, + n_layer=2, + n_head=8, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + hidden_dropout=0.0, + attention_dropout=0.0, + n_head_kv=None, + multi_query=False, + alibi=False, + bias=False, + parallel_attn=False, + **kwargs, + ): + if alibi: + raise NotImplementedError("alibi is not supported by this version of the model") + + self.model_type = model_type + self.alibi = False + self.rotary = True + + self.vocab_size = vocab_size + # Backward compatibility with n_embed kwarg + n_embed = kwargs.pop("n_embed", None) + self.hidden_size = hidden_size if n_embed is None else n_embed + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.bias = bias + self.parallel_attn = parallel_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + if n_head_kv is not None: + self.n_head_kv = n_head_kv + else: + self.n_head_kv = 1 if multi_query else n_head + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +class FlashRWAttention(torch.nn.Module): + def __init__( + self, + num_heads, + num_heads_kv, + hidden_size, + bias, + process_group=None, + reduce=True, + ): + super().__init__() + self.num_heads = num_heads + self.num_heads_kv = num_heads_kv + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + + self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.softmax_scale = self.head_size ** (-0.5) + + if process_group is None: + self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv), + bias=bias) + self.dense = FastLinear(hidden_size, hidden_size, bias=bias) + else: + self.num_heads = self.num_heads // process_group.size() + self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv), + bias=bias) + self.dense = TensorParallelRowLinear( + hidden_size, hidden_size, bias=bias, process_group=process_group, reduce=reduce + ) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + qkv = self.query_key_value(hidden_states) + + # Split query from key_value + query, kv = qkv.split( + [self.head_size * self.num_heads, 2 * self.head_size], dim=1 + ) + + # Prepare query and key_value for indexing + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, 1, self.head_size) + + # Inplace rotary + self.rotary_emb(query, cos, sin) + self.rotary_emb(kv[:, 0], cos, sin) + + # Prefill + if layer_past_present_indices is None: + # Copy to layer past + layer_past[...] = kv + # Expand to query shape + kv = kv.expand(-1, 2, query.shape[1], self.head_size) + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + kv[:, 0], + kv[:, 1], + attn_output, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + self.softmax_scale, + False, + True, + False, + 0, + None, + ) + # Decode + else: + # Add present to the layer_past tensor at the correct indices + layer_past[layer_past_present_indices] = kv + # Expand to query shape + kv = layer_past.expand(-1, 2, query.shape[1], self.head_size) + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + kv[:, 0], + kv[:, 1], + attn_output, + cu_seqlens_q, + cu_seqlens, + 1, + max_s, + 0.0, + self.softmax_scale, + False, + False, + False, + 0, + None, + ) + + return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) + + +class FlashMLP(nn.Module): + def __init__( + self, hidden_size, bias, process_group=None, reduce=True + ): + super().__init__() + self.act = torch.nn.functional.gelu + + if process_group is None: + self.dense_h_to_4h = FastLinear(hidden_size, 4 * hidden_size, bias=bias) + self.dense_4h_to_h = FastLinear(4 * hidden_size, hidden_size, bias=bias) + else: + self.dense_h_to_4h = TensorParallelColumnLinear( + hidden_size, + 4 * hidden_size, bias=bias, + process_group=process_group, + ) + self.dense_4h_to_h = TensorParallelRowLinear( + 4 * hidden_size, + hidden_size, bias=bias, + process_group=process_group, + reduce=reduce, + ) + self.process_group = process_group + + def forward(self, hidden_states): + hidden_states = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class FlashRWLayer(nn.Module): + def __init__( + self, + num_heads, + num_heads_kv, + hidden_size, + layer_norm_eps, + parallel_attn, + process_group=None, + ): + super().__init__() + + self.parallel_attn = parallel_attn + + self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) + self.self_attention = FlashRWAttention(num_heads, num_heads_kv, hidden_size, process_group, reduce=False) + self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) if not parallel_attn else None + + self.mlp = FlashMLP(hidden_size, process_group, reduce=False) + + self.process_group = process_group + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + if self.parallel_attn: + ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) + + attn_output = self.self_attention( + ln_hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ) + + mlp_output = self.mlp(ln_hidden_states) + intermediate = mlp_output + attn_output + + # Only reduce once and after the addition instead of once per layer + if self.process_group is not None: + torch.distributed.all_reduce(intermediate, group=self.process_group) + + return intermediate, residual + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attention( + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + + mlp_output = self.mlp(hidden_states) + + return mlp_output, residual + +class FlashRWPreTrainedModel(PreTrainedModel): + config_class = RWConfig + supports_gradient_checkpointing = False + _no_split_modules = None + + +class FlashRWModel(FlashRWPreTrainedModel): + def __init__(self, config, process_group=None): + super().__init__(config) + self.config = config + + self.tp_embeddings = False + if process_group is not None: + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + if config.vocab_size % self.tp_world_size == 0: + self.tp_embeddings = True + + if self.tp_embeddings: + self.word_embeddings = TensorParallelEmbedding( + config.vocab_size, config.hidden_size, process_group=process_group + ) + else: + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + + self.h = nn.ModuleList( + [ + FlashRWLayer( + config.n_head, + config.n_head_kv, + config.hidden_size, + config.layer_norm_epsilon, + config.parallel_attn, + process_group, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.ln_f = FastLayerNorm( + config.hidden_size, eps=config.layer_norm_epsilon + ) + + self.gradient_checkpointing = False + + self.head_size = self.h[0].self_attention.head_size + self.num_heads_kv = self.h[0].self_attention.num_heads_kv + + def post_load_weights(self, quantize: Optional[str] = None): + if isinstance(self.word_embeddings, TensorParallelEmbedding): + self.word_embeddings.add_null_idx() + for layer in self.h: + layer: FlashRWLayer + layer.self_attention.query_key_value.prepare_weights(quantize) + layer.self_attention.dense.prepare_weights(quantize) + layer.mlp.dense_h_to_4h.prepare_weights(quantize) + layer.mlp.dense_4h_to_h.prepare_weights(quantize) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + # Pop here as we will replace the layer in our own logic and don't want from_pretrained + # to do it for us + load_in_8bit = kwargs.pop("load_in_8bit", False) + model = super(FlashRWModel, cls).from_pretrained( + pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + ) + + model.post_load_weights("bitsandbytes" if load_in_8bit else None) + return model + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + cu_seqlens_q, + max_s, + past_key_values=None, + pre_allocate_past_size: Optional[int] = None, + ): + hidden_states = self.word_embeddings(input_ids) + + # Prefill + if past_key_values is None: + # Create past tensor + past_key_values = hidden_states.new_empty( + ( + len(self.h), + len(hidden_states) + if pre_allocate_past_size is None + else pre_allocate_past_size, + 2, + self.num_heads_kv, + self.head_size, + ) + ) + layer_past_present_indices = None + slice_past_index = len(hidden_states) + # Decode + else: + # Create indices from cumulative sequence lengths + layer_past_present_indices = cu_seqlens[1:] - 1 + slice_past_index = None + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.h): + # We added padding that we now need to slice + layer_past_key_values = ( + past_key_values[i] + if slice_past_index is None + else past_key_values[i, :slice_past_index] + ) + + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + layer_past_key_values, + layer_past_present_indices, + cu_seqlens_q, + ) + + hidden_states, _ = self.ln_f(hidden_states, residual) + + return hidden_states, past_key_values + + +class FlashRWForCausalLM(FlashRWPreTrainedModel): + def __init__(self, config, process_group=None): + super().__init__(config) + + self.process_group = process_group + if self.process_group is not None: + self.world_size = self.process_group.size() + else: + self.world_size = 1 + + self.transformer = FlashRWModel(config, process_group) + + if self.transformer.tp_embeddings: + self.lm_head = FastLinear( + config.hidden_size, + config.vocab_size // process_group.size(), + bias=False, + ) + else: + self.lm_head = FastLinear( + config.hidden_size, config.vocab_size, bias=False + ) + + def post_load_weights(self, quantize: Optional[str] = None): + self.transformer.post_load_weights(quantize) + self.lm_head.prepare_weights() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + # Pop here as we will replace the layer in our own logic and don't want from_pretrained + # to do it for us + load_in_8bit = kwargs.pop("load_in_8bit", False) + model = super(FlashRWForCausalLM, cls).from_pretrained( + pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs + ) + model.post_load_weights("bitsandbytes" if load_in_8bit else None) + return model + + def forward( + self, + input_ids, + position_ids, + cu_seqlens, + cu_seqlens_q, + max_s, + past_key_values: Optional[torch.Tensor] = None, + pre_allocate_past_size: Optional[int] = None, + ): + hidden_states, present = self.transformer( + input_ids, + position_ids, + cu_seqlens, + cu_seqlens_q, + max_s, + past_key_values, + pre_allocate_past_size, + ) + logits = self.lm_head(hidden_states) + + if self.transformer.tp_embeddings: + # Logits are sharded, so we need to gather them + world_logits = [torch.empty_like(logits) for _ in range(self.world_size)] + torch.distributed.all_gather(world_logits, logits, group=self.process_group) + world_logits = torch.cat(world_logits, dim=1) + + return world_logits, present + return logits, present diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py new file mode 100644 index 00000000..8e38c847 --- /dev/null +++ b/server/text_generation_server/models/flash_rw.py @@ -0,0 +1,246 @@ +import torch +import torch.distributed + +from pathlib import Path +from accelerate import init_empty_weights +from opentelemetry import trace +from safetensors import safe_open +from transformers import AutoTokenizer, AutoConfig +from typing import Optional, List + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_rw_modeling import ( + RWConfig, + FlashRWForCausalLM, + TensorParallelEmbedding, + TensorParallelRowLinear, + TensorParallelColumnLinear, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + download_weights, + weight_hub_files, + LocalEntryNotFoundError, +) + +tracer = trace.get_tracer(__name__) + + +class FlashRW(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, + ): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 + else: + raise NotImplementedError("RW is only available on GPU") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = RWConfig.from_pretrained( + model_id, + revision=revision, + ) + from loguru import logger + logger.error(config.model_type) + + # We do not use from_pretrained as we modified the model internal module layout + try: + filenames = weight_files(model_id, revision, ".bin") + # Local files not found + except LocalEntryNotFoundError: + hub_files = weight_hub_files(model_id, revision, ".bin") + filenames = download_weights(hub_files, model_id, revision) + + with init_empty_weights(): + model = FlashRWForCausalLM(config) + + self.load_weights( + model, + filenames, + quantize, + device, + dtype, + ) + + super(FlashCausalLM, self).__init__( + model=model.to(device), + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + ) + + @staticmethod + def load_weights( + model: FlashRWForCausalLM, + filenames: List[Path], + quantize: Optional[str], + device: torch.device, + dtype: torch.dtype, + ): + for filename in filenames: + state_dict = torch.load(filename, map_location="cpu") + for key, value in state_dict.items(): + value = value.to(device if quantize is None else "cpu").to(dtype) + + module_name, param_name = key.rsplit(".", 1) + module = model.get_submodule(module_name) + + try: + current_parameter_tensor = module._parameters[param_name] + if current_parameter_tensor.shape != value.shape: + raise ValueError( + f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}" + ) + module._parameters[param_name] = value + except KeyError: + module._buffers[param_name] = value + + del value + + torch.cuda.empty_cache() + model.post_load_weights(quantize) + +# +# class FlashNeoXSharded(FlashNeoX): +# def __init__( +# self, +# model_id: str, +# revision: Optional[str] = None, +# quantize: Optional[str] = None, +# trust_remote_code: bool = False, +# ): +# self.process_group, rank, world_size = initialize_torch_distributed() +# if torch.cuda.is_available(): +# device = torch.device(f"cuda:{rank}") +# dtype = torch.float16 +# else: +# raise NotImplementedError("FlashNeoX is only available on GPU") +# +# tokenizer = AutoTokenizer.from_pretrained( +# model_id, +# revision=revision, +# padding_side="left", +# truncation_side="left", +# trust_remote_code=trust_remote_code, +# ) +# +# config = AutoConfig.from_pretrained( +# model_id, revision=revision, trust_remote_code=trust_remote_code +# ) +# +# torch.distributed.barrier(group=self.process_group) +# filenames = weight_files(model_id, revision=revision, extension=".safetensors") +# +# with init_empty_weights(): +# model = FlashGPTNeoXForCausalLM(config, self.process_group) +# +# torch.distributed.barrier(group=self.process_group) +# self.load_weights( +# model, +# filenames, +# quantize=quantize, +# device=device, +# dtype=dtype, +# rank=rank, +# world_size=world_size, +# ) +# torch.distributed.barrier(group=self.process_group) +# super(FlashCausalLM, self).__init__( +# model=model.to(device), +# tokenizer=tokenizer, +# requires_padding=False, +# dtype=dtype, +# device=device, +# rank=rank, +# world_size=world_size, +# ) +# +# @staticmethod +# def load_weights( +# model, +# filenames: List[str], +# quantize: Optional[str], +# device: torch.device, +# dtype: torch.dtype, +# rank: int, +# world_size: int, +# ): +# parameters = dict(model.named_parameters()) +# for file in filenames: +# with safe_open( +# file, framework="pt", device=str(device) if quantize is None else "cpu" +# ) as f: +# for name in f.keys(): +# module_name, param_name = name.rsplit(".", 1) +# module = model.get_submodule(module_name) +# +# current_parameter_tensor = parameters.get(name, None) +# +# slice_ = f.get_slice(name) +# +# if isinstance(module, TensorParallelColumnLinear): +# size = slice_.get_shape()[0] +# block_size = size // world_size +# start = rank * block_size +# stop = (rank + 1) * block_size +# tensor = slice_[start:stop] +# elif isinstance(module, TensorParallelRowLinear): +# if param_name == "weight": +# size = slice_.get_shape()[1] +# block_size = size // world_size +# start = rank * block_size +# stop = (rank + 1) * block_size +# tensor = slice_[:, start:stop] +# else: +# tensor = slice_[:] +# # XXX: Hack for Rowlinear to add the bias only once. +# if rank != 0: +# tensor = torch.zeros_like(tensor) +# elif isinstance(module, TensorParallelEmbedding): +# size = slice_.get_shape()[0] +# block_size = size // world_size +# start = rank * block_size +# stop = (rank + 1) * block_size +# tensor = slice_[start:stop] +# elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: +# size = slice_.get_shape()[0] +# block_size = size // world_size +# start = rank * block_size +# stop = (rank + 1) * block_size +# tensor = slice_[start:stop] +# else: +# try: +# tensor = slice_[:] +# except: +# tensor = f.get_tensor(name) +# +# if ( +# current_parameter_tensor is not None +# and current_parameter_tensor.shape != tensor.shape +# ): +# raise ValueError( +# f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" +# ) +# +# tensor = tensor.contiguous().to(dtype) +# +# if current_parameter_tensor is not None: +# module._parameters[param_name] = tensor +# else: +# module._buffers[param_name] = tensor +# +# model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py new file mode 100644 index 00000000..6500ac37 --- /dev/null +++ b/server/text_generation_server/models/rw.py @@ -0,0 +1,80 @@ +import torch + +from transformers import AutoTokenizer, AutoModelForCausalLM +from typing import List, Optional, Tuple + +from text_generation_server.models import CausalLM + + +class RW(CausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, + ): + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.bfloat16 + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + model = AutoModelForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + device_map="auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None, + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + ) + if torch.cuda.is_available() and torch.cuda.device_count() == 1: + model = model.cuda() + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + super(CausalLM, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + + def forward( + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + # Model Forward + if past_key_values is not None: + reshaped_past_key_values = [] + for layer in past_key_values: + past_keys, past_values = layer + reshaped_past_key_values.append( + (past_keys.view(-1, *past_keys.shape[-2:]), past_values.view(-1, *past_values.shape[-2:])) + ) + past_key_values = reshaped_past_key_values + + outputs = self.model.forward(input_ids=input_ids, attention_mask=attention_mask, + past_key_values=past_key_values) + return outputs.logits, outputs.past_key_values diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 7605639d..127f9ba4 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -262,16 +262,13 @@ try: sin = torch.index_select(self._sin_cached, 0, position_ids) return cos.unsqueeze(1), sin.unsqueeze(1) - def forward(self, qkv: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): rotary_dim = cos.shape[-1] - q1 = qkv[:, 0, :, :rotary_dim] - q2 = qkv[:, 0, :, rotary_dim : 2 * rotary_dim] - k1 = qkv[:, 1, :, :rotary_dim] - k2 = qkv[:, 1, :, rotary_dim : 2 * rotary_dim] + x1 = x[..., :rotary_dim] + x2 = x[..., rotary_dim : 2 * rotary_dim] - rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) - rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) - return qkv + rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False) + return x except ImportError: pass