mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Hardcode a few stuff to make it work.
This commit is contained in:
parent
453e91f755
commit
38d6045443
@ -136,7 +136,7 @@ def get_model(
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
|
||||||
use_medusa = None
|
speculator = None
|
||||||
if "medusa_num_heads" in config_dict:
|
if "medusa_num_heads" in config_dict:
|
||||||
medusa_model_id = model_id
|
medusa_model_id = model_id
|
||||||
medusa_revision = revision
|
medusa_revision = revision
|
||||||
@ -166,11 +166,50 @@ def get_model(
|
|||||||
revision=medusa_revision,
|
revision=medusa_revision,
|
||||||
filename="medusa_lm_head.safetensors",
|
filename="medusa_lm_head.safetensors",
|
||||||
)
|
)
|
||||||
use_medusa = Path(medusa_config).parent
|
speculator = Path(medusa_config).parent
|
||||||
else:
|
else:
|
||||||
use_medusa = Path(medusa_model_id)
|
speculator = Path(medusa_model_id)
|
||||||
|
|
||||||
method = "medusa"
|
method = "medusa"
|
||||||
|
elif config_dict["model_type"] == "mlp_speculator":
|
||||||
|
# TODO make this not hardcoded.
|
||||||
|
mlp_model_id = model_id
|
||||||
|
mlp_revision = revision
|
||||||
|
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
revision = "main"
|
||||||
|
speculate_mlp = config_dict["n_predict"]
|
||||||
|
if speculate is not None:
|
||||||
|
if speculate > speculate_mlp:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate)
|
||||||
|
else:
|
||||||
|
set_speculate(speculate_mlp)
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
is_local = Path(mlp_model_id).exists()
|
||||||
|
if not is_local:
|
||||||
|
mlp_speculator_config = hf_hub_download(
|
||||||
|
mlp_model_id, revision=mlp_revision, filename="config.json"
|
||||||
|
)
|
||||||
|
hf_hub_download(
|
||||||
|
mlp_model_id,
|
||||||
|
revision=mlp_revision,
|
||||||
|
filename="model-00001-of-00002.safetensors",
|
||||||
|
)
|
||||||
|
hf_hub_download(
|
||||||
|
mlp_model_id,
|
||||||
|
revision=mlp_revision,
|
||||||
|
filename="model-00002-of-00002.safetensors",
|
||||||
|
)
|
||||||
|
speculator = Path(mlp_speculator_config).parent
|
||||||
|
else:
|
||||||
|
speculator = Path(mlp_model_id)
|
||||||
|
method = "mlp_speculator"
|
||||||
else:
|
else:
|
||||||
method = "n-gram"
|
method = "n-gram"
|
||||||
|
|
||||||
@ -202,7 +241,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -212,7 +251,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -227,7 +266,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -240,7 +279,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -250,7 +289,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -259,7 +298,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -270,7 +309,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -279,7 +318,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -288,7 +327,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -299,7 +338,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -308,7 +347,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -323,7 +362,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -334,7 +373,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -345,7 +384,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -355,7 +394,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -366,7 +405,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -377,7 +416,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -388,7 +427,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -399,7 +438,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -410,7 +449,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -424,7 +463,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -435,7 +474,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -444,7 +483,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -458,7 +497,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -469,7 +508,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -483,7 +522,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -494,7 +533,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -520,7 +559,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -544,7 +583,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -554,7 +593,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -564,7 +603,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -574,7 +613,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -586,7 +625,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -599,7 +638,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -623,7 +662,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -632,7 +671,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -644,7 +683,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -653,7 +692,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -419,5 +419,5 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
logits, speculative_logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states, input_ids)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
@ -27,7 +27,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -71,7 +71,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -442,6 +442,7 @@ class ResBlock(torch.nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x + self.act(self.linear(x))
|
return x + self.act(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
class MLPSpeculatorLayerNorm(nn.Module):
|
class MLPSpeculatorLayerNorm(nn.Module):
|
||||||
"""
|
"""
|
||||||
A L2 normalization implementation
|
A L2 normalization implementation
|
||||||
@ -461,15 +462,14 @@ class MLPSpeculatorLayerNorm(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
normalized_shape,
|
prefix,
|
||||||
elementwise_scale_weight: torch.Tensor,
|
config,
|
||||||
elementwise_shift_bias: torch.Tensor,
|
weights,
|
||||||
eps=1e-06,
|
eps=1e-06,
|
||||||
):
|
):
|
||||||
super(MLPSpeculatorLayerNorm, self).__init__()
|
super(MLPSpeculatorLayerNorm, self).__init__()
|
||||||
self.normalized_shape = normalized_shape
|
self.weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
self.weight = nn.Parameter(elementwise_scale_weight)
|
self.bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
self.bias = nn.Parameter(elementwise_shift_bias)
|
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -480,36 +480,44 @@ class MLPSpeculatorLayerNorm(nn.Module):
|
|||||||
x = x + self.bias
|
x = x + self.bias
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class MLPSpeculatorModel(torch.nn.Module):
|
class MLPSpeculatorModel(torch.nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.n_predict = config.n_predict
|
self.n_predict = get_speculate()
|
||||||
self.emb_dim = config.emb_dim
|
self.hidden_size = config.hidden_size
|
||||||
inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim
|
|
||||||
self.inner_dim = inner_dim
|
|
||||||
self.config = config.vocab_size
|
|
||||||
self.emb = nn.ModuleList(
|
self.emb = nn.ModuleList(
|
||||||
[TensorParallelEmbedding(f"{prefix}.emb.{i}", weights) for i in range(config.n_predict)]
|
[
|
||||||
|
TensorParallelEmbedding(f"{prefix}.emb.{i}", weights)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.proj = TensorParallelColumnLinear.load_multi(
|
self.proj = [
|
||||||
|
TensorParallelColumnLinear.load(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.proj.{i}" for i in range(config.n_predict)],
|
prefix=f"{prefix}.proj.{i}",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
dim=0
|
|
||||||
)
|
)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
self.head = nn.ModuleList(
|
self.head = nn.ModuleList(
|
||||||
[TensorParallelRowLinear.load(config, f"{prefix}.head.{i}", weights, bias=False) for i in range(config.n_predict)]
|
[
|
||||||
|
TensorParallelRowLinear.load(
|
||||||
|
config, f"{prefix}.head.{i}", weights, bias=False
|
||||||
|
)
|
||||||
|
for i in range(self.n_predict)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
self.ln = nn.ModuleList(
|
self.ln = nn.ModuleList(
|
||||||
[
|
[
|
||||||
MLPSpeculatorLayerNorm(
|
MLPSpeculatorLayerNorm(
|
||||||
config.inner_dim,
|
prefix=f"{prefix}.ln.{i}",
|
||||||
elementwise_scale_weight=weights.get_tensor(f"{prefix}.ln.{i}.weight"),
|
config=config,
|
||||||
elementwise_shift_bias=weights.get_tensor(f"{prefix}.ln.{i}.bias")
|
weights=weights,
|
||||||
)
|
)
|
||||||
for i in range(config.n_predict)
|
for i in range(self.n_predict)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -517,11 +525,24 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
self.state_weight = 0.5 ** (0.5 / self.n_predict)
|
self.state_weight = 0.5 ** (0.5 / self.n_predict)
|
||||||
self.emb_weight = math.sqrt(1 - self.state_weight**2)
|
self.emb_weight = math.sqrt(1 - self.state_weight**2)
|
||||||
self.activation = nn.GELU()
|
self.activation = nn.GELU()
|
||||||
|
# TODO
|
||||||
|
self.vsize = 128256
|
||||||
|
self.inner_dim = 3072
|
||||||
|
|
||||||
def forward(self, state: torch.Tensor, ind: torch.Tensor, top_k_tokens_per_head: Optional[List[int]], num_candidates: int = 1):
|
def forward(
|
||||||
|
self,
|
||||||
|
state: torch.Tensor,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
top_k_tokens_per_head: Optional[List[int]] = None,
|
||||||
|
num_candidates: int = 1,
|
||||||
|
):
|
||||||
|
# TODO
|
||||||
|
top_k_tokens_per_head = [1, 1, 1, 1]
|
||||||
if top_k_tokens_per_head is None:
|
if top_k_tokens_per_head is None:
|
||||||
top_k_tokens_per_head = self.config.top_k_tokens_per_head
|
top_k_tokens_per_head = self.config.top_k_tokens_per_head
|
||||||
|
|
||||||
|
ind = input_ids
|
||||||
|
|
||||||
# k indicates # of candidates
|
# k indicates # of candidates
|
||||||
# h indicates # of generated tokens
|
# h indicates # of generated tokens
|
||||||
b = state.size(0)
|
b = state.size(0)
|
||||||
@ -537,23 +558,40 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
||||||
state = self.proj[i](state) * self.state_weight + z
|
state = self.proj[i](state) * self.state_weight + z
|
||||||
state = self.activation(self.ln[i](state)) # b k d
|
state = self.activation(self.ln[i](state)) # b k d
|
||||||
_probs = F.log_softmax(self.head[i](state), dim=2) # b k v
|
_probs = F.log_softmax(self.head[i](state), dim=-1) # b k v
|
||||||
probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=2) # b k k'
|
probs, preds = _probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||||
|
|
||||||
# Update candidate set with new predictions
|
# Update candidate set with new predictions
|
||||||
out = out.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' h
|
out = out.unsqueeze(2).expand(
|
||||||
out = torch.cat([out, preds.unsqueeze(3)], dim=3) # b k k' h+1
|
-1, -1, top_k_tokens_per_head[i], -1
|
||||||
|
) # b k k' h
|
||||||
|
try:
|
||||||
|
out = torch.cat(
|
||||||
|
[out, preds.unsqueeze(2).unsqueeze(3)], dim=-1
|
||||||
|
) # b k k' h+1
|
||||||
|
except Exception:
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
out = out.view(b, -1, i + 1) # b kk' h+1
|
out = out.view(b, -1, i + 1) # b kk' h+1
|
||||||
|
|
||||||
# Update distribution set with new logits
|
# Update distribution set with new logits
|
||||||
all_probs = torch.cat([all_probs, _probs.exp().unsqueeze(2)], dim=2) # b k h+1 v
|
all_probs = torch.cat(
|
||||||
all_probs = all_probs.repeat(1, top_k_tokens_per_head[i], 1, 1) # b kk' h+1 v
|
[all_probs, _probs.exp().unsqueeze(2)], dim=-1
|
||||||
|
) # b k h+1 v
|
||||||
|
all_probs = all_probs.repeat(
|
||||||
|
1, top_k_tokens_per_head[i], 1, 1
|
||||||
|
) # b kk' h+1 v
|
||||||
|
|
||||||
# Update state, log_probs and ind for new predictions
|
# Update state, log_probs and ind for new predictions
|
||||||
state = state.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' d
|
state = state.unsqueeze(2).expand(
|
||||||
|
-1, -1, top_k_tokens_per_head[i], -1
|
||||||
|
) # b k k' d
|
||||||
state = state.reshape(b, -1, state.size(3)) # b kk' d
|
state = state.reshape(b, -1, state.size(3)) # b kk' d
|
||||||
ind = preds.view(b, -1) # b kk'
|
ind = preds.view(b, -1) # b kk'
|
||||||
log_probs = log_probs.unsqueeze(2).expand(b, -1, top_k_tokens_per_head[i]) # b k k'
|
log_probs = log_probs.unsqueeze(2).expand(
|
||||||
|
b, -1, top_k_tokens_per_head[i]
|
||||||
|
) # b k k'
|
||||||
log_probs = log_probs.add(probs).reshape(b, -1) # b kk'
|
log_probs = log_probs.add(probs).reshape(b, -1) # b kk'
|
||||||
|
|
||||||
# Take only top n best guesses
|
# Take only top n best guesses
|
||||||
@ -570,14 +608,14 @@ class MLPSpeculatorHead(nn.Module):
|
|||||||
self.mlp_speculator = mlp_speculator
|
self.mlp_speculator = mlp_speculator
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input: torch.Tensor
|
self, input: torch.Tensor, input_ids: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
logits = self.lm_head(input)
|
logits = self.lm_head(input)
|
||||||
# If we have too many tokens, we skip speculative logits
|
# If we have too many tokens, we skip speculative logits
|
||||||
if input.shape[0] > 128:
|
if input.shape[0] > 128:
|
||||||
return logits, None
|
return logits, None
|
||||||
|
|
||||||
speculative_logits = self.mlp_speculator(input)
|
speculative_logits = self.mlp_speculator(input, input_ids)
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -585,10 +623,13 @@ class MLPSpeculatorHead(nn.Module):
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
|
||||||
speculator_path = speculator_config.use_speculator
|
speculator_path = speculator_config.speculator
|
||||||
|
|
||||||
filename = str(Path(speculator_path) / "*.safetensors")
|
|
||||||
|
|
||||||
|
for fname in [
|
||||||
|
"model-00001-of-00002.safetensors",
|
||||||
|
"model-00002-of-00002.safetensors",
|
||||||
|
]:
|
||||||
|
filename = str(Path(speculator_path) / fname)
|
||||||
routing = weights.routing
|
routing = weights.routing
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
for k in f.keys():
|
for k in f.keys():
|
||||||
@ -776,9 +817,9 @@ class SpeculativeHead(nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(config, prefix: str, weights):
|
def load(config, prefix: str, weights):
|
||||||
use_speculator = config.use_speculator
|
speculator = config.speculator
|
||||||
if use_speculator:
|
if speculator:
|
||||||
speculator_config = str(Path(use_speculator) / "config.json")
|
speculator_config = str(Path(speculator) / "config.json")
|
||||||
|
|
||||||
with open(speculator_config, "r") as f:
|
with open(speculator_config, "r") as f:
|
||||||
speculator_config = json.load(f)
|
speculator_config = json.load(f)
|
||||||
@ -790,8 +831,7 @@ class SpeculativeHead(nn.Module):
|
|||||||
architecture = speculator_config["architectures"][0]
|
architecture = speculator_config["architectures"][0]
|
||||||
|
|
||||||
if architecture == "MLPSpeculatorPreTrainedModel":
|
if architecture == "MLPSpeculatorPreTrainedModel":
|
||||||
speculator_config.use_speculator = config.use_speculator
|
speculator = MLPSpeculatorHead.load(config, prefix, weights)
|
||||||
speculator = MLPSpeculatorHead.load(speculator_config, prefix, weights)
|
|
||||||
else:
|
else:
|
||||||
speculator = None
|
speculator = None
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -805,10 +845,10 @@ class SpeculativeHead(nn.Module):
|
|||||||
return SpeculativeHead(lm_head, speculator)
|
return SpeculativeHead(lm_head, speculator)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input: torch.Tensor
|
self, input: torch.Tensor, input_ids: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
if self.medusa is not None:
|
if self.speculator is not None:
|
||||||
return self.medusa(input)
|
return self.speculator(input, input_ids)
|
||||||
|
|
||||||
assert self.head is not None
|
assert self.head is not None
|
||||||
logits = self.head(input)
|
logits = self.head(input)
|
||||||
|
Loading…
Reference in New Issue
Block a user