diff --git a/server/text_generation_server/layers/medusa.py b/server/text_generation_server/layers/medusa.py index 4ac86978..b7f2aaf6 100644 --- a/server/text_generation_server/layers/medusa.py +++ b/server/text_generation_server/layers/medusa.py @@ -69,10 +69,10 @@ class MedusaHeadV1(nn.Module): from safetensors import safe_open import json - use_medusa = config.use_medusa + speculator = config.speculator - medusa_config = str(Path(use_medusa) / "config.json") - filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") + medusa_config = str(Path(speculator) / "config.json") + filename = str(Path(speculator) / "medusa_lm_head.safetensors") with open(medusa_config, "r") as f: medusa_config = json.load(f) @@ -108,10 +108,10 @@ class MedusaHeadV2(nn.Module): from safetensors import safe_open import json - use_medusa = config.use_medusa + speculator = config.speculator - medusa_config = str(Path(use_medusa) / "config.json") - filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") + medusa_config = str(Path(speculator) / "config.json") + filename = str(Path(speculator) / "medusa_lm_head.safetensors") with open(medusa_config, "r") as f: medusa_config = json.load(f) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 67129ec3..1e3dd10c 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -42,7 +42,7 @@ class BLOOMSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -71,7 +71,7 @@ class BLOOMSharded(CausalLM): ) config.pad_token_id = 3 config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 935f049b..51fd7c02 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -683,9 +683,9 @@ class Idefics2ForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() config.vision_config.quantize = config.quantize - config.vision_config.use_medusa = config.use_medusa + config.vision_config.speculator = config.speculator config.text_config.quantize = config.quantize - config.text_config.use_medusa = config.use_medusa + config.text_config.speculator = config.speculator vision_config = config.vision_config self.text_model = load_text_model( diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index a049f756..de9673aa 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -135,7 +135,7 @@ class LlavaNextForConditionalGeneration(nn.Module): self.vocab_size = config.text_config.vocab_size self.config = config config.text_config.quantize = config.quantize - config.text_config.use_medusa = config.use_medusa + config.text_config.speculator = config.speculator self.language_model = load_text_model( prefix="language_model" if not prefix else f"{prefix}.language_model", config=config.text_config, diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index f85c7722..b907ee08 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -24,7 +24,7 @@ class FlashCohere(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -49,7 +49,7 @@ class FlashCohere(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py index 367d3db0..d5eb1a6e 100644 --- a/server/text_generation_server/models/flash_dbrx.py +++ b/server/text_generation_server/models/flash_dbrx.py @@ -26,7 +26,7 @@ class FlashDbrx(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -74,7 +74,7 @@ class FlashDbrx(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 7259b820..9c00a056 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -25,7 +25,7 @@ class FlashGemma(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -50,7 +50,7 @@ class FlashGemma(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py index 2ee35e82..587d423f 100644 --- a/server/text_generation_server/models/flash_mixtral.py +++ b/server/text_generation_server/models/flash_mixtral.py @@ -15,7 +15,7 @@ class FlashMixtral(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -25,7 +25,7 @@ class FlashMixtral(BaseFlashMistral): model_id=model_id, revision=revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 1119bdae..adefaeb2 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -25,7 +25,7 @@ class FlashNeoXSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -51,7 +51,7 @@ class FlashNeoXSharded(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index cb55f9e6..32b573a9 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -25,7 +25,7 @@ class FlashPhi(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -48,7 +48,7 @@ class FlashPhi(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) @@ -58,7 +58,7 @@ class FlashPhi(FlashCausalLM): weights._set_gptq_params(model_id, revision) model = FlashPhiForCausalLM(config, weights) - if use_medusa: + if speculator: from text_generation_server.utils.medusa import MedusaModel from huggingface_hub import hf_hub_download import json @@ -66,19 +66,19 @@ class FlashPhi(FlashCausalLM): from pathlib import Path is_local_model = ( - Path(use_medusa).exists() and Path(use_medusa).is_dir() + Path(speculator).exists() and Path(speculator).is_dir() ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None if not is_local_model: medusa_config = hf_hub_download( - use_medusa, revision=revision, filename="config.json" + speculator, revision=revision, filename="config.json" ) medusa_head = hf_hub_download( - use_medusa, revision=revision, filename="medusa_lm_head.pt" + speculator, revision=revision, filename="medusa_lm_head.pt" ) else: - medusa_config = str(Path(use_medusa) / "config.json") - medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") + medusa_config = str(Path(speculator) / "config.json") + medusa_head = str(Path(speculator) / "medusa_lm_head.pt") with open(medusa_config, "r") as f: config = json.load(f) diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index cb3cf6b0..59064b30 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -30,7 +30,7 @@ class FlashQwen2(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -53,7 +53,7 @@ class FlashQwen2(BaseFlashMistral): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator # Set context windows if config.sliding_window is not None: diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 33298e1a..e6350611 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -26,7 +26,7 @@ class FlashRWSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -66,7 +66,7 @@ class FlashRWSharded(FlashCausalLM): ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator if config.quantize == "gptq": weights._set_gptq_params(model_id, revision) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 66698a3a..2ad36b93 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -29,7 +29,7 @@ class FlashSantacoderSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -57,7 +57,7 @@ class FlashSantacoderSharded(FlashCausalLM): trust_remote_code=True, ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator config.transpose = config.architectures[0].startswith("GPT2") torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 68e726d8..dc5d49be 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -29,7 +29,7 @@ class FlashStarcoder2(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -52,7 +52,7 @@ class FlashStarcoder2(BaseFlashMistral): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator # Set context windows if config.sliding_window is not None: diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index a46f86be..4656fd45 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -167,7 +167,7 @@ class GalacticaSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -195,7 +195,7 @@ class GalacticaSharded(CausalLM): ) config.quantize = quantize tokenizer.pad_token_id = config.pad_token_id - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 1c4cfe7d..c0e1adf2 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -24,7 +24,7 @@ class GPTNeoxSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -51,7 +51,7 @@ class GPTNeoxSharded(CausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index 30bf4aa6..c1fe03e4 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -31,7 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -52,7 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator config.vision_config.quantize = quantize tokenizer = LlamaTokenizerFast.from_pretrained( diff --git a/server/text_generation_server/models/idefics2.py b/server/text_generation_server/models/idefics2.py index e831af89..314c0500 100644 --- a/server/text_generation_server/models/idefics2.py +++ b/server/text_generation_server/models/idefics2.py @@ -18,7 +18,7 @@ class Idefics2(VlmCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -35,7 +35,7 @@ class Idefics2(VlmCausalLM): model_id=model_id, revision=revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/llava_next.py b/server/text_generation_server/models/llava_next.py index 3983bc85..effe8b91 100644 --- a/server/text_generation_server/models/llava_next.py +++ b/server/text_generation_server/models/llava_next.py @@ -18,7 +18,7 @@ class LlavaNext(VlmCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -30,7 +30,7 @@ class LlavaNext(VlmCausalLM): model_id=model_id, revision=revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 0884317e..b28b744f 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -408,7 +408,7 @@ class Mamba(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -445,7 +445,7 @@ class Mamba(Model): tokenizer.pad_token = tokenizer.eos_token config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 6b3f29a6..8d8b4909 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -43,7 +43,7 @@ class MPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -76,7 +76,7 @@ class MPTSharded(CausalLM): config = json.load(f) config = PretrainedConfig(**config) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 703e5b58..5b84f4ff 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -22,7 +22,7 @@ class OPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -48,7 +48,7 @@ class OPTSharded(CausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator tokenizer.pad_token_id = config.pad_token_id torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py index cc4e2505..d68866c1 100644 --- a/server/text_generation_server/models/phi.py +++ b/server/text_generation_server/models/phi.py @@ -22,7 +22,7 @@ class Phi(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -53,7 +53,7 @@ class Phi(CausalLM): tokenizer.pad_token = tokenizer.eos_token config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 92c93542..d4764ded 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -12,11 +12,11 @@ class RW(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - if use_medusa: + if speculator: raise RuntimeError("Medusa decoding is not enabled for AutoModel") if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 73c21cce..323e4324 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -19,7 +19,7 @@ class SantaCoder(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 3f3cb965..8e0735e5 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -25,7 +25,7 @@ class T5Sharded(Seq2SeqLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -43,7 +43,7 @@ class T5Sharded(Seq2SeqLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator tokenizer = AutoTokenizer.from_pretrained( model_id,