mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-05-21 17:52:09 +00:00
Integrate flash attention for starcoder2 tgi through habana and some fixes, enabling (#198)
This commit is contained in:
parent
0ca54b55f8
commit
9b71343328
@ -8,16 +8,19 @@ pub(crate) struct Env {
|
||||
docker_label: &'static str,
|
||||
nvidia_env: String,
|
||||
xpu_env: String,
|
||||
hpu_env: String,
|
||||
}
|
||||
|
||||
impl Env {
|
||||
pub fn new() -> Self {
|
||||
let nvidia_env = nvidia_smi();
|
||||
let xpu_env = xpu_smi();
|
||||
let hpu_env = hl_smi();
|
||||
|
||||
Self {
|
||||
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
|
||||
xpu_env: xpu_env.unwrap_or("N/A".to_string()),
|
||||
hpu_env: hpu_env.unwrap_or("N/A".to_string()),
|
||||
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
|
||||
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
|
||||
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
|
||||
@ -35,7 +38,8 @@ impl fmt::Display for Env {
|
||||
writeln!(f, "Commit sha: {}", self.git_sha)?;
|
||||
writeln!(f, "Docker label: {}", self.docker_label)?;
|
||||
writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
|
||||
write!(f, "xpu-smi:\n{}", self.xpu_env)?;
|
||||
writeln!(f, "xpu-smi:\n{}", self.xpu_env)?;
|
||||
write!(f, "hpu-smi:\n{}", self.hpu_env)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -54,3 +58,10 @@ fn xpu_smi() -> Option<String> {
|
||||
let output = xpu_smi.replace('\n', "\n ");
|
||||
Some(output.trim().to_string())
|
||||
}
|
||||
|
||||
fn hl_smi() -> Option<String> {
|
||||
let output = Command::new("hl-smi").output().ok()?;
|
||||
let hl_smi = String::from_utf8(output.stdout).ok()?;
|
||||
let output = hl_smi.replace('\n', "\n ");
|
||||
Some(output.trim().to_string())
|
||||
}
|
||||
|
@ -684,6 +684,8 @@ class CausalLM(Model):
|
||||
"return_dict": True,
|
||||
}
|
||||
|
||||
if model.config.model_type in ["llama", "mistral", "starcoder2"]:
|
||||
|
||||
if model.config.model_type in ["llama", "mistral"]:
|
||||
kwargs["attn_softmax_bf16"] = True
|
||||
kwargs["trim_logits"] = True
|
||||
|
@ -10,7 +10,7 @@ from transformers import PreTrainedTokenizerBase, AutoTokenizer, AutoConfig
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
|
||||
from text_generation_server.models.cache_manager import (
|
||||
get_cache_manager,
|
||||
|
Loading…
Reference in New Issue
Block a user