This commit is contained in:
Nicolas Patry 2024-02-23 15:10:08 +00:00
parent f592df5234
commit a0095b5b8d
4 changed files with 15 additions and 14 deletions

View File

@ -140,13 +140,17 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
medusa_config = hf_hub_download( is_local = Path(medusa_model_id).exists()
medusa_model_id, revision=revision, filename="config.json" if not is_local:
) medusa_config = hf_hub_download(
hf_hub_download( medusa_model_id, revision=revision, filename="config.json"
medusa_model_id, revision=revision, filename="medusa_lm_head.pt" )
) hf_hub_download(
use_medusa = Path(medusa_config).parent medusa_model_id, revision=revision, filename="medusa_lm_head.pt"
)
use_medusa = Path(medusa_config).parent
else:
use_medusa = Path(medusa_model_id)
method = "medusa" method = "medusa"
else: else:

View File

@ -592,7 +592,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
input_lengths: torch.Tensor, input_lengths: torch.Tensor,
max_s: int, max_s: int,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
@ -605,5 +605,5 @@ class FlashGemmaForCausalLM(torch.nn.Module):
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
return logits return logits, speculative_logits

View File

@ -968,8 +968,6 @@ class FlashCausalLM(Model):
speculative_logits, speculative_logits,
) )
logger.info(f"Accepted ids {accepted_ids}")
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
) )

View File

@ -444,11 +444,10 @@ class SpeculativeHead(nn.Module):
import json import json
medusa_config = str(Path(use_medusa) / "config.json") medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f: with open(medusa_config, "r") as f:
config = json.load(f) config = json.load(f)
filename = medusa_head[: -len(".pt")] + ".safetensors"
routing = weights.routing routing = weights.routing
with safe_open(filename, framework="pytorch") as f: with safe_open(filename, framework="pytorch") as f:
for k in f.keys(): for k in f.keys():