From 50fd663b9bf1f0927032679e9e9e26facdbe363e Mon Sep 17 00:00:00 2001 From: Nikola Borisov Date: Wed, 9 Aug 2023 14:43:08 -0700 Subject: [PATCH] Fix docker build, pinning the pytorch version. Before this change we would sometimes get a version of pytorch installed without cuda support. Some people reported different versions of pytorch like 11.7 instead of 11.8 --- Dockerfile | 2 +- server/pyproject.toml | 6 ++++++ server/requirements.txt | 4 ++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 34109d02..587ab9b8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -39,7 +39,7 @@ RUN cargo build --release # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile FROM debian:bullseye-slim as pytorch-install -ARG PYTORCH_VERSION=2.0.0 +ARG PYTORCH_VERSION=2.0.1 ARG PYTHON_VERSION=3.9 ARG CUDA_VERSION=11.8 ARG MAMBA_VERSION=23.1.0-1 diff --git a/server/pyproject.toml b/server/pyproject.toml index 3ee3351c..620420cc 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -30,6 +30,7 @@ transformers = "4.29.2" einops = "^0.6.1" texttable = { version = "^1.6.7", optional = true } datasets = { version = "^2.14.0", optional = true } +torch = { version = "^2.0.1+cu118", source = "pytorch-gpu-src"} [tool.poetry.extras] accelerate = ["accelerate"] @@ -40,6 +41,11 @@ quantize = ["texttable", "datasets", "accelerate"] grpcio-tools = "^1.51.1" pytest = "^7.3.0" +[[tool.poetry.source]] +name = "pytorch-gpu-src" +url = "https://download.pytorch.org/whl/cu118" +priority = "explicit" + [tool.pytest.ini_options] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] diff --git a/server/requirements.txt b/server/requirements.txt index 98838b36..92729fe1 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -8,7 +8,7 @@ bitsandbytes==0.38.1 ; python_version >= "3.9" and python_version < "4.0" certifi==2023.5.7 ; python_version >= "3.9" and python_version < "4.0" charset-normalizer==3.1.0 ; python_version >= "3.9" and python_version < "4.0" click==8.1.3 ; python_version >= "3.9" and python_version < "4.0" -colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.9" and python_version < "4.0" and platform_system == "Windows" +colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0" and (sys_platform == "win32" or platform_system == "Windows") datasets==2.14.0 ; python_version >= "3.9" and python_version < "4.0" deprecated==1.2.14 ; python_version >= "3.9" and python_version < "4.0" dill==0.3.7 ; python_version >= "3.9" and python_version < "4.0" @@ -32,7 +32,7 @@ mpmath==1.3.0 ; python_version >= "3.9" and python_version < "4.0" multidict==6.0.4 ; python_version >= "3.9" and python_version < "4.0" multiprocess==0.70.15 ; python_version >= "3.9" and python_version < "4.0" networkx==3.1 ; python_version >= "3.9" and python_version < "4.0" -numpy==1.25.0 ; python_version < "4.0" and python_version >= "3.9" +numpy==1.25.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "4.0" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "4.0"