feat(server): shard token decode (#303)

This commit is contained in:
OlivierDehaene 2023-05-10 15:48:21 +02:00 committed by GitHub
parent 1585404464
commit 68e9d6ab33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 224 additions and 178 deletions

View File

@ -23,6 +23,8 @@ pub enum ClientError {
Connection(String), Connection(String),
#[error("Server error: {0}")] #[error("Server error: {0}")]
Generation(String), Generation(String),
#[error("Sharded results are empty")]
EmptyResults,
} }
impl From<Status> for ClientError { impl From<Status> for ClientError {

View File

@ -1,6 +1,6 @@
/// Multi shard Client /// Multi shard Client
use crate::Result;
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
use crate::{ClientError, Result};
use futures::future::join_all; use futures::future::join_all;
use tonic::transport::Uri; use tonic::transport::Uri;
use tracing::instrument; use tracing::instrument;
@ -98,8 +98,9 @@ impl ShardedClient {
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone()))) .map(|client| Box::pin(client.prefill(batch.clone())))
.collect(); .collect();
// all shards return the same message let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
join_all(futures).await.pop().unwrap() join_all(futures).await.into_iter().collect();
merge_generations(results?)
} }
/// Generate one token for each request in the given cached batches /// Generate one token for each request in the given cached batches
@ -116,7 +117,20 @@ impl ShardedClient {
.iter_mut() .iter_mut()
.map(|client| Box::pin(client.decode(batches.clone()))) .map(|client| Box::pin(client.decode(batches.clone())))
.collect(); .collect();
// all shards return the same message let results: Result<Vec<(Vec<Generation>, Option<Batch>)>> =
join_all(futures).await.pop().unwrap() join_all(futures).await.into_iter().collect();
merge_generations(results?)
} }
} }
/// Merge generations from the different model shards
fn merge_generations(
mut results: Vec<(Vec<Generation>, Option<Batch>)>,
) -> Result<(Vec<Generation>, Option<Batch>)> {
let (mut generations, next_batch) = results.pop().ok_or(ClientError::EmptyResults)?;
for (mut shard_generations, _) in results.into_iter() {
generations.append(&mut shard_generations);
}
Ok((generations, next_batch))
}

View File

@ -63,10 +63,10 @@ class BLOOMSharded(BLOOM):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
device = torch.device("cpu") device = torch.device("cpu")
@ -94,8 +94,8 @@ class BLOOMSharded(BLOOM):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval() self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -105,6 +105,8 @@ class BLOOMSharded(BLOOM):
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=1, decode_buffer=1,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -549,7 +549,7 @@ class CausalLM(Model):
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
all_input_ids.view(1, -1), logits all_input_ids.view(1, -1), logits[-1:, :]
) )
# Append next token to all tokens # Append next token to all tokens
@ -569,6 +569,12 @@ class CausalLM(Model):
next_token_text, next_token_text,
) )
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop: if stop:
# Decode generated tokens # Decode generated tokens
output_text = self.decode( output_text = self.decode(
@ -584,16 +590,16 @@ class CausalLM(Model):
output_text, stopping_criteria.current_tokens, reason, seed output_text, stopping_criteria.current_tokens, reason, seed
) )
else: else:
# Keep request in the batch
generated_text = None generated_text = None
stopped = False
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:
# Remove generated token to only have prefill and add nan for first prompt token # Remove generated token to only have prefill and add nan for first prompt token
prefill_logprobs = [float("nan")] + logprobs.gather( prefill_logprobs = [float("nan")] + torch.log_softmax(
1, all_input_ids[1:] logits, -1
).squeeze(1)[-new_input_length:-1].tolist() ).gather(1, all_input_ids[1:]).squeeze(1)[
-new_input_length:-1
].tolist()
prefill_token_ids = all_input_ids[-new_input_length:-1] prefill_token_ids = all_input_ids[-new_input_length:-1]
prefill_texts = self.tokenizer.batch_decode( prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids, prefill_token_ids,

View File

@ -622,10 +622,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
self.process_group = process_group self.process_group = process_group
if self.process_group is not None: if self.process_group is not None:
self.world_size = self.process_group.size() self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
else: else:
self.world_size = 1 self.world_size = 1
self.rank = 0
self.model = FlashLlamaModel(config, process_group) self.model = FlashLlamaModel(config, process_group)

View File

@ -685,10 +685,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
self.process_group = process_group self.process_group = process_group
if self.process_group is not None: if self.process_group is not None:
self.world_size = self.process_group.size() self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
else: else:
self.world_size = 1 self.world_size = 1
self.rank = 0
self.gpt_neox = FlashGPTNeoXModel(config, process_group) self.gpt_neox = FlashGPTNeoXModel(config, process_group)

View File

@ -687,6 +687,12 @@ class FlashCausalLM(Model):
next_token_text, next_token_text,
) )
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop: if stop:
# Decode generated tokens # Decode generated tokens
output_text = self.decode( output_text = self.decode(
@ -702,7 +708,6 @@ class FlashCausalLM(Model):
output_text, stopping_criteria.current_tokens, reason, seed output_text, stopping_criteria.current_tokens, reason, seed
) )
else: else:
stopped = False
generated_text = None generated_text = None
# Prefill # Prefill
@ -734,6 +739,7 @@ class FlashCausalLM(Model):
) )
generations.append(generation) generations.append(generation)
new_input_length = input_length + 1 new_input_length = input_length + 1
# Update values # Update values

View File

