# Tensor Parallelism
# 权重切分
对于矩阵乘法,如下X∗W=Y, 其中:
b
: batch_size
, 批量大小s
: sequence_length
,输入序列长度h
: hidden_size/embedding_size
# 按行切分权重
# forward
# backward
∂W1∂L=∂Y∂L∗∂W1∂Y∂W2∂L=∂Y∂L∗∂W2∂Y
对于反向传播,更新W,只需要将∂Y∂L 更新到两块 GPU 内分别更新W1 和W_
X
的 backward
:∂X∂L=concat[∂X1∂L,∂X2∂L]
# 按列切分权重
# forward
# backward
# 参考文章