mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 23:12:07 +00:00
fix(server): fix OPT implementation (#2061)
This commit is contained in:
parent
99c947452d
commit
e85e7ac4f9
@ -792,7 +792,7 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits, speculative_logits = self.lm_head(outputs)
|
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
|
|
||||||
|
@ -85,5 +85,4 @@ class GPTNeoxSharded(CausalLM):
|
|||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = outputs.logits
|
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||||
return logits, speculative_logits, outputs.past_key_values
|
|
||||||
|
@ -75,11 +75,11 @@ class OPTSharded(CausalLM):
|
|||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
):
|
):
|
||||||
outputs = self.model.forward(
|
outputs, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return outputs.logits, outputs.past_key_values
|
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||||
|
@ -71,11 +71,13 @@ class RW(CausalLM):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
):
|
||||||
# Model Forward
|
# Model Forward
|
||||||
outputs = self.model.forward(
|
outputs, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
|
use_cache=True,
|
||||||
)
|
)
|
||||||
return outputs.logits, outputs.past_key_values
|
|
||||||
|
return outputs.logits, speculative_logits, outputs.past_key_values
|
||||||
|
Loading…
Reference in New Issue
Block a user