This commit is contained in:
OlivierDehaene 2024-01-25 19:34:14 +01:00
parent 885591acdb
commit f4eec092c6
2 changed files with 14 additions and 16 deletions

View File

@ -47,6 +47,7 @@ class GoldenGateConfig(PretrainedConfig):
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
hidden_act="gelu",
max_position_embeddings=8192,
initializer_range=0.02,
@ -65,6 +66,7 @@ class GoldenGateConfig(PretrainedConfig):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.head_dim = head_dim
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
@ -91,6 +93,12 @@ class GoldenGateConfig(PretrainedConfig):
**kwargs,
)
class GoldenGateFastRMSNorm(FastRMSNorm):
@classmethod
def load(cls, prefix, weights, eps=1e-6):
weight = weights.get_tensor(f"{prefix}.weight") + 1
return cls(weight, eps)
def load_attention(config, prefix, weights):
if config.num_attention_heads != config.num_key_value_heads:
@ -106,7 +114,6 @@ def load_attention(config, prefix, weights):
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
weight = weights.get_multi_weights_col(
@ -118,7 +125,7 @@ def _load_gqa(config, prefix: str, weights):
if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads
head_size = config.head_dim
num_heads = config.num_attention_heads // weights.process_group.size()
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
assert list(weight.shape) == [
@ -140,8 +147,7 @@ class FlashGoldenGateAttention(torch.nn.Module):
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.head_size = config.head_dim
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
@ -283,10 +289,10 @@ class FlashGoldenGateLayer(nn.Module):
)
self.mlp = GoldenGateMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load(
self.input_layernorm = GoldenGateFastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
self.post_attention_layernorm = GoldenGateFastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
@ -353,7 +359,7 @@ class FlashGoldenGateModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
self.norm = GoldenGateFastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
)

View File

@ -45,7 +45,6 @@ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
@ -54,26 +53,19 @@ correct. If you don't know the answer to a question, please don't share false in
class GoldenGateTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a GoldenGate tokenizer. Based on byte-level Byte-Pair-Encoding.
This uses notably ByteFallback and no normalization.
```python
>>> from transformers import GoldenGateTokenizerFast
>>> tokenizer = GoldenGateTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
>>> tokenizer.encode("Hello this is a test")
[1, 15043, 445, 338, 263, 1243]
```
If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
[post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
vocab_file (`str`, *optional*):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that