Make Gaudi adapt to the tgi 2.3.0

Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
yuanwu 2024-09-26 01:53:52 +00:00
parent 14fdc4ae5e
commit bab529c916
25 changed files with 3123 additions and 2355 deletions

View File

@ -61,6 +61,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
make \ make \
curl \ curl \
git \ git \
python3.11-dev \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Install server # Install server
@ -96,5 +97,5 @@ FROM base
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh RUN chmod +x /tgi-entrypoint.sh
ENTRYPOINT ["/tgi-entrypoint.sh"] #ENTRYPOINT ["/tgi-entrypoint.sh"]
# CMD ["--json-output"] # CMD ["--json-output"]

View File

@ -4,10 +4,6 @@ include Makefile-vllm
include Makefile-awq include Makefile-awq
include Makefile-eetq include Makefile-eetq
include Makefile-selective-scan include Makefile-selective-scan
include Makefile-lorax-punica
include Makefile-fbgemm
include Makefile-exllamav2
include Makefile-flashinfer
unit-tests: unit-tests:
pytest -s -vv -m "not private" tests pytest -s -vv -m "not private" tests
@ -21,25 +17,20 @@ gen-server:
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py touch text_generation_server/pb/__init__.py
install-server: gen-server install: gen-server
pip install pip --upgrade pip install pip --upgrade
pip install -r requirements_cuda.txt pip install -r requirements.txt
pip install -e ".[accelerate, quantize, peft, outlines]" pip install -e "."
install: install-cuda
echo "Installed server"
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
pip install -e ".[bnb]"
pip install nvidia-nccl-cu12==2.22.3
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
run-dev: run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
install-poetry:
curl -sSL https://install.python-poetry.org | python3 -
update-lock:
rm poetry.lock
poetry lock --no-update
export-requirements: export-requirements:
poetry export -o requirements_cuda.txt --without-hashes poetry export -o requirements.txt --without-hashes
poetry export -o requirements_rocm.txt --without-hashes
poetry export -o requirements_intel.txt --without-hashes

View File

