diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index b741a84c..fdbfeafc 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -116,7 +116,7 @@ def download_weights( logger.info("Files are already present on the host. " "Skipping download.") return # Local files not found - except (utils.LocalEntryNotFoundError, FileNotFoundError): + except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError): pass is_local_model = (Path(model_id).exists() and Path(model_id).is_dir()) or os.getenv( @@ -137,6 +137,29 @@ def download_weights( except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): pass + try: + import json + medusa_head = hf_hub_download(model_id, revision=revision, filename="medusa_lm_head.pt") + if auto_convert: + medusa_sf = Path(medusa_head[:-len(".pt")] + ".safetensors") + if not medusa_sf.exists(): + utils.convert_files([Path(medusa_head)], [medusa_sf], []) + medusa_config = hf_hub_download(model_id, revision=revision, filename="config.json") + with open(medusa_config, "r") as f: + config = json.load(f) + + model_id = config["base_model_name_or_path"] + revision = "main" + try: + utils.weight_files(model_id, revision, extension) + logger.info(f"Files for parent {model_id} are already present on the host. " "Skipping download.") + return + # Local files not found + except (utils.LocalEntryNotFoundError, FileNotFoundError, utils.EntryNotFoundError): + pass + except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): + pass + # Try to download weights from the hub try: filenames = utils.weight_hub_files(model_id, revision, extension) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5b1b5715..44f1fddd 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -129,6 +129,17 @@ def get_model( config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) + + use_medusa = None + if "medusa_num_heads" in config_dict: + use_medusa = model_id + medusa_config = config_dict + model_id = config_dict["base_model_name_or_path"] + revision = "main" + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + model_type = config_dict["model_type"] if model_type == "gpt_bigcode": @@ -204,6 +215,7 @@ def get_model( quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code, + use_medusa=use_medusa ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index d2ed0b15..42a82a1f 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -28,6 +28,7 @@ class FlashLlama(FlashCausalLM): quantize: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + use_medusa: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -66,6 +67,17 @@ class FlashLlama(FlashCausalLM): weights._set_gptq_params(model_id) model = FlashLlamaForCausalLM(config, weights) + if use_medusa: + from text_generation_server.utils.medusa import MedusaModel + from huggingface_hub import hf_hub_download + import json + medusa_config = hf_hub_download(use_medusa, revision=revision, filename="config.json") + with open(medusa_config, "r") as f: + config = json.load(f) + medusa_head = hf_hub_download(use_medusa, revision=revision, filename="medusa_lm_head.pt") + medusa_sf = medusa_head[:-len(".pt")] + ".safetensors" + weights = Weights([medusa_sf], device, dtype, process_group=self.process_group) + model.lm_head = MedusaModel(config, weights) torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( diff --git a/server/text_generation_server/utils/medusa.py b/server/text_generation_server/utils/medusa.py new file mode 100644 index 00000000..a6230016 --- /dev/null +++ b/server/text_generation_server/utils/medusa.py @@ -0,0 +1,38 @@ +import torch +from text_generation_server.utils.layers import TensorParallelHead, FastLinear + + +class ResBlock(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.linear = FastLinear.load(config, prefix=f"{prefix}.linear", weights=weights, bias=True) + self.act = torch.nn.SiLU() + + def forward(self, x): + return x + self.act(self.linear(x)) + + +class MedusaModel(torch.nn.Module): + def __init__( + self, + config, + weights + ): + super().__init__() + self.heads = torch.nn.ModuleList( + [MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config["medusa_num_heads"])] + ) + + +class MedusaHead(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.blocks = torch.nn.ModuleList([ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config["medusa_num_layers"])]) + n = len(self.blocks) + self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = self.out(x) + return x