diff --git a/Dockerfile b/Dockerfile index f2a7f9a1..6360ab06 100644 --- a/Dockerfile +++ b/Dockerfile @@ -165,6 +165,7 @@ COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bi # Install launcher COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher +RUN apt update && apt install build-essential g++ -y COPY proto proto COPY server/requirements.txt server/requirements.txt COPY server/pyproject.toml server/pyproject.toml @@ -176,7 +177,6 @@ COPY server/Makefile-transformers server/Makefile-transformers RUN cd server && \ make gen-server && \ pip install ".[bnb, accelerate]" --no-cache-dir -RUN apt update && apt install build-essential g++ -y # AWS Sagemaker compatbile image FROM base as sagemaker diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 56aa6b0d..c19641e8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -38,6 +38,7 @@ from flash_attn.layers.rotary import RotaryEmbedding # from safetensors.torch import load_file from safetensors import safe_open from huggingface_hub import hf_hub_download +from loguru import logger HAS_BITS_AND_BYTES = True try: @@ -197,29 +198,30 @@ class FastLinear(nn.Linear): with safe_open(filename, framework="pt", device=f"cuda:{rank}") as f: if name == 'self_attn.query_key_value': query_name = f'model.layers.{layer}.self_attn' - self.qlinear.qweight[:, : self.out_features // 3] = get_slice(f, f"{query_name}.q_proj.qweight") - self.qlinear.qweight[:, self.out_features // 3:-self.out_features // 3] = get_slice(f, f"{query_name}.k_proj.qweight") - self.qlinear.qweight[:,-self.out_features // 3: ] = get_slice(f, f"{query_name}.v_proj.qweight") + N = self.out_features // 3 + self.qlinear.qweight[:, : N] = get_slice(f, f"{query_name}.q_proj.qweight") + self.qlinear.qweight[:, N:2 * N] = get_slice(f, f"{query_name}.k_proj.qweight") + self.qlinear.qweight[:, 2*N:] = get_slice(f, f"{query_name}.v_proj.qweight") - N = self.qlinear.qzeros.shape[1] - self.qlinear.qzeros[:, : N // 3] = get_slice(f, f"{query_name}.q_proj.qzeros") - self.qlinear.qzeros[:, N // 3:-N // 3] = get_slice(f, f"{query_name}.k_proj.qzeros") - self.qlinear.qzeros[:,-N // 3: ] = get_slice(f, f"{query_name}.v_proj.qzeros") + self.qlinear.scales[:, : N] = get_slice(f, f"{query_name}.q_proj.scales") + self.qlinear.scales[:, N:2*N] = get_slice(f, f"{query_name}.k_proj.scales") + self.qlinear.scales[:, 2*N:] = get_slice(f, f"{query_name}.v_proj.scales") + + N = self.qlinear.qzeros.shape[1] // 3 + self.qlinear.qzeros[:, :N] = get_slice(f, f"{query_name}.q_proj.qzeros") + self.qlinear.qzeros[:, N:2*N] = get_slice(f, f"{query_name}.k_proj.qzeros") + self.qlinear.qzeros[:, 2*N:] = get_slice(f, f"{query_name}.v_proj.qzeros") - self.qlinear.scales[:, : self.out_features // 3] = get_slice(f, f"{query_name}.q_proj.scales") - self.qlinear.scales[:, self.out_features // 3:-self.out_features // 3] = get_slice(f, f"{query_name}.k_proj.scales") - self.qlinear.scales[:,-self.out_features // 3: ] = get_slice(f, f"{query_name}.v_proj.scales") torch.testing.assert_close(f.get_tensor(f"{query_name}.q_proj.g_idx"), f.get_tensor(f"{query_name}.k_proj.g_idx")) torch.testing.assert_close(f.get_tensor(f"{query_name}.q_proj.g_idx"), f.get_tensor(f"{query_name}.v_proj.g_idx")) self.qlinear.g_idx[:] = f.get_tensor(f"{query_name}.q_proj.g_idx") elif name == "self_attn.o_proj": - self.qlinear.qweight = f.get_tensor(f"model.layers.{layer}.self_attn.o_proj.qweight") + self.qlinear.qweight = get_slice(f, f"model.layers.{layer}.self_attn.o_proj.qweight") self.qlinear.qzeros = f.get_tensor(f"model.layers.{layer}.self_attn.o_proj.qzeros") self.qlinear.scales = f.get_tensor(f"model.layers.{layer}.self_attn.o_proj.scales") - self.qlinear.g_idx[:] = get_slice(f, f"model.layers.{layer}.self_attn.o_proj.g_idx") - import ipdb;ipdb.set_trace() + self.qlinear.g_idx = get_slice(f, f"model.layers.{layer}.self_attn.o_proj.g_idx") elif name == "mlp.gate_up_proj": N = self.qlinear.qweight.shape[1] // 2 @@ -233,21 +235,16 @@ class FastLinear(nn.Linear): self.qlinear.g_idx[:] = f.get_tensor(f"model.layers.{layer}.mlp.gate_proj.g_idx") N = self.qlinear.qzeros.shape[1] // 2 - self.qlinear.qzeros[:, N:] = get_slice(f, f"model.layers.{layer}.mlp.up_proj.qzeros") self.qlinear.qzeros[:, :N] = get_slice(f, f"model.layers.{layer}.mlp.gate_proj.qzeros") + self.qlinear.qzeros[:, N:] = get_slice(f, f"model.layers.{layer}.mlp.up_proj.qzeros") elif name == "mlp.down_proj": - self.qlinear.qweight = f.get_tensor(f"model.layers.{layer}.mlp.down_proj.qweight") + self.qlinear.qweight = get_slice(f, f"model.layers.{layer}.mlp.down_proj.qweight") self.qlinear.qzeros = f.get_tensor(f"model.layers.{layer}.mlp.down_proj.qzeros") self.qlinear.scales = f.get_tensor(f"model.layers.{layer}.mlp.down_proj.scales") - self.qlinear.g_idx[:] = get_slice(f, f"model.layers.{layer}.mlp.down_proj.g_idx") + self.qlinear.g_idx = get_slice(f, f"model.layers.{layer}.mlp.down_proj.g_idx") else: raise ValueError("Not handled") - print(layer, name) - if name == 'self_attn.query_key_value': - out = self.qlinear(torch.zeros((6, self.in_features)).cuda().half()) - if name == "self_attn.o_proj": - out = self.qlinear(torch.zeros((6, self.in_features)).cuda().half()) # Delete reference to data self.weight = None diff --git a/server/text_generation_server/quant/quant_linear.py b/server/text_generation_server/quant/quant_linear.py index 714c9296..3a181dfa 100644 --- a/server/text_generation_server/quant/quant_linear.py +++ b/server/text_generation_server/quant/quant_linear.py @@ -372,6 +372,7 @@ class QuantLinear(nn.Module): def forward(self, x): out_shape = x.shape[:-1] + (self.outfeatures, ) + assert x.shape[-1] == self.infeatures out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape)