mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Black.
This commit is contained in:
parent
7a9998d47c
commit
64d38afa9f
@ -551,7 +551,9 @@ class CausalLM(Model):
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
) -> Tuple[
|
||||
torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
|
||||
]:
|
||||
# Model Forward
|
||||
kwargs = {
|
||||
"input_ids": input_ids,
|
||||
|
@ -911,10 +911,13 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
), speculative_logits
|
||||
return (
|
||||
CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
),
|
||||
speculative_logits,
|
||||
)
|
||||
|
@ -613,9 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
||||
|
||||
self.transformer = FlashRWModel(config, weights)
|
||||
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="lm_head", weights=weights
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -272,9 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
||||
weights,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.fc = SpeculativeHead.load(
|
||||
config=config, prefix="lm_head", weights=weights
|
||||
)
|
||||
self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights)
|
||||
self.additional_fc = FastLinear.load(
|
||||
config=config,
|
||||
prefix="lm_head.additional_fc",
|
||||
|
@ -206,9 +206,7 @@ class MambaModel(nn.Module):
|
||||
self.norm_f = FastRMSNorm.load(
|
||||
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, f"{prefix}.embedding", weights
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
|
@ -1140,17 +1140,20 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
), speculative_logits
|
||||
return (
|
||||
Seq2SeqLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
),
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
|
@ -807,7 +807,9 @@ class FlashCausalLM(Model):
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
def forward(
|
||||
self, batch: FlashCausalLMBatch
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
@ -903,7 +905,11 @@ class FlashCausalLM(Model):
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
# Slice output to the correct shape
|
||||
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
@ -961,7 +967,6 @@ class FlashCausalLM(Model):
|
||||
batch.speculative_ids,
|
||||
speculative_logits,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Accepted ids {accepted_ids}")
|
||||
|
||||
|
@ -412,7 +412,9 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
def forward(
|
||||
self, batch: FlashMistralBatch
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
input_ids = batch.input_ids
|
||||
@ -515,12 +517,15 @@ class BaseFlashMistral(FlashCausalLM):
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
|
||||
class FlashMistral(BaseFlashMistral):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -560,7 +560,11 @@ class Mamba(Model):
|
||||
cuda_graph["inference_params"].ssm_states[:, :bs]
|
||||
)
|
||||
# Slice output to the correct shape
|
||||
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
||||
@ -593,7 +597,9 @@ class Mamba(Model):
|
||||
batch.inference_params = inference_params
|
||||
|
||||
# Forward pass
|
||||
logits, speculative_logits = self.forward(input_ids, inference_params=batch.inference_params)
|
||||
logits, speculative_logits = self.forward(
|
||||
input_ids, inference_params=batch.inference_params
|
||||
)
|
||||
|
||||
# batch.inference_params = new_inference_params
|
||||
# Results
|
||||
|
@ -379,6 +379,7 @@ class SuperLayer(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.linear.forward(x)
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, config, prefix, weights):
|
||||
super().__init__()
|
||||
@ -426,6 +427,7 @@ class MedusaHead(torch.nn.Module):
|
||||
x = self.out(x)
|
||||
return x
|
||||
|
||||
|
||||
class SpeculativeHead(nn.Module):
|
||||
def __init__(self, lm_head, medusa):
|
||||
super().__init__()
|
||||
@ -440,6 +442,7 @@ class SpeculativeHead(nn.Module):
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
|
||||
medusa_config = str(Path(use_medusa) / "config.json")
|
||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||
|
||||
@ -460,11 +463,14 @@ class SpeculativeHead(nn.Module):
|
||||
medusa = None
|
||||
return SpeculativeHead(lm_head, medusa)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
def forward(
|
||||
self, input: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
logits = self.lm_head(input)
|
||||
speculative_logits = self.medusa(input) if self.medusa is not None else None
|
||||
return logits, speculative_logits
|
||||
|
||||
|
||||
class TensorParallelHead(SuperLayer):
|
||||
def __init__(self, linear, process_group, should_gather: bool):
|
||||
super().__init__(linear)
|
||||
|
Loading…
Reference in New Issue
Block a user