Fix gpt_bigcode/starcoderbase-3b accuracy issue (#228)

Co-authored-by: Thanaji Rao Thakkalapelli <tthakkalapelli@habana.ai>
This commit is contained in:
Sun Choi 2024-10-14 01:01:55 -07:00 committed by GitHub
parent fe8a373831
commit 0578bd917d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -653,7 +653,10 @@ class CausalLM(Model):
self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1
self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true"
if model.config.model_type not in ["gpt_bigcode"]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output()
model = remove_kv_cache_from_output(model)
if self.enable_hpu_graph:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = wrap_in_hpu_graph(model, disable_tensor_cache=True)
@ -866,7 +869,7 @@ class CausalLM(Model):
kwargs["bypass_hpu_graphs"] = bypass_hpu_graph
kwargs.update(self.kwargs)
if past_key_values is not None:
if past_key_values is not None and self.model.config.model_type not in ["gpt_bigcode"]:
return self.model.forward(**kwargs)
else:
outputs = self.model.forward(**kwargs)
@ -950,7 +953,7 @@ class CausalLM(Model):
else:
batch.position_ids += 1
# Update past key values
if prefill:
if prefill or self.model.config.model_type in ["gpt_bigcode"]:
batch.past_key_values = past
htorch.core.mark_step()
@ -994,7 +997,7 @@ class CausalLM(Model):
else:
token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device)
input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1)
batch.logits = self.forward(
logits = self.forward(
input_ids,
batch.attention_mask,
batch.position_ids,
@ -1002,6 +1005,10 @@ class CausalLM(Model):
batch.past_key_values,
bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None,
)
if self.model.config.model_type in ["gpt_bigcode"]:
batch.logits, batch.past = logits
else:
batch.logits = logits
htorch.core.mark_step()