From f0609e73d8acd0e91442e47c5df10b758d1cd1c9 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Tue, 9 May 2023 18:40:17 +0200 Subject: [PATCH] add docs --- router/client/src/sharded_client.rs | 3 ++- server/text_generation_server/models/causal_lm.py | 2 ++ server/text_generation_server/models/flash_causal_lm.py | 2 ++ server/text_generation_server/models/seq2seq_lm.py | 2 ++ 4 files changed, 8 insertions(+), 1 deletion(-) diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 986ec377..60b81fe6 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,5 +1,5 @@ -use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; /// Multi shard Client +use crate::{Batch, Client, Generation, HealthResponse, Request, ShardInfo}; use crate::{ClientError, Result}; use futures::future::join_all; use tonic::transport::Uri; @@ -123,6 +123,7 @@ impl ShardedClient { } } +/// Merge generations from the different model shards fn merge_generations( mut results: Vec<(Vec, Option)>, ) -> Result<(Vec, Option)> { diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index c029e3ab..610dc4e2 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -572,6 +572,8 @@ class CausalLM(Model): if not stop: stopped = False + # Shard generations + # All generations will be appended in the rust sharded client if i % self.world_size == self.rank: if stop: # Decode generated tokens diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index eefa5be9..e862cfeb 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -690,6 +690,8 @@ class FlashCausalLM(Model): if not stop: stopped = False + # Shard generations + # All generations will be appended in the rust sharded client if i % self.world_size == self.rank: if stop: # Decode generated tokens diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index ed658126..d4a0ddcc 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -653,6 +653,8 @@ class Seq2SeqLM(Model): if not stop: stopped = False + # Shard generations + # All generations will be appended in the rust sharded client if i % self.world_size == self.rank: if stop: # Slice with decoder_input_length to remove padding