This commit is contained in:
OlivierDehaene 2023-06-28 19:26:26 +02:00
parent ae466a8736
commit d649cd8e02
14 changed files with 469 additions and 394 deletions

View File

@ -115,12 +115,6 @@ struct Args {
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
/// The maximum allowed batch size during dynamic batching.
/// Using `max_batch_total_tokens` should be favored in general
/// as it's a finer way to control RAM usage.
#[clap(long, env)]
max_batch_size: Option<usize>,
/// This represents the ratio of waiting queries vs running queries where
/// you want to start considering pausing the running queries to include the waiting
/// ones into the same batch.
@ -134,6 +128,9 @@ struct Args {
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "32000", long, env)]
max_batch_prefill_tokens: u32,
/// **IMPORTANT** This is one critical control to allow maximum usage
/// of the available hardware.
///
@ -181,7 +178,6 @@ struct Args {
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
#[clap(default_value = "3000", long, short, env)]
/// The port to listen on.
port: u16,
@ -329,6 +325,12 @@ fn shard_manager(
// Copy current process env
let mut env: Vec<(OsString, OsString)> = env::vars_os().collect();
// Use cuda allocator. It leads to less memory fragmentation
env.push((
"PYTORCH_CUDA_ALLOC_CONF".into(),
"backend:cudaMallocAsync".into(),
));
// Torch Distributed Env vars
env.push(("RANK".into(), rank.to_string().into()));
env.push(("WORLD_SIZE".into(), world_size.to_string().into()));
@ -822,6 +824,10 @@ fn spawn_webserver(
args.max_input_length.to_string(),
"--max-total-tokens".to_string(),
args.max_total_tokens.to_string(),
"--max-batch-prefill-tokens".to_string(),
args.max_batch_prefill_tokens.to_string(),
"--max-batch-total-tokens".to_string(),
args.max_batch_total_tokens.to_string(),
"--waiting-served-ratio".to_string(),
args.waiting_served_ratio.to_string(),
"--max-waiting-tokens".to_string(),
@ -834,15 +840,6 @@ fn spawn_webserver(
args.model_id,
];
// Deprecate max_batch_size
if let Some(max_batch_size) = args.max_batch_size {
argv.push("--max-batch-size".to_string());
argv.push(max_batch_size.to_string())
} else {
argv.push("--max-batch-total-tokens".to_string());
argv.push(args.max_batch_total_tokens.to_string())
}
// Model optional revision
if let Some(ref revision) = args.revision {
argv.push("--revision".to_string());

View File

@ -45,6 +45,7 @@ impl Infer {
client: ShardedClient,
validation: Validation,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
max_concurrent_requests: usize,
@ -61,6 +62,7 @@ impl Infer {
tokio::spawn(batching_task(
client,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
queue.clone(),
@ -243,6 +245,7 @@ impl Infer {
async fn batching_task(
mut client: ShardedClient,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
queue: Queue,
@ -257,8 +260,9 @@ async fn batching_task(
// Get the next batch from the queue
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while let Some((mut entries, batch, span)) =
queue.next_batch(None, max_batch_total_tokens).await
while let Some((mut entries, batch, span)) = queue
.next_batch(None, max_batch_prefill_tokens, max_batch_total_tokens)
.await
{
let mut cached_batch = prefill(&mut client, batch, &mut entries, &generation_health)
.instrument(span)
@ -287,8 +291,9 @@ async fn batching_task(
let token_budget = max_batch_total_tokens - batch_max_tokens;
// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) =
queue.next_batch(min_size, token_budget).await
if let Some((mut new_entries, new_batch, span)) = queue
.next_batch(min_size, max_batch_prefill_tokens, token_budget)
.await
{
// Tracking metrics
if min_size.is_some() {

View File

@ -32,11 +32,11 @@ struct Args {
max_input_length: usize,
#[clap(default_value = "1512", long, env)]
max_total_tokens: usize,
#[clap(long, env)]
max_batch_size: Option<usize>,
#[clap(default_value = "1.2", long, env)]
waiting_served_ratio: f32,
#[clap(default_value = "32000", long, env)]
max_batch_prefill_tokens: u32,
#[clap(default_value = "32000", long, env)]
max_batch_total_tokens: u32,
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,
@ -78,9 +78,9 @@ fn main() -> Result<(), std::io::Error> {
max_stop_sequences,
max_input_length,
max_total_tokens,
max_batch_size,
waiting_served_ratio,
mut max_batch_total_tokens,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
port,
master_shard_uds_path,
@ -141,12 +141,6 @@ fn main() -> Result<(), std::io::Error> {
.block_on(async {
init_logging(otlp_endpoint, json_output);
if let Some(max_batch_size) = max_batch_size {
tracing::warn!("`max-batch-size` is deprecated. Use `max-batch-total-tokens` instead");
max_batch_total_tokens = (max_batch_size * max_total_tokens) as u32;
tracing::warn!("Overriding `max-batch-total-tokens` value with `max-batch-size` * `max-total-tokens` = {max_batch_total_tokens}");
}
if tokenizer.is_none() {
tracing::warn!(
"Could not find a fast tokenizer implementation for {tokenizer_name}"
@ -161,9 +155,15 @@ fn main() -> Result<(), std::io::Error> {
sha: None,
pipeline_tag: None,
},
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| {
false => get_model_info(&tokenizer_name, &revision, authorization_token)
.await
.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
HubModelInfo {
model_id: tokenizer_name.to_string(),
sha: None,
pipeline_tag: None,
}
}),
};
@ -206,6 +206,7 @@ fn main() -> Result<(), std::io::Error> {
max_input_length,
max_total_tokens,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
sharded_client,

View File

@ -58,6 +58,7 @@ impl Queue {
pub(crate) async fn next_batch(
&self,
min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
) -> Option<NextBatch> {
// Create response channel
@ -67,6 +68,7 @@ impl Queue {
self.queue_sender
.send(QueueCommand::NextBatch {
min_size,
prefill_token_budget,
token_budget,
response_sender,
span: Span::current(),
@ -90,11 +92,12 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
}
QueueCommand::NextBatch {
min_size,
prefill_token_budget,
token_budget,
response_sender,
span,
} => span.in_scope(|| {
let next_batch = state.next_batch(min_size, token_budget);
let next_batch = state.next_batch(min_size, prefill_token_budget, token_budget);
response_sender.send(next_batch).unwrap();
metrics::gauge!("tgi_queue_size", state.entries.len() as f64);
}),
@ -140,7 +143,12 @@ impl State {
}
// Get the next batch
fn next_batch(&mut self, min_size: Option<usize>, token_budget: u32) -> Option<NextBatch> {
fn next_batch(
&mut self,
min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
) -> Option<NextBatch> {
if self.entries.is_empty() {
return None;
}
@ -184,7 +192,9 @@ impl State {
decode_tokens += entry.request.stopping_parameters.max_new_tokens;
if (prefill_tokens + decode_tokens) > token_budget {
if prefill_tokens > prefill_token_budget
|| (prefill_tokens + decode_tokens) > token_budget
{
// Entry is over budget
// Add it back to the front
self.entries.push_front((id, entry));
@ -259,6 +269,7 @@ enum QueueCommand {
Append(Box<Entry>, Span),
NextBatch {
min_size: Option<usize>,
prefill_token_budget: u32,
token_budget: u32,
response_sender: oneshot::Sender<Option<NextBatch>>,
span: Span,
@ -294,7 +305,7 @@ mod tests {
watermark: false,
},
stopping_parameters: StoppingCriteriaParameters {
ignore_eos_token: false,
ignore_eos_token: true,
max_new_tokens: 1,
stop_sequences: vec![],
},

View File

@ -152,7 +152,7 @@ async fn generate(
let start_time = Instant::now();
metrics::increment_counter!("tgi_request_count");
tracing::debug!("Input: {}", req.0.inputs);
// tracing::debug!("Input: {}", req.0.inputs);
let compute_characters = req.0.inputs.chars().count();
let mut add_prompt = None;
@ -286,7 +286,7 @@ async fn generate(
}
tracing::debug!("Output: {}", output_text);
tracing::info!("Success");
// tracing::info!("Success");
let response = GenerateResponse {
generated_text: output_text,
@ -513,6 +513,7 @@ pub async fn run(
max_input_length: usize,
max_total_tokens: usize,
waiting_served_ratio: f32,
max_batch_prefill_tokens: u32,
max_batch_total_tokens: u32,
max_waiting_tokens: usize,
client: ShardedClient,
@ -581,6 +582,7 @@ pub async fn run(
client,
validation,
waiting_served_ratio,
max_batch_prefill_tokens,
max_batch_total_tokens,
max_waiting_tokens,
max_concurrent_requests,

View File

@ -19,10 +19,12 @@ class Cache:
def delete(self, batch_id: int):
batch = self.pop(batch_id)
if batch is not None:
batch.cleanup()
del batch
def clear(self):
self.cache.clear()
for k in self.cache.keys():
self.delete(k)
def __len__(self):
return len(self.cache.keys())

View File

@ -23,7 +23,9 @@ import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional
from typing import Optional, List, Tuple
from vllm import attention_ops
from vllm import cache_ops
# Flash attention imports
import flash_attn_cuda
@ -106,7 +108,7 @@ class FlashLlamaAttention(torch.nn.Module):
prefix=f"{prefix}.rotary_emb", weights=weights
)
self.softmax_scale = self.head_size ** (-0.5)
self.softmax_scale = self.head_size**-0.5
self.num_heads = self.num_heads // weights.process_group.size()
self.query_key_value = TensorParallelColumnLinear.load_multi(
@ -128,14 +130,13 @@ class FlashLlamaAttention(torch.nn.Module):
hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -144,23 +145,25 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(qkv[:, 0], cos, sin)
self.rotary_emb(qkv[:, 1], cos, sin)
# Prefill
if prefill:
# Copy to layer past
layer_past[...] = qkv[:, 1:]
cache_ops.reshape_and_cache(
qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots
)
# output
# output tensor
attn_output = torch.empty_like(qkv[:, 0])
# Prefill
if start_seq_prefill is not None:
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
attn_output,
start_seq,
end_seq,
start_seq,
end_seq,
start_seq_prefill,
end_seq_prefill,
start_seq_prefill,
end_seq_prefill,
max_s,
max_s,
0.0,
@ -173,31 +176,18 @@ class FlashLlamaAttention(torch.nn.Module):
)
# Decode
else:
query = qkv[:, 0]
# Add present to the layer_past tensor at the correct indices
layer_past[past_present_indices] = qkv[:, 1:]
# output
attn_output = torch.empty_like(query)
# flash attention
flash_attn_cuda.fwd(
query,
layer_past[:, 0],
layer_past[:, 1],
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache[1].shape[3]
attention_ops.single_query_cached_kv_attention(
attn_output,
start_seq_q,
end_seq_q,
start_seq,
end_seq,
1,
max_s,
0.0,
qkv[:, 0],
kv_cache[0],
kv_cache[1],
self.softmax_scale,
False,
False,
False,
0,
None,
block_tables,
input_lengths,
block_size,
max_s,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -265,14 +255,13 @@ class FlashLlamaLayer(nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -281,14 +270,13 @@ class FlashLlamaLayer(nn.Module):
normed_hidden_states,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
layer_past,
past_present_indices,
prefill,
)
# faster post attention rms norm
@ -333,40 +321,18 @@ class FlashLlamaModel(torch.nn.Module):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
):
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Prefill
if past_key_values is None:
assert pre_allocate_past_size is not None
prefill = True
# Create past tensor
# We create a tensor of the same size as input_ids as we don't want to slice at every layer
past_key_values = hidden_states.new_empty(
(
len(input_ids),
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# Decode
else:
prefill = False
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
@ -380,34 +346,18 @@ class FlashLlamaModel(torch.nn.Module):
residual,
cos,
sin,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache[i],
block_tables,
slots,
input_lengths,
max_s,
past_key_values[:, i],
past_present_indices,
prefill,
)
if prefill:
present = past_key_values
# Create padded past tensor
past_key_values = hidden_states.new_empty(
(
pre_allocate_past_size,
len(self.layers),
2,
self.num_heads,
self.head_size,
)
)
# We slice only once instead of at every layer
past_key_values[past_present_indices] = present
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states, past_key_values
return hidden_states
class FlashLlamaForCausalLM(torch.nn.Module):
@ -423,31 +373,29 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def forward(
self,
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
max_s,
past_present_indices,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
lm_head_indices: Optional[torch.Tensor] = None,
):
hidden_states, present = self.model(
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
position_ids,
start_seq,
end_seq,
start_seq_q,
end_seq_q,
start_seq_prefill,
end_seq_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
max_s,
past_present_indices,
past_key_values,
pre_allocate_past_size,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.lm_head(hidden_states)
return logits, present
return logits

View File

@ -1004,7 +1004,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
try:
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
except RuntimeError:
self.shared = TensorParallelEmbedding(prefix="encoder.embed_tokens", weights=weights)
self.shared = TensorParallelEmbedding(
prefix="encoder.embed_tokens", weights=weights
)
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False

View File

@ -1,3 +1,4 @@
import itertools
import torch
import torch.distributed
@ -5,7 +6,7 @@ import numpy as np
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
from text_generation_server.models import Model
@ -20,6 +21,66 @@ from text_generation_server.utils import StoppingCriteria, HeterogeneousNextToke
tracer = trace.get_tracer(__name__)
# Will be set in warmup
CACHE_MANAGER: Optional["CacheManager"] = None
class CacheManager:
def __init__(
self,
num_blocks: int,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
self.block_size = 16
element_size = torch.tensor([], dtype=dtype).element_size()
x = self.block_size // element_size
self.kv_cache = [
(
torch.empty(
(num_blocks, num_heads, head_size // x, self.block_size, x),
dtype=dtype,
device=device,
),
torch.empty(
(num_blocks, num_heads, head_size, self.block_size),
dtype=dtype,
device=device,
),
)
for _ in range(num_layers)
]
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
self.slots = torch.arange(
0, num_blocks * self.block_size, dtype=torch.int32
).view(num_blocks, self.block_size)
def allocate(self, n_tokens: int) -> Tuple[List[int], torch.Tensor]:
# Number of needed block to cover all tokens
needed_blocks = (n_tokens // self.block_size) + 1
# Get free blocks indices by finding values in mask that are not set to 0
free_block_indices = self.free_block_mask.nonzero()
assert len(free_block_indices) >= needed_blocks, "Out of available cache blocks"
# Allocate the required number of blocks by setting the mask to 0
block_indices = free_block_indices[:needed_blocks]
self.free_block_mask[block_indices] = 0
# Get slots for the allocated blocks
slots = self.slots[block_indices].flatten()[:n_tokens]
return block_indices.flatten().tolist(), slots
def free(self, block_indices: List[int]):
# Reset mask
self.free_block_mask[block_indices] = 1
@dataclass
class FlashCausalLMBatch(Batch):
@ -32,23 +93,20 @@ class FlashCausalLMBatch(Batch):
input_ids: torch.Tensor
position_ids: torch.Tensor
# Indices to copy present to the correct indices is the pre-allocated past key values
past_present_indices: torch.Tensor
# tensor of length b holding starting offset of each sequence
start_seq: torch.Tensor
# tensor of length b holding ending offset of each sequence
end_seq: torch.Tensor
# tensor of length b holding starting offset of each sequence, only used in prefill
start_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding ending offset of each sequence, only used in prefill
end_seq_prefill: Optional[torch.Tensor]
# tensor of length b holding starting offset of each query sequence, only used in decode
start_seq_q: Optional[torch.Tensor]
# tensor of length b holding ending offset of each query sequence, only used in decode
end_seq_q: Optional[torch.Tensor]
# past key values, only used in decode
past_key_values: Optional[torch.Tensor]
# list of length b of list of length s_i // block_size
block_tables: List[List[int]]
# tensor of size [b, max_seqlen // block_size] holding the paged attention block tables for all sequences
block_tables_tensor: torch.Tensor
# CPU tensor of length b indicating the start of each sequence in slots
start_slots: torch.Tensor
# tensor of length \sum_{i=0}^{b} max_s_i holding the paged attention slots for all sequences
slots: torch.Tensor
# tensor of indices of the currently used slots, length = \sum_{i=0}^{b} s_i in prefill, length = b in decode
slot_indices: torch.Tensor
max_seqlen: int
# Prefill metadata tensors to efficiently compute logprobs
@ -62,6 +120,7 @@ class FlashCausalLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
input_lengths_tensor: torch.Tensor
prefix_offsets: List[Optional[int]]
read_offsets: List[Optional[int]]
@ -69,15 +128,16 @@ class FlashCausalLMBatch(Batch):
next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria]
# Maximum number of tokens this batch will grow to
max_tokens: int
# Maximum number of blocks
max_blocks: int
def to_pb(self) -> generate_pb2.CachedBatch:
global CACHE_MANAGER
return generate_pb2.CachedBatch(
id=self.batch_id,
request_ids=[r.id for r in self.requests],
size=len(self),
max_tokens=self.max_tokens,
max_tokens=len(self.slots),
)
@classmethod
@ -88,6 +148,8 @@ class FlashCausalLMBatch(Batch):
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
global CACHE_MANAGER
batch_inputs = []
max_truncation = 0
for r in pb.requests:
@ -99,12 +161,12 @@ class FlashCausalLMBatch(Batch):
)["input_ids"]
position_ids = []
past_present_indices = []
start_seq = []
end_seq = []
start_seq_prefill = []
end_seq_prefill = []
max_seqlen = 0
block_tables = []
start_slots = []
slots = []
slot_indices = []
input_lengths = []
prefix_offsets = []
@ -126,7 +188,9 @@ class FlashCausalLMBatch(Batch):
cumulative_max_length = 0
prefill_out_cumulative_length = 0
max_seqlen = 0
max_length = 0
max_blocks = 0
# Parse batch
for i, (r, tokenized_input) in enumerate(
@ -138,7 +202,6 @@ class FlashCausalLMBatch(Batch):
tokenized_input = tokenized_input[-r.truncate :]
input_length = len(tokenized_input)
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
prefix_offsets.append(input_length - 5)
@ -153,8 +216,6 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs
start_seq_prefill.append(cumulative_length)
end_seq_prefill.append(cumulative_length + input_length)
start_seq.append(cumulative_max_length)
end_seq.append(cumulative_max_length + input_length)
next_token_chooser_parameters.append(r.parameters)
@ -164,6 +225,21 @@ class FlashCausalLMBatch(Batch):
max_new_tokens = stopping_criteria.max_new_tokens
stopping_criterias.append(stopping_criteria)
# Paged attention
# Remove one as the first token des not have a past
total_tokens = input_length + max_new_tokens - 1
request_blocks, request_slots = CACHE_MANAGER.allocate(total_tokens)
block_tables.append(request_blocks)
slots.extend(request_slots)
start_slots.append(cumulative_max_length)
request_slot_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
slot_indices.append(request_slot_indices)
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
@ -184,22 +260,26 @@ class FlashCausalLMBatch(Batch):
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
prefill_out_cumulative_length += 1
request_past_present_indices = torch.arange(
cumulative_max_length,
cumulative_max_length + input_length,
dtype=torch.int64,
)
past_present_indices.append(request_past_present_indices)
# Update
# Remove one as the first token des not have a past
cumulative_length += input_length
cumulative_max_length += input_length + max_new_tokens - 1
cumulative_max_length += total_tokens
max_seqlen = max(max_seqlen, input_length)
max_blocks = max(max_blocks, len(request_blocks))
max_length = max(max_length, input_length + max_new_tokens)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
)
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Padded block tables
block_tables_tensor = torch.zeros(
(len(pb.requests), max_blocks), dtype=torch.int32
)
for i, request_blocks in enumerate(block_tables):
block_tables_tensor[i, : len(request_blocks)] = torch.tensor(request_blocks)
block_tables_tensor = block_tables_tensor.to(device)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros(
@ -212,14 +292,15 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
start_seq = torch.tensor(start_seq, device=device, dtype=torch.int32)
end_seq = torch.tensor(end_seq, device=device, dtype=torch.int32)
if len(pb.requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
position_ids = torch.cat(position_ids)
past_present_indices = np.concatenate(past_present_indices, dtype=np.int64)
slot_indices = torch.cat(slot_indices)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
slot_indices = slot_indices[0]
start_seq_prefill = torch.tensor(
start_seq_prefill, device=device, dtype=torch.int32
@ -227,19 +308,13 @@ class FlashCausalLMBatch(Batch):
end_seq_prefill = torch.tensor(
end_seq_prefill, device=device, dtype=torch.int32
)
else:
input_ids = all_input_ids[0]
position_ids = position_ids[0]
past_present_indices = past_present_indices[0]
start_seq_prefill = start_seq
end_seq_prefill = end_seq
position_ids = position_ids.to(device)
slot_indices = slot_indices.to(device)
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
position_ids = torch.tensor(position_ids, dtype=torch.int32, device=device)
past_present_indices = torch.tensor(
past_present_indices, device=device, dtype=torch.int64
slots = torch.tensor(slots, dtype=torch.int32, device=device)
input_lengths_tensor = torch.tensor(
input_lengths, dtype=torch.int32, device=device
)
if all_prefill_logprobs:
@ -262,30 +337,31 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=start_seq_prefill,
end_seq_prefill=end_seq_prefill,
start_seq_q=None,
end_seq_q=None,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
start_slots=start_slots,
slots=slots,
slot_indices=slot_indices,
max_seqlen=max_seqlen,
prefill_head_indices=prefill_head_indices,
prefill_next_token_indices=prefill_next_token_indices,
prefill_cu_outlens=prefill_cu_outlens,
past_key_values=None,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=cumulative_max_length,
max_blocks=max_blocks,
)
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
global CACHE_MANAGER
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
# We assume that if len(requests) == len(self) then the requests are the same
@ -294,28 +370,24 @@ class FlashCausalLMBatch(Batch):
device = self.input_ids.device
# Cumulative length
cumulative_max_length = 0
# New values after filtering
requests_idx_mapping = {}
# Used to index into tensors
indices = []
# past indices to keep
past_indices = torch.zeros(
self.past_key_values.shape[0], dtype=torch.bool, device=device
# slots to keep after filtering
slot_filtering_indices = torch.zeros(
self.slots.shape[0], dtype=torch.bool, device=device
)
# Create on CPU to only move to GPU once instead of at every copy
start_seq = torch.empty(len(request_ids), dtype=torch.int32)
end_seq = torch.empty(len(request_ids), dtype=torch.int32)
start_seq_q = self.start_seq_q[: len(request_ids)]
end_seq_q = self.end_seq_q[: len(request_ids)]
slot_indices = torch.empty(len(request_ids), dtype=torch.int64)
max_seqlen = 0
requests = []
start_slots = []
block_tables = []
all_input_ids = []
input_lengths = []
@ -324,6 +396,10 @@ class FlashCausalLMBatch(Batch):
stopping_criterias = []
max_blocks = 0
# Cumulative length
cumulative_max_length = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
@ -348,28 +424,45 @@ class FlashCausalLMBatch(Batch):
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens
)
request_block_table = self.block_tables[idx]
block_tables.append(request_block_table)
start_slots.append(cumulative_max_length)
# Copy to tensor (CPU)
start_seq[i] = cumulative_max_length
end_seq[i] = cumulative_max_length + request_input_length
slot_indices[i] = cumulative_max_length + request_input_length - 1
# Set slice
past_indices[
self.start_seq[idx] : self.end_seq[idx] + remaining_tokens - 1
slot_filtering_indices[
self.start_slots[idx] : self.start_slots[idx]
+ request_input_length
+ remaining_tokens
- 1
] = True
cumulative_max_length += request_input_length + remaining_tokens - 1
max_blocks = max(max_blocks, len(request_block_table))
# Iterate on all requests
for i, r in enumerate(self.requests):
# Filter requests that are not part of the new batch
if r.id not in requests_idx_mapping.keys():
# Free blocks
CACHE_MANAGER.free(self.block_tables[i])
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
block_tables_tensor = self.block_tables_tensor[indices]
input_lengths_tensor = self.input_lengths_tensor[indices]
slots = self.slots[slot_filtering_indices]
next_token_chooser = self.next_token_chooser.filter(indices)
past_key_values = self.past_key_values[past_indices]
start_slots = torch.tensor(start_slots, dtype=torch.int64)
# Move to GPU now that we have the whole tensor
start_seq = start_seq.to(device)
end_seq = end_seq.to(device)
past_present_indices = end_seq - 1
slot_indices = slot_indices.to(device)
return FlashCausalLMBatch(
batch_id=self.batch_id,
@ -377,51 +470,74 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=None,
end_seq_prefill=None,
start_seq_q=start_seq_q,
end_seq_q=end_seq_q,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
start_slots=start_slots,
slots=slots,
slot_indices=slot_indices,
max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
past_key_values=past_key_values,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=cumulative_max_length,
max_blocks=max_blocks,
)
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch":
global CACHE_MANAGER
# Batch attributes
requests = []
requests_idx_mapping = {}
total_batch_size = sum([len(b) for b in batches])
dtype = batches[0].past_key_values.dtype
device = batches[0].input_ids.device
total_batch_size = 0
total_slots = 0
max_blocks = 0
max_length = 0
max_seqlen = 0
for b in batches:
total_batch_size += len(b)
total_slots += len(b.slots)
max_blocks = max(max_blocks, b.max_blocks)
max_seqlen = max(max_seqlen, b.max_seqlen)
max_length = max(
max_length,
max(
input_length
+ stopping_criteria.max_new_tokens
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
b.input_lengths, b.stopping_criterias
)
),
)
input_ids = batches[0].input_ids.new_empty(total_batch_size)
position_ids = batches[0].position_ids.new_empty(total_batch_size)
start_seq = batches[0].start_seq.new_empty(total_batch_size)
end_seq = batches[0].end_seq.new_empty(total_batch_size)
start_seq_q = torch.arange(
0, total_batch_size, device=device, dtype=torch.int32
slots = batches[0].slots.new_empty(total_slots)
slot_indices = batches[0].slot_indices.new_empty(total_batch_size)
input_lengths_tensor = batches[0].input_lengths_tensor.new_empty(
total_batch_size
)
block_tables_tensor = batches[0].block_tables_tensor.new_zeros(
(total_batch_size, max_blocks)
)
all_input_ids_tensor = batches[0].all_input_ids_tensor.new_zeros(
(total_batch_size, max_length)
)
end_seq_q = start_seq_q + 1
max_seqlen = 0
past_key_values = []
start_slots = []
block_tables = []
all_input_ids = []
input_lengths = []
@ -433,8 +549,7 @@ class FlashCausalLMBatch(Batch):
# Cumulative length
cumulative_batch_size = 0
max_tokens = 0
max_length = 0
cumulative_slots = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
@ -448,16 +563,27 @@ class FlashCausalLMBatch(Batch):
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
slots_start_index = cumulative_slots
slots_end_index = cumulative_slots + len(batch.slots)
# Copy tensors (GPU)
input_ids[start_index:end_index] = batch.input_ids
position_ids[start_index:end_index] = batch.position_ids
slot_indices[start_index:end_index] = batch.slot_indices + cumulative_slots
input_lengths_tensor[start_index:end_index] = batch.input_lengths_tensor
slots[slots_start_index:slots_end_index] = batch.slots
start_seq[start_index:end_index] = batch.start_seq + max_tokens
end_seq[start_index:end_index] = batch.end_seq + max_tokens
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
max_seqlen = max(max_seqlen, batch.max_seqlen)
block_tables_tensor[
start_index:end_index, : batch.block_tables_tensor.shape[1]
] = batch.block_tables_tensor[:, :max_blocks]
start_slots.append(batch.start_slots + cumulative_slots)
block_tables.extend(batch.block_tables)
all_input_ids.extend(batch.all_input_ids)
input_lengths.extend(batch.input_lengths)
@ -466,43 +592,17 @@ class FlashCausalLMBatch(Batch):
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias)
past_key_values.append(batch.past_key_values)
# Update
cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens
max_length = max(
max_length,
max(
input_length
+ stopping_criteria.max_new_tokens
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
batch.input_lengths, batch.stopping_criterias
)
),
)
cumulative_slots += len(batch.slots)
past_key_values = torch.cat(past_key_values, dim=0)
past_present_indices = end_seq - 1
all_input_ids_tensor = torch.zeros(
(total_batch_size, max_length), dtype=torch.int64, device=device
)
cumulative_batch_size = 0
for i, batch in enumerate(batches):
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
cumulative_batch_size += len(batch)
start_slots = torch.concat(start_slots)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype=dtype, device=device
next_token_chooser_parameters,
dtype=batches[0].next_token_chooser.dtype,
device=batches[0].next_token_chooser.device,
)
return FlashCausalLMBatch(
@ -511,28 +611,33 @@ class FlashCausalLMBatch(Batch):
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
position_ids=position_ids,
past_present_indices=past_present_indices,
start_seq=start_seq,
end_seq=end_seq,
start_seq_prefill=None,
end_seq_prefill=None,
start_seq_q=start_seq_q,
end_seq_q=end_seq_q,
block_tables=block_tables,
block_tables_tensor=block_tables_tensor,
start_slots=start_slots,
slots=slots,
slot_indices=slot_indices,
max_seqlen=max_seqlen,
prefill_head_indices=None,
prefill_next_token_indices=None,
prefill_cu_outlens=None,
past_key_values=past_key_values,
input_lengths=input_lengths,
input_lengths_tensor=input_lengths_tensor,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
max_blocks=max_blocks,
)
def cleanup(self):
global CACHE_MANAGER
# Free blocks
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
def __len__(self):
return len(self.requests)
@ -540,32 +645,24 @@ class FlashCausalLMBatch(Batch):
class FlashCausalLM(Model):
def __init__(
self,
model_cls: Type[PreTrainedModel],
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
trust_remote_code: bool = False,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizerBase,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
rank: int = 0,
world_size: int = 1,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
else:
raise NotImplementedError("FlashCausalLM is only available on GPU")
self.num_heads = num_heads
self.head_size = head_size
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
global CACHE_MANAGER
torch.cuda.set_per_process_memory_fraction(1.0)
CACHE_MANAGER = CacheManager(
1000, num_layers, num_heads, head_size, dtype, device
)
model = model_cls.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
).to(device)
super(FlashCausalLM, self).__init__(
model=model,
@ -573,6 +670,8 @@ class FlashCausalLM(Model):
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
@property
@ -588,28 +687,27 @@ class FlashCausalLM(Model):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
start_seq: torch.Tensor,
end_seq: torch.Tensor,
start_seq_q: Optional[torch.Tensor],
end_seq_q: Optional[torch.Tensor],
start_seq_prefill: Optional[torch.Tensor],
end_seq_prefill: Optional[torch.Tensor],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
max_s: int,
past_present_indices: torch.Tensor,
past_key_values: Optional = None,
pre_allocate_past_size: Optional[int] = None,
lm_head_indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
global CACHE_MANAGER
# Model Forward
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
start_seq=start_seq,
end_seq=end_seq,
start_seq_q=start_seq_q,
end_seq_q=end_seq_q,
start_seq_prefill=start_seq_prefill,
end_seq_prefill=end_seq_prefill,
kv_cache=CACHE_MANAGER.kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
past_present_indices=past_present_indices,
past_key_values=past_key_values,
pre_allocate_past_size=pre_allocate_past_size,
lm_head_indices=lm_head_indices,
)
@ -617,31 +715,18 @@ class FlashCausalLM(Model):
def generate_token(
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None
prefill = batch.start_seq_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
if prefill:
# Ask to pre-allocate kv to its max size
# == Sum over batch size (number of tokens + max_new_tokens) - batch size
pre_allocate_past_size = batch.max_tokens
start_seq = batch.start_seq_prefill
end_seq = batch.end_seq_prefill
else:
pre_allocate_past_size = None
start_seq = batch.start_seq
end_seq = batch.end_seq
out, present = self.forward(
out = self.forward(
batch.input_ids,
batch.position_ids,
start_seq,
end_seq,
batch.start_seq_q,
batch.end_seq_q,
batch.start_seq_prefill,
batch.end_seq_prefill,
batch.block_tables_tensor,
batch.slots[batch.slot_indices],
batch.input_lengths_tensor,
batch.max_seqlen,
batch.past_present_indices,
batch.past_key_values,
pre_allocate_past_size,
batch.prefill_head_indices,
)
@ -662,12 +747,8 @@ class FlashCausalLM(Model):
# When batch == 1, we will just use the batch.input_ids values directly
prefill_tokens_indices = batch.input_ids.new_zeros(len(out))
# Create batch.start_seq_q and batch.end_seq_q for decode
batch.start_seq_q = torch.arange(
0, len(batch), device=self.device, dtype=torch.int32
)
batch.end_seq_q = batch.start_seq_q + 1
next_position_ids = batch.position_ids.new_empty(len(batch))
batch.slot_indices = batch.slot_indices[batch.end_seq_prefill - 1]
# We do not need start_seq_prefill and end_seq_prefill anymore
batch.start_seq_prefill = None
batch.end_seq_prefill = None
@ -731,8 +812,8 @@ class FlashCausalLM(Model):
# Set values in batch
batch.input_ids = next_input_ids
batch.position_ids = next_position_ids + 1
batch.past_present_indices = batch.end_seq
batch.end_seq = batch.end_seq + 1
batch.input_lengths_tensor += 1
batch.slot_indices += 1
if prefill and prefill_logprobs:
# Get prefill logprobs
@ -755,7 +836,6 @@ class FlashCausalLM(Model):
batch.read_offsets,
batch.stopping_criterias,
batch.all_input_ids,
batch.all_input_ids_tensor,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
next_token_ids,
@ -770,7 +850,6 @@ class FlashCausalLM(Model):
read_offset,
stopping_criteria,
all_input_ids,
all_input_ids_tensor,
do_sample,
seed,
next_token_id,
@ -845,19 +924,20 @@ class FlashCausalLM(Model):
generations.append(generation)
new_input_length = input_length + 1
# Update values
batch.input_lengths[i] = new_input_length
batch.input_lengths[i] = input_length + 1
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
if stopped:
batch.cleanup()
# No need to return a batch if we know that all requests stopped
return generations, None
batch.prefill_cu_outlens = None
batch.prefill_head_indices = None
batch.prefill_next_token_indices = None
batch.max_seqlen = batch.max_seqlen + 1
batch.past_key_values = present
# No need to return a batch if we know that all requests stopped
return generations, batch if not stopped else None
return generations, batch

View File

@ -64,10 +64,12 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
super(FlashLlama, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
num_layers=len(model.model.layers),
num_heads=model.model.num_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,

View File

@ -52,8 +52,11 @@ class FlashSantacoderSharded(FlashCausalLM):
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group,
aliases = {"transformer.wte.weight": ["lm_head.weight"]}
filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
)
model = FlashSantacoderForCausalLM(config, weights)

View File

@ -35,6 +35,9 @@ class Batch(ABC):
def concatenate(cls, batches: List["Batch"]) -> "Batch":
raise NotImplementedError
def cleanup(self):
pass
@abstractmethod
def __len__(self):
raise NotImplementedError

View File

@ -216,6 +216,8 @@ class HeterogeneousNextTokenChooser:
self.seeds = seeds
self.do_sample = do_sample
self.dtype = dtype
self.device = device
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor):
if self.watermark_processor is not None:

View File

@ -5,7 +5,14 @@ import torch
class Weights:
def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
def __init__(
self,
filenames: List[Path],
device,
dtype,
process_group,
aliases: Optional[Dict[str, List[str]]] = None,
):
routing = {}
for filename in filenames:
with safe_open(filename, framework="pytorch") as f:
@ -43,7 +50,7 @@ class Weights:
return str(filename), tensor_name
def _get_slice(self, tensor_name: str):
filename, tensor_name= self.get_filename(tensor_name)
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name)
return slice_
@ -94,12 +101,20 @@ class Weights:
def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
if quantize == "gptq":
try:
qweight = torch.cat([self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1)
qweight = torch.cat(
[self.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = torch.cat([self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1)
scales = torch.cat([self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1)
qzeros = torch.cat(
[self.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
scales = torch.cat(
[self.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
w = [self.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
@ -118,7 +133,9 @@ class Weights:
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)