Hardcode a few stuff to make it work.

This commit is contained in:
Nicolas Patry 2024-05-06 14:03:05 +00:00
parent 453e91f755
commit 38d6045443
4 changed files with 179 additions and 100 deletions

View File

@ -136,7 +136,7 @@ def get_model(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
use_medusa = None speculator = None
if "medusa_num_heads" in config_dict: if "medusa_num_heads" in config_dict:
medusa_model_id = model_id medusa_model_id = model_id
medusa_revision = revision medusa_revision = revision
@ -166,11 +166,50 @@ def get_model(
revision=medusa_revision, revision=medusa_revision,
filename="medusa_lm_head.safetensors", filename="medusa_lm_head.safetensors",
) )
use_medusa = Path(medusa_config).parent speculator = Path(medusa_config).parent
else: else:
use_medusa = Path(medusa_model_id) speculator = Path(medusa_model_id)
method = "medusa" method = "medusa"
elif config_dict["model_type"] == "mlp_speculator":
# TODO make this not hardcoded.
mlp_model_id = model_id
mlp_revision = revision
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
revision = "main"
speculate_mlp = config_dict["n_predict"]
if speculate is not None:
if speculate > speculate_mlp:
raise RuntimeError(
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
set_speculate(speculate_mlp)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
is_local = Path(mlp_model_id).exists()
if not is_local:
mlp_speculator_config = hf_hub_download(
mlp_model_id, revision=mlp_revision, filename="config.json"
)
hf_hub_download(
mlp_model_id,
revision=mlp_revision,
filename="model-00001-of-00002.safetensors",
)
hf_hub_download(
mlp_model_id,
revision=mlp_revision,
filename="model-00002-of-00002.safetensors",
)
speculator = Path(mlp_speculator_config).parent
else:
speculator = Path(mlp_model_id)
method = "mlp_speculator"
else: else:
method = "n-gram" method = "n-gram"
@ -202,7 +241,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -212,7 +251,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -227,7 +266,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -240,7 +279,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -250,7 +289,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -259,7 +298,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -270,7 +309,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -279,7 +318,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -288,7 +327,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -299,7 +338,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -308,7 +347,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -323,7 +362,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -334,7 +373,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -345,7 +384,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -355,7 +394,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -366,7 +405,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -377,7 +416,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -388,7 +427,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -399,7 +438,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -410,7 +449,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -424,7 +463,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -435,7 +474,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -444,7 +483,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -458,7 +497,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -469,7 +508,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -483,7 +522,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -494,7 +533,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -520,7 +559,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -544,7 +583,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -554,7 +593,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -564,7 +603,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -574,7 +613,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -586,7 +625,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -599,7 +638,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -623,7 +662,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -632,7 +671,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -644,7 +683,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -653,7 +692,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa, speculator=speculator,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -419,5 +419,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states, input_ids)
return logits, speculative_logits return logits, speculative_logits

View File

@ -27,7 +27,7 @@ class FlashLlama(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None, speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -71,7 +71,7 @@ class FlashLlama(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa config.speculator = speculator
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -442,6 +442,7 @@ class ResBlock(torch.nn.Module):
def forward(self, x): def forward(self, x):
return x + self.act(self.linear(x)) return x + self.act(self.linear(x))
class MLPSpeculatorLayerNorm(nn.Module): class MLPSpeculatorLayerNorm(nn.Module):
""" """
A L2 normalization implementation A L2 normalization implementation
@ -461,15 +462,14 @@ class MLPSpeculatorLayerNorm(nn.Module):
def __init__( def __init__(
self, self,
normalized_shape, prefix,
elementwise_scale_weight: torch.Tensor, config,
elementwise_shift_bias: torch.Tensor, weights,
eps=1e-06, eps=1e-06,
): ):
super(MLPSpeculatorLayerNorm, self).__init__() super(MLPSpeculatorLayerNorm, self).__init__()
self.normalized_shape = normalized_shape self.weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(elementwise_scale_weight) self.bias = weights.get_tensor(f"{prefix}.bias")
self.bias = nn.Parameter(elementwise_shift_bias)
self.eps = eps self.eps = eps
def forward(self, x): def forward(self, x):
@ -480,48 +480,69 @@ class MLPSpeculatorLayerNorm(nn.Module):
x = x + self.bias x = x + self.bias
return x return x
class MLPSpeculatorModel(torch.nn.Module): class MLPSpeculatorModel(torch.nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
self.config = config self.config = config
self.n_predict = config.n_predict self.n_predict = get_speculate()
self.emb_dim = config.emb_dim self.hidden_size = config.hidden_size
inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim
self.inner_dim = inner_dim
self.config = config.vocab_size
self.emb = nn.ModuleList( self.emb = nn.ModuleList(
[TensorParallelEmbedding(f"{prefix}.emb.{i}", weights) for i in range(config.n_predict)] [
) TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
self.proj = TensorParallelColumnLinear.load_multi( for i in range(self.n_predict)
config, ]
prefixes=[f"{prefix}.proj.{i}" for i in range(config.n_predict)],
weights=weights,
bias=False,
dim=0
) )
self.proj = [
TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.proj.{i}",
weights=weights,
bias=False,
)
for i in range(self.n_predict)
]
self.head = nn.ModuleList( self.head = nn.ModuleList(
[TensorParallelRowLinear.load(config, f"{prefix}.head.{i}", weights, bias=False) for i in range(config.n_predict)] [
TensorParallelRowLinear.load(
config, f"{prefix}.head.{i}", weights, bias=False
)
for i in range(self.n_predict)
]
) )
self.ln = nn.ModuleList( self.ln = nn.ModuleList(
[ [
MLPSpeculatorLayerNorm( MLPSpeculatorLayerNorm(
config.inner_dim, prefix=f"{prefix}.ln.{i}",
elementwise_scale_weight=weights.get_tensor(f"{prefix}.ln.{i}.weight"), config=config,
elementwise_shift_bias=weights.get_tensor(f"{prefix}.ln.{i}.bias") weights=weights,
) )
for i in range(config.n_predict) for i in range(self.n_predict)
] ]
) )
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict) self.state_weight = 0.5 ** (0.5 / self.n_predict)
self.emb_weight = math.sqrt(1 - self.state_weight ** 2) self.emb_weight = math.sqrt(1 - self.state_weight**2)
self.activation = nn.GELU() self.activation = nn.GELU()
# TODO
self.vsize = 128256
self.inner_dim = 3072
def forward(self, state: torch.Tensor, ind: torch.Tensor, top_k_tokens_per_head: Optional[List[int]], num_candidates: int = 1): def forward(
self,
state: torch.Tensor,
input_ids: torch.Tensor,
top_k_tokens_per_head: Optional[List[int]] = None,
num_candidates: int = 1,
):
# TODO
top_k_tokens_per_head = [1, 1, 1, 1]
if top_k_tokens_per_head is None: if top_k_tokens_per_head is None:
top_k_tokens_per_head = self.config.top_k_tokens_per_head top_k_tokens_per_head = self.config.top_k_tokens_per_head
ind = input_ids
# k indicates # of candidates # k indicates # of candidates
# h indicates # of generated tokens # h indicates # of generated tokens
b = state.size(0) b = state.size(0)
@ -529,7 +550,7 @@ class MLPSpeculatorModel(torch.nn.Module):
log_probs = torch.zeros(b, 1, device=state.device) # b k log_probs = torch.zeros(b, 1, device=state.device) # b k
all_probs = torch.empty(b, 1, 0, self.vsize, device=state.device) # b k h v all_probs = torch.empty(b, 1, 0, self.vsize, device=state.device) # b k h v
assert ( assert (
len(top_k_tokens_per_head) == self.n_predict len(top_k_tokens_per_head) == self.n_predict
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)" ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
for i in range(self.n_predict): for i in range(self.n_predict):
# Project and predict # Project and predict
@ -537,23 +558,40 @@ class MLPSpeculatorModel(torch.nn.Module):
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
state = self.proj[i](state) * self.state_weight + z state = self.proj[i](state) * self.state_weight + z
state = self.activation(self.ln[i](state)) # b k d state = self.activation(self.ln[i](state)) # b k d
_probs = F.log_softmax(self.head[i](state), dim=2) # b k v _probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=2) # b k k' probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
# Update candidate set with new predictions # Update candidate set with new predictions
out = out.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' h out = out.unsqueeze(2).expand(
out = torch.cat([out, preds.unsqueeze(3)], dim=3) # b k k' h+1 -1, -1, top_k_tokens_per_head[i], -1
) # b k k' h
try:
out = torch.cat(
[out, preds.unsqueeze(2).unsqueeze(3)], dim=-1
) # b k k' h+1
except Exception:
import ipdb
ipdb.set_trace()
out = out.view(b, -1, i + 1) # b kk' h+1 out = out.view(b, -1, i + 1) # b kk' h+1
# Update distribution set with new logits # Update distribution set with new logits
all_probs = torch.cat([all_probs, _probs.exp().unsqueeze(2)], dim=2) # b k h+1 v all_probs = torch.cat(
all_probs = all_probs.repeat(1, top_k_tokens_per_head[i], 1, 1) # b kk' h+1 v [all_probs, _probs.exp().unsqueeze(2)], dim=-1
) # b k h+1 v
all_probs = all_probs.repeat(
1, top_k_tokens_per_head[i], 1, 1
) # b kk' h+1 v
# Update state, log_probs and ind for new predictions # Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' d state = state.unsqueeze(2).expand(
-1, -1, top_k_tokens_per_head[i], -1
) # b k k' d
state = state.reshape(b, -1, state.size(3)) # b kk' d state = state.reshape(b, -1, state.size(3)) # b kk' d
ind = preds.view(b, -1) # b kk' ind = preds.view(b, -1) # b kk'
log_probs = log_probs.unsqueeze(2).expand(b, -1, top_k_tokens_per_head[i]) # b k k' log_probs = log_probs.unsqueeze(2).expand(
b, -1, top_k_tokens_per_head[i]
) # b k k'
log_probs = log_probs.add(probs).reshape(b, -1) # b kk' log_probs = log_probs.add(probs).reshape(b, -1) # b kk'
# Take only top n best guesses # Take only top n best guesses
@ -570,14 +608,14 @@ class MLPSpeculatorHead(nn.Module):
self.mlp_speculator = mlp_speculator self.mlp_speculator = mlp_speculator
def forward( def forward(
self, input: torch.Tensor self, input: torch.Tensor, input_ids: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input) logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits # If we have too many tokens, we skip speculative logits
if input.shape[0] > 128: if input.shape[0] > 128:
return logits, None return logits, None
speculative_logits = self.mlp_speculator(input) speculative_logits = self.mlp_speculator(input, input_ids)
return logits, speculative_logits return logits, speculative_logits
@staticmethod @staticmethod
@ -585,18 +623,21 @@ class MLPSpeculatorHead(nn.Module):
from pathlib import Path from pathlib import Path
from safetensors import safe_open from safetensors import safe_open
speculator_path = speculator_config.use_speculator speculator_path = speculator_config.speculator
filename = str(Path(speculator_path) / "*.safetensors") for fname in [
"model-00001-of-00002.safetensors",
routing = weights.routing "model-00002-of-00002.safetensors",
with safe_open(filename, framework="pytorch") as f: ]:
for k in f.keys(): filename = str(Path(speculator_path) / fname)
if k in routing and routing[k] != filename: routing = weights.routing
raise RuntimeError( with safe_open(filename, framework="pytorch") as f:
f"Key {k} was found in multiple files: {filename} and {routing[k]}" for k in f.keys():
) if k in routing and routing[k] != filename:
routing[k] = filename raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename
mlp_speculator = MLPSpeculatorModel(speculator_config, "speculator", weights) mlp_speculator = MLPSpeculatorModel(speculator_config, "speculator", weights)
lm_head = TensorParallelHead.load(speculator_config, prefix, weights) lm_head = TensorParallelHead.load(speculator_config, prefix, weights)
@ -776,9 +817,9 @@ class SpeculativeHead(nn.Module):
@staticmethod @staticmethod
def load(config, prefix: str, weights): def load(config, prefix: str, weights):
use_speculator = config.use_speculator speculator = config.speculator
if use_speculator: if speculator:
speculator_config = str(Path(use_speculator) / "config.json") speculator_config = str(Path(speculator) / "config.json")
with open(speculator_config, "r") as f: with open(speculator_config, "r") as f:
speculator_config = json.load(f) speculator_config = json.load(f)
@ -790,8 +831,7 @@ class SpeculativeHead(nn.Module):
architecture = speculator_config["architectures"][0] architecture = speculator_config["architectures"][0]
if architecture == "MLPSpeculatorPreTrainedModel": if architecture == "MLPSpeculatorPreTrainedModel":
speculator_config.use_speculator = config.use_speculator speculator = MLPSpeculatorHead.load(config, prefix, weights)
speculator = MLPSpeculatorHead.load(speculator_config, prefix, weights)
else: else:
speculator = None speculator = None
except KeyError: except KeyError:
@ -805,10 +845,10 @@ class SpeculativeHead(nn.Module):
return SpeculativeHead(lm_head, speculator) return SpeculativeHead(lm_head, speculator)
def forward( def forward(
self, input: torch.Tensor self, input: torch.Tensor, input_ids: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.medusa is not None: if self.speculator is not None:
return self.medusa(input) return self.speculator(input, input_ids)
assert self.head is not None assert self.head is not None
logits = self.head(input) logits = self.head(input)