TensorFlow在美团推荐系统中的分布式训练优化实践( 六 )



TensorFlow在美团推荐系统中的分布式训练优化实践

文章插图
图15 Embedding流水线模块交互关系
两张子图的交互关系为:EG向MG传递Embeding向量(从MG的视角看,是从一个稠密Variable读取数值);MG向EG传递Embedding参数对应的梯度 。上述两个过程的表达都是TensorFlow的计算图,我们利用两个线程,两个Session并发的执行两张计算图,使得两个阶段Overlap起来,以此到达了更大的训练吞吐 。

TensorFlow在美团推荐系统中的分布式训练优化实践

文章插图
图16 Embedding流水线架构流程图
上图是Embedding流水线的架构流程图 。直观来看分为左侧的样本分发模块,顶部的跨Session数据交换模块,以及自动图切分得到的Embedding Graph和Main Graph,蓝色的圆圈代表新增算子,橙色箭头代表EG重点流程,蓝色箭头代表MG重点流程,红色箭头代表样本数据重点流程 。
以对用户透明的形式引入了一层名为Pipeline Dataset的抽象层,这一层的产生是为了满足EG/MG两张计算图以不同节奏运行的需求,支持自定义配置 。另外,为了使得整个流水线中的数据做到彼此的配套,这里还会负责进行一个全局Batch ID的生成及注册工作 。Pipeline Dataset对外暴露两种Iterator,一个供EG使用,一个供MG使用 。Pipeline Dataset底部共享TensorFlow原生的各层Dataset 。顶部的ExchangeManager是一个静态的,跨Session的数据交换媒介,对外暴露数据注册和数据拉取的能力 。抽象这个模块的原因是,EG和MG原本归属于一张计算图,因为流水线的原因拆解为拆为两张图,这样我们需要建立一种跨Session的数据交换机制,并准确进行配套 。它内部以全局Batch ID做Key,后面管理了样本数据、Embeding向量、Embedding梯度、Unique后的Index等数据,并负责这些数据的生命周期管理 。中间的Embedding Graph由独立的TF Session运行于一个独立的线程中,通过a算子获得样本数据后,进行特征ID的抽取等动作,并进行基于HashTable方法的稀疏参数查询,查询结果通过c算子放置到ExchangeManager中 。EG中还包含用于反向更新的f算子,它会从ExchangeManager中获取Embedding梯度和与其配套的前向参数,然后执行梯度更新参数逻辑 。下面的Main Graph负责实际稠密子网络的计算,我们继承并实现一种可训练的EmbeddingVariable,它的构建过程(d算子)会从ExchangeManager查找与自己配套的Embedding向量封装成EmbeddingVariable,给稠密子网络 。此外,在EmbeddingVariable注册的反向方法中,我们添加了e算子使得Embedding梯度得以添加到ExchangeManager中,供EG中的f算子消费 。
通过上面的设计,我们就搭建起了一套可控的EG/MG并发流水线训练模式 。总体来看,Embedding流水线训练模式的收益来源有:
经过我们对多个业务模型的Profiling分析发现,EG和MG在时间的比例上在3:7或4:6的左右,通过将这两个阶段并行起来,可以有效的隐藏Embedding阶段,使得MG网络计算部分几乎总是可以立即开始,大大加速了整体模型的训练吞吐 。TensorFlow引擎中当使用多个优化器(稀疏与非稀疏)的时候,会出现重复构建反向计算图的问题,一定程度增加了额外计算,通过两张子图的拆分,恰好避免了这个问题 。在实施过程中的ExchangeManager不仅负责了Embedding参数和梯度的交换,还承担了元数据复用管理的职责 。例如Unique等算子的结果保存,进一步降低了重复计算 。
另外,在API设计上,我们做到了对用户透明,仅需一行代码即可开启Embedding流水线功能,对用户隐藏了EG/MG的切割过程 。目前,在美团某业务训练中,Embedding流水线功能在CPU PS架构下可以带来20%~60%的性能提升(而且Worker并发规模越大,性能越好) 。
3.5 单实例PS并发优化
经过2.2章节的分析可知,我们不能通过持续扩PS来提升分布式任务的吞吐,单实例PS的并发优化,也是非常重要的优化方向 。我们主要的优化工作如下 。
3.5.1 高性能的HashTable
PS架构下,大规模稀疏模型训练对于HashTable的并发读写要求很高,因为每个PS都要承担成百乃至上千个Worker的Embedding压力,这里我们综合速度和稳定性考虑,选用了tbb::concurrent_hash_map[10]作为底层HashTable表实现,并将其包装成一个新的TBBConcurrentHashTable算子 。经过测试,在千亿规模下TBBConcurrentHashTable比原生MutableDenseHashTable训练速度上快了3倍 。
3.5.2 HashTable BucketPool
对于大规模稀疏模型训练来说,Embedding HashTable会面对大量的并发操作,通过Profiling我们发现,频繁动态的内存申请会带来了较大性能开销(即使TensorFlow的Tensor有专门的内存分配器) 。我们基于内存池化的思路优化了HashTable的内存管理 。