fix: add adapter_data param and avoid missing layers

This commit is contained in:
drbh 2024-06-07 03:03:15 +00:00
parent 91f407226d
commit b1169273fd
5 changed files with 19 additions and 12 deletions

View File

@ -672,6 +672,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,

View File

@ -488,6 +488,7 @@ class FlashSantacoderForCausalLM(nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,

View File

@ -525,6 +525,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
true_max_s = max_s true_max_s = max_s
if prefill_cache_indices is not None: if prefill_cache_indices is not None:

View File

@ -147,6 +147,9 @@ class BaseFlashMistral(FlashCausalLM):
layer.self_attn.o_proj, layer.self_attn.o_proj,
) )
# TODO: this is a hack to avoid the gate_proj for
# FlashStarcoder2 that doesnt have these layers
if hasattr(layer.mlp, "gate_up_proj"):
layer_weights[(i, "gate_proj")] = ( layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj", f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj, layer.mlp.gate_up_proj,

View File

@ -634,6 +634,7 @@ class IdeficsCausalLM(Model):
tokenizer.add_special_tokens({"pad_token": "<unk>"}) tokenizer.add_special_tokens({"pad_token": "<unk>"})
super(IdeficsCausalLM, self).__init__( super(IdeficsCausalLM, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,