mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-20 06:12:07 +00:00
Updating config, removing TODO
This commit is contained in:
parent
047e2e8163
commit
44cdb00bbb
@ -462,7 +462,7 @@ class MllamaVisionModel(nn.Module):
|
|||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
self.max_num_tiles = config.max_num_tiles
|
self.max_num_tiles = config.max_num_tiles
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.in_channels = config.in_channels
|
self.num_channels = config.num_channels
|
||||||
self.intermediate_layers_indices = config.intermediate_layers_indices
|
self.intermediate_layers_indices = config.intermediate_layers_indices
|
||||||
|
|
||||||
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
||||||
@ -470,7 +470,7 @@ class MllamaVisionModel(nn.Module):
|
|||||||
self.dtype = weights.dtype
|
self.dtype = weights.dtype
|
||||||
|
|
||||||
self.patch_embedding = nn.Conv2d(
|
self.patch_embedding = nn.Conv2d(
|
||||||
in_channels=config.in_channels,
|
in_channels=config.num_channels,
|
||||||
out_channels=self.hidden_size,
|
out_channels=self.hidden_size,
|
||||||
kernel_size=self.patch_size,
|
kernel_size=self.patch_size,
|
||||||
stride=self.patch_size,
|
stride=self.patch_size,
|
||||||
@ -810,7 +810,7 @@ class MllamaTextMLP(nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.act_fn = ACT2FN[config.hidden_activation]
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shape = x.shape
|
shape = x.shape
|
||||||
@ -1187,15 +1187,11 @@ class MllamaTextModel(nn.Module):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO Should we use this slow norm ?
|
|
||||||
# self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.norm = MllamaTextRMSNorm.load(
|
self.norm = MllamaTextRMSNorm.load(
|
||||||
prefix=f"{prefix}.norm",
|
prefix=f"{prefix}.norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
# TODO Anything specific ?
|
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
|
||||||
self.rotary_emb = MllamaRotaryEmbedding(config=config, weights=weights)
|
self.rotary_emb = MllamaRotaryEmbedding(config=config, weights=weights)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -1383,7 +1379,6 @@ class MllamaForCausalLM(nn.Module):
|
|||||||
num_logits_to_keep: int = 0,
|
num_logits_to_keep: int = 0,
|
||||||
):
|
):
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
# TODO
|
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
cross_attention_states=cross_attention_states,
|
cross_attention_states=cross_attention_states,
|
||||||
@ -1493,11 +1488,7 @@ class MllamaForConditionalGeneration(nn.Module):
|
|||||||
config.vision_config.speculator = config.speculator
|
config.vision_config.speculator = config.speculator
|
||||||
config.text_config.quantize = config.quantize
|
config.text_config.quantize = config.quantize
|
||||||
config.text_config.speculator = config.speculator
|
config.text_config.speculator = config.speculator
|
||||||
# TODO check how this is determined
|
|
||||||
config.text_config._attn_implementation = "sdpa"
|
config.text_config._attn_implementation = "sdpa"
|
||||||
# self.hidden_size = (
|
|
||||||
# config.text_config.hidden_size // weights.process_group.size()
|
|
||||||
# )
|
|
||||||
self.hidden_size = config.text_config.hidden_size
|
self.hidden_size = config.text_config.hidden_size
|
||||||
self.vision_model = MllamaVisionModel(
|
self.vision_model = MllamaVisionModel(
|
||||||
prefix="vision_model", config=config.vision_config, weights=weights
|
prefix="vision_model", config=config.vision_config, weights=weights
|
||||||
|
Loading…
Reference in New Issue
Block a user