mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Upgrade ALL the code.
This commit is contained in:
parent
2446f3ec32
commit
ac419f5e46
@ -115,16 +115,6 @@ def get_model(
|
||||
else:
|
||||
set_speculate(0)
|
||||
|
||||
if "facebook/galactica" in model_id:
|
||||
return GalacticaSharded(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
@ -177,7 +167,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
if model_type == "gpt_bigcode":
|
||||
if model_type in {"gpt_bigcode", "gpt2"}:
|
||||
if FLASH_ATTENTION:
|
||||
return FlashSantacoderSharded(
|
||||
model_id,
|
||||
@ -311,9 +301,9 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_medusa=use_medusa,
|
||||
)
|
||||
elif sharded:
|
||||
raise NotImplementedError(
|
||||
@ -324,6 +314,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -448,6 +439,7 @@ def get_model(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -42,6 +42,7 @@ class BLOOMSharded(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -70,6 +71,7 @@ class BLOOMSharded(CausalLM):
|
||||
)
|
||||
config.pad_token_id = 3
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
@ -482,6 +482,7 @@ class CausalLM(Model):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -36,7 +36,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
)
|
||||
|
||||
CUSTOM_KERNELS_ENABLED = False
|
||||
@ -820,7 +820,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.transformer = BloomModel(config, weights)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="word_embeddings",
|
||||
weights=weights,
|
||||
|
@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastRMSNorm,
|
||||
)
|
||||
@ -575,7 +575,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashGemmaModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
|
||||
weights=weights,
|
||||
|
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastRMSNorm,
|
||||
)
|
||||
@ -410,7 +410,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashLlamaModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
|
@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
@ -810,7 +810,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.model = MixtralModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
|
@ -33,7 +33,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
FastLayerNorm,
|
||||
PositionRotaryEmbedding,
|
||||
get_linear,
|
||||
@ -369,7 +369,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.gpt_neox = FlashGPTNeoXModel(config, weights)
|
||||
|
||||
self.embed_out = TensorParallelHead.load(
|
||||
self.embed_out = SpeculativeHead.load(
|
||||
config, prefix="embed_out", weights=weights
|
||||
)
|
||||
|
||||
|
@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
FastLayerNorm,
|
||||
)
|
||||
@ -376,7 +376,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.model = FlashPhiModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
|
@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
FastLayerNorm,
|
||||
PositionRotaryEmbedding,
|
||||
get_linear,
|
||||
@ -613,7 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
|
||||
self.transformer = FlashRWModel(config, weights)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="lm_head", weights=weights
|
||||
)
|
||||
|
||||
|
@ -9,7 +9,7 @@ from text_generation_server.utils import paged_attention, flash_attn
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
TensorParallelEmbedding,
|
||||
FastLayerNorm,
|
||||
get_linear,
|
||||
@ -453,7 +453,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
self.transformer = FlashSantacoderModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="transformer.wte", weights=weights
|
||||
)
|
||||
|
||||
|
@ -51,7 +51,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
PositionRotaryEmbedding,
|
||||
FastLinear,
|
||||
)
|
||||
@ -272,7 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
||||
weights,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.fc = TensorParallelHead.load(
|
||||
self.fc = SpeculativeHead.load(
|
||||
config=config, prefix="lm_head", weights=weights
|
||||
)
|
||||
self.additional_fc = FastLinear.load(
|
||||
|
@ -9,6 +9,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
import torch.nn.functional as F
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
SpeculativeHead,
|
||||
TensorParallelEmbedding,
|
||||
FastRMSNorm,
|
||||
FastLinear,
|
||||
@ -205,14 +206,14 @@ class MambaModel(nn.Module):
|
||||
self.norm_f = FastRMSNorm.load(
|
||||
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.lm_head = FastLinear.load(
|
||||
config, f"{prefix}.embedding", weights, bias=False
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, f"{prefix}.embedding", weights
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor, inference_params=None, residual=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, residual, conv_state, ssm_state = block(
|
||||
@ -226,8 +227,8 @@ class MambaModel(nn.Module):
|
||||
)
|
||||
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
|
||||
hidden_states = hidden_states.view(residual.shape)
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
|
||||
# update the offset for the next inference using these params
|
||||
inference_params.seqlen_offset += input_ids.size(1)
|
||||
return logits
|
||||
return logits, speculative_logits
|
||||
|
@ -21,7 +21,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
@ -1090,7 +1090,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
||||
if not config.tie_word_embeddings:
|
||||
raise ValueError("MPTForCausalLM only supports tied word embeddings")
|
||||
self.transformer = MPTModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="transformer.wte", weights=weights
|
||||
)
|
||||
self.logit_scale = None
|
||||
|
@ -44,7 +44,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
)
|
||||
|
||||
|
||||
@ -646,7 +646,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__(config)
|
||||
self.gpt_neox = GPTNeoXModel(config, weights)
|
||||
self.embed_out = TensorParallelHead.load(
|
||||
self.embed_out = SpeculativeHead.load(
|
||||
config, prefix="embed_out", weights=weights
|
||||
)
|
||||
|
||||
|
@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
)
|
||||
|
||||
EPS = 1e-5
|
||||
@ -748,7 +748,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
||||
|
||||
self.model = OPTModel(config, weights)
|
||||
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="model.decoder.embed_tokens", weights=weights
|
||||
)
|
||||
|
||||
|
@ -13,7 +13,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
FastLinear,
|
||||
)
|
||||
|
||||
@ -120,7 +120,7 @@ class PhiCausalLMHead(nn.Module):
|
||||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
self.linear = TensorParallelHead.load(
|
||||
self.linear = SpeculativeHead.load(
|
||||
config=config, prefix="lm_head.linear", weights=weights
|
||||
)
|
||||
|
||||
|
@ -42,7 +42,7 @@ from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelHead,
|
||||
SpeculativeHead,
|
||||
)
|
||||
|
||||
|
||||
@ -1033,14 +1033,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
)
|
||||
|
||||
try:
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="lm_head", weights=weights
|
||||
)
|
||||
except RuntimeError:
|
||||
# Some models like t5-small were saved with shared weights unlike flan
|
||||
# Since they are declared as the same arch we have no choice but hope
|
||||
# that this is OK instead of using a proper flag.
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="shared", weights=weights
|
||||
)
|
||||
|
||||
@ -1126,7 +1126,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
||||
sequence_output = sequence_output * (self.model_dim**-0.5)
|
||||
|
||||
lm_logits = self.lm_head(sequence_output)
|
||||
logits, speculative_logits = self.lm_head(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
@ -1142,7 +1142,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
logits=logits,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
@ -1150,7 +1150,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
), speculative_logits
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
|
@ -723,7 +723,7 @@ class FlashCausalLM(Model):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
self.cuda_graphs[bs]["logits"] = self.model.forward(
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
@ -734,6 +734,8 @@ class FlashCausalLM(Model):
|
||||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
@ -805,7 +807,7 @@ class FlashCausalLM(Model):
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor:
|
||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
@ -900,9 +902,10 @@ class FlashCausalLM(Model):
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
return cuda_graph["logits"][:bs]
|
||||
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
def generate_token(
|
||||
|
@ -25,9 +25,9 @@ class FlashGemma(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_medusa: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -50,6 +50,7 @@ class FlashGemma(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -26,9 +26,9 @@ class FlashLlama(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_medusa: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -58,6 +58,7 @@ class FlashLlama(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -409,8 +409,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
lm_head_indices=None,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
if speculative_logits is not None:
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
@ -516,7 +515,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
speculative_logits = cuda_graph["speculative_logits"][:bs] if "speculative_logits" in cuda_graph else None
|
||||
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
|
@ -15,6 +15,7 @@ class FlashMixtral(BaseFlashMistral):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -24,6 +25,7 @@ class FlashMixtral(BaseFlashMistral):
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
@ -24,6 +24,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -46,6 +47,7 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
|
@ -25,9 +25,9 @@ class FlashPhi(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
use_medusa: Optional[str] = None,
|
||||
):
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
@ -48,6 +48,7 @@ class FlashPhi(FlashCausalLM):
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -25,6 +25,7 @@ class FlashRWSharded(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -61,6 +62,7 @@ class FlashRWSharded(FlashCausalLM):
|
||||
)
|
||||
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
if config.quantize == "gptq":
|
||||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
|
@ -27,6 +27,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -51,6 +52,7 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||
trust_remote_code=True,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.transpose = config.architectures[0].startswith("GPT2")
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -31,6 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -51,6 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
config.vision_config.quantize = quantize
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
|
@ -408,6 +408,7 @@ class Mamba(Model):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -444,6 +445,7 @@ class Mamba(Model):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
@ -505,7 +507,7 @@ class Mamba(Model):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
logits = self.model.forward(
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids, inference_params=inference_params
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
@ -514,6 +516,7 @@ class Mamba(Model):
|
||||
"inference_params": inference_params,
|
||||
"graph": graph,
|
||||
"logits": logits,
|
||||
"speculative_logits": speculative_logits,
|
||||
}
|
||||
self.cuda_graphs[batch_size] = graph_dict
|
||||
|
||||
@ -556,9 +559,10 @@ class Mamba(Model):
|
||||
inference_params.ssm_states.copy_(
|
||||
cuda_graph["inference_params"].ssm_states[:, :bs]
|
||||
)
|
||||
|
||||
# Slice output to the correct shape
|
||||
return cuda_graph["logits"][:bs]
|
||||
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
@ -589,7 +593,7 @@ class Mamba(Model):
|
||||
batch.inference_params = inference_params
|
||||
|
||||
# Forward pass
|
||||
logits = self.forward(input_ids, inference_params=batch.inference_params)
|
||||
logits, speculative_logits = self.forward(input_ids, inference_params=batch.inference_params)
|
||||
|
||||
# batch.inference_params = new_inference_params
|
||||
# Results
|
||||
|
@ -43,6 +43,7 @@ class MPTSharded(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -75,6 +76,7 @@ class MPTSharded(CausalLM):
|
||||
config = json.load(f)
|
||||
config = PretrainedConfig(**config)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
|
@ -22,6 +22,7 @@ class OPTSharded(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -47,6 +48,7 @@ class OPTSharded(CausalLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
@ -22,6 +22,7 @@ class Phi(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -52,6 +53,7 @@ class Phi(CausalLM):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
|
@ -19,6 +19,7 @@ class SantaCoder(CausalLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
|
@ -532,6 +532,7 @@ class Seq2SeqLM(Model):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -596,11 +597,12 @@ class Seq2SeqLM(Model):
|
||||
past_key_values: Optional = None,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
torch.Tensor,
|
||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
@ -611,6 +613,7 @@ class Seq2SeqLM(Model):
|
||||
)
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
outputs.encoder_last_hidden_state,
|
||||
outputs.past_key_values,
|
||||
)
|
||||
@ -635,7 +638,7 @@ class Seq2SeqLM(Model):
|
||||
else:
|
||||
encoder_last_hidden_state = None
|
||||
|
||||
logits, encoder_last_hidden_state, past = self.forward(
|
||||
logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
|
||||
batch.input_ids,
|
||||
batch.attention_mask,
|
||||
batch.decoder_input_ids,
|
||||
|
@ -25,6 +25,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
@ -42,6 +43,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
config.quantize = quantize
|
||||
config.use_medusa = use_medusa
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
@ -94,7 +96,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||||
]:
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
@ -106,6 +108,7 @@ class T5Sharded(Seq2SeqLM):
|
||||
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
outputs.encoder_last_hidden_state,
|
||||
outputs.past_key_values,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user