diff --git a/Dockerfile b/Dockerfile index 056f2f2b..42d01479 100644 --- a/Dockerfile +++ b/Dockerfile @@ -138,7 +138,7 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.lin # Copy build artifacts from transformers builder COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-39/custom_kernels /usr/src/custom-kernels/src/custom_kernels -# Install transformers dependencies +# Install flash-attention dependencies RUN pip install einops --no-cache-dir # Install server diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 1a02cc91..8f59d75a 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -249,7 +249,6 @@ def launcher(event_loop): ) as process: yield ProcessLauncherHandle(process, port) - process.terminate() process.wait(60) @@ -261,6 +260,7 @@ def launcher(event_loop): if not use_flash_attention: del env["USE_FLASH_ATTENTION"] + @contextlib.contextmanager def docker_launcher( model_id: str, diff --git a/integration-tests/models/test_neox.py b/integration-tests/models/test_neox.py index 8d949ddb..7b88f86a 100644 --- a/integration-tests/models/test_neox.py +++ b/integration-tests/models/test_neox.py @@ -3,7 +3,9 @@ import pytest @pytest.fixture(scope="module") def neox_handle(launcher): - with launcher("stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False) as handle: + with launcher( + "stabilityai/stablelm-tuned-alpha-3b", num_shard=1, use_flash_attention=False + ) as handle: yield handle diff --git a/integration-tests/models/test_neox_sharded.py b/integration-tests/models/test_neox_sharded.py index fd691a1a..8cee8765 100644 --- a/integration-tests/models/test_neox_sharded.py +++ b/integration-tests/models/test_neox_sharded.py @@ -3,7 +3,9 @@ import pytest @pytest.fixture(scope="module") def neox_sharded_handle(launcher): - with launcher("OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False) as handle: + with launcher( + "OpenAssistant/oasst-sft-1-pythia-12b", num_shard=2, use_flash_attention=False + ) as handle: yield handle diff --git a/server/custom_kernels/setup.py b/server/custom_kernels/setup.py index fe45b631..43b8ee4e 100644 --- a/server/custom_kernels/setup.py +++ b/server/custom_kernels/setup.py @@ -13,7 +13,7 @@ setup( name="custom_kernels.fused_attention_cuda", sources=["custom_kernels/fused_attention_cuda.cu"], extra_compile_args=["-arch=compute_80", "-std=c++17"], - ) + ), ], cmdclass={"build_ext": BuildExtension}, ) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 19b0ce63..f1b84a53 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -19,7 +19,10 @@ from text_generation_server.models.t5 import T5Sharded from text_generation_server.models.gpt_neox import GPTNeoxSharded try: - if torch.cuda.is_available() and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + if ( + torch.cuda.is_available() + and not os.getenv("USE_FLASH_ATTENTION", "").lower() == "false" + ): major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 is_sm8x = major == 8 and minor >= 0 diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 9b3353e9..8a35ffa8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -46,7 +46,6 @@ class LlamaRMSNorm(nn.Module): super().__init__() weight = weights.get_tensor(f"{prefix}.weight") - # assert weight.shape == (hidden_size,) self.weight = nn.Parameter(weight) self.variance_epsilon = eps @@ -103,7 +102,9 @@ class FlashLlamaAttention(torch.nn.Module): self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads - self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) + self.rotary_emb = PositionRotaryEmbedding.load( + prefix=f"{prefix}.rotary_emb", weights=weights + ) self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index d30095ef..0fe43bcb 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -90,10 +90,9 @@ class FlashNeoxAttention(torch.nn.Module): self.head_size = hidden_size // num_heads self.num_heads = self.num_heads // weights.process_group.size() - rotary_pct = config.rotary_pct - - rotary_ndims = int(self.head_size * rotary_pct) - self.rotary_emb = PositionRotaryEmbedding.load(prefix=f"{prefix}.rotary_emb", weights=weights) + self.rotary_emb = PositionRotaryEmbedding.load( + prefix=f"{prefix}.rotary_emb", weights=weights + ) self.softmax_scale = self.head_size ** (-0.5) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 443c636b..55195162 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,5 +1,3 @@ -import os - import torch import torch.distributed @@ -104,7 +102,6 @@ class FlashRWAttention(torch.nn.Module): config, prefix, weights, - reduce=True, ): super().__init__() self.num_heads = config.n_head @@ -395,7 +392,6 @@ class FlashRWLayer(nn.Module): config, prefix=f"{prefix}.self_attention", weights=weights, - reduce=False, ) self.post_attention_layernorm = ( FastLayerNorm.load( @@ -548,18 +544,7 @@ class FlashRWModel(FlashRWPreTrainedModel): if config.model_type == "RefinedWebModel": self.h = nn.ModuleList( [ - FlashRWLayer( - layer_id, - config, - weights - # config.n_head, - # config.n_head_kv, - # config.hidden_size, - # config.bias, - # config.layer_norm_epsilon, - # config.parallel_attn, - # process_group, - ) + FlashRWLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers) ] ) diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index 79fa1915..bf2656d1 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -48,7 +48,6 @@ from text_generation_server.utils.layers import ( ) - CUSTOM_KERNELS_ENABLED = False if not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": try: @@ -62,7 +61,6 @@ if not CUSTOM_KERNELS_ENABLED: logger.warning("We're not using custom kernels.") - def make_causal_mask( input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int ) -> torch.BoolTensor: @@ -70,10 +68,16 @@ def make_causal_mask( Make causal mask used for self-attention. """ batch_size, target_length = input_ids_shape - mask = torch.ones((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device) + mask = torch.ones( + (target_length, target_length + past_key_values_length), + dtype=torch.bool, + device=device, + ) mask = mask.triu(1 + past_key_values_length) - expanded_mask = mask.unsqueeze(0).expand(batch_size, target_length, target_length + past_key_values_length) + expanded_mask = mask.unsqueeze(0).expand( + batch_size, target_length, target_length + past_key_values_length + ) return expanded_mask @@ -89,7 +93,9 @@ def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: def prepare_attn_mask( - attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + past_key_values_length: int, ) -> torch.BoolTensor: # create causal mask # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] @@ -105,7 +111,9 @@ def prepare_attn_mask( # [batch_size, seq_length] -> [batch_size, tgt_length, src_length] expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length) combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask | combined_attention_mask ) return combined_attention_mask @@ -118,7 +126,6 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): """ - class GPTNeoXAttention(nn.Module): def __init__(self, config, prefix, weights): super().__init__() @@ -136,17 +143,21 @@ class GPTNeoXAttention(nn.Module): # ) # self.register_buffer("masked_bias", torch.tensor(-1e9)) self.rotary_emb = RotaryEmbedding( - self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base + self.rotary_ndims, + config.max_position_embeddings, + base=config.rotary_emb_base, ) self.rotary_emb.inv_freq = nn.Parameter( weights.get_tensor(f"{prefix}.rotary_emb.inv_freq") ) - self.inv_norm_factor = 1.0 / torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to( - torch.get_default_dtype() - ) + self.inv_norm_factor = 1.0 / torch.sqrt( + torch.tensor(self.head_size, dtype=torch.float32) + ).to(torch.get_default_dtype()) assert self.num_attention_heads % weights.process_group.size() == 0 - self.num_attention_heads = self.num_attention_heads // weights.process_group.size() + self.num_attention_heads = ( + self.num_attention_heads // weights.process_group.size() + ) self.query_key_value = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True ) @@ -214,10 +225,14 @@ class GPTNeoXAttention(nn.Module): present = (key, value) if use_cache else None # Compute attention - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output, attn_weights = self._attn( + query, key, value, attention_mask, head_mask + ) # Reshape outputs - attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) + attn_output = self._merge_heads( + attn_output, self.num_attention_heads, self.head_size + ) attn_output = self.dense(attn_output) @@ -248,7 +263,9 @@ class GPTNeoXAttention(nn.Module): # tensor [bs, num_attention_heads, seq_len, attn_head_size] tensor = tensor.permute(0, 2, 1, 3).contiguous() # -> [bs, seq_len, num_attention_heads, attn_head_size] - tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) + tensor = tensor.view( + tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size + ) # -> [bs, seq_len, hidden_size] return tensor @@ -258,7 +275,9 @@ class GPTNeoXAttention(nn.Module): batch_size, num_attention_heads, query_length, attn_head_size = query.size() key_length = key.size(-2) - query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + query = query.view( + batch_size * num_attention_heads, query_length, attn_head_size + ) key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) attn_scores = torch.zeros( 1, @@ -277,8 +296,12 @@ class GPTNeoXAttention(nn.Module): input_dtype = attn_scores.dtype if input_dtype in [torch.float16, torch.bfloat16]: attn_scores = attn_scores.to(torch.float) - attn_scores = torch.where(attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores) - attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + attn_scores = torch.where( + attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores + ) + attn_scores = attn_scores.view( + batch_size, num_attention_heads, query_length, key_length + ) attn_weights = nn.functional.softmax(attn_scores, dim=-1) attn_weights = attn_weights.to(value.dtype) @@ -294,7 +317,9 @@ class GPTNeoXAttention(nn.Module): class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings, base=10000, device=None): super().__init__() - self.true_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.true_inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2).float().to(device) / dim) + ) self.register_buffer("inv_freq", self.true_inv_freq) # Build here to make `torch.jit.trace` work. @@ -311,7 +336,9 @@ class RotaryEmbedding(torch.nn.Module): @staticmethod def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device): - t = torch.arange(max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype) + t = torch.arange( + max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype + ) freqs = torch.einsum("i,j->ij", t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) @@ -319,7 +346,11 @@ class RotaryEmbedding(torch.nn.Module): def forward(self, q, k, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached or self.cos_cached is None or self.sin_cached is None: + if ( + seq_len > self.max_seq_len_cached + or self.cos_cached is None + or self.sin_cached is None + ): if seq_len > self.max_seq_len_cached: self.max_seq_len_cached = seq_len self.cos_cached, self.sin_cached = self._create_cos_sin( @@ -371,11 +402,22 @@ class GPTNeoXLayer(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() self.use_parallel_residual = config.use_parallel_residual - self.input_layernorm = nn.LayerNorm.load(prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", weights=weights, eps=config.layer_norm_eps) - self.post_attention_layernorm = nn.LayerNorm.load(prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", weights=weights, eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights) - self.mlp = GPTNeoXMLP(config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights) - + self.input_layernorm = nn.LayerNorm.load( + prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.post_attention_layernorm = nn.LayerNorm.load( + prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.attention = GPTNeoXAttention( + config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights + ) + self.mlp = GPTNeoXMLP( + config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights + ) def forward( self, @@ -396,7 +438,9 @@ class GPTNeoXLayer(nn.Module): use_cache=use_cache, output_attentions=output_attentions, ) - attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) + attn_output = attention_layer_outputs[ + 0 + ] # output_attn: attn_output, present, (attn_weights) outputs = attention_layer_outputs[1:] if self.use_parallel_residual: @@ -413,7 +457,9 @@ class GPTNeoXLayer(nn.Module): hidden_states = mlp_output + attn_output if use_cache: - outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights) + outputs = ( + hidden_states, + ) + outputs # hidden_states, present, (attn_weights) else: outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) @@ -427,12 +473,22 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): self.num_attention_heads = config.num_attention_heads - self.embed_in = TensorParallelEmbedding(prefix="gpt_neox.embed_in", weights=weights) - self.layers = nn.ModuleList([GPTNeoXLayer(layer_id, config, weights) for layer_id in range(config.num_hidden_layers)]) - self.final_layer_norm = nn.LayerNorm.load(prefix="gpt_neox.final_layer_norm", weights=weights, eps=config.layer_norm_eps) + self.embed_in = TensorParallelEmbedding( + prefix="gpt_neox.embed_in", weights=weights + ) + self.layers = nn.ModuleList( + [ + GPTNeoXLayer(layer_id, config, weights) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm.load( + prefix="gpt_neox.final_layer_norm", + weights=weights, + eps=config.layer_norm_eps, + ) self.tp_world_size = weights.process_group.size() - def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -456,15 +512,25 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) elif input_ids is not None: input_shape = input_ids.size() elif inputs_embeds is not None: @@ -482,7 +548,9 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) + position_ids = torch.arange( + past_length, seq_length + past_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -499,7 +567,9 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): past_key_values_length = past_key_values[0][0].shape[-1] seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + attention_mask = torch.ones( + (batch_size, seq_length_with_past), device=hidden_states.device + ) else: attention_mask = attention_mask.to(hidden_states.device) @@ -548,7 +618,11 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_attentions] + if v is not None + ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, @@ -564,7 +638,9 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): def __init__(self, config, weights): super().__init__(config) self.gpt_neox = GPTNeoXModel(config, weights) - self.embed_out = TensorParallelHead.load(config, prefix="embed_out", weights=weights) + self.embed_out = TensorParallelHead.load( + config, prefix="embed_out", weights=weights + ) def forward( self, @@ -619,7 +695,9 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): >>> prediction_logits = outputs.logits ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) outputs = self.gpt_neox( input_ids, @@ -645,7 +723,9 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): shift_logits = lm_logits[:, :-1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = CrossEntropyLoss() - lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) + lm_loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) + ) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -660,7 +740,12 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, ): input_shape = input_ids.shape @@ -700,6 +785,10 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): reordered_past = () for layer_past in past_key_values: reordered_past += ( - tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], + tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past[:2] + ) + + layer_past[2:], ) return reordered_past diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index a4e6249b..51862e3c 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -845,7 +845,6 @@ class T5Stack(T5PreTrainedModel): ), "You have to initialize the model with valid token embeddings" inputs_embeds = self.embed_tokens(input_ids) - batch_size, seq_length = input_shape # required mask seq length can be calculated via length of past @@ -1026,7 +1025,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel): embed_tokens=self.shared, ) - self.lm_head = TensorParallelHead.load(config, prefix="lm_head", weights=weights) + self.lm_head = TensorParallelHead.load( + config, prefix="lm_head", weights=weights + ) def forward( self, diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 846b9051..5f963bfb 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -1,28 +1,19 @@ import torch import torch.distributed -from pathlib import Path -from accelerate import init_empty_weights from opentelemetry import trace -from safetensors import safe_open -from transformers import AutoTokenizer, AutoConfig -from typing import Optional, List +from transformers import AutoTokenizer +from typing import Optional from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_rw_modeling import ( RWConfig, FlashRWForCausalLM, - TensorParallelEmbedding, - TensorParallelRowLinear, - TensorParallelColumnLinear, ) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, - download_weights, - weight_hub_files, Weights, - LocalEntryNotFoundError, ) tracer = trace.get_tracer(__name__) @@ -73,79 +64,3 @@ class FlashRWSharded(FlashCausalLM): rank=rank, world_size=world_size, ) - - # @staticmethod - # def load_weights( - # model, - # filenames: List[str], - # quantize: Optional[str], - # device: torch.device, - # dtype: torch.dtype, - # rank: int, - # world_size: int, - # ): - # parameters = dict(model.named_parameters()) - # for file in filenames: - # with safe_open( - # file, framework="pt", device=str(device) if quantize is None else "cpu" - # ) as f: - # for name in f.keys(): - # module_name, param_name = name.rsplit(".", 1) - # module = model.get_submodule(module_name) - - # current_parameter_tensor = parameters.get(name, None) - - # slice_ = f.get_slice(name) - - # if isinstance(module, TensorParallelColumnLinear): - # size = slice_.get_shape()[0] - # block_size = size // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # tensor = slice_[start:stop] - # elif isinstance(module, TensorParallelRowLinear): - # if param_name == "weight": - # size = slice_.get_shape()[1] - # block_size = size // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # tensor = slice_[:, start:stop] - # else: - # tensor = slice_[:] - # # XXX: Hack for Rowlinear to add the bias only once. - # if rank != 0: - # tensor = torch.zeros_like(tensor) - # elif isinstance(module, TensorParallelEmbedding): - # size = slice_.get_shape()[0] - # block_size = size // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # tensor = slice_[start:stop] - # elif name == "lm_head.weight" and model.transformer.tp_embeddings: - # size = slice_.get_shape()[0] - # block_size = size // world_size - # start = rank * block_size - # stop = (rank + 1) * block_size - # tensor = slice_[start:stop] - # else: - # try: - # tensor = slice_[:] - # except: - # tensor = f.get_tensor(name) - - # if ( - # current_parameter_tensor is not None - # and current_parameter_tensor.shape != tensor.shape - # ): - # raise ValueError( - # f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" - # ) - - # tensor = tensor.contiguous().to(dtype) - - # if current_parameter_tensor is not None: - # module._parameters[param_name] = tensor - # else: - # module._buffers[param_name] = tensor - - # model.post_load_weights(quantize) diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index a907ee6c..01e1c773 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -182,6 +182,7 @@ class GalacticaSharded(CausalLM): tp_parallel=True, trust_remote_code=trust_remote_code, ) + config.quantize = quantize tokenizer.pad_token_id = config.pad_token_id torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 5c854348..0abf0239 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -1,13 +1,10 @@ import torch import torch.distributed -from typing import List, Optional +from typing import Optional -from accelerate import init_empty_weights -from safetensors import safe_open from transformers import ( AutoTokenizer, - AutoModelForCausalLM, AutoConfig, ) from text_generation_server.models import CausalLM diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index e844c36f..c89462fc 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -35,7 +35,9 @@ class T5Sharded(Seq2SeqLM): device = torch.device("cpu") dtype = torch.float32 - config = AutoConfig.from_pretrained(model_id, revision=revision, + config = AutoConfig.from_pretrained( + model_id, + revision=revision, trust_remote_code=trust_remote_code, ) config.quantize = quantize diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index 9443d21b..2ed7673c 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -10,8 +10,8 @@ from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.utils import ( LocalEntryNotFoundError, - EntryNotFoundError, # Import here to ease try/except in other part of the lib - RevisionNotFoundError + EntryNotFoundError, + RevisionNotFoundError, # Import here to ease try/except in other part of the lib ) WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 5945f210..ee32a0dc 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -1,4 +1,5 @@ import torch +import torch.distributed from torch import nn from torch.nn import functional as F @@ -44,14 +45,14 @@ class FastLinear(nn.Module): else: self.bias = None - @staticmethod - def load(config, prefix: str, weights, bias: bool): + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_tensor(f"{prefix}.weight") if bias: bias = weights.get_tensor(f"{prefix}.bias") else: bias = None - return FastLinear(weight, bias) + return cls(weight, bias) def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.weight, self.bias) @@ -130,9 +131,7 @@ def get_linear(weight, bias, quantize): elif quantize == "gptq": raise NotImplementedError("Soon") else: - raise NotImplementedError( - f"Quantization `{config.quantize}` is not implemented yet." - ) + raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear @@ -170,17 +169,17 @@ class TensorParallelHead(SuperLayer): class TensorParallelColumnLinear(SuperLayer): - @staticmethod - def load(config, prefix: str, weights, bias: bool): + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_sharded(f"{prefix}.weight", dim=0) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return cls(get_linear(weight, bias, config.quantize)) - @staticmethod - def load_multi(config, prefixes: List[str], weights, bias: bool, dim: int): + @classmethod + def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) @@ -189,7 +188,7 @@ class TensorParallelColumnLinear(SuperLayer): bias = torch.cat(b, dim=0) else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return cls(get_linear(weight, bias, config.quantize)) class TensorParallelRowLinear(SuperLayer): @@ -197,15 +196,15 @@ class TensorParallelRowLinear(SuperLayer): super().__init__(linear) self.process_group = process_group - @staticmethod - def load(config, prefix: str, weights, bias: bool): + @classmethod + def load(cls, config, prefix: str, weights, bias: bool): weight = weights.get_sharded(f"{prefix}.weight", dim=1) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") else: bias = None - return TensorParallelRowLinear( + return cls( get_linear(weight, bias, config.quantize), process_group=weights.process_group, ) @@ -308,22 +307,22 @@ try: self._cos_k_cached = None self._sin_k_cached = None - @staticmethod - def static(dim, base, device): + @classmethod + def static(cls, dim, base, device): inv_freq = 1.0 / ( base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) ) - return PositionRotaryEmbedding(inv_freq) + return cls(inv_freq) - @staticmethod - def load(prefix, weights): + @classmethod + def load(cls, prefix, weights): # XXX: Always load this in float32 ! dtype = weights.dtype weights.dtype = torch.float32 inv_freq = weights.get_tensor(f"{prefix}.inv_freq") weights.dtype = dtype - return PositionRotaryEmbedding(inv_freq) + return cls(inv_freq) def _update_cos_sin_cache(self, dtype, device, seqlen): # Reset the tables if the sequence length has changed,