Upgrade ALL the code.

This commit is contained in:
Nicolas Patry 2024-02-22 11:37:05 +00:00
parent 2446f3ec32
commit ac419f5e46
35 changed files with 94 additions and 66 deletions

View File

@ -115,16 +115,6 @@ def get_model(
else: else:
set_speculate(0) set_speculate(0)
if "facebook/galactica" in model_id:
return GalacticaSharded(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
config_dict, _ = PretrainedConfig.get_config_dict( config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
@ -177,7 +167,7 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type == "gpt_bigcode": if model_type in {"gpt_bigcode", "gpt2"}:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashSantacoderSharded( return FlashSantacoderSharded(
model_id, model_id,
@ -311,9 +301,9 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
) )
elif sharded: elif sharded:
raise NotImplementedError( raise NotImplementedError(
@ -324,6 +314,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -448,6 +439,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -42,6 +42,7 @@ class BLOOMSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -70,6 +71,7 @@ class BLOOMSharded(CausalLM):
) )
config.pad_token_id = 3 config.pad_token_id = 3
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")

View File

@ -482,6 +482,7 @@ class CausalLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):

View File

@ -36,7 +36,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
) )
CUSTOM_KERNELS_ENABLED = False CUSTOM_KERNELS_ENABLED = False
@ -820,7 +820,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
super().__init__(config) super().__init__(config)
self.transformer = BloomModel(config, weights) self.transformer = BloomModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="word_embeddings", prefix="word_embeddings",
weights=weights, weights=weights,

View File

@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastRMSNorm, FastRMSNorm,
) )
@ -575,7 +575,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = FlashGemmaModel(config, weights) self.model = FlashGemmaModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head",
weights=weights, weights=weights,

View File

@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastRMSNorm, FastRMSNorm,
) )
@ -410,7 +410,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = FlashLlamaModel(config, weights) self.model = FlashLlamaModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,

View File

@ -37,7 +37,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
) )
@ -810,7 +810,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = MixtralModel(config, weights) self.model = MixtralModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,

View File

@ -33,7 +33,7 @@ from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead, SpeculativeHead,
FastLayerNorm, FastLayerNorm,
PositionRotaryEmbedding, PositionRotaryEmbedding,
get_linear, get_linear,
@ -369,7 +369,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
super().__init__(config) super().__init__(config)
self.gpt_neox = FlashGPTNeoXModel(config, weights) self.gpt_neox = FlashGPTNeoXModel(config, weights)
self.embed_out = TensorParallelHead.load( self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights config, prefix="embed_out", weights=weights
) )

View File

@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastLayerNorm, FastLayerNorm,
) )
@ -376,7 +376,7 @@ class FlashPhiForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = FlashPhiModel(config, weights) self.model = FlashPhiModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,

View File

@ -12,7 +12,7 @@ from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead, SpeculativeHead,
FastLayerNorm, FastLayerNorm,
PositionRotaryEmbedding, PositionRotaryEmbedding,
get_linear, get_linear,
@ -613,7 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
self.transformer = FlashRWModel(config, weights) self.transformer = FlashRWModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="lm_head", weights=weights config, prefix="lm_head", weights=weights
) )

View File

@ -9,7 +9,7 @@ from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelHead, SpeculativeHead,
TensorParallelEmbedding, TensorParallelEmbedding,
FastLayerNorm, FastLayerNorm,
get_linear, get_linear,
@ -453,7 +453,7 @@ class FlashSantacoderForCausalLM(nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
self.transformer = FlashSantacoderModel(config, weights) self.transformer = FlashSantacoderModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix="transformer.wte", weights=weights
) )

View File

@ -51,7 +51,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
PositionRotaryEmbedding, PositionRotaryEmbedding,
FastLinear, FastLinear,
) )
@ -272,7 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
weights, weights,
) -> None: ) -> None:
super().__init__() super().__init__()
self.fc = TensorParallelHead.load( self.fc = SpeculativeHead.load(
config=config, prefix="lm_head", weights=weights config=config, prefix="lm_head", weights=weights
) )
self.additional_fc = FastLinear.load( self.additional_fc = FastLinear.load(

View File

@ -9,6 +9,7 @@ from transformers.configuration_utils import PretrainedConfig
import torch.nn.functional as F import torch.nn.functional as F
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
SpeculativeHead,
TensorParallelEmbedding, TensorParallelEmbedding,
FastRMSNorm, FastRMSNorm,
FastLinear, FastLinear,
@ -205,14 +206,14 @@ class MambaModel(nn.Module):
self.norm_f = FastRMSNorm.load( self.norm_f = FastRMSNorm.load(
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
) )
self.lm_head = FastLinear.load( self.lm_head = SpeculativeHead.load(
config, f"{prefix}.embedding", weights, bias=False config, f"{prefix}.embedding", weights
) )
self.config = config self.config = config
def forward( def forward(
self, input_ids: torch.Tensor, inference_params=None, residual=None self, input_ids: torch.Tensor, inference_params=None, residual=None
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
hidden_states, residual, conv_state, ssm_state = block( hidden_states, residual, conv_state, ssm_state = block(
@ -226,8 +227,8 @@ class MambaModel(nn.Module):
) )
hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1)))
hidden_states = hidden_states.view(residual.shape) hidden_states = hidden_states.view(residual.shape)
logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
# update the offset for the next inference using these params # update the offset for the next inference using these params
inference_params.seqlen_offset += input_ids.size(1) inference_params.seqlen_offset += input_ids.size(1)
return logits return logits, speculative_logits

View File

@ -21,7 +21,7 @@ from text_generation_server.utils.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
) )
@ -1090,7 +1090,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
if not config.tie_word_embeddings: if not config.tie_word_embeddings:
raise ValueError("MPTForCausalLM only supports tied word embeddings") raise ValueError("MPTForCausalLM only supports tied word embeddings")
self.transformer = MPTModel(config, weights) self.transformer = MPTModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="transformer.wte", weights=weights config, prefix="transformer.wte", weights=weights
) )
self.logit_scale = None self.logit_scale = None

