[Tmp] Revamping medusa to make it orthogonal.

This commit is contained in:
Nicolas Patry 2024-02-21 21:37:27 +00:00
parent 010508cec8
commit 2446f3ec32
6 changed files with 147 additions and 44 deletions

View File

@ -120,32 +120,11 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_id.startswith("bigcode/"):
if FLASH_ATTENTION:
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
return SantaCoder(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
@ -193,6 +172,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -203,6 +183,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -215,6 +196,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -224,6 +206,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -232,6 +215,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -242,6 +226,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -250,6 +235,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -258,6 +244,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -268,15 +255,16 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
)
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -291,6 +279,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -301,9 +290,9 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
)
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
@ -312,6 +301,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -347,6 +337,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -357,6 +348,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -365,6 +357,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -378,6 +371,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -391,6 +385,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -400,6 +395,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -409,6 +405,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -418,6 +415,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -441,6 +439,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -460,6 +459,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@ -468,6 +468,7 @@ def get_model(
model_id,
revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
PositionRotaryEmbedding,
TensorParallelHead,
SpeculativeHead,
get_linear,
FastRMSNorm,
)
@ -419,7 +419,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
super().__init__()
self.model = MistralModel(config, weights)
self.lm_head = TensorParallelHead.load(
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",
weights=weights,

View File

@ -926,16 +926,11 @@ class FlashCausalLM(Model):
batch.slots = slots
try:
out = self.forward(batch)
out, speculative_logits = self.forward(batch)
except Exception as e:
del batch
raise e
if isinstance(out, tuple):
out, speculative_logits = out
else:
speculative_logits = None
if prefill:
next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out
@ -964,6 +959,9 @@ class FlashCausalLM(Model):
speculative_logits,
)
logger.info(f"Accepted ids {accepted_ids}")
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
)

View File

@ -294,6 +294,7 @@ class BaseFlashMistral(FlashCausalLM):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -319,6 +320,7 @@ class BaseFlashMistral(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
config.use_medusa = use_medusa
# Set context windows
if config.sliding_window is not None:
@ -394,7 +396,7 @@ class BaseFlashMistral(FlashCausalLM):
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward(
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
@ -406,9 +408,12 @@ class BaseFlashMistral(FlashCausalLM):
prefill_cache_indices=None,
lm_head_indices=None,
)
self.cuda_graphs[bs]["logits"] = logits
if speculative_logits is not None:
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize()
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
@ -479,7 +484,7 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None:
logits = self.model.forward(
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
@ -493,7 +498,7 @@ class BaseFlashMistral(FlashCausalLM):
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
@ -511,7 +516,10 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph["graph"].replay()
# Slice output to the correct shape
return cuda_graph["logits"][:bs]
speculative_logits = cuda_graph["speculative_logits"][:bs] if "speculative_logits" in cuda_graph else None
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
class FlashMistral(BaseFlashMistral):
@ -520,6 +528,7 @@ class FlashMistral(BaseFlashMistral):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
@ -529,6 +538,7 @@ class FlashMistral(BaseFlashMistral):
model_id=model_id,
revision=revision,
quantize=quantize,
use_medusa=use_medusa,
dtype=dtype,
trust_remote_code=trust_remote_code,
)

View File

@ -4,7 +4,7 @@ import torch.distributed
from torch import nn
from torch.nn import functional as F
from typing import List
from typing import List, Tuple, Optional
from loguru import logger
from functools import lru_cache
@ -379,6 +379,106 @@ class SuperLayer(nn.Module):
def forward(self, x):
return self.linear.forward(x)
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True
)
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
class MedusaModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"])
]
)
def forward(self, x):
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return speculative_logits
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"])
]
)
n = len(self.blocks)
self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x
class SpeculativeHead(nn.Module):
def __init__(self, lm_head, medusa):
super().__init__()
self.lm_head = lm_head
self.medusa = medusa
@staticmethod
def load(config, prefix: str, weights):
lm_head = TensorParallelHead.load(config, prefix, weights)
use_medusa = config.use_medusa
if use_medusa:
from pathlib import Path
from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import Weights
from safetensors import safe_open
import json
import os
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
filename = medusa_head[: -len(".pt")] + ".safetensors"
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
weights.routing[k] = filename
medusa = MedusaModel(config, weights)
else:
medusa = None
return SpeculativeHead(lm_head, medusa)
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
speculative_logits = self.medusa(input) if self.medusa is not None else None
return logits, speculative_logits
class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group, should_gather: bool):

View File

@ -3,12 +3,6 @@ from dataclasses import dataclass
from text_generation_server.utils.layers import TensorParallelHead, FastLinear
@dataclass
class Output:
logits: torch.FloatTensor = None
speculative_logits: torch.FloatTensor = None
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()