This commit is contained in:
OlivierDehaene 2024-01-25 19:34:14 +01:00
parent 885591acdb
commit f4eec092c6
2 changed files with 14 additions and 16 deletions

View File

@ -47,6 +47,7 @@ class GoldenGateConfig(PretrainedConfig):
num_hidden_layers=28, num_hidden_layers=28,
num_attention_heads=16, num_attention_heads=16,
num_key_value_heads=16, num_key_value_heads=16,
head_dim=256,
hidden_act="gelu", hidden_act="gelu",
max_position_embeddings=8192, max_position_embeddings=8192,
initializer_range=0.02, initializer_range=0.02,
@ -65,6 +66,7 @@ class GoldenGateConfig(PretrainedConfig):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_dim = head_dim
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
@ -91,6 +93,12 @@ class GoldenGateConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
class GoldenGateFastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix, weights, eps=1e-6):
weight = weights.get_tensor(f"{prefix}.weight") + 1
return cls(weight, eps)
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
@ -106,7 +114,6 @@ def load_attention(config, prefix, weights):
def _load_gqa(config, prefix: str, weights): def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0 assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col( weight = weights.get_multi_weights_col(
@ -118,7 +125,7 @@ def _load_gqa(config, prefix: str, weights):
if config.quantize not in ["gptq", "awq"]: if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size() num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [ assert list(weight.shape) == [
@ -140,8 +147,7 @@ class FlashGoldenGateAttention(torch.nn.Module):
): ):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.head_size = config.head_dim
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, config=config,
@ -283,10 +289,10 @@ class FlashGoldenGateLayer(nn.Module):
) )
self.mlp = GoldenGateMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = GoldenGateMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = GoldenGateFastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
) )
self.post_attention_layernorm = FastRMSNorm.load( self.post_attention_layernorm = GoldenGateFastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
@ -353,7 +359,7 @@ class FlashGoldenGateModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = FastRMSNorm.load( self.norm = GoldenGateFastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix="model.norm", weights=weights, eps=config.rms_norm_eps
) )

View File

@ -45,7 +45,6 @@ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
that your responses are socially unbiased and positive in nature. that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
correct. If you don't know the answer to a question, please don't share false information.""" correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on # fmt: on
@ -54,26 +53,19 @@ correct. If you don't know the answer to a question, please don't share false in
class GoldenGateTokenizerFast(PreTrainedTokenizerFast): class GoldenGateTokenizerFast(PreTrainedTokenizerFast):
""" """
Construct a GoldenGate tokenizer. Based on byte-level Byte-Pair-Encoding. Construct a GoldenGate tokenizer. Based on byte-level Byte-Pair-Encoding.
This uses notably ByteFallback and no normalization. This uses notably ByteFallback and no normalization.
```python ```python
>>> from transformers import GoldenGateTokenizerFast >>> from transformers import GoldenGateTokenizerFast
>>> tokenizer = GoldenGateTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") >>> tokenizer = GoldenGateTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
>>> tokenizer.encode("Hello this is a test") >>> tokenizer.encode("Hello this is a test")
[1, 15043, 445, 338, 263, 1243] [1, 15043, 445, 338, 263, 1243]
``` ```
If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
values of the first token and final token of an encoded sequence will not be correct). For more details, checkout values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
[post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation. [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods. refer to this superclass for more information regarding those methods.
Args: Args:
vocab_file (`str`, *optional*): vocab_file (`str`, *optional*):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that