mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-23 16:02:10 +00:00
Tmp work for medusa.
This commit is contained in:
parent
3c71c656c7
commit
94a0bf1bbc
@ -116,7 +116,7 @@ def download_weights(
|
|||||||
logger.info("Files are already present on the host. " "Skipping download.")
|
logger.info("Files are already present on the host. " "Skipping download.")
|
||||||
return
|
return
|
||||||
# Local files not found
|
# Local files not found
|
||||||
except (utils.LocalEntryNotFoundError, FileNotFoundError):
|
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
|
is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv(
|
||||||
@ -137,6 +137,29 @@ def download_weights(
|
|||||||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt")
|
||||||
|
if auto_convert:
|
||||||
|
medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors")
|
||||||
|
if not medusa_sf.exists():
|
||||||
|
utils.convert_files([Path(medusa_head)], [medusa_sf], [])
|
||||||
|
medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json")
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
model_id = config["base_model_name_or_path"]
|
||||||
|
revision = "main"
|
||||||
|
try:
|
||||||
|
utils.weight_files(model_id, revision, extension)
|
||||||
|
logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.")
|
||||||
|
return
|
||||||
|
# Local files not found
|
||||||
|
except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
# Try to download weights from the hub
|
# Try to download weights from the hub
|
||||||
try:
|
try:
|
||||||
filenames = utils.weight_hub_files(model_id, revision, extension)
|
filenames = utils.weight_hub_files(model_id, revision, extension)
|
||||||
|
@ -129,6 +129,17 @@ def get_model(
|
|||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
|
|
||||||
|
use_medusa = None
|
||||||
|
if "medusa_num_heads" in config_dict:
|
||||||
|
use_medusa = model_id
|
||||||
|
medusa_config = config_dict
|
||||||
|
model_id = config_dict["base_model_name_or_path"]
|
||||||
|
revision = "main"
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
model_type = config_dict["model_type"]
|
model_type = config_dict["model_type"]
|
||||||
|
|
||||||
if model_type == "gpt_bigcode":
|
if model_type == "gpt_bigcode":
|
||||||
@ -204,6 +215,7 @@ def get_model(
|
|||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
use_medusa=use_medusa
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
|
||||||
|
@ -28,6 +28,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -66,6 +67,17 @@ class FlashLlama(FlashCausalLM):
|
|||||||
weights._set_gptq_params(model_id)
|
weights._set_gptq_params(model_id)
|
||||||
|
|
||||||
model = FlashLlamaForCausalLM(config, weights)
|
model = FlashLlamaForCausalLM(config, weights)
|
||||||
|
if use_medusa:
|
||||||
|
from text_generation_server.utils.medusa import MedusaModel
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
import json
|
||||||
|
medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json")
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt")
|
||||||
|
medusa_sf = medusa_head[:-len(".pt")] + ".safetensors"
|
||||||
|
weights = Weights([medusa_sf], device, dtype, process_group=self.process_group)
|
||||||
|
model.lm_head = MedusaModel(config, weights)
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashLlama, self).__init__(
|
super(FlashLlama, self).__init__(
|
||||||
|
38
server/text_generation_server/utils/medusa.py
Normal file
38
server/text_generation_server/utils/medusa.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import torch
|
||||||
|
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True)
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x + self.act(self.linear(x))
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
weights
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.heads = torch.nn.ModuleList(
|
||||||
|
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaHead(torch.nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])])
|
||||||
|
n = len(self.blocks)
|
||||||
|
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
x = self.out(x)
|
||||||
|
return x
|
Loading…
Reference in New Issue
Block a user