@ -157,10 +157,10 @@ class FlashLlamaSharded(FlashLlama):
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.past_pad = None self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")
@ -190,8 +190,8 @@ class FlashLlamaSharded(FlashLlama):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval().to(device) self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -200,6 +200,8 @@ class FlashLlamaSharded(FlashLlama):
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -34,10 +34,10 @@ class FlashNeoXSharded(FlashNeoX):
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.past_pad = None self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16
else: else:
raise NotImplementedError("FlashNeoX is only available on GPU") raise NotImplementedError("FlashNeoX is only available on GPU")
@ -64,8 +64,8 @@ class FlashNeoXSharded(FlashNeoX):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval().to(device) self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -74,6 +74,8 @@ class FlashNeoXSharded(FlashNeoX):
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -174,10 +174,10 @@ class FlashSantacoderSharded(FlashSantacoder):
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.past_pad = None self.past_pad = None
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 dtype = torch.float16
else: else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU") raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
@ -204,8 +204,8 @@ class FlashSantacoderSharded(FlashSantacoder):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
transpose=config.architectures[0].startswith("GPT2"), transpose=config.architectures[0].startswith("GPT2"),
) )
self.model = model.eval().to(device) self.model = model.eval().to(device)
@ -215,6 +215,8 @@ class FlashSantacoderSharded(FlashSantacoder):
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -195,10 +195,10 @@ class GalacticaSharded(Galactica):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
device = torch.device("cpu") device = torch.device("cpu")
@ -226,8 +226,8 @@ class GalacticaSharded(Galactica):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval() self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -236,6 +236,8 @@ class GalacticaSharded(Galactica):
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -34,10 +34,10 @@ class GPTNeoxSharded(CausalLM):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
device = torch.device("cpu") device = torch.device("cpu")
@ -65,8 +65,8 @@ class GPTNeoxSharded(CausalLM):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval() self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -75,6 +75,8 @@ class GPTNeoxSharded(CausalLM):
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -18,6 +18,8 @@ class Model(ABC):
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
decode_buffer: int = 3, decode_buffer: int = 3,
rank: int = 0,
world_size: int = 1,
): ):
if decode_buffer < 1: if decode_buffer < 1:
raise ValueError("decode_buffer must be >= 1") raise ValueError("decode_buffer must be >= 1")
@ -28,6 +30,8 @@ class Model(ABC):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.decode_buffer = decode_buffer self.decode_buffer = decode_buffer
self.rank = rank
self.world_size = world_size
@property @property
def info(self) -> InfoResponse: def info(self) -> InfoResponse:

View File

@ -50,10 +50,10 @@ class OPTSharded(OPT):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
device = torch.device("cpu") device = torch.device("cpu")
@ -81,8 +81,8 @@ class OPTSharded(OPT):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval() self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -91,6 +91,8 @@ class OPTSharded(OPT):
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -631,7 +631,7 @@ class Seq2SeqLM(Model):
) in enumerate(iterator): ) in enumerate(iterator):
# Select next token # Select next token
next_token_id, logprobs = next_token_chooser( next_token_id, logprobs = next_token_chooser(
all_decoder_input_ids.view(1, -1), logits all_decoder_input_ids.view(1, -1), logits[-1:, :]
) )
# Append next token to decoder tokens # Append next token to decoder tokens
@ -650,10 +650,18 @@ class Seq2SeqLM(Model):
# Evaluate stopping criteria # Evaluate stopping criteria
stop, reason = stopping_criteria(next_token_id, next_token_text) stop, reason = stopping_criteria(next_token_id, next_token_text)
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop: if stop:
# Slice with decoder_input_length to remove padding # Slice with decoder_input_length to remove padding
# Decode all tokens # Decode all tokens
output_text = self.decode(all_decoder_input_ids[-decoder_input_length:]) output_text = self.decode(
all_decoder_input_ids[-decoder_input_length:]
)
# Get seed # Get seed
if isinstance(next_token_chooser.choice, Sampling): if isinstance(next_token_chooser.choice, Sampling):
@ -665,9 +673,7 @@ class Seq2SeqLM(Model):
output_text, stopping_criteria.current_tokens, reason, seed output_text, stopping_criteria.current_tokens, reason, seed
) )
else: else:
# Keep request in the batch
generated_text = None generated_text = None
stopped = False
# Prefill # Prefill
if stopping_criteria.current_tokens == 1: if stopping_criteria.current_tokens == 1:

View File

@ -34,10 +34,10 @@ class T5Sharded(Seq2SeqLM):
def __init__( def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False self, model_id: str, revision: Optional[str] = None, quantize: bool = False
): ):
self.process_group, self.rank, self.world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
self.master = self.rank == 0 self.master = rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else: else:
device = torch.device("cpu") device = torch.device("cpu")
@ -65,8 +65,8 @@ class T5Sharded(Seq2SeqLM):
quantize=quantize, quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=rank,
world_size=self.world_size, world_size=world_size,
) )
self.model = model.eval() self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
@ -75,6 +75,8 @@ class T5Sharded(Seq2SeqLM):
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank,
world_size=world_size,
) )
@staticmethod @staticmethod

View File

@ -75,10 +75,6 @@ class NextTokenChooser:
def __call__(self, input_ids, scores): def __call__(self, input_ids, scores):
# Warp logits # Warp logits
if scores.shape[0] > 1:
# only warp the last token logits
scores[-1:, :] = self.warpers(input_ids, scores[-1:, :])
else:
scores = self.warpers(input_ids, scores) scores = self.warpers(input_ids, scores)
# Compute logprobs # Compute logprobs