Updating config, removing TODO

This commit is contained in:
Nicolas Patry 2024-09-25 10:48:03 +02:00
parent 047e2e8163
commit 44cdb00bbb
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863

View File

@ -462,7 +462,7 @@ class MllamaVisionModel(nn.Module):
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.in_channels = config.in_channels self.num_channels = config.num_channels
self.intermediate_layers_indices = config.intermediate_layers_indices self.intermediate_layers_indices = config.intermediate_layers_indices
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
@ -470,7 +470,7 @@ class MllamaVisionModel(nn.Module):
self.dtype = weights.dtype self.dtype = weights.dtype
self.patch_embedding = nn.Conv2d( self.patch_embedding = nn.Conv2d(
in_channels=config.in_channels, in_channels=config.num_channels,
out_channels=self.hidden_size, out_channels=self.hidden_size,
kernel_size=self.patch_size, kernel_size=self.patch_size,
stride=self.patch_size, stride=self.patch_size,
@ -810,7 +810,7 @@ class MllamaTextMLP(nn.Module):
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.act_fn = ACT2FN[config.hidden_activation] self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x): def forward(self, x):
shape = x.shape shape = x.shape
@ -1187,15 +1187,11 @@ class MllamaTextModel(nn.Module):
) )
) )
# TODO Should we use this slow norm ?
# self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = MllamaTextRMSNorm.load( self.norm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.norm", prefix=f"{prefix}.norm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )
# TODO Anything specific ?
head_size = config.hidden_size // config.num_attention_heads
self.rotary_emb = MllamaRotaryEmbedding(config=config, weights=weights) self.rotary_emb = MllamaRotaryEmbedding(config=config, weights=weights)
def forward( def forward(
@ -1383,7 +1379,6 @@ class MllamaForCausalLM(nn.Module):
num_logits_to_keep: int = 0, num_logits_to_keep: int = 0,
): ):
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
# TODO
outputs = self.model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
@ -1493,11 +1488,7 @@ class MllamaForConditionalGeneration(nn.Module):
config.vision_config.speculator = config.speculator config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator config.text_config.speculator = config.speculator
# TODO check how this is determined
config.text_config._attn_implementation = "sdpa" config.text_config._attn_implementation = "sdpa"
# self.hidden_size = (
# config.text_config.hidden_size // weights.process_group.size()
# )
self.hidden_size = config.text_config.hidden_size self.hidden_size = config.text_config.hidden_size
self.vision_model = MllamaVisionModel( self.vision_model = MllamaVisionModel(
prefix="vision_model", config=config.vision_config, weights=weights prefix="vision_model", config=config.vision_config, weights=weights