Hardcode a few stuff to make it work.

This commit is contained in:
Nicolas Patry 2024-05-06 14:03:05 +00:00
parent 453e91f755
commit 38d6045443
4 changed files with 179 additions and 100 deletions

View File

@ -136,7 +136,7 @@ def get_model(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
use_medusa = None
speculator = None
if "medusa_num_heads" in config_dict:
medusa_model_id = model_id
medusa_revision = revision
@ -166,11 +166,50 @@ def get_model(
revision=medusa_revision,
filename="medusa_lm_head.safetensors",
)
use_medusa = Path(medusa_config).parent
speculator = Path(medusa_config).parent
else:
use_medusa = Path(medusa_model_id)
speculator = Path(medusa_model_id)
method = "medusa"
elif config_dict["model_type"] == "mlp_speculator":
# TODO make this not hardcoded.
mlp_model_id = model_id
mlp_revision = revision
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
revision = "main"
speculate_mlp = config_dict["n_predict"]
if speculate is not None:
if speculate > speculate_mlp:
raise RuntimeError(
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
set_speculate(speculate_mlp)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
is_local = Path(mlp_model_id).exists()
if not is_local:
mlp_speculator_config = hf_hub_download(
mlp_model_id, revision=mlp_revision, filename="config.json"
)
hf_hub_download(
mlp_model_id,
revision=mlp_revision,
filename="model-00001-of-00002.safetensors",
)
hf_hub_download(
mlp_model_id,
revision=mlp_revision,
filename="model-00002-of-00002.safetensors",
)
speculator = Path(mlp_speculator_config).parent
else:
speculator = Path(mlp_model_id)
method = "mlp_speculator"
else:
method = "n-gram"
@ -202,7 +241,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -212,7 +251,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -227,7 +266,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -240,7 +279,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -250,7 +289,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -259,7 +298,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -270,7 +309,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -279,7 +318,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -288,7 +327,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -299,7 +338,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -308,7 +347,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -323,7 +362,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -334,7 +373,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -345,7 +384,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -355,7 +394,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -366,7 +405,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -377,7 +416,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -388,7 +427,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -399,7 +438,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -410,7 +449,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -424,7 +463,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -435,7 +474,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -444,7 +483,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -458,7 +497,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -469,7 +508,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -483,7 +522,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -494,7 +533,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -520,7 +559,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -544,7 +583,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -554,7 +593,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -564,7 +603,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -574,7 +613,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -586,7 +625,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -599,7 +638,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -623,7 +662,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -632,7 +671,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -644,7 +683,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -653,7 +692,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -419,5 +419,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
logits, speculative_logits = self.lm_head(hidden_states, input_ids)
return logits, speculative_logits

View File

@ -27,7 +27,7 @@ class FlashLlama(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -71,7 +71,7 @@ class FlashLlama(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
config.speculator = speculator
torch.distributed.barrier(group=self.process_group)

View File

@ -442,6 +442,7 @@ class ResBlock(torch.nn.Module):
def forward(self, x):
return x + self.act(self.linear(x))
class MLPSpeculatorLayerNorm(nn.Module):
"""
A L2 normalization implementation
@ -461,15 +462,14 @@ class MLPSpeculatorLayerNorm(nn.Module):
def __init__(
self,
normalized_shape,
elementwise_scale_weight: torch.Tensor,
elementwise_shift_bias: torch.Tensor,
prefix,
config,
weights,
eps=1e-06,
):
super(MLPSpeculatorLayerNorm, self).__init__()
self.normalized_shape = normalized_shape
self.weight = nn.Parameter(elementwise_scale_weight)
self.bias = nn.Parameter(elementwise_shift_bias)
self.weight = weights.get_tensor(f"{prefix}.weight")
self.bias = weights.get_tensor(f"{prefix}.bias")
self.eps = eps
def forward(self, x):
@ -480,48 +480,69 @@ class MLPSpeculatorLayerNorm(nn.Module):
x = x + self.bias
return x
class MLPSpeculatorModel(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.config = config
self.n_predict = config.n_predict
self.emb_dim = config.emb_dim
inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim
self.inner_dim = inner_dim
self.config = config.vocab_size
self.n_predict = get_speculate()
self.hidden_size = config.hidden_size
self.emb = nn.ModuleList(
[TensorParallelEmbedding(f"{prefix}.emb.{i}", weights) for i in range(config.n_predict)]
[
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
for i in range(self.n_predict)
]
)
self.proj = TensorParallelColumnLinear.load_multi(
self.proj = [
TensorParallelColumnLinear.load(
config,
prefixes=[f"{prefix}.proj.{i}" for i in range(config.n_predict)],
prefix=f"{prefix}.proj.{i}",
weights=weights,
bias=False,
dim=0
)
for i in range(self.n_predict)
]
self.head = nn.ModuleList(
[TensorParallelRowLinear.load(config, f"{prefix}.head.{i}", weights, bias=False) for i in range(config.n_predict)]
[
TensorParallelRowLinear.load(
config, f"{prefix}.head.{i}", weights, bias=False
)
for i in range(self.n_predict)
]
)
self.ln = nn.ModuleList(
[
MLPSpeculatorLayerNorm(
config.inner_dim,
elementwise_scale_weight=weights.get_tensor(f"{prefix}.ln.{i}.weight"),
elementwise_shift_bias=weights.get_tensor(f"{prefix}.ln.{i}.bias")
prefix=f"{prefix}.ln.{i}",
config=config,
weights=weights,
)
for i in range(config.n_predict)
for i in range(self.n_predict)
]
)
# Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation
self.state_weight = 0.5 ** (0.5 / self.n_predict)
self.emb_weight = math.sqrt(1 - self.state_weight ** 2)
self.emb_weight = math.sqrt(1 - self.state_weight**2)
self.activation = nn.GELU()
# TODO
self.vsize = 128256
self.inner_dim = 3072
def forward(self, state: torch.Tensor, ind: torch.Tensor, top_k_tokens_per_head: Optional[List[int]], num_candidates: int = 1):
def forward(
self,
state: torch.Tensor,
input_ids: torch.Tensor,
top_k_tokens_per_head: Optional[List[int]] = None,
num_candidates: int = 1,
):
# TODO
top_k_tokens_per_head = [1, 1, 1, 1]
if top_k_tokens_per_head is None:
top_k_tokens_per_head = self.config.top_k_tokens_per_head
ind = input_ids
# k indicates # of candidates
# h indicates # of generated tokens
b = state.size(0)
@ -537,23 +558,40 @@ class MLPSpeculatorModel(torch.nn.Module):
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
state = self.proj[i](state) * self.state_weight + z
state = self.activation(self.ln[i](state)) # b k d
_probs = F.log_softmax(self.head[i](state), dim=2) # b k v
probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=2) # b k k'
_probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
# Update candidate set with new predictions
out = out.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' h
out = torch.cat([out, preds.unsqueeze(3)], dim=3) # b k k' h+1
out = out.unsqueeze(2).expand(
-1, -1, top_k_tokens_per_head[i], -1
) # b k k' h
try:
out = torch.cat(
[out, preds.unsqueeze(2).unsqueeze(3)], dim=-1
) # b k k' h+1
except Exception:
import ipdb
ipdb.set_trace()
out = out.view(b, -1, i + 1) # b kk' h+1
# Update distribution set with new logits
all_probs = torch.cat([all_probs, _probs.exp().unsqueeze(2)], dim=2) # b k h+1 v
all_probs = all_probs.repeat(1, top_k_tokens_per_head[i], 1, 1) # b kk' h+1 v
all_probs = torch.cat(
[all_probs, _probs.exp().unsqueeze(2)], dim=-1
) # b k h+1 v
all_probs = all_probs.repeat(
1, top_k_tokens_per_head[i], 1, 1
) # b kk' h+1 v
# Update state, log_probs and ind for new predictions
state = state.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' d
state = state.unsqueeze(2).expand(
-1, -1, top_k_tokens_per_head[i], -1
) # b k k' d
state = state.reshape(b, -1, state.size(3)) # b kk' d
ind = preds.view(b, -1) # b kk'
log_probs = log_probs.unsqueeze(2).expand(b, -1, top_k_tokens_per_head[i]) # b k k'
log_probs = log_probs.unsqueeze(2).expand(
b, -1, top_k_tokens_per_head[i]
) # b k k'
log_probs = log_probs.add(probs).reshape(b, -1) # b kk'
# Take only top n best guesses
@ -570,14 +608,14 @@ class MLPSpeculatorHead(nn.Module):
self.mlp_speculator = mlp_speculator
def forward(
self, input: torch.Tensor
self, input: torch.Tensor, input_ids: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
# If we have too many tokens, we skip speculative logits
if input.shape[0] > 128:
return logits, None
speculative_logits = self.mlp_speculator(input)
speculative_logits = self.mlp_speculator(input, input_ids)
return logits, speculative_logits
@staticmethod
@ -585,10 +623,13 @@ class MLPSpeculatorHead(nn.Module):
from pathlib import Path
from safetensors import safe_open
speculator_path = speculator_config.use_speculator
filename = str(Path(speculator_path) / "*.safetensors")
speculator_path = speculator_config.speculator
for fname in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
]:
filename = str(Path(speculator_path) / fname)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
@ -776,9 +817,9 @@ class SpeculativeHead(nn.Module):
@staticmethod
def load(config, prefix: str, weights):
use_speculator = config.use_speculator
if use_speculator:
speculator_config = str(Path(use_speculator) / "config.json")
speculator = config.speculator
if speculator:
speculator_config = str(Path(speculator) / "config.json")
with open(speculator_config, "r") as f:
speculator_config = json.load(f)
@ -790,8 +831,7 @@ class SpeculativeHead(nn.Module):
architecture = speculator_config["architectures"][0]
if architecture == "MLPSpeculatorPreTrainedModel":
speculator_config.use_speculator = config.use_speculator
speculator = MLPSpeculatorHead.load(speculator_config, prefix, weights)
speculator = MLPSpeculatorHead.load(config, prefix, weights)
else:
speculator = None
except KeyError:
@ -805,10 +845,10 @@ class SpeculativeHead(nn.Module):
return SpeculativeHead(lm_head, speculator)
def forward(
self, input: torch.Tensor
self, input: torch.Tensor, input_ids: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.medusa is not None:
return self.medusa(input)
if self.speculator is not None:
return self.speculator(input, input_ids)
assert self.head is not None
logits = self.head(input)