Integrate flash attention for starcoder2 tgi through habana and some fixes, enabling (#198)

This commit is contained in:
Abhilash Majumder 2024-08-08 01:36:05 +05:30 committed by GitHub
parent 0ca54b55f8
commit 9b71343328
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 5 deletions

View File

@ -8,16 +8,19 @@ pub(crate) struct Env {
docker_label: &'static str,
nvidia_env: String,
xpu_env: String,
hpu_env: String,
}
impl Env {
pub fn new() -> Self {
let nvidia_env = nvidia_smi();
let xpu_env = xpu_smi();
let hpu_env = hl_smi();
Self {
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
xpu_env: xpu_env.unwrap_or("N/A".to_string()),
hpu_env: hpu_env.unwrap_or("N/A".to_string()),
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
@ -35,7 +38,8 @@ impl fmt::Display for Env {
writeln!(f, "Commit sha: {}", self.git_sha)?;
writeln!(f, "Docker label: {}", self.docker_label)?;
writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
write!(f, "xpu-smi:\n{}", self.xpu_env)?;
writeln!(f, "xpu-smi:\n{}", self.xpu_env)?;
write!(f, "hpu-smi:\n{}", self.hpu_env)?;
Ok(())
}
@ -54,3 +58,10 @@ fn xpu_smi() -> Option<String> {
let output = xpu_smi.replace('\n', "\n ");
Some(output.trim().to_string())
}
fn hl_smi() -> Option<String> {
let output = Command::new("hl-smi").output().ok()?;
let hl_smi = String::from_utf8(output.stdout).ok()?;
let output = hl_smi.replace('\n', "\n ");
Some(output.trim().to_string())
}

View File

@ -684,6 +684,8 @@ class CausalLM(Model):
"return_dict": True,
}
if model.config.model_type in ["llama", "mistral", "starcoder2"]:
if model.config.model_type in ["llama", "mistral"]:
kwargs["attn_softmax_bf16"] = True
kwargs["trim_logits"] = True

View File

@ -10,7 +10,7 @@ from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
from typing import Optional, Tuple, Type
from text_generation_server.pb import generate_pb2
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
from text_generation_server.models.cache_manager import (
get_cache_manager,