diff --git a/Dockerfile b/Dockerfile index 78870f49..3e02b6ef 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,9 +28,9 @@ COPY router router COPY launcher launcher RUN cargo build --release -# CUDA kernel builder +# Python builder # Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile -FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as kernel-builder +FROM debian:bullseye-slim as pytorch-install ARG PYTORCH_VERSION=2.0.0 ARG PYTHON_VERSION=3.9 @@ -41,19 +41,15 @@ ARG INSTALL_CHANNEL=pytorch # Automatically set by buildx ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ build-essential \ ca-certificates \ ccache \ - ninja-build \ - cmake \ curl \ git && \ rm -rf /var/lib/apt/lists/* -RUN /usr/sbin/update-ccache-symlinks && \ - mkdir /opt/ccache && \ - ccache --set-config=cache_dir=/opt/ccache -ENV PATH /opt/conda/bin:$PATH # Install conda # translating Docker's TARGETPLATFORM into mamba arches @@ -75,6 +71,16 @@ RUN case ${TARGETPLATFORM} in \ esac && \ /opt/conda/bin/conda clean -ya +# CUDA kernels builder image +FROM pytorch-install as kernel-builder + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + ninja-build \ + && rm -rf /var/lib/apt/lists/* + +RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \ + /opt/conda/bin/conda clean -ya + # Build Flash Attention CUDA kernels FROM kernel-builder as flash-att-builder @@ -97,10 +103,11 @@ COPY server/Makefile-transformers Makefile RUN BUILD_EXTENSIONS="True" make build-transformers # Text Generation Inference base image -FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as base +FROM debian:bullseye-slim as base # Conda env -ENV PATH=/opt/conda/bin:$PATH +ENV PATH=/opt/conda/bin:$PATH \ + CONDA_PREFIX=/opt/conda # Text Generation Inference base env ENV HUGGINGFACE_HUB_CACHE=/data \ @@ -121,7 +128,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins && rm -rf /var/lib/apt/lists/* # Copy conda with PyTorch installed -COPY --from=kernel-builder /opt/conda /opt/conda +COPY --from=pytorch-install /opt/conda /opt/conda # Copy build artifacts from flash attention builder COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index d2c9aa7e..99d0ed3d 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -8,6 +8,7 @@ description = "Text Generation Launcher" [dependencies] clap = { version = "4.1.4", features = ["derive", "env"] } ctrlc = { version = "3.2.5", features = ["termination"] } +serde = { version = "1.0.152", features = ["derive"] } serde_json = "1.0.93" subprocess = "0.2.9" tracing = "0.1.37" @@ -16,4 +17,3 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] } [dev-dependencies] float_eq = "1.0.1" reqwest = { version = "0.11.14", features = ["blocking", "json"] } -serde = { version = "1.0.152", features = ["derive"] } diff --git a/launcher/src/main.rs b/launcher/src/main.rs index a598a8bb..0cf43f16 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1,5 +1,5 @@ use clap::Parser; -use serde_json::Value; +use serde::Deserialize; use std::env; use std::ffi::OsString; use std::io::{BufRead, BufReader, Read}; @@ -244,11 +244,8 @@ fn main() -> ExitCode { let _span = tracing::span!(tracing::Level::INFO, "download").entered(); for line in stdout.lines() { // Parse loguru logs - if let Ok(value) = serde_json::from_str::(&line.unwrap()) { - if let Some(text) = value.get("text") { - // Format escaped newlines - tracing::info!("{}", text.to_string().replace("\\n", "")); - } + if let Ok(log) = serde_json::from_str::(&line.unwrap()) { + log.trace(); } } }); @@ -525,7 +522,7 @@ fn shard_manager( "--uds-path".to_string(), uds_path, "--logger-level".to_string(), - "ERROR".to_string(), + "INFO".to_string(), "--json-output".to_string(), ]; @@ -643,11 +640,8 @@ fn shard_manager( let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered(); for line in stdout.lines() { // Parse loguru logs - if let Ok(value) = serde_json::from_str::(&line.unwrap()) { - if let Some(text) = value.get("text") { - // Format escaped newlines - tracing::error!("{}", text.to_string().replace("\\n", "\n")); - } + if let Ok(log) = serde_json::from_str::(&line.unwrap()) { + log.trace(); } } }); @@ -708,3 +702,45 @@ fn num_cuda_devices() -> Option { } None } + +#[derive(Deserialize)] +#[serde(rename_all = "UPPERCASE")] +enum PythonLogLevelEnum { + Trace, + Debug, + Info, + Success, + Warning, + Error, + Critical, +} + +#[derive(Deserialize)] +struct PythonLogLevel { + name: PythonLogLevelEnum, +} + +#[derive(Deserialize)] +struct PythonLogRecord { + level: PythonLogLevel, +} + +#[derive(Deserialize)] +struct PythonLogMessage { + text: String, + record: PythonLogRecord, +} + +impl PythonLogMessage { + fn trace(&self) { + match self.record.level.name { + PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text), + PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text), + PythonLogLevelEnum::Info => tracing::info!("{}", self.text), + PythonLogLevelEnum::Success => tracing::info!("{}", self.text), + PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text), + PythonLogLevelEnum::Error => tracing::error!("{}", self.text), + PythonLogLevelEnum::Critical => tracing::error!("{}", self.text), + } + } +} diff --git a/server/poetry.lock b/server/poetry.lock index 1ae18f9b..0ab6faa3 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -33,7 +33,7 @@ python-versions = ">=3.7,<4.0" [[package]] name = "bitsandbytes" -version = "0.35.4" +version = "0.38.1" description = "8-bit optimizers and matrix multiplication routines." category = "main" optional = false @@ -138,17 +138,17 @@ grpc = ["grpcio (>=1.44.0,<2.0.0dev)"] [[package]] name = "grpc-interceptor" -version = "0.15.0" +version = "0.15.1" description = "Simplifies gRPC interceptors" category = "main" optional = false -python-versions = ">=3.6.1,<4.0.0" +python-versions = ">=3.7,<4.0" [package.dependencies] -grpcio = ">=1.32.0,<2.0.0" +grpcio = ">=1.49.1,<2.0.0" [package.extras] -testing = ["protobuf (>=3.6.0)"] +testing = ["protobuf (>=4.21.9)"] [[package]] name = "grpcio" @@ -597,7 +597,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"] [[package]] name = "pytest" -version = "7.3.0" +version = "7.3.1" description = "pytest: simple powerful testing with Python" category = "dev" optional = false @@ -833,7 +833,7 @@ bnb = ["bitsandbytes"] [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "6141d488429e0ab579028036e8e4cbc54f583b48214cb4a6be066bb7ce5154db" +content-hash = "e05491a03938b79a71b498f2759169f5a41181084158fde5993e7dcb25292cb0" [metadata.files] accelerate = [ @@ -845,8 +845,8 @@ backoff = [ {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, ] bitsandbytes = [ - {file = "bitsandbytes-0.35.4-py3-none-any.whl", hash = "sha256:201f168538ccfbd7594568a2f86c149cec8352782301076a15a783695ecec7fb"}, - {file = "bitsandbytes-0.35.4.tar.gz", hash = "sha256:b23db6b91cd73cb14faf9841a66bffa5c1722f9b8b57039ef2fb461ac22dd2a6"}, + {file = "bitsandbytes-0.38.1-py3-none-any.whl", hash = "sha256:5f532e7b1353eb7049ae831da2eb62ed8a1e0444116bd51b9e088a6e0bc7a34a"}, + {file = "bitsandbytes-0.38.1.tar.gz", hash = "sha256:ba95a806b5065ea3263558e188f07eacb32ad691842932fb0d36a879883167ce"}, ] certifi = [ {file = "certifi-2022.12.7-py3-none-any.whl", hash = "sha256:4ad3232f5e926d6718ec31cfc1fcadfde020920e278684144551c91769c7bc18"}, @@ -973,8 +973,8 @@ googleapis-common-protos = [ {file = "googleapis_common_protos-1.59.0-py2.py3-none-any.whl", hash = "sha256:b287dc48449d1d41af0c69f4ea26242b5ae4c3d7249a38b0984c86a4caffff1f"}, ] grpc-interceptor = [ - {file = "grpc-interceptor-0.15.0.tar.gz", hash = "sha256:5c1aa9680b1d7e12259960c38057b121826860b05ebbc1001c74343b7ad1455e"}, - {file = "grpc_interceptor-0.15.0-py3-none-any.whl", hash = "sha256:63e390162e64df96c39c40508eb697def76a7cafac32a7eaf9272093eec1109e"}, + {file = "grpc-interceptor-0.15.1.tar.gz", hash = "sha256:3efadbc9aead272ac7a360c75c4bd96233094c9a5192dbb51c6156246bd64ba0"}, + {file = "grpc_interceptor-0.15.1-py3-none-any.whl", hash = "sha256:1cc52c34b0d7ff34512fb7780742ecda37bf3caa18ecc5f33f09b4f74e96b276"}, ] grpcio = [ {file = "grpcio-1.53.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:752d2949b40e12e6ad3ed8cc552a65b54d226504f6b1fb67cab2ccee502cc06f"}, @@ -1329,8 +1329,8 @@ psutil = [ {file = "psutil-5.9.4.tar.gz", hash = "sha256:3d7f9739eb435d4b1338944abe23f49584bde5395f27487d2ee25ad9a8774a62"}, ] pytest = [ - {file = "pytest-7.3.0-py3-none-any.whl", hash = "sha256:933051fa1bfbd38a21e73c3960cebdad4cf59483ddba7696c48509727e17f201"}, - {file = "pytest-7.3.0.tar.gz", hash = "sha256:58ecc27ebf0ea643ebfdf7fb1249335da761a00c9f955bcd922349bcb68ee57d"}, + {file = "pytest-7.3.1-py3-none-any.whl", hash = "sha256:3799fa815351fea3a5e96ac7e503a96fa51cc9942c3753cda7651b93c1cfa362"}, + {file = "pytest-7.3.1.tar.gz", hash = "sha256:434afafd78b1d78ed0addf160ad2b77a30d35d4bdf8af234fe621919d9ed15e3"}, ] PyYAML = [ {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, diff --git a/server/pyproject.toml b/server/pyproject.toml index c892cd3f..68dd9327 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" typer = "^0.6.1" accelerate = "^0.15.0" -bitsandbytes = "^0.35.1" +bitsandbytes = "^0.38.1" safetensors = "^0.2.4" loguru = "^0.6.0" opentelemetry-api = "^1.15.0" diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 4d48f492..94340fac 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -6,8 +6,6 @@ from pathlib import Path from loguru import logger from typing import Optional -from text_generation_server import server, utils -from text_generation_server.tracing import setup_tracing app = typer.Typer() @@ -48,6 +46,11 @@ def serve( backtrace=True, diagnose=False, ) + + # Import here after the logger is added to log potential import exceptions + from text_generation_server import server + from text_generation_server.tracing import setup_tracing + # Setup OpenTelemetry distributed tracing if otlp_endpoint is not None: setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) @@ -75,6 +78,9 @@ def download_weights( diagnose=False, ) + # Import here after the logger is added to log potential import exceptions + from text_generation_server import utils + # Test if files were already download try: utils.weight_files(model_id, revision, extension) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 368060a0..9c1ea3b0 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -26,7 +26,7 @@ try: FLASH_ATTENTION = torch.cuda.is_available() except ImportError: - logger.exception("Could not import Flash Attention enabled models") + logger.opt(exception=True).warning("Could not import Flash Attention enabled models") FLASH_ATTENTION = False __all__ = [