mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 11:54:52 +00:00
add option to set the watermark gamma & delta from the launcher
This commit is contained in:
parent
299f7367a5
commit
0be0506a7e
@ -55,6 +55,10 @@ struct Args {
|
|||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
#[clap(long, env)]
|
#[clap(long, env)]
|
||||||
cors_allow_origin: Vec<String>,
|
cors_allow_origin: Vec<String>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
watermark_gamma: Option<f32>,
|
||||||
|
#[clap(long, env)]
|
||||||
|
watermark_delta: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> ExitCode {
|
fn main() -> ExitCode {
|
||||||
@ -88,6 +92,8 @@ fn main() -> ExitCode {
|
|||||||
json_output,
|
json_output,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
cors_allow_origin,
|
cors_allow_origin,
|
||||||
|
watermark_gamma,
|
||||||
|
watermark_delta,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
// Signal handler
|
// Signal handler
|
||||||
@ -243,6 +249,8 @@ fn main() -> ExitCode {
|
|||||||
huggingface_hub_cache,
|
huggingface_hub_cache,
|
||||||
weights_cache_override,
|
weights_cache_override,
|
||||||
disable_custom_kernels,
|
disable_custom_kernels,
|
||||||
|
watermark_gamma,
|
||||||
|
watermark_delta,
|
||||||
otlp_endpoint,
|
otlp_endpoint,
|
||||||
status_sender,
|
status_sender,
|
||||||
shutdown,
|
shutdown,
|
||||||
@ -414,6 +422,8 @@ fn shard_manager(
|
|||||||
huggingface_hub_cache: Option<String>,
|
huggingface_hub_cache: Option<String>,
|
||||||
weights_cache_override: Option<String>,
|
weights_cache_override: Option<String>,
|
||||||
disable_custom_kernels: bool,
|
disable_custom_kernels: bool,
|
||||||
|
watermark_gamma: Option<f32>,
|
||||||
|
watermark_delta: Option<f32>,
|
||||||
otlp_endpoint: Option<String>,
|
otlp_endpoint: Option<String>,
|
||||||
status_sender: mpsc::Sender<ShardStatus>,
|
status_sender: mpsc::Sender<ShardStatus>,
|
||||||
shutdown: Arc<Mutex<bool>>,
|
shutdown: Arc<Mutex<bool>>,
|
||||||
@ -494,6 +504,16 @@ fn shard_manager(
|
|||||||
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Watermark Gamma
|
||||||
|
if let Some(watermark_gamma) = watermark_gamma {
|
||||||
|
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watermark Delta
|
||||||
|
if let Some(watermark_delta) = watermark_delta {
|
||||||
|
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
||||||
|
}
|
||||||
|
|
||||||
// Start process
|
// Start process
|
||||||
tracing::info!("Starting shard {rank}");
|
tracing::info!("Starting shard {rank}");
|
||||||
let mut p = match Popen::create(
|
let mut p = match Popen::create(
|
||||||
|
@ -13,17 +13,21 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import LogitsProcessor
|
from transformers import LogitsProcessor
|
||||||
|
|
||||||
|
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
|
||||||
|
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
|
||||||
|
|
||||||
|
|
||||||
class WatermarkLogitsProcessor(LogitsProcessor):
|
class WatermarkLogitsProcessor(LogitsProcessor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
gamma: float = 0.5,
|
gamma: float = GAMMA,
|
||||||
delta: float = 2.0,
|
delta: float = DELTA,
|
||||||
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
):
|
):
|
||||||
|
Loading…
Reference in New Issue
Block a user