mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 15:32:08 +00:00
hotfix: mixtral
This commit is contained in:
parent
ad9d6288c8
commit
30620a9a44
@ -464,9 +464,9 @@ class DenseMoE(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MixtralLayer(nn.Module):
|
class MixtralLayer(nn.Module):
|
||||||
def __init__(self, layer_id, config, weights):
|
def __init__(self, prefix, layer_id, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
prefix = f"model.layers.{layer_id}"
|
prefix = f"{prefix}.layers.{layer_id}"
|
||||||
|
|
||||||
self.self_attn = MixtralAttention(
|
self.self_attn = MixtralAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
@ -525,16 +525,20 @@ class MixtralLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MixtralModel(torch.nn.Module):
|
class MixtralModel(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix="model.embed_tokens", weights=weights
|
prefix=(
|
||||||
|
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||||
|
),
|
||||||
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
MixtralLayer(
|
MixtralLayer(
|
||||||
|
"model" if not prefix else f"{prefix}.model",
|
||||||
layer_id,
|
layer_id,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
@ -543,7 +547,9 @@ class MixtralModel(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.head_size = self.layers[0].self_attn.head_size
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
@ -593,13 +599,13 @@ class MixtralModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashMixtralForCausalLM(torch.nn.Module):
|
class FlashMixtralForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = MixtralModel(config, weights)
|
self.model = MixtralModel(prefix, config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix="lm_head",
|
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.max_past = config.sliding_window
|
self.max_past = config.sliding_window
|
||||||
|
Loading…
Reference in New Issue
Block a user