From 37aabf8571f19d77c693dc050006b7f4b9fbafed Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Fri, 26 Apr 2024 11:07:27 +0200 Subject: [PATCH 1/2] Move call to `adapt_transformers_to_gaudi` earlier in the code (#133) --- server/text_generation_server/models/__init__.py | 3 +++ server/text_generation_server/models/causal_lm.py | 3 --- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index efe9b62a..ce252ba1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -10,6 +10,8 @@ from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.bloom import BLOOM from text_generation_server.models.santacoder import SantaCoder +from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + # Disable gradients torch.set_grad_enabled(False) @@ -20,6 +22,7 @@ def get_model( revision: Optional[str], dtype: Optional[torch.dtype] = None, ) -> Model: + adapt_transformers_to_gaudi() config = AutoConfig.from_pretrained(model_id, revision=revision) model_type = config.model_type diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index bdc0b4c5..97a9fd6f 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -16,7 +16,6 @@ from opentelemetry import trace import text_generation_server.habana_quantization_env as hq_env import habana_frameworks.torch as htorch from habana_frameworks.torch.hpu import wrap_in_hpu_graph -from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi from optimum.habana.utils import HabanaProfile from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES from optimum.habana.checkpoint_utils import ( @@ -572,8 +571,6 @@ class CausalLM(Model): revision: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): - adapt_transformers_to_gaudi() - # Create tokenizer tokenizer = AutoTokenizer.from_pretrained( model_id, From 91eb4e555f5685f60949b9da2579c96884b81705 Mon Sep 17 00:00:00 2001 From: Yaser Afshar Date: Fri, 26 Apr 2024 02:08:15 -0700 Subject: [PATCH 2/2] Hgraph dill patch (#131) --- Dockerfile | 1 + server/dill-0.3.7-patch.sh | 91 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 server/dill-0.3.7-patch.sh diff --git a/Dockerfile b/Dockerfile index c49f43e6..481bfb2a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -58,6 +58,7 @@ COPY server/Makefile server/Makefile RUN cd server && \ make gen-server && \ pip install -r requirements.txt && \ + bash ./dill-0.3.7-patch.sh && \ pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.14.0 && \ pip install . --no-cache-dir diff --git a/server/dill-0.3.7-patch.sh b/server/dill-0.3.7-patch.sh new file mode 100644 index 00000000..ad8c8be5 --- /dev/null +++ b/server/dill-0.3.7-patch.sh @@ -0,0 +1,91 @@ +#!/bin/bash +git clone -b dill-0.3.7 https://github.com/uqfoundation/dill.git +pushd dill +cat < dill-0.3.7.patch +diff --git a/dill/_dill.py b/dill/_dill.py +index d0cf543..f6eb662 100644 +--- a/dill/_dill.py ++++ b/dill/_dill.py +@@ -69,7 +69,15 @@ TypeType = type # 'new-style' classes #XXX: unregistered + XRangeType = range + from types import MappingProxyType as DictProxyType, new_class + from pickle import DEFAULT_PROTOCOL, HIGHEST_PROTOCOL, PickleError, PicklingError, UnpicklingError +-import __main__ as _main_module ++class _LazyMainModule(object): ++ _module = None ++ @property ++ def module(self): ++ if self._module is None: ++ import __main__ as _m_module ++ self._module = _m_module ++ return self._module ++_main_module = _LazyMainModule() + import marshal + import gc + # import zlib +@@ -353,7 +361,7 @@ class Pickler(StockPickler): + _fmode = kwds.pop('fmode', None) + _recurse = kwds.pop('recurse', None) + StockPickler.__init__(self, file, *args, **kwds) +- self._main = _main_module ++ self._main = _main_module.module + self._diff_cache = {} + self._byref = settings['byref'] if _byref is None else _byref + self._strictio = False #_strictio +@@ -435,12 +443,12 @@ class Unpickler(StockUnpickler): + settings = Pickler.settings + _ignore = kwds.pop('ignore', None) + StockUnpickler.__init__(self, *args, **kwds) +- self._main = _main_module ++ self._main = _main_module.module + self._ignore = settings['ignore'] if _ignore is None else _ignore + + def load(self): #NOTE: if settings change, need to update attributes + obj = StockUnpickler.load(self) +- if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'): ++ if type(obj).__module__ == getattr(self._main, '__name__', '__main__'): + if not self._ignore: + # point obj class to main + try: obj.__class__ = getattr(self._main, type(obj).__name__) +@@ -1194,11 +1202,11 @@ def save_module_dict(pickler, obj): + logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj + pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8')) + logger.trace(pickler, "# D1") +- elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__): ++ elif (not is_dill(pickler, child=False)) and (obj == _main_module.module.__dict__): + logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj + pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general? + logger.trace(pickler, "# D3") +- elif '__name__' in obj and obj != _main_module.__dict__ \\ ++ elif '__name__' in obj and obj != _main_module.module.__dict__ \\ + and type(obj['__name__']) is str \\ + and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None): + logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj +diff --git a/dill/session.py b/dill/session.py +index 74234ab..1be8d89 100644 +--- a/dill/session.py ++++ b/dill/session.py +@@ -233,7 +233,7 @@ def dump_module( + protocol = settings['protocol'] + main = module + if main is None: +- main = _main_module ++ main = _main_module.module + elif isinstance(main, str): + main = _import_module(main) + if not isinstance(main, ModuleType): +@@ -501,7 +501,7 @@ def load_module( + pass + assert loaded is main + _restore_modules(unpickler, main) +- if main is _main_module or main is module: ++ if main is _main_module.module or main is module: + return None + else: + return main + +EOF +git apply dill-0.3.7.patch +python -m pip install . +popd +rm -fr dill