mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Fixing.
This commit is contained in:
parent
f592df5234
commit
a0095b5b8d
@ -140,6 +140,8 @@ def get_model(
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
is_local = Path(medusa_model_id).exists()
|
||||
if not is_local:
|
||||
medusa_config = hf_hub_download(
|
||||
medusa_model_id, revision=revision, filename="config.json"
|
||||
)
|
||||
@ -147,6 +149,8 @@ def get_model(
|
||||
medusa_model_id, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
use_medusa = Path(medusa_config).parent
|
||||
else:
|
||||
use_medusa = Path(medusa_model_id)
|
||||
|
||||
method = "medusa"
|
||||
else:
|
||||
|
@ -592,7 +592,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
@ -605,5 +605,5 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
@ -968,8 +968,6 @@ class FlashCausalLM(Model):
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
logger.info(f"Accepted ids {accepted_ids}")
|
||||
|
||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
|
||||
)
|
||||
|
@ -444,11 +444,10 @@ class SpeculativeHead(nn.Module):
|
||||
import json
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
filename = medusa_head[: -len(".pt")] + ".safetensors"
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
|
Loading…
Reference in New Issue
Block a user