View File

@ -44,7 +44,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
) )
@ -646,7 +646,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__(config) super().__init__(config)
self.gpt_neox = GPTNeoXModel(config, weights) self.gpt_neox = GPTNeoXModel(config, weights)
self.embed_out = TensorParallelHead.load( self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights config, prefix="embed_out", weights=weights
) )

View File

@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
) )
EPS = 1e-5 EPS = 1e-5
@ -748,7 +748,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
self.model = OPTModel(config, weights) self.model = OPTModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="model.decoder.embed_tokens", weights=weights config, prefix="model.decoder.embed_tokens", weights=weights
) )

View File

@ -13,7 +13,7 @@ from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelHead, SpeculativeHead,
FastLinear, FastLinear,
) )
@ -120,7 +120,7 @@ class PhiCausalLMHead(nn.Module):
weights=weights, weights=weights,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
self.linear = TensorParallelHead.load( self.linear = SpeculativeHead.load(
config=config, prefix="lm_head.linear", weights=weights config=config, prefix="lm_head.linear", weights=weights
) )

View File

@ -42,7 +42,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelHead, SpeculativeHead,
) )
@ -1033,14 +1033,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
) )
try: try:
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="lm_head", weights=weights config, prefix="lm_head", weights=weights
) )
except RuntimeError: except RuntimeError:
# Some models like t5-small were saved with shared weights unlike flan # Some models like t5-small were saved with shared weights unlike flan
# Since they are declared as the same arch we have no choice but hope # Since they are declared as the same arch we have no choice but hope
# that this is OK instead of using a proper flag. # that this is OK instead of using a proper flag.
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, prefix="shared", weights=weights config, prefix="shared", weights=weights
) )
@ -1126,7 +1126,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5) sequence_output = sequence_output * (self.model_dim**-0.5)
lm_logits = self.lm_head(sequence_output) logits, speculative_logits = self.lm_head(sequence_output)
loss = None loss = None
if labels is not None: if labels is not None:
@ -1142,7 +1142,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
return Seq2SeqLMOutput( return Seq2SeqLMOutput(
loss=loss, loss=loss,
logits=lm_logits, logits=logits,
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
@ -1150,7 +1150,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
) ), speculative_logits
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,

View File

