Step1. 构建支配树支配点算法LCA(a,b)Step2. FuseOP分类kElemWisekBroadcast kInjective kCommReducekOutEWiseFusable kTuple
kOpaque Combine patternRunFusePatternFuse逻辑Codegen for fused op
Step1. 构建支配树
对一个节点A的支配节点S,其实就是A的所有input节点的LCA(最近公共祖先)。
支配点算法
TVM使用的算法:
LCA(a,b)
TVM使用的也是最暴力的办法,两个节点向上跳,直至相遇。这里也有更优的算法。
Step2. Fuse
OP分类
enum OpPatternKind { // Elementwise operation kElemWise = 0, // Broadcasting operator, can always map output axis to the input in order. // for example :code:`out[i, ax1, j, ax2] = input[i, j]`. // Note that the axis need to be in order so transpose is not a bcast operator. kBroadcast = 1, // Injective operator, can always injectively map output axis to a single input axis. // All injective operator can still be safely fused to injective and reduction. kInjective = 2, // Communicative reduction operator. kCommReduce = 3, // Complex operation, can still fuse elemwise operations into its output. // but cannot chain another complex op kOutEWiseFusable = 4, // The pattern for tuple nodes. Can fuse into subsequent injective ops, // but treated specially kTuple = 7, // Opaque operation, cannot fuse anything. kOpaque = 8 };
下面举一些op例子:
kElemWise
RELAY_REGISTER_UNARY_OP("log") .describe(R"code(Returns the log input array, computed element-wise. )code" TVM_ADD_FILELINE) .set_support_level(1) .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); RELAY_REGISTER_UNARY_OP("log2") .describe(R"code(Returns the log to base 2 of input array, computed element-wise. )code" TVM_ADD_FILELINE) .set_support_level(1) .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2)); RELAY_REGISTER_UNARY_OP("log10") .describe(R"code(Returns the log to base 10 of input array, computed element-wise. )code" TVM_ADD_FILELINE) .set_support_level(1) .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10)); RELAY_REGISTER_UNARY_OP("tan") .describe(R"code(Returns the tan of input array, computed element-wise. )code" TVM_ADD_FILELINE) .set_support_level(1) .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan));
kBroadcast
// Addition RELAY_REGISTER_BINARY_OP("add") .describe("Elementwise add with broadcasting") .set_support_level(1) .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add)); // Subtraction RELAY_REGISTER_BINARY_OP("subtract") .describe("Elementwise substract with broadcasting") .set_support_level(1) .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract)); // Right shift RELAY_REGISTER_BINARY_OP("right_shift") .describe("Elementwise right shift with broadcasting") .set_support_level(4) .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift));
其实上面的这些kBroadcast也是Elementwise的
kInjective
RELAY_REGISTER_OP("image.resize2d") .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. RELAY_REGISTER_OP("concatenate") .describe(R"code(Concatenate the input tensors along the given axis. RELAY_REGISTER_OP("transpose") .describe(R"code(Permutes the dimensions of an array. RELAY_REGISTER_OP("reshape") .describe(R"code(Reshapes the input array.
我理解就是output的每一维是和input的某一维度有映射关系?相比broadcast, injective不要求ouput的维度和input的维度保持顺序一致。
kCommReduce
RELAY_REGISTER_REDUCE_OP("sum") .describe(R"code(Computes the sum of array elements over given axes. RELAY_REGISTER_REDUCE_OP("max") .describe(R"code(Computes the max of array elements over given axes. RELAY_REGISTER_REDUCE_OP("min") .describe(R"code(Computes the min of array elements over given axes.
kOutEWiseFusable
{"nn.matmul", OpPatternKind::kOutEWiseFusable}, {"nn.conv1d", OpPatternKind::kOutEWiseFusable}, {"nn.conv2d", OpPatternKind::kOutEWiseFusable}, {"nn.max_pool3d", OpPatternKind::kOutEWiseFusable},
kTuple kOpaque
RELAY_REGISTER_OP("shape_of") .describe(R"code(Returns a tensor representing the shape of a tensor.
Combine pattern
// Combine two patterns together. static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { if (lhs > kBroadcast && rhs > kBroadcast) { LOG(FATAL) << "Cannot merge two complex group together"; } if (lhs > rhs) return lhs; return rhs; }
- 可以看见不能将两个高于kBroadcast类型的op合并到一起
- 两种类型的OP在fuse之后的类型为其中类型更高的一个。
RunFuse
核心函数实现在:
tvm::relay::GraphPartitioner::RunFuse
总共这里分了3个phase
Pattern
- (phase 0)
- (phase 0,1,2)
- (phase 1)
这里分了3个phase的原因我认为主要是为了控制pattern间的顺序。
Fuse逻辑
因为在Fuse时候是在一棵树上进行的,在选择和哪个/些节点时,有很多种方案,同时还要避免成环。
上面的pattern可以总结为:
起始节点 (+中间节点)+ 结束节点
TVM的要求起始节点必须是结束节点的支配节点。
- 检查当前节点是否匹配pattern的起始节点
- 检查当前节点的支配节点是否匹配pattern的结束节点
- 检查起始节点到结束节点的所有路径上的节点,是否匹配中间节点的要求
Codegen for fused op
To be done