mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix
This commit is contained in:
parent
885591acdb
commit
f4eec092c6
@ -47,6 +47,7 @@ class GoldenGateConfig(PretrainedConfig):
|
|||||||
num_hidden_layers=28,
|
num_hidden_layers=28,
|
||||||
num_attention_heads=16,
|
num_attention_heads=16,
|
||||||
num_key_value_heads=16,
|
num_key_value_heads=16,
|
||||||
|
head_dim=256,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
max_position_embeddings=8192,
|
max_position_embeddings=8192,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
@ -65,6 +66,7 @@ class GoldenGateConfig(PretrainedConfig):
|
|||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
self.head_dim = head_dim
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
@ -91,6 +93,12 @@ class GoldenGateConfig(PretrainedConfig):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class GoldenGateFastRMSNorm(FastRMSNorm):
|
||||||
|
@classmethod
|
||||||
|
def load(cls, prefix, weights, eps=1e-6):
|
||||||
|
weight = weights.get_tensor(f"{prefix}.weight") + 1
|
||||||
|
return cls(weight, eps)
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
if config.num_attention_heads != config.num_key_value_heads:
|
if config.num_attention_heads != config.num_key_value_heads:
|
||||||
@ -106,7 +114,6 @@ def load_attention(config, prefix, weights):
|
|||||||
|
|
||||||
|
|
||||||
def _load_gqa(config, prefix: str, weights):
|
def _load_gqa(config, prefix: str, weights):
|
||||||
assert config.hidden_size % config.num_attention_heads == 0
|
|
||||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
|
||||||
weight = weights.get_multi_weights_col(
|
weight = weights.get_multi_weights_col(
|
||||||
@ -118,7 +125,7 @@ def _load_gqa(config, prefix: str, weights):
|
|||||||
if config.quantize not in ["gptq", "awq"]:
|
if config.quantize not in ["gptq", "awq"]:
|
||||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||||
|
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
head_size = config.head_dim
|
||||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||||
assert list(weight.shape) == [
|
assert list(weight.shape) == [
|
||||||
@ -140,8 +147,7 @@ class FlashGoldenGateAttention(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.head_size = config.head_dim
|
||||||
self.head_size = self.hidden_size // self.num_heads
|
|
||||||
|
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config,
|
config=config,
|
||||||
@ -283,10 +289,10 @@ class FlashGoldenGateLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.mlp = GoldenGateMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = GoldenGateMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
self.input_layernorm = FastRMSNorm.load(
|
self.input_layernorm = GoldenGateFastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
self.post_attention_layernorm = FastRMSNorm.load(
|
self.post_attention_layernorm = GoldenGateFastRMSNorm.load(
|
||||||
prefix=f"{prefix}.post_attention_layernorm",
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
@ -353,7 +359,7 @@ class FlashGoldenGateModel(torch.nn.Module):
|
|||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = GoldenGateFastRMSNorm.load(
|
||||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -45,7 +45,6 @@ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
|||||||
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
|
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
|
||||||
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
|
||||||
that your responses are socially unbiased and positive in nature.
|
that your responses are socially unbiased and positive in nature.
|
||||||
|
|
||||||
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
|
||||||
correct. If you don't know the answer to a question, please don't share false information."""
|
correct. If you don't know the answer to a question, please don't share false information."""
|
||||||
# fmt: on
|
# fmt: on
|
||||||
@ -54,26 +53,19 @@ correct. If you don't know the answer to a question, please don't share false in
|
|||||||
class GoldenGateTokenizerFast(PreTrainedTokenizerFast):
|
class GoldenGateTokenizerFast(PreTrainedTokenizerFast):
|
||||||
"""
|
"""
|
||||||
Construct a GoldenGate tokenizer. Based on byte-level Byte-Pair-Encoding.
|
Construct a GoldenGate tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||||
|
|
||||||
This uses notably ByteFallback and no normalization.
|
This uses notably ByteFallback and no normalization.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import GoldenGateTokenizerFast
|
>>> from transformers import GoldenGateTokenizerFast
|
||||||
|
|
||||||
>>> tokenizer = GoldenGateTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
|
>>> tokenizer = GoldenGateTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
>>> tokenizer.encode("Hello this is a test")
|
>>> tokenizer.encode("Hello this is a test")
|
||||||
[1, 15043, 445, 338, 263, 1243]
|
[1, 15043, 445, 338, 263, 1243]
|
||||||
```
|
```
|
||||||
|
|
||||||
If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
|
If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
|
||||||
call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
|
call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
|
||||||
values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
|
values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
|
||||||
[post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
|
[post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
|
||||||
|
|
||||||
|
|
||||||
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
||||||
refer to this superclass for more information regarding those methods.
|
refer to this superclass for more information regarding those methods.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab_file (`str`, *optional*):
|
vocab_file (`str`, *optional*):
|
||||||
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
|
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
|
||||||
|
Loading…
Reference in New Issue
Block a user