@ -0,0 +1,91 @@
#!/bin/bash
git clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git
pushd dill
cat <<EOF > dill-0.3.7.patch
diff --git a/dill/_dill.py b/dill/_dill.py
index d0cf543..f6eb662 100644
--- a/dill/_dill.py
+++ b/dill/_dill.py
@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
XRangeType = range
from types import MappingProxyType as DictProxyType, new_class
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
-import __main__ as _main_module
+class _LazyMainModule(object):
+ _module = None
+ @property
+ def module(self):
+ if self._module is None:
+ import __main__ as _m_module
+ self._module = _m_module
+ return self._module
+_main_module = _LazyMainModule()
import marshal
import gc
# import zlib
@@ -353,7 +361,7 @@ class Pickler(StockPickler):
_fmode = kwds.pop('fmode', None)
_recurse = kwds.pop('recurse', None)
StockPickler.__init__(self, file, *args, **kwds)
- self._main = _main_module
+ self._main = _main_module.module
self._diff_cache = {}
self._byref = settings['byref'] if _byref is None else _byref
self._strictio = False #_strictio
@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler):
settings = Pickler.settings
_ignore = kwds.pop('ignore', None)
StockUnpickler.__init__(self, *args, **kwds)
- self._main = _main_module
+ self._main = _main_module.module
self._ignore = settings['ignore'] if _ignore is None else _ignore
def load(self): #NOTE: if settings change, need to update attributes
obj = StockUnpickler.load(self)
- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
if not self._ignore:
# point obj class to main
try: obj.__class__ = getattr(self._main, type(obj).__name__)
@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj):
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
logger.trace(pickler, "# D1")
- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
logger.trace(pickler, "# D3")
- elif '__name__' in obj and obj != _main_module.__dict__ \\
+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
and type(obj['__name__']) is str \\
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
diff --git a/dill/session.py b/dill/session.py
index 74234ab..1be8d89 100644
--- a/dill/session.py
+++ b/dill/session.py
@@ -233,7 +233,7 @@ def dump_module(
protocol = settings['protocol']
main = module
if main is None:
- main = _main_module
+ main = _main_module.module
elif isinstance(main, str):
main = _import_module(main)
if not isinstance(main, ModuleType):
@@ -501,7 +501,7 @@ def load_module(
pass
assert loaded is main
_restore_modules(unpickler, main)
- if main is _main_module or main is module:
+ if main is _main_module.module or main is module:
return None
else:
return main
EOF
git apply dill-0.3.7.patch
python -m pip install .
popd
rm -fr dill

View File

@ -0,0 +1,91 @@
#!/bin/bash
git clone -b 0.3.8 https://github.com/uqfoundation/dill.git
pushd dill
cat <<EOF > dill-0.3.8.patch
diff --git a/dill/_dill.py b/dill/_dill.py
index d42432f..1d251e6 100644
--- a/dill/_dill.py
+++ b/dill/_dill.py
@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
XRangeType = range
from types import MappingProxyType as DictProxyType, new_class
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
-import __main__ as _main_module
+class _LazyMainModule(object):
+ _module = None
+ @property
+ def module(self):
+ if self._module is None:
+ import __main__ as _m_module
+ self._module = _m_module
+ return self._module
+_main_module = _LazyMainModule()
import marshal
import gc
# import zlib
@@ -355,7 +363,7 @@ class Pickler(StockPickler):
_fmode = kwds.pop('fmode', None)
_recurse = kwds.pop('recurse', None)
StockPickler.__init__(self, file, *args, **kwds)
- self._main = _main_module
+ self._main = _main_module.module
self._diff_cache = {}
self._byref = settings['byref'] if _byref is None else _byref
self._strictio = False #_strictio
@@ -437,12 +445,12 @@ class Unpickler(StockUnpickler):
settings = Pickler.settings
_ignore = kwds.pop('ignore', None)
StockUnpickler.__init__(self, *args, **kwds)
- self._main = _main_module
+ self._main = _main_module.module
self._ignore = settings['ignore'] if _ignore is None else _ignore
def load(self): #NOTE: if settings change, need to update attributes
obj = StockUnpickler.load(self)
- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
if not self._ignore:
# point obj class to main
try: obj.__class__ = getattr(self._main, type(obj).__name__)
@@ -1199,11 +1207,11 @@ def save_module_dict(pickler, obj):
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
logger.trace(pickler, "# D1")
- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
logger.trace(pickler, "# D3")
- elif '__name__' in obj and obj != _main_module.__dict__ \\
+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
and type(obj['__name__']) is str \\
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
diff --git a/dill/session.py b/dill/session.py
index e91068a..a921b43 100644
--- a/dill/session.py
+++ b/dill/session.py
@@ -233,7 +233,7 @@ def dump_module(
protocol = settings['protocol']
main = module
if main is None:
- main = _main_module
+ main = _main_module.module
elif isinstance(main, str):
main = _import_module(main)
if not isinstance(main, ModuleType):
@@ -501,7 +501,7 @@ def load_module(
pass
assert loaded is main
_restore_modules(unpickler, main)
- if main is _main_module or main is module:
+ if main is _main_module.module or main is module:
return None
else:
return main
EOF
git apply dill-0.3.8.patch
python -m pip install .
popd
rm -fr dill

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "text-generation-server" name = "text-generation-server"
version = "2.0.5-dev0" version = "2.0.4"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
@ -9,76 +9,34 @@ text-generation-server = 'text_generation_server.cli:app'
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<3.13" python = ">=3.9,<3.13"
protobuf = "^4.25.3" protobuf = "^3.20.3"
grpcio = "^1.51.1" grpcio = "^1.51.1"
grpcio-status = "^1.51.1" grpcio-status = "*"
grpcio-reflection = "^1.51.1" grpcio-reflection = "*"
grpc-interceptor = "^0.15.0" grpc-interceptor = "^0.15.0"
typer = "^0.6.1" typer = "^0.7.0"
accelerate = { version = "^0.29.1", optional = true }
bitsandbytes = { version = "^0.43.0", optional = true }
safetensors = "^0.4"
loguru = "^0.6.0" loguru = "^0.6.0"
opentelemetry-api = "^1.25.0" opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.25.0" opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.46b0" opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.2" hf-transfer = "^0.1.2"
sentencepiece = "^0.1.97" sentencepiece = "^0.1.97"
tokenizers = "^0.19.1" peft = "^0.10"
huggingface-hub = "^0.23" optimum-habana = "1.13.2"
transformers = "^4.43" transformers = "4.43.4"
einops = "^0.6.1" numpy = "1.26.4"
texttable = { version = "^1.6.7", optional = true } accelerate = "0.33.0"
datasets = { version = "^2.14.0", optional = true } outlines= { version = "^0.0.36", optional = true }
peft = { version = "^0.10", optional = true }
torch = { version = "^2.4.0", optional = true }
scipy = "^1.11.1"
pillow = "^10.0.0"
outlines= { version = "^0.0.34", optional = true }
prometheus-client = "^0.20.0" prometheus-client = "^0.20.0"
py-cpuinfo = "^9.0.0" py-cpuinfo = "^9.0.0"
# Remove later, temporary workaround for outlines.
numpy = "^1.26"
marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
rich = "^13.7.1"
[tool.poetry.extras]
torch = ["torch"]
accelerate = ["accelerate"]
bnb = ["bitsandbytes"]
marlin = ["marlin-kernels"]
moe = ["moe-kernels"]
peft = ["peft"]
quantize = ["texttable", "datasets", "accelerate"]
outlines = ["outlines"]
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.51.1" grpcio-tools = "*"
pytest = "^7.3.0" pytest = "^7.3.0"
[[tool.poetry.source]]
name = "pytorch-gpu-src"
url = "https://download.pytorch.org/whl/cu121"
priority = "explicit"
[tool.pytest.ini_options] [tool.pytest.ini_options]
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
[build-system] [build-system]
requires = [ requires = ["poetry-core>=1.0.0"]
"poetry-core>=1.0.0",
]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

88
server/requirements.txt Normal file
View File

@ -0,0 +1,88 @@
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
aiohappyeyeballs==2.4.0 ; python_version >= "3.9" and python_version < "3.13"
aiohttp==3.10.5 ; python_version >= "3.9" and python_version < "3.13"
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
attrs==24.2.0 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
coloredlogs==15.0.1 ; python_version >= "3.9" and python_version < "3.13"
datasets==2.21.0 ; python_version >= "3.9" and python_version < "3.13"
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
diffusers==0.29.2 ; python_version >= "3.9" and python_version < "3.13"
dill==0.3.8 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec[http]==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.66.0 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.24.6 ; python_version >= "3.9" and python_version < "3.13"
humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13"
idna==3.8 ; python_version >= "3.9" and python_version < "3.13"
importlib-metadata==8.4.0 ; python_version >= "3.9" and python_version < "3.13"
jinja2==3.1.4 ; python_version >= "3.9" and python_version < "3.13"
joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "3.13"
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
multidict==6.0.5 ; python_version >= "3.9" and python_version < "3.13"
multiprocess==0.70.16 ; python_version >= "3.9" and python_version < "3.13"
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
optimum-habana==1.13.2 ; python_version >= "3.9" and python_version < "3.13"
optimum==1.21.4 ; python_version >= "3.9" and python_version < "3.13"
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
peft==0.10.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==3.20.3 ; python_version >= "3.9" and python_version < "3.13"
psutil==6.0.0 ; python_version >= "3.9" and python_version < "3.13"
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyarrow==17.0.0 ; python_version >= "3.9" and python_version < "3.13"
pyreadline3==3.4.1 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "3.13"
python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_version < "3.13"
pytz==2024.1 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
regex==2024.7.24 ; python_version >= "3.9" and python_version < "3.13"
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.4.4 ; python_version >= "3.9" and python_version < "3.13"
scikit-learn==1.5.1 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
sentence-transformers[train]==3.0.1 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==73.0.1 ; python_version >= "3.9" and python_version < "3.13"
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
sympy==1.12.1 ; python_version >= "3.9" and python_version < "3.13"
threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
transformers[sentencepiece]==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9"
typer==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
yarl==1.9.4 ; python_version >= "3.9" and python_version < "3.13"
zipp==3.20.0 ; python_version >= "3.9" and python_version < "3.13"

View File

@ -47,9 +47,9 @@ def serve(
max_input_tokens: Optional[int] = None, max_input_tokens: Optional[int] = None,
): ):
if sharded: if sharded:
assert ( # assert (
os.getenv("RANK", None) is not None # os.getenv("RANK", None) is not None
), "RANK must be set when sharded is True" # ), "RANK must be set when sharded is True"
assert ( assert (
os.getenv("WORLD_SIZE", None) is not None os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True" ), "WORLD_SIZE must be set when sharded is True"
@ -96,7 +96,7 @@ def serve(
# Downgrade enum into str for easier management later on # Downgrade enum into str for easier management later on
quantize = None if quantize is None else quantize.value quantize = None if quantize is None else quantize.value
dtype = None if dtype is None else dtype.value dtype = "bfloat16" if dtype is None else dtype.value
if dtype is not None and quantize not in { if dtype is not None and quantize not in {
None, None,
"bitsandbytes", "bitsandbytes",
@ -106,6 +106,64 @@ def serve(
raise RuntimeError( raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
) )
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
if sharded:
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
num_shard = int(os.getenv("WORLD_SIZE", "1"))
logger.info("CLI SHARDED = {}".format(num_shard))
import subprocess
cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
if speculate is not None:
cmd += f"--speculate {speculate}"
logger.info("CLI server start deepspeed ={} ".format(cmd))
sys.stdout.flush()
sys.stderr.flush()
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
do_terminate = False
current_handler = signal.getsignal(signal.SIGTERM)
def terminate_handler(sig, frame):
nonlocal do_terminate
do_terminate = True
if callable(current_handler):
current_handler(sig, frame)
signal.signal(signal.SIGTERM, terminate_handler)
finished = False
while not finished:
try:
if do_terminate:
parent = psutil.Process(proc.pid)
all_procs = parent.children(recursive=True) + [parent]
for p in all_procs:
try:
p.terminate()
except psutil.NoSuchProcess:
pass
_, alive = psutil.wait_procs(all_procs, timeout=30)
for p in alive:
p.kill()
do_terminate = False
proc.wait(timeout=3)
except subprocess.TimeoutExpired:
pass
else:
finished = True
sys.stdout.flush()
sys.stderr.flush()
if proc.returncode != 0:
logger.error(f"{cmd} exited with status = {proc.returncode}")
return proc.returncode
else:
server.serve( server.serve(
model_id, model_id,
lora_adapters, lora_adapters,

View File

@ -0,0 +1,27 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import os
quant_config = os.getenv("QUANT_CONFIG", "")
is_quantization_enabled = quant_config != ""
if is_quantization_enabled:
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
os.environ.setdefault(
"UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
def prepare_model_for_quantization(model):
if is_quantization_enabled:
if os.getenv("USE_INC", "1") != "0":
from neural_compressor.torch.quantization import FP8Config, convert
config = FP8Config.from_json_file(quant_config)
model = convert(model, config)
else:
import habana_quantization_toolkit
habana_quantization_toolkit.prep_model(model)
return model

View File

@ -1,3 +1,5 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import torch import torch
import grpc import grpc
@ -6,6 +8,8 @@ from grpc_status import rpc_status
from grpc_interceptor.server import AsyncServerInterceptor from grpc_interceptor.server import AsyncServerInterceptor
from loguru import logger from loguru import logger
from typing import Callable, Any from typing import Callable, Any
import traceback
import os
class ExceptionInterceptor(AsyncServerInterceptor): class ExceptionInterceptor(AsyncServerInterceptor):
@ -20,6 +24,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
response = method(request_or_iterator, context) response = method(request_or_iterator, context)
return await response return await response
except Exception as err: except Exception as err:
trace = " " + traceback.format_exc() if os.environ.get('DUMP_STACK') else ''
method_name = method_name.split("/")[-1] method_name = method_name.split("/")[-1]
logger.exception(f"Method {method_name} encountered an error.") logger.exception(f"Method {method_name} encountered an error.")
@ -30,8 +35,10 @@ class ExceptionInterceptor(AsyncServerInterceptor):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
from .utils.debug import dbg_trace
dbg_trace('EXCEPTION', traceback.format_exc())
await context.abort_with_status( await context.abort_with_status(
rpc_status.to_status( rpc_status.to_status(
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err)) status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
) )
) )

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,10 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import torch import torch
import torch.distributed
from typing import Optional, Type from typing import Optional, Type
from transformers import ( from transformers import PreTrainedTokenizerBase
PreTrainedTokenizerBase,
)
from text_generation_server.models import CausalLM from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
@ -21,26 +20,33 @@ class BloomCausalLMBatch(CausalLMBatch):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "CausalLMBatch": ) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) batch = super().from_pb(
pb=pb,
tokenizer=tokenizer,
dtype=dtype,
device=device,
)
batch.keys_head_dim_last = False batch.keys_head_dim_last = False
return batch return batch
class BLOOMSharded(CausalLM): class BLOOM(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
speculator: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
super(BLOOM, self).__init__(
model_id=model_id,
revision=revision,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
@property @property
def batch_type(self) -> Type[CausalLMBatch]: def batch_type(self) -> Type[CausalLMBatch]:
return BloomCausalLMBatch return BloomCausalLMBatch
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs, speculative_logits = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)
logits = outputs.logits
return logits, speculative_logits, outputs.past_key_values

File diff suppressed because it is too large Load Diff

View File

@ -14,33 +14,26 @@
# limitations under the License. # limitations under the License.
""" PyTorch Llava-NeXT model.""" """ PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.models.llava_next.modeling_llava_next import (
unpad_image,
)
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
from transformers.image_processing_utils import select_best_resolution from transformers.image_processing_utils import select_best_resolution
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
""" """
Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args: Args:
image_size (`tuple`): image_size (`tuple`):
The size of the input image in the format (height, width). The size of the input image in the format (width, height).
grid_pinpoints (`List`): grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`. of the form `(height, width)`.
@ -48,7 +41,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
The size of each image patch. The size of each image patch.
Returns: Returns:
tuple: The shape of the image patch grid in the format (height, width). tuple: The shape of the image patch grid in the format (width, height).
""" """
if not isinstance(grid_pinpoints, list): if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists") raise ValueError("grid_pinpoints should be a list of tuples or lists")
@ -57,100 +50,13 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
return height // patch_size, width // patch_size return height // patch_size, width // patch_size
def unpad_image(tensor, original_size): class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.linear_1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class LlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
if config.vision_feature_layer < 0:
vision_config.num_hidden_layers += config.vision_feature_layer + 1
else:
vision_config.num_hidden_layers = config.vision_feature_layer + 1
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.multi_modal_projector = LlavaNextMultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights
)
self.image_newline = weights.get_tensor("image_newline")
self.vocab_size = config.text_config.vocab_size
self.config = config
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features( def _merge_input_ids_with_image_features(
self, self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
image_features: torch.Tensor, image_features: torch.Tensor,
input_ids: torch.Tensor,
): ):
"""In place merges in vision_embeddings with inputs_embeds.""" """In place merges in vision_embeddings with inputs_embeds."""
mask = input_ids == self.config.image_token_index mask = input_ids == self.config.image_token_index
@ -165,61 +71,130 @@ class LlavaNextForConditionalGeneration(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.LongTensor = None,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None, image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
): ):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images if token_idx is not None:
num_images, num_patches, channels, height, width = pixel_values.shape output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
pixel_values = pixel_values.view( output_hidden_states = (
num_images * num_patches, channels, height, width output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
image_features = self.vision_tower(pixel_values) return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer] outputs = self.language_model(
# Already done within the clip model attention_mask=attention_mask,
selected_image_feature = image_features.last_hidden_state position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
if self.config.vision_feature_select_strategy == "default": logits = outputs[0]
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full": if not return_dict:
selected_image_feature = selected_image_feature output = (logits,) + outputs[1:]
return output
return outputs
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
image_sizes=None,
attention_mask=None,
**kwargs,
):
"""
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
The only differences are:
- add new args token_idx
- add the process of merging images into inputs_embeds
"""
token_idx = kwargs.get("token_idx", None)
if token_idx is None:
return super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
pixel_values=pixel_values,
image_sizes=image_sizes,
attention_mask=attention_mask,
**kwargs,
)
else: else:
raise RuntimeError( use_flash_attention = kwargs.get("use_flash_attention", False)
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid." flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None)
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None)
vision_feature_layer = kwargs.get("vision_feature_layer", None)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
) )
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
# 1. Extract the input embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
# 2. Merge text and images
batch_size, num_patches, num_channels, height, width = pixel_values.shape
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
image_features = self.vision_tower(
reshaped_pixel_values,
output_hidden_states=True,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
selected_image_feature = image_features.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
image_features = self.multi_modal_projector(selected_image_feature) image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images # split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size) # hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches) # if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0) image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad" # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = ( height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = [] new_image_features = []
for image_idx, image_feature in enumerate(image_features): for image_idx, image_feature in enumerate(image_features):
@ -228,62 +203,94 @@ class LlavaNextForConditionalGeneration(nn.Module):
image_feature = image_feature[1:] image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]: if height * width != base_image_feature.shape[0]:
raise ValueError( raise ValueError("The number of patches is not consistent with the image size.")
"The number of patches is not consistent with the image size."
)
# Dimensions are intentionally swapped to be bug-compatible with num_patch_height, num_patch_width = get_anyres_image_grid_shape(
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59 image_sizes[image_idx].tolist(),
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints, self.config.image_grid_pinpoints,
self.config.vision_config.image_size, self.config.vision_config.image_size,
) )
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1 image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx]) image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat( image_feature = torch.cat(
( (
image_feature, image_feature,
self.image_newline[:, None, None].expand( self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
*image_feature.shape[:-1], 1
),
), ),
dim=-1, dim=-1,
) )
image_feature = image_feature.flatten(1, 2).transpose(0, 1) image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat( image_feature = torch.cat((base_image_feature, image_feature), dim=0)
(base_image_feature, image_feature), dim=0
)
else: else:
image_feature = image_feature[0] image_feature = image_feature[0]
image_feature = torch.cat( image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature) new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0) image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids)
self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position.
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
# generation with cache
elif past_key_values is not None:
seq_len = input_ids.shape[1]
pad_len = seq_len - token_idx.item()
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
inputs_embeds = self._merge_input_ids_with_image_features( # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
input_ids, inputs_embeds, image_features batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]
# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = extended_attention_mask
attention_mask[:, -pad_len:] = 0
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"labels": labels,
"use_flash_attention": use_flash_attention,
"flash_attention_recompute": flash_attention_recompute,
}
) )
hidden_states = self.text_model.model( return model_inputs
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -1,36 +1,9 @@
import torch import torch
import os import os
from loguru import logger
from typing import Dict, Optional from typing import Dict, Optional
from text_generation_server.utils.log import log_master
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
ATTENTION = os.getenv("ATTENTION")
_expected = {"paged", "flashdecoding", "flashinfer"}
assert (
ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1
# This is overridden by the cli # This is overridden by the cli
BLOCK_SIZE: int
if ATTENTION == "flashdecoding":
BLOCK_SIZE = 256
elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1
else:
BLOCK_SIZE = 16
cuda_graphs = os.getenv("CUDA_GRAPHS") cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None: if cuda_graphs is not None:
try: try:
@ -41,13 +14,18 @@ if cuda_graphs is not None:
) )
else: else:
cuda_graphs = None cuda_graphs = None
# sorting the cuda graphs in descending order helps reduce the
# memory impact and results in less memory usage
if cuda_graphs is not None:
cuda_graphs.sort(reverse=True)
CUDA_GRAPHS = cuda_graphs CUDA_GRAPHS = cuda_graphs
# This is overridden at model loading.
global MODEL_ID
MODEL_ID = None
def set_model_id(model_id: str):
global MODEL_ID
MODEL_ID = model_id
# NOTE: eventually we should move this into the router and pass back the # NOTE: eventually we should move this into the router and pass back the
# index in all cases. # index in all cases.
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None

View File

@ -0,0 +1,47 @@
from loguru import logger
import torch
from dataclasses import dataclass
import os
from typing import List, Optional, Type
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch
@dataclass
class StarCoderCausalLMBatch(CausalLMBatch):
past_key_values: Optional[List[torch.Tensor]]
def detach_kv_cache(self):
past_keys = []
past_values = []
last_dim = int(self.past_key_values[0].size(dim=-1)/2)
for key_value in self.past_key_values:
past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0])
past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1])
del self.past_key_values
return past_keys, past_values
def attach_kv_cache(self, past_keys, past_values):
self.past_key_values = [
torch.cat((key, value), dim=-1) for key, value in zip(past_keys, past_values)]
class StarCoder(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
super(StarCoder, self).__init__(
model_id=model_id,
revision=revision,
dtype=dtype,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return StarCoderCausalLMBatch

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,8 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import asyncio import asyncio
import os import os
import sys
import torch import torch
import time import time
import signal import signal
@ -14,23 +17,24 @@ from typing import List, Optional
from text_generation_server.cache import Cache from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id
from text_generation_server.models.globals import set_adapter_to_index
from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.adapter import AdapterInfo
try: try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch #from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.vlm_causal_lm import ( from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch, VlmCausalLMBatch,
) )
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch #from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
VLM_BATCH_TYPES = {PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch} VLM_BATCH_TYPES = {VlmCausalLMBatch}
except (ImportError, NotImplementedError): except (ImportError, NotImplementedError):
# These imports can fail on CPU/Non flash. # These imports can fail on CPU/Non flash.
VLM_BATCH_TYPES = set() VLM_BATCH_TYPES = set()
from text_generation_server.utils.version import is_driver_compatible, MIN_TGI_GAUDI_SYNAPSE_VERSION
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_adapter_to_index
class SignalHandler: class SignalHandler:
@ -58,16 +62,19 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self.quantize = model.quantize self.quantize = model.quantize
self.server_urls = server_urls self.server_urls = server_urls
# For some reason, inference_mode does not work well with GLOO which we use on CPU # For some reason, inference_mode does not work well with GLOO which we use on CPU
if model.device.type == "cuda": # TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul
# op not optimized issue. Will investigate further.
# if model.device.type == "hpu":
# Force inference mode for the lifetime of TextGenerationService # Force inference mode for the lifetime of TextGenerationService
self._inference_mode_raii_guard = torch._C._InferenceMode(True) # self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def Info(self, request, context): async def Info(self, request, context):
return self.model.info return self.model.info
async def Health(self, request, context): async def Health(self, request, context):
if self.model.device.type == "cuda": if self.model.device.type == "hpu":
torch.zeros((2, 2)).cuda() torch.zeros((2, 2)).to("hpu")
return generate_pb2.HealthResponse() return generate_pb2.HealthResponse()
async def ServiceDiscovery(self, request, context): async def ServiceDiscovery(self, request, context):
@ -90,41 +97,17 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
if self.quantize in {"exl2", "gptq"}:
try:
# When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded
# This will allocate those buffers.
from text_generation_server.layers.gptq import (
create_exllama_buffers,
set_device,
)
set_device(self.model.device) max_supported_total_tokens = self.model.warmup(request)
create_exllama_buffers(request.max_prefill_tokens) return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens)
except ImportError: # else:
pass # batch = self.model.batch_type.from_pb(
# request.batch, self.model.tokenizer, self.model.dtype, self.model.device
# )
if ( # max_supported_total_tokens = self.model.warmup(batch)
self.model.batch_type in VLM_BATCH_TYPES # return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens)
): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch,
self.model.tokenizer,
self.model.processor,
self.model.model.config,
self.model.dtype,
self.model.device,
)
else:
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
max_supported_total_tokens = self.model.warmup(batch)
return generate_pb2.WarmupResponse(
max_supported_total_tokens=max_supported_total_tokens
)
async def Prefill(self, request, context): async def Prefill(self, request, context):
start = time.time_ns() start = time.time_ns()
@ -144,7 +127,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request.batch, self.model.tokenizer, self.model.dtype, self.model.device request.batch, self.model.tokenizer, self.model.dtype, self.model.device
) )
generations, next_batch, timings = self.model.generate_token(batch) generations, next_batch, timings = self.model.generate_token([batch])
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.PrefillResponse( return generate_pb2.PrefillResponse(
@ -170,21 +153,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
if len(batches) == 0: if len(batches) == 0:
raise ValueError("All batches are empty") raise ValueError("All batches are empty")
if len(batches) > 1: generations, next_batch, timings = self.model.generate_token(batches)
start_concat = time.time_ns()
batch = self.model.batch_type.concatenate(batches)
concat_ns = time.time_ns() - start_concat
else:
batch = batches[0]
concat_ns = None
generations, next_batch, timings = self.model.generate_token(batch)
self.cache.set(next_batch) self.cache.set(next_batch)
return generate_pb2.DecodeResponse( return generate_pb2.DecodeResponse(
generations=[generation.to_pb() for generation in generations], generations=[generation.to_pb() for generation in generations],
batch=next_batch.to_pb() if next_batch else None, batch=next_batch.to_pb() if next_batch else None,
concat_ns=concat_ns, concat_ns=None,
forward_ns=timings[0], forward_ns=timings[0],
decode_ns=timings[1], decode_ns=timings[1],
total_ns=time.time_ns() - start, total_ns=time.time_ns() - start,
@ -213,18 +188,31 @@ def serve(
dtype: Optional[str] = None, dtype: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
if not is_driver_compatible():
logger.warning(f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures")
unix_socket_template = "unix://{}-{}" unix_socket_template = "unix://{}-{}"
adapter_to_index = {} adapter_to_index = {}
logger.info("Server:server_inner: sharded ={}".format(sharded))
if sharded: if sharded:
rank = int(os.environ["RANK"])
logger.info("Server:server_inner: rank ={}".format(rank))
server_urls = [ server_urls = [
unix_socket_template.format(uds_path, rank) unix_socket_template.format(uds_path, rank) for rank in range(int(os.environ["WORLD_SIZE"]))
for rank in range(int(os.environ["WORLD_SIZE"]))
] ]
local_url = server_urls[int(os.environ["RANK"])] local_url = server_urls[int(os.environ["RANK"])]
else: else:
local_url = unix_socket_template.format(uds_path, 0) local_url = unix_socket_template.format(uds_path, 0)
server_urls = [local_url] server_urls = [local_url]
logger.info("Server:server_inner: data type = {}, local_url = {}".format(dtype, local_url))
if dtype == "bfloat16" or None:
data_type = torch.bfloat16
else:
data_type = torch.float
if revision == "None":
revision = None
try: try:
model = get_model_with_lora_adapters( model = get_model_with_lora_adapters(
model_id, model_id,
@ -233,7 +221,7 @@ def serve(
sharded, sharded,
quantize, quantize,
speculate, speculate,
dtype, data_type,
trust_remote_code, trust_remote_code,
max_input_tokens, max_input_tokens,
adapter_to_index, adapter_to_index,
@ -271,6 +259,7 @@ def serve(
while signal_handler.KEEP_PROCESSING: while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
set_model_id(model_id)
asyncio.run( asyncio.run(
serve_inner( serve_inner(
model_id, model_id,

View File

@ -0,0 +1,45 @@
import os
from pathlib import Path
from loguru import logger
import sys
from text_generation_server import server
import argparse
from typing import List
from text_generation_server.utils.adapter import parse_lora_adapters
def main(args):
logger.info("TGIService: starting tgi service .... ")
logger.info(
"TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format(
args.model_id, args.revision, args.sharded, args.speculate, args.dtype, args.trust_remote_code, args.uds_path
)
)
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
server.serve(
model_id=args.model_id,
lora_adapters=lora_adapters,
revision=args.revision,
sharded=args.sharded,
quantize=args.quantize,
speculate=args.speculate,
dtype=args.dtype,
trust_remote_code=args.trust_remote_code,
uds_path=args.uds_path,
max_input_tokens=args.max_input_tokens
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str)
parser.add_argument("--revision", type=str)
parser.add_argument("--sharded", type=bool)
parser.add_argument("--speculate", type=int, default=None)
parser.add_argument("--dtype", type=str)
parser.add_argument("--trust_remote_code", type=bool)
parser.add_argument("--uds_path", type=Path)
parser.add_argument("--quantize", type=str)
parser.add_argument("--max_input_tokens", type=int)
args = parser.parse_args()
main(args)

View File

@ -1,3 +1,6 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import text_generation_server.habana_quantization_env
from text_generation_server.utils.convert import convert_file, convert_files from text_generation_server.utils.convert import convert_file, convert_files
from text_generation_server.utils.dist import initialize_torch_distributed from text_generation_server.utils.dist import initialize_torch_distributed
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
@ -18,6 +21,9 @@ from text_generation_server.utils.tokens import (
FinishReason, FinishReason,
Sampling, Sampling,
Greedy, Greedy,
make_tokenizer_optional,
is_tokenizer_transparent,
pad_next_token_chooser_parameters,
) )
__all__ = [ __all__ = [

View File

@ -0,0 +1,31 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
import os
import glob
import time
from optimum.habana.utils import to_gb_rounded
import habana_frameworks.torch as htorch
START_TS = None
DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME')
if 'GRAPH_VISUALIZATION' in os.environ:
for f in glob.glob('.graph_dumps/*'):
os.remove(f)
def count_hpu_graphs():
return len(glob.glob('.graph_dumps/*PreGraph*'))
def dbg_trace(tag, txt):
global START_TS
if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0:
if START_TS is None:
START_TS = time.perf_counter()
time_offset = time.perf_counter() - START_TS
mem_stats = htorch.hpu.memory.memory_stats()
mem_used = to_gb_rounded(mem_stats['InUse'])
max_mem_used = to_gb_rounded(mem_stats['MaxInUse'])
print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB '
f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a'))

View File

@ -3,7 +3,6 @@ import torch
from datetime import timedelta from datetime import timedelta
from loguru import logger from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
# Tensor Parallelism settings # Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0")) RANK = int(os.getenv("RANK", "0"))
@ -45,6 +44,12 @@ class FakeGroup:
def initialize_torch_distributed(): def initialize_torch_distributed():
import habana_frameworks.torch.core as htcore
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
options = None
if torch.cuda.is_available(): if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL from torch.distributed import ProcessGroupNCCL
@ -56,8 +61,20 @@ def initialize_torch_distributed():
backend = "nccl" backend = "nccl"
options = ProcessGroupNCCL.Options() options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True options.is_high_priority_stream = True
options._timeout = timedelta(seconds=120) options._timeout = timedelta(seconds=60)
elif torch.hpu.is_available():
backend = "hccl"
n_hpus = torch.hpu.device_count()
if world_size > n_hpus:
raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).")
else: else:
try:
import oneccl_bindings_for_pytorch
backend = "ccl"
if os.getenv("CCL_WORKER_COUNT", None) is None:
os.environ["CCL_WORKER_COUNT"] = str(1)
except ImportError:
backend = "gloo" backend = "gloo"
options = None options = None
@ -69,22 +86,11 @@ def initialize_torch_distributed():
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
# Call the init process. # Call the init process.
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
ipex.distributed.init_process_group(
backend="ccl",
world_size=WORLD_SIZE,
rank=RANK,
timeout=timedelta(seconds=120),
pg_options=options,
)
else:
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend=backend,
world_size=WORLD_SIZE, world_size=WORLD_SIZE,
rank=RANK, rank=RANK,
timeout=timedelta(seconds=120), timeout=timedelta(seconds=60),
pg_options=options, pg_options=options,
) )
else: else:

View File

@ -1,5 +1,6 @@
import math import math
import torch import torch
import habana_frameworks.torch.core as htcore
from loguru import logger from loguru import logger
from typing import Dict, Union from typing import Dict, Union
@ -43,38 +44,32 @@ class StaticWarper:
if typical_p is not None and typical_p < 1.0: if typical_p is not None and typical_p < 1.0:
self.warpers.append(TypicalLogitsWarper(mass=typical_p)) self.warpers.append(TypicalLogitsWarper(mass=typical_p))
self.cuda_graph = None self.hpu_graph = None
self.static_scores = None self.static_scores = None
self.static_warped_scores = None self.static_warped_scores = None
self.static_next_logprob = None self.static_next_logprob = None
def __call__(self, scores): def __call__(self, scores):
if torch.cuda.is_available(): if self.hpu_graph is None:
if self.cuda_graph is None: self.static_scores = scores.clone().contiguous()
self.static_scores = scores self.static_warped_scores = scores.clone().contiguous()
self.cuda_graph = torch.cuda.CUDAGraph() self.static_next_logprob = scores.clone().contiguous()
self.hpu_graph = htcore.hpu.HPUGraph()
with torch.cuda.graph(self.cuda_graph, pool=mempool): with htcore.hpu.graph(self.hpu_graph):
local_scores = self.static_scores local_scores = self.static_scores
for warper in self.warpers: for warper in self.warpers:
local_scores = warper(None, local_scores) local_scores = warper(None, local_scores)
self.static_warped_scores = local_scores self.static_warped_scores.copy_(local_scores)
# Compute logprobs # Compute logprobs
self.static_next_logprob = torch.log_softmax( self.static_next_logprob.copy_(torch.log_softmax(self.static_warped_scores, -1))
self.static_warped_scores, -1
)
self.static_scores.copy_(scores) self.static_scores.copy_(scores)
self.cuda_graph.replay() self.hpu_graph.replay()
return self.static_warped_scores, self.static_next_logprob return self.static_warped_scores, self.static_next_logprob
# CPU branch
for warper in self.warpers:
scores = warper(None, scores)
return scores, torch.log_softmax(scores, -1)
@lru_cache(10) @lru_cache(10)
def static_warper( def static_warper(
@ -83,9 +78,7 @@ def static_warper(
top_p: Optional[float], top_p: Optional[float],
typical_p: Optional[float], typical_p: Optional[float],
) -> StaticWarper: ) -> StaticWarper:
return StaticWarper( return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
)
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
@ -102,17 +95,13 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device): def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
self.penalty = penalty self.penalty = penalty
self.penalty_tensor = torch.tensor( self.penalty_tensor = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1)
penalty, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
score = torch.gather(scores, 1, input_ids) score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where( score = torch.where(score < 0, score * self.penalty_tensor, score / self.penalty_tensor)
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
)
scores.scatter_(1, input_ids, score) scores.scatter_(1, input_ids, score)
return scores return scores
@ -170,9 +159,11 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
vocab_size = scores.size(1) vocab_size = scores.size(1)
# Calculate the frequency for each token so far # Calculate the frequency for each token so far
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device) token_freq = torch.zeros(
batch_size, vocab_size, dtype=scores.dtype, device=scores.device
)
token_freq.scatter_add_( token_freq.scatter_add_(
1, input_ids, torch.ones_like(input_ids, dtype=torch.float) 1, input_ids, torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device)
) )
token_freq /= input_size token_freq /= input_size
@ -199,13 +190,9 @@ class HeterogeneousTemperatureLogitsWarper:
The value used to module the logits distribution. The value used to module the logits distribution.
""" """
def __init__( def __init__(self, temperature: List[float], dtype: torch.dtype, device: torch.device):
self, temperature: List[float], dtype: torch.dtype, device: torch.device
):
self.temperature = temperature self.temperature = temperature
self.temperature_tensor = torch.tensor( self.temperature_tensor = torch.tensor(temperature, dtype=dtype, device=device).unsqueeze(1)
temperature, dtype=dtype, device=device
).unsqueeze(1)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
scores.div_(self.temperature_tensor) scores.div_(self.temperature_tensor)
@ -244,9 +231,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
min_tokens_to_keep: int = 1, min_tokens_to_keep: int = 1,
): ):
self.top_p = top_p self.top_p = top_p
self.top_p_opposite = 1 - torch.tensor( self.top_p_opposite = 1 - torch.tensor(top_p, dtype=dtype, device=device).unsqueeze(1)
top_p, dtype=dtype, device=device
).unsqueeze(1)
self.filter_value = filter_value self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
@ -263,9 +248,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0 sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
1, sorted_indices, sorted_indices_to_remove
)
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
return warped_scores return warped_scores
@ -313,9 +296,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
disabled = [x == 0 for x in top_k] disabled = [x == 0 for x in top_k]
if any(disabled): if any(disabled):
self.top_k_disabled_mask = torch.tensor( self.top_k_disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device).view(-1, 1)
disabled, dtype=torch.bool, device=device
).view(-1, 1)
else: else:
self.top_k_disabled_mask = None self.top_k_disabled_mask = None
@ -351,9 +332,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
self.max_top_k = max(self.top_k) self.max_top_k = max(self.top_k)
if self.top_k_disabled_mask is not None: if self.top_k_disabled_mask is not None:
self.top_k_disabled_mask = ( self.top_k_disabled_mask = self.top_k_disabled_mask[indices] if any(disabled) else None
self.top_k_disabled_mask[indices] if any(disabled) else None
)
return self return self
return None return None
@ -419,15 +398,11 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
if self.disabled_mask is not None: if self.disabled_mask is not None:
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1) last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
sorted_indices_to_remove = sorted_scores > sorted_scores.gather( sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
1, last_ind.view(-1, 1)
)
if self.min_tokens_to_keep > 1: if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
1, sorted_indices, sorted_indices_to_remove
)
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value) warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
@ -441,9 +416,7 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
self.mass_tensor = self.mass_tensor[indices] self.mass_tensor = self.mass_tensor[indices]
if self.disabled_mask is not None: if self.disabled_mask is not None:
self.disabled_mask = ( self.disabled_mask = self.disabled_mask[indices] if any(disabled) else None
self.disabled_mask[indices] if any(disabled) else None
)
return self return self
return None return None
@ -521,13 +494,7 @@ class GrammarLogitProcessor(LogitsProcessor):
def _cached_compile_fsm(grammar_type, schema, tokenizer): def _cached_compile_fsm(grammar_type, schema, tokenizer):
start_time = time.time() start_time = time.time()
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON: if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
try:
schema = build_regex_from_schema(schema) schema = build_regex_from_schema(schema)
# TODO: this is only here short term to avoid crashing the python server, mid term we want this in the rust/router layer
except Exception as e:
logger.error(f"Error compiling FSM, grammar won't be enforced \n{e}")
# allows everything
schema = "(.*?)"
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX: elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
pass # schema is already a regex just here for clarity pass # schema is already a regex just here for clarity
fsm = RegexFSM(schema, tokenizer) fsm = RegexFSM(schema, tokenizer)
@ -586,7 +553,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
mask = torch.full_like(logits, -math.inf) mask = torch.full_like(logits, -math.inf)
for i in range(logits.shape[0]): for i in range(logits.shape[0]):
fsm = self.fsms[i] fsm = self.fsms[i]
if fsm_grammar_states[i] == -1 or fsm is None: if fsm is None:
continue continue
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i]) allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
mask[i, allowed_tokens] = 0 mask[i, allowed_tokens] = 0

View File

@ -247,10 +247,12 @@ class HeterogeneousNextTokenChooser:
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
grammars: List[str], grammars: List[str],
grammar_types: List[int], grammar_types: List[int],
fsm_grammar_states=List[int], fsm_grammar_states:List[int],
quantization_enabled: bool,
): ):
warpers = [] warpers = []
# TODO: enable watermark with FP8 quantization
self.watermark_processor = ( self.watermark_processor = (
HeterogeneousProcessorWrapper( HeterogeneousProcessorWrapper(
{ {
@ -259,7 +261,7 @@ class HeterogeneousNextTokenChooser:
if do_watermark if do_watermark
} }
) )
if any(watermark) if any(watermark) and not quantization_enabled
else None else None
) )
@ -431,6 +433,18 @@ class HeterogeneousNextTokenChooser:
) )
return self return self
def advance_grammar_single_with_past_state(
self, grammar_state_index: int, next_id: torch.Tensor, past_state: int
):
if self.grammar_processor is not None:
next_id = next_id.item()
self.fsm_grammar_states[grammar_state_index] = (
self.grammar_processor.advance_at_index(
next_id, past_state, grammar_state_index,
)
)
return self
def filter(self, indices): def filter(self, indices):
if self.watermark_processor is not None: if self.watermark_processor is not None:
self.watermark_processor = self.watermark_processor.filter(indices) self.watermark_processor = self.watermark_processor.filter(indices)
@ -481,6 +495,7 @@ class HeterogeneousNextTokenChooser:
device: torch.device, device: torch.device,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fsm_grammar_states: Optional[List[int]] = None, fsm_grammar_states: Optional[List[int]] = None,
quantization_enabled: bool = False,
) -> "HeterogeneousNextTokenChooser": ) -> "HeterogeneousNextTokenChooser":
return HeterogeneousNextTokenChooser( return HeterogeneousNextTokenChooser(
watermark=[pb_.watermark for pb_ in pb], watermark=[pb_.watermark for pb_ in pb],
@ -500,12 +515,37 @@ class HeterogeneousNextTokenChooser:
fsm_grammar_states=( fsm_grammar_states=(
fsm_grammar_states if fsm_grammar_states else [0] * len(pb) fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
), ),
quantization_enabled=quantization_enabled,
) )
def pad_next_token_chooser_parameters(
parameters: List[generate_pb2.NextTokenChooserParameters],
expected_size: int,
) -> List[generate_pb2.NextTokenChooserParameters]:
# disable all logits processors to minimize padding overhead
empty_parameters = generate_pb2.NextTokenChooserParameters(
temperature=1.0,
top_k=0,
top_p=1.0,
typical_p=1.0,
do_sample=False,
seed=0,
repetition_penalty=1.0,
frequency_penalty=0.0,
watermark=False,
grammar="",
grammar_type=0,
)
parameters.extend(
[empty_parameters] * (expected_size - len(parameters))
)
return parameters
class Sampling: class Sampling:
def __init__(self, seed: int, device: str = "cpu"): def __init__(self, seed: int, device: str = "cpu"):
self.generator = torch.Generator(device) self.generator = torch.Generator("cpu")
self.generator.manual_seed(seed) self.generator.manual_seed(seed)
self.seed = seed self.seed = seed
@ -541,7 +581,7 @@ class HeterogeneousSampling:
self.greedy = Greedy() self.greedy = Greedy()
def __call__(self, logits): def __call__(self, logits):
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device) out = torch.zeros(logits.shape[0], dtype=torch.int64, device=logits.device)
if self.greedy_indices: if self.greedy_indices:
# Computing for all indices is faster than slicing # Computing for all indices is faster than slicing
torch.argmax(logits, -1, out=out) torch.argmax(logits, -1, out=out)
@ -643,3 +683,50 @@ def batch_top_tokens(
batch_top_token_logprobs.append(row_top_token_logprobs) batch_top_token_logprobs.append(row_top_token_logprobs)
return batch_top_token_ids, batch_top_token_logprobs return batch_top_token_ids, batch_top_token_logprobs
def make_tokenizer_optional(tokenizer):
class _(type(tokenizer)):
def __call__(
self,
text,
return_tensors,
padding,
return_token_type_ids,
truncation,
max_length
):
assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer"
assert padding == "max_length" or padding == "longest", "inccorrect input arguments when calling TransparentTokenizer"
assert return_token_type_ids == False, "inccorrect input arguments when calling TransparentTokenizer"
assert truncation == True, "inccorrect input arguments when calling TransparentTokenizer"
def str_token_to_int(i):
if i == '?':
return tokenizer.pad_token_id
else:
return int(i)
all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')]
for inner_text in text]
if padding == "longest":
max_length = max(len(tokens) for tokens in all_tokens)
return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens for tokens in all_tokens]),
"attention_mask": torch.tensor([[0] * (max_length - len(tokens)) + [1] * len(tokens) for tokens in all_tokens])}
def decode(
self,
token_ids,
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
**kwargs,
) -> str:
return ','.join(str(i) for i in to_py_obj(token_ids))
import os
if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":
tokenizer.__class__ = _
tokenizer.is_transparent = True
def is_tokenizer_transparent(tokenizer):
return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True

View File

@ -0,0 +1,12 @@
from optimum.habana.utils import get_driver_version
from packaging.version import Version
MIN_TGI_GAUDI_SYNAPSE_VERSION=Version("1.16.0")
def is_driver_compatible():
driver_version = get_driver_version()
if driver_version is not None:
if driver_version < MIN_TGI_GAUDI_SYNAPSE_VERSION:
return False
return True

View File

@ -34,7 +34,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
# watermarking parameters # watermarking parameters
self.gamma = gamma self.gamma = gamma
self.delta = delta self.delta = delta
self.rng = torch.Generator(device=device) self.rng = torch.Generator(device="cpu")
self.hash_key = hash_key self.hash_key = hash_key
def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]): def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):