mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Upgrade ALL the code.
This commit is contained in:
parent
2446f3ec32
commit
ac419f5e46
@ -115,16 +115,6 @@ def get_model(
|
|||||||
else:
|
else:
|
||||||
set_speculate(0)
|
set_speculate(0)
|
||||||
|
|
||||||
if "facebook/galactica" in model_id:
|
|
||||||
return GalacticaSharded(
|
|
||||||
model_id,
|
|
||||||
revision,
|
|
||||||
quantize=quantize,
|
|
||||||
use_medusa=use_medusa,
|
|
||||||
dtype=dtype,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
@ -177,7 +167,7 @@ def get_model(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "gpt_bigcode":
|
if model_type in {"gpt_bigcode", "gpt2"}:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return FlashSantacoderSharded(
|
return FlashSantacoderSharded(
|
||||||
model_id,
|
model_id,
|
||||||
@ -311,9 +301,9 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
use_medusa=use_medusa,
|
|
||||||
)
|
)
|
||||||
elif sharded:
|
elif sharded:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -324,6 +314,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
@ -448,6 +439,7 @@ def get_model(
|
|||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -42,6 +42,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -70,6 +71,7 @@ class BLOOMSharded(CausalLM):
|
|||||||
)
|
)
|
||||||
config.pad_token_id = 3
|
config.pad_token_id = 3
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
@ -482,6 +482,7 @@ class CausalLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -36,7 +36,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
|
|
||||||
CUSTOM_KERNELS_ENABLED = False
|
CUSTOM_KERNELS_ENABLED = False
|
||||||
@ -820,7 +820,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.transformer = BloomModel(config, weights)
|
self.transformer = BloomModel(config, weights)
|
||||||
|
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="word_embeddings",
|
prefix="word_embeddings",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
@ -575,7 +575,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = FlashGemmaModel(config, weights)
|
self.model = FlashGemmaModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
|
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
)
|
)
|
||||||
@ -410,7 +410,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = FlashLlamaModel(config, weights)
|
self.model = FlashLlamaModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -810,7 +810,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = MixtralModel(config, weights)
|
self.model = MixtralModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -33,7 +33,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
get_linear,
|
get_linear,
|
||||||
@ -369,7 +369,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.gpt_neox = FlashGPTNeoXModel(config, weights)
|
self.gpt_neox = FlashGPTNeoXModel(config, weights)
|
||||||
|
|
||||||
self.embed_out = TensorParallelHead.load(
|
self.embed_out = SpeculativeHead.load(
|
||||||
config, prefix="embed_out", weights=weights
|
config, prefix="embed_out", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
)
|
)
|
||||||
@ -376,7 +376,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = FlashPhiModel(config, weights)
|
self.model = FlashPhiModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
get_linear,
|
get_linear,
|
||||||
@ -613,7 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
|
|
||||||
self.transformer = FlashRWModel(config, weights)
|
self.transformer = FlashRWModel(config, weights)
|
||||||
|
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="lm_head", weights=weights
|
config, prefix="lm_head", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from text_generation_server.utils import paged_attention, flash_attn
|
|||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
FastLayerNorm,
|
FastLayerNorm,
|
||||||
get_linear,
|
get_linear,
|
||||||
@ -453,7 +453,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.transformer = FlashSantacoderModel(config, weights)
|
self.transformer = FlashSantacoderModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="transformer.wte", weights=weights
|
config, prefix="transformer.wte", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
PositionRotaryEmbedding,
|
PositionRotaryEmbedding,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
@ -272,7 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
|||||||
weights,
|
weights,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc = TensorParallelHead.load(
|
self.fc = SpeculativeHead.load(
|
||||||
config=config, prefix="lm_head", weights=weights
|
config=config, prefix="lm_head", weights=weights
|
||||||
)
|
)
|
||||||
self.additional_fc = FastLinear.load(
|
self.additional_fc = FastLinear.load(
|
||||||
|
@ -9,6 +9,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
|
SpeculativeHead,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
@ -205,14 +206,14 @@ class MambaModel(nn.Module):
|
|||||||
self.norm_f = FastRMSNorm.load(
|
self.norm_f = FastRMSNorm.load(
|
||||||
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
||||||
)
|
)
|
||||||
self.lm_head = FastLinear.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, f"{prefix}.embedding", weights, bias=False
|
config, f"{prefix}.embedding", weights
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids: torch.Tensor, inference_params=None, residual=None
|
self, input_ids: torch.Tensor, inference_params=None, residual=None
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
hidden_states, residual, conv_state, ssm_state = block(
|
hidden_states, residual, conv_state, ssm_state = block(
|
||||||
@ -226,8 +227,8 @@ class MambaModel(nn.Module):
|
|||||||
)
|
)
|
||||||
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
||||||
hidden_states = hidden_states.view(residual.shape)
|
hidden_states = hidden_states.view(residual.shape)
|
||||||
logits = self.lm_head(hidden_states)
|
logits, speculative_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
# update the offset for the next inference using these params
|
# update the offset for the next inference using these params
|
||||||
inference_params.seqlen_offset += input_ids.size(1)
|
inference_params.seqlen_offset += input_ids.size(1)
|
||||||
return logits
|
return logits, speculative_logits
|
||||||
|
@ -21,7 +21,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1090,7 +1090,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|||||||
if not config.tie_word_embeddings:
|
if not config.tie_word_embeddings:
|
||||||
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
||||||
self.transformer = MPTModel(config, weights)
|
self.transformer = MPTModel(config, weights)
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="transformer.wte", weights=weights
|
config, prefix="transformer.wte", weights=weights
|
||||||
)
|
)
|
||||||
self.logit_scale = None
|
self.logit_scale = None
|
||||||
|
@ -44,7 +44,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -646,7 +646,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
def __init__(self, config, weights):
|
def __init__(self, config, weights):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.gpt_neox = GPTNeoXModel(config, weights)
|
self.gpt_neox = GPTNeoXModel(config, weights)
|
||||||
self.embed_out = TensorParallelHead.load(
|
self.embed_out = SpeculativeHead.load(
|
||||||
config, prefix="embed_out", weights=weights
|
config, prefix="embed_out", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
|
|
||||||
EPS = 1e-5
|
EPS = 1e-5
|
||||||
@ -748,7 +748,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
|||||||
|
|
||||||
self.model = OPTModel(config, weights)
|
self.model = OPTModel(config, weights)
|
||||||
|
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="model.decoder.embed_tokens", weights=weights
|
config, prefix="model.decoder.embed_tokens", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
FastLinear,
|
FastLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ class PhiCausalLMHead(nn.Module):
|
|||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.layer_norm_epsilon,
|
eps=config.layer_norm_epsilon,
|
||||||
)
|
)
|
||||||
self.linear = TensorParallelHead.load(
|
self.linear = SpeculativeHead.load(
|
||||||
config=config, prefix="lm_head.linear", weights=weights
|
config=config, prefix="lm_head.linear", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ from text_generation_server.utils.layers import (
|
|||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelHead,
|
SpeculativeHead,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1033,14 +1033,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="lm_head", weights=weights
|
config, prefix="lm_head", weights=weights
|
||||||
)
|
)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# Some models like t5-small were saved with shared weights unlike flan
|
# Some models like t5-small were saved with shared weights unlike flan
|
||||||
# Since they are declared as the same arch we have no choice but hope
|
# Since they are declared as the same arch we have no choice but hope
|
||||||
# that this is OK instead of using a proper flag.
|
# that this is OK instead of using a proper flag.
|
||||||
self.lm_head = TensorParallelHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="shared", weights=weights
|
config, prefix="shared", weights=weights
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1126,7 +1126,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
||||||
sequence_output = sequence_output * (self.model_dim**-0.5)
|
sequence_output = sequence_output * (self.model_dim**-0.5)
|
||||||
|
|
||||||
lm_logits = self.lm_head(sequence_output)
|
logits, speculative_logits = self.lm_head(sequence_output)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
@ -1142,7 +1142,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
|
|
||||||
return Seq2SeqLMOutput(
|
return Seq2SeqLMOutput(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=lm_logits,
|
logits=logits,
|
||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
decoder_attentions=decoder_outputs.attentions,
|
||||||
@ -1150,7 +1150,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||||
encoder_attentions=encoder_outputs.attentions,
|
encoder_attentions=encoder_outputs.attentions,
|
||||||
)
|
), speculative_logits
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
|
@ -723,7 +723,7 @@ class FlashCausalLM(Model):
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
cu_seqlen_prefill=None,
|
cu_seqlen_prefill=None,
|
||||||
@ -734,6 +734,8 @@ class FlashCausalLM(Model):
|
|||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
)
|
)
|
||||||
|
self.cuda_graphs[bs]["logits"] = logits
|
||||||
|
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def warmup(self, batch: FlashCausalLMBatch):
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
@ -805,7 +807,7 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor:
|
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
@ -900,9 +902,10 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
# Slice output to the correct shape
|
# Slice output to the correct shape
|
||||||
return cuda_graph["logits"][:bs]
|
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
|
||||||
|
logits = cuda_graph["logits"][:bs]
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
def generate_token(
|
def generate_token(
|
||||||
|
@ -25,9 +25,9 @@ class FlashGemma(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
use_medusa: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -50,6 +50,7 @@ class FlashGemma(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -26,9 +26,9 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
use_medusa: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -58,6 +58,7 @@ class FlashLlama(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -409,7 +409,6 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
)
|
)
|
||||||
self.cuda_graphs[bs]["logits"] = logits
|
self.cuda_graphs[bs]["logits"] = logits
|
||||||
if speculative_logits is not None:
|
|
||||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
@ -516,7 +515,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
# Slice output to the correct shape
|
# Slice output to the correct shape
|
||||||
speculative_logits = cuda_graph["speculative_logits"][:bs] if "speculative_logits" in cuda_graph else None
|
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
|
||||||
logits = cuda_graph["logits"][:bs]
|
logits = cuda_graph["logits"][:bs]
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ class FlashMixtral(BaseFlashMistral):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -24,6 +25,7 @@ class FlashMixtral(BaseFlashMistral):
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
|
use_medusa=use_medusa,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
@ -24,6 +24,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -46,6 +47,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
|
@ -25,9 +25,9 @@ class FlashPhi(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
use_medusa: Optional[str] = None,
|
|
||||||
):
|
):
|
||||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -48,6 +48,7 @@ class FlashPhi(FlashCausalLM):
|
|||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -61,6 +62,7 @@ class FlashRWSharded(FlashCausalLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
if config.quantize == "gptq":
|
if config.quantize == "gptq":
|
||||||
weights._set_gptq_params(model_id, revision)
|
weights._set_gptq_params(model_id, revision)
|
||||||
|
|
||||||
|
@ -27,6 +27,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -51,6 +52,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
config.transpose = config.architectures[0].startswith("GPT2")
|
config.transpose = config.architectures[0].startswith("GPT2")
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -31,6 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -51,6 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
config.vision_config.quantize = quantize
|
config.vision_config.quantize = quantize
|
||||||
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||||
|
@ -408,6 +408,7 @@ class Mamba(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -444,6 +445,7 @@ class Mamba(Model):
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
@ -505,7 +507,7 @@ class Mamba(Model):
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids, inference_params=inference_params
|
input_ids=input_ids, inference_params=inference_params
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -514,6 +516,7 @@ class Mamba(Model):
|
|||||||
"inference_params": inference_params,
|
"inference_params": inference_params,
|
||||||
"graph": graph,
|
"graph": graph,
|
||||||
"logits": logits,
|
"logits": logits,
|
||||||
|
"speculative_logits": speculative_logits,
|
||||||
}
|
}
|
||||||
self.cuda_graphs[batch_size] = graph_dict
|
self.cuda_graphs[batch_size] = graph_dict
|
||||||
|
|
||||||
@ -556,9 +559,10 @@ class Mamba(Model):
|
|||||||
inference_params.ssm_states.copy_(
|
inference_params.ssm_states.copy_(
|
||||||
cuda_graph["inference_params"].ssm_states[:, :bs]
|
cuda_graph["inference_params"].ssm_states[:, :bs]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Slice output to the correct shape
|
# Slice output to the correct shape
|
||||||
return cuda_graph["logits"][:bs]
|
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
|
||||||
|
logits = cuda_graph["logits"][:bs]
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
@ -589,7 +593,7 @@ class Mamba(Model):
|
|||||||
batch.inference_params = inference_params
|
batch.inference_params = inference_params
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
logits = self.forward(input_ids, inference_params=batch.inference_params)
|
logits, speculative_logits = self.forward(input_ids, inference_params=batch.inference_params)
|
||||||
|
|
||||||
# batch.inference_params = new_inference_params
|
# batch.inference_params = new_inference_params
|
||||||
# Results
|
# Results
|
||||||
|
@ -43,6 +43,7 @@ class MPTSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -75,6 +76,7 @@ class MPTSharded(CausalLM):
|
|||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
config = PretrainedConfig(**config)
|
config = PretrainedConfig(**config)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ class OPTSharded(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -47,6 +48,7 @@ class OPTSharded(CausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
@ -22,6 +22,7 @@ class Phi(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -52,6 +53,7 @@ class Phi(CausalLM):
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||||
|
@ -19,6 +19,7 @@ class SantaCoder(CausalLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -532,6 +532,7 @@ class Seq2SeqLM(Model):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -596,11 +597,12 @@ class Seq2SeqLM(Model):
|
|||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
|
Optional[torch.Tensor],
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||||
]:
|
]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
outputs = self.model.forward(
|
outputs, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
@ -611,6 +613,7 @@ class Seq2SeqLM(Model):
|
|||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
outputs.logits,
|
outputs.logits,
|
||||||
|
speculative_logits,
|
||||||
outputs.encoder_last_hidden_state,
|
outputs.encoder_last_hidden_state,
|
||||||
outputs.past_key_values,
|
outputs.past_key_values,
|
||||||
)
|
)
|
||||||
@ -635,7 +638,7 @@ class Seq2SeqLM(Model):
|
|||||||
else:
|
else:
|
||||||
encoder_last_hidden_state = None
|
encoder_last_hidden_state = None
|
||||||
|
|
||||||
logits, encoder_last_hidden_state, past = self.forward(
|
logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
|
||||||
batch.input_ids,
|
batch.input_ids,
|
||||||
batch.attention_mask,
|
batch.attention_mask,
|
||||||
batch.decoder_input_ids,
|
batch.decoder_input_ids,
|
||||||
|
@ -25,6 +25,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
quantize: Optional[str] = None,
|
quantize: Optional[str] = None,
|
||||||
|
use_medusa: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
@ -42,6 +43,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
config.quantize = quantize
|
config.quantize = quantize
|
||||||
|
config.use_medusa = use_medusa
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
@ -94,7 +96,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||||
]:
|
]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
outputs = self.model.forward(
|
outputs, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
@ -106,6 +108,7 @@ class T5Sharded(Seq2SeqLM):
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
outputs.logits,
|
outputs.logits,
|
||||||
|
speculative_logits,
|
||||||
outputs.encoder_last_hidden_state,
|
outputs.encoder_last_hidden_state,
|
||||||
outputs.past_key_values,
|
outputs.past_key_values,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user