diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index f814a8e2..22d7a71e 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -653,7 +653,10 @@ class CausalLM(Model): self.enable_hpu_graph = os.getenv("ENABLE_HPU_GRAPH", "true").lower() == "true" and LAZY_MODE == 1 self.limit_hpu_graph = os.getenv("LIMIT_HPU_GRAPH", "false").lower() == "true" - model = remove_kv_cache_from_output(model) + + if model.config.model_type not in ["gpt_bigcode"]: # gpt_bigcode/starcoderbase-3b skips remove_kv_cache_from_output() + model = remove_kv_cache_from_output(model) + if self.enable_hpu_graph: from habana_frameworks.torch.hpu import wrap_in_hpu_graph model = wrap_in_hpu_graph(model, disable_tensor_cache=True) @@ -866,7 +869,7 @@ class CausalLM(Model): kwargs["bypass_hpu_graphs"] = bypass_hpu_graph kwargs.update(self.kwargs) - if past_key_values is not None: + if past_key_values is not None and self.model.config.model_type not in ["gpt_bigcode"]: return self.model.forward(**kwargs) else: outputs = self.model.forward(**kwargs) @@ -950,7 +953,7 @@ class CausalLM(Model): else: batch.position_ids += 1 # Update past key values - if prefill: + if prefill or self.model.config.model_type in ["gpt_bigcode"]: batch.past_key_values = past htorch.core.mark_step() @@ -994,7 +997,7 @@ class CausalLM(Model): else: token_idx = torch.tensor(batch.attention_mask.shape[-1] - batch.right_padding).to(self.device) input_ids = torch.index_select(batch.input_ids, 1, token_idx - 1) - batch.logits = self.forward( + logits = self.forward( input_ids, batch.attention_mask, batch.position_ids, @@ -1002,6 +1005,10 @@ class CausalLM(Model): batch.past_key_values, bypass_hpu_graph=prefill and self.limit_hpu_graph if self.enable_hpu_graph else None, ) + if self.model.config.model_type in ["gpt_bigcode"]: + batch.logits, batch.past = logits + else: + batch.logits = logits htorch.core.mark_step()