OpenAI魔法模型DALL-E论文、代码公布( 二 )


基于以上问题 , 我们借用Oord和Razavi在2017和2019年的工作:两阶段训练法 , 进行尝试解决 。
阶段1:训练一个离散变分自动编码器(DVAE) , 将每个256×256RGB图像压缩成一个32×32的图像token网络 , 每个网格的每个元素可以取8192个可能的值 。 这一阶段会让transformer的上下文尺寸(contextsize)减少192倍 , 同时还不会大幅降低“视觉”质量 。
阶段2:将256个BPE编码的文本token与32×32=1024图片tokens连接起来 , 然后训练一个自回归transformer对文本和图像的联合分布进行建模 。

OpenAI魔法模型DALL-E论文、代码公布
文章图片
图1:原始图像(上图)和离散VAE重建图像(下图)的比较

OpenAI魔法模型DALL-E论文、代码公布
文章图片
建模公式如上图所示 , 整体可以看成联合分布的似然函数 , x代表图像 , y代表图像的标题 , z代表token , 使用因式分解p_θψ(x , y , z)=p_θ(x|y , z)pψ(y , z)对该分布进行建模 , 得到下界 。 其中:
q_φ表示在给定RGB图像x2的情况下 , 由DVAE编码器生成的32×32图像token上的分布
p_θ表示由DVAE解码器在给定图像token的情况下生成的RGB图像上的分布
p_ψ表示文本和图像token在transformer建模中得到的联合分布 。
值得一提的是 , 这个界(bound)只在β=1时成立 , 实际上 , 使用更大的β值非常有好处 。
阶段1:学习视觉编码
在阶段1的训练中 , 针对φ和θ最大化ELB(evidencelowerbound) , 这相当于在图像上训练DVAE 。 一开始将p_ψ设置为K=8192个向量上的均匀分类分布 , q_φ为编码器输出的32×32网格中同一空间位置上的8192个logits参数化的分类分布 。
但ELB难以优化:因为q_ψ是一个离散分布 , 不能使用重参数化技巧进行最大化 。 有人使用在线聚类分配程序加上直通估计器来解决这个问题 。 我们还使用Gumbel-Softmax技巧转换q_φ 。 此外 , 条件放松的ELB使用Adam与指数加权迭代平均法进行最大化 。 其中 , 在编码器的末端和解码器的开始使用1×1卷积;将编码器和解码器重分块的输出激活乘以一个小常量等是非常重要的技巧和参数 。
阶段2:学习先验
在第二阶段 , 修正了φ和θ , 并通过最大化关于ψ的ELB来学习文本和图像token的先验分布 , 其中p_ψ由含有120亿个参数的稀疏transformer进行表示 。
给定一个文本-图像对 , 最多使用256个词汇大小(vocabularysize)为16384的tokens对小写标题进行BPE编码 , 并使用32×32=1024个词汇大小为8192的tokens对图像进行编码 。 图像token是通过使用argmax采样从DVAE编码器获得的 , 没有添加任何gumbel噪声 。 最后 , 文本和图像token进行连接 , 并作为一个单一的数据流进行自回归建模 。
我们通过一堆数据中各个种类的总数 , 对文本-图像token的交叉熵损失进行了归一化 。 因为我们主要对图像建模感兴趣 , 因此我们将文本的交叉熵损失乘以1/8 , 将图像的交叉熵损失乘以7/8 。 目标则通过使用Adam算法 , 以指数加权的迭代平均法进行了优化 。 我们大概用了606,000张图像用于验证 , 但在收敛时没有发现过度拟合现象 。
数据收集
我们在一个包含330万个文本-图像对的数据集ConceptualCaptions上对模型进行了高达12亿参数的初步实验 。
为了扩展到120亿个参数 , 我们从互联网上收集了2.5亿个文本-图像对 , 创建了一个与JFT-300M规模相当的数据集 。 该数据集不包括MS-COCO , 但包含了ConceptualCaptions数据集和YFCC100M的一个过滤子集 。 由于MS-COCO是基于YFCC100M创建的 , 我们的训练数据还包含了一部分MS-COCO验证图像(但没有caption部分) 。
混合精度训练
为了节省GPU内存并提高吞吐量 , 大多数参数、Adam矩和激活都以16位精度存储 。 我们还使用激活checkpointing , 并在向后传递期间重新计算resblock中的激活 。 我们还使模型以16位精度对10亿个参数进行无差异训练 , 这是该项目最具挑战性的部分 。
分布式优化
当以16位精度存储时 , 我们的120亿参数模型需要消耗约24GB的显存 , 这超过了NVIDIAV10016GB的显存 。 我们使用参数分片(parametersharding)来解决这个问题 。
如图5所示 , 参数分片允许我们通过将其与计算密集型操作重叠 , 从而几乎可以完全忽略机器内通信的延迟 。

OpenAI魔法模型DALL-E论文、代码公布
文章图片
图5:用于分布式训练的通信模式 。
样本生成
我们使用预训练的对比模型(Radfordetal.,2021)对从transformer提取的样本进行重新排序 。 给定字幕和候选图像后 , 对比模型会根据图像与字幕的匹配程度来分配分数 。 图6显示了增加样本数量N的效果 , 我们从中选择了前k个图像 。 这个过程可以看作是一种语言指导的搜索(Andreasetal.,2017) , 也类似于辅助文本-图像匹配损失(Xuetal.,2018) 。