mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Black.
This commit is contained in:
parent
7a9998d47c
commit
64d38afa9f
@ -551,7 +551,9 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[
|
||||||
|
torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]
|
||||||
|
]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
|
@ -911,10 +911,13 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|||||||
output = (lm_logits,) + transformer_outputs[1:]
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return CausalLMOutputWithCrossAttentions(
|
return (
|
||||||
|
CausalLMOutputWithCrossAttentions(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
past_key_values=transformer_outputs.past_key_values,
|
past_key_values=transformer_outputs.past_key_values,
|
||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=transformer_outputs.hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
), speculative_logits
|
),
|
||||||
|
speculative_logits,
|
||||||
|
)
|
||||||
|
@ -613,9 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
|
|||||||
|
|
||||||
self.transformer = FlashRWModel(config, weights)
|
self.transformer = FlashRWModel(config, weights)
|
||||||
|
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights)
|
||||||
config, prefix="lm_head", weights=weights
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -272,9 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
|||||||
weights,
|
weights,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc = SpeculativeHead.load(
|
self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights)
|
||||||
config=config, prefix="lm_head", weights=weights
|
|
||||||
)
|
|
||||||
self.additional_fc = FastLinear.load(
|
self.additional_fc = FastLinear.load(
|
||||||
config=config,
|
config=config,
|
||||||
prefix="lm_head.additional_fc",
|
prefix="lm_head.additional_fc",
|
||||||
|
@ -206,9 +206,7 @@ class MambaModel(nn.Module):
|
|||||||
self.norm_f = FastRMSNorm.load(
|
self.norm_f = FastRMSNorm.load(
|
||||||
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
||||||
)
|
)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
|
||||||
config, f"{prefix}.embedding", weights
|
|
||||||
)
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -1140,7 +1140,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return Seq2SeqLMOutput(
|
return (
|
||||||
|
Seq2SeqLMOutput(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
past_key_values=decoder_outputs.past_key_values,
|
past_key_values=decoder_outputs.past_key_values,
|
||||||
@ -1150,7 +1151,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
||||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||||
encoder_attentions=encoder_outputs.attentions,
|
encoder_attentions=encoder_outputs.attentions,
|
||||||
), speculative_logits
|
),
|
||||||
|
speculative_logits,
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
|
@ -807,7 +807,9 @@ class FlashCausalLM(Model):
|
|||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
|
|
||||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
def forward(
|
||||||
|
self, batch: FlashCausalLMBatch
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
@ -903,7 +905,11 @@ class FlashCausalLM(Model):
|
|||||||
# Replay the graph
|
# Replay the graph
|
||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
# Slice output to the correct shape
|
# Slice output to the correct shape
|
||||||
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
|
speculative_logits = (
|
||||||
|
cuda_graph["speculative_logits"][:bs]
|
||||||
|
if cuda_graph["speculative_logits"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
logits = cuda_graph["logits"][:bs]
|
logits = cuda_graph["logits"][:bs]
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
@ -962,7 +968,6 @@ class FlashCausalLM(Model):
|
|||||||
speculative_logits,
|
speculative_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"Accepted ids {accepted_ids}")
|
logger.info(f"Accepted ids {accepted_ids}")
|
||||||
|
|
||||||
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
|
||||||
|
@ -412,7 +412,9 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
def forward(
|
||||||
|
self, batch: FlashMistralBatch
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
input_ids = batch.input_ids
|
input_ids = batch.input_ids
|
||||||
@ -515,12 +517,15 @@ class BaseFlashMistral(FlashCausalLM):
|
|||||||
cuda_graph["graph"].replay()
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
# Slice output to the correct shape
|
# Slice output to the correct shape
|
||||||
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
|
speculative_logits = (
|
||||||
|
cuda_graph["speculative_logits"][:bs]
|
||||||
|
if cuda_graph["speculative_logits"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
logits = cuda_graph["logits"][:bs]
|
logits = cuda_graph["logits"][:bs]
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FlashMistral(BaseFlashMistral):
|
class FlashMistral(BaseFlashMistral):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -560,7 +560,11 @@ class Mamba(Model):
|
|||||||
cuda_graph["inference_params"].ssm_states[:, :bs]
|
cuda_graph["inference_params"].ssm_states[:, :bs]
|
||||||
)
|
)
|
||||||
# Slice output to the correct shape
|
# Slice output to the correct shape
|
||||||
speculative_logits = cuda_graph["speculative_logits"][:bs] if cuda_graph["speculative_logits"] is not None else None
|
speculative_logits = (
|
||||||
|
cuda_graph["speculative_logits"][:bs]
|
||||||
|
if cuda_graph["speculative_logits"] is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
logits = cuda_graph["logits"][:bs]
|
logits = cuda_graph["logits"][:bs]
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
@ -593,7 +597,9 @@ class Mamba(Model):
|
|||||||
batch.inference_params = inference_params
|
batch.inference_params = inference_params
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
logits, speculative_logits = self.forward(input_ids, inference_params=batch.inference_params)
|
logits, speculative_logits = self.forward(
|
||||||
|
input_ids, inference_params=batch.inference_params
|
||||||
|
)
|
||||||
|
|
||||||
# batch.inference_params = new_inference_params
|
# batch.inference_params = new_inference_params
|
||||||
# Results
|
# Results
|
||||||
|
@ -379,6 +379,7 @@ class SuperLayer(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.linear.forward(x)
|
return self.linear.forward(x)
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(torch.nn.Module):
|
class ResBlock(torch.nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -426,6 +427,7 @@ class MedusaHead(torch.nn.Module):
|
|||||||
x = self.out(x)
|
x = self.out(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SpeculativeHead(nn.Module):
|
class SpeculativeHead(nn.Module):
|
||||||
def __init__(self, lm_head, medusa):
|
def __init__(self, lm_head, medusa):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -440,6 +442,7 @@ class SpeculativeHead(nn.Module):
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import json
|
import json
|
||||||
|
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
medusa_config = str(Path(use_medusa) / "config.json")
|
||||||
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
|
||||||
|
|
||||||
@ -460,11 +463,14 @@ class SpeculativeHead(nn.Module):
|
|||||||
medusa = None
|
medusa = None
|
||||||
return SpeculativeHead(lm_head, medusa)
|
return SpeculativeHead(lm_head, medusa)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
def forward(
|
||||||
|
self, input: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
logits = self.lm_head(input)
|
logits = self.lm_head(input)
|
||||||
speculative_logits = self.medusa(input) if self.medusa is not None else None
|
speculative_logits = self.medusa(input) if self.medusa is not None else None
|
||||||
return logits, speculative_logits
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelHead(SuperLayer):
|
class TensorParallelHead(SuperLayer):
|
||||||
def __init__(self, linear, process_group, should_gather: bool):
|
def __init__(self, linear, process_group, should_gather: bool):
|
||||||
super().__init__(linear)
|
super().__init__(linear)
|
||||||
|
Loading…
Reference in New Issue
Block a user