From 885591acdbe37700b646fd679fdc19cab869795e Mon Sep 17 00:00:00 2001
From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com>
Date: Tue, 23 Jan 2024 10:30:17 +0100
Subject: [PATCH] feat: add support for golden gate
---
.../text_generation_server/models/__init__.py | 23 +
.../flash_golden_gate_modeling.py | 441 ++++++++++++++++++
.../models/custom_modeling/temp_tok.py | 224 +++++++++
.../models/flash_golden_gate.py | 105 +++++
4 files changed, 793 insertions(+)
create mode 100644 server/text_generation_server/models/custom_modeling/flash_golden_gate_modeling.py
create mode 100644 server/text_generation_server/models/custom_modeling/temp_tok.py
create mode 100644 server/text_generation_server/models/flash_golden_gate.py
diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py
index da7d8416..b1994314 100644
--- a/server/text_generation_server/models/__init__.py
+++ b/server/text_generation_server/models/__init__.py
@@ -52,6 +52,9 @@ try:
from text_generation_server.models.flash_llama import (
FlashLlama,
)
+ from text_generation_server.models.flash_golden_gate import (
+ FlashGoldenGate,
+ )
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
)
@@ -312,6 +315,26 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
+ if model_type == "golden_gate":
+ if FLASH_ATTENTION:
+ return FlashGoldenGate(
+ model_id,
+ revision,
+ quantize=quantize,
+ dtype=dtype,
+ trust_remote_code=trust_remote_code,
+ use_medusa=use_medusa,
+ )
+ elif sharded:
+ raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate"))
+ else:
+ return CausalLM(
+ model_id,
+ revision,
+ quantize=quantize,
+ dtype=dtype,
+ trust_remote_code=trust_remote_code,
+ )
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
if sharded:
diff --git a/server/text_generation_server/models/custom_modeling/flash_golden_gate_modeling.py b/server/text_generation_server/models/custom_modeling/flash_golden_gate_modeling.py
new file mode 100644
index 00000000..4d80f951
--- /dev/null
+++ b/server/text_generation_server/models/custom_modeling/flash_golden_gate_modeling.py
@@ -0,0 +1,441 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.distributed
+
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from typing import Optional, List, Tuple
+
+from text_generation_server.utils import paged_attention, flash_attn
+from text_generation_server.utils.layers import (
+ TensorParallelRowLinear,
+ TensorParallelColumnLinear,
+ TensorParallelEmbedding,
+ PositionRotaryEmbedding,
+ TensorParallelHead,
+ get_linear,
+ FastRMSNorm,
+)
+
+
+class GoldenGateConfig(PretrainedConfig):
+ def __init__(
+ self,
+ vocab_size=256128,
+ hidden_size=3072,
+ intermediate_size=24576,
+ num_hidden_layers=28,
+ num_attention_heads=16,
+ num_key_value_heads=16,
+ hidden_act="gelu",
+ max_position_embeddings=8192,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=True,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+def load_attention(config, prefix, weights):
+ if config.num_attention_heads != config.num_key_value_heads:
+ return _load_gqa(config, prefix, weights)
+ else:
+ return TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ dim=0,
+ weights=weights,
+ bias=False,
+ )
+
+
+def _load_gqa(config, prefix: str, weights):
+ assert config.hidden_size % config.num_attention_heads == 0
+ assert config.num_attention_heads % weights.process_group.size() == 0
+
+ weight = weights.get_multi_weights_col(
+ prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
+ quantize=config.quantize,
+ dim=0,
+ )
+
+ if config.quantize not in ["gptq", "awq"]:
+ weight = weight.to(dtype=weights.dtype).to(device=weights.device)
+
+ head_size = config.hidden_size // config.num_attention_heads
+ num_heads = config.num_attention_heads // weights.process_group.size()
+ num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
+ assert list(weight.shape) == [
+ (num_heads + 2 * num_key_value_heads) * head_size,
+ config.hidden_size,
+ ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
+
+ return TensorParallelColumnLinear(
+ get_linear(weight, bias=None, quantize=config.quantize)
+ )
+
+
+class FlashGoldenGateAttention(torch.nn.Module):
+ def __init__(
+ self,
+ prefix: str,
+ config,
+ weights,
+ ):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ self.head_size = self.hidden_size // self.num_heads
+
+ self.rotary_emb = PositionRotaryEmbedding.static(
+ config=config,
+ dim=self.head_size,
+ base=config.rope_theta,
+ device=weights.device,
+ )
+
+ self.softmax_scale = self.head_size**-0.5
+
+ if self.num_heads % weights.process_group.size() != 0:
+ raise ValueError(
+ f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
+ f"and `num_shards`: {weights.process_group.size()}"
+ )
+ self.num_heads = self.num_heads // weights.process_group.size()
+ self.num_key_value_heads = (
+ config.num_key_value_heads // weights.process_group.size()
+ )
+
+ self.query_key_value = load_attention(config, prefix, weights)
+
+ self.o_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.o_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.num_groups = self.num_heads // self.num_key_value_heads
+ self.kv_head_mapping = torch.arange(
+ 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
+ ).repeat_interleave(self.num_groups)
+
+ def forward(
+ self,
+ hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ block_tables,
+ slots,
+ input_lengths,
+ max_s,
+ ):
+ qkv = self.query_key_value(hidden_states)
+ query, kv = qkv.split(
+ [
+ self.head_size * self.num_heads,
+ 2 * self.head_size * self.num_key_value_heads,
+ ],
+ dim=1,
+ )
+ query = query.view(-1, self.num_heads, self.head_size)
+ kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
+
+ self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
+
+ paged_attention.reshape_and_cache(
+ kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
+ )
+
+ # output tensor
+ attn_output = torch.empty_like(query)
+
+ # Prefill
+ if cu_seqlen_prefill is not None:
+ # flash attention
+ flash_attn.attention(
+ query,
+ torch.select(kv, dim=1, index=0),
+ torch.select(kv, dim=1, index=1),
+ attn_output,
+ cu_seqlen_prefill,
+ max_s,
+ self.softmax_scale,
+ )
+ # Decode
+ else:
+ paged_attention.attention(
+ attn_output,
+ query,
+ kv_cache[0],
+ kv_cache[1],
+ self.kv_head_mapping,
+ self.softmax_scale,
+ block_tables,
+ input_lengths,
+ max_s,
+ )
+
+ return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
+
+
+class GoldenGateMLP(nn.Module):
+ def __init__(self, prefix, config, weights):
+ super().__init__()
+ act = config.hidden_act
+ self.act = (
+ ACT2FN[act]
+ if "gelu" not in act
+ else lambda x: torch.nn.functional.gelu(
+ x,
+ approximate="tanh"
+ if act in ["gelu_fast", "gelu_pytorch_tanh"]
+ else "none",
+ )
+ )
+ # Fuse gate and up proj
+ self.gate_up_proj = TensorParallelColumnLinear.load_multi(
+ config,
+ prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
+ weights=weights,
+ dim=0,
+ bias=False,
+ )
+ self.down_proj = TensorParallelRowLinear.load(
+ config,
+ prefix=f"{prefix}.down_proj",
+ weights=weights,
+ bias=False,
+ )
+ self.intermediate_size = (
+ config.intermediate_size // weights.process_group.size()
+ )
+
+ def forward(self, hidden_states):
+ gate_up_states = self.gate_up_proj(hidden_states)
+ gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
+ return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
+
+
+class FlashGoldenGateLayer(nn.Module):
+ def __init__(self, layer_id, config, weights):
+ super().__init__()
+ prefix = f"model.layers.{layer_id}"
+ self.self_attn = FlashGoldenGateAttention(
+ prefix=f"{prefix}.self_attn", config=config, weights=weights
+ )
+ self.mlp = GoldenGateMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
+
+ self.input_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
+ )
+ self.post_attention_layernorm = FastRMSNorm.load(
+ prefix=f"{prefix}.post_attention_layernorm",
+ weights=weights,
+ eps=config.rms_norm_eps,
+ )
+
+ def forward(
+ self,
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ block_tables,
+ slots,
+ input_lengths,
+ max_s,
+ ):
+ normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
+
+ # Self Attention
+ attn_output = self.self_attn(
+ normed_hidden_states,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache,
+ block_tables,
+ slots,
+ input_lengths,
+ max_s,
+ )
+
+ # faster post attention rms norm
+ normed_attn_res_output, attn_res = self.post_attention_layernorm(
+ attn_output, res
+ )
+
+ mlp_output = self.mlp(normed_attn_res_output)
+
+ return mlp_output, attn_res
+
+
+class FlashGoldenGateModel(torch.nn.Module):
+ def __init__(self, config, weights):
+ super().__init__()
+
+ process_group = weights.process_group
+ self.tp_rank = process_group.rank()
+ self.tp_world_size = process_group.size()
+ embed_norm = config.hidden_size ** 0.5
+ self.embed_tokens = TensorParallelEmbedding(
+ prefix="model.embed_tokens", weights=weights
+ )
+ self.embed_tokens.weight *= embed_norm
+
+ self.layers = nn.ModuleList(
+ [
+ FlashGoldenGateLayer(
+ layer_id,
+ config,
+ weights,
+ )
+ for layer_id in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = FastRMSNorm.load(
+ prefix="model.norm", weights=weights, eps=config.rms_norm_eps
+ )
+
+ self.gradient_checkpointing = False
+
+ self.head_size = self.layers[0].self_attn.head_size
+ self.num_heads = self.layers[0].self_attn.num_heads
+ self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ block_tables: torch.Tensor,
+ slots: torch.Tensor,
+ input_lengths: torch.Tensor,
+ max_s: int,
+ ) -> torch.Tensor:
+ hidden_states = self.embed_tokens(input_ids)
+
+ # Get rotary cos and sin for this forward
+ # Avoid to index in each layer
+ cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
+ position_ids, max_s, hidden_states.dtype
+ )
+
+ residual = None
+ for i, layer in enumerate(self.layers):
+ hidden_states, residual = layer(
+ hidden_states,
+ residual,
+ cos,
+ sin,
+ cu_seqlen_prefill,
+ kv_cache[i],
+ block_tables,
+ slots,
+ input_lengths,
+ max_s,
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+class FlashGoldenGateForCausalLM(torch.nn.Module):
+ def __init__(self, config, weights):
+ super().__init__()
+
+ self.model = FlashGoldenGateModel(config, weights)
+ self.lm_head = TensorParallelHead.load(
+ config,
+ prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
+ weights=weights,
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ cu_seqlen_prefill: Optional[torch.Tensor],
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
+ block_tables: torch.Tensor,
+ slots: torch.Tensor,
+ input_lengths: torch.Tensor,
+ max_s: int,
+ lm_head_indices: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ hidden_states = self.model(
+ input_ids,
+ position_ids,
+ cu_seqlen_prefill,
+ kv_cache,
+ block_tables,
+ slots,
+ input_lengths,
+ max_s,
+ )
+ if lm_head_indices is not None:
+ hidden_states = hidden_states[lm_head_indices]
+ logits = self.lm_head(hidden_states)
+ return logits
diff --git a/server/text_generation_server/models/custom_modeling/temp_tok.py b/server/text_generation_server/models/custom_modeling/temp_tok.py
new file mode 100644
index 00000000..3c4fa8d9
--- /dev/null
+++ b/server/text_generation_server/models/custom_modeling/temp_tok.py
@@ -0,0 +1,224 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from shutil import copyfile
+from typing import Optional, Tuple
+
+from tokenizers import processors
+
+from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
+from transformers.utils import logging
+from transformers.utils.versions import require_version
+
+
+require_version("tokenizers>=0.13.3")
+
+GoldenGateTokenizer = None
+
+logger = logging.get_logger(__name__)
+VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
+
+PRETRAINED_VOCAB_FILES_MAP = {
+ "vocab_file": {
+ "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
+ },
+ "tokenizer_file": {
+ "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
+ },
+}
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<>\n", "\n<>\n\n"
+
+# fmt: off
+DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
+answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
+ that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
+correct. If you don't know the answer to a question, please don't share false information."""
+# fmt: on
+
+
+class GoldenGateTokenizerFast(PreTrainedTokenizerFast):
+ """
+ Construct a GoldenGate tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ This uses notably ByteFallback and no normalization.
+
+ ```python
+ >>> from transformers import GoldenGateTokenizerFast
+
+ >>> tokenizer = GoldenGateTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
+ >>> tokenizer.encode("Hello this is a test")
+ [1, 15043, 445, 338, 263, 1243]
+ ```
+
+ If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
+ call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
+ values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
+ [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
+
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`, *optional*):
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
+ contains the vocabulary necessary to instantiate a tokenizer.
+ tokenizer_file (`str`, *optional*):
+ [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
+ contains everything needed to load the tokenizer.
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+ Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
+ extra spaces.
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`):
+ The end of sequence token.
+ add_bos_token (`bool`, *optional*, defaults to `True`):
+ Whether or not to add an `bos_token` at the start of sequences.
+ add_eos_token (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an `eos_token` at the end of sequences.
+ use_default_system_prompt (`bool`, *optional*, defaults to `False`):
+ Whether or not the default system prompt for GoldenGate should be used.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ slow_tokenizer_class = GoldenGateTokenizer
+ padding_side = "left"
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ clean_up_tokenization_spaces=False,
+ unk_token="",
+ bos_token="",
+ eos_token="",
+ pad_token="",
+ add_bos_token=True,
+ add_eos_token=False,
+ use_default_system_prompt=False,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file=vocab_file,
+ tokenizer_file=tokenizer_file,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
+ use_default_system_prompt=use_default_system_prompt,
+ **kwargs,
+ )
+ self._add_bos_token = add_bos_token
+ self._add_eos_token = add_eos_token
+ self.update_post_processor()
+ self.use_default_system_prompt = use_default_system_prompt
+ self.vocab_file = vocab_file
+
+ @property
+ def can_save_slow_tokenizer(self) -> bool:
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
+
+ def update_post_processor(self):
+ """
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
+ """
+ bos = self.bos_token
+ bos_token_id = self.bos_token_id
+ if bos is None and self.add_bos_token:
+ raise ValueError("add_bos_token = True but bos_token = None")
+
+ eos = self.eos_token
+ eos_token_id = self.eos_token_id
+ if eos is None and self.add_eos_token:
+ raise ValueError("add_eos_token = True but eos_token = None")
+
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
+
+ special_tokens = []
+ if self.add_bos_token:
+ special_tokens.append((bos, bos_token_id))
+ if self.add_eos_token:
+ special_tokens.append((eos, eos_token_id))
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=single, pair=pair, special_tokens=special_tokens
+ )
+
+ @property
+ def add_eos_token(self):
+ return self._add_eos_token
+
+ @property
+ def add_bos_token(self):
+ return self._add_bos_token
+
+ @add_eos_token.setter
+ def add_eos_token(self, value):
+ self._add_eos_token = value
+ self.update_post_processor()
+
+ @add_bos_token.setter
+ def add_bos_token(self, value):
+ self._add_bos_token = value
+ self.update_post_processor()
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
+ "tokenizer."
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
+
+ @property
+ # Copied from transformers.models.llama.tokenization_llama.GoldenGateTokenizer.default_chat_template
+ def default_chat_template(self):
+ raise NotImplementedError
+
+ # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
+ # Copied from transformers.models.llama.tokenization_llama.GoldenGateTokenizer.build_inputs_with_special_tokens
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
+
+ output = bos_token_id + token_ids_0 + eos_token_id
+
+ if token_ids_1 is not None:
+ output = output + bos_token_id + token_ids_1 + eos_token_id
+
+ return output
diff --git a/server/text_generation_server/models/flash_golden_gate.py b/server/text_generation_server/models/flash_golden_gate.py
new file mode 100644
index 00000000..ae5940d8
--- /dev/null
+++ b/server/text_generation_server/models/flash_golden_gate.py
@@ -0,0 +1,105 @@
+import torch
+import torch.distributed
+
+from opentelemetry import trace
+from typing import Optional
+from transformers import AutoTokenizer
+
+from text_generation_server.models import FlashCausalLM
+from text_generation_server.models.custom_modeling.flash_golden_gate_modeling import (
+ FlashGoldenGateForCausalLM,
+ GoldenGateConfig,
+)
+from text_generation_server.utils import (
+ initialize_torch_distributed,
+ weight_files,
+ Weights,
+)
+
+tracer = trace.get_tracer(__name__)
+
+
+class FlashGoldenGate(FlashCausalLM):
+ def __init__(
+ self,
+ model_id: str,
+ revision: Optional[str] = None,
+ quantize: Optional[str] = None,
+ dtype: Optional[torch.dtype] = None,
+ trust_remote_code: bool = False,
+ use_medusa: Optional[str] = None,
+ ):
+ self.process_group, rank, world_size = initialize_torch_distributed()
+ if torch.cuda.is_available():
+ device = torch.device(f"cuda:{rank}")
+ dtype = torch.float16 if dtype is None else dtype
+ else:
+ raise NotImplementedError("FlashGoldenGate is only available on GPU")
+
+ from text_generation_server.models.custom_modeling.temp_tok import GoldenGateTokenizerFast
+ tokenizer = GoldenGateTokenizerFast.from_pretrained(
+ model_id,
+ revision=revision,
+ padding_side="left",
+ truncation_side="left",
+ trust_remote_code=trust_remote_code,
+ use_fast=True,
+ from_slow=False,
+ )
+
+ config = GoldenGateConfig.from_pretrained(
+ model_id, revision=revision, trust_remote_code=trust_remote_code
+ )
+ config.quantize = quantize
+
+ torch.distributed.barrier(group=self.process_group)
+
+ filenames = weight_files(model_id, revision=revision, extension=".safetensors")
+ weights = Weights(filenames, device, dtype, process_group=self.process_group)
+ if config.quantize in ["gptq", "awq"]:
+ weights._set_gptq_params(model_id, revision)
+
+ model = FlashGoldenGateForCausalLM(config, weights)
+ if use_medusa:
+ from text_generation_server.utils.medusa import MedusaModel
+ from huggingface_hub import hf_hub_download
+ import json
+ import os
+ from pathlib import Path
+
+ is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv(
+ "WEIGHTS_CACHE_OVERRIDE", None
+ ) is not None
+
+ if not is_local_model:
+ medusa_config = hf_hub_download(
+ use_medusa, revision=revision, filename="config.json"
+ )
+ medusa_head = hf_hub_download(
+ use_medusa, revision=revision, filename="medusa_lm_head.pt"
+ )
+ else:
+ medusa_config = str(Path(use_medusa) / "config.json")
+ medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
+
+ with open(medusa_config, "r") as f:
+ config = json.load(f)
+ medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
+ weights = Weights(
+ [medusa_sf], device, dtype, process_group=self.process_group
+ )
+ lm_head = model.lm_head
+ model.lm_head = MedusaModel(config, weights, lm_head)
+
+ torch.distributed.barrier(group=self.process_group)
+ super(FlashGoldenGate, self).__init__(
+ model=model,
+ tokenizer=tokenizer,
+ num_layers=len(model.model.layers),
+ num_kv_heads=model.model.num_key_value_heads,
+ head_size=model.model.head_size,
+ dtype=dtype,
+ device=device,
+ rank=rank,
+ world_size=world_size,
+ )