mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-03 16:22:06 +00:00
Merge branch 'habana-main' into rebase_tgi_2.0
This commit is contained in:
commit
600d033c04
@ -58,6 +58,7 @@ COPY server/Makefile server/Makefile
|
|||||||
RUN cd server && \
|
RUN cd server && \
|
||||||
make gen-server && \
|
make gen-server && \
|
||||||
pip install -r requirements.txt && \
|
pip install -r requirements.txt && \
|
||||||
|
bash ./dill-0.3.7-patch.sh && \
|
||||||
pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0 && \
|
pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.0 && \
|
||||||
pip install . --no-cache-dir
|
pip install . --no-cache-dir
|
||||||
|
|
||||||
|
91
server/dill-0.3.7-patch.sh
Normal file
91
server/dill-0.3.7-patch.sh
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
git clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git
|
||||||
|
pushd dill
|
||||||
|
cat <<EOF > dill-0.3.7.patch
|
||||||
|
diff --git a/dill/_dill.py b/dill/_dill.py
|
||||||
|
index d0cf543..f6eb662 100644
|
||||||
|
--- a/dill/_dill.py
|
||||||
|
+++ b/dill/_dill.py
|
||||||
|
@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered
|
||||||
|
XRangeType = range
|
||||||
|
from types import MappingProxyType as DictProxyType, new_class
|
||||||
|
from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError
|
||||||
|
-import __main__ as _main_module
|
||||||
|
+class _LazyMainModule(object):
|
||||||
|
+ _module = None
|
||||||
|
+ @property
|
||||||
|
+ def module(self):
|
||||||
|
+ if self._module is None:
|
||||||
|
+ import __main__ as _m_module
|
||||||
|
+ self._module = _m_module
|
||||||
|
+ return self._module
|
||||||
|
+_main_module = _LazyMainModule()
|
||||||
|
import marshal
|
||||||
|
import gc
|
||||||
|
# import zlib
|
||||||
|
@@ -353,7 +361,7 @@ class Pickler(StockPickler):
|
||||||
|
_fmode = kwds.pop('fmode', None)
|
||||||
|
_recurse = kwds.pop('recurse', None)
|
||||||
|
StockPickler.__init__(self, file, *args, **kwds)
|
||||||
|
- self._main = _main_module
|
||||||
|
+ self._main = _main_module.module
|
||||||
|
self._diff_cache = {}
|
||||||
|
self._byref = settings['byref'] if _byref is None else _byref
|
||||||
|
self._strictio = False #_strictio
|
||||||
|
@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler):
|
||||||
|
settings = Pickler.settings
|
||||||
|
_ignore = kwds.pop('ignore', None)
|
||||||
|
StockUnpickler.__init__(self, *args, **kwds)
|
||||||
|
- self._main = _main_module
|
||||||
|
+ self._main = _main_module.module
|
||||||
|
self._ignore = settings['ignore'] if _ignore is None else _ignore
|
||||||
|
|
||||||
|
def load(self): #NOTE: if settings change, need to update attributes
|
||||||
|
obj = StockUnpickler.load(self)
|
||||||
|
- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'):
|
||||||
|
+ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'):
|
||||||
|
if not self._ignore:
|
||||||
|
# point obj class to main
|
||||||
|
try: obj.__class__ = getattr(self._main, type(obj).__name__)
|
||||||
|
@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj):
|
||||||
|
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
|
||||||
|
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
|
||||||
|
logger.trace(pickler, "# D1")
|
||||||
|
- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
|
||||||
|
+ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__):
|
||||||
|
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
|
||||||
|
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
|
||||||
|
logger.trace(pickler, "# D3")
|
||||||
|
- elif '__name__' in obj and obj != _main_module.__dict__ \\
|
||||||
|
+ elif '__name__' in obj and obj != _main_module.module.__dict__ \\
|
||||||
|
and type(obj['__name__']) is str \\
|
||||||
|
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
|
||||||
|
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
|
||||||
|
diff --git a/dill/session.py b/dill/session.py
|
||||||
|
index 74234ab..1be8d89 100644
|
||||||
|
--- a/dill/session.py
|
||||||
|
+++ b/dill/session.py
|
||||||
|
@@ -233,7 +233,7 @@ def dump_module(
|
||||||
|
protocol = settings['protocol']
|
||||||
|
main = module
|
||||||
|
if main is None:
|
||||||
|
- main = _main_module
|
||||||
|
+ main = _main_module.module
|
||||||
|
elif isinstance(main, str):
|
||||||
|
main = _import_module(main)
|
||||||
|
if not isinstance(main, ModuleType):
|
||||||
|
@@ -501,7 +501,7 @@ def load_module(
|
||||||
|
pass
|
||||||
|
assert loaded is main
|
||||||
|
_restore_modules(unpickler, main)
|
||||||
|
- if main is _main_module or main is module:
|
||||||
|
+ if main is _main_module.module or main is module:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return main
|
||||||
|
|
||||||
|
EOF
|
||||||
|
git apply dill-0.3.7.patch
|
||||||
|
python -m pip install .
|
||||||
|
popd
|
||||||
|
rm -fr dill
|
@ -16,6 +16,8 @@ from text_generation_server.models.causal_lm import CausalLM
|
|||||||
from text_generation_server.models.bloom import BLOOM
|
from text_generation_server.models.bloom import BLOOM
|
||||||
from text_generation_server.models.santacoder import SantaCoder
|
from text_generation_server.models.santacoder import SantaCoder
|
||||||
|
|
||||||
|
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||||
|
|
||||||
|
|
||||||
# Disable gradients
|
# Disable gradients
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
@ -28,6 +30,8 @@ def get_model(
|
|||||||
dtype: Optional[torch.dtype],
|
dtype: Optional[torch.dtype],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
|
adapt_transformers_to_gaudi()
|
||||||
|
|
||||||
if speculate is not None:
|
if speculate is not None:
|
||||||
set_speculate(speculate)
|
set_speculate(speculate)
|
||||||
else:
|
else:
|
||||||
|
@ -17,7 +17,6 @@ from opentelemetry import trace
|
|||||||
import text_generation_server.habana_quantization_env as hq_env
|
import text_generation_server.habana_quantization_env as hq_env
|
||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
|
||||||
from optimum.habana.utils import HabanaProfile
|
from optimum.habana.utils import HabanaProfile
|
||||||
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
|
||||||
from optimum.habana.checkpoint_utils import (
|
from optimum.habana.checkpoint_utils import (
|
||||||
@ -584,8 +583,6 @@ class CausalLM(Model):
|
|||||||
if use_medusa:
|
if use_medusa:
|
||||||
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
raise RuntimeError("Medusa decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
adapt_transformers_to_gaudi()
|
|
||||||
|
|
||||||
# Create tokenizer
|
# Create tokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user