From da4199ed970ad5db0bf62b921533f1d1558305d0 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Fri, 22 Mar 2024 17:59:25 +0100 Subject: [PATCH] feat: cohere (#1660) --- server/Makefile | 2 +- server/poetry.lock | 10 +- server/requirements_common.txt | 46 -- server/requirements_cuda.txt | 25 +- server/requirements_rocm.txt | 24 +- .../custom_modeling/flash_cohere_modeling.py | 461 ++++++++++++++++++ .../custom_modeling/flash_gemma_modeling.py | 161 ------ .../models/flash_cohere.py | 75 +++ .../models/flash_gemma.py | 2 +- 9 files changed, 567 insertions(+), 239 deletions(-) delete mode 100644 server/requirements_common.txt create mode 100644 server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py create mode 100644 server/text_generation_server/models/flash_cohere.py diff --git a/server/Makefile b/server/Makefile index d0fdbaad..04919a59 100644 --- a/server/Makefile +++ b/server/Makefile @@ -36,4 +36,4 @@ update-lock: poetry lock --no-update export-requirements: - poetry export -f requirements.txt --without-hashes --output requirements.txt + poetry export -o requirements.txt --without-hashes diff --git a/server/poetry.lock b/server/poetry.lock index ade73f03..f7d40699 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1830,13 +1830,13 @@ tests = ["GitPython", "datasets", "optuna", "parameterized", "psutil", "pytest ( [[package]] name = "outlines" -version = "0.0.27" +version = "0.0.36" description = "Probabilistic Generative Model Programming" optional = true python-versions = ">=3.8" files = [ - {file = "outlines-0.0.27-py3-none-any.whl", hash = "sha256:dd614f49760ff8870a5d491fad4a372d7b7d4da5c1646f1b42f12a9d5e34db4b"}, - {file = "outlines-0.0.27.tar.gz", hash = "sha256:debc49f0db4d968eea05a4a6134516b3e49c6c8607106aa097410a4b53d5af6c"}, + {file = "outlines-0.0.36-py3-none-any.whl", hash = "sha256:afa02ca5c449c47731fa06af66d13c2f5ee8b30f8b82b4db90e08215d6f111d1"}, + {file = "outlines-0.0.36.tar.gz", hash = "sha256:3cffb43143548cd78c6061990feb461cffd5479999391b8390471ea839c2d46e"}, ] [package.dependencies] @@ -1859,7 +1859,7 @@ transformers = "*" [package.extras] serve = ["fastapi", "pydantic (>=2.0)", "ray (==2.9.0)", "uvicorn", "vllm (>=0.3.0)"] -test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "llama-cpp-python", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers"] +test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface-hub", "llama-cpp-python", "openai (>=1.0.0)", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers"] [[package]] name = "packaging" @@ -3634,4 +3634,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "7510d50d3465d99cf08f9b1daed50c41d251262f548fb11dd141d46f02b8a8f5" +content-hash = "70670851d12a378b67fd4b4ed8a4a17d0861637e13a02ddf96a119768e8444e5" diff --git a/server/requirements_common.txt b/server/requirements_common.txt deleted file mode 100644 index 5a321834..00000000 --- a/server/requirements_common.txt +++ /dev/null @@ -1,46 +0,0 @@ -backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" -certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13" -charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" -click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" -colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") -deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" -einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13" -fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13" -grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" -grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13" -grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" -idna==3.4 ; python_version >= "3.9" and python_version < "3.13" -loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" -opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==23.2 ; python_version >= "3.9" and python_version < "3.13" -pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13" -protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13" -pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" -regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13" -requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" -safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" -scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13" -sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" -setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13" -tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13" -tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13" -typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13" -win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" -wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_cuda.txt b/server/requirements_cuda.txt index 1e3477bf..4b2fbc24 100644 --- a/server/requirements_cuda.txt +++ b/server/requirements_cuda.txt @@ -1,5 +1,4 @@ backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" -bitsandbytes==0.41.3.post2 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" @@ -7,13 +6,13 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13" -fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" -grpcio-reflection==1.60.1 ; python_version >= "3.9" and python_version < "3.13" -grpcio-status==1.60.1 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.60.1 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13" +grpcio-reflection==1.62.1 ; python_version >= "3.9" and python_version < "3.13" +grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" idna==3.6 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" @@ -27,21 +26,21 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==23.2 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" -safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" +safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13" scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" -setuptools==69.1.0 ; python_version >= "3.9" and python_version < "3.13" +setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.39.0 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.2.0 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.10.0 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/requirements_rocm.txt b/server/requirements_rocm.txt index 3912abd8..4b2fbc24 100644 --- a/server/requirements_rocm.txt +++ b/server/requirements_rocm.txt @@ -6,13 +6,13 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13" -fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13" -googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13" +fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13" +googleapis-common-protos==1.63.0 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" -grpcio-reflection==1.60.1 ; python_version >= "3.9" and python_version < "3.13" -grpcio-status==1.60.1 ; python_version >= "3.9" and python_version < "3.13" -grpcio==1.60.1 ; python_version >= "3.9" and python_version < "3.13" -hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13" +grpcio-reflection==1.62.1 ; python_version >= "3.9" and python_version < "3.13" +grpcio-status==1.62.1 ; python_version >= "3.9" and python_version < "3.13" +grpcio==1.62.1 ; python_version >= "3.9" and python_version < "3.13" +hf-transfer==0.1.6 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" idna==3.6 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" @@ -26,21 +26,21 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==23.2 ; python_version >= "3.9" and python_version < "3.13" +packaging==24.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" -safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" +safetensors==0.4.2 ; python_version >= "3.9" and python_version < "3.13" scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" -setuptools==69.1.0 ; python_version >= "3.9" and python_version < "3.13" +setuptools==69.2.0 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13" -transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13" +transformers==4.39.0 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" -typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13" -urllib3==2.2.0 ; python_version >= "3.9" and python_version < "3.13" +typing-extensions==4.10.0 ; python_version >= "3.9" and python_version < "3.13" +urllib3==2.2.1 ; python_version >= "3.9" and python_version < "3.13" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py new file mode 100644 index 00000000..985bbd8e --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -0,0 +1,461 @@ +# coding=utf-8 +# Copyright 2024 Cohere team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed + +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig +from typing import Optional, List, Tuple + +from text_generation_server.utils import paged_attention, flash_attn +from text_generation_server.utils.layers import ( + TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelEmbedding, + PositionRotaryEmbedding, + SpeculativeHead, + get_linear, + FastRMSNorm, +) + + +class CohereConfig(PretrainedConfig): + def __init__( + self, + vocab_size=256000, + hidden_size=8192, + intermediate_size=22528, + num_hidden_layers=40, + num_attention_heads=64, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=8192, + initializer_range=0.02, + layer_norm_eps=1e-5, + use_cache=True, + pad_token_id=0, + bos_token_id=5, + eos_token_id=255001, + pretraining_tp=1, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + logit_scale=1.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.logit_scale = logit_scale + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +def load_attention(config, prefix, weights): + if config.num_attention_heads != config.num_key_value_heads: + return _load_gqa(config, prefix, weights) + else: + return TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=config.attention_bias, + ) + + +def _load_gqa(config, prefix: str, weights): + assert config.hidden_size % config.num_attention_heads == 0 + assert config.num_attention_heads % weights.process_group.size() == 0 + + weight = weights.get_multi_weights_col( + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + quantize=config.quantize, + dim=0, + ) + + if config.quantize not in ["gptq", "awq"]: + weight = weight.to(dtype=weights.dtype).to(device=weights.device) + + head_size = config.hidden_size // config.num_attention_heads + num_heads = config.num_attention_heads // weights.process_group.size() + num_key_value_heads = config.num_key_value_heads // weights.process_group.size() + assert list(weight.shape) == [ + (num_heads + 2 * num_key_value_heads) * head_size, + config.hidden_size, + ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + + if config.attention_bias: + w = [ + weights.get_sharded(f"{p}.bias", dim=0) + for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"] + ] + bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device) + else: + bias = None + + return TensorParallelColumnLinear( + get_linear(weight, bias=bias, quantize=config.quantize) + ) + + +class FlashCohereAttention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.num_heads + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, + ) + + self.softmax_scale = self.head_size**-0.5 + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + self.query_key_value = load_attention(config, prefix, weights) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=config.attention_bias, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + qkv = self.query_key_value(hidden_states) + query, kv = qkv.split( + [ + self.head_size * self.num_heads, + 2 * self.head_size * self.num_key_value_heads, + ], + dim=1, + ) + query = query.view(-1, self.num_heads, self.head_size) + kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) + + self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) + + paged_attention.reshape_and_cache( + kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots + ) + + # output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + flash_attn.attention( + query, + torch.select(kv, dim=1, index=0), + torch.select(kv, dim=1, index=1), + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention.attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), reduce=False + ) + + +class CohereMLP(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + act = config.hidden_act + self.act = ( + ACT2FN[act] + if "gelu" not in act + else lambda x: torch.nn.functional.gelu( + x, + approximate=( + "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" + ), + ) + ) + # Fuse gate and up proj + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + self.intermediate_size = ( + config.intermediate_size // weights.process_group.size() + ) + + def forward(self, hidden_states): + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=False + ) + + +class FlashCohereLayer(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + prefix = f"model.layers.{layer_id}" + self.self_attn = FlashCohereAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) + self.mlp = CohereMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_eps, + ) + self.process_group = weights.process_group + + def forward( + self, + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ): + normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + mlp_output = self.mlp(normed_hidden_states) + output = attn_output + mlp_output + + if self.process_group.size() > 1: + torch.distributed.all_reduce(output, group=self.process_group) + + return output, res + + +class FlashCohereModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + process_group = weights.process_group + self.tp_rank = process_group.rank() + self.tp_world_size = process_group.size() + self.embed_tokens = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.layers = nn.ModuleList( + [ + FlashCohereLayer( + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix="model.norm", weights=weights, eps=config.layer_norm_eps + ) + + self.gradient_checkpointing = False + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashCohereForCausalLM(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + + self.model = FlashCohereModel(config, weights) + try: + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head", + weights=weights, + ) + except RuntimeError: + self.lm_head = SpeculativeHead.load( + config, + prefix="model.embed_tokens", + weights=weights, + ) + self.logit_scale = config.logit_scale + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + lm_head_indices: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + logits *= self.logit_scale + if speculative_logits is not None: + speculative_logits *= self.logit_scale + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 69c1665d..bd7596db 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -20,16 +20,11 @@ import torch import torch.distributed -import os -from shutil import copyfile from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple -from tokenizers import processors -from transformers.tokenization_utils_fast import PreTrainedTokenizerFast -from transformers.utils import logging from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( @@ -42,162 +37,6 @@ from text_generation_server.utils.layers import ( FastRMSNorm, ) -GemmaTokenizer = None - -logger = logging.get_logger(__name__) -VOCAB_FILES_NAMES = { - "vocab_file": "tokenizer.model", - "tokenizer_file": "tokenizer.json", -} - -PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", - }, - "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", - }, -} -B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "<>\n", "\n<>\n\n" - -# fmt: off -DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ -answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ - that your responses are socially unbiased and positive in nature. - -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ -correct. If you don't know the answer to a question, please don't share false information.""" -# fmt: on - - -class GemmaTokenizerFast(PreTrainedTokenizerFast): - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - slow_tokenizer_class = GemmaTokenizer - padding_side = "left" - model_input_names = ["input_ids", "attention_mask"] - - def __init__( - self, - vocab_file=None, - tokenizer_file=None, - clean_up_tokenization_spaces=False, - unk_token="", - bos_token="", - eos_token="", - pad_token="", - add_bos_token=True, - add_eos_token=False, - use_default_system_prompt=False, - **kwargs, - ): - super().__init__( - vocab_file=vocab_file, - tokenizer_file=tokenizer_file, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - unk_token=unk_token, - bos_token=bos_token, - eos_token=eos_token, - pad_token=pad_token, - add_bos_token=add_bos_token, - add_eos_token=add_eos_token, - use_default_system_prompt=use_default_system_prompt, - **kwargs, - ) - self._add_bos_token = add_bos_token - self._add_eos_token = add_eos_token - self.update_post_processor() - self.use_default_system_prompt = use_default_system_prompt - self.vocab_file = vocab_file - - @property - def can_save_slow_tokenizer(self) -> bool: - return os.path.isfile(self.vocab_file) if self.vocab_file else False - - def update_post_processor(self): - """ - Updates the underlying post processor with the current `bos_token` and `eos_token`. - """ - bos = self.bos_token - bos_token_id = self.bos_token_id - if bos is None and self.add_bos_token: - raise ValueError("add_bos_token = True but bos_token = None") - - eos = self.eos_token - eos_token_id = self.eos_token_id - if eos is None and self.add_eos_token: - raise ValueError("add_eos_token = True but eos_token = None") - - single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" - pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" - - special_tokens = [] - if self.add_bos_token: - special_tokens.append((bos, bos_token_id)) - if self.add_eos_token: - special_tokens.append((eos, eos_token_id)) - self._tokenizer.post_processor = processors.TemplateProcessing( - single=single, pair=pair, special_tokens=special_tokens - ) - - @property - def add_eos_token(self): - return self._add_eos_token - - @property - def add_bos_token(self): - return self._add_bos_token - - @add_eos_token.setter - def add_eos_token(self, value): - self._add_eos_token = value - self.update_post_processor() - - @add_bos_token.setter - def add_bos_token(self, value): - self._add_bos_token = value - self.update_post_processor() - - def save_vocabulary( - self, save_directory: str, filename_prefix: Optional[str] = None - ) -> Tuple[str]: - if not self.can_save_slow_tokenizer: - raise ValueError( - "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " - "tokenizer." - ) - - if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) should be a directory") - return - out_vocab_file = os.path.join( - save_directory, - (filename_prefix + "-" if filename_prefix else "") - + VOCAB_FILES_NAMES["vocab_file"], - ) - - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): - copyfile(self.vocab_file, out_vocab_file) - - return (out_vocab_file,) - - @property - def default_chat_template(self): - raise NotImplementedError - - # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - bos_token_id = [self.bos_token_id] if self.add_bos_token else [] - eos_token_id = [self.eos_token_id] if self.add_eos_token else [] - - output = bos_token_id + token_ids_0 + eos_token_id - - if token_ids_1 is not None: - output = output + bos_token_id + token_ids_1 + eos_token_id - - return output - class GemmaConfig(PretrainedConfig): def __init__( diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py new file mode 100644 index 00000000..33b053a6 --- /dev/null +++ b/server/text_generation_server/models/flash_cohere.py @@ -0,0 +1,75 @@ +import torch +import torch.distributed + +from opentelemetry import trace +from typing import Optional +from transformers.models.llama import LlamaTokenizerFast + +from text_generation_server.models import FlashCausalLM +from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( + FlashCohereForCausalLM, + CohereConfig, +) +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) + +tracer = trace.get_tracer(__name__) + + +class FlashCohere(FlashCausalLM): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + use_medusa: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.bfloat16 if dtype is None else dtype + else: + raise NotImplementedError("FlashCohere is only available on GPU") + + tokenizer = LlamaTokenizerFast.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + use_fast=True, + from_slow=False, + ) + + config = CohereConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + config.use_medusa = use_medusa + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["gptq", "awq"]: + weights._set_gptq_params(model_id, revision) + + model = FlashCohereForCausalLM(config, weights) + + torch.distributed.barrier(group=self.process_group) + super(FlashCohere, self).__init__( + model=model, + tokenizer=tokenizer, + num_layers=len(model.model.layers), + num_kv_heads=model.model.num_key_value_heads, + head_size=model.model.head_size, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 8cfb6631..7259b820 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -3,10 +3,10 @@ import torch.distributed from opentelemetry import trace from typing import Optional +from transformers.models.gemma import GemmaTokenizerFast from text_generation_server.models import FlashCausalLM from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( - GemmaTokenizerFast, FlashGemmaForCausalLM, GemmaConfig, )