mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
Exclude non-MLP layers when using FP8 quantization with Llama
This commit is contained in:
parent
a93b2b5083
commit
cf16172a85
@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream(
|
||||
chunk = [c.replace("data:", "") for c in chunk]
|
||||
# remove empty strings
|
||||
chunk = [c for c in chunk if c]
|
||||
# remove completion marking chunk
|
||||
chunk = [c for c in chunk if c != " [DONE]"]
|
||||
# parse json
|
||||
chunk = [json.loads(c) for c in chunk]
|
||||
|
||||
|
@ -18,6 +18,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -25,7 +26,6 @@ import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
@ -42,10 +42,16 @@ from text_generation_server.layers import (
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.fp8 import Fp8Weight
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
)
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
@ -105,6 +111,19 @@ def load_attention(config, prefix: str, weights, layer_id):
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_fp8(weights: Weights):
|
||||
weights_loader = weights.weights_loader
|
||||
if (
|
||||
isinstance(weights_loader, DefaultWeightsLoader)
|
||||
and weights_loader.weight_class is Fp8Weight
|
||||
):
|
||||
weights_loader = DefaultWeightsLoader(UnquantizedWeight)
|
||||
|
||||
with weights.use_loader(weights_loader):
|
||||
yield
|
||||
|
||||
|
||||
class FlashLlamaAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -330,12 +349,15 @@ class LlamaMLP(nn.Module):
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, index, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
with no_fp8(weights):
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
index=index,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.mlp = LlamaMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||
)
|
||||
@ -470,9 +492,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
super().__init__()
|
||||
|
||||
with no_fp8(weights):
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=(
|
||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||
"model.embed_tokens"
|
||||
if not prefix
|
||||
else f"{prefix}.model.embed_tokens"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
@ -482,6 +507,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
else:
|
||||
suffix = "lm_head"
|
||||
|
||||
with no_fp8(weights):
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||
|
Loading…
Reference in New Issue
Block a user