MLIR CSE (公共子表达式消除)

Tags
AI summary
MLIR的builtin pass中自带的CSE pass的实现。

简述

MLIR的builtin pass中自带了CSE pass
在MLIR中的cse的核心逻辑就是寻找两个等价的OP并消除。这里OP的等价意味着:
  • OP本身的attrs, result types一致。
  • OP的operands等价,这里需要考虑OP本身是否满足交换律,如 add(a, b) 和 add(b, a)是等价的
MLIR在实现的时候,还有很多细节,比如考虑op的side effect,根据支配关系决定block消除的顺序。

simplifyRegion

根据cse消除的前提条件,如果要消除一个表达式,则要求所有能到达这个表达式的path上都已经计算过这个表达式了。实际上这里就决定了必须要按照block间的支配关系来进行扫描消除。
伪代码:
def simplifyDomTree(DomTree node): //DomTree代表了一个region内,block间的支配关系 simplifyBlock(node.getBlock()) for childNode in node: simplifyDomTree(childNode) def simplifyRegion(region): auto domTree = getDomTree(region) simplifyDomTree(domTree)
MLIR 关于 SimplifyDomTree 的source code:
std::deque<std::unique_ptr<CFGStackNode>> stack; // Process the nodes of the dom tree for this region. stack.emplace_back(std::make_unique<CFGStackNode>( knownValues, domInfo->getRootNode(&region))); while (!stack.empty()) { auto &currentNode = stack.back(); // Check to see if we need to process this node. if (!currentNode->processed) { currentNode->processed = true; simplifyBlock(knownValues, currentNode->node->getBlock(), hasSSADominance); } // Otherwise, check to see if we need to process a child node. if (currentNode->childIterator != currentNode->node->end()) { auto *childNode = *(currentNode->childIterator++); stack.emplace_back( std::make_unique<CFGStackNode>(knownValues, childNode)); } else { // Finally, if the node and all of its children have been processed // then we delete the node. stack.pop_back(); } }
这里实际上是用stack这个deque手动模拟了递归调用。

simplifyBlock

伪代码:
def simplifyBlock(block): for op in block: simplifyOperation(op)

simplifyOperation

MLIR的消除比较保守(至少我这个version的是这样)。
  • 对于没有任何memory effect的op,直接在可见域内查找等价的operation,并进行消除。
  • 对于memory effect为read only的op,在可见域内查找等价的operation, 称 existingOp,则能进行消除的条件为:
    • existingOp和op在同一个block内
    • existingOp到带消除op之间的所有其它op都没有memory effect或都没有write effect。注意这里的“之间”指的是op list的”之间“,而不是图上的“之间”。
def simplifyOperation(op): if op.hasNoMemoryEffect(): if (auto existingOp = findEqualOpInKownOps(op)) replaceWith(op, existingOp) return if op.onlyHasReadEffect(): if (auto existingOp = findEqualOpInKownOps(op)) if(isSameBlock(existingOp, op) and not hasOtherSideEffectingOpInBetween(existing, op): replaceWith(op, existingOp) return addToKownOps(op)

isEqual(Operation*, Operation*)

bool isEqual(Operation* lhr, Operation* rhs): // Compare the operation properties. if (lhs->getName() != rhs->getName() || lhs->getAttrDictionary() != rhs->getAttrDictionary() || lhs->getNumRegions() != rhs->getNumRegions() || lhs->getNumSuccessors() != rhs->getNumSuccessors() || lhs->getNumOperands() != rhs->getNumOperands() || lhs->getNumResults() != rhs->getNumResults()) return false; auto lhsOperands = lhs->getOperands() auto rhsOperands = rhs->getOperands() if(lhs->hasTrait<IsCommutative>()): // 考虑结合律 lhsOperands = sort(lhsOperands) rhsOperands = sort(rhsOperands) checkValues(lhsOperands, rhsOperands) checkValues(lhs->getResults(), rhs->getResults()) checkRegions(lhs->getRegions(), rhs->getRegions())

在可见域中查找

MLIR通过 ScopedHashTable 来维护当前可见的域,通过函数参数 knownValues传递。
MLIR source code:
/// Attempt to eliminate a redundant operation. Returns success if the /// operation was marked for removal, failure otherwise. LogicalResult simplifyOperation(ScopedMapTy &knownValues, Operation *op, bool hasSSADominance); void simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance); void simplifyRegion(ScopedMapTy &knownValues, Region &region);
ScopedHashTable 这个结构非常契合编程由 {} 划分的可见域,通过 insert 插入元素时会插入到当前scope, 退出当前scope时,会把所有当前scope的元素删除掉。