Toccata in Nowhere.

mpi4py MPI 并行的 Python 封装

2020.07.04

消息传递接口(Message Passing Interface, MPI)是一种并行计算的API,其支持点对点传播、广播、分散与聚集等在集群中多进程之间的交互操作。得益于Python易用的numpy数组操作,MPI并行的使用与实现相较于 C / C++ 简单而易于理解。

依赖

系统依赖:OpenMPI

计算机环境中的 mpirun支持,即需要安装open MPI。 对于 macOS,可使用 homebrew直接安装:

brew install open-mpi

Python 依赖:mpi4py

安装:

pip install mpi4py

使用:

from mpi4py import MPI                                                     
comm = MPI.COMM_WORLD                                                      
rank = comm.Get_rank()                                                     
size = comm.Get_size()

以上方法可直接获取该程序所在进程的 rank 与并行总量 size

运行方法

与 C / C++的程序相同,需要使用 mpirun指定并行数。

mpirun -np 4 python file_name.py

以上例中使用 4 个进程进行 MPI 并行,可并行进程数目与硬件相关。

使用例

点对点通信

阻塞通信 Send / Recv

comm = MPI.COMM_WORLD                                                  
rank = comm.Get_rank()                                                 
size = comm.Get_size()                                                 
                                                                       
if rank == 0:                                                          
    data = np.arange(10, dtype='i')                                    
    comm.Send([data, MPI.INT], dest=1, tag=11)                         
    print(str(rank) + " send: " + str(data))
else:                                                                  
    data = np.empty(10, dtype='i')                                     
    comm.Recv([data, MPI.INT], source=0, tag=11)                       
    print(str(rank) + " recv: " + str(data))

以上例子中由 rank 0rank 1 发送了一条消息,包含有一个长度为 10 的 intnumpy 数组。

组内通信(多对一,多对多,一对多)

广播 bcast

一对多通信,以下例子将 rank 0 中数据分发100%复制到其他所有 rank 中:

comm = MPI.COMM_WORLD                                                      
rank = comm.Get_rank()                                                     
size = comm.Get_size()                                                     
                                                                           
if rank == 0:                                                              
    data = range(10)                                                       
    print(str(rank) + " broadcast: " + str(data))
else:                                                                      
    data = None                                                            
data = comm.bcast(data, root=0)                                            
print(str(rank) + " recv: " + str(data))

发散 scatter

相比于bcast, scatter 将数据分块后复制到各个rank中。


from mpi4py import MPI                                                            
import numpy as np      

comm = MPI.COMM_WORLD                                                             
rank = comm.Get_rank()                                                            
size = comm.Get_size()                                                            
                                                                                  
recv_data = None                                                                  
                                                                                  
if rank == 0:                                                                     
    send_data = range(4)                                                         
    print(str(rank) + " scatter: " + str(send_data))
else:                                                                             
    send_data = None                                                              
recv_data = comm.scatter(send_data, root=0)                                       
print(str(rank) + " recv: " + str(recv_data))

集合 gather

scatter 相反, gather 集合所有 rank 的数据到 root进程中:

comm = MPI.COMM_WORLD                                               
rank = comm.Get_rank()                                              
size = comm.Get_size()                                              
                                                                    
send_data = rank                                                    
print(str(rank) + " send gather data: " + str(send_data))
recv_data = comm.gather(send_data, root=0)                          
if rank == 0:                                                       
    print(str(rank) + " recv gather data: " + str(recv_data))

值得注意的是,在使用 gather 时,numpy 数组的聚合会新增一个维度用以区分来自不同 rank的数据,相比于 C / C++ 的叠放更为灵活。