Exclude non-MLP layers when using FP8 quantization with Llama

This commit is contained in:
Daniël de Kok 2024-07-18 15:15:57 +00:00
parent a93b2b5083
commit cf16172a85
2 changed files with 46 additions and 18 deletions

View File

@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream(
chunk = [c.replace("data:", "") for c in chunk]
# remove empty strings
chunk = [c for c in chunk if c]
# remove completion marking chunk
chunk = [c for c in chunk if c != " [DONE]"]
# parse json
chunk = [json.loads(c) for c in chunk]

View File

@ -18,6 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import List, Optional, Tuple
import torch
@ -25,7 +26,6 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import (
@ -42,10 +42,16 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.fp8 import Fp8Weight
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
FastRMSNorm,
)
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
)
if SYSTEM == "rocm":
try:
@ -105,6 +111,19 @@ def load_attention(config, prefix: str, weights, layer_id):
)
@contextmanager
def no_fp8(weights: Weights):
weights_loader = weights.weights_loader
if (
isinstance(weights_loader, DefaultWeightsLoader)
and weights_loader.weight_class is Fp8Weight
):
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
with weights.use_loader(weights_loader):
yield
class FlashLlamaAttention(torch.nn.Module):
def __init__(
self,
@ -330,12 +349,15 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights):
super().__init__()
with no_fp8(weights):
self.self_attn = FlashLlamaAttention(
index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
)
self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
)
@ -470,9 +492,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding(
prefix=(
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
"model.embed_tokens"
if not prefix
else f"{prefix}.model.embed_tokens"
),
weights=weights,
)
@ -482,6 +507,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else:
suffix = "lm_head"
with no_fp8(weights):
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",