better docker layer caching

This commit is contained in:
OlivierDehaene 2023-04-06 20:08:46 +02:00
parent c3779fa859
commit 1111125092
5 changed files with 28 additions and 24 deletions

View File

@ -56,14 +56,16 @@ WORKDIR /usr/src
# Install torch # Install torch
RUN pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir RUN pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir
COPY server/Makefile server/Makefile
# Install specific version of flash attention # Install specific version of flash attention
COPY server/Makefile-flash-att server/Makefile
RUN cd server && make install-flash-attention RUN cd server && make install-flash-attention
# Install specific version of transformers # Install specific version of transformers
COPY server/Makefile-transformers server/Makefile
RUN cd server && BUILD_EXTENSIONS="True" make install-transformers RUN cd server && BUILD_EXTENSIONS="True" make install-transformers
COPY server/Makefile server/Makefile
# Install server # Install server
COPY proto proto COPY proto proto
COPY server server COPY server server

View File

@ -1,5 +1,5 @@
transformers_commit := b8d969ff47c6a9d40538a6ea33df021953363afc include Makefile-transformers
flash_att_commit := 4d87e4d875077ad9efd25030efa4ab0ba92c19e1 include Makefile-flash-att
gen-server: gen-server:
# Compile protos # Compile protos
@ -10,24 +10,6 @@ gen-server:
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \; find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
touch text_generation_server/pb/__init__.py touch text_generation_server/pb/__init__.py
install-transformers:
# Install specific version of transformers with custom cuda kernels
pip install --upgrade setuptools
pip uninstall transformers -y || true
rm -rf transformers || true
git clone https://github.com/OlivierDehaene/transformers.git
cd transformers && git checkout $(transformers_commit)
cd transformers && python setup.py install
install-flash-attention:
# Install specific version of flash attention
pip install packaging
pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true
rm -rf flash-attention || true
git clone https://github.com/HazyResearch/flash-attention.git
cd flash-attention && git checkout $(flash_att_commit)
cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install
install-torch: install-torch:
# Install specific version of torch # Install specific version of torch
pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 --no-cache-dir

10
server/Makefile-flash-att Normal file
View File

@ -0,0 +1,10 @@
flash_att_commit := 221670026643da10fa18391eb995ef6d9b407530
install-flash-attention:
# Install specific version of flash attention
pip install packaging
pip uninstall flash_attn rotary_emb dropout_layer_norm -y || true
rm -rf flash-attention || true
git clone https://github.com/HazyResearch/flash-attention.git
cd flash-attention && git checkout $(flash_att_commit)
cd flash-attention && python setup.py install && cd csrc/layer_norm && python setup.py install && cd ../rotary && python setup.py install

View File

@ -0,0 +1,10 @@
transformers_commit := b8d969ff47c6a9d40538a6ea33df021953363afc
install-transformers:
# Install specific version of transformers with custom cuda kernels
pip install --upgrade setuptools
pip uninstall transformers -y || true
rm -rf transformers || true
git clone https://github.com/OlivierDehaene/transformers.git
cd transformers && git checkout $(transformers_commit)
cd transformers && python setup.py install

View File

@ -39,7 +39,7 @@ class FlashLlama(FlashCausalLM):
raise NotImplementedError("FlashLlama does not support quantization") raise NotImplementedError("FlashLlama does not support quantization")
tokenizer = LlamaTokenizer.from_pretrained( tokenizer = LlamaTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left",
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
@ -155,7 +155,7 @@ class FlashLlamaSharded(FlashLlama):
raise NotImplementedError("FlashLlama does not support quantization") raise NotImplementedError("FlashLlama does not support quantization")
tokenizer = LlamaTokenizer.from_pretrained( tokenizer = LlamaTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left",
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(