From f4eec092c6ace89f738de7759072d7098332ae8e Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 25 Jan 2024 19:34:14 +0100 Subject: [PATCH] fix --- .../flash_golden_gate_modeling.py | 20 ++++++++++++------- .../models/custom_modeling/temp_tok.py | 10 +--------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_golden_gate_modeling.py b/server/text_generation_server/models/custom_modeling/flash_golden_gate_modeling.py index 4d80f951..ca5b5952 100644 --- a/server/text_generation_server/models/custom_modeling/flash_golden_gate_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_golden_gate_modeling.py @@ -47,6 +47,7 @@ class GoldenGateConfig(PretrainedConfig): num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=16, + head_dim=256, hidden_act="gelu", max_position_embeddings=8192, initializer_range=0.02, @@ -65,6 +66,7 @@ class GoldenGateConfig(PretrainedConfig): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size + self.head_dim = head_dim self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads @@ -91,6 +93,12 @@ class GoldenGateConfig(PretrainedConfig): **kwargs, ) +class GoldenGateFastRMSNorm(FastRMSNorm): + @classmethod + def load(cls, prefix, weights, eps=1e-6): + weight = weights.get_tensor(f"{prefix}.weight") + 1 + return cls(weight, eps) + def load_attention(config, prefix, weights): if config.num_attention_heads != config.num_key_value_heads: @@ -106,7 +114,6 @@ def load_attention(config, prefix, weights): def _load_gqa(config, prefix: str, weights): - assert config.hidden_size % config.num_attention_heads == 0 assert config.num_attention_heads % weights.process_group.size() == 0 weight = weights.get_multi_weights_col( @@ -118,7 +125,7 @@ def _load_gqa(config, prefix: str, weights): if config.quantize not in ["gptq", "awq"]: weight = weight.to(dtype=weights.dtype).to(device=weights.device) - head_size = config.hidden_size // config.num_attention_heads + head_size = config.head_dim num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() assert list(weight.shape) == [ @@ -140,8 +147,7 @@ class FlashGoldenGateAttention(torch.nn.Module): ): super().__init__() self.num_heads = config.num_attention_heads - self.hidden_size = config.hidden_size - self.head_size = self.hidden_size // self.num_heads + self.head_size = config.head_dim self.rotary_emb = PositionRotaryEmbedding.static( config=config, @@ -283,10 +289,10 @@ class FlashGoldenGateLayer(nn.Module): ) self.mlp = GoldenGateMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - self.input_layernorm = FastRMSNorm.load( + self.input_layernorm = GoldenGateFastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) - self.post_attention_layernorm = FastRMSNorm.load( + self.post_attention_layernorm = GoldenGateFastRMSNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps, @@ -353,7 +359,7 @@ class FlashGoldenGateModel(torch.nn.Module): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = FastRMSNorm.load( + self.norm = GoldenGateFastRMSNorm.load( prefix="model.norm", weights=weights, eps=config.rms_norm_eps ) diff --git a/server/text_generation_server/models/custom_modeling/temp_tok.py b/server/text_generation_server/models/custom_modeling/temp_tok.py index 3c4fa8d9..06516cbc 100644 --- a/server/text_generation_server/models/custom_modeling/temp_tok.py +++ b/server/text_generation_server/models/custom_modeling/temp_tok.py @@ -45,7 +45,6 @@ B_SYS, E_SYS = "<>\n", "\n<>\n\n" DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ that your responses are socially unbiased and positive in nature. - If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ correct. If you don't know the answer to a question, please don't share false information.""" # fmt: on @@ -54,26 +53,19 @@ correct. If you don't know the answer to a question, please don't share false in class GoldenGateTokenizerFast(PreTrainedTokenizerFast): """ Construct a GoldenGate tokenizer. Based on byte-level Byte-Pair-Encoding. - This uses notably ByteFallback and no normalization. - ```python >>> from transformers import GoldenGateTokenizerFast - >>> tokenizer = GoldenGateTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") >>> tokenizer.encode("Hello this is a test") [1, 15043, 445, 338, 263, 1243] ``` - If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the values of the first token and final token of an encoded sequence will not be correct). For more details, checkout [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation. - - This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. - Args: vocab_file (`str`, *optional*): [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that @@ -221,4 +213,4 @@ class GoldenGateTokenizerFast(PreTrainedTokenizerFast): if token_ids_1 is not None: output = output + bos_token_id + token_ids_1 + eos_token_id - return output + return output \ No newline at end of file