Skip to content

Commit d2a291a

Browse files
committed
[MLIR][Linalg] Lower linalg.tiled_loop to scf loops
Differential Revision: https://reviews.llvm.org/D101747
1 parent 40f7834 commit d2a291a

File tree

2 files changed

+111
-1
lines changed

2 files changed

+111
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/Loops.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,14 +555,49 @@ class LinalgRewritePattern : public RewritePattern {
555555
}
556556
};
557557

558+
struct TiledLoopPattern : public OpRewritePattern<TiledLoopOp> {
559+
using OpRewritePattern<TiledLoopOp>::OpRewritePattern;
560+
561+
LogicalResult matchAndRewrite(TiledLoopOp tiledLoop,
562+
PatternRewriter &rewriter) const override {
563+
Location loc = tiledLoop.getLoc();
564+
565+
// Fail conversion if the `tiled_loop` has not been bufferized.
566+
if (!llvm::all_of(tiledLoop.outputs(), [&](Value arg) {
567+
return arg.getType().isa<MemRefType>();
568+
}))
569+
return failure();
570+
571+
// TODO: Build loop nest with `scf.for` and `scf.parallel` depending on the
572+
// iterator type.
573+
scf::buildLoopNest(rewriter, loc, tiledLoop.lowerBound(),
574+
tiledLoop.upperBound(), tiledLoop.step(),
575+
[&](OpBuilder &builder, Location loc, ValueRange ivs) {
576+
// Move body without its terminator.
577+
SmallVector<Value, 16> newBlockArgs;
578+
newBlockArgs.append(ivs.begin(), ivs.end());
579+
newBlockArgs.append(tiledLoop.inputs().begin(),
580+
tiledLoop.inputs().end());
581+
newBlockArgs.append(tiledLoop.outputs().begin(),
582+
tiledLoop.outputs().end());
583+
Block *newBody = rewriter.getInsertionBlock();
584+
rewriter.mergeBlocks(tiledLoop.getBody(), newBody,
585+
newBlockArgs);
586+
rewriter.eraseOp(newBody->getTerminator());
587+
});
588+
rewriter.eraseOp(tiledLoop);
589+
return success();
590+
}
591+
};
592+
558593
struct FoldAffineOp;
559594
} // namespace
560595

561596
template <typename LoopType>
562597
static void lowerLinalgToLoopsImpl(FuncOp funcOp) {
563598
MLIRContext *context = funcOp.getContext();
564599
RewritePatternSet patterns(context);
565-
patterns.add<LinalgRewritePattern<LoopType>>(context);
600+
patterns.add<LinalgRewritePattern<LoopType>, TiledLoopPattern>(context);
566601
memref::DimOp::getCanonicalizationPatterns(patterns, context);
567602
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
568603
patterns.add<FoldAffineOp>(context);

mlir/test/Dialect/Linalg/loops.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,3 +1522,78 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
15221522
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
15231523
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
15241524
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
1525+
1526+
1527+
#map0 = affine_map<(d0) -> (24, -d0 + 192)>
1528+
#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
1529+
#map2 = affine_map<(d0) -> (16, -d0 + 192)>
1530+
1531+
func @tiled_loop_to_parallel(%A: memref<192x192xf32>,
1532+
%B: memref<192x192xf32>,
1533+
%C: memref<192x192xf32>) {
1534+
%cst = constant 0.000000e+00 : f32
1535+
%c24 = constant 24 : index
1536+
%c16 = constant 16 : index
1537+
%c0 = constant 0 : index
1538+
%c192 = constant 192 : index
1539+
1540+
linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
1541+
ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
1542+
outs (%C_ = %C: memref<192x192xf32>) {
1543+
%0 = affine.min #map0(%i)
1544+
%1 = memref.subview %A_[%i, 0] [%0, 192] [1, 1]
1545+
: memref<192x192xf32> to memref<?x192xf32, #map1>
1546+
%2 = affine.min #map2(%j)
1547+
%3 = memref.subview %B_[0, %j] [192, %2] [1, 1]
1548+
: memref<192x192xf32> to memref<192x?xf32, #map1>
1549+
%4 = memref.subview %C_[%i, %j] [%0, %2] [1, 1]
1550+
: memref<192x192xf32> to memref<?x?xf32, #map1>
1551+
linalg.fill(%4, %cst) : memref<?x?xf32, #map1>, f32
1552+
linalg.matmul ins(%1, %3 : memref<?x192xf32, #map1>,
1553+
memref<192x?xf32, #map1>)
1554+
outs(%4 : memref<?x?xf32, #map1>)
1555+
linalg.yield
1556+
}
1557+
return
1558+
}
1559+
1560+
// CHECKLOOP-LABEL: @tiled_loop_to_parallel
1561+
// CHECKLOOP-SAME: %[[A:.*]]: memref<192x192xf32>, %[[B:.*]]: memref<192x192xf32>,
1562+
// CHECKLOOP-SAME: %[[C:.*]]: memref<192x192xf32>) {
1563+
// CHECKLOOP: %[[C24:.*]] = constant 24 : index
1564+
// CHECKLOOP: %[[C16:.*]] = constant 16 : index
1565+
// CHECKLOOP: %[[C192:.*]] = constant 192 : index
1566+
// CHECKLOOP: %[[C0:.*]] = constant 0 : index
1567+
// CHECKLOOP: scf.for %[[I:.*]] = %[[C0]] to %[[C192]] step %[[C24]] {
1568+
// CHECKLOOP: scf.for %[[J:.*]] = %[[C0]] to %[[C192]] step %[[C16]] {
1569+
// CHECKLOOP: %[[A_sub:.*]] = memref.subview %[[A]][%[[I]]
1570+
// CHECKLOOP: %[[B_sub:.*]] = memref.subview %[[B]][0, %[[J]]]
1571+
// CHECKLOOP: %[[C_sub:.*]] = memref.subview %[[C]][%[[I]]
1572+
1573+
1574+
func @tiled_loop_to_for(%A: memref<192x192xf32>,
1575+
%B: memref<192x192xf32>,
1576+
%C: memref<f32>) {
1577+
%c24 = constant 24 : index
1578+
%c16 = constant 16 : index
1579+
%c0 = constant 0 : index
1580+
%c192 = constant 192 : index
1581+
%cst = constant 0.000000e+00 : f32
1582+
1583+
linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) step (%c24, %c16)
1584+
ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
1585+
outs (%C_ = %C: memref<f32>)
1586+
iterators["reduction", "reduction"] {
1587+
linalg.fill(%A_, %cst) : memref<192x192xf32>, f32
1588+
linalg.yield
1589+
}
1590+
return
1591+
}
1592+
1593+
// CHECKLOOP-LABEL: @tiled_loop_to_for
1594+
// CHECKLOOP: %[[C24:.*]] = constant 24 : index
1595+
// CHECKLOOP: %[[C16:.*]] = constant 16 : index
1596+
// CHECKLOOP: %[[C192:.*]] = constant 192 : index
1597+
// CHECKLOOP: %[[C0:.*]] = constant 0 : index
1598+
// CHECKLOOP: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C24]]
1599+
// CHECKLOOP: scf.for %{{.*}} = %[[C0]] to %[[C192]] step %[[C16]]

0 commit comments

Comments
 (0)