diff --git a/Dockerfile b/Dockerfile index 50145395..930d7ad7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -108,7 +108,7 @@ COPY server/Makefile-transformers Makefile RUN BUILD_EXTENSIONS="True" make build-transformers # Text Generation Inference base image -FROM debian:bullseye-slim as base +FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as base # Conda env ENV PATH=/opt/conda/bin:$PATH \ @@ -122,17 +122,6 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ NUM_SHARD=1 \ PORT=80 -# NVIDIA env vars -ENV NVIDIA_VISIBLE_DEVICES all -ENV NVIDIA_DRIVER_CAPABILITIES compute,utility -# Required for nvidia-docker v1 -RUN /bin/bash -c echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \ - echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf -ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH - -LABEL com.nvidia.volumes.needed="nvidia_driver" - WORKDIR /usr/src RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 309ec19f..23c3ea28 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -585,13 +585,25 @@ class FlashSantacoderForCausalLM(nn.Module): if self.transformer.tp_embeddings: # Logits are sharded, so we need to gather them - world_logits = [ - torch.empty_like(logits) for _ in range(self.transformer.tp_world_size) - ] - torch.distributed.all_gather( - world_logits, logits, group=self.transformer.process_group - ) - world_logits = torch.cat(world_logits, dim=1) + if logits.shape[0] == 1: + # Fast path when batch size is 1 + world_logits = logits.new_empty( + (logits.shape[1] * self.transformer.tp_world_size) + ) + torch.distributed.all_gather_into_tensor( + world_logits, logits.view(-1), group=self.transformer.process_group + ) + world_logits = world_logits.view(1, -1) + else: + # We cannot use all_gather_into_tensor as it only support concatenating on the first dim + world_logits = [ + torch.empty_like(logits) + for _ in range(self.transformer.tp_world_size) + ] + torch.distributed.all_gather( + world_logits, logits, group=self.transformer.process_group + ) + world_logits = torch.cat(world_logits, dim=1) return world_logits, present diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 6e5c231e..c463ee98 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -217,6 +217,7 @@ class FlashSantacoderSharded(FlashSantacoder): device=device, rank=rank, world_size=world_size, + decode_buffer=1, ) @staticmethod