This commit is contained in:
Nicolas Patry 2024-02-22 13:01:43 +00:00
parent 7a9998d47c
commit 64d38afa9f
10 changed files with 61 additions and 37 deletions

View File

@ -551,7 +551,9 @@ class CausalLM(Model):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]:
) -> Tuple[
torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
]:
# Model Forward
kwargs = {
"input_ids": input_ids,

View File

@ -911,10 +911,13 @@ class BloomForCausalLM(BloomPreTrainedModel):
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
), speculative_logits
return (
CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
),
speculative_logits,
)

View File

@ -613,9 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
self.transformer = FlashRWModel(config, weights)
self.lm_head = SpeculativeHead.load(
config, prefix="lm_head", weights=weights
)
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
def forward(
self,

View File

@ -272,9 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
weights,
) -> None:
super().__init__()
self.fc = SpeculativeHead.load(
config=config, prefix="lm_head", weights=weights
)
self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights)
self.additional_fc = FastLinear.load(
config=config,
prefix="lm_head.additional_fc",

View File

@ -206,9 +206,7 @@ class MambaModel(nn.Module):
self.norm_f = FastRMSNorm.load(
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
)
self.lm_head = SpeculativeHead.load(
config, f"{prefix}.embedding", weights
)
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
self.config = config
def forward(

View File

@ -1140,17 +1140,20 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
), speculative_logits
return (
Seq2SeqLMOutput(
loss=loss,
logits=logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
),
speculative_logits,
)
def prepare_inputs_for_generation(
self,

View File

@ -807,7 +807,9 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE)
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
def forward(
self, batch: FlashCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
@ -903,7 +905,11 @@ class FlashCausalLM(Model):
# Replay the graph
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
@ -962,7 +968,6 @@ class FlashCausalLM(Model):
speculative_logits,
)
logger.info(f"Accepted ids {accepted_ids}")
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(

View File

@ -412,7 +412,9 @@ class BaseFlashMistral(FlashCausalLM):
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize()
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
def forward(
self, batch: FlashMistralBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
@ -515,12 +517,15 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph["graph"].replay()
# Slice output to the correct shape
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
class FlashMistral(BaseFlashMistral):
def __init__(
self,

View File

@ -560,7 +560,11 @@ class Mamba(Model):
cuda_graph["inference_params"].ssm_states[:, :bs]
)
# Slice output to the correct shape
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
speculative_logits = (
cuda_graph["speculative_logits"][:bs]
if cuda_graph["speculative_logits"] is not None
else None
)
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
@ -593,7 +597,9 @@ class Mamba(Model):
batch.inference_params = inference_params
# Forward pass
logits, speculative_logits = self.forward(input_ids, inference_params=batch.inference_params)
logits, speculative_logits = self.forward(
input_ids, inference_params=batch.inference_params
)
# batch.inference_params = new_inference_params
# Results

View File

@ -379,6 +379,7 @@ class SuperLayer(nn.Module):
def forward(self, x):
return self.linear.forward(x)
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
@ -426,6 +427,7 @@ class MedusaHead(torch.nn.Module):
x = self.out(x)
return x
class SpeculativeHead(nn.Module):
def __init__(self, lm_head, medusa):
super().__init__()
@ -440,6 +442,7 @@ class SpeculativeHead(nn.Module):
from pathlib import Path
from safetensors import safe_open
import json
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
@ -460,11 +463,14 @@ class SpeculativeHead(nn.Module):
medusa = None
return SpeculativeHead(lm_head, medusa)
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
speculative_logits = self.medusa(input) if self.medusa is not None else None
return logits, speculative_logits
class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group, should_gather: bool):
super().__init__(linear)