diff --git a/Dockerfile b/Dockerfile index 85463af1..9fe0b49b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -56,14 +56,16 @@ WORKDIR /usr/src # Install torch RUN pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir -COPY server/Makefile server/Makefile - # Install specific version of flash attention +COPY server/Makefile-flash-att server/Makefile RUN cd server && make install-flash-attention # Install specific version of transformers +COPY server/Makefile-transformers server/Makefile RUN cd server && BUILD_EXTENSIONS="True" make install-transformers +COPY server/Makefile server/Makefile + # Install server COPY proto proto COPY server server diff --git a/server/Makefile b/server/Makefile index 1c1d3578..d827ceca 100644 --- a/server/Makefile +++ b/server/Makefile @@ -1,5 +1,5 @@ -transformers_commit := b8d969ff47c6a9d40538a6ea33df021953363afc -flash_att_commit := 4d87e4d875077ad9efd25030efa4ab0ba92c19e1 +include Makefile-transformers +include Makefile-flash-att gen-server: # Compile protos @@ -10,24 +10,6 @@ gen-server: find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; touch text_generation_server/pb/__init__.py -install-transformers: - # Install specific version of transformers with custom cuda kernels - pip install --upgrade setuptools - pip uninstall transformers -y || true - rm -rf transformers || true - git clone https://github.com/OlivierDehaene/transformers.git - cd transformers && git checkout $(transformers_commit) - cd transformers && python setup.py install - -install-flash-attention: - # Install specific version of flash attention - pip install packaging - pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true - rm -rf flash-attention || true - git clone https://github.com/HazyResearch/flash-attention.git - cd flash-attention && git checkout $(flash_att_commit) - cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install - install-torch: # Install specific version of torch pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir diff --git a/server/Makefile-flash-att b/server/Makefile-flash-att new file mode 100644 index 00000000..fa2fe3fd --- /dev/null +++ b/server/Makefile-flash-att @@ -0,0 +1,10 @@ +flash_att_commit := 221670026643da10fa18391eb995ef6d9b407530 + +install-flash-attention: + # Install specific version of flash attention + pip install packaging + pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true + rm -rf flash-attention || true + git clone https://github.com/HazyResearch/flash-attention.git + cd flash-attention && git checkout $(flash_att_commit) + cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install \ No newline at end of file diff --git a/server/Makefile-transformers b/server/Makefile-transformers new file mode 100644 index 00000000..1e081336 --- /dev/null +++ b/server/Makefile-transformers @@ -0,0 +1,10 @@ +transformers_commit := b8d969ff47c6a9d40538a6ea33df021953363afc + +install-transformers: + # Install specific version of transformers with custom cuda kernels + pip install --upgrade setuptools + pip uninstall transformers -y || true + rm -rf transformers || true + git clone https://github.com/OlivierDehaene/transformers.git + cd transformers && git checkout $(transformers_commit) + cd transformers && python setup.py install \ No newline at end of file diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 85039cdf..f1db6395 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -39,7 +39,7 @@ class FlashLlama(FlashCausalLM): raise NotImplementedError("FlashLlama does not support quantization") tokenizer = LlamaTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left" + model_id, revision=revision, padding_side="left", ) config = AutoConfig.from_pretrained( @@ -155,7 +155,7 @@ class FlashLlamaSharded(FlashLlama): raise NotImplementedError("FlashLlama does not support quantization") tokenizer = LlamaTokenizer.from_pretrained( - model_id, revision=revision, padding_side="left" + model_id, revision=revision, padding_side="left", ) config = AutoConfig.from_pretrained(