This commit is contained in:
Nicolas Patry 2024-02-22 13:01:43 +00:00
parent 7a9998d47c
commit 64d38afa9f
10 changed files with 61 additions and 37 deletions

View File

@ -551,7 +551,9 @@ class CausalLM(Model):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[
torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
]:
# Model Forward # Model Forward
kwargs = { kwargs = {
"input_ids": input_ids, "input_ids": input_ids,

View File

@ -911,10 +911,13 @@ class BloomForCausalLM(BloomPreTrainedModel):
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions( return (
loss=loss, CausalLMOutputWithCrossAttentions(
logits=logits, loss=loss,
past_key_values=transformer_outputs.past_key_values, logits=logits,
hidden_states=transformer_outputs.hidden_states, past_key_values=transformer_outputs.past_key_values,
attentions=transformer_outputs.attentions, hidden_states=transformer_outputs.hidden_states,
), speculative_logits attentions=transformer_outputs.attentions,
),
speculative_logits,
)

View File

@ -613,9 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
self.transformer = FlashRWModel(config, weights) self.transformer = FlashRWModel(config, weights)
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
config, prefix="lm_head", weights=weights
)
def forward( def forward(
self, self,

View File

@ -272,9 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
weights, weights,
) -> None: ) -> None:
super().__init__() super().__init__()
self.fc = SpeculativeHead.load( self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights)
config=config, prefix="lm_head", weights=weights
)
self.additional_fc = FastLinear.load( self.additional_fc = FastLinear.load(
config=config, config=config,
prefix="lm_head.additional_fc", prefix="lm_head.additional_fc",

View File

@ -206,9 +206,7 @@ class MambaModel(nn.Module):
self.norm_f = FastRMSNorm.load( self.norm_f = FastRMSNorm.load(
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
) )
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
config, f"{prefix}.embedding", weights
)
self.config = config self.config = config
def forward( def forward(

View File

@ -1140,17 +1140,20 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput( return (
loss=loss, Seq2SeqLMOutput(
logits=logits, loss=loss,
past_key_values=decoder_outputs.past_key_values, logits=logits,
decoder_hidden_states=decoder_outputs.hidden_states, past_key_values=decoder_outputs.past_key_values,
decoder_attentions=decoder_outputs.attentions, decoder_hidden_states=decoder_outputs.hidden_states,
cross_attentions=decoder_outputs.cross_attentions, decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state, cross_attentions=decoder_outputs.cross_attentions,
encoder_hidden_states=encoder_outputs.hidden_states, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_attentions=encoder_outputs.attentions, encoder_hidden_states=encoder_outputs.hidden_states,
), speculative_logits encoder_attentions=encoder_outputs.attentions,
),
speculative_logits,
)
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,

View File

@ -807,7 +807,9 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: def forward(
self, batch: FlashCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids = batch.input_ids input_ids = batch.input_ids
@ -903,7 +905,11 @@ class FlashCausalLM(Model):
# Replay the graph # Replay the graph
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
# Slice output to the correct shape # Slice output to the correct shape
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs] logits = cuda_graph["logits"][:bs]
return logits, speculative_logits return logits, speculative_logits
@ -961,7 +967,6 @@ class FlashCausalLM(Model):
batch.speculative_ids, batch.speculative_ids,
speculative_logits, speculative_logits,
) )
logger.info(f"Accepted ids {accepted_ids}") logger.info(f"Accepted ids {accepted_ids}")

View File

@ -412,7 +412,9 @@ class BaseFlashMistral(FlashCausalLM):
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize() torch.cuda.synchronize()
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: def forward(
self, batch: FlashMistralBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids = batch.input_ids input_ids = batch.input_ids
@ -515,12 +517,15 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
# Slice output to the correct shape # Slice output to the correct shape
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs] logits = cuda_graph["logits"][:bs]
return logits, speculative_logits return logits, speculative_logits
class FlashMistral(BaseFlashMistral): class FlashMistral(BaseFlashMistral):
def __init__( def __init__(
self, self,

View File

@ -560,7 +560,11 @@ class Mamba(Model):
cuda_graph["inference_params"].ssm_states[:, :bs] cuda_graph["inference_params"].ssm_states[:, :bs]
) )
# Slice output to the correct shape # Slice output to the correct shape
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs] logits = cuda_graph["logits"][:bs]
return logits, speculative_logits return logits, speculative_logits
@ -593,7 +597,9 @@ class Mamba(Model):
batch.inference_params = inference_params batch.inference_params = inference_params
# Forward pass # Forward pass
logits, speculative_logits = self.forward(input_ids, inference_params=batch.inference_params) logits, speculative_logits = self.forward(
input_ids, inference_params=batch.inference_params
)
# batch.inference_params = new_inference_params # batch.inference_params = new_inference_params
# Results # Results

View File

@ -379,6 +379,7 @@ class SuperLayer(nn.Module):
def forward(self, x): def forward(self, x):
return self.linear.forward(x) return self.linear.forward(x)
class ResBlock(torch.nn.Module): class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()
@ -426,6 +427,7 @@ class MedusaHead(torch.nn.Module):
x = self.out(x) x = self.out(x)
return x return x
class SpeculativeHead(nn.Module): class SpeculativeHead(nn.Module):
def __init__(self, lm_head, medusa): def __init__(self, lm_head, medusa):
super().__init__() super().__init__()
@ -440,6 +442,7 @@ class SpeculativeHead(nn.Module):
from pathlib import Path from pathlib import Path
from safetensors import safe_open from safetensors import safe_open
import json import json
medusa_config = str(Path(use_medusa) / "config.json") medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
@ -460,11 +463,14 @@ class SpeculativeHead(nn.Module):
medusa = None medusa = None
return SpeculativeHead(lm_head, medusa) return SpeculativeHead(lm_head, medusa)
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input) logits = self.lm_head(input)
speculative_logits = self.medusa(input) if self.medusa is not None else None speculative_logits = self.medusa(input) if self.medusa is not None else None
return logits, speculative_logits return logits, speculative_logits
class TensorParallelHead(SuperLayer): class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group, should_gather: bool): def __init__(self, linear, process_group, should_gather: bool):
super().__init__(linear) super().__init__(linear)