From cbffddcc069cb23caf37331f6e6cfcdb00070055 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 30 May 2023 10:41:10 +0200 Subject: [PATCH] wip --- server/tests/models/test_bloom.py | 4 +- server/tests/models/test_causal_lm.py | 4 +- server/tests/models/test_seq2seq_lm.py | 4 +- .../text_generation_server/models/__init__.py | 19 +- .../custom_modeling/flash_rw_modeling.py | 446 ++++++++++++++---- .../text_generation_server/models/flash_rw.py | 263 +++++------ server/text_generation_server/models/rw.py | 24 +- 7 files changed, 509 insertions(+), 255 deletions(-) diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 105b3573..4c9a659c 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -286,7 +286,9 @@ def test_batch_concatenate( == default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) + next_batch = next_batch.filter( + [next_batch.requests[0].id, next_batch.requests[1].id] + ) for _ in range( default_bloom_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index d8d1bd16..cc5af070 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -285,7 +285,9 @@ def test_batch_concatenate( == default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens ) - next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) + next_batch = next_batch.filter( + [next_batch.requests[0].id, next_batch.requests[1].id] + ) for _ in range( default_causal_lm_batch.stopping_criterias[0].max_new_tokens diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 8fdeee60..a98f6fb6 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -323,7 +323,9 @@ def test_batch_concatenate( ) assert generations[2].generated_text.generated_tokens == 5 - next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id]) + next_batch = next_batch.filter( + [next_batch.requests[0].id, next_batch.requests[1].id] + ) generations, next_batch = default_seq2seq_lm.generate_token(next_batch) assert next_batch is not None diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 50b5a83d..7db63a5c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -31,7 +31,7 @@ try: ) from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded - from text_generation_server.models.flash_rw import FlashRW + from text_generation_server.models.flash_rw import FlashRW, FlashRWSharded from text_generation_server.models.flash_llama import ( FlashLlama, FlashLlamaSharded, @@ -71,6 +71,7 @@ if FLASH_ATTENTION: __all__.append(FlashNeoX) __all__.append(FlashNeoXSharded) __all__.append(FlashRW) + __all__.append(FlashRWSharded) __all__.append(FlashSantacoder) __all__.append(FlashSantacoderSharded) __all__.append(FlashLlama) @@ -202,13 +203,15 @@ def get_model( if FLASH_ATTENTION: if config.alibi: raise NotImplementedError("sharded is not supported for this model") - # return FlashRWSharded( - # model_id, - # revision, - # quantize=quantize, - # trust_remote_code=trust_remote_code, - # ) - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")) + return FlashRWSharded( + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, + ) + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb") + ) else: if FLASH_ATTENTION and not config.alibi: return FlashRW( diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index c13ffa7c..41be97fc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -27,28 +27,30 @@ class RWConfig(PretrainedConfig): } def __init__( - self, - model_type="RefinedWeb", - vocab_size=250880, - hidden_size=64, - n_layer=2, - n_head=8, - layer_norm_epsilon=1e-5, - initializer_range=0.02, - use_cache=True, - bos_token_id=1, - eos_token_id=2, - hidden_dropout=0.0, - attention_dropout=0.0, - n_head_kv=None, - multi_query=False, - alibi=False, - bias=False, - parallel_attn=False, - **kwargs, + self, + model_type="RefinedWeb", + vocab_size=250880, + hidden_size=64, + n_layer=2, + n_head=8, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + hidden_dropout=0.0, + attention_dropout=0.0, + n_head_kv=None, + multi_query=False, + alibi=False, + bias=False, + parallel_attn=False, + **kwargs, ): if alibi: - raise NotImplementedError("alibi is not supported by this version of the model") + raise NotImplementedError( + "alibi is not supported by this version of the model" + ) self.model_type = model_type self.alibi = False @@ -81,13 +83,13 @@ class RWConfig(PretrainedConfig): class FlashRWAttention(torch.nn.Module): def __init__( - self, - num_heads, - num_heads_kv, - hidden_size, - bias, - process_group=None, - reduce=True, + self, + num_heads, + num_heads_kv, + hidden_size, + bias, + process_group=None, + reduce=True, ): super().__init__() self.num_heads = num_heads @@ -99,33 +101,44 @@ class FlashRWAttention(torch.nn.Module): self.softmax_scale = self.head_size ** (-0.5) if process_group is None: - self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv), - bias=bias) + self.query_key_value = FastLinear( + hidden_size, + self.head_size * (self.num_heads + 2 * self.num_heads_kv), + bias=bias, + ) self.dense = FastLinear(hidden_size, hidden_size, bias=bias) else: self.num_heads = self.num_heads // process_group.size() - self.query_key_value = FastLinear(hidden_size, self.head_size * (self.num_heads + 2 * self.num_heads_kv), - bias=bias) + self.query_key_value = FastLinear( + hidden_size, + self.head_size * (self.num_heads + 2 * self.num_heads_kv), + bias=bias, + ) self.dense = TensorParallelRowLinear( - hidden_size, hidden_size, bias=bias, process_group=process_group, reduce=reduce + hidden_size, + hidden_size, + bias=bias, + process_group=process_group, + reduce=reduce, ) def forward( - self, - hidden_states, - cos, - sin, - cu_seqlens, - max_s, - layer_past, - layer_past_present_indices, - cu_seqlens_q, + self, + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, ): qkv = self.query_key_value(hidden_states) # Split query from key_value query, kv = qkv.split( - [self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv], dim=1 + [self.head_size * self.num_heads, 2 * self.head_size * self.num_heads_kv], + dim=1, ) # Prepare query and key_value for indexing @@ -194,11 +207,149 @@ class FlashRWAttention(torch.nn.Module): return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) -class FlashMLP(nn.Module): +class FlashRWLargeAttention(torch.nn.Module): def __init__( - self, hidden_size, bias, process_group=None, reduce=True + self, + num_heads, + num_heads_kv, + hidden_size, + bias, + process_group=None, + reduce=True, ): super().__init__() + + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + + self.rotary_emb = PositionRotaryEmbedding(self.head_size, base=10000) + self.softmax_scale = self.head_size ** (-0.5) + + self.num_groups = num_heads // (num_heads_kv * 2) + self.num_heads = num_heads // self.num_groups + self.num_heads_kv = num_heads_kv // self.num_groups + + if process_group is None: + self.query_key_value = FastLinear( + hidden_size, + self.num_groups * + self.head_size + * (self.num_heads + 2 * self.num_heads_kv), + bias=bias, + ) + self.dense = FastLinear(hidden_size, hidden_size, bias=bias) + else: + self.query_key_value = TensorParallelColumnLinear( + hidden_size, + self.num_groups + * self.head_size + * (self.num_heads + 2 * self.num_heads_kv), + bias=bias, + process_group=process_group, + ) + self.dense = TensorParallelRowLinear( + hidden_size, + hidden_size, + bias=bias, + process_group=process_group, + reduce=reduce, + ) + + self.num_groups = self.num_groups // process_group.size() + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + qkv = self.query_key_value(hidden_states) + qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) + + # Split query from key_value + query, kv = qkv.split( + [self.num_heads, 2], + dim=2, + ) + + # Prepare query and key_value for indexing + query = query.reshape(-1, self.num_groups * self.num_heads, self.head_size) + kv = kv.transpose(1, 2) + + # Inplace rotary + self.rotary_emb(query, cos, sin) + self.rotary_emb(kv[:, 0], cos, sin) + + # Prefill + if layer_past_present_indices is None: + # Copy to layer past + layer_past[...] = kv + k, v = kv.split(1, dim=1) + # Expand to query shape + k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + k, + v, + attn_output, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + self.softmax_scale, + False, + True, + False, + 0, + None, + ) + # Decode + else: + # Add present to the layer_past tensor at the correct indices + layer_past[layer_past_present_indices] = kv + k, v = layer_past.split(1, dim=1) + # Expand to query shape + k = k.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + v = v.expand(-1, self.num_heads, self.num_groups, self.head_size).reshape(-1, self.num_groups * self.num_heads, self.head_size) + + # output + attn_output = torch.empty_like(query) + # flash attention + flash_attn_cuda.fwd( + query, + k, + v, + attn_output, + cu_seqlens_q, + cu_seqlens, + 1, + max_s, + 0.0, + self.softmax_scale, + False, + False, + False, + 0, + None, + ) + + return self.dense(attn_output.view(-1, self.num_heads * self.num_groups * self.head_size)) + + +class FlashMLP(nn.Module): + def __init__(self, hidden_size, bias, process_group=None, reduce=True): + super().__init__() self.act = torch.nn.functional.gelu if process_group is None: @@ -207,12 +358,14 @@ class FlashMLP(nn.Module): else: self.dense_h_to_4h = TensorParallelColumnLinear( hidden_size, - 4 * hidden_size, bias=bias, + 4 * hidden_size, + bias=bias, process_group=process_group, ) self.dense_4h_to_h = TensorParallelRowLinear( 4 * hidden_size, - hidden_size, bias=bias, + hidden_size, + bias=bias, process_group=process_group, reduce=reduce, ) @@ -227,37 +380,44 @@ class FlashMLP(nn.Module): class FlashRWLayer(nn.Module): def __init__( - self, - num_heads, - num_heads_kv, - hidden_size, - layer_norm_eps, - parallel_attn, - process_group=None, + self, + num_heads, + num_heads_kv, + hidden_size, + bias, + layer_norm_eps, + parallel_attn, + process_group=None, ): super().__init__() self.parallel_attn = parallel_attn self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) - self.self_attention = FlashRWAttention(num_heads, num_heads_kv, hidden_size, process_group, reduce=False) - self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) if not parallel_attn else None + self.self_attention = FlashRWAttention( + num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False + ) + self.post_attention_layernorm = ( + FastLayerNorm(hidden_size, eps=layer_norm_eps) + if not parallel_attn + else None + ) - self.mlp = FlashMLP(hidden_size, process_group, reduce=False) + self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False) self.process_group = process_group def forward( - self, - hidden_states, - residual, - cos, - sin, - cu_seqlens, - max_s, - layer_past, - layer_past_present_indices, - cu_seqlens_q, + self, + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, ): if self.parallel_attn: ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -303,6 +463,68 @@ class FlashRWLayer(nn.Module): return mlp_output, residual + +class FlashRWLargeLayer(nn.Module): + def __init__( + self, + num_heads, + num_heads_kv, + hidden_size, + bias, + layer_norm_eps, + process_group=None, + ): + super().__init__() + self.ln_attn = FastLayerNorm(hidden_size, eps=layer_norm_eps) + self.ln_mlp = FastLayerNorm(hidden_size, eps=layer_norm_eps) + + self.self_attention = FlashRWLargeAttention( + num_heads, num_heads_kv, hidden_size, bias, process_group=process_group, reduce=False + ) + + self.mlp = FlashMLP(hidden_size, bias, process_group=process_group, reduce=False) + + self.process_group = process_group + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ): + ln_attn, residual = self.ln_attn(hidden_states, residual) + ln_mlp, _ = self.ln_mlp(hidden_states, residual) + + # Self attention. + attn_output = self.self_attention( + ln_attn, + cos, + sin, + cu_seqlens, + max_s, + layer_past, + layer_past_present_indices, + cu_seqlens_q, + ) + + # MLP. + mlp_output = self.mlp(ln_mlp) + + intermediate = attn_output + mlp_output + + # Only reduce once and after the addition instead of once per layer + if self.process_group is not None: + torch.distributed.all_reduce(intermediate, group=self.process_group) + + return intermediate, residual + + class FlashRWPreTrainedModel(PreTrainedModel): config_class = RWConfig supports_gradient_checkpointing = False @@ -328,27 +550,47 @@ class FlashRWModel(FlashRWPreTrainedModel): else: self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.h = nn.ModuleList( - [ - FlashRWLayer( - config.n_head, - config.n_head_kv, - config.hidden_size, - config.layer_norm_epsilon, - config.parallel_attn, - process_group, - ) - for _ in range(config.num_hidden_layers) - ] - ) - self.ln_f = FastLayerNorm( - config.hidden_size, eps=config.layer_norm_epsilon - ) + if config.model_type == "RefinedWebModel": + self.h = nn.ModuleList( + [ + FlashRWLayer( + config.n_head, + config.n_head_kv, + config.hidden_size, + config.bias, + config.layer_norm_epsilon, + config.parallel_attn, + process_group, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.kv_size = self.h[0].self_attention.num_heads_kv + elif config.model_type == "RefinedWeb": + self.h = nn.ModuleList( + [ + FlashRWLargeLayer( + config.n_head, + config.n_head_kv, + config.hidden_size, + config.bias, + config.layer_norm_epsilon, + process_group, + ) + for _ in range(config.num_hidden_layers) + ] + ) + self.kv_size = self.h[0].self_attention.num_groups + else: + raise NotImplementedError( + f"model_type {config.model_type} is not supported." + ) + + self.ln_f = FastLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.gradient_checkpointing = False self.head_size = self.h[0].self_attention.head_size - self.num_heads_kv = self.h[0].self_attention.num_heads_kv def post_load_weights(self, quantize: Optional[str] = None): if isinstance(self.word_embeddings, TensorParallelEmbedding): @@ -373,14 +615,14 @@ class FlashRWModel(FlashRWPreTrainedModel): return model def forward( - self, - input_ids, - position_ids, - cu_seqlens, - cu_seqlens_q, - max_s, - past_key_values=None, - pre_allocate_past_size: Optional[int] = None, + self, + input_ids, + position_ids, + cu_seqlens, + cu_seqlens_q, + max_s, + past_key_values=None, + pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.word_embeddings(input_ids) @@ -394,7 +636,7 @@ class FlashRWModel(FlashRWPreTrainedModel): if pre_allocate_past_size is None else pre_allocate_past_size, 2, - self.num_heads_kv, + self.kv_size, self.head_size, ) ) @@ -457,9 +699,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): bias=False, ) else: - self.lm_head = FastLinear( - config.hidden_size, config.vocab_size, bias=False - ) + self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) def post_load_weights(self, quantize: Optional[str] = None): self.transformer.post_load_weights(quantize) @@ -477,14 +717,14 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): return model def forward( - self, - input_ids, - position_ids, - cu_seqlens, - cu_seqlens_q, - max_s, - past_key_values: Optional[torch.Tensor] = None, - pre_allocate_past_size: Optional[int] = None, + self, + input_ids, + position_ids, + cu_seqlens, + cu_seqlens_q, + max_s, + past_key_values: Optional[torch.Tensor] = None, + pre_allocate_past_size: Optional[int] = None, ): hidden_states, present = self.transformer( input_ids, diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 8e38c847..94efc833 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -53,8 +53,6 @@ class FlashRW(FlashCausalLM): model_id, revision=revision, ) - from loguru import logger - logger.error(config.model_type) # We do not use from_pretrained as we modified the model internal module layout try: @@ -114,133 +112,134 @@ class FlashRW(FlashCausalLM): torch.cuda.empty_cache() model.post_load_weights(quantize) -# -# class FlashNeoXSharded(FlashNeoX): -# def __init__( -# self, -# model_id: str, -# revision: Optional[str] = None, -# quantize: Optional[str] = None, -# trust_remote_code: bool = False, -# ): -# self.process_group, rank, world_size = initialize_torch_distributed() -# if torch.cuda.is_available(): -# device = torch.device(f"cuda:{rank}") -# dtype = torch.float16 -# else: -# raise NotImplementedError("FlashNeoX is only available on GPU") -# -# tokenizer = AutoTokenizer.from_pretrained( -# model_id, -# revision=revision, -# padding_side="left", -# truncation_side="left", -# trust_remote_code=trust_remote_code, -# ) -# -# config = AutoConfig.from_pretrained( -# model_id, revision=revision, trust_remote_code=trust_remote_code -# ) -# -# torch.distributed.barrier(group=self.process_group) -# filenames = weight_files(model_id, revision=revision, extension=".safetensors") -# -# with init_empty_weights(): -# model = FlashGPTNeoXForCausalLM(config, self.process_group) -# -# torch.distributed.barrier(group=self.process_group) -# self.load_weights( -# model, -# filenames, -# quantize=quantize, -# device=device, -# dtype=dtype, -# rank=rank, -# world_size=world_size, -# ) -# torch.distributed.barrier(group=self.process_group) -# super(FlashCausalLM, self).__init__( -# model=model.to(device), -# tokenizer=tokenizer, -# requires_padding=False, -# dtype=dtype, -# device=device, -# rank=rank, -# world_size=world_size, -# ) -# -# @staticmethod -# def load_weights( -# model, -# filenames: List[str], -# quantize: Optional[str], -# device: torch.device, -# dtype: torch.dtype, -# rank: int, -# world_size: int, -# ): -# parameters = dict(model.named_parameters()) -# for file in filenames: -# with safe_open( -# file, framework="pt", device=str(device) if quantize is None else "cpu" -# ) as f: -# for name in f.keys(): -# module_name, param_name = name.rsplit(".", 1) -# module = model.get_submodule(module_name) -# -# current_parameter_tensor = parameters.get(name, None) -# -# slice_ = f.get_slice(name) -# -# if isinstance(module, TensorParallelColumnLinear): -# size = slice_.get_shape()[0] -# block_size = size // world_size -# start = rank * block_size -# stop = (rank + 1) * block_size -# tensor = slice_[start:stop] -# elif isinstance(module, TensorParallelRowLinear): -# if param_name == "weight": -# size = slice_.get_shape()[1] -# block_size = size // world_size -# start = rank * block_size -# stop = (rank + 1) * block_size -# tensor = slice_[:, start:stop] -# else: -# tensor = slice_[:] -# # XXX: Hack for Rowlinear to add the bias only once. -# if rank != 0: -# tensor = torch.zeros_like(tensor) -# elif isinstance(module, TensorParallelEmbedding): -# size = slice_.get_shape()[0] -# block_size = size // world_size -# start = rank * block_size -# stop = (rank + 1) * block_size -# tensor = slice_[start:stop] -# elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings: -# size = slice_.get_shape()[0] -# block_size = size // world_size -# start = rank * block_size -# stop = (rank + 1) * block_size -# tensor = slice_[start:stop] -# else: -# try: -# tensor = slice_[:] -# except: -# tensor = f.get_tensor(name) -# -# if ( -# current_parameter_tensor is not None -# and current_parameter_tensor.shape != tensor.shape -# ): -# raise ValueError( -# f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" -# ) -# -# tensor = tensor.contiguous().to(dtype) -# -# if current_parameter_tensor is not None: -# module._parameters[param_name] = tensor -# else: -# module._buffers[param_name] = tensor -# -# model.post_load_weights(quantize) + + +class FlashRWSharded(FlashRW): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 + else: + raise NotImplementedError("FlashRW is only available on GPU") + + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = RWConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + + with init_empty_weights(): + model = FlashRWForCausalLM(config, self.process_group) + + torch.distributed.barrier(group=self.process_group) + self.load_weights( + model, + filenames, + quantize=quantize, + device=device, + dtype=dtype, + rank=rank, + world_size=world_size, + ) + torch.distributed.barrier(group=self.process_group) + super(FlashCausalLM, self).__init__( + model=model.to(device), + tokenizer=tokenizer, + requires_padding=False, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + @staticmethod + def load_weights( + model, + filenames: List[str], + quantize: Optional[str], + device: torch.device, + dtype: torch.dtype, + rank: int, + world_size: int, + ): + parameters = dict(model.named_parameters()) + for file in filenames: + with safe_open( + file, framework="pt", device=str(device) if quantize is None else "cpu" + ) as f: + for name in f.keys(): + module_name, param_name = name.rsplit(".", 1) + module = model.get_submodule(module_name) + + current_parameter_tensor = parameters.get(name, None) + + slice_ = f.get_slice(name) + + if isinstance(module, TensorParallelColumnLinear): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif isinstance(module, TensorParallelRowLinear): + if param_name == "weight": + size = slice_.get_shape()[1] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[:, start:stop] + else: + tensor = slice_[:] + # XXX: Hack for Rowlinear to add the bias only once. + if rank != 0: + tensor = torch.zeros_like(tensor) + elif isinstance(module, TensorParallelEmbedding): + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + elif name == "lm_head.weight" and model.transformer.tp_embeddings: + size = slice_.get_shape()[0] + block_size = size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + tensor = slice_[start:stop] + else: + try: + tensor = slice_[:] + except: + tensor = f.get_tensor(name) + + if ( + current_parameter_tensor is not None + and current_parameter_tensor.shape != tensor.shape + ): + raise ValueError( + f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" + ) + + tensor = tensor.contiguous().to(dtype) + + if current_parameter_tensor is not None: + module._parameters[param_name] = tensor + else: + module._buffers[param_name] = tensor + + model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 6500ac37..dd389027 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -8,11 +8,11 @@ from text_generation_server.models import CausalLM class RW(CausalLM): def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - trust_remote_code: bool = False, + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + trust_remote_code: bool = False, ): if torch.cuda.is_available(): device = torch.device("cuda") @@ -63,7 +63,7 @@ class RW(CausalLM): ) def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward if past_key_values is not None: @@ -71,10 +71,16 @@ class RW(CausalLM): for layer in past_key_values: past_keys, past_values = layer reshaped_past_key_values.append( - (past_keys.view(-1, *past_keys.shape[-2:]), past_values.view(-1, *past_values.shape[-2:])) + ( + past_keys.view(-1, *past_keys.shape[-2:]), + past_values.view(-1, *past_values.shape[-2:]), + ) ) past_key_values = reshaped_past_key_values - outputs = self.model.forward(input_ids=input_ids, attention_mask=attention_mask, - past_key_values=past_key_values) + outputs = self.model.forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + ) return outputs.logits, outputs.past_key_values