mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Remove traces of use_medusa
.
This commit is contained in:
parent
3397b26341
commit
de11fc064a
@ -69,10 +69,10 @@ class MedusaHeadV1(nn.Module):
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
use_medusa = config.use_medusa
|
||||
speculator = config.speculator
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
medusa_config = str(Path(speculator) / "config.json")
|
||||
filename = str(Path(speculator) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
medusa_config = json.load(f)
|
||||
@ -108,10 +108,10 @@ class MedusaHeadV2(nn.Module):
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
use_medusa = config.use_medusa
|
||||
speculator = config.speculator
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||
medusa_config = str(Path(speculator) / "config.json")
|
||||
filename = str(Path(speculator) / "medusa_lm_head.safetensors")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
medusa_config = json.load(f)
|
||||
|
@ -42,7 +42,7 @@ class BLOOMSharded(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -71,7 +71,7 @@ class BLOOMSharded(CausalLM):
|
||||
)
|
||||
config.pad_token_id = 3
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
@ -683,9 +683,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
config.vision_config.use_medusa = config.use_medusa
|
||||
config.vision_config.speculator = config.speculator
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.use_medusa = config.use_medusa
|
||||
config.text_config.speculator = config.speculator
|
||||
|
||||
vision_config = config.vision_config
|
||||
self.text_model = load_text_model(
|
||||
|
@ -135,7 +135,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.config = config
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.use_medusa = config.use_medusa
|
||||
config.text_config.speculator = config.speculator
|
||||
self.language_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
|
@ -24,7 +24,7 @@ class FlashCohere(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -49,7 +49,7 @@ class FlashCohere(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -26,7 +26,7 @@ class FlashDbrx(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -74,7 +74,7 @@ class FlashDbrx(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -25,7 +25,7 @@ class FlashGemma(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -50,7 +50,7 @@ class FlashGemma(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -15,7 +15,7 @@ class FlashMixtral(BaseFlashMistral):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -25,7 +25,7 @@ class FlashMixtral(BaseFlashMistral):
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -25,7 +25,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -51,7 +51,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
@ -25,7 +25,7 @@ class FlashPhi(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -48,7 +48,7 @@ class FlashPhi(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
@ -58,7 +58,7 @@ class FlashPhi(FlashCausalLM):
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
model = FlashPhiForCausalLM(config, weights)
|
||||
if use_medusa:
|
||||
if speculator:
|
||||
from text_generation_server.utils.medusa import MedusaModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import json
|
||||
@ -66,19 +66,19 @@ class FlashPhi(FlashCausalLM):
|
||||
from pathlib import Path
|
||||
|
||||
is_local_model = (
|
||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
||||
Path(speculator).exists() and Path(speculator).is_dir()
|
||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||
|
||||
if not is_local_model:
|
||||
medusa_config = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="config.json"
|
||||
speculator, revision=revision, filename="config.json"
|
||||
)
|
||||
medusa_head = hf_hub_download(
|
||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
||||
speculator, revision=revision, filename="medusa_lm_head.pt"
|
||||
)
|
||||
else:
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
medusa_config = str(Path(speculator) / "config.json")
|
||||
medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
|
||||
|
||||
with open(medusa_config, "r") as f:
|
||||
config = json.load(f)
|
||||
|
@ -30,7 +30,7 @@ class FlashQwen2(BaseFlashMistral):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -53,7 +53,7 @@ class FlashQwen2(BaseFlashMistral):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
|
||||
# Set context windows
|
||||
if config.sliding_window is not None:
|
||||
|
@ -26,7 +26,7 @@ class FlashRWSharded(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -66,7 +66,7 @@ class FlashRWSharded(FlashCausalLM):
|
||||
)
|
||||
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
|
@ -29,7 +29,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -57,7 +57,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||
trust_remote_code=True,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
config.transpose = config.architectures[0].startswith("GPT2")
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -29,7 +29,7 @@ class FlashStarcoder2(BaseFlashMistral):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -52,7 +52,7 @@ class FlashStarcoder2(BaseFlashMistral):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
|
||||
# Set context windows
|
||||
if config.sliding_window is not None:
|
||||
|
@ -167,7 +167,7 @@ class GalacticaSharded(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -195,7 +195,7 @@ class GalacticaSharded(CausalLM):
|
||||
)
|
||||
config.quantize = quantize
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
@ -24,7 +24,7 @@ class GPTNeoxSharded(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -51,7 +51,7 @@ class GPTNeoxSharded(CausalLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
@ -31,7 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -52,7 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
config.vision_config.quantize = quantize
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
|
@ -18,7 +18,7 @@ class Idefics2(VlmCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -35,7 +35,7 @@ class Idefics2(VlmCausalLM):
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -18,7 +18,7 @@ class LlavaNext(VlmCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -30,7 +30,7 @@ class LlavaNext(VlmCausalLM):
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -408,7 +408,7 @@ class Mamba(Model):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -445,7 +445,7 @@ class Mamba(Model):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
|
@ -43,7 +43,7 @@ class MPTSharded(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -76,7 +76,7 @@ class MPTSharded(CausalLM):
|
||||
config = json.load(f)
|
||||
config = PretrainedConfig(**config)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -22,7 +22,7 @@ class OPTSharded(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -48,7 +48,7 @@ class OPTSharded(CausalLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -22,7 +22,7 @@ class Phi(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -53,7 +53,7 @@ class Phi(CausalLM):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
|
@ -12,11 +12,11 @@ class RW(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if use_medusa:
|
||||
if speculator:
|
||||
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
|
@ -19,7 +19,7 @@ class SantaCoder(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -25,7 +25,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -43,7 +43,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.speculator = speculator
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
|
Loading…
Reference in New Issue
Block a user