mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Removed a bunch of hardcodes.
This commit is contained in:
parent
1a8a18d541
commit
b884899086
@ -1,9 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download, HfApi
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -166,9 +167,15 @@ def get_model(
|
|||||||
revision=medusa_revision,
|
revision=medusa_revision,
|
||||||
filename="medusa_lm_head.safetensors",
|
filename="medusa_lm_head.safetensors",
|
||||||
)
|
)
|
||||||
speculator = Path(medusa_config).parent
|
speculator = {
|
||||||
|
"path": Path(medusa_config).parent,
|
||||||
|
"model_paths": ["medusa_lm_head.safetensors"],
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
speculator = Path(medusa_model_id)
|
speculator = {
|
||||||
|
"path": Path(medusa_model_id),
|
||||||
|
"model_paths": ["medusa_lm_head.safetensors"],
|
||||||
|
}
|
||||||
|
|
||||||
method = "medusa"
|
method = "medusa"
|
||||||
elif config_dict["model_type"] == "mlp_speculator":
|
elif config_dict["model_type"] == "mlp_speculator":
|
||||||
@ -192,23 +199,36 @@ def get_model(
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
is_local = Path(mlp_model_id).exists()
|
is_local = Path(mlp_model_id).exists()
|
||||||
|
extension = ".safetensors"
|
||||||
if not is_local:
|
if not is_local:
|
||||||
mlp_speculator_config = hf_hub_download(
|
mlp_speculator_config = hf_hub_download(
|
||||||
mlp_model_id, revision=mlp_revision, filename="config.json"
|
mlp_model_id, revision=mlp_revision, filename="config.json"
|
||||||
)
|
)
|
||||||
hf_hub_download(
|
api = HfApi()
|
||||||
mlp_model_id,
|
info = api.model_info(mlp_model_id, revision=mlp_revision)
|
||||||
revision=mlp_revision,
|
filenames = [
|
||||||
filename="model-00001-of-00002.safetensors",
|
s.rfilename
|
||||||
)
|
for s in info.siblings
|
||||||
hf_hub_download(
|
if s.rfilename.endswith(extension)
|
||||||
mlp_model_id,
|
and len(s.rfilename.split("/")) == 1
|
||||||
revision=mlp_revision,
|
and "arguments" not in s.rfilename
|
||||||
filename="model-00002-of-00002.safetensors",
|
and "args" not in s.rfilename
|
||||||
)
|
and "training" not in s.rfilename
|
||||||
speculator = Path(mlp_speculator_config).parent
|
]
|
||||||
|
for filename in filenames:
|
||||||
|
hf_hub_download(
|
||||||
|
mlp_model_id,
|
||||||
|
revision=mlp_revision,
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
speculator = {
|
||||||
|
"path": Path(mlp_speculator_config).parent,
|
||||||
|
"model_paths": filenames,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
speculator = Path(mlp_model_id)
|
speculator = Path(mlp_model_id)
|
||||||
|
filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
|
||||||
|
speculator = {"path": speculator, "model_paths": filenames}
|
||||||
method = "mlp_speculator"
|
method = "mlp_speculator"
|
||||||
else:
|
else:
|
||||||
method = "n-gram"
|
method = "n-gram"
|
||||||
|
@ -525,10 +525,9 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
self.emb_weight = math.sqrt(1 - self.state_weight**2)
|
self.emb_weight = math.sqrt(1 - self.state_weight**2)
|
||||||
self.activation = nn.GELU()
|
self.activation = nn.GELU()
|
||||||
# TODO
|
# TODO
|
||||||
self.vsize = 128256
|
self.vsize = config.vocab_size
|
||||||
self.inner_dim = 3072
|
self.inner_dim = config.speculator_config["inner_dim"]
|
||||||
self.top_k_tokens_per_head = [1] * self.n_predict
|
self.top_k_tokens_per_head = [1] * self.n_predict
|
||||||
self.candidates = 1
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -536,27 +535,20 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
):
|
):
|
||||||
top_k_tokens_per_head = self.top_k_tokens_per_head
|
top_k_tokens_per_head = self.top_k_tokens_per_head
|
||||||
num_candidates = self.candidates
|
|
||||||
|
|
||||||
# if state.shape[0] > 1:
|
|
||||||
# state = state[:1]
|
|
||||||
|
|
||||||
# k indicates # of candidates
|
# k indicates # of candidates
|
||||||
# h indicates # of generated tokens
|
# h indicates # of generated tokens
|
||||||
state = hidden_states
|
state = hidden_states
|
||||||
b = state.size(0)
|
b = state.size(0)
|
||||||
ind = input_ids.unsqueeze(0)
|
ind = input_ids.unsqueeze(0)
|
||||||
out = torch.empty(1, b, self.n_predict, device=state.device).int() # b k h
|
|
||||||
# log_probs = torch.zeros(1, b, device=state.device) # b k
|
|
||||||
all_probs = torch.empty(
|
all_probs = torch.empty(
|
||||||
1, b, self.n_predict, self.vsize, device=state.device
|
b, self.n_predict, self.vsize, device=state.device
|
||||||
) # b k h v
|
) # b k h v
|
||||||
assert (
|
assert (
|
||||||
len(top_k_tokens_per_head) == self.n_predict
|
len(top_k_tokens_per_head) == self.n_predict
|
||||||
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
|
), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)"
|
||||||
for i in range(self.n_predict):
|
for i in range(self.n_predict):
|
||||||
# Project and predict
|
# Project and predict
|
||||||
# print(ind)
|
|
||||||
z = self.emb[i](ind)
|
z = self.emb[i](ind)
|
||||||
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
|
||||||
state = self.proj[i](state) * self.state_weight + z
|
state = self.proj[i](state) * self.state_weight + z
|
||||||
@ -565,10 +557,9 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
_probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k'
|
||||||
|
|
||||||
# Update candidate set with new predictions
|
# Update candidate set with new predictions
|
||||||
out[:, :, i : i + 1] = preds
|
|
||||||
|
|
||||||
# Update distribution set with new logits
|
# Update distribution set with new logits
|
||||||
all_probs[:, :, i] = probs.exp()
|
all_probs[:, i] = probs.exp()
|
||||||
|
|
||||||
# Update state, log_probs and ind for new predictions
|
# Update state, log_probs and ind for new predictions
|
||||||
state = state.unsqueeze(2).expand(
|
state = state.unsqueeze(2).expand(
|
||||||
@ -576,20 +567,8 @@ class MLPSpeculatorModel(torch.nn.Module):
|
|||||||
) # b k k' d
|
) # b k k' d
|
||||||
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
state = state.reshape(-1, b, state.size(3)) # b kk' d
|
||||||
ind = preds.view(-1, b) # b kk'
|
ind = preds.view(-1, b) # b kk'
|
||||||
# log_probs = log_probs.unsqueeze(2).expand(
|
|
||||||
# -1, b, top_k_tokens_per_head[i]
|
|
||||||
# ) # b k k'
|
|
||||||
# log_probs = log_probs.add(probs).reshape(-1, b) # b kk'
|
|
||||||
|
|
||||||
# print("done")
|
speculative_logits = all_probs
|
||||||
# Take only top n best guesses
|
|
||||||
# best_guesses = log_probs.topk(num_candidates, dim=1)[1] # b k
|
|
||||||
# speculative_logits = all_probs.gather(
|
|
||||||
# 1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize)
|
|
||||||
# ).squeeze(0)
|
|
||||||
speculative_logits = all_probs[0]
|
|
||||||
# assert list(speculative_logits.shape) == [hidden_states.shape[0], self.n_predict, self.vsize], f"{speculative_logits.shape}, {hidden_states.shape[0]} {self.n_predict} {self.vsize}"
|
|
||||||
# TODO Why is this shift existing, are speculative logits also including the natural next token ?
|
|
||||||
return speculative_logits
|
return speculative_logits
|
||||||
|
|
||||||
|
|
||||||
@ -612,16 +591,13 @@ class MLPSpeculatorHead(nn.Module):
|
|||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(speculator_config, prefix: str, weights):
|
def load(config, prefix: str, weights):
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
|
||||||
speculator_path = speculator_config.speculator
|
speculator_path = config.speculator["path"]
|
||||||
|
|
||||||
for fname in [
|
for fname in config.speculator["model_paths"]:
|
||||||
"model-00001-of-00002.safetensors",
|
|
||||||
"model-00002-of-00002.safetensors",
|
|
||||||
]:
|
|
||||||
filename = str(Path(speculator_path) / fname)
|
filename = str(Path(speculator_path) / fname)
|
||||||
routing = weights.routing
|
routing = weights.routing
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
@ -632,8 +608,8 @@ class MLPSpeculatorHead(nn.Module):
|
|||||||
)
|
)
|
||||||
routing[k] = filename
|
routing[k] = filename
|
||||||
|
|
||||||
mlp_speculator = MLPSpeculatorModel(speculator_config, "speculator", weights)
|
mlp_speculator = MLPSpeculatorModel(config, "speculator", weights)
|
||||||
lm_head = TensorParallelHead.load(speculator_config, prefix, weights)
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
return MLPSpeculatorHead(lm_head, mlp_speculator)
|
return MLPSpeculatorHead(lm_head, mlp_speculator)
|
||||||
|
|
||||||
|
|
||||||
@ -726,8 +702,9 @@ class MedusaHeadV2(nn.Module):
|
|||||||
|
|
||||||
speculator = config.speculator
|
speculator = config.speculator
|
||||||
|
|
||||||
medusa_config = str(Path(speculator) / "config.json")
|
path = Path(speculator["path"])
|
||||||
filename = str(Path(speculator) / "medusa_lm_head.safetensors")
|
medusa_config = str(path / "config.json")
|
||||||
|
filename = path / speculator["model_paths"][0]
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
medusa_config = json.load(f)
|
medusa_config = json.load(f)
|
||||||
@ -812,11 +789,14 @@ class SpeculativeHead(nn.Module):
|
|||||||
def load(config, prefix: str, weights):
|
def load(config, prefix: str, weights):
|
||||||
speculator = config.speculator
|
speculator = config.speculator
|
||||||
if speculator:
|
if speculator:
|
||||||
speculator_config = str(Path(speculator) / "config.json")
|
|
||||||
|
speculator_path = config.speculator["path"]
|
||||||
|
speculator_config = str(speculator_path / "config.json")
|
||||||
|
|
||||||
with open(speculator_config, "r") as f:
|
with open(speculator_config, "r") as f:
|
||||||
speculator_config = json.load(f)
|
speculator_config = json.load(f)
|
||||||
lm_head = None
|
lm_head = None
|
||||||
|
config.speculator_config = speculator_config
|
||||||
|
|
||||||
# currently medusa does not have an architecture specified, so try-except for now
|
# currently medusa does not have an architecture specified, so try-except for now
|
||||||
# this should really be handled in a better way though (maybe the classname can be part of the config)
|
# this should really be handled in a better way though (maybe the classname can be part of the config)
|
||||||
|
Loading…
Reference in New Issue
Block a user