From 268e8d4935e458dfd5ca22cc004744af63923bc2 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Fri, 12 Apr 2024 20:21:52 +0200 Subject: [PATCH] fix: use get_speculate to the number of layers --- server/text_generation_server/utils/layers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 8c46ea49..6e4a13cd 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -8,7 +8,8 @@ from typing import List, Tuple, Optional from loguru import logger from functools import lru_cache -# Dummy comment. +from text_generation_server.utils.speculate import get_speculate + HAS_BITS_AND_BYTES = True try: import bitsandbytes as bnb @@ -445,7 +446,7 @@ class MedusaModel(torch.nn.Module): self.heads = torch.nn.ModuleList( [ MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights) - for i in range(medusa_config["medusa_num_heads"]) + for i in range(get_speculate()) ] ) @@ -542,7 +543,7 @@ class MedusaHeadV2(nn.Module): ) routing[k] = filename - self.n_medusa_heads = medusa_config["medusa_num_heads"] + self.n_medusa_heads = get_speculate() assert medusa_config["medusa_num_layers"] == 1 self.linear = TensorParallelColumnLinear.load_multi(