mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Hardcode a few stuff to make it work.
This commit is contained in:
parent
453e91f755
commit
38d6045443
@ -136,7 +136,7 @@ def get_model(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
use_medusa = None
|
||||
speculator = None
|
||||
if "medusa_num_heads" in config_dict:
|
||||
medusa_model_id = model_id
|
||||
medusa_revision = revision
|
||||
@ -166,11 +166,50 @@ def get_model(
|
||||
revision=medusa_revision,
|
||||
filename="medusa_lm_head.safetensors",
|
||||
)
|
||||
use_medusa = Path(medusa_config).parent
|
||||
speculator = Path(medusa_config).parent
|
||||
else:
|
||||
use_medusa = Path(medusa_model_id)
|
||||
speculator = Path(medusa_model_id)
|
||||
|
||||
method = "medusa"
|
||||
elif config_dict["model_type"] == "mlp_speculator":
|
||||
# TODO make this not hardcoded.
|
||||
mlp_model_id = model_id
|
||||
mlp_revision = revision
|
||||
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
revision = "main"
|
||||
speculate_mlp = config_dict["n_predict"]
|
||||
if speculate is not None:
|
||||
if speculate > speculate_mlp:
|
||||
raise RuntimeError(
|
||||
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
|
||||
)
|
||||
else:
|
||||
set_speculate(speculate)
|
||||
else:
|
||||
set_speculate(speculate_mlp)
|
||||
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
is_local = Path(mlp_model_id).exists()
|
||||
if not is_local:
|
||||
mlp_speculator_config = hf_hub_download(
|
||||
mlp_model_id, revision=mlp_revision, filename="config.json"
|
||||
)
|
||||
hf_hub_download(
|
||||
mlp_model_id,
|
||||
revision=mlp_revision,
|
||||
filename="model-00001-of-00002.safetensors",
|
||||
)
|
||||
hf_hub_download(
|
||||
mlp_model_id,
|
||||
revision=mlp_revision,
|
||||
filename="model-00002-of-00002.safetensors",
|
||||
)
|
||||
speculator = Path(mlp_speculator_config).parent
|
||||
else:
|
||||
speculator = Path(mlp_model_id)
|
||||
method = "mlp_speculator"
|
||||
else:
|
||||
method = "n-gram"
|
||||
|
||||
@ -202,7 +241,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -212,7 +251,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -227,7 +266,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -240,7 +279,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -250,7 +289,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -259,7 +298,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -270,7 +309,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -279,7 +318,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -288,7 +327,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -299,7 +338,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -308,7 +347,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -323,7 +362,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -334,7 +373,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -345,7 +384,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -355,7 +394,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -366,7 +405,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -377,7 +416,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -388,7 +427,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -399,7 +438,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -410,7 +449,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -424,7 +463,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -435,7 +474,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -444,7 +483,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -458,7 +497,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -469,7 +508,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -483,7 +522,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -494,7 +533,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -520,7 +559,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -544,7 +583,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -554,7 +593,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -564,7 +603,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -574,7 +613,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -586,7 +625,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -599,7 +638,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -623,7 +662,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -632,7 +671,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -644,7 +683,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -653,7 +692,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -419,5 +419,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
logits, speculative_logits = self.lm_head(hidden_states, input_ids)
|
||||
return logits, speculative_logits
|
||||
|
@ -27,7 +27,7 @@ class FlashLlama(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -71,7 +71,7 @@ class FlashLlama(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -442,6 +442,7 @@ class ResBlock(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + self.act(self.linear(x))
|
||||
|
||||
|
||||
class MLPSpeculatorLayerNorm(nn.Module):
|
||||
"""
|
||||
A L2 normalization implementation
|
||||
@ -461,15 +462,14 @@ class MLPSpeculatorLayerNorm(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape,
|
||||
elementwise_scale_weight: torch.Tensor,
|
||||
elementwise_shift_bias: torch.Tensor,
|
||||
prefix,
|
||||
config,
|
||||
weights,
|
||||
eps=1e-06,
|
||||
):
|
||||
super(MLPSpeculatorLayerNorm, self).__init__()
|
||||
self.normalized_shape = normalized_shape
|
||||
self.weight = nn.Parameter(elementwise_scale_weight)
|
||||
self.bias = nn.Parameter(elementwise_shift_bias)
|
||||
self.weight = weights.get_tensor(f"{prefix}.weight")
|
||||
self.bias = weights.get_tensor(f"{prefix}.bias")
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
@ -480,36 +480,44 @@ class MLPSpeculatorLayerNorm(nn.Module):
|
||||
x = x + self.bias
|
||||
return x
|
||||
|
||||
|
||||
class MLPSpeculatorModel(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.n_predict = config.n_predict
|
||||
self.emb_dim = config.emb_dim
|
||||
inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim
|
||||
self.inner_dim = inner_dim
|
||||
self.config = config.vocab_size
|
||||
self.n_predict = get_speculate()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.emb = nn.ModuleList(
|
||||
[TensorParallelEmbedding(f"{prefix}.emb.{i}", weights) for i in range(config.n_predict)]
|
||||
[
|
||||
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
|
||||
for i in range(self.n_predict)
|
||||
]
|
||||
)
|
||||
self.proj = TensorParallelColumnLinear.load_multi(
|
||||
self.proj = [
|
||||
TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefixes=[f"{prefix}.proj.{i}" for i in range(config.n_predict)],
|
||||
prefix=f"{prefix}.proj.{i}",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
dim=0
|
||||
)
|
||||
for i in range(self.n_predict)
|
||||
]
|
||||
self.head = nn.ModuleList(
|
||||
[TensorParallelRowLinear.load(config, f"{prefix}.head.{i}", weights, bias=False) for i in range(config.n_predict)]
|
||||
[
|
||||
TensorParallelRowLinear.load(
|
||||
config, f"{prefix}.head.{i}", weights, bias=False
|
||||
)
|
||||
for i in range(self.n_predict)
|
||||
]
|
||||
)
|
||||
self.ln = nn.ModuleList(
|
||||
[
|
||||
MLPSpeculatorLayerNorm(
|
||||
config.inner_dim,
|
||||
elementwise_scale_weight=weights.get_tensor(f"{prefix}.ln.{i}.weight"),
|
||||
elementwise_shift_bias=weights.get_tensor(f"{prefix}.ln.{i}.bias")
|
||||
prefix=f"{prefix}.ln.{i}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for i in range(config.n_predict)
|
||||
for i in range(self.n_predict)
|
||||
]
|
||||
)
|
||||
|
||||
@ -517,11 +525,24 @@ class MLPSpeculatorModel(torch.nn.Module):
|
||||
self.state_weight = 0.5 ** (0.5 / self.n_predict)
|
||||
self.emb_weight = math.sqrt(1 - self.state_weight**2)
|
||||
self.activation = nn.GELU()
|
||||
# TODO
|
||||
self.vsize = 128256
|
||||
self.inner_dim = 3072
|
||||
|
||||
def forward(self, state: torch.Tensor, ind: torch.Tensor, top_k_tokens_per_head: Optional[List[int]], num_candidates: int = 1):
|
||||
def forward(
|
||||
self,
|
||||
state: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
top_k_tokens_per_head: Optional[List[int]] = None,
|
||||
num_candidates: int = 1,
|
||||
):
|
||||
# TODO
|
||||
top_k_tokens_per_head = [1, 1, 1, 1]
|
||||
if top_k_tokens_per_head is None:
|
||||
top_k_tokens_per_head = self.config.top_k_tokens_per_head
|
||||
|
||||
ind = input_ids
|
||||
|
||||
# k indicates # of candidates
|
||||
# h indicates # of generated tokens
|
||||
b = state.size(0)
|
||||
@ -537,23 +558,40 @@ class MLPSpeculatorModel(torch.nn.Module):
|
||||
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
||||
state = self.proj[i](state) * self.state_weight + z
|
||||
state = self.activation(self.ln[i](state)) # b k d
|
||||
_probs = F.log_softmax(self.head[i](state), dim=2) # b k v
|
||||
probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=2) # b k k'
|
||||
_probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
||||
probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||
|
||||
# Update candidate set with new predictions
|
||||
out = out.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' h
|
||||
out = torch.cat([out, preds.unsqueeze(3)], dim=3) # b k k' h+1
|
||||
out = out.unsqueeze(2).expand(
|
||||
-1, -1, top_k_tokens_per_head[i], -1
|
||||
) # b k k' h
|
||||
try:
|
||||
out = torch.cat(
|
||||
[out, preds.unsqueeze(2).unsqueeze(3)], dim=-1
|
||||
) # b k k' h+1
|
||||
except Exception:
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
out = out.view(b, -1, i + 1) # b kk' h+1
|
||||
|
||||
# Update distribution set with new logits
|
||||
all_probs = torch.cat([all_probs, _probs.exp().unsqueeze(2)], dim=2) # b k h+1 v
|
||||
all_probs = all_probs.repeat(1, top_k_tokens_per_head[i], 1, 1) # b kk' h+1 v
|
||||
all_probs = torch.cat(
|
||||
[all_probs, _probs.exp().unsqueeze(2)], dim=-1
|
||||
) # b k h+1 v
|
||||
all_probs = all_probs.repeat(
|
||||
1, top_k_tokens_per_head[i], 1, 1
|
||||
) # b kk' h+1 v
|
||||
|
||||
# Update state, log_probs and ind for new predictions
|
||||
state = state.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' d
|
||||
state = state.unsqueeze(2).expand(
|
||||
-1, -1, top_k_tokens_per_head[i], -1
|
||||
) # b k k' d
|
||||
state = state.reshape(b, -1, state.size(3)) # b kk' d
|
||||
ind = preds.view(b, -1) # b kk'
|
||||
log_probs = log_probs.unsqueeze(2).expand(b, -1, top_k_tokens_per_head[i]) # b k k'
|
||||
log_probs = log_probs.unsqueeze(2).expand(
|
||||
b, -1, top_k_tokens_per_head[i]
|
||||
) # b k k'
|
||||
log_probs = log_probs.add(probs).reshape(b, -1) # b kk'
|
||||
|
||||
# Take only top n best guesses
|
||||
@ -570,14 +608,14 @@ class MLPSpeculatorHead(nn.Module):
|
||||
self.mlp_speculator = mlp_speculator
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
self, input: torch.Tensor, input_ids: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
logits = self.lm_head(input)
|
||||
# If we have too many tokens, we skip speculative logits
|
||||
if input.shape[0] > 128:
|
||||
return logits, None
|
||||
|
||||
speculative_logits = self.mlp_speculator(input)
|
||||
speculative_logits = self.mlp_speculator(input, input_ids)
|
||||
return logits, speculative_logits
|
||||
|
||||
@staticmethod
|
||||
@ -585,10 +623,13 @@ class MLPSpeculatorHead(nn.Module):
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
|
||||
speculator_path = speculator_config.use_speculator
|
||||
|
||||
filename = str(Path(speculator_path) / "*.safetensors")
|
||||
speculator_path = speculator_config.speculator
|
||||
|
||||
for fname in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
]:
|
||||
filename = str(Path(speculator_path) / fname)
|
||||
routing = weights.routing
|
||||
with safe_open(filename, framework="pytorch") as f:
|
||||
for k in f.keys():
|
||||
@ -776,9 +817,9 @@ class SpeculativeHead(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def load(config, prefix: str, weights):
|
||||
use_speculator = config.use_speculator
|
||||
if use_speculator:
|
||||
speculator_config = str(Path(use_speculator) / "config.json")
|
||||
speculator = config.speculator
|
||||
if speculator:
|
||||
speculator_config = str(Path(speculator) / "config.json")
|
||||
|
||||
with open(speculator_config, "r") as f:
|
||||
speculator_config = json.load(f)
|
||||
@ -790,8 +831,7 @@ class SpeculativeHead(nn.Module):
|
||||
architecture = speculator_config["architectures"][0]
|
||||
|
||||
if architecture == "MLPSpeculatorPreTrainedModel":
|
||||
speculator_config.use_speculator = config.use_speculator
|
||||
speculator = MLPSpeculatorHead.load(speculator_config, prefix, weights)
|
||||
speculator = MLPSpeculatorHead.load(config, prefix, weights)
|
||||
else:
|
||||
speculator = None
|
||||
except KeyError:
|
||||
@ -805,10 +845,10 @@ class SpeculativeHead(nn.Module):
|
||||
return SpeculativeHead(lm_head, speculator)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
self, input: torch.Tensor, input_ids: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if self.medusa is not None:
|
||||
return self.medusa(input)
|
||||
if self.speculator is not None:
|
||||
return self.speculator(input, input_ids)
|
||||
|
||||
assert self.head is not None
|
||||
logits = self.head(input)
|
||||
|
Loading…
Reference in New Issue
Block a user