From b9d8af694bb1a712c93869394eded6a5e6305c8f Mon Sep 17 00:00:00 2001 From: Joshua Rosenkranz Date: Fri, 3 May 2024 10:02:11 -0400 Subject: [PATCH] added a bunch of cleanup based on comments in PR; removed conditionals from LayerNormParameterized and renamed to MLPSpeculatorLayerNorm; now using modules for tensor-parallel (this is work in progress - looking into if this is right approach); fixed issue with getting medusa model; fixed for more efficient loading --- server/text_generation_server/utils/layers.py | 130 +++++++----------- 1 file changed, 52 insertions(+), 78 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 0fccbd75..c3080937 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -442,69 +442,46 @@ class ResBlock(torch.nn.Module): def forward(self, x): return x + self.act(self.linear(x)) -class LayerNormParameterized(nn.Module): +class MLPSpeculatorLayerNorm(nn.Module): """ - A generalized LayerNorm implementation. With all optional arguments set to True, equivalent to nn.LayerNorm up to epsilon stabilization term - (this class divides inputs by min(norm, eps), while nn.LayerNorm divides by norm + eps). + A L2 normalization implementation ... Args ---- normalized_shape : int Dimensionality of input data (size of final tensor axis) + elementwise_scale_weight : torch.Tensor + learned scaling term after normalization? + elementwise_shift_bias : torch.Tensor + learned bias term after normalization? eps : float Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8). - elementwise_scale : bool - Include a learned scaling term after normalization? - elementwise_shift : bool - Include a learned bias term after normalization? - use_mean : bool - Recenter inputs around zero before normalizing, or just rescale? + """ def __init__( self, normalized_shape, + elementwise_scale_weight: torch.Tensor, + elementwise_shift_bias: torch.Tensor, eps=1e-06, - elementwise_scale=True, - elementwise_shift=False, - use_mean=False, - use_high_precision_pow=False, ): - super(LayerNormParameterized, self).__init__() + super(MLPSpeculatorLayerNorm, self).__init__() self.normalized_shape = normalized_shape + self.weight = nn.Parameter(elementwise_scale_weight) + self.bias = nn.Parameter(elementwise_shift_bias) self.eps = eps - self.elementwise_scale = elementwise_scale - self.elementwise_shift = elementwise_shift - self.use_mean = use_mean - self.use_high_precision_pow = use_high_precision_pow - - if self.elementwise_scale: - self.weight = nn.Parameter(torch.empty(self.normalized_shape)) - if self.elementwise_shift: - self.bias = nn.Parameter(torch.empty(self.normalized_shape)) - - def reset_parameters(self): - if self.elementwise_scale: - self.weight.data.fill_(1) - if self.elementwise_shift: - self.bias.data.zero_() def forward(self, x): - if self.use_mean: - x = x - x.mean(-1, keepdim=True) xf = x - if self.use_high_precision_pow: - xf = x.float() xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) x = xf.type_as(x) - if self.elementwise_scale: - x = self.weight * x - if self.elementwise_shift: - x = x + self.bias + x = self.weight * x + x = x + self.bias return x class MLPSpeculatorModel(torch.nn.Module): - def __init__(self, config, emb, proj, head, ln): + def __init__(self, config, prefix, weights): super().__init__() self.config = config self.n_predict = config.n_predict @@ -512,10 +489,30 @@ class MLPSpeculatorModel(torch.nn.Module): inner_dim = config.inner_dim if config.inner_dim != 0 else config.emb_dim self.inner_dim = inner_dim self.config = config.vocab_size - self.emb = emb - self.proj = proj - self.head = head - self.ln = ln + self.emb = nn.ModuleList( + [TensorParallelEmbedding(f"{prefix}.emb.{i}", weights) for i in range(config.n_predict)] + ) + self.proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.proj.{i}" for i in range(config.n_predict)], + weights=weights, + bias=False, + dim=0 + ) + self.head = nn.ModuleList( + [TensorParallelRowLinear.load(config, f"{prefix}.head.{i}", weights, bias=False) for i in range(config.n_predict)] + ) + self.ln = nn.ModuleList( + [ + MLPSpeculatorLayerNorm( + config.inner_dim, + elementwise_scale_weight=weights.get_tensor(f"{prefix}.ln.{i}.weight"), + elementwise_shift_bias=weights.get_tensor(f"{prefix}.ln.{i}.bias") + ) + for i in range(config.n_predict) + ] + ) + # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation self.state_weight = 0.5 ** (0.5 / self.n_predict) self.emb_weight = math.sqrt(1 - self.state_weight ** 2) @@ -565,36 +562,8 @@ class MLPSpeculatorModel(torch.nn.Module): 1, best_guesses[:, :, None, None].expand(-1, -1, self.n_predict, self.vsize) ) # b n h v - def load(self, config, prefix, weights): - self.emb = nn.ModuleList( - [nn.Embedding(config.vocab_size, config.inner_dim) for _ in range(config.n_predict)] - ) - self.proj = nn.ModuleList( - [ - nn.Linear((config.emb_dim if i == 0 else config.inner_dim), config.inner_dim, bias=False) - for i in range(config.n_predict) - ] - ) - self.head = nn.ModuleList( - [nn.Linear(config.inner_dim, config.vocab_size, bias=False) for _ in range(config.n_predict)] - ) - self.ln = nn.ModuleList( - [ - LayerNormParameterized( - config.inner_dim, elementwise_shift=True, elementwise_scale=True - ) - for _ in range(config.n_predict) - ] - ) - for i in range(config.n_predict): - self.emb[i].weight.data.copy_(weights.get_tensor(f"{prefix}.emb.{i}.weight")) - self.proj[i].weight.data.copy_(weights.get_tensor(f"{prefix}.proj.{i}.weight")) - self.ln[i].weight.data.copy_(weights.get_tensor(f"{prefix}.ln.{i}.weight")) - self.ln[i].bias.data.copy_(weights.get_tensor(f"{prefix}.ln.{i}.bias")) - self.head[i].weight.data.copy_(weights.get_tensor(f"{prefix}.head.{i}.weight")) - -class MLPSpeculatorHeadV1(nn.Module): +class MLPSpeculatorHead(nn.Module): def __init__(self, lm_head, mlp_speculator): super().__init__() self.lm_head = lm_head @@ -629,9 +598,9 @@ class MLPSpeculatorHeadV1(nn.Module): ) routing[k] = filename - mlp_speculator = MLPSpeculatorModel.load(speculator_config, prefix, weights) + mlp_speculator = MLPSpeculatorModel(speculator_config, "speculator", weights) lm_head = TensorParallelHead.load(speculator_config, prefix, weights) - return MLPSpeculatorHeadV1(lm_head, mlp_speculator) + return MLPSpeculatorHead(lm_head, mlp_speculator) class MedusaModel(torch.nn.Module): @@ -815,12 +784,17 @@ class SpeculativeHead(nn.Module): speculator_config = json.load(f) lm_head = None - architecture = speculator_config["architectures"][0] + # currently medusa does not have an architecture specified, so try-except for now + # this should really be handled in a better way though (maybe the classname can be part of the config) + try: + architecture = speculator_config["architectures"][0] - if architecture == "MLPSpeculatorPreTrainedModel": - speculator_config.use_speculator = config.use_speculator - speculator = MLPSpeculatorHeadV1.load(speculator_config, "speculator", weights) - else: # not sure what medusa name is... + if architecture == "MLPSpeculatorPreTrainedModel": + speculator_config.use_speculator = config.use_speculator + speculator = MLPSpeculatorHead.load(speculator_config, prefix, weights) + else: + speculator = None + except KeyError: try: speculator = MedusaHeadV1.load(config, prefix, weights) except: