mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
feat: medusa v2 (#1734)
This commit is contained in:
parent
661081d2d2
commit
f6d5c2edf2
@ -47,7 +47,7 @@ def get_model(
|
|||||||
if speculate is not None:
|
if speculate is not None:
|
||||||
if speculate > speculate_medusa:
|
if speculate > speculate_medusa:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
set_speculate(speculate)
|
set_speculate(speculate)
|
||||||
|
@ -814,7 +814,7 @@ class FlashCausalLM(Model):
|
|||||||
for bs in CUDA_GRAPHS:
|
for bs in CUDA_GRAPHS:
|
||||||
if self.speculate is None or self.speculate + 1 <= bs:
|
if self.speculate is None or self.speculate + 1 <= bs:
|
||||||
self.cuda_graph_warmup(bs, max_s, max_bt)
|
self.cuda_graph_warmup(bs, max_s, max_bt)
|
||||||
except Exception:
|
except torch.cuda.OutOfMemoryError:
|
||||||
logger.exception(f"Decode cuda graph warmup failed")
|
logger.exception(f"Decode cuda graph warmup failed")
|
||||||
|
|
||||||
return int(num_blocks * BLOCK_SIZE)
|
return int(num_blocks * BLOCK_SIZE)
|
||||||
@ -874,22 +874,14 @@ class FlashCausalLM(Model):
|
|||||||
lm_head_indices = batch.prefill_head_indices
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
bs = input_ids.shape[0]
|
bs = input_ids.shape[0]
|
||||||
padded_bs = bs
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
if bs == 3:
|
if sorted_padded_bs:
|
||||||
padded_bs = 4
|
# Get associated cuda graph
|
||||||
elif 3 < bs <= 8:
|
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||||
padded_bs = 8
|
else:
|
||||||
elif bs > 8:
|
cuda_graph = None
|
||||||
padded_bs = (bs + 7) // 8 * 8
|
|
||||||
|
|
||||||
# Try to find an associated cuda graph
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
|
||||||
|
|
||||||
if (
|
|
||||||
cu_seqlen_prefill is not None
|
|
||||||
or cuda_graph is None
|
|
||||||
or batch.speculative_ids is not None
|
|
||||||
):
|
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
@ -432,12 +432,12 @@ class ResBlock(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MedusaModel(torch.nn.Module):
|
class MedusaModel(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, config, medusa_config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.heads = torch.nn.ModuleList(
|
self.heads = torch.nn.ModuleList(
|
||||||
[
|
[
|
||||||
MedusaHead(config, prefix=f"{i}", weights=weights)
|
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
||||||
for i in range(config["medusa_num_heads"])
|
for i in range(medusa_config["medusa_num_heads"])
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -447,12 +447,12 @@ class MedusaModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MedusaHead(torch.nn.Module):
|
class MedusaHead(torch.nn.Module):
|
||||||
def __init__(self, config, prefix, weights):
|
def __init__(self, config, medusa_config, prefix, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.blocks = torch.nn.ModuleList(
|
self.blocks = torch.nn.ModuleList(
|
||||||
[
|
[
|
||||||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
|
||||||
for i in range(config["medusa_num_layers"])
|
for i in range(medusa_config["medusa_num_layers"])
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
n = len(self.blocks)
|
n = len(self.blocks)
|
||||||
@ -467,7 +467,7 @@ class MedusaHead(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SpeculativeHead(nn.Module):
|
class MedusaHeadV1(nn.Module):
|
||||||
def __init__(self, lm_head, medusa):
|
def __init__(self, lm_head, medusa):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lm_head = lm_head
|
self.lm_head = lm_head
|
||||||
@ -475,38 +475,147 @@ class SpeculativeHead(nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(config, prefix: str, weights):
|
def load(config, prefix: str, weights):
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors import safe_open
|
||||||
|
import json
|
||||||
|
|
||||||
|
use_medusa = config.use_medusa
|
||||||
|
|
||||||
|
medusa_config = str(Path(use_medusa) / "config.json")
|
||||||
|
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||||
|
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
medusa_config = json.load(f)
|
||||||
|
routing = weights.routing
|
||||||
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if k in routing and routing[k] != filename:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
|
)
|
||||||
|
routing[k] = filename
|
||||||
|
|
||||||
|
medusa = MedusaModel(config, medusa_config, weights)
|
||||||
lm_head = TensorParallelHead.load(config, prefix, weights)
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
return MedusaHeadV1(lm_head, medusa)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
logits = self.lm_head(input)
|
||||||
|
speculative_logits = self.medusa(input)
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaHeadV2(nn.Module):
|
||||||
|
def __init__(self, config, prefix, weights):
|
||||||
|
super().__init__()
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors import safe_open
|
||||||
|
import json
|
||||||
|
|
||||||
|
use_medusa = config.use_medusa
|
||||||
|
|
||||||
|
medusa_config = str(Path(use_medusa) / "config.json")
|
||||||
|
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
||||||
|
|
||||||
|
with open(medusa_config, "r") as f:
|
||||||
|
medusa_config = json.load(f)
|
||||||
|
routing = weights.routing
|
||||||
|
with safe_open(filename, framework="pytorch") as f:
|
||||||
|
for k in f.keys():
|
||||||
|
if k in routing and routing[k] != filename:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
||||||
|
)
|
||||||
|
routing[k] = filename
|
||||||
|
|
||||||
|
self.n_medusa_heads = medusa_config["medusa_num_heads"]
|
||||||
|
|
||||||
|
assert medusa_config["medusa_num_layers"] == 1
|
||||||
|
self.linear = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.process_group = weights.process_group
|
||||||
|
self.world_size = self.process_group.size()
|
||||||
|
self.rank = self.process_group.rank()
|
||||||
|
|
||||||
|
self.act = torch.nn.SiLU()
|
||||||
|
|
||||||
|
self.lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
size = x.shape[-1]
|
||||||
|
block_size = (size + self.world_size - 1) // self.world_size
|
||||||
|
start = self.rank * block_size
|
||||||
|
stop = (self.rank + 1) * block_size
|
||||||
|
|
||||||
|
x_block = x[:, start:stop]
|
||||||
|
|
||||||
|
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
|
||||||
|
medusa_res = self.act(self.linear(x)).reshape(
|
||||||
|
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply all residual medusa heads
|
||||||
|
output = x[:, start:stop].unsqueeze(-2) + medusa_res
|
||||||
|
|
||||||
|
# Gather medusa heads
|
||||||
|
world_output = [
|
||||||
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
|
]
|
||||||
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
|
|
||||||
|
# Stack x and medusa residual x
|
||||||
|
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
|
||||||
|
|
||||||
|
# Compute lm head on x + medusa residual x
|
||||||
|
logits = self.lm_head(stacked_x)
|
||||||
|
|
||||||
|
# Finally, split logits from speculative logits
|
||||||
|
logits, speculative_logits = torch.split(
|
||||||
|
logits, [1, self.n_medusa_heads], dim=-2
|
||||||
|
)
|
||||||
|
# Squeeze added dimension
|
||||||
|
logits = logits.squeeze(-2)
|
||||||
|
|
||||||
|
return logits, speculative_logits
|
||||||
|
|
||||||
|
|
||||||
|
class SpeculativeHead(nn.Module):
|
||||||
|
def __init__(self, lm_head, medusa):
|
||||||
|
super().__init__()
|
||||||
|
self.head = lm_head
|
||||||
|
self.medusa = medusa
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(config, prefix: str, weights):
|
||||||
use_medusa = config.use_medusa
|
use_medusa = config.use_medusa
|
||||||
if use_medusa:
|
if use_medusa:
|
||||||
from pathlib import Path
|
lm_head = None
|
||||||
from safetensors import safe_open
|
try:
|
||||||
import json
|
medusa = MedusaHeadV1.load(config, prefix, weights)
|
||||||
|
except:
|
||||||
medusa_config = str(Path(use_medusa) / "config.json")
|
medusa = MedusaHeadV2(config, prefix, weights)
|
||||||
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
|
|
||||||
|
|
||||||
with open(medusa_config, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
routing = weights.routing
|
|
||||||
with safe_open(filename, framework="pytorch") as f:
|
|
||||||
for k in f.keys():
|
|
||||||
if k in routing:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
|
|
||||||
)
|
|
||||||
weights.routing[k] = filename
|
|
||||||
|
|
||||||
medusa = MedusaModel(config, weights)
|
|
||||||
else:
|
else:
|
||||||
|
lm_head = TensorParallelHead.load(config, prefix, weights)
|
||||||
medusa = None
|
medusa = None
|
||||||
return SpeculativeHead(lm_head, medusa)
|
return SpeculativeHead(lm_head, medusa)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input: torch.Tensor
|
self, input: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
logits = self.lm_head(input)
|
if self.medusa is not None:
|
||||||
speculative_logits = self.medusa(input) if self.medusa is not None else None
|
return self.medusa(input)
|
||||||
return logits, speculative_logits
|
|
||||||
|
assert self.head is not None
|
||||||
|
logits = self.head(input)
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelHead(SuperLayer):
|
class TensorParallelHead(SuperLayer):
|
||||||
|
Loading…
Reference in New Issue
Block a user