TVM Fuse

Tags
AI summary
TVM Fuse是一个用于图优化的工具,它通过构建支配树和应用不同的模式来实现节点融合。支配树定义了节点之间的支配关系,而模式定义了可以融合的节点类型。Fuse逻辑根据模式的要求选择和融合节点,以优化计算图的性能。Fuse的过程分为三个阶段,每个阶段根据不同的模式进行节点的选择和融合。通过Fuse,可以将多个节点合并为一个节点,减少计算和内存开销。(ai summary)

Step1. 构建支配树

对一个节点A的支配节点S,其实就是A的所有input节点的LCA(最近公共祖先)。

支配点算法

TVM使用的算法:

LCA(a,b)

TVM使用的也是最暴力的办法,两个节点向上跳,直至相遇。这里也有更优的算法

Step2. Fuse

OP分类

enum OpPatternKind { // Elementwise operation kElemWise = 0, // Broadcasting operator, can always map output axis to the input in order. // for example :code:`out[i, ax1, j, ax2] = input[i, j]`. // Note that the axis need to be in order so transpose is not a bcast operator. kBroadcast = 1, // Injective operator, can always injectively map output axis to a single input axis. // All injective operator can still be safely fused to injective and reduction. kInjective = 2, // Communicative reduction operator. kCommReduce = 3, // Complex operation, can still fuse elemwise operations into its output. // but cannot chain another complex op kOutEWiseFusable = 4, // The pattern for tuple nodes. Can fuse into subsequent injective ops, // but treated specially kTuple = 7, // Opaque operation, cannot fuse anything. kOpaque = 8 };
下面举一些op例子:

kElemWise

RELAY_REGISTER_UNARY_OP("log") .describe(R"code(Returns the log input array, computed element-wise. )code" TVM_ADD_FILELINE) .set_support_level(1) .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); RELAY_REGISTER_UNARY_OP("log2") .describe(R"code(Returns the log to base 2 of input array, computed element-wise. )code" TVM_ADD_FILELINE) .set_support_level(1) .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2)); RELAY_REGISTER_UNARY_OP("log10") .describe(R"code(Returns the log to base 10 of input array, computed element-wise. )code" TVM_ADD_FILELINE) .set_support_level(1) .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10)); RELAY_REGISTER_UNARY_OP("tan") .describe(R"code(Returns the tan of input array, computed element-wise. )code" TVM_ADD_FILELINE) .set_support_level(1) .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan));

kBroadcast

// Addition RELAY_REGISTER_BINARY_OP("add") .describe("Elementwise add with broadcasting") .set_support_level(1) .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add)); // Subtraction RELAY_REGISTER_BINARY_OP("subtract") .describe("Elementwise substract with broadcasting") .set_support_level(1) .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract)); // Right shift RELAY_REGISTER_BINARY_OP("right_shift") .describe("Elementwise right shift with broadcasting") .set_support_level(4) .set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift));
其实上面的这些kBroadcast也是Elementwise的

kInjective

RELAY_REGISTER_OP("image.resize2d") .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. RELAY_REGISTER_OP("concatenate") .describe(R"code(Concatenate the input tensors along the given axis. RELAY_REGISTER_OP("transpose") .describe(R"code(Permutes the dimensions of an array. RELAY_REGISTER_OP("reshape") .describe(R"code(Reshapes the input array.
我理解就是output的每一维是和input的某一维度有映射关系?相比broadcast, injective不要求ouput的维度和input的维度保持顺序一致。

kCommReduce

RELAY_REGISTER_REDUCE_OP("sum") .describe(R"code(Computes the sum of array elements over given axes. RELAY_REGISTER_REDUCE_OP("max") .describe(R"code(Computes the max of array elements over given axes. RELAY_REGISTER_REDUCE_OP("min") .describe(R"code(Computes the min of array elements over given axes.

kOutEWiseFusable

{"nn.matmul", OpPatternKind::kOutEWiseFusable}, {"nn.conv1d", OpPatternKind::kOutEWiseFusable}, {"nn.conv2d", OpPatternKind::kOutEWiseFusable}, {"nn.max_pool3d", OpPatternKind::kOutEWiseFusable},

kTuple kOpaque

RELAY_REGISTER_OP("shape_of") .describe(R"code(Returns a tensor representing the shape of a tensor.

Combine pattern

// Combine two patterns together. static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { if (lhs > kBroadcast && rhs > kBroadcast) { LOG(FATAL) << "Cannot merge two complex group together"; } if (lhs > rhs) return lhs; return rhs; }
  • 可以看见不能将两个高于kBroadcast类型的op合并到一起
  • 两种类型的OP在fuse之后的类型为其中类型更高的一个。

RunFuse

核心函数实现在:tvm::relay::GraphPartitioner::RunFuse
总共这里分了3个phase

Pattern

  1. (phase 0)
  1. (phase 0,1,2)
  1. (phase 1)
这里分了3个phase的原因我认为主要是为了控制pattern间的顺序。

Fuse逻辑

因为在Fuse时候是在一棵树上进行的,在选择和哪个/些节点时,有很多种方案,同时还要避免成环。
上面的pattern可以总结为:
起始节点 (+中间节点)+ 结束节点
TVM的要求起始节点必须是结束节点的支配节点。
  1. 检查当前节点是否匹配pattern的起始节点
  1. 检查当前节点的支配节点是否匹配pattern的结束节点
  1. 检查起始节点到结束节点的所有路径上的节点,是否匹配中间节点的要求

Codegen for fused op

To be done