diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 6c752de8..fef29899 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -11,6 +11,10 @@ on: branches: - 'main' +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + jobs: build-and-push-image: runs-on: ubuntu-latest diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 476ae443..830e2119 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -9,6 +9,10 @@ on: - "router/**" - "launcher/**" +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + jobs: run_tests: runs-on: ubuntu-20.04 diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 8caf7967..2f637ae1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,5 +1,7 @@ +import os import torch +from loguru import logger from transformers import AutoConfig from typing import Optional @@ -14,9 +16,10 @@ from text_generation_server.models.t5 import T5Sharded try: from text_generation_server.models.flash_neox import FlashNeoX, FlashNeoXSharded - - FLASH_NEOX = torch.cuda.is_available() + FLASH_NEOX = torch.cuda.is_available() and int(os.environ.get("FLASH_NEOX", 0)) == 1 except ImportError: + if int(os.environ.get("FLASH_NEOX", 0)) == 1: + logger.exception("Could not import FlashNeoX") FLASH_NEOX = False __all__ = [