diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 9992e0a2..124e6a33 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -213,13 +213,12 @@ jobs: sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }} - name: Install run: | - pip install pytest-xdist make install-integration-tests - name: Run tests run: | export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} - pytest -s -vv -n 2 --dist loadfile integration-tests + pytest -s -vv integration-tests stop-runner: name: Stop self-hosted EC2 runner diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 2c8b3933..bf7a2849 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -261,12 +261,17 @@ def get_model( if trust_remote_code and auto_map is not None: if "AutoModelForCausalLM" in auto_map.keys(): return CausalLM( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, ) if "AutoModelForSeq2SeqLM" in auto_map.keys: return Seq2SeqLM( - model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + trust_remote_code=trust_remote_code, ) - raise ValueError(f"Unsupported model type {model_type}") diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 0cd47461..09df70d2 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,4 +1,5 @@ import torch +import inspect from dataclasses import dataclass from opentelemetry import trace @@ -473,17 +474,28 @@ class CausalLM(Model): model_id, revision=revision, torch_dtype=dtype, - device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, + device_map="auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None, load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, ) if torch.cuda.is_available() and torch.cuda.device_count() == 1: model = model.cuda() - tokenizer.pad_token_id = ( - model.config.pad_token_id - if model.config.pad_token_id is not None - else model.config.eos_token_id + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + self.has_position_ids = ( + inspect.signature(model.forward).parameters.get("position_ids", None) + is not None ) super(CausalLM, self).__init__( @@ -507,14 +519,17 @@ class CausalLM(Model): self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward - outputs = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=True, - return_dict=True, - ) + kwargs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "use_cache": True, + "return_dict": True, + } + if self.has_position_ids: + kwargs["position_ids"] = position_ids + + outputs = self.model.forward(**kwargs) return outputs.logits, outputs.past_key_values @tracer.start_as_current_span("generate_token") diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 677bc61e..a1a39fd4 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -519,7 +519,9 @@ class Seq2SeqLM(Model): model_id, revision=revision, torch_dtype=dtype, - device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, + device_map="auto" + if torch.cuda.is_available() and torch.cuda.device_count() > 1 + else None, load_in_8bit=quantize == "bitsandbytes", trust_remote_code=trust_remote_code, )