mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
inspect signature for position ids
This commit is contained in:
parent
b83ea010fa
commit
ac59aadf17
3
.github/workflows/build.yaml
vendored
3
.github/workflows/build.yaml
vendored
@ -213,13 +213,12 @@ jobs:
|
|||||||
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
|
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
|
||||||
- name: Install
|
- name: Install
|
||||||
run: |
|
run: |
|
||||||
pip install pytest-xdist
|
|
||||||
make install-integration-tests
|
make install-integration-tests
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
|
||||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||||
pytest -s -vv -n 2 --dist loadfile integration-tests
|
pytest -s -vv integration-tests
|
||||||
|
|
||||||
stop-runner:
|
stop-runner:
|
||||||
name: Stop self-hosted EC2 runner
|
name: Stop self-hosted EC2 runner
|
||||||
|
@ -261,12 +261,17 @@ def get_model(
|
|||||||
if trust_remote_code and auto_map is not None:
|
if trust_remote_code and auto_map is not None:
|
||||||
if "AutoModelForCausalLM" in auto_map.keys():
|
if "AutoModelForCausalLM" in auto_map.keys():
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if "AutoModelForSeq2SeqLM" in auto_map.keys:
|
if "AutoModelForSeq2SeqLM" in auto_map.keys:
|
||||||
return Seq2SeqLM(
|
return Seq2SeqLM(
|
||||||
model_id, revision, quantize=quantize, trust_remote_code=trust_remote_code
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
raise ValueError(f"Unsupported model type {model_type}")
|
raise ValueError(f"Unsupported model type {model_type}")
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import inspect
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
@ -473,17 +474,28 @@ class CausalLM(Model):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
device_map="auto"
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
else None,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
tokenizer.pad_token_id = (
|
if tokenizer.pad_token_id is None:
|
||||||
model.config.pad_token_id
|
if model.config.pad_token_id is not None:
|
||||||
if model.config.pad_token_id is not None
|
tokenizer.pad_token_id = model.config.pad_token_id
|
||||||
else model.config.eos_token_id
|
elif model.config.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = model.config.eos_token_id
|
||||||
|
elif tokenizer.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
else:
|
||||||
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
|
self.has_position_ids = (
|
||||||
|
inspect.signature(model.forward).parameters.get("position_ids", None)
|
||||||
|
is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
super(CausalLM, self).__init__(
|
super(CausalLM, self).__init__(
|
||||||
@ -507,14 +519,17 @@ class CausalLM(Model):
|
|||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
outputs = self.model.forward(
|
kwargs = {
|
||||||
input_ids=input_ids,
|
"input_ids": input_ids,
|
||||||
attention_mask=attention_mask,
|
"attention_mask": attention_mask,
|
||||||
position_ids=position_ids,
|
"past_key_values": past_key_values,
|
||||||
past_key_values=past_key_values,
|
"use_cache": True,
|
||||||
use_cache=True,
|
"return_dict": True,
|
||||||
return_dict=True,
|
}
|
||||||
)
|
if self.has_position_ids:
|
||||||
|
kwargs["position_ids"] = position_ids
|
||||||
|
|
||||||
|
outputs = self.model.forward(**kwargs)
|
||||||
return outputs.logits, outputs.past_key_values
|
return outputs.logits, outputs.past_key_values
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
|
@ -519,7 +519,9 @@ class Seq2SeqLM(Model):
|
|||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
|
device_map="auto"
|
||||||
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||||
|
else None,
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user