mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-26 12:32:10 +00:00
Make Gaudi adapt to the tgi 2.3.0
Signed-off-by: yuanwu <yuan.wu@intel.com>
This commit is contained in:
parent
14fdc4ae5e
commit
bab529c916
@ -61,6 +61,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
|
||||
make \
|
||||
curl \
|
||||
git \
|
||||
python3.11-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install server
|
||||
@ -96,5 +97,5 @@ FROM base
|
||||
COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
|
||||
RUN chmod +x /tgi-entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
#ENTRYPOINT ["/tgi-entrypoint.sh"]
|
||||
# CMD ["--json-output"]
|
||||
|
@ -4,10 +4,6 @@ include Makefile-vllm
|
||||
include Makefile-awq
|
||||
include Makefile-eetq
|
||||
include Makefile-selective-scan
|
||||
include Makefile-lorax-punica
|
||||
include Makefile-fbgemm
|
||||
include Makefile-exllamav2
|
||||
include Makefile-flashinfer
|
||||
|
||||
unit-tests:
|
||||
pytest -s -vv -m "not private" tests
|
||||
@ -21,25 +17,20 @@ gen-server:
|
||||
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
|
||||
touch text_generation_server/pb/__init__.py
|
||||
|
||||
install-server: gen-server
|
||||
install: gen-server
|
||||
pip install pip --upgrade
|
||||
pip install -r requirements_cuda.txt
|
||||
pip install -e ".[accelerate, quantize, peft, outlines]"
|
||||
|
||||
|
||||
install: install-cuda
|
||||
echo "Installed server"
|
||||
|
||||
install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm
|
||||
pip install -e ".[bnb]"
|
||||
pip install nvidia-nccl-cu12==2.22.3
|
||||
|
||||
install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm
|
||||
pip install -r requirements.txt
|
||||
pip install -e "."
|
||||
|
||||
run-dev:
|
||||
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
|
||||
|
||||
install-poetry:
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
update-lock:
|
||||
rm poetry.lock
|
||||
poetry lock --no-update
|
||||
|
||||
export-requirements:
|
||||
poetry export -o requirements_cuda.txt --without-hashes
|
||||
poetry export -o requirements_rocm.txt --without-hashes
|
||||
poetry export -o requirements_intel.txt --without-hashes
|
||||
poetry export -o requirements.txt --without-hashes
|
||||
|
91
server/dill-0.3.7-patch.sh
Normal file
91
server/dill-0.3.7-patch.sh
Normal file
@ -0,0 +1,91 @@
|
||||
#!/bin/bash
|
||||
git clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git
|
||||
pushd dill
|
||||
cat <<EOF > dill-0.3.7.patch
|
||||
diff --git a/dill/_dill.py b/dill/_dill.py
|
||||
index d0cf543..f6eb662 100644
|
||||
--- a/dill/_dill.py
|
||||
+++ b/dill/_dill.py
|
||||
@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
|
||||
XRangeType = range
|
||||
from types import MappingProxyType as DictProxyType, new_class
|
||||
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
|
||||
-import __main__ as _main_module
|
||||
+class _LazyMainModule(object):
|
||||
+ _module = None
|
||||
+ @property
|
||||
+ def module(self):
|
||||
+ if self._module is None:
|
||||
+ import __main__ as _m_module
|
||||
+ self._module = _m_module
|
||||
+ return self._module
|
||||
+_main_module = _LazyMainModule()
|
||||
import marshal
|
||||
import gc
|
||||
# import zlib
|
||||
@@ -353,7 +361,7 @@ class Pickler(StockPickler):
|
||||
_fmode = kwds.pop('fmode', None)
|
||||
_recurse = kwds.pop('recurse', None)
|
||||
StockPickler.__init__(self, file, *args, **kwds)
|
||||
- self._main = _main_module
|
||||
+ self._main = _main_module.module
|
||||
self._diff_cache = {}
|
||||
self._byref = settings['byref'] if _byref is None else _byref
|
||||
self._strictio = False #_strictio
|
||||
@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler):
|
||||
settings = Pickler.settings
|
||||
_ignore = kwds.pop('ignore', None)
|
||||
StockUnpickler.__init__(self, *args, **kwds)
|
||||
- self._main = _main_module
|
||||
+ self._main = _main_module.module
|
||||
self._ignore = settings['ignore'] if _ignore is None else _ignore
|
||||
|
||||
def load(self): #NOTE: if settings change, need to update attributes
|
||||
obj = StockUnpickler.load(self)
|
||||
- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
|
||||
+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
|
||||
if not self._ignore:
|
||||
# point obj class to main
|
||||
try: obj.__class__ = getattr(self._main, type(obj).__name__)
|
||||
@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj):
|
||||
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
|
||||
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
|
||||
logger.trace(pickler, "# D1")
|
||||
- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
|
||||
+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
|
||||
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
|
||||
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
|
||||
logger.trace(pickler, "# D3")
|
||||
- elif '__name__' in obj and obj != _main_module.__dict__ \\
|
||||
+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
|
||||
and type(obj['__name__']) is str \\
|
||||
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
|
||||
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
|
||||
diff --git a/dill/session.py b/dill/session.py
|
||||
index 74234ab..1be8d89 100644
|
||||
--- a/dill/session.py
|
||||
+++ b/dill/session.py
|
||||
@@ -233,7 +233,7 @@ def dump_module(
|
||||
protocol = settings['protocol']
|
||||
main = module
|
||||
if main is None:
|
||||
- main = _main_module
|
||||
+ main = _main_module.module
|
||||
elif isinstance(main, str):
|
||||
main = _import_module(main)
|
||||
if not isinstance(main, ModuleType):
|
||||
@@ -501,7 +501,7 @@ def load_module(
|
||||
pass
|
||||
assert loaded is main
|
||||
_restore_modules(unpickler, main)
|
||||
- if main is _main_module or main is module:
|
||||
+ if main is _main_module.module or main is module:
|
||||
return None
|
||||
else:
|
||||
return main
|
||||
|
||||
EOF
|
||||
git apply dill-0.3.7.patch
|
||||
python -m pip install .
|
||||
popd
|
||||
rm -fr dill
|
91
server/dill-0.3.8-patch.sh
Normal file
91
server/dill-0.3.8-patch.sh
Normal file
@ -0,0 +1,91 @@
|
||||
#!/bin/bash
|
||||
git clone -b 0.3.8 https://github.com/uqfoundation/dill.git
|
||||
pushd dill
|
||||
cat <<EOF > dill-0.3.8.patch
|
||||
diff --git a/dill/_dill.py b/dill/_dill.py
|
||||
index d42432f..1d251e6 100644
|
||||
--- a/dill/_dill.py
|
||||
+++ b/dill/_dill.py
|
||||
@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
|
||||
XRangeType = range
|
||||
from types import MappingProxyType as DictProxyType, new_class
|
||||
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
|
||||
-import __main__ as _main_module
|
||||
+class _LazyMainModule(object):
|
||||
+ _module = None
|
||||
+ @property
|
||||
+ def module(self):
|
||||
+ if self._module is None:
|
||||
+ import __main__ as _m_module
|
||||
+ self._module = _m_module
|
||||
+ return self._module
|
||||
+_main_module = _LazyMainModule()
|
||||
import marshal
|
||||
import gc
|
||||
# import zlib
|
||||
@@ -355,7 +363,7 @@ class Pickler(StockPickler):
|
||||
_fmode = kwds.pop('fmode', None)
|
||||
_recurse = kwds.pop('recurse', None)
|
||||
StockPickler.__init__(self, file, *args, **kwds)
|
||||
- self._main = _main_module
|
||||
+ self._main = _main_module.module
|
||||
self._diff_cache = {}
|
||||
self._byref = settings['byref'] if _byref is None else _byref
|
||||
self._strictio = False #_strictio
|
||||
@@ -437,12 +445,12 @@ class Unpickler(StockUnpickler):
|
||||
settings = Pickler.settings
|
||||
_ignore = kwds.pop('ignore', None)
|
||||
StockUnpickler.__init__(self, *args, **kwds)
|
||||
- self._main = _main_module
|
||||
+ self._main = _main_module.module
|
||||
self._ignore = settings['ignore'] if _ignore is None else _ignore
|
||||
|
||||
def load(self): #NOTE: if settings change, need to update attributes
|
||||
obj = StockUnpickler.load(self)
|
||||
- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
|
||||
+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
|
||||
if not self._ignore:
|
||||
# point obj class to main
|
||||
try: obj.__class__ = getattr(self._main, type(obj).__name__)
|
||||
@@ -1199,11 +1207,11 @@ def save_module_dict(pickler, obj):
|
||||
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
|
||||
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
|
||||
logger.trace(pickler, "# D1")
|
||||
- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
|
||||
+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
|
||||
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
|
||||
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
|
||||
logger.trace(pickler, "# D3")
|
||||
- elif '__name__' in obj and obj != _main_module.__dict__ \\
|
||||
+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
|
||||
and type(obj['__name__']) is str \\
|
||||
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
|
||||
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
|
||||
diff --git a/dill/session.py b/dill/session.py
|
||||
index e91068a..a921b43 100644
|
||||
--- a/dill/session.py
|
||||
+++ b/dill/session.py
|
||||
@@ -233,7 +233,7 @@ def dump_module(
|
||||
protocol = settings['protocol']
|
||||
main = module
|
||||
if main is None:
|
||||
- main = _main_module
|
||||
+ main = _main_module.module
|
||||
elif isinstance(main, str):
|
||||
main = _import_module(main)
|
||||
if not isinstance(main, ModuleType):
|
||||
@@ -501,7 +501,7 @@ def load_module(
|
||||
pass
|
||||
assert loaded is main
|
||||
_restore_modules(unpickler, main)
|
||||
- if main is _main_module or main is module:
|
||||
+ if main is _main_module.module or main is module:
|
||||
return None
|
||||
else:
|
||||
return main
|
||||
|
||||
EOF
|
||||
git apply dill-0.3.8.patch
|
||||
python -m pip install .
|
||||
popd
|
||||
rm -fr dill
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "text-generation-server"
|
||||
version = "2.0.5-dev0"
|
||||
version = "2.0.4"
|
||||
description = "Text Generation Inference Python gRPC Server"
|
||||
authors = ["Olivier Dehaene <olivier@huggingface.co>"]
|
||||
|
||||
@ -9,76 +9,34 @@ text-generation-server = 'text_generation_server.cli:app'
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9,<3.13"
|
||||
protobuf = "^4.25.3"
|
||||
protobuf = "^3.20.3"
|
||||
grpcio = "^1.51.1"
|
||||
grpcio-status = "^1.51.1"
|
||||
grpcio-reflection = "^1.51.1"
|
||||
grpcio-status = "*"
|
||||
grpcio-reflection = "*"
|
||||
grpc-interceptor = "^0.15.0"
|
||||
typer = "^0.6.1"
|
||||
accelerate = { version = "^0.29.1", optional = true }
|
||||
bitsandbytes = { version = "^0.43.0", optional = true }
|
||||
safetensors = "^0.4"
|
||||
typer = "^0.7.0"
|
||||
loguru = "^0.6.0"
|
||||
opentelemetry-api = "^1.25.0"
|
||||
opentelemetry-exporter-otlp = "^1.25.0"
|
||||
opentelemetry-instrumentation-grpc = "^0.46b0"
|
||||
opentelemetry-api = "^1.15.0"
|
||||
opentelemetry-exporter-otlp = "^1.15.0"
|
||||
opentelemetry-instrumentation-grpc = "^0.36b0"
|
||||
hf-transfer = "^0.1.2"
|
||||
sentencepiece = "^0.1.97"
|
||||
tokenizers = "^0.19.1"
|
||||
huggingface-hub = "^0.23"
|
||||
transformers = "^4.43"
|
||||
einops = "^0.6.1"
|
||||
texttable = { version = "^1.6.7", optional = true }
|
||||
datasets = { version = "^2.14.0", optional = true }
|
||||
peft = { version = "^0.10", optional = true }
|
||||
torch = { version = "^2.4.0", optional = true }
|
||||
scipy = "^1.11.1"
|
||||
pillow = "^10.0.0"
|
||||
outlines= { version = "^0.0.34", optional = true }
|
||||
peft = "^0.10"
|
||||
optimum-habana = "1.13.2"
|
||||
transformers = "4.43.4"
|
||||
numpy = "1.26.4"
|
||||
accelerate = "0.33.0"
|
||||
outlines= { version = "^0.0.36", optional = true }
|
||||
prometheus-client = "^0.20.0"
|
||||
py-cpuinfo = "^9.0.0"
|
||||
# Remove later, temporary workaround for outlines.
|
||||
numpy = "^1.26"
|
||||
|
||||
marlin-kernels = [
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.2.0/marlin_kernels-0.2.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
]
|
||||
moe-kernels = [
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
|
||||
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.3.1/moe_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
|
||||
]
|
||||
rich = "^13.7.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
torch = ["torch"]
|
||||
accelerate = ["accelerate"]
|
||||
bnb = ["bitsandbytes"]
|
||||
marlin = ["marlin-kernels"]
|
||||
moe = ["moe-kernels"]
|
||||
peft = ["peft"]
|
||||
quantize = ["texttable", "datasets", "accelerate"]
|
||||
outlines = ["outlines"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
grpcio-tools = "^1.51.1"
|
||||
grpcio-tools = "*"
|
||||
pytest = "^7.3.0"
|
||||
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = "pytorch-gpu-src"
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
priority = "explicit"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"]
|
||||
|
||||
[build-system]
|
||||
requires = [
|
||||
"poetry-core>=1.0.0",
|
||||
]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
88
server/requirements.txt
Normal file
88
server/requirements.txt
Normal file
@ -0,0 +1,88 @@
|
||||
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
aiohappyeyeballs==2.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
aiohttp==3.10.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
|
||||
attrs==24.2.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
certifi==2024.7.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
|
||||
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
|
||||
coloredlogs==15.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
datasets==2.21.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
|
||||
diffusers==0.29.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
dill==0.3.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
filelock==3.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
fsspec[http]==2024.6.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
googleapis-common-protos==1.63.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-reflection==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio-status==1.48.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
grpcio==1.66.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
hf-transfer==0.1.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
huggingface-hub==0.24.6 ; python_version >= "3.9" and python_version < "3.13"
|
||||
humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
idna==3.8 ; python_version >= "3.9" and python_version < "3.13"
|
||||
importlib-metadata==8.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
jinja2==3.1.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
joblib==1.4.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
multidict==6.0.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
multiprocess==0.70.16 ; python_version >= "3.9" and python_version < "3.13"
|
||||
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
optimum-habana==1.13.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
optimum==1.21.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
packaging==24.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
peft==0.10.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pillow==10.4.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
prometheus-client==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
protobuf==3.20.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
psutil==6.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
py-cpuinfo==9.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyarrow==17.0.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyreadline3==3.4.1 ; sys_platform == "win32" and python_version >= "3.9" and python_version < "3.13"
|
||||
python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pytz==2024.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
regex==2024.7.24 ; python_version >= "3.9" and python_version < "3.13"
|
||||
requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13"
|
||||
safetensors==0.4.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scikit-learn==1.5.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentence-transformers[train]==3.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
|
||||
setuptools==73.0.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
sympy==1.12.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tqdm==4.66.5 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
transformers[sentencepiece]==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
triton==3.0.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.13" and python_version >= "3.9"
|
||||
typer==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13"
|
||||
urllib3==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
|
||||
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
|
||||
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
xxhash==3.5.0 ; python_version >= "3.9" and python_version < "3.13"
|
||||
yarl==1.9.4 ; python_version >= "3.9" and python_version < "3.13"
|
||||
zipp==3.20.0 ; python_version >= "3.9" and python_version < "3.13"
|
@ -47,9 +47,9 @@ def serve(
|
||||
max_input_tokens: Optional[int] = None,
|
||||
):
|
||||
if sharded:
|
||||
assert (
|
||||
os.getenv("RANK", None) is not None
|
||||
), "RANK must be set when sharded is True"
|
||||
# assert (
|
||||
# os.getenv("RANK", None) is not None
|
||||
# ), "RANK must be set when sharded is True"
|
||||
assert (
|
||||
os.getenv("WORLD_SIZE", None) is not None
|
||||
), "WORLD_SIZE must be set when sharded is True"
|
||||
@ -96,7 +96,7 @@ def serve(
|
||||
|
||||
# Downgrade enum into str for easier management later on
|
||||
quantize = None if quantize is None else quantize.value
|
||||
dtype = None if dtype is None else dtype.value
|
||||
dtype = "bfloat16" if dtype is None else dtype.value
|
||||
if dtype is not None and quantize not in {
|
||||
None,
|
||||
"bitsandbytes",
|
||||
@ -106,18 +106,76 @@ def serve(
|
||||
raise RuntimeError(
|
||||
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
|
||||
)
|
||||
server.serve(
|
||||
model_id,
|
||||
lora_adapters,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
||||
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
|
||||
|
||||
if sharded:
|
||||
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
|
||||
num_shard = int(os.getenv("WORLD_SIZE", "1"))
|
||||
logger.info("CLI SHARDED = {}".format(num_shard))
|
||||
import subprocess
|
||||
|
||||
cmd = f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
|
||||
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
|
||||
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
|
||||
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
|
||||
if speculate is not None:
|
||||
cmd += f"--speculate {speculate}"
|
||||
logger.info("CLI server start deepspeed ={} ".format(cmd))
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
|
||||
do_terminate = False
|
||||
current_handler = signal.getsignal(signal.SIGTERM)
|
||||
def terminate_handler(sig, frame):
|
||||
nonlocal do_terminate
|
||||
do_terminate = True
|
||||
if callable(current_handler):
|
||||
current_handler(sig, frame)
|
||||
|
||||
signal.signal(signal.SIGTERM, terminate_handler)
|
||||
|
||||
finished = False
|
||||
while not finished:
|
||||
try:
|
||||
if do_terminate:
|
||||
parent = psutil.Process(proc.pid)
|
||||
all_procs = parent.children(recursive=True) + [parent]
|
||||
for p in all_procs:
|
||||
try:
|
||||
p.terminate()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
_, alive = psutil.wait_procs(all_procs, timeout=30)
|
||||
for p in alive:
|
||||
p.kill()
|
||||
|
||||
do_terminate = False
|
||||
|
||||
proc.wait(timeout=3)
|
||||
except subprocess.TimeoutExpired:
|
||||
pass
|
||||
else:
|
||||
finished = True
|
||||
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
if proc.returncode != 0:
|
||||
logger.error(f"{cmd} exited with status = {proc.returncode}")
|
||||
return proc.returncode
|
||||
else:
|
||||
server.serve(
|
||||
model_id,
|
||||
lora_adapters,
|
||||
revision,
|
||||
sharded,
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
trust_remote_code,
|
||||
uds_path,
|
||||
max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
27
server/text_generation_server/habana_quantization_env.py
Normal file
27
server/text_generation_server/habana_quantization_env.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||
|
||||
import os
|
||||
|
||||
quant_config = os.getenv("QUANT_CONFIG", "")
|
||||
is_quantization_enabled = quant_config != ""
|
||||
|
||||
if is_quantization_enabled:
|
||||
os.environ.setdefault("ENABLE_EXPERIMENTAL_FLAGS", "true")
|
||||
os.environ.setdefault("USE_DEFAULT_QUANT_PARAM", "true")
|
||||
os.environ.setdefault("UPDATE_GRAPH_OUTPUT_MME", "false")
|
||||
os.environ.setdefault("ENABLE_CALC_DYNAMIC_RANGE", "false")
|
||||
os.environ.setdefault(
|
||||
"UPDATE_MME_OUTPUT_PRECISION_FILTER", "v_proj,matmul_av")
|
||||
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
|
||||
|
||||
|
||||
def prepare_model_for_quantization(model):
|
||||
if is_quantization_enabled:
|
||||
if os.getenv("USE_INC", "1") != "0":
|
||||
from neural_compressor.torch.quantization import FP8Config, convert
|
||||
config = FP8Config.from_json_file(quant_config)
|
||||
model = convert(model, config)
|
||||
else:
|
||||
import habana_quantization_toolkit
|
||||
habana_quantization_toolkit.prep_model(model)
|
||||
return model
|
@ -1,3 +1,5 @@
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||
|
||||
import torch
|
||||
import grpc
|
||||
|
||||
@ -6,6 +8,8 @@ from grpc_status import rpc_status
|
||||
from grpc_interceptor.server import AsyncServerInterceptor
|
||||
from loguru import logger
|
||||
from typing import Callable, Any
|
||||
import traceback
|
||||
import os
|
||||
|
||||
|
||||
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
@ -20,6 +24,7 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
response = method(request_or_iterator, context)
|
||||
return await response
|
||||
except Exception as err:
|
||||
trace = " " + traceback.format_exc() if os.environ.get('DUMP_STACK') else ''
|
||||
method_name = method_name.split("/")[-1]
|
||||
logger.exception(f"Method {method_name} encountered an error.")
|
||||
|
||||
@ -30,8 +35,10 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
from .utils.debug import dbg_trace
|
||||
dbg_trace('EXCEPTION', traceback.format_exc())
|
||||
await context.abort_with_status(
|
||||
rpc_status.to_status(
|
||||
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err))
|
||||
status_pb2.Status(code=code_pb2.INTERNAL, message=str(err) + trace)
|
||||
)
|
||||
)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,11 +1,10 @@
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
@ -21,26 +20,33 @@ class BloomCausalLMBatch(CausalLMBatch):
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
||||
batch = super().from_pb(
|
||||
pb=pb,
|
||||
tokenizer=tokenizer,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
||||
|
||||
class BLOOMSharded(CausalLM):
|
||||
class BLOOM(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(BLOOM, self).__init__(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return BloomCausalLMBatch
|
||||
|
||||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
logits = outputs.logits
|
||||
return logits, speculative_logits, outputs.past_key_values
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -14,33 +14,26 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch Llava-NeXT model."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
unpad_image,
|
||||
)
|
||||
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (height, width).
|
||||
The size of the input image in the format (width, height).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
@ -48,7 +41,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (height, width).
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
@ -57,100 +50,13 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def unpad_image(tensor, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):
|
||||
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||
original_size (`tuple`):
|
||||
The original size of the image (height, width).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The unpadded image tensor.
|
||||
"""
|
||||
original_height, original_width = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
||||
class LlavaNextMultiModalProjector(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlavaNextForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
vision_config = config.vision_config
|
||||
# Instead of selecting in hidden_states[-2].
|
||||
# Instead compute only the n -2 + 1 layers and don't pool
|
||||
if config.vision_feature_layer < 0:
|
||||
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||
else:
|
||||
vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||
self.vision_tower = load_vision_model(
|
||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||
config=config.vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||
prefix="multi_modal_projector", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.image_newline = weights.get_tensor("image_newline")
|
||||
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.config = config
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.speculator = config.speculator
|
||||
self.text_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
)
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
@ -165,125 +71,226 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused for this model
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[int] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
token_idx: Optional[torch.Tensor] = None,
|
||||
use_flash_attention: Optional[bool] = False,
|
||||
flash_attention_recompute: Optional[bool] = False,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
# 1. Extract the input embeddings
|
||||
|
||||
# 2. Merge text and images
|
||||
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(
|
||||
num_images * num_patches, channels, height, width
|
||||
if token_idx is not None:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
image_features = self.vision_tower(pixel_values)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||
# Already done within the clip model
|
||||
selected_image_feature = image_features.last_hidden_state
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
token_idx=token_idx,
|
||||
use_flash_attention=use_flash_attention,
|
||||
flash_attention_recompute=flash_attention_recompute,
|
||||
)
|
||||
|
||||
if self.config.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif self.config.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
logits = outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return output
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
pixel_values=None,
|
||||
image_sizes=None,
|
||||
attention_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
|
||||
The only differences are:
|
||||
- add new args token_idx
|
||||
- add the process of merging images into inputs_embeds
|
||||
"""
|
||||
token_idx = kwargs.get("token_idx", None)
|
||||
if token_idx is None:
|
||||
return super().prepare_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||
use_flash_attention = kwargs.get("use_flash_attention", False)
|
||||
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
labels = kwargs.get("labels", None)
|
||||
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
|
||||
vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None)
|
||||
vision_feature_layer = kwargs.get("vision_feature_layer", None)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
# 1. Extract the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
# 2. Merge text and images
|
||||
batch_size, num_patches, num_channels, height, width = pixel_values.shape
|
||||
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
|
||||
image_features = self.vision_tower(
|
||||
reshaped_pixel_values,
|
||||
output_hidden_states=True,
|
||||
use_flash_attention=use_flash_attention,
|
||||
flash_attention_recompute=flash_attention_recompute,
|
||||
)
|
||||
|
||||
selected_image_feature = image_features.hidden_states[vision_feature_layer]
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [image.shape[0] for image in pixel_values]
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError("The number of patches is not consistent with the image size.")
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx].tolist(),
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
|
||||
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(inputs_embeds, image_features, input_ids)
|
||||
self.image_offset = image_features.shape[1] - 1 # image_token has occupied 1 token position.
|
||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
elif past_key_values is not None:
|
||||
seq_len = input_ids.shape[1]
|
||||
pad_len = seq_len - token_idx.item()
|
||||
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
|
||||
|
||||
# Get the target length
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = extended_attention_mask
|
||||
attention_mask[:, -pad_len:] = 0
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
if token_idx is not None:
|
||||
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
else:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"token_idx": token_idx,
|
||||
"labels": labels,
|
||||
"use_flash_attention": use_flash_attention,
|
||||
"flash_attention_recompute": flash_attention_recompute,
|
||||
}
|
||||
)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [num_patches] * num_images
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
new_image_features = []
|
||||
for image_idx, image_feature in enumerate(image_features):
|
||||
if image_feature.shape[0] > 1:
|
||||
base_image_feature = image_feature[0]
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
if height * width != base_image_feature.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
self.image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||||
image_feature = torch.cat(
|
||||
(base_image_feature, image_feature), dim=0
|
||||
)
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature, self.image_newline[None]), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_features
|
||||
)
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
return model_inputs
|
||||
|
@ -1,36 +1,9 @@
|
||||
import torch
|
||||
import os
|
||||
from loguru import logger
|
||||
from typing import Dict, Optional
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
ATTENTION = os.getenv("ATTENTION")
|
||||
_expected = {"paged", "flashdecoding", "flashinfer"}
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||
|
||||
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95"))
|
||||
assert TGI_WIGGLE_ROOM > 0
|
||||
assert TGI_WIGGLE_ROOM < 1
|
||||
|
||||
# This is overridden by the cli
|
||||
BLOCK_SIZE: int
|
||||
if ATTENTION == "flashdecoding":
|
||||
BLOCK_SIZE = 256
|
||||
elif ATTENTION == "flashinfer":
|
||||
BLOCK_SIZE = 1
|
||||
else:
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
if cuda_graphs is not None:
|
||||
try:
|
||||
@ -41,13 +14,18 @@ if cuda_graphs is not None:
|
||||
)
|
||||
else:
|
||||
cuda_graphs = None
|
||||
# sorting the cuda graphs in descending order helps reduce the
|
||||
# memory impact and results in less memory usage
|
||||
if cuda_graphs is not None:
|
||||
cuda_graphs.sort(reverse=True)
|
||||
|
||||
CUDA_GRAPHS = cuda_graphs
|
||||
|
||||
# This is overridden at model loading.
|
||||
global MODEL_ID
|
||||
MODEL_ID = None
|
||||
|
||||
|
||||
def set_model_id(model_id: str):
|
||||
global MODEL_ID
|
||||
MODEL_ID = model_id
|
||||
|
||||
# NOTE: eventually we should move this into the router and pass back the
|
||||
# index in all cases.
|
||||
ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None
|
||||
|
47
server/text_generation_server/models/starcoder.py
Normal file
47
server/text_generation_server/models/starcoder.py
Normal file
@ -0,0 +1,47 @@
|
||||
from loguru import logger
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from text_generation_server.models import CausalLM
|
||||
from text_generation_server.models.causal_lm import CausalLMBatch
|
||||
|
||||
|
||||
@dataclass
|
||||
class StarCoderCausalLMBatch(CausalLMBatch):
|
||||
past_key_values: Optional[List[torch.Tensor]]
|
||||
|
||||
def detach_kv_cache(self):
|
||||
past_keys = []
|
||||
past_values = []
|
||||
last_dim = int(self.past_key_values[0].size(dim=-1)/2)
|
||||
for key_value in self.past_key_values:
|
||||
past_keys.append(key_value.split((last_dim, last_dim), dim=-1)[0])
|
||||
past_values.append(key_value.split((last_dim, last_dim), dim=-1)[1])
|
||||
del self.past_key_values
|
||||
|
||||
return past_keys, past_values
|
||||
|
||||
def attach_kv_cache(self, past_keys, past_values):
|
||||
self.past_key_values = [
|
||||
torch.cat((key, value), dim=-1) for key, value in zip(past_keys, past_values)]
|
||||
|
||||
|
||||
class StarCoder(CausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
|
||||
super(StarCoder, self).__init__(
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[CausalLMBatch]:
|
||||
return StarCoderCausalLMBatch
|
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,8 @@
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import time
|
||||
import signal
|
||||
@ -14,23 +17,24 @@ from typing import List, Optional
|
||||
from text_generation_server.cache import Cache
|
||||
from text_generation_server.interceptor import ExceptionInterceptor
|
||||
from text_generation_server.models import Model, get_model_with_lora_adapters
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.globals import set_model_id
|
||||
from text_generation_server.models.globals import set_adapter_to_index
|
||||
from text_generation_server.utils.adapter import AdapterInfo
|
||||
|
||||
try:
|
||||
from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
#from text_generation_server.models.pali_gemma import PaliGemmaBatch
|
||||
from text_generation_server.models.vlm_causal_lm import (
|
||||
VlmCausalLMBatch,
|
||||
)
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
#from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||
|
||||
VLM_BATCH_TYPES = {PaliGemmaBatch, VlmCausalLMBatch, IdeficsCausalLMBatch}
|
||||
VLM_BATCH_TYPES = {VlmCausalLMBatch}
|
||||
except (ImportError, NotImplementedError):
|
||||
# These imports can fail on CPU/Non flash.
|
||||
VLM_BATCH_TYPES = set()
|
||||
|
||||
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||
from text_generation_server.models.globals import set_adapter_to_index
|
||||
from text_generation_server.utils.version import is_driver_compatible, MIN_TGI_GAUDI_SYNAPSE_VERSION
|
||||
|
||||
|
||||
class SignalHandler:
|
||||
@ -58,16 +62,19 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
self.quantize = model.quantize
|
||||
self.server_urls = server_urls
|
||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||
if model.device.type == "cuda":
|
||||
# Force inference mode for the lifetime of TextGenerationService
|
||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||
# TODO: The inferecemode set messes up the autograd op dispatch. And results in aten::matmul
|
||||
# op not optimized issue. Will investigate further.
|
||||
# if model.device.type == "hpu":
|
||||
# Force inference mode for the lifetime of TextGenerationService
|
||||
# self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||
|
||||
|
||||
async def Info(self, request, context):
|
||||
return self.model.info
|
||||
|
||||
async def Health(self, request, context):
|
||||
if self.model.device.type == "cuda":
|
||||
torch.zeros((2, 2)).cuda()
|
||||
if self.model.device.type == "hpu":
|
||||
torch.zeros((2, 2)).to("hpu")
|
||||
return generate_pb2.HealthResponse()
|
||||
|
||||
async def ServiceDiscovery(self, request, context):
|
||||
@ -90,41 +97,17 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
if self.quantize in {"exl2", "gptq"}:
|
||||
try:
|
||||
# When using GPTQ, Exllama kernels need some global kernels
|
||||
# For which we have the finale shapes only after the model has loaded
|
||||
# This will allocate those buffers.
|
||||
from text_generation_server.layers.gptq import (
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
|
||||
set_device(self.model.device)
|
||||
create_exllama_buffers(request.max_prefill_tokens)
|
||||
except ImportError:
|
||||
pass
|
||||
max_supported_total_tokens = self.model.warmup(request)
|
||||
return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens)
|
||||
# else:
|
||||
# batch = self.model.batch_type.from_pb(
|
||||
# request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
# )
|
||||
|
||||
if (
|
||||
self.model.batch_type in VLM_BATCH_TYPES
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
batch = self.model.batch_type.from_pb_processor(
|
||||
request.batch,
|
||||
self.model.tokenizer,
|
||||
self.model.processor,
|
||||
self.model.model.config,
|
||||
self.model.dtype,
|
||||
self.model.device,
|
||||
)
|
||||
else:
|
||||
batch = self.model.batch_type.from_pb(
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
max_supported_total_tokens = self.model.warmup(batch)
|
||||
# max_supported_total_tokens = self.model.warmup(batch)
|
||||
# return generate_pb2.WarmupResponse(max_supported_total_tokens=max_supported_total_tokens)
|
||||
|
||||
return generate_pb2.WarmupResponse(
|
||||
max_supported_total_tokens=max_supported_total_tokens
|
||||
)
|
||||
|
||||
async def Prefill(self, request, context):
|
||||
start = time.time_ns()
|
||||
@ -144,7 +127,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
|
||||
)
|
||||
|
||||
generations, next_batch, timings = self.model.generate_token(batch)
|
||||
generations, next_batch, timings = self.model.generate_token([batch])
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.PrefillResponse(
|
||||
@ -170,21 +153,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
if len(batches) == 0:
|
||||
raise ValueError("All batches are empty")
|
||||
|
||||
if len(batches) > 1:
|
||||
start_concat = time.time_ns()
|
||||
batch = self.model.batch_type.concatenate(batches)
|
||||
concat_ns = time.time_ns() - start_concat
|
||||
else:
|
||||
batch = batches[0]
|
||||
concat_ns = None
|
||||
|
||||
generations, next_batch, timings = self.model.generate_token(batch)
|
||||
generations, next_batch, timings = self.model.generate_token(batches)
|
||||
self.cache.set(next_batch)
|
||||
|
||||
return generate_pb2.DecodeResponse(
|
||||
generations=[generation.to_pb() for generation in generations],
|
||||
batch=next_batch.to_pb() if next_batch else None,
|
||||
concat_ns=concat_ns,
|
||||
concat_ns=None,
|
||||
forward_ns=timings[0],
|
||||
decode_ns=timings[1],
|
||||
total_ns=time.time_ns() - start,
|
||||
@ -213,18 +188,31 @@ def serve(
|
||||
dtype: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if not is_driver_compatible():
|
||||
logger.warning(f"Current Synapse version is lower than the minimum version supported: {MIN_TGI_GAUDI_SYNAPSE_VERSION}, this could result in failures")
|
||||
|
||||
unix_socket_template = "unix://{}-{}"
|
||||
adapter_to_index = {}
|
||||
logger.info("Server:server_inner: sharded ={}".format(sharded))
|
||||
|
||||
if sharded:
|
||||
rank = int(os.environ["RANK"])
|
||||
logger.info("Server:server_inner: rank ={}".format(rank))
|
||||
server_urls = [
|
||||
unix_socket_template.format(uds_path, rank)
|
||||
for rank in range(int(os.environ["WORLD_SIZE"]))
|
||||
unix_socket_template.format(uds_path, rank) for rank in range(int(os.environ["WORLD_SIZE"]))
|
||||
]
|
||||
local_url = server_urls[int(os.environ["RANK"])]
|
||||
else:
|
||||
local_url = unix_socket_template.format(uds_path, 0)
|
||||
server_urls = [local_url]
|
||||
|
||||
logger.info("Server:server_inner: data type = {}, local_url = {}".format(dtype, local_url))
|
||||
if dtype == "bfloat16" or None:
|
||||
data_type = torch.bfloat16
|
||||
else:
|
||||
data_type = torch.float
|
||||
if revision == "None":
|
||||
revision = None
|
||||
try:
|
||||
model = get_model_with_lora_adapters(
|
||||
model_id,
|
||||
@ -233,7 +221,7 @@ def serve(
|
||||
sharded,
|
||||
quantize,
|
||||
speculate,
|
||||
dtype,
|
||||
data_type,
|
||||
trust_remote_code,
|
||||
max_input_tokens,
|
||||
adapter_to_index,
|
||||
@ -271,6 +259,7 @@ def serve(
|
||||
while signal_handler.KEEP_PROCESSING:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
set_model_id(model_id)
|
||||
asyncio.run(
|
||||
serve_inner(
|
||||
model_id,
|
||||
|
45
server/text_generation_server/tgi_service.py
Normal file
45
server/text_generation_server/tgi_service.py
Normal file
@ -0,0 +1,45 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
import sys
|
||||
from text_generation_server import server
|
||||
import argparse
|
||||
from typing import List
|
||||
from text_generation_server.utils.adapter import parse_lora_adapters
|
||||
|
||||
|
||||
def main(args):
|
||||
logger.info("TGIService: starting tgi service .... ")
|
||||
logger.info(
|
||||
"TGIService: --model_id {}, --revision {}, --sharded {}, --speculate {}, --dtype {}, --trust_remote_code {}, --uds_path {} ".format(
|
||||
args.model_id, args.revision, args.sharded, args.speculate, args.dtype, args.trust_remote_code, args.uds_path
|
||||
)
|
||||
)
|
||||
lora_adapters = parse_lora_adapters(os.getenv("LORA_ADAPTERS"))
|
||||
server.serve(
|
||||
model_id=args.model_id,
|
||||
lora_adapters=lora_adapters,
|
||||
revision=args.revision,
|
||||
sharded=args.sharded,
|
||||
quantize=args.quantize,
|
||||
speculate=args.speculate,
|
||||
dtype=args.dtype,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
uds_path=args.uds_path,
|
||||
max_input_tokens=args.max_input_tokens
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_id", type=str)
|
||||
parser.add_argument("--revision", type=str)
|
||||
parser.add_argument("--sharded", type=bool)
|
||||
parser.add_argument("--speculate", type=int, default=None)
|
||||
parser.add_argument("--dtype", type=str)
|
||||
parser.add_argument("--trust_remote_code", type=bool)
|
||||
parser.add_argument("--uds_path", type=Path)
|
||||
parser.add_argument("--quantize", type=str)
|
||||
parser.add_argument("--max_input_tokens", type=int)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -1,3 +1,6 @@
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||
|
||||
import text_generation_server.habana_quantization_env
|
||||
from text_generation_server.utils.convert import convert_file, convert_files
|
||||
from text_generation_server.utils.dist import initialize_torch_distributed
|
||||
from text_generation_server.utils.weights import Weights
|
||||
@ -18,6 +21,9 @@ from text_generation_server.utils.tokens import (
|
||||
FinishReason,
|
||||
Sampling,
|
||||
Greedy,
|
||||
make_tokenizer_optional,
|
||||
is_tokenizer_transparent,
|
||||
pad_next_token_chooser_parameters,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
31
server/text_generation_server/utils/debug.py
Normal file
31
server/text_generation_server/utils/debug.py
Normal file
@ -0,0 +1,31 @@
|
||||
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
|
||||
|
||||
import os
|
||||
import glob
|
||||
import time
|
||||
|
||||
from optimum.habana.utils import to_gb_rounded
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
START_TS = None
|
||||
DBG_TRACE_FILENAME = os.environ.get('DBG_TRACE_FILENAME')
|
||||
if 'GRAPH_VISUALIZATION' in os.environ:
|
||||
for f in glob.glob('.graph_dumps/*'):
|
||||
os.remove(f)
|
||||
|
||||
|
||||
def count_hpu_graphs():
|
||||
return len(glob.glob('.graph_dumps/*PreGraph*'))
|
||||
|
||||
|
||||
def dbg_trace(tag, txt):
|
||||
global START_TS
|
||||
if DBG_TRACE_FILENAME is not None and int(os.getenv("RANK", 0)) == 0:
|
||||
if START_TS is None:
|
||||
START_TS = time.perf_counter()
|
||||
time_offset = time.perf_counter() - START_TS
|
||||
mem_stats = htorch.hpu.memory.memory_stats()
|
||||
mem_used = to_gb_rounded(mem_stats['InUse'])
|
||||
max_mem_used = to_gb_rounded(mem_stats['MaxInUse'])
|
||||
print(f'ts:{time_offset:.3f}s g:{count_hpu_graphs()} mu:{mem_used:.1f}GB '
|
||||
f'mmu:{max_mem_used:.1f}GB | {tag} | {txt}', flush=True, file=open(DBG_TRACE_FILENAME, 'a'))
|
@ -3,7 +3,6 @@ import torch
|
||||
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
# Tensor Parallelism settings
|
||||
RANK = int(os.getenv("RANK", "0"))
|
||||
@ -45,6 +44,12 @@ class FakeGroup:
|
||||
|
||||
|
||||
def initialize_torch_distributed():
|
||||
import habana_frameworks.torch.core as htcore
|
||||
|
||||
rank = int(os.getenv("RANK", "0"))
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
options = None
|
||||
if torch.cuda.is_available():
|
||||
from torch.distributed import ProcessGroupNCCL
|
||||
|
||||
@ -56,9 +61,21 @@ def initialize_torch_distributed():
|
||||
backend = "nccl"
|
||||
options = ProcessGroupNCCL.Options()
|
||||
options.is_high_priority_stream = True
|
||||
options._timeout = timedelta(seconds=120)
|
||||
options._timeout = timedelta(seconds=60)
|
||||
elif torch.hpu.is_available():
|
||||
backend = "hccl"
|
||||
n_hpus = torch.hpu.device_count()
|
||||
if world_size > n_hpus:
|
||||
raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).")
|
||||
else:
|
||||
backend = "gloo"
|
||||
try:
|
||||
import oneccl_bindings_for_pytorch
|
||||
|
||||
backend = "ccl"
|
||||
if os.getenv("CCL_WORKER_COUNT", None) is None:
|
||||
os.environ["CCL_WORKER_COUNT"] = str(1)
|
||||
except ImportError:
|
||||
backend = "gloo"
|
||||
options = None
|
||||
|
||||
if WORLD_SIZE == 1:
|
||||
@ -69,24 +86,13 @@ def initialize_torch_distributed():
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
# Call the init process.
|
||||
if SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
ipex.distributed.init_process_group(
|
||||
backend="ccl",
|
||||
world_size=WORLD_SIZE,
|
||||
rank=RANK,
|
||||
timeout=timedelta(seconds=120),
|
||||
pg_options=options,
|
||||
)
|
||||
else:
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
world_size=WORLD_SIZE,
|
||||
rank=RANK,
|
||||
timeout=timedelta(seconds=120),
|
||||
pg_options=options,
|
||||
)
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
world_size=WORLD_SIZE,
|
||||
rank=RANK,
|
||||
timeout=timedelta(seconds=60),
|
||||
pg_options=options,
|
||||
)
|
||||
else:
|
||||
logger.warning("torch.distributed is already initialized.")
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import math
|
||||
import torch
|
||||
import habana_frameworks.torch.core as htcore
|
||||
|
||||
from loguru import logger
|
||||
from typing import Dict, Union
|
||||
@ -43,37 +44,31 @@ class StaticWarper:
|
||||
if typical_p is not None and typical_p < 1.0:
|
||||
self.warpers.append(TypicalLogitsWarper(mass=typical_p))
|
||||
|
||||
self.cuda_graph = None
|
||||
self.hpu_graph = None
|
||||
self.static_scores = None
|
||||
self.static_warped_scores = None
|
||||
self.static_next_logprob = None
|
||||
|
||||
def __call__(self, scores):
|
||||
if torch.cuda.is_available():
|
||||
if self.cuda_graph is None:
|
||||
self.static_scores = scores
|
||||
self.cuda_graph = torch.cuda.CUDAGraph()
|
||||
if self.hpu_graph is None:
|
||||
self.static_scores = scores.clone().contiguous()
|
||||
self.static_warped_scores = scores.clone().contiguous()
|
||||
self.static_next_logprob = scores.clone().contiguous()
|
||||
self.hpu_graph = htcore.hpu.HPUGraph()
|
||||
|
||||
with torch.cuda.graph(self.cuda_graph, pool=mempool):
|
||||
local_scores = self.static_scores
|
||||
for warper in self.warpers:
|
||||
local_scores = warper(None, local_scores)
|
||||
with htcore.hpu.graph(self.hpu_graph):
|
||||
local_scores = self.static_scores
|
||||
for warper in self.warpers:
|
||||
local_scores = warper(None, local_scores)
|
||||
|
||||
self.static_warped_scores = local_scores
|
||||
# Compute logprobs
|
||||
self.static_next_logprob = torch.log_softmax(
|
||||
self.static_warped_scores, -1
|
||||
)
|
||||
self.static_warped_scores.copy_(local_scores)
|
||||
# Compute logprobs
|
||||
self.static_next_logprob.copy_(torch.log_softmax(self.static_warped_scores, -1))
|
||||
|
||||
self.static_scores.copy_(scores)
|
||||
self.cuda_graph.replay()
|
||||
self.static_scores.copy_(scores)
|
||||
self.hpu_graph.replay()
|
||||
|
||||
return self.static_warped_scores, self.static_next_logprob
|
||||
|
||||
# CPU branch
|
||||
for warper in self.warpers:
|
||||
scores = warper(None, scores)
|
||||
return scores, torch.log_softmax(scores, -1)
|
||||
return self.static_warped_scores, self.static_next_logprob
|
||||
|
||||
|
||||
@lru_cache(10)
|
||||
@ -83,9 +78,7 @@ def static_warper(
|
||||
top_p: Optional[float],
|
||||
typical_p: Optional[float],
|
||||
) -> StaticWarper:
|
||||
return StaticWarper(
|
||||
temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p
|
||||
)
|
||||
return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p)
|
||||
|
||||
|
||||
class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
@ -102,17 +95,13 @@ class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, penalty: List[float], dtype: torch.dtype, device: torch.device):
|
||||
self.penalty = penalty
|
||||
self.penalty_tensor = torch.tensor(
|
||||
penalty, dtype=dtype, device=device
|
||||
).unsqueeze(1)
|
||||
self.penalty_tensor = torch.tensor(penalty, dtype=dtype, device=device).unsqueeze(1)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(
|
||||
score < 0, score * self.penalty_tensor, score / self.penalty_tensor
|
||||
)
|
||||
score = torch.where(score < 0, score * self.penalty_tensor, score / self.penalty_tensor)
|
||||
|
||||
scores.scatter_(1, input_ids, score)
|
||||
return scores
|
||||
@ -170,9 +159,11 @@ class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor):
|
||||
vocab_size = scores.size(1)
|
||||
|
||||
# Calculate the frequency for each token so far
|
||||
token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device)
|
||||
token_freq = torch.zeros(
|
||||
batch_size, vocab_size, dtype=scores.dtype, device=scores.device
|
||||
)
|
||||
token_freq.scatter_add_(
|
||||
1, input_ids, torch.ones_like(input_ids, dtype=torch.float)
|
||||
1, input_ids, torch.ones_like(input_ids, dtype=scores.dtype, device=scores.device)
|
||||
)
|
||||
token_freq /= input_size
|
||||
|
||||
@ -199,13 +190,9 @@ class HeterogeneousTemperatureLogitsWarper:
|
||||
The value used to module the logits distribution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, temperature: List[float], dtype: torch.dtype, device: torch.device
|
||||
):
|
||||
def __init__(self, temperature: List[float], dtype: torch.dtype, device: torch.device):
|
||||
self.temperature = temperature
|
||||
self.temperature_tensor = torch.tensor(
|
||||
temperature, dtype=dtype, device=device
|
||||
).unsqueeze(1)
|
||||
self.temperature_tensor = torch.tensor(temperature, dtype=dtype, device=device).unsqueeze(1)
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
|
||||
scores.div_(self.temperature_tensor)
|
||||
@ -244,9 +231,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
min_tokens_to_keep: int = 1,
|
||||
):
|
||||
self.top_p = top_p
|
||||
self.top_p_opposite = 1 - torch.tensor(
|
||||
top_p, dtype=dtype, device=device
|
||||
).unsqueeze(1)
|
||||
self.top_p_opposite = 1 - torch.tensor(top_p, dtype=dtype, device=device).unsqueeze(1)
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
@ -263,9 +248,7 @@ class HeterogeneousTopPLogitsWarper(LogitsWarper):
|
||||
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||
|
||||
return warped_scores
|
||||
@ -313,9 +296,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||
disabled = [x == 0 for x in top_k]
|
||||
|
||||
if any(disabled):
|
||||
self.top_k_disabled_mask = torch.tensor(
|
||||
disabled, dtype=torch.bool, device=device
|
||||
).view(-1, 1)
|
||||
self.top_k_disabled_mask = torch.tensor(disabled, dtype=torch.bool, device=device).view(-1, 1)
|
||||
else:
|
||||
self.top_k_disabled_mask = None
|
||||
|
||||
@ -351,9 +332,7 @@ class HeterogeneousTopKLogitsWarper(LogitsWarper):
|
||||
self.max_top_k = max(self.top_k)
|
||||
|
||||
if self.top_k_disabled_mask is not None:
|
||||
self.top_k_disabled_mask = (
|
||||
self.top_k_disabled_mask[indices] if any(disabled) else None
|
||||
)
|
||||
self.top_k_disabled_mask = self.top_k_disabled_mask[indices] if any(disabled) else None
|
||||
|
||||
return self
|
||||
return None
|
||||
@ -419,15 +398,11 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||
if self.disabled_mask is not None:
|
||||
last_ind.masked_fill_(self.disabled_mask, scores.shape[-1] - 1)
|
||||
|
||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
||||
1, last_ind.view(-1, 1)
|
||||
)
|
||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
|
||||
warped_scores = scores.masked_fill_(indices_to_remove, self.filter_value)
|
||||
|
||||
@ -441,9 +416,7 @@ class HeterogeneousTypicalLogitsWarper(LogitsWarper):
|
||||
self.mass_tensor = self.mass_tensor[indices]
|
||||
|
||||
if self.disabled_mask is not None:
|
||||
self.disabled_mask = (
|
||||
self.disabled_mask[indices] if any(disabled) else None
|
||||
)
|
||||
self.disabled_mask = self.disabled_mask[indices] if any(disabled) else None
|
||||
|
||||
return self
|
||||
return None
|
||||
@ -521,13 +494,7 @@ class GrammarLogitProcessor(LogitsProcessor):
|
||||
def _cached_compile_fsm(grammar_type, schema, tokenizer):
|
||||
start_time = time.time()
|
||||
if grammar_type == GrammarType.GRAMMAR_TYPE_JSON:
|
||||
try:
|
||||
schema = build_regex_from_schema(schema)
|
||||
# TODO: this is only here short term to avoid crashing the python server, mid term we want this in the rust/router layer
|
||||
except Exception as e:
|
||||
logger.error(f"Error compiling FSM, grammar won't be enforced \n{e}")
|
||||
# allows everything
|
||||
schema = "(.*?)"
|
||||
schema = build_regex_from_schema(schema)
|
||||
elif grammar_type == GrammarType.GRAMMAR_TYPE_REGEX:
|
||||
pass # schema is already a regex just here for clarity
|
||||
fsm = RegexFSM(schema, tokenizer)
|
||||
@ -586,7 +553,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
|
||||
mask = torch.full_like(logits, -math.inf)
|
||||
for i in range(logits.shape[0]):
|
||||
fsm = self.fsms[i]
|
||||
if fsm_grammar_states[i] == -1 or fsm is None:
|
||||
if fsm is None:
|
||||
continue
|
||||
allowed_tokens = fsm.allowed_token_ids(fsm_grammar_states[i])
|
||||
mask[i, allowed_tokens] = 0
|
||||
|
@ -247,10 +247,12 @@ class HeterogeneousNextTokenChooser:
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
grammars: List[str],
|
||||
grammar_types: List[int],
|
||||
fsm_grammar_states=List[int],
|
||||
fsm_grammar_states:List[int],
|
||||
quantization_enabled: bool,
|
||||
):
|
||||
warpers = []
|
||||
|
||||
# TODO: enable watermark with FP8 quantization
|
||||
self.watermark_processor = (
|
||||
HeterogeneousProcessorWrapper(
|
||||
{
|
||||
@ -259,7 +261,7 @@ class HeterogeneousNextTokenChooser:
|
||||
if do_watermark
|
||||
}
|
||||
)
|
||||
if any(watermark)
|
||||
if any(watermark) and not quantization_enabled
|
||||
else None
|
||||
)
|
||||
|
||||
@ -431,6 +433,18 @@ class HeterogeneousNextTokenChooser:
|
||||
)
|
||||
return self
|
||||
|
||||
def advance_grammar_single_with_past_state(
|
||||
self, grammar_state_index: int, next_id: torch.Tensor, past_state: int
|
||||
):
|
||||
if self.grammar_processor is not None:
|
||||
next_id = next_id.item()
|
||||
self.fsm_grammar_states[grammar_state_index] = (
|
||||
self.grammar_processor.advance_at_index(
|
||||
next_id, past_state, grammar_state_index,
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def filter(self, indices):
|
||||
if self.watermark_processor is not None:
|
||||
self.watermark_processor = self.watermark_processor.filter(indices)
|
||||
@ -481,6 +495,7 @@ class HeterogeneousNextTokenChooser:
|
||||
device: torch.device,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fsm_grammar_states: Optional[List[int]] = None,
|
||||
quantization_enabled: bool = False,
|
||||
) -> "HeterogeneousNextTokenChooser":
|
||||
return HeterogeneousNextTokenChooser(
|
||||
watermark=[pb_.watermark for pb_ in pb],
|
||||
@ -500,12 +515,37 @@ class HeterogeneousNextTokenChooser:
|
||||
fsm_grammar_states=(
|
||||
fsm_grammar_states if fsm_grammar_states else [0] * len(pb)
|
||||
),
|
||||
quantization_enabled=quantization_enabled,
|
||||
)
|
||||
|
||||
|
||||
def pad_next_token_chooser_parameters(
|
||||
parameters: List[generate_pb2.NextTokenChooserParameters],
|
||||
expected_size: int,
|
||||
) -> List[generate_pb2.NextTokenChooserParameters]:
|
||||
# disable all logits processors to minimize padding overhead
|
||||
empty_parameters = generate_pb2.NextTokenChooserParameters(
|
||||
temperature=1.0,
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
typical_p=1.0,
|
||||
do_sample=False,
|
||||
seed=0,
|
||||
repetition_penalty=1.0,
|
||||
frequency_penalty=0.0,
|
||||
watermark=False,
|
||||
grammar="",
|
||||
grammar_type=0,
|
||||
)
|
||||
parameters.extend(
|
||||
[empty_parameters] * (expected_size - len(parameters))
|
||||
)
|
||||
return parameters
|
||||
|
||||
|
||||
class Sampling:
|
||||
def __init__(self, seed: int, device: str = "cpu"):
|
||||
self.generator = torch.Generator(device)
|
||||
self.generator = torch.Generator("cpu")
|
||||
self.generator.manual_seed(seed)
|
||||
self.seed = seed
|
||||
|
||||
@ -541,7 +581,7 @@ class HeterogeneousSampling:
|
||||
self.greedy = Greedy()
|
||||
|
||||
def __call__(self, logits):
|
||||
out = torch.empty(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||
out = torch.zeros(logits.shape[0], dtype=torch.int64, device=logits.device)
|
||||
if self.greedy_indices:
|
||||
# Computing for all indices is faster than slicing
|
||||
torch.argmax(logits, -1, out=out)
|
||||
@ -643,3 +683,50 @@ def batch_top_tokens(
|
||||
batch_top_token_logprobs.append(row_top_token_logprobs)
|
||||
|
||||
return batch_top_token_ids, batch_top_token_logprobs
|
||||
|
||||
|
||||
def make_tokenizer_optional(tokenizer):
|
||||
class _(type(tokenizer)):
|
||||
def __call__(
|
||||
self,
|
||||
text,
|
||||
return_tensors,
|
||||
padding,
|
||||
return_token_type_ids,
|
||||
truncation,
|
||||
max_length
|
||||
):
|
||||
assert return_tensors == "pt", "inccorrect input arguments when calling TransparentTokenizer"
|
||||
assert padding == "max_length" or padding == "longest", "inccorrect input arguments when calling TransparentTokenizer"
|
||||
assert return_token_type_ids == False, "inccorrect input arguments when calling TransparentTokenizer"
|
||||
assert truncation == True, "inccorrect input arguments when calling TransparentTokenizer"
|
||||
|
||||
def str_token_to_int(i):
|
||||
if i == '?':
|
||||
return tokenizer.pad_token_id
|
||||
else:
|
||||
return int(i)
|
||||
all_tokens = [[str_token_to_int(i.strip()) for i in inner_text.split(',')]
|
||||
for inner_text in text]
|
||||
if padding == "longest":
|
||||
max_length = max(len(tokens) for tokens in all_tokens)
|
||||
return {"input_ids": torch.tensor([[tokenizer.pad_token_id] * (max_length - len(tokens)) + tokens for tokens in all_tokens]),
|
||||
"attention_mask": torch.tensor([[0] * (max_length - len(tokens)) + [1] * len(tokens) for tokens in all_tokens])}
|
||||
|
||||
def decode(
|
||||
self,
|
||||
token_ids,
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
return ','.join(str(i) for i in to_py_obj(token_ids))
|
||||
|
||||
import os
|
||||
if os.getenv("SKIP_TOKENIZER_IN_TGI", "false").lower() == "true":
|
||||
tokenizer.__class__ = _
|
||||
tokenizer.is_transparent = True
|
||||
|
||||
|
||||
def is_tokenizer_transparent(tokenizer):
|
||||
return hasattr(tokenizer, "is_transparent") and tokenizer.is_transparent is True
|
||||
|
12
server/text_generation_server/utils/version.py
Normal file
12
server/text_generation_server/utils/version.py
Normal file
@ -0,0 +1,12 @@
|
||||
from optimum.habana.utils import get_driver_version
|
||||
from packaging.version import Version
|
||||
|
||||
MIN_TGI_GAUDI_SYNAPSE_VERSION=Version("1.16.0")
|
||||
|
||||
|
||||
def is_driver_compatible():
|
||||
driver_version = get_driver_version()
|
||||
if driver_version is not None:
|
||||
if driver_version < MIN_TGI_GAUDI_SYNAPSE_VERSION:
|
||||
return False
|
||||
return True
|
@ -34,7 +34,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
||||
# watermarking parameters
|
||||
self.gamma = gamma
|
||||
self.delta = delta
|
||||
self.rng = torch.Generator(device=device)
|
||||
self.rng = torch.Generator(device="cpu")
|
||||
self.hash_key = hash_key
|
||||
|
||||
def _seed_rng(self, input_ids: Union[List[int], torch.LongTensor]):
|
||||
|
Loading…
Reference in New Issue
Block a user