feat: relax mistral requirements (#1351)

Close #1253 
Close #1279
This commit is contained in:
OlivierDehaene 2023-12-15 12:52:24 +01:00 committed by GitHub
parent f3aea78fb6
commit 9b56d3fbf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 671 additions and 621 deletions

View File

@ -146,11 +146,50 @@ jobs:
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
integration-tests:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
needs:
- start-runner
- build-and-push-image # Wait for the docker image to be built
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
env:
DOCKER_VOLUME: /cache
steps:
- uses: actions/checkout@v2
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Tailscale
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
with:
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
- name: Prepare disks
run: |
sudo mkfs -t ext4 /dev/nvme1n1
sudo mkdir ${{ env.DOCKER_VOLUME }}
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
- name: Install
run: |
make install-integration-tests
- name: Run tests
run: |
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv integration-tests
build-and-push-image-rocm: build-and-push-image-rocm:
concurrency: concurrency:
group: ${{ github.workflow }}-build-and-push-image-rocm-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-build-and-push-image-rocm-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
needs: start-runner # required to start the main job when the runner is ready needs:
- start-runner
- build-and-push-image # Wait for the main docker image to be built
- integration-tests # Wait for the main integration-tests
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
permissions: permissions:
contents: write contents: write
@ -235,43 +274,6 @@ jobs:
cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
integration-tests:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
needs:
- start-runner
- build-and-push-image # Wait for the docker image to be built
- build-and-push-image-rocm
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
env:
DOCKER_VOLUME: /cache
steps:
- uses: actions/checkout@v2
- name: Inject slug/short variables
uses: rlespinasse/github-slug-action@v4.4.1
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Tailscale
uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
with:
authkey: ${{ secrets.TAILSCALE_AUTHKEY }}
- name: Prepare disks
run: |
sudo mkfs -t ext4 /dev/nvme1n1
sudo mkdir ${{ env.DOCKER_VOLUME }}
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
- name: Install
run: |
make install-integration-tests
- name: Run tests
run: |
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv integration-tests
stop-runner: stop-runner:
name: Stop self-hosted EC2 runner name: Stop self-hosted EC2 runner
needs: needs:

1095
server/poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -15,7 +15,7 @@ grpcio-status = "^1.51.1"
grpcio-reflection = "^1.51.1" grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0" grpc-interceptor = "^0.15.0"
typer = "^0.6.1" typer = "^0.6.1"
accelerate = { version = "^0.20.0", optional = true } accelerate = { version = "^0.25.0", optional = true }
bitsandbytes = { version = "^0.41.1", optional = true } bitsandbytes = { version = "^0.41.1", optional = true }
safetensors = "^0.3.2" safetensors = "^0.3.2"
loguru = "^0.6.0" loguru = "^0.6.0"
@ -24,9 +24,9 @@ opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0" opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2" hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97" sentencepiece = "^0.1.97"
tokenizers = "^0.13.3" tokenizers = "^0.15.0"
huggingface-hub = "^0.16.4" huggingface-hub = "^0.19.3"
transformers = "^4.32.1" transformers = "^4.36.1"
einops = "^0.6.1" einops = "^0.6.1"
texttable = { version = "^1.6.7", optional = true } texttable = { version = "^1.6.7", optional = true }
datasets = { version = "^2.14.0", optional = true } datasets = { version = "^2.14.0", optional = true }

View File

@ -1,5 +1,5 @@
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
bitsandbytes==0.41.2.post2 ; python_version >= "3.9" and python_version < "3.13" bitsandbytes==0.41.3.post2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13" certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
@ -8,14 +8,14 @@ deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13" filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13" fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13" idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -37,11 +37,11 @@ safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13" scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13" transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -7,14 +7,14 @@ deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13" einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13" filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13" fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13" idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
@ -36,11 +36,11 @@ safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13" scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13" transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -55,10 +55,14 @@ try:
FlashSantacoderSharded, FlashSantacoderSharded,
) )
from text_generation_server.models.idefics import IDEFICSSharded from text_generation_server.models.idefics import IDEFICSSharded
from text_generation_server.models.flash_mistral import FlashMistral
from text_generation_server.models.flash_mixtral import FlashMixtral
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
except ImportError as e: except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}") logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False FLASH_ATTENTION = False
HAS_FLASH_ATTN_V2_CUDA = False
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashNeoXSharded) __all__.append(FlashNeoXSharded)
@ -66,25 +70,7 @@ if FLASH_ATTENTION:
__all__.append(FlashSantacoderSharded) __all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama) __all__.append(FlashLlama)
__all__.append(IDEFICSSharded) __all__.append(IDEFICSSharded)
MISTRAL = True
try:
from text_generation_server.models.flash_mistral import FlashMistral
except ImportError as e:
logger.warning(f"Could not import Mistral model: {e}")
MISTRAL = False
if MISTRAL:
__all__.append(FlashMistral) __all__.append(FlashMistral)
MIXTRAL = True
try:
from text_generation_server.models.flash_mixtral import FlashMixtral
except ImportError as e:
logger.warning(f"Could not import Mixtral model: {e}")
MIXTRAL = False
if MIXTRAL:
__all__.append(FlashMixtral) __all__.append(FlashMixtral)
@ -295,7 +281,9 @@ def get_model(
) )
if model_type == "mistral": if model_type == "mistral":
if MISTRAL: if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or (
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA
):
return FlashMistral( return FlashMistral(
model_id, model_id,
revision, revision,
@ -303,10 +291,11 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
raise NotImplementedError("Mistral models requires flash attention v2")
if model_type == "mixtral": if model_type == "mixtral":
if MIXTRAL: if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or (
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA
):
return FlashMixtral( return FlashMixtral(
model_id, model_id,
revision, revision,
@ -314,9 +303,6 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
raise NotImplementedError(
"Mixtral models requires flash attention v2, stk and megablocks"
)
if model_type == "opt": if model_type == "opt":
return OPTSharded( return OPTSharded(
@ -348,17 +334,17 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if sharded: if sharded:
raise ValueError("sharded is not supported for AutoModel") raise NotImplementedError("sharded is not supported for AutoModel")
if quantize == "gptq": if quantize == "gptq":
raise ValueError( raise NotImplementedError(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
) )
if quantize == "awq": if quantize == "awq":
raise ValueError("awq quantization is not supported for AutoModel") raise NotImplementedError("awq quantization is not supported for AutoModel")
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError("4bit quantization is not supported for AutoModel") raise NotImplementedError("4bit quantization is not supported for AutoModel")
elif quantize == "eetq": elif quantize == "eetq":
raise ValueError("Eetq quantization is not supported for AutoModel") raise NotImplementedError("Eetq quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM( return CausalLM(
model_id, model_id,

View File

@ -27,11 +27,6 @@ from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import (
attention,
HAS_FLASH_ATTN_V2_ROCM,
HAS_FLASH_ATTN_V2_CUDA,
)
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -43,10 +38,6 @@ from text_generation_server.utils.layers import (
) )
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
raise ImportError("Mistral model requires flash attn v2")
class MistralConfig(PretrainedConfig): class MistralConfig(PretrainedConfig):
model_type = "mistral" model_type = "mistral"

View File

@ -27,12 +27,9 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from loguru import logger
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.utils.flash_attn import (
HAS_FLASH_ATTN_V2_ROCM,
HAS_FLASH_ATTN_V2_CUDA,
)
from text_generation_server.utils.layers import ( from text_generation_server.utils.layers import (
FastLinear, FastLinear,
FastRMSNorm, FastRMSNorm,
@ -44,18 +41,13 @@ from text_generation_server.utils.layers import (
get_linear, get_linear,
) )
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM: HAS_MEGABLOCKS = True
raise ImportError("Mixtral model requires flash attn v2")
try:
import megablocks.ops as ops
except ImportError:
raise ImportError("Mixtral model requires megablocks to be installed")
try: try:
import stk import stk
import megablocks.ops as ops
except ImportError: except ImportError:
raise ImportError("Mixtral model requires stk to be installed") logger.warning("Mixtral: megablocks is not installed")
HAS_MEGABLOCKS = False
class MixtralConfig(PretrainedConfig): class MixtralConfig(PretrainedConfig):
@ -590,7 +582,7 @@ class BlockSparseMoE(nn.Module):
return out return out
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x) > 256: if len(x) > 256 and HAS_MEGABLOCKS:
return self.sparse_forward(x) return self.sparse_forward(x)
# This is faster when there is not a lot of tokens # This is faster when there is not a lot of tokens
return self.dense_forward(x) return self.dense_forward(x)