mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Tmp work for medusa.
This commit is contained in:
parent
3c71c656c7
commit
94a0bf1bbc
@ -116,7 +116,7 @@ def download_weights(
|
||||
logger.info("Files are already present on the host. " "Skipping download.")
|
||||
return
|
||||
# Local files not found
|
||||
except (utils.LocalEntryNotFoundError, FileNotFoundError):
|
||||
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
|
||||
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
|
||||
@ -137,6 +137,29 @@ def download_weights(
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
|
||||
try:
|
||||
import json
|
||||
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
|
||||
if auto_convert:
|
||||
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors")
|
||||
if not medusa_sf.exists():
|
||||
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
||||
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json")
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
model_id = config["base_model_name_or_path"]
|
||||
revision = "main"
|
||||
try:
|
||||
utils.weight_files(model_id, revision, extension)
|
||||
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.")
|
||||
return
|
||||
# Local files not found
|
||||
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
|
||||
# Try to download weights from the hub
|
||||
try:
|
||||
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||||
|
@ -129,6 +129,17 @@ def get_model(
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
use_medusa = None
|
||||
if "medusa_num_heads" in config_dict:
|
||||
use_medusa = model_id
|
||||
medusa_config = config_dict
|
||||
model_id = config_dict["base_model_name_or_path"]
|
||||
revision = "main"
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
model_type = config_dict["model_type"]
|
||||
|
||||
if model_type == "gpt_bigcode":
|
||||
@ -204,6 +215,7 @@ def get_model(
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_medusa=use_medusa
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||
|
@ -28,6 +28,7 @@ class FlashLlama(FlashCausalLM):
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_medusa: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -66,6 +67,17 @@ class FlashLlama(FlashCausalLM):
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
model = FlashLlamaForCausalLM(config, weights)
|
||||
if use_medusa:
|
||||
from text_generation_server.utils.medusa import MedusaModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json")
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt")
|
||||
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors"
|
||||
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group)
|
||||
model.lm_head = MedusaModel(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
|
38
server/text_generation_server/utils/medusa.py
Normal file
38
server/text_generation_server/utils/medusa.py
Normal file
@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True)
|
||||
self.act = torch.nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.act(self.linear(x))
|
||||
|
||||
|
||||
class MedusaModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
weights
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = torch.nn.ModuleList(
|
||||
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
|
||||
)
|
||||
|
||||
|
||||
class MedusaHead(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])])
|
||||
n = len(self.blocks)
|
||||
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.out(x)
|
||||
return x
|
Loading…
Reference in New Issue
Block a user