From c3954319996b4833ace90fa5fcbecfe5a3e8f994 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 13 May 2024 13:46:29 +0200 Subject: [PATCH] Granite support? (#1882) Fixes # (issue) - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- server/poetry.lock | 3 +- server/requirements_cuda.txt | 8 +-- server/requirements_rocm.txt | 8 +-- .../custom_modeling/flash_llama_modeling.py | 57 +++++++------------ 4 files changed, 31 insertions(+), 45 deletions(-) diff --git a/server/poetry.lock b/server/poetry.lock index cdbbd581..70e51d64 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "accelerate" @@ -1585,6 +1585,7 @@ description = "Nvidia JIT LTO Library" optional = false python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_aarch64.whl", hash = "sha256:004186d5ea6a57758fd6d57052a123c73a4815adf365eb8dd6a85c9eaa7535ff"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"}, {file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"}, ] diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index c2714764..7f0efded 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -11,7 +11,7 @@ googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.62.2 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" @@ -32,15 +32,15 @@ prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" -regex==2024.4.28 ; python_version >= "3.9" and python_version < "3.13" +regex==2024.5.10 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" -tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.40.1 ; python_version >= "3.9" and python_version < "3.13" +tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.40.2 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index c2714764..7f0efded 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -11,7 +11,7 @@ googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.62.2 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.62.2 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.62.2 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.63.0 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" idna==3.7 ; python_version >= "3.9" and python_version < "3.13" @@ -32,15 +32,15 @@ prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" -regex==2024.4.28 ; python_version >= "3.9" and python_version < "3.13" +regex==2024.5.10 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.4.3 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.0 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.5.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13" -tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.40.1 ; python_version >= "3.9" and python_version < "3.13" +tqdm==4.66.4 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.40.2 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.11.0 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index a7969494..6a6b2e0a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -23,7 +23,6 @@ import torch.distributed from torch import nn from transformers.activations import ACT2FN -from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.utils import paged_attention, flash_attn @@ -32,7 +31,6 @@ from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, - get_linear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( @@ -41,22 +39,29 @@ from text_generation_server.layers.layernorm import ( def load_attention(config, prefix, weights): + bias = config.attention_bias if config.num_attention_heads != config.num_key_value_heads: - return _load_gqa(config, prefix, weights) + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=bias, + ) else: if config.model_type == "baichuan": return TensorParallelColumnLinear.load_qkv( config, prefix=f"{prefix}.W_pack", weights=weights, - bias=False, + bias=bias, ) elif config.model_type == "phi3": return TensorParallelColumnLinear.load_qkv( config, prefix=f"{prefix}.qkv_proj", weights=weights, - bias=False, + bias=bias, ) else: return TensorParallelColumnLinear.load_multi( @@ -64,36 +69,10 @@ def load_attention(config, prefix, weights): prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, - bias=False, + bias=bias, ) -def _load_gqa(config, prefix: str, weights): - assert config.hidden_size % config.num_attention_heads == 0 - assert config.num_attention_heads % weights.process_group.size() == 0 - - weight = weights.get_multi_weights_col( - prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], - quantize=config.quantize, - dim=0, - ) - - if config.quantize not in ["gptq", "awq"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) - - head_size = config.hidden_size // config.num_attention_heads - num_heads = config.num_attention_heads // weights.process_group.size() - num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ - (num_heads + 2 * num_key_value_heads) * head_size, - config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) - - class FlashLlamaAttention(torch.nn.Module): def __init__( self, @@ -214,12 +193,13 @@ class LlamaMLP(nn.Module): ) ) # Fuse gate and up proj + bias = getattr(config, "mlp_bias", False) if config.model_type == "phi3": self.gate_up_proj = TensorParallelColumnLinear.load_gate_up( config, prefix=f"{prefix}.gate_up_proj", weights=weights, - bias=False, + bias=bias, ) else: self.gate_up_proj = TensorParallelColumnLinear.load_multi( @@ -227,13 +207,13 @@ class LlamaMLP(nn.Module): prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, - bias=False, + bias=bias, ) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, - bias=False, + bias=bias, ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() @@ -385,9 +365,14 @@ class FlashLlamaForCausalLM(torch.nn.Module): weights=weights, ) self.model = FlashLlamaModel(prefix, config, weights) + if config.tie_word_embeddings: + suffix = "model.embed_tokens" + else: + suffix = "lm_head" + self.lm_head = SpeculativeHead.load( config, - prefix="lm_head" if not prefix else f"{prefix}.lm_head", + prefix=suffix if not prefix else f"{prefix}.suffix", weights=weights, )