This commit is contained in:
OlivierDehaene 2024-02-20 15:43:44 +01:00
parent f4eec092c6
commit 5da4b01cc5
4 changed files with 208 additions and 257 deletions

View File

@ -52,8 +52,8 @@ try:
from text_generation_server.models.flash_llama import ( from text_generation_server.models.flash_llama import (
FlashLlama, FlashLlama,
) )
from text_generation_server.models.flash_golden_gate import ( from text_generation_server.models.flash_gemma import (
FlashGoldenGate, FlashGemma,
) )
from text_generation_server.models.flash_santacoder import ( from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded, FlashSantacoderSharded,
@ -315,9 +315,9 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "golden_gate": if model_type == "gemma":
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashGoldenGate( return FlashGemma(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
@ -326,7 +326,9 @@ def get_model(
use_medusa=use_medusa, use_medusa=use_medusa,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate")) raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate")
)
else: else:
return CausalLM( return CausalLM(
model_id, model_id,

View File

@ -20,11 +20,16 @@
import torch import torch
import torch.distributed import torch.distributed
import os
from shutil import copyfile
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from tokenizers import processors
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.utils import logging
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
@ -37,8 +42,168 @@ from text_generation_server.utils.layers import (
FastRMSNorm, FastRMSNorm,
) )
GemmaTokenizer = None
class GoldenGateConfig(PretrainedConfig): logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "tokenizer.model",
"tokenizer_file": "tokenizer.json",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
},
"tokenizer_file": {
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
},
}
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
# fmt: off
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
class GemmaTokenizerFast(PreTrainedTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
slow_tokenizer_class = GemmaTokenizer
padding_side = "left"
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
clean_up_tokenization_spaces=False,
unk_token="<unk>",
bos_token="<bos>",
eos_token="<eos>",
pad_token="<pad>",
add_bos_token=True,
add_eos_token=False,
use_default_system_prompt=False,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
use_default_system_prompt=use_default_system_prompt,
**kwargs,
)
self._add_bos_token = add_bos_token
self._add_eos_token = add_eos_token
self.update_post_processor()
self.use_default_system_prompt = use_default_system_prompt
self.vocab_file = vocab_file
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def update_post_processor(self):
"""
Updates the underlying post processor with the current `bos_token` and `eos_token`.
"""
bos = self.bos_token
bos_token_id = self.bos_token_id
if bos is None and self.add_bos_token:
raise ValueError("add_bos_token = True but bos_token = None")
eos = self.eos_token
eos_token_id = self.eos_token_id
if eos is None and self.add_eos_token:
raise ValueError("add_eos_token = True but eos_token = None")
single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
special_tokens = []
if self.add_bos_token:
special_tokens.append((bos, bos_token_id))
if self.add_eos_token:
special_tokens.append((eos, eos_token_id))
self._tokenizer.post_processor = processors.TemplateProcessing(
single=single, pair=pair, special_tokens=special_tokens
)
@property
def add_eos_token(self):
return self._add_eos_token
@property
def add_bos_token(self):
return self._add_bos_token
@add_eos_token.setter
def add_eos_token(self, value):
self._add_eos_token = value
self.update_post_processor()
@add_bos_token.setter
def add_bos_token(self, value):
self._add_bos_token = value
self.update_post_processor()
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ VOCAB_FILES_NAMES["vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
@property
# Copied from transformers.models.llama.tokenization_llama.GemmaTokenizer.default_chat_template
def default_chat_template(self):
raise NotImplementedError
# TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
# Copied from transformers.models.llama.tokenization_llama.GemmaTokenizer.build_inputs_with_special_tokens
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = bos_token_id + token_ids_0 + eos_token_id
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id
return output
class GemmaConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
vocab_size=256128, vocab_size=256128,
@ -93,7 +258,8 @@ class GoldenGateConfig(PretrainedConfig):
**kwargs, **kwargs,
) )
class GoldenGateFastRMSNorm(FastRMSNorm):
class GemmaFastRMSNorm(FastRMSNorm):
@classmethod @classmethod
def load(cls, prefix, weights, eps=1e-6): def load(cls, prefix, weights, eps=1e-6):
weight = weights.get_tensor(f"{prefix}.weight") + 1 weight = weights.get_tensor(f"{prefix}.weight") + 1
@ -138,7 +304,7 @@ def _load_gqa(config, prefix: str, weights):
) )
class FlashGoldenGateAttention(torch.nn.Module): class FlashGemmaAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
prefix: str, prefix: str,
@ -242,7 +408,7 @@ class FlashGoldenGateAttention(torch.nn.Module):
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class GoldenGateMLP(nn.Module): class GemmaMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
act = config.hidden_act act = config.hidden_act
@ -251,9 +417,9 @@ class GoldenGateMLP(nn.Module):
if "gelu" not in act if "gelu" not in act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate="tanh" approximate=(
if act in ["gelu_fast", "gelu_pytorch_tanh"] "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
else "none", ),
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
@ -280,19 +446,19 @@ class GoldenGateMLP(nn.Module):
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
class FlashGoldenGateLayer(nn.Module): class FlashGemmaLayer(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, layer_id, config, weights):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"model.layers.{layer_id}"
self.self_attn = FlashGoldenGateAttention( self.self_attn = FlashGemmaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights
) )
self.mlp = GoldenGateMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = GoldenGateFastRMSNorm.load( self.input_layernorm = GemmaFastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
) )
self.post_attention_layernorm = GoldenGateFastRMSNorm.load( self.post_attention_layernorm = GemmaFastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm", prefix=f"{prefix}.post_attention_layernorm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
@ -336,14 +502,14 @@ class FlashGoldenGateLayer(nn.Module):
return mlp_output, attn_res return mlp_output, attn_res
class FlashGoldenGateModel(torch.nn.Module): class FlashGemmaModel(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
embed_norm = config.hidden_size ** 0.5 embed_norm = config.hidden_size**0.5
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights prefix="model.embed_tokens", weights=weights
) )
@ -351,7 +517,7 @@ class FlashGoldenGateModel(torch.nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
FlashGoldenGateLayer( FlashGemmaLayer(
layer_id, layer_id,
config, config,
weights, weights,
@ -359,7 +525,7 @@ class FlashGoldenGateModel(torch.nn.Module):
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = GoldenGateFastRMSNorm.load( self.norm = GemmaFastRMSNorm.load(
prefix="model.norm", weights=weights, eps=config.rms_norm_eps prefix="model.norm", weights=weights, eps=config.rms_norm_eps
) )
@ -408,11 +574,11 @@ class FlashGoldenGateModel(torch.nn.Module):
return hidden_states return hidden_states
class FlashGoldenGateForCausalLM(torch.nn.Module): class FlashGemmaForCausalLM(torch.nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
self.model = FlashGoldenGateModel(config, weights) self.model = FlashGemmaModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = TensorParallelHead.load(
config, config,
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",

View File

@ -1,216 +0,0 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from shutil import copyfile
from typing import Optional, Tuple
from tokenizers import processors
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.utils import logging
from transformers.utils.versions import require_version
require_version("tokenizers>=0.13.3")
GoldenGateTokenizer = None
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
},
"tokenizer_file": {
"hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
},
}
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
# fmt: off
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
class GoldenGateTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a GoldenGate tokenizer. Based on byte-level Byte-Pair-Encoding.
This uses notably ByteFallback and no normalization.
```python
>>> from transformers import GoldenGateTokenizerFast
>>> tokenizer = GoldenGateTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
>>> tokenizer.encode("Hello this is a test")
[1, 15043, 445, 338, 263, 1243]
```
If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
[post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
vocab_file (`str`, *optional*):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
contains the vocabulary necessary to instantiate a tokenizer.
tokenizer_file (`str`, *optional*):
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
contains everything needed to load the tokenizer.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
extra spaces.
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
The end of sequence token.
add_bos_token (`bool`, *optional*, defaults to `True`):
Whether or not to add an `bos_token` at the start of sequences.
add_eos_token (`bool`, *optional*, defaults to `False`):
Whether or not to add an `eos_token` at the end of sequences.
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
Whether or not the default system prompt for GoldenGate should be used.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
slow_tokenizer_class = GoldenGateTokenizer
padding_side = "left"
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
clean_up_tokenization_spaces=False,
unk_token="<unk>",
bos_token="<bos>",
eos_token="<eos>",
pad_token="<pad>",
add_bos_token=True,
add_eos_token=False,
use_default_system_prompt=False,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
use_default_system_prompt=use_default_system_prompt,
**kwargs,
)
self._add_bos_token = add_bos_token
self._add_eos_token = add_eos_token
self.update_post_processor()
self.use_default_system_prompt = use_default_system_prompt
self.vocab_file = vocab_file
@property
def can_save_slow_tokenizer(self) -> bool:
return os.path.isfile(self.vocab_file) if self.vocab_file else False
def update_post_processor(self):
"""
Updates the underlying post processor with the current `bos_token` and `eos_token`.
"""
bos = self.bos_token
bos_token_id = self.bos_token_id
if bos is None and self.add_bos_token:
raise ValueError("add_bos_token = True but bos_token = None")
eos = self.eos_token
eos_token_id = self.eos_token_id
if eos is None and self.add_eos_token:
raise ValueError("add_eos_token = True but eos_token = None")
single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
special_tokens = []
if self.add_bos_token:
special_tokens.append((bos, bos_token_id))
if self.add_eos_token:
special_tokens.append((eos, eos_token_id))
self._tokenizer.post_processor = processors.TemplateProcessing(
single=single, pair=pair, special_tokens=special_tokens
)
@property
def add_eos_token(self):
return self._add_eos_token
@property
def add_bos_token(self):
return self._add_bos_token
@add_eos_token.setter
def add_eos_token(self, value):
self._add_eos_token = value
self.update_post_processor()
@add_bos_token.setter
def add_bos_token(self, value):
self._add_bos_token = value
self.update_post_processor()
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
@property
# Copied from transformers.models.llama.tokenization_llama.GoldenGateTokenizer.default_chat_template
def default_chat_template(self):
raise NotImplementedError
# TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
# Copied from transformers.models.llama.tokenization_llama.GoldenGateTokenizer.build_inputs_with_special_tokens
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = bos_token_id + token_ids_0 + eos_token_id
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id
return output

View File

@ -3,12 +3,12 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Optional from typing import Optional
from transformers import AutoTokenizer
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_golden_gate_modeling import ( from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGoldenGateForCausalLM, GemmaTokenizerFast,
GoldenGateConfig, FlashGemmaForCausalLM,
GemmaConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
@ -19,7 +19,7 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
class FlashGoldenGate(FlashCausalLM): class FlashGemma(FlashCausalLM):
def __init__( def __init__(
self, self,
model_id: str, model_id: str,
@ -32,12 +32,11 @@ class FlashGoldenGate(FlashCausalLM):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashGoldenGate is only available on GPU") raise NotImplementedError("FlashGemma is only available on GPU")
from text_generation_server.models.custom_modeling.temp_tok import GoldenGateTokenizerFast tokenizer = GemmaTokenizerFast.from_pretrained(
tokenizer = GoldenGateTokenizerFast.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
padding_side="left", padding_side="left",
@ -47,7 +46,7 @@ class FlashGoldenGate(FlashCausalLM):
from_slow=False, from_slow=False,
) )
config = GoldenGateConfig.from_pretrained( config = GemmaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
@ -59,18 +58,18 @@ class FlashGoldenGate(FlashCausalLM):
if config.quantize in ["gptq", "awq"]: if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = FlashGoldenGateForCausalLM(config, weights) model = FlashGemmaForCausalLM(config, weights)
if use_medusa: if use_medusa:
from text_generation_server.utils.medusa import MedusaModel from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import json import json
import os import os
from pathlib import Path from pathlib import Path
is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv( is_local_model = (
"WEIGHTS_CACHE_OVERRIDE", None Path(use_medusa).exists() and Path(use_medusa).is_dir()
) is not None ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model: if not is_local_model:
medusa_config = hf_hub_download( medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json" use_medusa, revision=revision, filename="config.json"
@ -81,7 +80,7 @@ class FlashGoldenGate(FlashCausalLM):
else: else:
medusa_config = str(Path(use_medusa) / "config.json") medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
@ -92,7 +91,7 @@ class FlashGoldenGate(FlashCausalLM):
model.lm_head = MedusaModel(config, weights, lm_head) model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashGoldenGate, self).__init__( super(FlashGemma, self).__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
num_layers=len(model.model.layers), num_layers=len(model.model.layers),