mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
Remove traces of use_medusa
.
This commit is contained in:
parent
3397b26341
commit
de11fc064a
@ -69,10 +69,10 @@ class MedusaHeadV1(nn.Module):
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import json
|
import json
|
||||||
|
|
||||||
use_medusa = config.use_medusa
|
speculator = config.speculator
|
||||||
|
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
medusa_config = str(Path(speculator) / "config.json")
|
||||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
filename = str(Path(speculator) / "medusa_lm_head.safetensors")
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
medusa_config = json.load(f)
|
medusa_config = json.load(f)
|
||||||
@ -108,10 +108,10 @@ class MedusaHeadV2(nn.Module):
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import json
|
import json
|
||||||
|
|
||||||
use_medusa = config.use_medusa
|
speculator = config.speculator
|
||||||
|
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
medusa_config = str(Path(speculator) / "config.json")
|
||||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
filename = str(Path(speculator) / "medusa_lm_head.safetensors")
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
medusa_config = json.load(f)
|
medusa_config = json.load(f)
|
||||||
|
@ -42,7 +42,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -71,7 +71,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
)
|
)
|
||||||
config.pad_token_id = 3
|
config.pad_token_id = 3
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
@ -683,9 +683,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config.vision_config.quantize = config.quantize
|
config.vision_config.quantize = config.quantize
|
||||||
config.vision_config.use_medusa = config.use_medusa
|
config.vision_config.speculator = config.speculator
|
||||||
config.text_config.quantize = config.quantize
|
config.text_config.quantize = config.quantize
|
||||||
config.text_config.use_medusa = config.use_medusa
|
config.text_config.speculator = config.speculator
|
||||||
|
|
||||||
vision_config = config.vision_config
|
vision_config = config.vision_config
|
||||||
self.text_model = load_text_model(
|
self.text_model = load_text_model(
|
||||||
|
@ -135,7 +135,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||||||
self.vocab_size = config.text_config.vocab_size
|
self.vocab_size = config.text_config.vocab_size
|
||||||
self.config = config
|
self.config = config
|
||||||
config.text_config.quantize = config.quantize
|
config.text_config.quantize = config.quantize
|
||||||
config.text_config.use_medusa = config.use_medusa
|
config.text_config.speculator = config.speculator
|
||||||
self.language_model = load_text_model(
|
self.language_model = load_text_model(
|
||||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||||
config=config.text_config,
|
config=config.text_config,
|
||||||
|
@ -24,7 +24,7 @@ class FlashCohere(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -49,7 +49,7 @@ class FlashCohere(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class FlashDbrx(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -74,7 +74,7 @@ class FlashDbrx(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ class FlashGemma(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -50,7 +50,7 @@ class FlashGemma(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ class FlashMixtral(BaseFlashMistral):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -25,7 +25,7 @@ class FlashMixtral(BaseFlashMistral):
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -25,7 +25,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -51,7 +51,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
@ -25,7 +25,7 @@ class FlashPhi(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -48,7 +48,7 @@ class FlashPhi(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ class FlashPhi(FlashCausalLM):
|
|||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
model = FlashPhiForCausalLM(config, weights)
|
model = FlashPhiForCausalLM(config, weights)
|
||||||
if use_medusa:
|
if speculator:
|
||||||
from text_generation_server.utils.medusa import MedusaModel
|
from text_generation_server.utils.medusa import MedusaModel
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import json
|
import json
|
||||||
@ -66,19 +66,19 @@ class FlashPhi(FlashCausalLM):
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
is_local_model = (
|
is_local_model = (
|
||||||
Path(use_medusa).exists() and Path(use_medusa).is_dir()
|
Path(speculator).exists() and Path(speculator).is_dir()
|
||||||
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
|
||||||
|
|
||||||
if not is_local_model:
|
if not is_local_model:
|
||||||
medusa_config = hf_hub_download(
|
medusa_config = hf_hub_download(
|
||||||
use_medusa, revision=revision, filename="config.json"
|
speculator, revision=revision, filename="config.json"
|
||||||
)
|
)
|
||||||
medusa_head = hf_hub_download(
|
medusa_head = hf_hub_download(
|
||||||
use_medusa, revision=revision, filename="medusa_lm_head.pt"
|
speculator, revision=revision, filename="medusa_lm_head.pt"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
medusa_config = str(Path(speculator) / "config.json")
|
||||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
medusa_head = str(Path(speculator) / "medusa_lm_head.pt")
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
with open(medusa_config, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
@ -30,7 +30,7 @@ class FlashQwen2(BaseFlashMistral):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -53,7 +53,7 @@ class FlashQwen2(BaseFlashMistral):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
# Set context windows
|
# Set context windows
|
||||||
if config.sliding_window is not None:
|
if config.sliding_window is not None:
|
||||||
|
@ -26,7 +26,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -66,7 +66,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -57,7 +57,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
config.transpose = config.architectures[0].startswith("GPT2")
|
config.transpose = config.architectures[0].startswith("GPT2")
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -29,7 +29,7 @@ class FlashStarcoder2(BaseFlashMistral):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -52,7 +52,7 @@ class FlashStarcoder2(BaseFlashMistral):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
# Set context windows
|
# Set context windows
|
||||||
if config.sliding_window is not None:
|
if config.sliding_window is not None:
|
||||||
|
@ -167,7 +167,7 @@ class GalacticaSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -195,7 +195,7 @@ class GalacticaSharded(CausalLM):
|
|||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
@ -24,7 +24,7 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -51,7 +51,7 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
@ -31,7 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -52,7 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
config.vision_config.quantize = quantize
|
config.vision_config.quantize = quantize
|
||||||
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||||
|
@ -18,7 +18,7 @@ class Idefics2(VlmCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -35,7 +35,7 @@ class Idefics2(VlmCausalLM):
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -18,7 +18,7 @@ class LlavaNext(VlmCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -30,7 +30,7 @@ class LlavaNext(VlmCausalLM):
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
use_medusa=use_medusa,
|
speculator=speculator,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -408,7 +408,7 @@ class Mamba(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -445,7 +445,7 @@ class Mamba(Model):
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
@ -43,7 +43,7 @@ class MPTSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -76,7 +76,7 @@ class MPTSharded(CausalLM):
|
|||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
config = PretrainedConfig(**config)
|
config = PretrainedConfig(**config)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ class OPTSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -48,7 +48,7 @@ class OPTSharded(CausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -22,7 +22,7 @@ class Phi(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -53,7 +53,7 @@ class Phi(CausalLM):
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
@ -12,11 +12,11 @@ class RW(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
if use_medusa:
|
if speculator:
|
||||||
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -19,7 +19,7 @@ class SantaCoder(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -25,7 +25,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
use_medusa: Optional[str] = None,
|
speculator: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -43,7 +43,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
config.use_medusa = use_medusa
|
config.speculator = speculator
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user