From 9b71343328e386426937cccb4f6a525b7e340134 Mon Sep 17 00:00:00 2001 From: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> Date: Thu, 8 Aug 2024 01:36:05 +0530 Subject: [PATCH] Integrate flash attention for starcoder2 tgi through habana and some fixes, enabling (#198) --- launcher/src/env_runtime.rs | 13 ++++++++++++- server/text_generation_server/models/causal_lm.py | 8 +++++--- .../text_generation_server/models/flash_mistral.py | 2 +- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/launcher/src/env_runtime.rs b/launcher/src/env_runtime.rs index 08fb301c..82562555 100644 --- a/launcher/src/env_runtime.rs +++ b/launcher/src/env_runtime.rs @@ -8,16 +8,19 @@ pub(crate) struct Env { docker_label: &'static str, nvidia_env: String, xpu_env: String, + hpu_env: String, } impl Env { pub fn new() -> Self { let nvidia_env = nvidia_smi(); let xpu_env = xpu_smi(); + let hpu_env = hl_smi(); Self { nvidia_env: nvidia_env.unwrap_or("N/A".to_string()), xpu_env: xpu_env.unwrap_or("N/A".to_string()), + hpu_env: hpu_env.unwrap_or("N/A".to_string()), cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"), cargo_version: env!("VERGEN_RUSTC_SEMVER"), git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), @@ -35,7 +38,8 @@ impl fmt::Display for Env { writeln!(f, "Commit sha: {}", self.git_sha)?; writeln!(f, "Docker label: {}", self.docker_label)?; writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?; - write!(f, "xpu-smi:\n{}", self.xpu_env)?; + writeln!(f, "xpu-smi:\n{}", self.xpu_env)?; + write!(f, "hpu-smi:\n{}", self.hpu_env)?; Ok(()) } @@ -54,3 +58,10 @@ fn xpu_smi() -> Option { let output = xpu_smi.replace('\n', "\n "); Some(output.trim().to_string()) } + +fn hl_smi() -> Option { + let output = Command::new("hl-smi").output().ok()?; + let hl_smi = String::from_utf8(output.stdout).ok()?; + let output = hl_smi.replace('\n', "\n "); + Some(output.trim().to_string()) +} diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 012f6249..6402f385 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -684,9 +684,11 @@ class CausalLM(Model): "return_dict": True, } - if model.config.model_type in ["llama", "mistral"]: - kwargs["attn_softmax_bf16"] = True - kwargs["trim_logits"] = True + if model.config.model_type in ["llama", "mistral", "starcoder2"]: + + if model.config.model_type in ["llama", "mistral"]: + kwargs["attn_softmax_bf16"] = True + kwargs["trim_logits"] = True if os.getenv("USE_FLASH_ATTENTION", "false").lower() == "true": kwargs["use_flash_attention"] = True diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 6959e2ec..88b9fe63 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -10,7 +10,7 @@ from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig from typing import Optional, Tuple, Type from text_generation_server.pb import generate_pb2 -from text_generation_server.models import FlashCausalLM +from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE from text_generation_server.models.cache_manager import ( get_cache_manager,