# in easynlp/appzoo/text2image_generation/model.py
@torch.no_grad() def encode_to_z(self, x): quant_z, _, info = self.first_stage_model.encode(x) indices = info[2].view(quant_z.shape[0], -1) return quant_z, indices
x = inputs['image'] x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) # one step to produce the logits _, z_indices = self.encode_to_z(x) # z_indice: torch.Size([batch_size, 256])
VQModel的Decoding阶段过程为:
# in easynlp/appzoo/text2image_generation/model.py
@torch.no_grad() def decode_to_img(self, index, zshape): bhwc = (zshape[0],zshape[2],zshape[3],zshape[1]) quant_z = self.first_stage_model.quantize.get_codebook_entry( index.reshape(-1), shape=bhwc) x = self.first_stage_model.decode(quant_z) return x
# in easynlp/appzoo/text2image_generation/model.py
def forward(self, inputs): x = inputs['image'] c = inputs['text'] x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) # one step to produce the logits _, z_indices = self.encode_to_z(x) # z_indice: torch.Size([batch_size, 256]) c_indices = c
if self.training and self.pkeep < 1.0: mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape, device=z_indices.device)) mask = mask.round().to(dtype=torch.int64) r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size) a_indices = mask*z_indices+(1-mask)*r_indices
else: a_indices = z_indices cz_indices = torch.cat((c_indices, a_indices), dim=1) # target includes all sequence elements (no need to handle first one # differently because we are conditioning) target = z_indices # make the prediction logits, _ = self.transformer(cz_indices[:, :-1]) # cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c) logits = logits[:, c_indices.shape[1]-1:] return logits, target
# in easynlp/appzoo/text2image_generation/predictor.py
def preprocess(self, in_data): if not in_data: raise RuntimeError("Input data should not be None.")
if not isinstance(in_data, list): in_data = [in_data] rst = {"idx": [], "input_ids": []} max_seq_length = -1 for record in in_data: if "sequence_length" not in record: break max_seq_length = max(max_seq_length, record["sequence_length"]) max_seq_length = self.sequence_length if (max_seq_length == -1) else max_seq_length
for record in in_data: text= record[self.first_sequence] try: self.MUTEX.acquire() text_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) text_ids = text_ids[: self.text_len] n_pad = self.text_len - len(text_ids) text_ids += [self.pad_id] * n_pad text_ids = np.array(text_ids) + self.img_vocab_size
# in easynlp/appzoo/text2image_generation/model.py
def generate(self, inputs, top_k=100, temperature=1.0): cidx = inputs sample = True steps = 256 for k in range(steps): x_cond = cidx logits, _ = self.transformer(x_cond) # pluck the logits at the final step and scale by temperature logits = logits[:, -1, :] / temperature # optionally crop probabilities to only the top k options if top_k is not None: logits = self.top_k_logits(logits, top_k) # apply softmax to convert to probabilities probs = torch.nn.functional.softmax(logits, dim=-1) # sample from the distribution or take the most likely if sample: ix = torch.multinomial(probs, num_samples=1) else: _, ix = torch.topk(probs, k=1, dim=-1) # append to the sequence and continue cidx = torch.cat((cidx, ix), dim=1) img_idx = cidx[:, 32:] return img_idx
Chengyu Wang, Minghui Qiu, Taolin Zhang, Tingting Liu, Lei Li, Jianing Wang, Ming Wang, Jun Huang, Wei Lin. EasyNLP: A Comprehensive and Easy-to-use Toolkit for Natural Language Processing. arXiv
Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, Ilya Sutskever. Zero-Shot Text-to-Image Generation. ICML 2021: 8821-8831
Ming Ding, Zhuoyi Yang, Wenyi Hong, Wendi Zheng, Chang Zhou, Da Yin, Junyang Lin, Xu Zou, Zhou Shao, Hongxia Yang, Jie Tang. CogView: Mastering Text-to-Image Generation via Transformers. NeurIPS 2021: 19822-19835
Peng Wang, An Yang, Rui Men, Junyang Lin, Shuai Bai, Zhikang Li, Jianxin Ma, Chang Zhou, Jingren Zhou, Hongxia Yang. Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework. ICML 2022
Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. Hierarchical Text-Conditional Image Generation with CLIP Latents. arXiv
Van Den Oord A, Vinyals O. Neural discrete representation learning. NIPS 2017
Esser P, Rombach R, Ommer B. Taming transformers for high-resolution image synthesis. CVPR 2021: 12873-12883.
Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho, David J. Fleet, Mohammad Norouzi: Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding. arXiv
Jiahui Yu, Yuanzhong Xu, Jing Yu Koh, Thang Luong, Gunjan Baid, Zirui Wang, Vijay Vasudevan, Alexander Ku, Yinfei Yang, Burcu Karagol Ayan, Ben Hutchinson, Wei Han, Zarana Parekh, Xin Li, Han Zhang, Jason Baldridge, Yonghui Wu. Scaling Autoregressive Models for Content-Rich Text-to-Image Generation. arXiv