add option to set the watermark gamma & delta from the launcher

This commit is contained in:
OlivierDehaene 2023-03-01 17:14:44 +01:00
parent 299f7367a5
commit 0be0506a7e
2 changed files with 34 additions and 10 deletions

View File

@ -55,6 +55,10 @@ struct Args {
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
#[clap(long, env)] #[clap(long, env)]
cors_allow_origin: Vec<String>, cors_allow_origin: Vec<String>,
#[clap(long, env)]
watermark_gamma: Option<f32>,
#[clap(long, env)]
watermark_delta: Option<f32>,
} }
fn main() -> ExitCode { fn main() -> ExitCode {
@ -88,6 +92,8 @@ fn main() -> ExitCode {
json_output, json_output,
otlp_endpoint, otlp_endpoint,
cors_allow_origin, cors_allow_origin,
watermark_gamma,
watermark_delta,
} = args; } = args;
// Signal handler // Signal handler
@ -243,6 +249,8 @@ fn main() -> ExitCode {
huggingface_hub_cache, huggingface_hub_cache,
weights_cache_override, weights_cache_override,
disable_custom_kernels, disable_custom_kernels,
watermark_gamma,
watermark_delta,
otlp_endpoint, otlp_endpoint,
status_sender, status_sender,
shutdown, shutdown,
@ -414,6 +422,8 @@ fn shard_manager(
huggingface_hub_cache: Option<String>, huggingface_hub_cache: Option<String>,
weights_cache_override: Option<String>, weights_cache_override: Option<String>,
disable_custom_kernels: bool, disable_custom_kernels: bool,
watermark_gamma: Option<f32>,
watermark_delta: Option<f32>,
otlp_endpoint: Option<String>, otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>, status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<Mutex<bool>>, shutdown: Arc<Mutex<bool>>,
@ -494,6 +504,16 @@ fn shard_manager(
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into())) env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
} }
// Watermark Gamma
if let Some(watermark_gamma) = watermark_gamma {
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
}
// Watermark Delta
if let Some(watermark_delta) = watermark_delta {
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
}
// Start process // Start process
tracing::info!("Starting shard {rank}"); tracing::info!("Starting shard {rank}");
let mut p = match Popen::create( let mut p = match Popen::create(

View File

@ -13,17 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import torch import torch
from transformers import LogitsProcessor from transformers import LogitsProcessor
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
class WatermarkLogitsProcessor(LogitsProcessor): class WatermarkLogitsProcessor(LogitsProcessor):
def __init__( def __init__(
self, self,
vocab_size: int, vocab_size: int,
gamma: float = 0.5, gamma: float = GAMMA,
delta: float = 2.0, delta: float = DELTA,
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
device: str = "cpu", device: str = "cpu",
): ):