mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
Fix gpt_bigcode/starcoderbase-3b accuracy issue (#228)
Co-authored-by: Thanaji Rao Thakkalapelli <tthakkalapelli@habana.ai>
This commit is contained in:
parent
fe8a373831
commit
0578bd917d
@ -653,7 +653,10 @@ class CausalLM(Model):
|
|||||||
|
|
||||||
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
|
||||||
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
|
||||||
model = remove_kv_cache_from_output(model)
|
|
||||||
|
if model.config.model_type not in ["gpt_bigcode"]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output()
|
||||||
|
model = remove_kv_cache_from_output(model)
|
||||||
|
|
||||||
if self.enable_hpu_graph:
|
if self.enable_hpu_graph:
|
||||||
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
|
||||||
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
|
||||||
@ -866,7 +869,7 @@ class CausalLM(Model):
|
|||||||
kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
|
kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
|
||||||
|
|
||||||
kwargs.update(self.kwargs)
|
kwargs.update(self.kwargs)
|
||||||
if past_key_values is not None:
|
if past_key_values is not None and self.model.config.model_type not in ["gpt_bigcode"]:
|
||||||
return self.model.forward(**kwargs)
|
return self.model.forward(**kwargs)
|
||||||
else:
|
else:
|
||||||
outputs = self.model.forward(**kwargs)
|
outputs = self.model.forward(**kwargs)
|
||||||
@ -950,7 +953,7 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
batch.position_ids += 1
|
batch.position_ids += 1
|
||||||
# Update past key values
|
# Update past key values
|
||||||
if prefill:
|
if prefill or self.model.config.model_type in ["gpt_bigcode"]:
|
||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
@ -994,7 +997,7 @@ class CausalLM(Model):
|
|||||||
else:
|
else:
|
||||||
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
|
||||||
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
|
||||||
batch.logits = self.forward(
|
logits = self.forward(
|
||||||
input_ids,
|
input_ids,
|
||||||
batch.attention_mask,
|
batch.attention_mask,
|
||||||
batch.position_ids,
|
batch.position_ids,
|
||||||
@ -1002,6 +1005,10 @@ class CausalLM(Model):
|
|||||||
batch.past_key_values,
|
batch.past_key_values,
|
||||||
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
|
||||||
)
|
)
|
||||||
|
if self.model.config.model_type in ["gpt_bigcode"]:
|
||||||
|
batch.logits, batch.past = logits
|
||||||
|
else:
|
||||||
|
batch.logits = logits
|
||||||
|
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user