mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
inspect signature for position ids
This commit is contained in:
parent
b83ea010fa
commit
ac59aadf17
3
.github/workflows/build.yaml
vendored
3
.github/workflows/build.yaml
vendored
@ -213,13 +213,12 @@ jobs:
|
||||
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
|
||||
- name: Install
|
||||
run: |
|
||||
pip install pytest-xdist
|
||||
make install-integration-tests
|
||||
- name: Run tests
|
||||
run: |
|
||||
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
pytest -s -vv -n 2 --dist loadfile integration-tests
|
||||
pytest -s -vv integration-tests
|
||||
|
||||
stop-runner:
|
||||
name: Stop self-hosted EC2 runner
|
||||
|
@ -261,12 +261,17 @@ def get_model(
|
||||
if trust_remote_code and auto_map is not None:
|
||||
if "AutoModelForCausalLM" in auto_map.keys():
|
||||
return CausalLM(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if "AutoModelForSeq2SeqLM" in auto_map.keys:
|
||||
return Seq2SeqLM(
|
||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
raise ValueError(f"Unsupported model type {model_type}")
|
||||
|
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import inspect
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
@ -473,17 +474,28 @@ class CausalLM(Model):
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
device_map="auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
|
||||
tokenizer.pad_token_id = (
|
||||
model.config.pad_token_id
|
||||
if model.config.pad_token_id is not None
|
||||
else model.config.eos_token_id
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.pad_token_id
|
||||
elif model.config.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
self.has_position_ids = (
|
||||
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||
is not None
|
||||
)
|
||||
|
||||
super(CausalLM, self).__init__(
|
||||
@ -507,14 +519,17 @@ class CausalLM(Model):
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Model Forward
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": True,
|
||||
"return_dict": True,
|
||||
}
|
||||
if self.has_position_ids:
|
||||
kwargs["position_ids"] = position_ids
|
||||
|
||||
outputs = self.model.forward(**kwargs)
|
||||
return outputs.logits, outputs.past_key_values
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
|
@ -519,7 +519,9 @@ class Seq2SeqLM(Model):
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
||||
device_map="auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None,
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user