mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
cleanup
This commit is contained in:
parent
6818b8c285
commit
dd0028f336
@ -18,22 +18,15 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
SequenceClassifierOutputWithPast,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from transformers.models.gpt_bigcode.configuration_gpt_bigcode import (
|
||||
GPTBigCodeConfig,
|
||||
InferenceRunnerType,
|
||||
@ -51,15 +44,6 @@ except ImportError:
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "bigcode/gpt_bigcode-santacoder"
|
||||
_CONFIG_FOR_DOC = "GPTBigCodeConfig"
|
||||
|
||||
GPT_BIGCODE_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"bigcode/gpt_bigcode-santacoder",
|
||||
# See all GPTBigCode models at https://huggingface.co/models?filter=gpt_bigcode
|
||||
]
|
||||
|
||||
|
||||
def upcast_masked_softmax(
|
||||
x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor, scale: float, softmax_dtype: torch.dtype
|
||||
):
|
||||
@ -581,95 +565,6 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
GPT_BIGCODE_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`GPTBigCodeConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
GPT_BIGCODE_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`):
|
||||
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
||||
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
||||
sequence tokens in the vocabulary.
|
||||
|
||||
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
||||
`input_ids`.
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
past_key_values (`Tuple[torch.Tensor]` of length `config.n_layers`):
|
||||
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
||||
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
||||
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
|
||||
`past_key_values`. In other words, the `attention_mask` always has to have the length:
|
||||
`len(past_key_values) + len(input_ids)`
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
token_type_ids (`torch.Tensor` of shape `(batch_size, input_ids_length)`, *optional*):
|
||||
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
||||
1]`:
|
||||
|
||||
- 0 corresponds to a *sentence A* token,
|
||||
- 1 corresponds to a *sentence B* token.
|
||||
|
||||
[What are token type IDs?](../glossary#token-type-ids)
|
||||
position_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.max_position_embeddings - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
|
||||
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
||||
`past_key_values`).
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare GPT_BIGCODE Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
GPT_BIGCODE_START_DOCSTRING,
|
||||
)
|
||||
class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
|
||||
|
||||
@ -744,12 +639,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
position_ids = torch.arange(key_length - query_length, key_length, dtype=torch.long, device=device)
|
||||
return position_ids.view(-1, query_length)
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
@ -935,13 +824,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
||||
embeddings).
|
||||
""",
|
||||
GPT_BIGCODE_START_DOCSTRING,
|
||||
)
|
||||
class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
|
||||
|
||||
@ -985,12 +867,6 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=CausalLMOutputWithCrossAttentions,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
@ -1072,212 +948,3 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel):
|
||||
"""
|
||||
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The GPTBigCode Model transformer with a sequence classification head on top (linear layer).
|
||||
|
||||
[`GPTBigCodeForSequenceClassification`] uses the last token in order to do the classification, as other causal
|
||||
models (e.g. GPT-1) do.
|
||||
|
||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
||||
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
||||
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
||||
each row of the batch).
|
||||
""",
|
||||
GPT_BIGCODE_START_DOCSTRING,
|
||||
)
|
||||
class GPTBigCodeForSequenceClassification(GPTBigCodePreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
self.transformer = GPTBigCodeModel(config)
|
||||
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.Tensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
if input_ids is not None:
|
||||
batch_size, sequence_length = input_ids.shape[:2]
|
||||
else:
|
||||
batch_size, sequence_length = inputs_embeds.shape[:2]
|
||||
|
||||
assert (
|
||||
self.config.pad_token_id is not None or batch_size == 1
|
||||
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
||||
if self.config.pad_token_id is None:
|
||||
sequence_lengths = -1
|
||||
else:
|
||||
if input_ids is not None:
|
||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
||||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
||||
)
|
||||
|
||||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(logits.device)
|
||||
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(pooled_logits, labels)
|
||||
if not return_dict:
|
||||
output = (pooled_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutputWithPast(
|
||||
loss=loss,
|
||||
logits=pooled_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
GPT_BIGCODE Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
|
||||
for Named-Entity-Recognition (NER) tasks.
|
||||
""",
|
||||
GPT_BIGCODE_START_DOCSTRING,
|
||||
)
|
||||
class GPTBigCodeForTokenClassification(GPTBigCodePreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.transformer = GPTBigCodeModel(config)
|
||||
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
||||
classifier_dropout = config.classifier_dropout
|
||||
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
||||
classifier_dropout = config.hidden_dropout
|
||||
else:
|
||||
classifier_dropout = 0.1
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT_BIGCODE_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, TokenClassifierOutput]:
|
||||
r"""
|
||||
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1).to(logits.device))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + transformer_outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user