mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-09 19:34:53 +00:00
add option to set the watermark gamma & delta from the launcher
This commit is contained in:
parent
299f7367a5
commit
0be0506a7e
@ -55,6 +55,10 @@ struct Args {
|
||||
otlp_endpoint: Option<String>,
|
||||
#[clap(long, env)]
|
||||
cors_allow_origin: Vec<String>,
|
||||
#[clap(long, env)]
|
||||
watermark_gamma: Option<f32>,
|
||||
#[clap(long, env)]
|
||||
watermark_delta: Option<f32>,
|
||||
}
|
||||
|
||||
fn main() -> ExitCode {
|
||||
@ -88,6 +92,8 @@ fn main() -> ExitCode {
|
||||
json_output,
|
||||
otlp_endpoint,
|
||||
cors_allow_origin,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
} = args;
|
||||
|
||||
// Signal handler
|
||||
@ -243,6 +249,8 @@ fn main() -> ExitCode {
|
||||
huggingface_hub_cache,
|
||||
weights_cache_override,
|
||||
disable_custom_kernels,
|
||||
watermark_gamma,
|
||||
watermark_delta,
|
||||
otlp_endpoint,
|
||||
status_sender,
|
||||
shutdown,
|
||||
@ -414,6 +422,8 @@ fn shard_manager(
|
||||
huggingface_hub_cache: Option<String>,
|
||||
weights_cache_override: Option<String>,
|
||||
disable_custom_kernels: bool,
|
||||
watermark_gamma: Option<f32>,
|
||||
watermark_delta: Option<f32>,
|
||||
otlp_endpoint: Option<String>,
|
||||
status_sender: mpsc::Sender<ShardStatus>,
|
||||
shutdown: Arc<Mutex<bool>>,
|
||||
@ -494,6 +504,16 @@ fn shard_manager(
|
||||
env.push(("DISABLE_CUSTOM_KERNELS".into(), "True".into()))
|
||||
}
|
||||
|
||||
// Watermark Gamma
|
||||
if let Some(watermark_gamma) = watermark_gamma {
|
||||
env.push(("WATERMARK_GAMMA".into(), watermark_gamma.to_string().into()))
|
||||
}
|
||||
|
||||
// Watermark Delta
|
||||
if let Some(watermark_delta) = watermark_delta {
|
||||
env.push(("WATERMARK_DELTA".into(), watermark_delta.to_string().into()))
|
||||
}
|
||||
|
||||
// Start process
|
||||
tracing::info!("Starting shard {rank}");
|
||||
let mut p = match Popen::create(
|
||||
|
@ -13,19 +13,23 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import LogitsProcessor
|
||||
|
||||
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
|
||||
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
|
||||
|
||||
|
||||
class WatermarkLogitsProcessor(LogitsProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
gamma: float = 0.5,
|
||||
delta: float = 2.0,
|
||||
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
||||
device: str = "cpu",
|
||||
self,
|
||||
vocab_size: int,
|
||||
gamma: float = GAMMA,
|
||||
delta: float = DELTA,
|
||||
hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
|
||||
device: str = "cpu",
|
||||
):
|
||||
# watermarking parameters
|
||||
self.vocab_size = vocab_size
|
||||
@ -36,7 +40,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def _seed_rng(self, input_ids: torch.LongTensor) -> None:
|
||||
assert (
|
||||
input_ids.shape[-1] >= 1
|
||||
input_ids.shape[-1] >= 1
|
||||
), "requires at least a 1 token prefix sequence to seed rng"
|
||||
prev_token = input_ids[-1].item()
|
||||
self.rng.manual_seed(self.hash_key * prev_token)
|
||||
@ -54,7 +58,7 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
||||
|
||||
@staticmethod
|
||||
def _calc_greenlist_mask(
|
||||
scores: torch.FloatTensor, greenlist_token_ids
|
||||
scores: torch.FloatTensor, greenlist_token_ids
|
||||
) -> torch.BoolTensor:
|
||||
green_tokens_mask = torch.zeros_like(scores)
|
||||
green_tokens_mask[-1, greenlist_token_ids] = 1
|
||||
@ -63,13 +67,13 @@ class WatermarkLogitsProcessor(LogitsProcessor):
|
||||
|
||||
@staticmethod
|
||||
def _bias_greenlist_logits(
|
||||
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
|
||||
scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float
|
||||
) -> torch.Tensor:
|
||||
scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
|
||||
return scores
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
assert len(input_ids) == 1
|
||||
greenlist_ids = self._get_greenlist_ids(input_ids[0])
|
||||
|
Loading…
Reference in New Issue
Block a user