This commit is contained in:
OlivierDehaene 2023-05-09 18:40:17 +02:00
parent 89565b4eaf
commit f0609e73d8
4 changed files with 8 additions and 1 deletions

View File

@ -1,5 +1,5 @@
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
/// Multi shard Client
use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo};
use crate::{ClientError, Result};
use futures::future::join_all;
use tonic::transport::Uri;
@ -123,6 +123,7 @@ impl ShardedClient {
}
}
/// Merge generations from the different model shards
fn merge_generations(
mut results: Vec<(Vec<Generation>, Option<Batch>)>,
) -> Result<(Vec<Generation>, Option<Batch>)> {

View File

@ -572,6 +572,8 @@ class CausalLM(Model):
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens

View File

@ -690,6 +690,8 @@ class FlashCausalLM(Model):
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Decode generated tokens

View File

@ -653,6 +653,8 @@ class Seq2SeqLM(Model):
if not stop:
stopped = False
# Shard generations
# All generations will be appended in the rust sharded client
if i % self.world_size == self.rank:
if stop:
# Slice with decoder_input_length to remove padding