This commit is contained in:
Nicolas Patry 2024-02-23 15:10:08 +00:00
parent f592df5234
commit a0095b5b8d
4 changed files with 15 additions and 14 deletions

View File

@ -140,6 +140,8 @@ def get_model(
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
is_local = Path(medusa_model_id).exists()
if not is_local:
medusa_config = hf_hub_download(
medusa_model_id, revision=revision, filename="config.json"
)
@ -147,6 +149,8 @@ def get_model(
medusa_model_id, revision=revision, filename="medusa_lm_head.pt"
)
use_medusa = Path(medusa_config).parent
else:
use_medusa = Path(medusa_model_id)
method = "medusa"
else:

View File

@ -592,7 +592,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
@ -605,5 +605,5 @@ class FlashGemmaForCausalLM(torch.nn.Module):
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -968,8 +968,6 @@ class FlashCausalLM(Model):
speculative_logits,
)
logger.info(f"Accepted ids {accepted_ids}")
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
)

View File

@ -444,11 +444,10 @@ class SpeculativeHead(nn.Module):
import json
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
with open(medusa_config, "r") as f:
config = json.load(f)
filename = medusa_head[: -len(".pt")] + ".safetensors"
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():