From 2446f3ec32755ac7181bfc83a5631feab1289eee Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 21 Feb 2024 21:37:27 +0000 Subject: [PATCH] [Tmp] Revamping medusa to make it orthogonal. --- .../text_generation_server/models/__init__.py | 49 ++++----- .../custom_modeling/flash_mistral_modeling.py | 4 +- .../models/flash_causal_lm.py | 10 +- .../models/flash_mistral.py | 20 +++- server/text_generation_server/utils/layers.py | 102 +++++++++++++++++- server/text_generation_server/utils/medusa.py | 6 -- 6 files changed, 147 insertions(+), 44 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index abab3486..fcdc25f1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -120,32 +120,11 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) - if model_id.startswith("bigcode/"): - if FLASH_ATTENTION: - return FlashSantacoderSharded( - model_id, - revision, - quantize=quantize, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - elif sharded: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") - ) - else: - return SantaCoder( - model_id, - revision, - quantize=quantize, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) @@ -193,6 +172,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -203,6 +183,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -215,6 +196,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -224,6 +206,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -232,6 +215,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -242,6 +226,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -250,6 +235,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -258,6 +244,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -268,15 +255,16 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, - use_medusa=use_medusa, ) else: return CausalLM( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -291,6 +279,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -301,9 +290,9 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, - use_medusa=use_medusa, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) @@ -312,6 +301,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -347,6 +337,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -357,6 +348,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -365,6 +357,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -378,6 +371,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -391,6 +385,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -400,6 +395,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -409,6 +405,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -418,6 +415,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -441,6 +439,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -460,6 +459,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -468,6 +468,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index fda34e5a..ed9306e0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -32,7 +32,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastRMSNorm, ) @@ -419,7 +419,7 @@ class FlashMistralForCausalLM(torch.nn.Module): super().__init__() self.model = MistralModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b8d0be22..a63c6641 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -926,16 +926,11 @@ class FlashCausalLM(Model): batch.slots = slots try: - out = self.forward(batch) + out, speculative_logits = self.forward(batch) except Exception as e: del batch raise e - if isinstance(out, tuple): - out, speculative_logits = out - else: - speculative_logits = None - if prefill: next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out @@ -963,6 +958,9 @@ class FlashCausalLM(Model): batch.speculative_ids, speculative_logits, ) + + + logger.info(f"Accepted ids {accepted_ids}") batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs, accepted_ids diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 5df4e214..69d7493a 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -294,6 +294,7 @@ class BaseFlashMistral(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -319,6 +320,7 @@ class BaseFlashMistral(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa # Set context windows if config.sliding_window is not None: @@ -394,7 +396,7 @@ class BaseFlashMistral(FlashCausalLM): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - self.cuda_graphs[bs]["logits"] = self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, @@ -406,9 +408,12 @@ class BaseFlashMistral(FlashCausalLM): prefill_cache_indices=None, lm_head_indices=None, ) + self.cuda_graphs[bs]["logits"] = logits + if speculative_logits is not None: + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() - def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: input_ids = batch.input_ids @@ -479,7 +484,7 @@ class BaseFlashMistral(FlashCausalLM): cuda_graph = self.cuda_graphs.get(padded_bs, None) if cu_seqlen_prefill is not None or cuda_graph is None: - logits = self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -493,7 +498,7 @@ class BaseFlashMistral(FlashCausalLM): ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None - return logits + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded @@ -511,7 +516,10 @@ class BaseFlashMistral(FlashCausalLM): cuda_graph["graph"].replay() # Slice output to the correct shape - return cuda_graph["logits"][:bs] + speculative_logits = cuda_graph["speculative_logits"][:bs] if "speculative_logits" in cuda_graph else None + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits + class FlashMistral(BaseFlashMistral): @@ -520,6 +528,7 @@ class FlashMistral(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -529,6 +538,7 @@ class FlashMistral(BaseFlashMistral): model_id=model_id, revision=revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index bef2a146..35bfbcba 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -4,7 +4,7 @@ import torch.distributed from torch import nn from torch.nn import functional as F -from typing import List +from typing import List, Tuple, Optional from loguru import logger from functools import lru_cache @@ -379,6 +379,106 @@ class SuperLayer(nn.Module): def forward(self, x): return self.linear.forward(x) +class ResBlock(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.linear = FastLinear.load( + config, prefix=f"{prefix}.linear", weights=weights, bias=True + ) + self.act = torch.nn.SiLU() + + def forward(self, x): + return x + self.act(self.linear(x)) + + +class MedusaModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + self.heads = torch.nn.ModuleList( + [ + MedusaHead(config, prefix=f"{i}", weights=weights) + for i in range(config["medusa_num_heads"]) + ] + ) + + def forward(self, x): + speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) + return speculative_logits + + +class MedusaHead(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.blocks = torch.nn.ModuleList( + [ + ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) + for i in range(config["medusa_num_layers"]) + ] + ) + n = len(self.blocks) + self.out = FastLinear.load( + config, prefix=f"{prefix}.{n}", weights=weights, bias=False + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = self.out(x) + return x + +class SpeculativeHead(nn.Module): + def __init__(self, lm_head, medusa): + super().__init__() + self.lm_head = lm_head + self.medusa = medusa + + @staticmethod + def load(config, prefix: str, weights): + lm_head = TensorParallelHead.load(config, prefix, weights) + use_medusa = config.use_medusa + if use_medusa: + from pathlib import Path + from huggingface_hub import hf_hub_download + from text_generation_server.utils.weights import Weights + from safetensors import safe_open + import json + import os + is_local_model = ( + Path(use_medusa).exists() and Path(use_medusa).is_dir() + ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None + + if not is_local_model: + medusa_config = hf_hub_download( + use_medusa, revision=revision, filename="config.json" + ) + medusa_head = hf_hub_download( + use_medusa, revision=revision, filename="medusa_lm_head.pt" + ) + else: + medusa_config = str(Path(use_medusa) / "config.json") + medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") + + with open(medusa_config, "r") as f: + config = json.load(f) + filename = medusa_head[: -len(".pt")] + ".safetensors" + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + weights.routing[k] = filename + + medusa = MedusaModel(config, weights) + else: + medusa = None + return SpeculativeHead(lm_head, medusa) + + def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + logits = self.lm_head(input) + speculative_logits = self.medusa(input) if self.medusa is not None else None + return logits, speculative_logits class TensorParallelHead(SuperLayer): def __init__(self, linear, process_group, should_gather: bool): diff --git a/server/text_generation_server/utils/medusa.py b/server/text_generation_server/utils/medusa.py index 634119cb..9f66ba10 100644 --- a/server/text_generation_server/utils/medusa.py +++ b/server/text_generation_server/utils/medusa.py @@ -3,12 +3,6 @@ from dataclasses import dataclass from text_generation_server.utils.layers import TensorParallelHead, FastLinear -@dataclass -class Output: - logits: torch.FloatTensor = None - speculative_logits: torch.FloatTensor = None - - class ResBlock(torch.nn.Module): def __init__(self, config, prefix, weights): super().__init__()