From ed95f1982defd483be4e97c9ab8e535977e070dc Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 23 Feb 2024 21:13:34 +0000 Subject: [PATCH] Fix gemma + medusa. --- .../models/flash_gemma.py | 30 ------------------- 1 file changed, 30 deletions(-) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index beb12371..8cfb6631 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -60,36 +60,6 @@ class FlashGemma(FlashCausalLM): weights._set_gptq_params(model_id, revision) model = FlashGemmaForCausalLM(config, weights) - if use_medusa: - from text_generation_server.utils.medusa import MedusaModel - from huggingface_hub import hf_hub_download - import json - import os - from pathlib import Path - - is_local_model = ( - Path(use_medusa).exists() and Path(use_medusa).is_dir() - ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None - - if not is_local_model: - medusa_config = hf_hub_download( - use_medusa, revision=revision, filename="config.json" - ) - medusa_head = hf_hub_download( - use_medusa, revision=revision, filename="medusa_lm_head.pt" - ) - else: - medusa_config = str(Path(use_medusa) / "config.json") - medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") - - with open(medusa_config, "r") as f: - config = json.load(f) - medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" - weights = Weights( - [medusa_sf], device, dtype, process_group=self.process_group - ) - lm_head = model.lm_head - model.lm_head = MedusaModel(config, weights, lm_head) torch.distributed.barrier(group=self.process_group) super(FlashGemma, self).__init__(