inspect signature for position ids

This commit is contained in:
OlivierDehaene 2023-05-23 20:05:56 +02:00
parent b83ea010fa
commit ac59aadf17
4 changed files with 40 additions and 19 deletions

View File

@ -213,13 +213,12 @@ jobs:
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
- name: Install
run: |
pip install pytest-xdist
make install-integration-tests
- name: Run tests
run: |
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv -n 2 --dist loadfile integration-tests
pytest -s -vv integration-tests
stop-runner:
name: Stop self-hosted EC2 runner

View File

@ -261,12 +261,17 @@ def get_model(
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if "AutoModelForSeq2SeqLM" in auto_map.keys:
return Seq2SeqLM(
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
raise ValueError(f"Unsupported model type {model_type}")

View File

@ -1,4 +1,5 @@
import torch
import inspect
from dataclasses import dataclass
from opentelemetry import trace
@ -473,17 +474,28 @@ class CausalLM(Model):
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
tokenizer.pad_token_id = (
model.config.pad_token_id
if model.config.pad_token_id is not None
else model.config.eos_token_id
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
self.has_position_ids = (
inspect.signature(model.forward).parameters.get("position_ids", None)
is not None
)
super(CausalLM, self).__init__(
@ -507,14 +519,17 @@ class CausalLM(Model):
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
)
kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": True,
"return_dict": True,
}
if self.has_position_ids:
kwargs["position_ids"] = position_ids
outputs = self.model.forward(**kwargs)
return outputs.logits, outputs.past_key_values
@tracer.start_as_current_span("generate_token")

View File

@ -519,7 +519,9 @@ class Seq2SeqLM(Model):
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)