From 44cdb00bbbcb45037ce79d087588049aadcf66cd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 25 Sep 2024 10:48:03 +0200 Subject: [PATCH] Updating config, removing TODO --- .../models/custom_modeling/mllama.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/mllama.py b/server/text_generation_server/models/custom_modeling/mllama.py index 8e6428291..b59b623ab 100644 --- a/server/text_generation_server/models/custom_modeling/mllama.py +++ b/server/text_generation_server/models/custom_modeling/mllama.py @@ -462,7 +462,7 @@ class MllamaVisionModel(nn.Module): self.patch_size = config.patch_size self.max_num_tiles = config.max_num_tiles self.hidden_size = config.hidden_size - self.in_channels = config.in_channels + self.num_channels = config.num_channels self.intermediate_layers_indices = config.intermediate_layers_indices self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 @@ -470,7 +470,7 @@ class MllamaVisionModel(nn.Module): self.dtype = weights.dtype self.patch_embedding = nn.Conv2d( - in_channels=config.in_channels, + in_channels=config.num_channels, out_channels=self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size, @@ -810,7 +810,7 @@ class MllamaTextMLP(nn.Module): weights=weights, bias=False, ) - self.act_fn = ACT2FN[config.hidden_activation] + self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): shape = x.shape @@ -1187,15 +1187,11 @@ class MllamaTextModel(nn.Module): ) ) - # TODO Should we use this slow norm ? - # self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = MllamaTextRMSNorm.load( prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps, ) - # TODO Anything specific ? - head_size = config.hidden_size // config.num_attention_heads self.rotary_emb = MllamaRotaryEmbedding(config=config, weights=weights) def forward( @@ -1383,7 +1379,6 @@ class MllamaForCausalLM(nn.Module): num_logits_to_keep: int = 0, ): # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - # TODO outputs = self.model( input_ids=input_ids, cross_attention_states=cross_attention_states, @@ -1493,11 +1488,7 @@ class MllamaForConditionalGeneration(nn.Module): config.vision_config.speculator = config.speculator config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator - # TODO check how this is determined config.text_config._attn_implementation = "sdpa" - # self.hidden_size = ( - # config.text_config.hidden_size // weights.process_group.size() - # ) self.hidden_size = config.text_config.hidden_size self.vision_model = MllamaVisionModel( prefix="vision_model", config=config.vision_config, weights=weights