diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index b1994314..abab3486 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -52,8 +52,8 @@ try: from text_generation_server.models.flash_llama import ( FlashLlama, ) - from text_generation_server.models.flash_golden_gate import ( - FlashGoldenGate, + from text_generation_server.models.flash_gemma import ( + FlashGemma, ) from text_generation_server.models.flash_santacoder import ( FlashSantacoderSharded, @@ -315,9 +315,9 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - if model_type == "golden_gate": + if model_type == "gemma": if FLASH_ATTENTION: - return FlashGoldenGate( + return FlashGemma( model_id, revision, quantize=quantize, @@ -326,7 +326,9 @@ def get_model( use_medusa=use_medusa, ) elif sharded: - raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate")) + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate") + ) else: return CausalLM( model_id, diff --git a/server/text_generation_server/models/custom_modeling/flash_golden_gate_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py similarity index 65% rename from server/text_generation_server/models/custom_modeling/flash_golden_gate_modeling.py rename to server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index ca5b5952..bb55f5d5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_golden_gate_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -20,11 +20,16 @@ import torch import torch.distributed +import os +from shutil import copyfile from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple +from tokenizers import processors +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +from transformers.utils import logging from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( @@ -37,8 +42,168 @@ from text_generation_server.utils.layers import ( FastRMSNorm, ) +GemmaTokenizer = None -class GoldenGateConfig(PretrainedConfig): +logger = logging.get_logger(__name__) +VOCAB_FILES_NAMES = { + "vocab_file": "tokenizer.model", + "tokenizer_file": "tokenizer.json", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + }, + "tokenizer_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + }, +} +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" + + +# fmt: on + + +class GemmaTokenizerFast(PreTrainedTokenizerFast): + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + slow_tokenizer_class = GemmaTokenizer + padding_side = "left" + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + clean_up_tokenization_spaces=False, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + add_bos_token=True, + add_eos_token=False, + use_default_system_prompt=False, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + use_default_system_prompt=use_default_system_prompt, + **kwargs, + ) + self._add_bos_token = add_bos_token + self._add_eos_token = add_eos_token + self.update_post_processor() + self.use_default_system_prompt = use_default_system_prompt + self.vocab_file = vocab_file + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + def update_post_processor(self): + """ + Updates the underlying post processor with the current `bos_token` and `eos_token`. + """ + bos = self.bos_token + bos_token_id = self.bos_token_id + if bos is None and self.add_bos_token: + raise ValueError("add_bos_token = True but bos_token = None") + + eos = self.eos_token + eos_token_id = self.eos_token_id + if eos is None and self.add_eos_token: + raise ValueError("add_eos_token = True but eos_token = None") + + single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}" + pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}" + + special_tokens = [] + if self.add_bos_token: + special_tokens.append((bos, bos_token_id)) + if self.add_eos_token: + special_tokens.append((eos, eos_token_id)) + self._tokenizer.post_processor = processors.TemplateProcessing( + single=single, pair=pair, special_tokens=special_tokens + ) + + @property + def add_eos_token(self): + return self._add_eos_token + + @property + def add_bos_token(self): + return self._add_bos_token + + @add_eos_token.setter + def add_eos_token(self, value): + self._add_eos_token = value + self.update_post_processor() + + @add_bos_token.setter + def add_bos_token(self, value): + self._add_bos_token = value + self.update_post_processor() + + def save_vocabulary( + self, save_directory: str, filename_prefix: Optional[str] = None + ) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + @property + # Copied from transformers.models.llama.tokenization_llama.GemmaTokenizer.default_chat_template + def default_chat_template(self): + raise NotImplementedError + + # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers + # Copied from transformers.models.llama.tokenization_llama.GemmaTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + +class GemmaConfig(PretrainedConfig): def __init__( self, vocab_size=256128, @@ -93,7 +258,8 @@ class GoldenGateConfig(PretrainedConfig): **kwargs, ) -class GoldenGateFastRMSNorm(FastRMSNorm): + +class GemmaFastRMSNorm(FastRMSNorm): @classmethod def load(cls, prefix, weights, eps=1e-6): weight = weights.get_tensor(f"{prefix}.weight") + 1 @@ -138,7 +304,7 @@ def _load_gqa(config, prefix: str, weights): ) -class FlashGoldenGateAttention(torch.nn.Module): +class FlashGemmaAttention(torch.nn.Module): def __init__( self, prefix: str, @@ -242,7 +408,7 @@ class FlashGoldenGateAttention(torch.nn.Module): return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) -class GoldenGateMLP(nn.Module): +class GemmaMLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() act = config.hidden_act @@ -251,9 +417,9 @@ class GoldenGateMLP(nn.Module): if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, - approximate="tanh" - if act in ["gelu_fast", "gelu_pytorch_tanh"] - else "none", + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), ) ) # Fuse gate and up proj @@ -280,19 +446,19 @@ class GoldenGateMLP(nn.Module): return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) -class FlashGoldenGateLayer(nn.Module): +class FlashGemmaLayer(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() prefix = f"model.layers.{layer_id}" - self.self_attn = FlashGoldenGateAttention( + self.self_attn = FlashGemmaAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) - self.mlp = GoldenGateMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) - self.input_layernorm = GoldenGateFastRMSNorm.load( + self.input_layernorm = GemmaFastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) - self.post_attention_layernorm = GoldenGateFastRMSNorm.load( + self.post_attention_layernorm = GemmaFastRMSNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps, @@ -336,14 +502,14 @@ class FlashGoldenGateLayer(nn.Module): return mlp_output, attn_res -class FlashGoldenGateModel(torch.nn.Module): +class FlashGemmaModel(torch.nn.Module): def __init__(self, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - embed_norm = config.hidden_size ** 0.5 + embed_norm = config.hidden_size**0.5 self.embed_tokens = TensorParallelEmbedding( prefix="model.embed_tokens", weights=weights ) @@ -351,7 +517,7 @@ class FlashGoldenGateModel(torch.nn.Module): self.layers = nn.ModuleList( [ - FlashGoldenGateLayer( + FlashGemmaLayer( layer_id, config, weights, @@ -359,7 +525,7 @@ class FlashGoldenGateModel(torch.nn.Module): for layer_id in range(config.num_hidden_layers) ] ) - self.norm = GoldenGateFastRMSNorm.load( + self.norm = GemmaFastRMSNorm.load( prefix="model.norm", weights=weights, eps=config.rms_norm_eps ) @@ -408,11 +574,11 @@ class FlashGoldenGateModel(torch.nn.Module): return hidden_states -class FlashGoldenGateForCausalLM(torch.nn.Module): +class FlashGemmaForCausalLM(torch.nn.Module): def __init__(self, config, weights): super().__init__() - self.model = FlashGoldenGateModel(config, weights) + self.model = FlashGemmaModel(config, weights) self.lm_head = TensorParallelHead.load( config, prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", diff --git a/server/text_generation_server/models/custom_modeling/temp_tok.py b/server/text_generation_server/models/custom_modeling/temp_tok.py deleted file mode 100644 index 06516cbc..00000000 --- a/server/text_generation_server/models/custom_modeling/temp_tok.py +++ /dev/null @@ -1,216 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from shutil import copyfile -from typing import Optional, Tuple - -from tokenizers import processors - -from transformers.tokenization_utils_fast import PreTrainedTokenizerFast -from transformers.utils import logging -from transformers.utils.versions import require_version - - -require_version("tokenizers>=0.13.3") - -GoldenGateTokenizer = None - -logger = logging.get_logger(__name__) -VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"} - -PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", - }, - "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", - }, -} -B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "<>\n", "\n<>\n\n" - -# fmt: off -DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ -answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ - that your responses are socially unbiased and positive in nature. -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ -correct. If you don't know the answer to a question, please don't share false information.""" -# fmt: on - - -class GoldenGateTokenizerFast(PreTrainedTokenizerFast): - """ - Construct a GoldenGate tokenizer. Based on byte-level Byte-Pair-Encoding. - This uses notably ByteFallback and no normalization. - ```python - >>> from transformers import GoldenGateTokenizerFast - >>> tokenizer = GoldenGateTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") - >>> tokenizer.encode("Hello this is a test") - [1, 15043, 445, 338, 263, 1243] - ``` - If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or - call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the - values of the first token and final token of an encoded sequence will not be correct). For more details, checkout - [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation. - This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should - refer to this superclass for more information regarding those methods. - Args: - vocab_file (`str`, *optional*): - [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that - contains the vocabulary necessary to instantiate a tokenizer. - tokenizer_file (`str`, *optional*): - [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that - contains everything needed to load the tokenizer. - clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): - Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like - extra spaces. - unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this - token instead. - bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. - eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): - The end of sequence token. - add_bos_token (`bool`, *optional*, defaults to `True`): - Whether or not to add an `bos_token` at the start of sequences. - add_eos_token (`bool`, *optional*, defaults to `False`): - Whether or not to add an `eos_token` at the end of sequences. - use_default_system_prompt (`bool`, *optional*, defaults to `False`): - Whether or not the default system prompt for GoldenGate should be used. - """ - - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - slow_tokenizer_class = GoldenGateTokenizer - padding_side = "left" - model_input_names = ["input_ids", "attention_mask"] - - def __init__( - self, - vocab_file=None, - tokenizer_file=None, - clean_up_tokenization_spaces=False, - unk_token="", - bos_token="", - eos_token="", - pad_token="", - add_bos_token=True, - add_eos_token=False, - use_default_system_prompt=False, - **kwargs, - ): - super().__init__( - vocab_file=vocab_file, - tokenizer_file=tokenizer_file, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - unk_token=unk_token, - bos_token=bos_token, - eos_token=eos_token, - pad_token=pad_token, - add_bos_token=add_bos_token, - add_eos_token=add_eos_token, - use_default_system_prompt=use_default_system_prompt, - **kwargs, - ) - self._add_bos_token = add_bos_token - self._add_eos_token = add_eos_token - self.update_post_processor() - self.use_default_system_prompt = use_default_system_prompt - self.vocab_file = vocab_file - - @property - def can_save_slow_tokenizer(self) -> bool: - return os.path.isfile(self.vocab_file) if self.vocab_file else False - - def update_post_processor(self): - """ - Updates the underlying post processor with the current `bos_token` and `eos_token`. - """ - bos = self.bos_token - bos_token_id = self.bos_token_id - if bos is None and self.add_bos_token: - raise ValueError("add_bos_token = True but bos_token = None") - - eos = self.eos_token - eos_token_id = self.eos_token_id - if eos is None and self.add_eos_token: - raise ValueError("add_eos_token = True but eos_token = None") - - single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" - pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" - - special_tokens = [] - if self.add_bos_token: - special_tokens.append((bos, bos_token_id)) - if self.add_eos_token: - special_tokens.append((eos, eos_token_id)) - self._tokenizer.post_processor = processors.TemplateProcessing( - single=single, pair=pair, special_tokens=special_tokens - ) - - @property - def add_eos_token(self): - return self._add_eos_token - - @property - def add_bos_token(self): - return self._add_bos_token - - @add_eos_token.setter - def add_eos_token(self, value): - self._add_eos_token = value - self.update_post_processor() - - @add_bos_token.setter - def add_bos_token(self, value): - self._add_bos_token = value - self.update_post_processor() - - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: - if not self.can_save_slow_tokenizer: - raise ValueError( - "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " - "tokenizer." - ) - - if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) should be a directory") - return - out_vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] - ) - - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): - copyfile(self.vocab_file, out_vocab_file) - - return (out_vocab_file,) - - @property - # Copied from transformers.models.llama.tokenization_llama.GoldenGateTokenizer.default_chat_template - def default_chat_template(self): - raise NotImplementedError - - # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers - # Copied from transformers.models.llama.tokenization_llama.GoldenGateTokenizer.build_inputs_with_special_tokens - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - bos_token_id = [self.bos_token_id] if self.add_bos_token else [] - eos_token_id = [self.eos_token_id] if self.add_eos_token else [] - - output = bos_token_id + token_ids_0 + eos_token_id - - if token_ids_1 is not None: - output = output + bos_token_id + token_ids_1 + eos_token_id - - return output \ No newline at end of file diff --git a/server/text_generation_server/models/flash_golden_gate.py b/server/text_generation_server/models/flash_gemma.py similarity index 76% rename from server/text_generation_server/models/flash_golden_gate.py rename to server/text_generation_server/models/flash_gemma.py index ae5940d8..220b3992 100644 --- a/server/text_generation_server/models/flash_golden_gate.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -3,12 +3,12 @@ import torch.distributed from opentelemetry import trace from typing import Optional -from transformers import AutoTokenizer from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_golden_gate_modeling import ( - FlashGoldenGateForCausalLM, - GoldenGateConfig, +from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + GemmaTokenizerFast, + FlashGemmaForCausalLM, + GemmaConfig, ) from text_generation_server.utils import ( initialize_torch_distributed, @@ -19,7 +19,7 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) -class FlashGoldenGate(FlashCausalLM): +class FlashGemma(FlashCausalLM): def __init__( self, model_id: str, @@ -32,12 +32,11 @@ class FlashGoldenGate(FlashCausalLM): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype + dtype = torch.bfloat16 if dtype is None else dtype else: - raise NotImplementedError("FlashGoldenGate is only available on GPU") + raise NotImplementedError("FlashGemma is only available on GPU") - from text_generation_server.models.custom_modeling.temp_tok import GoldenGateTokenizerFast - tokenizer = GoldenGateTokenizerFast.from_pretrained( + tokenizer = GemmaTokenizerFast.from_pretrained( model_id, revision=revision, padding_side="left", @@ -47,7 +46,7 @@ class FlashGoldenGate(FlashCausalLM): from_slow=False, ) - config = GoldenGateConfig.from_pretrained( + config = GemmaConfig.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize @@ -59,18 +58,18 @@ class FlashGoldenGate(FlashCausalLM): if config.quantize in ["gptq", "awq"]: weights._set_gptq_params(model_id, revision) - model = FlashGoldenGateForCausalLM(config, weights) + model = FlashGemmaForCausalLM(config, weights) if use_medusa: from text_generation_server.utils.medusa import MedusaModel from huggingface_hub import hf_hub_download import json import os from pathlib import Path - - is_local_model = (Path(use_medusa).exists() and Path(use_medusa).is_dir()) or os.getenv( - "WEIGHTS_CACHE_OVERRIDE", None - ) is not None - + + is_local_model = ( + Path(use_medusa).exists() and Path(use_medusa).is_dir() + ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None + if not is_local_model: medusa_config = hf_hub_download( use_medusa, revision=revision, filename="config.json" @@ -81,7 +80,7 @@ class FlashGoldenGate(FlashCausalLM): else: medusa_config = str(Path(use_medusa) / "config.json") medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") - + with open(medusa_config, "r") as f: config = json.load(f) medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" @@ -92,7 +91,7 @@ class FlashGoldenGate(FlashCausalLM): model.lm_head = MedusaModel(config, weights, lm_head) torch.distributed.barrier(group=self.process_group) - super(FlashGoldenGate, self).__init__( + super(FlashGemma, self).__init__( model=model, tokenizer=tokenizer, num_layers=len(model.model.layers),