[Tmp] Revamping medusa to make it orthogonal.

This commit is contained in:
Nicolas Patry 2024-02-21 21:37:27 +00:00
parent 010508cec8
commit 2446f3ec32
6 changed files with 147 additions and 44 deletions

View File

@ -120,28 +120,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
dtype=dtype, use_medusa=use_medusa,
trust_remote_code=trust_remote_code,
)
if model_id.startswith("bigcode/"):
if FLASH_ATTENTION:
return FlashSantacoderSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder")
)
else:
return SantaCoder(
model_id,
revision,
quantize=quantize,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -193,6 +172,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -203,6 +183,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -215,6 +196,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -224,6 +206,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -232,6 +215,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -242,6 +226,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -250,6 +235,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -258,6 +244,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -268,15 +255,16 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
) )
else: else:
return CausalLM( return CausalLM(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -291,6 +279,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -301,9 +290,9 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
) )
elif sharded: elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
@ -312,6 +301,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -347,6 +337,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -357,6 +348,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -365,6 +357,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -378,6 +371,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -391,6 +385,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -400,6 +395,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -409,6 +405,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -418,6 +415,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -441,6 +439,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -460,6 +459,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
@ -468,6 +468,7 @@ def get_model(
model_id, model_id,
revision, revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -32,7 +32,7 @@ from text_generation_server.utils.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
PositionRotaryEmbedding, PositionRotaryEmbedding,
TensorParallelHead, SpeculativeHead,
get_linear, get_linear,
FastRMSNorm, FastRMSNorm,
) )
@ -419,7 +419,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
super().__init__() super().__init__()
self.model = MistralModel(config, weights) self.model = MistralModel(config, weights)
self.lm_head = TensorParallelHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix="lm_head", prefix="lm_head",
weights=weights, weights=weights,

View File

@ -926,16 +926,11 @@ class FlashCausalLM(Model):
batch.slots = slots batch.slots = slots
try: try:
out = self.forward(batch) out, speculative_logits = self.forward(batch)
except Exception as e: except Exception as e:
del batch del batch
raise e raise e
if isinstance(out, tuple):
out, speculative_logits = out
else:
speculative_logits = None
if prefill: if prefill:
next_token_logits = ( next_token_logits = (
out[batch.prefill_next_token_indices] if prefill_logprobs else out out[batch.prefill_next_token_indices] if prefill_logprobs else out
@ -964,6 +959,9 @@ class FlashCausalLM(Model):
speculative_logits, speculative_logits,
) )
logger.info(f"Accepted ids {accepted_ids}")
batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens(
batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids
) )

View File

@ -294,6 +294,7 @@ class BaseFlashMistral(FlashCausalLM):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -319,6 +320,7 @@ class BaseFlashMistral(FlashCausalLM):
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
config.quantize = quantize config.quantize = quantize
config.use_medusa = use_medusa
# Set context windows # Set context windows
if config.sliding_window is not None: if config.sliding_window is not None:
@ -394,7 +396,7 @@ class BaseFlashMistral(FlashCausalLM):
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL): with torch.cuda.graph(graph, pool=MEM_POOL):
self.cuda_graphs[bs]["logits"] = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=None, cu_seqlen_prefill=None,
@ -406,9 +408,12 @@ class BaseFlashMistral(FlashCausalLM):
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
) )
self.cuda_graphs[bs]["logits"] = logits
if speculative_logits is not None:
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
torch.cuda.synchronize() torch.cuda.synchronize()
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None:
input_ids = batch.input_ids input_ids = batch.input_ids
@ -479,7 +484,7 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph = self.cuda_graphs.get(padded_bs, None) cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
@ -493,7 +498,7 @@ class BaseFlashMistral(FlashCausalLM):
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None batch.prefill_cache_indices = None
return logits return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph # Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded # Static inputs are potentially padded
@ -511,7 +516,10 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph["graph"].replay() cuda_graph["graph"].replay()
# Slice output to the correct shape # Slice output to the correct shape
return cuda_graph["logits"][:bs] speculative_logits = cuda_graph["speculative_logits"][:bs] if "speculative_logits" in cuda_graph else None
logits = cuda_graph["logits"][:bs]
return logits, speculative_logits
class FlashMistral(BaseFlashMistral): class FlashMistral(BaseFlashMistral):
@ -520,6 +528,7 @@ class FlashMistral(BaseFlashMistral):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
use_medusa: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
@ -529,6 +538,7 @@ class FlashMistral(BaseFlashMistral):
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
quantize=quantize, quantize=quantize,
use_medusa=use_medusa,
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )

View File

@ -4,7 +4,7 @@ import torch.distributed
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from typing import List from typing import List, Tuple, Optional
from loguru import logger from loguru import logger
from functools import lru_cache from functools import lru_cache
@ -379,6 +379,106 @@ class SuperLayer(nn.Module):
def forward(self, x): def forward(self, x):
return self.linear.forward(x) return self.linear.forward(x)
class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.linear = FastLinear.load(
config, prefix=f"{prefix}.linear", weights=weights, bias=True
)
self.act = torch.nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
class MedusaModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"])
]
)
def forward(self, x):
speculative_logits = torch.stack([head(x) for head in self.heads], dim=1)
return speculative_logits
class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"])
]
)
n = len(self.blocks)
self.out = FastLinear.load(
config, prefix=f"{prefix}.{n}", weights=weights, bias=False
)
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.out(x)
return x
class SpeculativeHead(nn.Module):
def __init__(self, lm_head, medusa):
super().__init__()
self.lm_head = lm_head
self.medusa = medusa
@staticmethod
def load(config, prefix: str, weights):
lm_head = TensorParallelHead.load(config, prefix, weights)
use_medusa = config.use_medusa
if use_medusa:
from pathlib import Path
from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import Weights
from safetensors import safe_open
import json
import os
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
filename = medusa_head[: -len(".pt")] + ".safetensors"
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
weights.routing[k] = filename
medusa = MedusaModel(config, weights)
else:
medusa = None
return SpeculativeHead(lm_head, medusa)
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
speculative_logits = self.medusa(input) if self.medusa is not None else None
return logits, speculative_logits
class TensorParallelHead(SuperLayer): class TensorParallelHead(SuperLayer):
def __init__(self, linear, process_group, should_gather: bool): def __init__(self, linear, process_group, should_gather: bool):

View File

@ -3,12 +3,6 @@ from dataclasses import dataclass
from text_generation_server.utils.layers import TensorParallelHead, FastLinear from text_generation_server.utils.layers import TensorParallelHead, FastLinear
@dataclass
class Output:
logits: torch.FloatTensor = None
speculative_logits: torch.FloatTensor = None
class ResBlock(torch.nn.Module): class ResBlock(torch.nn.Module):
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights):
super().__init__() super().__init__()