Merge branch 'main' into warmup_gaudi_backend

This commit is contained in:
Wang, Yi A 2025-04-15 22:16:42 -07:00
commit 01f17d526c
10 changed files with 1588 additions and 2581 deletions

View File

@ -87,7 +87,7 @@ RUN echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https:/
RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d RUN mv /tmp/intel-for-pytorch-gpu-dev.list /etc/apt/sources.list.d
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt install -y xpu-smi cmake ninja-build pciutils intel-ocloc libnl-genl-3-200
# Text Generation Inference base env # Text Generation Inference base env
ENV HF_HOME=/data \ ENV HF_HOME=/data \
@ -98,9 +98,7 @@ ENV HF_HOME=/data \
WORKDIR /usr/src WORKDIR /usr/src
RUN pip install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/xpu RUN pip install torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/xpu
RUN pip install triton-xpu==3.2.0b1 --no-cache-dir
# Install server # Install server
COPY proto proto COPY proto proto

File diff suppressed because it is too large Load Diff

View File

@ -9,30 +9,30 @@ text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<3.13" python = ">=3.9,<3.13"
protobuf = "^3.20.3" protobuf = "^5.0"
grpcio = "^1.51.1" grpcio = "^1.71.1"
grpcio-status = "*" grpcio-status = "*"
grpcio-reflection = "*" grpcio-reflection = "*"
grpc-interceptor = "^0.15.0" grpc-interceptor = "^0.15.0"
typer = "^0.7.0" typer = "^0.15.0"
loguru = "^0.6.0" loguru = "^0.7.3"
opentelemetry-api = "^1.15.0" opentelemetry-api = "^1.32.0"
opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-exporter-otlp = "^1.32.0"
opentelemetry-instrumentation-grpc = "^0.36b0" opentelemetry-instrumentation-grpc = "^0.53b0"
hf-transfer = "^0.1.2" hf-transfer = "^0.1.9"
sentencepiece = "^0.1.97" sentencepiece = "^0.2.0"
peft = "^0.10" peft = "^0.15"
optimum-habana = "1.16.0" optimum-habana = "1.17"
transformers = "4.45.2" transformers = "^4.49"
numpy = "1.26.4" numpy = "^1.26"
accelerate = "0.33.0" accelerate = "^0.33"
outlines= { version = "^0.0.36", optional = true } outlines= { version = "^0.0.36", optional = true }
prometheus-client = "^0.20.0" prometheus-client = "^0.21.1"
py-cpuinfo = "^9.0.0" py-cpuinfo = "^9.0.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
grpcio-tools = "*" grpcio-tools = "*"
pytest = "^7.3.0" pytest = "^8.3.5"
[tool.pytest.ini_options] [tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
@ -40,3 +40,6 @@ markers = ["private: marks tests as requiring an admin hf token (deselect with '
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.poetry.requires-plugins]
poetry-plugin-export = ">=1.8"

View File

@ -1,104 +1,101 @@
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13" accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
aiohappyeyeballs==2.4.3 ; python_version >= "3.9" and python_version < "3.13"
aiohttp==3.10.10 ; python_version >= "3.9" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
attrs==24.2.0 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.8.30 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.0 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
coloredlogs==15.0.1 ; python_version >= "3.9" and python_version < "3.13"
datasets==3.0.1 ; python_version >= "3.9" and python_version < "3.13"
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
diffusers==0.31.0 ; python_version >= "3.9" and python_version < "3.13"
dill==0.3.7 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.16.1 ; python_version >= "3.9" and python_version < "3.13"
frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec[http]==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.65.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.67.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.26.1 ; python_version >= "3.9" and python_version < "3.13"
humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==8.5.0 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.4 ; python_version >= "3.9" and python_version < "3.13"
joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
multidict==6.1.0 ; python_version >= "3.9" and python_version < "3.13"
multiprocess==0.70.15 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.23.2 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pandas==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
peft==0.10.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==11.0.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
propcache==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==3.20.3 ; python_version >= "3.9" and python_version < "3.13"
psutil==6.1.0 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyarrow==17.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyreadline3==3.5.4 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "3.13"
python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_version < "3.13"
pytz==2024.2 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.9.11 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13"
scikit-learn==1.5.2 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentence-transformers[train]==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==75.2.0 ; python_version >= "3.9" and python_version < "3.13"
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.20.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
transformers[sentencepiece]==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9"
typer==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
tzdata==2024.2 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.3 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.2 ; python_version >= "3.9" and python_version < "3.13"
outlines==0.0.34 ; python_version >= "3.9" and python_version < "3.13"
interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
cloudpickle==3.1.0 ; python_version >= "3.9" and python_version < "3.13"
diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13" annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13"
certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.8 ; python_version >= "3.9" and python_version < "3.13"
cloudpickle==3.1.1 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Windows" or python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
deprecated==1.2.18 ; python_version >= "3.9" and python_version < "3.13"
diffusers==0.31.0 ; python_version >= "3.9" and python_version < "3.13"
diskcache==5.6.3 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.18.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2025.3.2 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.70.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.71.0 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.72.0rc1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.9 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.30.2 ; python_version >= "3.9" and python_version < "3.13"
idna==3.10 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==8.6.1 ; python_version >= "3.9" and python_version < "3.13"
interegular==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.6 ; python_version >= "3.9" and python_version < "3.13"
joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13" jsonschema-specifications==2024.10.1 ; python_version >= "3.9" and python_version < "3.13"
nest-asyncio==1.6.0; python_version >= "3.9" and python_version < "3.13" jsonschema==4.23.0 ; python_version >= "3.9" and python_version < "3.13"
pydantic==2.10.6; python_version >= "3.9" and python_version < "3.13" lark==1.2.2 ; python_version >= "3.9" and python_version < "3.13"
pydantic-core==2.27.2 ; python_version >= "3.9" and python_version < "3.13" llvmlite==0.43.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.7.3 ; python_version >= "3.9" and python_version < "3.13"
markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==3.0.2 ; python_version >= "3.9" and python_version < "3.13"
mdurl==0.1.2 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
nest-asyncio==1.6.0 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numba==0.60.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
nvidia-cublas-cu12==12.4.5.8 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-cupti-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-nvrtc-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-runtime-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cudnn-cu12==9.1.0.70 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cufft-cu12==11.2.1.3 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-curand-cu12==10.3.5.147 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusolver-cu12==11.6.1.9 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusparse-cu12==12.3.1.170 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusparselt-cu12==0.6.2 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nccl-cu12==2.21.5 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvjitlink-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvtx-cu12==12.4.127 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
opentelemetry-api==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-common==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.17.0 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13"
outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
peft==0.15.1 ; python_version >= "3.9" and python_version < "3.13"
pillow==11.2.1 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
protobuf==5.29.4 ; python_version >= "3.9" and python_version < "3.13"
psutil==7.0.0 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pydantic-core==2.33.1 ; python_version >= "3.9" and python_version < "3.13"
pydantic==2.11.3 ; python_version >= "3.9" and python_version < "3.13"
pygments==2.19.1 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13" referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.13"
rpds-py==0.22.3 ; python_version >= "3.9" and python_version < "3.13" regex==2024.11.6 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
rich==14.0.0 ; python_version >= "3.9" and python_version < "3.13"
rpds-py==0.24.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.5.3 ; python_version >= "3.9" and python_version < "3.13"
scikit-learn==1.6.1 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.2.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==78.1.0 ; python_version >= "3.12" and python_version < "3.13"
shellingham==1.5.4 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
torch==2.6.0 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13"
triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13"
typing-inspection==0.4.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.4.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.2.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.17.2 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.21.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -3,7 +3,7 @@ import torch
import habana_frameworks.torch.core as htcore import habana_frameworks.torch.core as htcore
from loguru import logger from loguru import logger
from typing import Dict, Union from typing import Dict
from text_generation_server.pb.generate_pb2 import GrammarType from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM from outlines.fsm.fsm import RegexFSM
@ -13,7 +13,6 @@ from typing import List, Optional, DefaultDict
import time import time
from transformers import ( from transformers import (
LogitsWarper,
LogitsProcessor, LogitsProcessor,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
@ -191,7 +190,7 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
class HeterogeneousTemperatureLogitsWarper: class HeterogeneousTemperatureLogitsWarper:
r""" r"""
[`LogitsWarper`] for temperature (exponential scaling output probability distribution). [`LogitsProcessor`] for temperature (exponential scaling output probability distribution).
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs. It doesn't validate inputs.
@ -220,7 +219,7 @@ class HeterogeneousTemperatureLogitsWarper:
return None return None
class HeterogeneousTopPLogitsWarper(LogitsWarper): class HeterogeneousTopPLogitsWarper(LogitsProcessor):
""" """
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
@ -279,9 +278,9 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
return None return None
class HeterogeneousTopKLogitsWarper(LogitsWarper): class HeterogeneousTopKLogitsWarper(LogitsProcessor):
r""" r"""
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. [`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements.
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs. It doesn't validate inputs.
@ -360,9 +359,9 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
return None return None
class HeterogeneousTypicalLogitsWarper(LogitsWarper): class HeterogeneousTypicalLogitsWarper(LogitsProcessor):
r""" r"""
[`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language [`LogitsProcessor`] that performs typical decoding. See [Typical Decoding for Natural Language
Generation](https://arxiv.org/abs/2202.00666) for more information. Generation](https://arxiv.org/abs/2202.00666) for more information.
This version allows for a separate value for each sample and runs inplace when possible. This version allows for a separate value for each sample and runs inplace when possible.
It doesn't validate inputs. It doesn't validate inputs.
@ -454,13 +453,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor):
r""" r"""
A wrapper for logit warpers or processors without heterogeneous parameter support. A wrapper for logit warpers or processors without heterogeneous parameter support.
Args: Args:
processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): processors (`Dict[int, LogitsProcessor]`):
A mapping of sample indices to logit warpers or processors, to be run sequentially. A mapping of sample indices to logit warpers or processors, to be run sequentially.
""" """
def __init__( def __init__(
self, self,
processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], processors: Dict[int, LogitsProcessor],
): ):
self.processors = processors self.processors = processors

View File

@ -1,5 +1,5 @@
[build-system] [build-system]
requires = ["setuptools>=61.0"] requires = ["setuptools>=78.1"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]

View File

@ -201,7 +201,8 @@ except ImportError as e:
if MAMBA_AVAILABLE: if MAMBA_AVAILABLE:
__all__.append(Mamba) __all__.append(Mamba)
FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() FLASH_TRANSFORMERS_BACKEND = torch.cuda.is_available() or SYSTEM == "ipex"
try: try:
from text_generation_server.models.transformers_flash_causal_lm import ( from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM, TransformersFlashCausalLM,

View File

@ -12,7 +12,7 @@ from text_generation_server.utils import initialize_torch_distributed
from text_generation_server.layers.attention import paged_attention, attention, Seqlen from text_generation_server.layers.attention import paged_attention, attention, Seqlen
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
from text_generation_server.models.globals import ATTENTION from text_generation_server.models.globals import ATTENTION
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -115,8 +115,11 @@ class TransformersFlashCausalLM(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available(): elif SYSTEM == "ipex":
device = torch.device("xpu") if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
else: else:
raise ValueError( raise ValueError(

View File

@ -14,6 +14,7 @@ from text_generation_server.layers.attention import paged_attention, attention,
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
from text_generation_server.models.globals import ATTENTION from text_generation_server.models.globals import ATTENTION
import torch.nn.functional as F import torch.nn.functional as F
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -174,8 +175,11 @@ class TransformersFlashVlmCausalLM(VlmCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available(): elif SYSTEM == "ipex":
device = torch.device("xpu") if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
else: else:
raise ValueError( raise ValueError(

View File

@ -73,6 +73,13 @@ def initialize_torch_distributed():
if SYSTEM == "ipex": if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
assert (
WORLD_SIZE <= torch.xpu.device_count()
), "Each process is one xpu"
device = RANK % torch.xpu.device_count()
torch.xpu.set_device(device)
ipex.distributed.init_process_group( ipex.distributed.init_process_group(
backend="ccl", backend="ccl",
world_size=WORLD_SIZE, world_size=WORLD_SIZE,