# Author: Hanzhou Wu # EMAIL: h.wu.phd@ieee.org # HTML: https://hzwu.github.io/ # TIME: 2024/09/22 # ADDRESS: Shanghai University import torch import random import numpy as np from transformers import BertForMaskedLM, BertTokenizer def set_seed(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False tokenizer = BertTokenizer.from_pretrained('./pt-models/bert-base-uncased') model = BertForMaskedLM.from_pretrained('./pt-models/bert-base-uncased').cuda() model.eval() start = tokenizer.cls_token_id end = tokenizer.sep_token_id mask = tokenizer.mask_token_id def single_token_sampling(logit, top_k=0, overall=False): if top_k > 0: top_k_vals, top_k_idxs = logit.topk(top_k, dim=-1) top_k_dist = torch.distributions.categorical.Categorical(logits=top_k_vals) ret_idx = top_k_idxs.gather(dim=-1, index=top_k_dist.sample().unsqueeze(-1)).squeeze(-1) elif overall: overall_dist = torch.distributions.categorical.Categorical(logits=logit) ret_idx = overall_dist.sample().squeeze(-1) else: ret_idx = torch.argmax(logit, dim=-1) return ret_idx def bert_only(token_indices, given_tokens, text_length, top_k=1, overall=False): masked_tokens = [mask] * text_length token_indices.sort() for each_idx, each_token in zip(token_indices, given_tokens): masked_tokens[each_idx] = tokenizer.convert_tokens_to_ids(each_token) indices_to_be_filled = list(set(list(range(text_length))) - set(token_indices)) entire_indices = torch.tensor([start] + masked_tokens + [end]).unsqueeze(0).cuda() entire_logits = model(entire_indices).logits for each_idx in indices_to_be_filled: each_logit = entire_logits[0][each_idx + 1] sample_idx = single_token_sampling(each_logit, top_k, overall) masked_tokens[each_idx] = sample_idx return tokenizer.convert_ids_to_tokens(masked_tokens) def bert_gibbs(token_indices, given_tokens, text_length, top_k=1, overall=False, turns=42): masked_tokens = [mask] * text_length token_indices.sort() for each_idx, each_token in zip(token_indices, given_tokens): masked_tokens[each_idx] = tokenizer.convert_tokens_to_ids(each_token) indices_to_be_filled = list(set(list(range(text_length))) - set(token_indices)) for _ in range(turns): for each_idx in indices_to_be_filled: masked_tokens[each_idx] = mask entire_indices = torch.tensor([start] + masked_tokens + [end]).unsqueeze(0).cuda() entire_logits = model(entire_indices).logits each_logit = entire_logits[0][each_idx + 1] sample_idx = single_token_sampling(each_logit, top_k, overall) masked_tokens[each_idx] = sample_idx return tokenizer.convert_ids_to_tokens(masked_tokens) set_seed() token_indices = [3, 7] given_tokens = ['come', 'on'] text_length = random.randint(10, 20) set_seed() bert_gibbs_text = bert_gibbs(token_indices, given_tokens, text_length, top_k=0, overall=False) print(given_tokens, token_indices, text_length) print('bert_gibbs:', bert_gibbs_text) print('bert_gibbs:', ' '.join(bert_gibbs_text)) set_seed() bert_only_text = bert_only(token_indices, given_tokens, text_length, top_k=0, overall=False) print(given_tokens, token_indices, text_length) print('bert_only:', bert_only_text) print('bert_only:', ' '.join(bert_only_text))