@ -723,7 +723,7 @@ class FlashCausalLM(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
@ -734,6 +734,8 @@ class FlashCausalLM(Model):
max_s=max_s, max_s=max_s,
lm_head_indices=None, lm_head_indices=None,
) )
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize() torch.cuda.synchronize()
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
@ -805,7 +807,7 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor: def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids = batch.input_ids input_ids = batch.input_ids
@ -900,9 +902,10 @@ class FlashCausalLM(Model):
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
# Slice output to the correct shape # Slice output to the correct shape
return cuda_graph["logits"][:bs] speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
def generate_token( def generate_token(

View File

@ -25,9 +25,9 @@ class FlashGemma(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -50,6 +50,7 @@ class FlashGemma(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -26,9 +26,9 @@ class FlashLlama(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -58,6 +58,7 @@ class FlashLlama(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -409,8 +409,7 @@ class BaseFlashMistral(FlashCausalLM):
lm_head_indices=None, lm_head_indices=None,
) )
self.cuda_graphs[bs]["logits"] = logits self.cuda_graphs[bs]["logits"] = logits
if speculative_logits is not None: self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize() torch.cuda.synchronize()
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -516,7 +515,7 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
# Slice output to the correct shape # Slice output to the correct shape
speculative_logits = cuda_graph["speculative_logits"][:bs] if "speculative_logits" in cuda_graph else None speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
logits = cuda_graph["logits"][:bs] logits = cuda_graph["logits"][:bs]
return logits, speculative_logits return logits, speculative_logits

View File

@ -15,6 +15,7 @@ class FlashMixtral(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -24,6 +25,7 @@ class FlashMixtral(BaseFlashMistral):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -24,6 +24,7 @@ class FlashNeoXSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -46,6 +47,7 @@ class FlashNeoXSharded(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")

View File

@ -25,9 +25,9 @@ class FlashPhi(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
): ):
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): if torch.cuda.is_available():
@ -48,6 +48,7 @@ class FlashPhi(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -25,6 +25,7 @@ class FlashRWSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -61,6 +62,7 @@ class FlashRWSharded(FlashCausalLM):
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
if config.quantize == "gptq": if config.quantize == "gptq":
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)

View File

@ -27,6 +27,7 @@ class FlashSantacoderSharded(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -51,6 +52,7 @@ class FlashSantacoderSharded(FlashCausalLM):
trust_remote_code=True, trust_remote_code=True,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
config.transpose = config.architectures[0].startswith("GPT2") config.transpose = config.architectures[0].startswith("GPT2")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -31,6 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -51,6 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
config.vision_config.quantize = quantize config.vision_config.quantize = quantize
tokenizer = LlamaTokenizerFast.from_pretrained( tokenizer = LlamaTokenizerFast.from_pretrained(

View File

@ -408,6 +408,7 @@ class Mamba(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -444,6 +445,7 @@ class Mamba(Model):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)
@ -505,7 +507,7 @@ class Mamba(Model):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, inference_params=inference_params input_ids=input_ids, inference_params=inference_params
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@ -514,6 +516,7 @@ class Mamba(Model):
"inference_params": inference_params, "inference_params": inference_params,
"graph": graph, "graph": graph,
"logits": logits, "logits": logits,
"speculative_logits": speculative_logits,
} }
self.cuda_graphs[batch_size] = graph_dict self.cuda_graphs[batch_size] = graph_dict
@ -556,9 +559,10 @@ class Mamba(Model):
inference_params.ssm_states.copy_( inference_params.ssm_states.copy_(
cuda_graph["inference_params"].ssm_states[:, :bs] cuda_graph["inference_params"].ssm_states[:, :bs]
) )
# Slice output to the correct shape # Slice output to the correct shape
return cuda_graph["logits"][:bs] speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns() start = time.time_ns()
@ -589,7 +593,7 @@ class Mamba(Model):
batch.inference_params = inference_params batch.inference_params = inference_params
# Forward pass # Forward pass
logits = self.forward(input_ids, inference_params=batch.inference_params) logits, speculative_logits = self.forward(input_ids, inference_params=batch.inference_params)
# batch.inference_params = new_inference_params # batch.inference_params = new_inference_params
# Results # Results

View File

@ -43,6 +43,7 @@ class MPTSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -75,6 +76,7 @@ class MPTSharded(CausalLM):
config = json.load(f) config = json.load(f)
config = PretrainedConfig(**config) config = PretrainedConfig(**config)
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -22,6 +22,7 @@ class OPTSharded(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -47,6 +48,7 @@ class OPTSharded(CausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)

View File

@ -22,6 +22,7 @@ class Phi(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -52,6 +53,7 @@ class Phi(CausalLM):
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group) weights = Weights(filenames, device, dtype, process_group=self.process_group)

View File

@ -19,6 +19,7 @@ class SantaCoder(CausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):

View File

@ -532,6 +532,7 @@ class Seq2SeqLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -596,11 +597,12 @@ class Seq2SeqLM(Model):
past_key_values: Optional = None, past_key_values: Optional = None,
) -> Tuple[ ) -> Tuple[
torch.Tensor, torch.Tensor,
Optional[torch.Tensor],
torch.Tensor, torch.Tensor,
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]: ]:
# Model Forward # Model Forward
outputs = self.model.forward( outputs, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
@ -611,6 +613,7 @@ class Seq2SeqLM(Model):
) )
return ( return (
outputs.logits, outputs.logits,
speculative_logits,
outputs.encoder_last_hidden_state, outputs.encoder_last_hidden_state,
outputs.past_key_values, outputs.past_key_values,
) )
@ -635,7 +638,7 @@ class Seq2SeqLM(Model):
else: else:
encoder_last_hidden_state = None encoder_last_hidden_state = None
logits, encoder_last_hidden_state, past = self.forward( logits, speculative_logits, encoder_last_hidden_state, past = self.forward(
batch.input_ids, batch.input_ids,
batch.attention_mask, batch.attention_mask,
batch.decoder_input_ids, batch.decoder_input_ids,

View File

@ -25,6 +25,7 @@ class T5Sharded(Seq2SeqLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -42,6 +43,7 @@ class T5Sharded(Seq2SeqLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
@ -94,7 +96,7 @@ class T5Sharded(Seq2SeqLM):
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
]: ]:
# Model Forward # Model Forward
outputs = self.model.forward( outputs, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
@ -106,6 +108,7 @@ class T5Sharded(Seq2SeqLM):
return ( return (
outputs.logits, outputs.logits,
speculative_logits,
outputs.encoder_last_hidden_state, outputs.encoder_last_hidden_state,
outputs.past_key_values, outputs.past_key_values,
) )