inspect signature for position ids

This commit is contained in:
OlivierDehaene 2023-05-23 20:05:56 +02:00
parent b83ea010fa
commit ac59aadf17
4 changed files with 40 additions and 19 deletions

View File

@ -213,13 +213,12 @@ jobs:
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }} sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
- name: Install - name: Install
run: | run: |
pip install pytest-xdist
make install-integration-tests make install-integration-tests
- name: Run tests - name: Run tests
run: | run: |
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv -n 2 --dist loadfile integration-tests pytest -s -vv integration-tests
stop-runner: stop-runner:
name: Stop self-hosted EC2 runner name: Stop self-hosted EC2 runner

View File

@ -261,12 +261,17 @@ def get_model(
if trust_remote_code and auto_map is not None: if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys(): if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM( return CausalLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
) )
if "AutoModelForSeq2SeqLM" in auto_map.keys: if "AutoModelForSeq2SeqLM" in auto_map.keys:
return Seq2SeqLM( return Seq2SeqLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
) )
raise ValueError(f"Unsupported model type {model_type}") raise ValueError(f"Unsupported model type {model_type}")

View File

@ -1,4 +1,5 @@
import torch import torch
import inspect
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
@ -473,17 +474,28 @@ class CausalLM(Model):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1: if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda() model = model.cuda()
tokenizer.pad_token_id = ( if tokenizer.pad_token_id is None:
model.config.pad_token_id if model.config.pad_token_id is not None:
if model.config.pad_token_id is not None tokenizer.pad_token_id = model.config.pad_token_id
else model.config.eos_token_id elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
@ -507,14 +519,17 @@ class CausalLM(Model):
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward # Model Forward
outputs = self.model.forward( kwargs = {
input_ids=input_ids, "input_ids": input_ids,
attention_mask=attention_mask, "attention_mask": attention_mask,
position_ids=position_ids, "past_key_values": past_key_values,
past_key_values=past_key_values, "use_cache": True,
use_cache=True, "return_dict": True,
return_dict=True, }
) if self.has_position_ids:
kwargs["position_ids"] = position_ids
outputs = self.model.forward(**kwargs)
return outputs.logits, outputs.past_key_values return outputs.logits, outputs.past_key_values
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")

View File

@ -519,7 +519,9 @@ class Seq2SeqLM(Model):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )