This commit is contained in:
OlivierDehaene 2023-04-16 17:34:57 +02:00
parent c23cc3e2f7
commit a37e6edd5c
7 changed files with 90 additions and 41 deletions

View File

@ -28,9 +28,9 @@ COPY router router
COPY launcher launcher
RUN cargo build --release
# CUDA kernel builder
# Python builder
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as kernel-builder
FROM debian:bullseye-slim as pytorch-install
ARG PYTORCH_VERSION=2.0.0
ARG PYTHON_VERSION=3.9
@ -41,19 +41,15 @@ ARG INSTALL_CHANNEL=pytorch
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
ccache \
ninja-build \
cmake \
curl \
git && \
rm -rf /var/lib/apt/lists/*
RUN /usr/sbin/update-ccache-symlinks && \
mkdir /opt/ccache && \
ccache --set-config=cache_dir=/opt/ccache
ENV PATH /opt/conda/bin:$PATH
# Install conda
# translating Docker's TARGETPLATFORM into mamba arches
@ -75,6 +71,16 @@ RUN case ${TARGETPLATFORM} in \
esac && \
/opt/conda/bin/conda clean -ya
# CUDA kernels builder image
FROM pytorch-install as kernel-builder
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ninja-build \
&& rm -rf /var/lib/apt/lists/*
RUN /opt/conda/bin/conda install -c "nvidia/label/cuda-11.8.0" cuda==11.8 && \
/opt/conda/bin/conda clean -ya
# Build Flash Attention CUDA kernels
FROM kernel-builder as flash-att-builder
@ -97,10 +103,11 @@ COPY server/Makefile-transformers Makefile
RUN BUILD_EXTENSIONS="True" make build-transformers
# Text Generation Inference base image
FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as base
FROM debian:bullseye-slim as base
# Conda env
ENV PATH=/opt/conda/bin:$PATH
ENV PATH=/opt/conda/bin:$PATH \
CONDA_PREFIX=/opt/conda
# Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/data \
@ -121,7 +128,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
&& rm -rf /var/lib/apt/lists/*
# Copy conda with PyTorch installed
COPY --from=kernel-builder /opt/conda /opt/conda
COPY --from=pytorch-install /opt/conda /opt/conda
# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-39 /opt/conda/lib/python3.9/site-packages

View File

@ -8,6 +8,7 @@ description = "Text Generation Launcher"
[dependencies]
clap = { version = "4.1.4", features = ["derive", "env"] }
ctrlc = { version = "3.2.5", features = ["termination"] }
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.93"
subprocess = "0.2.9"
tracing = "0.1.37"
@ -16,4 +17,3 @@ tracing-subscriber = { version = "0.3.16", features = ["json"] }
[dev-dependencies]
float_eq = "1.0.1"
reqwest = { version = "0.11.14", features = ["blocking", "json"] }
serde = { version = "1.0.152", features = ["derive"] }

View File

@ -1,5 +1,5 @@
use clap::Parser;
use serde_json::Value;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
use std::io::{BufRead, BufReader, Read};
@ -244,11 +244,8 @@ fn main() -> ExitCode {
let _span = tracing::span!(tracing::Level::INFO, "download").entered();
for line in stdout.lines() {
// Parse loguru logs
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) {
if let Some(text) = value.get("text") {
// Format escaped newlines
tracing::info!("{}", text.to_string().replace("\\n", ""));
}
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
log.trace();
}
}
});
@ -525,7 +522,7 @@ fn shard_manager(
"--uds-path".to_string(),
uds_path,
"--logger-level".to_string(),
"ERROR".to_string(),
"INFO".to_string(),
"--json-output".to_string(),
];
@ -643,11 +640,8 @@ fn shard_manager(
let _span = tracing::span!(tracing::Level::INFO, "shard-manager", rank = rank).entered();
for line in stdout.lines() {
// Parse loguru logs
if let Ok(value) = serde_json::from_str::<Value>(&line.unwrap()) {
if let Some(text) = value.get("text") {
// Format escaped newlines
tracing::error!("{}", text.to_string().replace("\\n", "\n"));
}
if let Ok(log) = serde_json::from_str::<PythonLogMessage>(&line.unwrap()) {
log.trace();
}
}
});
@ -708,3 +702,45 @@ fn num_cuda_devices() -> Option<usize> {
}
None
}
#[derive(Deserialize)]
#[serde(rename_all = "UPPERCASE")]
enum PythonLogLevelEnum {
Trace,
Debug,
Info,
Success,
Warning,
Error,
Critical,
}
#[derive(Deserialize)]
struct PythonLogLevel {
name: PythonLogLevelEnum,
}
#[derive(Deserialize)]
struct PythonLogRecord {
level: PythonLogLevel,
}
#[derive(Deserialize)]
struct PythonLogMessage {
text: String,
record: PythonLogRecord,
}
impl PythonLogMessage {
fn trace(&self) {
match self.record.level.name {
PythonLogLevelEnum::Trace => tracing::trace!("{}", self.text),
PythonLogLevelEnum::Debug => tracing::debug!("{}", self.text),
PythonLogLevelEnum::Info => tracing::info!("{}", self.text),
PythonLogLevelEnum::Success => tracing::info!("{}", self.text),
PythonLogLevelEnum::Warning => tracing::warn!("{}", self.text),
PythonLogLevelEnum::Error => tracing::error!("{}", self.text),
PythonLogLevelEnum::Critical => tracing::error!("{}", self.text),
}
}
}

26
server/poetry.lock generated
View File

@ -33,7 +33,7 @@ python-versions = ">=3.7,<4.0"
[[package]]
name = "bitsandbytes"
version = "0.35.4"
version = "0.38.1"
description = "8-bit optimizers and matrix multiplication routines."
category = "main"
optional = false
@ -138,17 +138,17 @@ grpc = ["grpcio (>=1.44.0,<2.0.0dev)"]
[[package]]
name = "grpc-interceptor"
version = "0.15.0"
version = "0.15.1"
description = "Simplifies gRPC interceptors"
category = "main"
optional = false
python-versions = ">=3.6.1,<4.0.0"
python-versions = ">=3.7,<4.0"
[package.dependencies]
grpcio = ">=1.32.0,<2.0.0"
grpcio = ">=1.49.1,<2.0.0"
[package.extras]
testing = ["protobuf (>=3.6.0)"]
testing = ["protobuf (>=4.21.9)"]
[[package]]
name = "grpcio"
@ -597,7 +597,7 @@ test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
[[package]]
name = "pytest"
version = "7.3.0"
version = "7.3.1"
description = "pytest: simple powerful testing with Python"
category = "dev"
optional = false
@ -833,7 +833,7 @@ bnb = ["bitsandbytes"]
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
content-hash = "6141d488429e0ab579028036e8e4cbc54f583b48214cb4a6be066bb7ce5154db"
content-hash = "e05491a03938b79a71b498f2759169f5a41181084158fde5993e7dcb25292cb0"
[metadata.files]
accelerate = [
@ -845,8 +845,8 @@ backoff = [
{file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"},
]
bitsandbytes = [
{file = "bitsandbytes-0.35.4-py3-none-any.whl", hash = "sha256:201f168538ccfbd7594568a2f86c149cec8352782301076a15a783695ecec7fb"},
{file = "bitsandbytes-0.35.4.tar.gz", hash = "sha256:b23db6b91cd73cb14faf9841a66bffa5c1722f9b8b57039ef2fb461ac22dd2a6"},
{file = "bitsandbytes-0.38.1-py3-none-any.whl", hash = "sha256:5f532e7b1353eb7049ae831da2eb62ed8a1e0444116bd51b9e088a6e0bc7a34a"},
{file = "bitsandbytes-0.38.1.tar.gz", hash = "sha256:ba95a806b5065ea3263558e188f07eacb32ad691842932fb0d36a879883167ce"},
]
certifi = [
{file = "certifi-2022.12.7-py3-none-any.whl", hash = "sha256:4ad3232f5e926d6718ec31cfc1fcadfde020920e278684144551c91769c7bc18"},
@ -973,8 +973,8 @@ googleapis-common-protos = [
{file = "googleapis_common_protos-1.59.0-py2.py3-none-any.whl", hash = "sha256:b287dc48449d1d41af0c69f4ea26242b5ae4c3d7249a38b0984c86a4caffff1f"},
]
grpc-interceptor = [
{file = "grpc-interceptor-0.15.0.tar.gz", hash = "sha256:5c1aa9680b1d7e12259960c38057b121826860b05ebbc1001c74343b7ad1455e"},
{file = "grpc_interceptor-0.15.0-py3-none-any.whl", hash = "sha256:63e390162e64df96c39c40508eb697def76a7cafac32a7eaf9272093eec1109e"},
{file = "grpc-interceptor-0.15.1.tar.gz", hash = "sha256:3efadbc9aead272ac7a360c75c4bd96233094c9a5192dbb51c6156246bd64ba0"},
{file = "grpc_interceptor-0.15.1-py3-none-any.whl", hash = "sha256:1cc52c34b0d7ff34512fb7780742ecda37bf3caa18ecc5f33f09b4f74e96b276"},
]
grpcio = [
{file = "grpcio-1.53.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:752d2949b40e12e6ad3ed8cc552a65b54d226504f6b1fb67cab2ccee502cc06f"},
@ -1329,8 +1329,8 @@ psutil = [
{file = "psutil-5.9.4.tar.gz", hash = "sha256:3d7f9739eb435d4b1338944abe23f49584bde5395f27487d2ee25ad9a8774a62"},
]
pytest = [
{file = "pytest-7.3.0-py3-none-any.whl", hash = "sha256:933051fa1bfbd38a21e73c3960cebdad4cf59483ddba7696c48509727e17f201"},
{file = "pytest-7.3.0.tar.gz", hash = "sha256:58ecc27ebf0ea643ebfdf7fb1249335da761a00c9f955bcd922349bcb68ee57d"},
{file = "pytest-7.3.1-py3-none-any.whl", hash = "sha256:3799fa815351fea3a5e96ac7e503a96fa51cc9942c3753cda7651b93c1cfa362"},
{file = "pytest-7.3.1.tar.gz", hash = "sha256:434afafd78b1d78ed0addf160ad2b77a30d35d4bdf8af234fe621919d9ed15e3"},
]
PyYAML = [
{file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"},

View File

@ -16,7 +16,7 @@ grpcio-reflection = "^1.51.1"
grpc-interceptor = "^0.15.0"
typer = "^0.6.1"
accelerate = "^0.15.0"
bitsandbytes = "^0.35.1"
bitsandbytes = "^0.38.1"
safetensors = "^0.2.4"
loguru = "^0.6.0"
opentelemetry-api = "^1.15.0"

View File

@ -6,8 +6,6 @@ from pathlib import Path
from loguru import logger
from typing import Optional
from text_generation_server import server, utils
from text_generation_server.tracing import setup_tracing
app = typer.Typer()
@ -48,6 +46,11 @@ def serve(
backtrace=True,
diagnose=False,
)
# Import here after the logger is added to log potential import exceptions
from text_generation_server import server
from text_generation_server.tracing import setup_tracing
# Setup OpenTelemetry distributed tracing
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)
@ -75,6 +78,9 @@ def download_weights(
diagnose=False,
)
# Import here after the logger is added to log potential import exceptions
from text_generation_server import utils
# Test if files were already download
try:
utils.weight_files(model_id, revision, extension)

View File

@ -26,7 +26,7 @@ try:
FLASH_ATTENTION = torch.cuda.is_available()
except ImportError:
logger.exception("Could not import Flash Attention enabled models")
logger.opt(exception=True).warning("Could not import Flash Attention enabled models")
FLASH_ATTENTION = False
__all__ = [