From cf16172a85b90ad0092001d45d5e712e4622bbfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 18 Jul 2024 15:15:57 +0000 Subject: [PATCH] Exclude non-MLP layers when using FP8 quantization with Llama --- .../models/test_completion_prompts.py | 2 + .../custom_modeling/flash_llama_modeling.py | 62 +++++++++++++------ 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index 0efb6693..d787873b 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream( chunk = [c.replace("data:", "") for c in chunk] # remove empty strings chunk = [c for c in chunk if c] + # remove completion marking chunk + chunk = [c for c in chunk if c != " [DONE]"] # parse json chunk = [json.loads(c) for c in chunk] diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 78832341..5237a484 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager from typing import List, Optional, Tuple import torch @@ -25,7 +26,6 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( @@ -42,10 +42,16 @@ from text_generation_server.layers import ( TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) +from text_generation_server.layers.fp8 import Fp8Weight from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + UnquantizedWeight, + Weights, +) if SYSTEM == "rocm": try: @@ -105,6 +111,19 @@ def load_attention(config, prefix: str, weights, layer_id): ) +@contextmanager +def no_fp8(weights: Weights): + weights_loader = weights.weights_loader + if ( + isinstance(weights_loader, DefaultWeightsLoader) + and weights_loader.weight_class is Fp8Weight + ): + weights_loader = DefaultWeightsLoader(UnquantizedWeight) + + with weights.use_loader(weights_loader): + yield + + class FlashLlamaAttention(torch.nn.Module): def __init__( self, @@ -330,12 +349,15 @@ class LlamaMLP(nn.Module): class FlashLlamaLayer(nn.Module): def __init__(self, index, prefix, config, weights): super().__init__() - self.self_attn = FlashLlamaAttention( - index=index, - prefix=f"{prefix}.self_attn", - config=config, - weights=weights, - ) + + with no_fp8(weights): + self.self_attn = FlashLlamaAttention( + index=index, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + self.mlp = LlamaMLP( prefix=f"{prefix}.mlp", config=config, weights=weights, index=index ) @@ -470,23 +492,27 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() - self.embed_tokens = TensorParallelEmbedding( - prefix=( - "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" - ), - weights=weights, - ) + with no_fp8(weights): + self.embed_tokens = TensorParallelEmbedding( + prefix=( + "model.embed_tokens" + if not prefix + else f"{prefix}.model.embed_tokens" + ), + weights=weights, + ) self.model = FlashLlamaModel(prefix, config, weights) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: suffix = "lm_head" - self.lm_head = SpeculativeHead.load( - config, - prefix=suffix if not prefix else f"{prefix}.{suffix}", - weights=weights, - ) + with no_fp8(weights): + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ) def forward( self,