Exclude non-MLP layers when using FP8 quantization with Llama

This commit is contained in:
Daniël de Kok 2024-07-18 15:15:57 +00:00
parent a93b2b5083
commit cf16172a85
2 changed files with 46 additions and 18 deletions

View File

@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream(
chunk = [c.replace("data:", "") for c in chunk] chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings # remove empty strings
chunk = [c for c in chunk if c] chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json # parse json
chunk = [json.loads(c) for c in chunk] chunk = [json.loads(c) for c in chunk]

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import contextmanager
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
@ -25,7 +26,6 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
@ -42,10 +42,16 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.fp8 import Fp8Weight
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
)
if SYSTEM == "rocm": if SYSTEM == "rocm":
try: try:
@ -105,6 +111,19 @@ def load_attention(config, prefix: str, weights, layer_id):
) )
@contextmanager
def no_fp8(weights: Weights):
weights_loader = weights.weights_loader
if (
isinstance(weights_loader, DefaultWeightsLoader)
and weights_loader.weight_class is Fp8Weight
):
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
with weights.use_loader(weights_loader):
yield
class FlashLlamaAttention(torch.nn.Module): class FlashLlamaAttention(torch.nn.Module):
def __init__( def __init__(
self, self,
@ -330,12 +349,15 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights): def __init__(self, index, prefix, config, weights):
super().__init__() super().__init__()
with no_fp8(weights):
self.self_attn = FlashLlamaAttention( self.self_attn = FlashLlamaAttention(
index=index, index=index,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
config=config, config=config,
weights=weights, weights=weights,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
) )
@ -470,9 +492,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights): def __init__(self, prefix: str, config, weights):
super().__init__() super().__init__()
with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix=( prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" "model.embed_tokens"
if not prefix
else f"{prefix}.model.embed_tokens"
), ),
weights=weights, weights=weights,
) )
@ -482,6 +507,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else: else:
suffix = "lm_head" suffix = "lm_head"
with no_fp8(weights):
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix=suffix if not prefix else f"{prefix}.{suffix}", prefix=suffix if not prefix else f"{prefix}.{suffix}",