mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
feat: relax mistral requirements
This commit is contained in:
parent
f3aea78fb6
commit
68990a5635
1095
server/poetry.lock
generated
1095
server/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -15,7 +15,7 @@ grpcio-status = "^1.51.1"
|
|||||||
grpcio-reflection = "^1.51.1"
|
grpcio-reflection = "^1.51.1"
|
||||||
grpc-interceptor = "^0.15.0"
|
grpc-interceptor = "^0.15.0"
|
||||||
typer = "^0.6.1"
|
typer = "^0.6.1"
|
||||||
accelerate = { version = "^0.20.0", optional = true }
|
accelerate = { version = "^0.25.0", optional = true }
|
||||||
bitsandbytes = { version = "^0.41.1", optional = true }
|
bitsandbytes = { version = "^0.41.1", optional = true }
|
||||||
safetensors = "^0.3.2"
|
safetensors = "^0.3.2"
|
||||||
loguru = "^0.6.0"
|
loguru = "^0.6.0"
|
||||||
@ -24,9 +24,9 @@ opentelemetry-exporter-otlp = "^1.15.0"
|
|||||||
opentelemetry-instrumentation-grpc = "^0.36b0"
|
opentelemetry-instrumentation-grpc = "^0.36b0"
|
||||||
hf-transfer = "^0.1.2"
|
hf-transfer = "^0.1.2"
|
||||||
sentencepiece = "^0.1.97"
|
sentencepiece = "^0.1.97"
|
||||||
tokenizers = "^0.13.3"
|
tokenizers = "^0.15.0"
|
||||||
huggingface-hub = "^0.16.4"
|
huggingface-hub = "^0.19.3"
|
||||||
transformers = "^4.32.1"
|
transformers = "^4.36.1"
|
||||||
einops = "^0.6.1"
|
einops = "^0.6.1"
|
||||||
texttable = { version = "^1.6.7", optional = true }
|
texttable = { version = "^1.6.7", optional = true }
|
||||||
datasets = { version = "^2.14.0", optional = true }
|
datasets = { version = "^2.14.0", optional = true }
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
bitsandbytes==0.41.2.post2 ; python_version >= "3.9" and python_version < "3.13"
|
bitsandbytes==0.41.3.post2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
|
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -8,14 +8,14 @@ deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -37,11 +37,11 @@ safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -7,14 +7,14 @@ deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13"
|
googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
|
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
|
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
|
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
|
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
@ -36,11 +36,11 @@ safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
|
|||||||
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
|
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
|
tokenizers==0.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
|
transformers==4.36.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
|
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||||
|
@ -5,6 +5,7 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from transformers.models.auto import modeling_auto
|
from transformers.models.auto import modeling_auto
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM
|
from text_generation_server.models.causal_lm import CausalLM
|
||||||
@ -55,6 +56,8 @@ try:
|
|||||||
FlashSantacoderSharded,
|
FlashSantacoderSharded,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.idefics import IDEFICSSharded
|
from text_generation_server.models.idefics import IDEFICSSharded
|
||||||
|
from text_generation_server.models.flash_mistral import FlashMistral
|
||||||
|
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||||
@ -66,25 +69,7 @@ if FLASH_ATTENTION:
|
|||||||
__all__.append(FlashSantacoderSharded)
|
__all__.append(FlashSantacoderSharded)
|
||||||
__all__.append(FlashLlama)
|
__all__.append(FlashLlama)
|
||||||
__all__.append(IDEFICSSharded)
|
__all__.append(IDEFICSSharded)
|
||||||
|
|
||||||
MISTRAL = True
|
|
||||||
try:
|
|
||||||
from text_generation_server.models.flash_mistral import FlashMistral
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning(f"Could not import Mistral model: {e}")
|
|
||||||
MISTRAL = False
|
|
||||||
|
|
||||||
if MISTRAL:
|
|
||||||
__all__.append(FlashMistral)
|
__all__.append(FlashMistral)
|
||||||
|
|
||||||
MIXTRAL = True
|
|
||||||
try:
|
|
||||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
|
||||||
except ImportError as e:
|
|
||||||
logger.warning(f"Could not import Mixtral model: {e}")
|
|
||||||
MIXTRAL = False
|
|
||||||
|
|
||||||
if MIXTRAL:
|
|
||||||
__all__.append(FlashMixtral)
|
__all__.append(FlashMixtral)
|
||||||
|
|
||||||
|
|
||||||
@ -295,7 +280,9 @@ def get_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_type == "mistral":
|
if model_type == "mistral":
|
||||||
if MISTRAL:
|
if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or (
|
||||||
|
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA
|
||||||
|
):
|
||||||
return FlashMistral(
|
return FlashMistral(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -303,10 +290,11 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
raise NotImplementedError("Mistral models requires flash attention v2")
|
|
||||||
|
|
||||||
if model_type == "mixtral":
|
if model_type == "mixtral":
|
||||||
if MIXTRAL:
|
if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or (
|
||||||
|
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA
|
||||||
|
):
|
||||||
return FlashMixtral(
|
return FlashMixtral(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
@ -314,9 +302,6 @@ def get_model(
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
raise NotImplementedError(
|
|
||||||
"Mixtral models requires flash attention v2, stk and megablocks"
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type == "opt":
|
if model_type == "opt":
|
||||||
return OPTSharded(
|
return OPTSharded(
|
||||||
@ -348,17 +333,17 @@ def get_model(
|
|||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
if sharded:
|
if sharded:
|
||||||
raise ValueError("sharded is not supported for AutoModel")
|
raise NotImplementedError("sharded is not supported for AutoModel")
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
raise ValueError(
|
raise NotImplementedError(
|
||||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
)
|
)
|
||||||
if quantize == "awq":
|
if quantize == "awq":
|
||||||
raise ValueError("awq quantization is not supported for AutoModel")
|
raise NotImplementedError("awq quantization is not supported for AutoModel")
|
||||||
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
||||||
raise ValueError("4bit quantization is not supported for AutoModel")
|
raise NotImplementedError("4bit quantization is not supported for AutoModel")
|
||||||
elif quantize == "eetq":
|
elif quantize == "eetq":
|
||||||
raise ValueError("Eetq quantization is not supported for AutoModel")
|
raise NotImplementedError("Eetq quantization is not supported for AutoModel")
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id,
|
model_id,
|
||||||
|
@ -27,11 +27,6 @@ from transformers.configuration_utils import PretrainedConfig
|
|||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.flash_attn import (
|
|
||||||
attention,
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM,
|
|
||||||
HAS_FLASH_ATTN_V2_CUDA,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
@ -43,10 +38,6 @@ from text_generation_server.utils.layers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
|
|
||||||
raise ImportError("Mistral model requires flash attn v2")
|
|
||||||
|
|
||||||
|
|
||||||
class MistralConfig(PretrainedConfig):
|
class MistralConfig(PretrainedConfig):
|
||||||
model_type = "mistral"
|
model_type = "mistral"
|
||||||
|
|
||||||
|
@ -27,12 +27,9 @@ from torch import nn
|
|||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from text_generation_server.utils import paged_attention, flash_attn
|
from text_generation_server.utils import paged_attention, flash_attn
|
||||||
from text_generation_server.utils.flash_attn import (
|
|
||||||
HAS_FLASH_ATTN_V2_ROCM,
|
|
||||||
HAS_FLASH_ATTN_V2_CUDA,
|
|
||||||
)
|
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
FastLinear,
|
FastLinear,
|
||||||
FastRMSNorm,
|
FastRMSNorm,
|
||||||
@ -44,18 +41,13 @@ from text_generation_server.utils.layers import (
|
|||||||
get_linear,
|
get_linear,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not HAS_FLASH_ATTN_V2_CUDA and not HAS_FLASH_ATTN_V2_ROCM:
|
HAS_MEGABLOCKS = True
|
||||||
raise ImportError("Mixtral model requires flash attn v2")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import megablocks.ops as ops
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Mixtral model requires megablocks to be installed")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import stk
|
import stk
|
||||||
|
import megablocks.ops as ops
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Mixtral model requires stk to be installed")
|
logger.warning("Mixtral: megablocks is not installed")
|
||||||
|
HAS_MEGABLOCKS = False
|
||||||
|
|
||||||
|
|
||||||
class MixtralConfig(PretrainedConfig):
|
class MixtralConfig(PretrainedConfig):
|
||||||
@ -590,7 +582,7 @@ class BlockSparseMoE(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if len(x) > 256:
|
if len(x) > 256 and HAS_MEGABLOCKS:
|
||||||
return self.sparse_forward(x)
|
return self.sparse_forward(x)
|
||||||
# This is faster when there is not a lot of tokens
|
# This is faster when there is not a lot of tokens
|
||||||
return self.dense_forward(x)
|
return self.dense_forward(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user