Skip to content

Commit fa2b6bf

Browse files
committed
Add MLIR support (#1044)
1 parent 61768de commit fa2b6bf

2 files changed

Lines changed: 102 additions & 40 deletions

File tree

source/mlir-metadata.json

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110579,6 +110579,17 @@
110579110579
],
110580110580
"assemblyFormat": "$module `into` $container attr-dict `:` type($module) `,` type($container)"
110581110581
},
110582+
{
110583+
"name": "transform.util.eliminate_hoistable_conversions",
110584+
"description": "Hoists and cancels matching `util.hoistable_conversion` pairs on the\n target op. See the op documentation for the optimizations performed and the\n requirements on the op contents.",
110585+
"operands": [
110586+
{ "name": "target", "type": "TransformHandleTypeInterface" }
110587+
],
110588+
"results": [
110589+
{ "name": "result", "type": "TransformHandleTypeInterface" }
110590+
],
110591+
"assemblyFormat": "$target attr-dict `:` functional-type(operands, results)"
110592+
},
110582110593
{
110583110594
"name": "transform.util.get_nearest_symbol_table",
110584110595
"description": "Returns the nearest symbol table op for each op in the payload, inclusive.\n\n This operation reads the `target` handle and produces the `result`\n handle. This operation emits a definite failure if the nearest symbol table\n is unknown.",
@@ -121196,6 +121207,25 @@
121196121207
"assemblyFormat": "$value `,` $global attr-dict `:` type($value) `->` qualified(type($global))",
121197121208
"hasCustomAssemblyFormat": true
121198121209
},
121210+
{
121211+
"name": "util.hoistable_conversion",
121212+
"summary": "Defines a hoistable conversion between inputs and outputs.",
121213+
"description": "Defines a conversion between some number of inputs and outputs that should\n be hoisted out of loops or canceled where possible.\n\n This is an operation that allows local pattern rewrites to be defined that\n shouldn't modify loop structure to allow that structure to be modified as\n a post-processing step. It is **not** expected to appear in IR long-term,\n and should, if possible, be rewritten away by the same pass that introduced\n it.\n\n The initial motivation for this operation is to enable code such as the\n lowering of an `inner_tiled` operation to platform intrinsics to mark\n marshaling/unmarshaling needed to convert between the types and shapes IREE\n expects and what the platform expects. However, it could be used for other\n cases where, within a loop, we have `arg <- f^{-1}(g(f(arg)))` where the\n $f$ and $f^{-1}$ can be moved out of the loop without sacrificing\n correctness.\n\n In general, this operation executes the contained single-block region once,\n with the arguments of the region bound to its inputs and the returned\n results bound to its outputs. It must always be correct to perform this\n inlining, though it is likely to be less performant.\n\n The `tag` and `inverse_tag` are used to identify pairs of hoistable\n conversions that can cancel with each other. The operations placed in a\n `hoistable_conversion` operation must be such that the following operations\n are correct for a pair of operations F and G such that `tag(F) ==\n inverse_tag(G)` and `inverse_tag(G) == tag(F)`:\n\n 1. If the outputs of F are the inputs of G, the results of G can be\n replaced by the arguments to F\n 2. If the arguments to F are loop iteration arguments and the results of G\n are the subsequent yielded values of those arguments, then\n - Those loop arguments can be replaced by a set of arguments whose types\n match the result types of F / input types of G\n - The initial values of those new loop arguments can be constructed by\n applying F to the previous initial values\n - The results of the loop can be transformed back into their old types\n by applying G to the new results\n - If F is not the only user of those loop-carried arguments, applying G\n to the new arguments inside the loop will preserve correctness\n\n These optimizations are performed by `eliminateHoistableConversions`\n transformation in util optimizations.\n\n These properties allow pure pre-/post-processing of operands, such as\n chains of `shape_cast` operations surrounding the accumulators of intrinsic\n calls, to be pulled out of loops, and to prevent redundant conversions from\n being added after unrolling intrinsics.\n\n Note that these hoistable conversions are **not** expected to commute with\n other hoistable conversions.",
121214+
"operands": [
121215+
{ "name": "inputs", "type": "Variadic<AnyType>" }
121216+
],
121217+
"results": [
121218+
{ "name": "results", "type": "Variadic<AnyType>" }
121219+
],
121220+
"attributes": [
121221+
{ "name": "tag", "type": "StringProp" },
121222+
{ "name": "inverseTag", "type": "StringProp" }
121223+
],
121224+
"regions": [
121225+
{ "name": "body", "type": "AnyRegion" }
121226+
],
121227+
"hasCustomAssemblyFormat": true
121228+
},
121199121229
{
121200121230
"name": "util.initializer",
121201121231
"summary": "Global initialization function.",

source/mlir.js

Lines changed: 72 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11742,9 +11742,9 @@ _.AffineDialect = class extends _.Dialect {
1174211742
const numUbOperands = result.operands.length - numOperandsBeforeUb;
1174311743
if (parser.parseOptionalKeyword('step')) {
1174411744
const step = parser.parseInteger();
11745-
result.addAttribute('step', String(step));
11745+
result.addAttribute('step', step);
1174611746
} else {
11747-
result.addAttribute('step', '1');
11747+
result.addAttribute('step', 1);
1174811748
}
1174911749
const regionArgs = [inductionVariable];
1175011750
const operands = [];
@@ -13930,6 +13930,9 @@ _.UtilDialect = class extends _.IREEDialect {
1393013930
this.parseUtilFuncOp(parser, result);
1393113931
return true;
1393213932
}
13933+
if (op === 'util.hoistable_conversion') {
13934+
return this.parseHoistableConversionOp(parser, result);
13935+
}
1393313936
if (op === 'util.unfoldable_constant') {
1393413937
parser.parseOptionalAttrDict(result.attributes);
1393513938
const value = parser.parseAttribute();
@@ -13946,6 +13949,29 @@ _.UtilDialect = class extends _.IREEDialect {
1394613949
return super.parseOperation(parser, result);
1394713950
}
1394813951

13952+
parseHoistableConversionOp(parser, result) {
13953+
const tag = parser.parseString();
13954+
result.addAttribute('tag', tag);
13955+
parser.parseKeyword('inverts');
13956+
parser.parseLParen();
13957+
const inverseTag = parser.parseString();
13958+
result.addAttribute('inverseTag', inverseTag);
13959+
parser.parseRParen();
13960+
const regionArgs = [];
13961+
const inputOperands = [];
13962+
parser.parseAssignmentList(regionArgs, inputOperands);
13963+
const fnType = parser.parseColonType();
13964+
for (let i = 0; i < regionArgs.length; i++) {
13965+
regionArgs[i].type = fnType.inputs[i];
13966+
}
13967+
parser.resolveOperands(inputOperands, fnType.inputs, result.operands);
13968+
result.addTypes(fnType.results);
13969+
const region = result.addRegion();
13970+
parser.parseRegion(region, regionArgs);
13971+
parser.parseOptionalAttrDict(result.attributes);
13972+
return true;
13973+
}
13974+
1394913975
parseUtilFuncOp(parser, result) {
1395013976
parser.parseOptionalVisibilityKeyword(result.attributes);
1395113977
parser.parseSymbolName('sym_name', result.attributes);
@@ -14262,18 +14288,16 @@ _.FlowDialect = class extends _.IREEDialect {
1426214288
}
1426314289

1426414290
parseDispatchRegionOp(parser, result) {
14291+
const indexType = new _.IndexType();
1426514292
const workloadOperands = parser.parseOperandList('optionalSquare');
14266-
for (const workload of workloadOperands) {
14267-
parser.resolveOperand(workload, null, result.operands);
14268-
}
14293+
const dynamicDimOperands = [];
1426914294
if (parser.parseOptionalArrow()) {
1427014295
if (parser.parseOptionalLParen()) {
1427114296
while (!parser.parseOptionalRParen()) {
1427214297
const type = parser.parseType();
1427314298
if (parser.parseOptionalLBrace()) {
1427414299
while (!parser.parseOptionalRBrace()) {
14275-
const tied = parser.parseOperand();
14276-
parser.resolveOperand(tied, null, result.operands);
14300+
dynamicDimOperands.push(parser.parseOperand());
1427714301
parser.parseOptionalComma();
1427814302
}
1427914303
}
@@ -14282,6 +14306,8 @@ _.FlowDialect = class extends _.IREEDialect {
1428214306
}
1428314307
}
1428414308
}
14309+
parser.resolveOperands(dynamicDimOperands, dynamicDimOperands.map(() => indexType), result.operands);
14310+
parser.resolveOperands(workloadOperands, workloadOperands.map(() => indexType), result.operands);
1428514311
parser.parseOptionalAttrDictWithKeyword(result.attributes);
1428614312
const region = result.addRegion();
1428714313
parser.parseRegion(region);
@@ -14423,8 +14449,8 @@ _.FlowDialect = class extends _.IREEDialect {
1442314449
}
1442414450

1442514451
parseTensorLoadStoreOp(parser, result) {
14426-
// or: store %26, %arg4, offsets = [...] : type -> type
1442714452
const op = result.name.getStringRef();
14453+
const isLoad = op.endsWith('.load');
1442814454
const unresolvedOperands = [];
1442914455
let nextOp = parser.parseOptionalOperand();
1443014456
while (nextOp) {
@@ -14437,43 +14463,49 @@ _.FlowDialect = class extends _.IREEDialect {
1443714463
break;
1443814464
}
1443914465
}
14440-
// Note: first parameter might not need comma-eating if we just broke from operand loop
14441-
let paramName = parser.parseOptionalKeyword();
14442-
while (true) {
14443-
if (!paramName) {
14444-
if (!parser.parseOptionalComma()) {
14445-
break;
14446-
}
14447-
paramName = parser.parseOptionalKeyword();
14448-
}
14449-
if (paramName) {
14450-
if (parser.parseOptionalEqual()) {
14451-
if (parser.parseOptionalLSquare()) {
14452-
while (!parser.parseOptionalRSquare()) {
14453-
const operand = parser.parseOptionalOperand();
14454-
if (!operand) {
14455-
// Handle integer literals (e.g., 0, 1, 3)
14456-
parser.parseInteger();
14457-
}
14458-
parser.parseOptionalComma();
14459-
}
14466+
const parseDynamicIndexList = (name) => {
14467+
parser.parseKeyword(name);
14468+
parser.parseEqual();
14469+
parser.parseLSquare();
14470+
const staticValues = [];
14471+
if (!parser.parseOptionalRSquare()) {
14472+
do {
14473+
const dynOp = parser.parseOptionalOperand();
14474+
if (dynOp) {
14475+
parser.resolveOperand(dynOp, null, result.operands);
14476+
staticValues.push(-9223372036854775808n);
1446014477
} else {
14461-
parser.parseKeyword();
14478+
const intVal = parser.parseOptionalInteger('int64');
14479+
if (intVal !== null) {
14480+
staticValues.push(intVal);
14481+
}
1446214482
}
14463-
result.addAttribute(paramName, paramName);
14464-
}
14465-
paramName = null;
14466-
} else {
14467-
break;
14483+
} while (parser.parseOptionalComma());
14484+
parser.parseRSquare();
14485+
}
14486+
result.addAttribute(`static_${name}`, staticValues);
14487+
};
14488+
parseDynamicIndexList('offsets');
14489+
parser.parseComma();
14490+
parseDynamicIndexList('sizes');
14491+
parser.parseComma();
14492+
parseDynamicIndexList('strides');
14493+
parser.parseOptionalAttrDict(result.attributes);
14494+
parser.parseColon();
14495+
const sourceType = parser.parseType();
14496+
parser.resolveOperands(unresolvedOperands, unresolvedOperands.map(() => sourceType), result.operands);
14497+
if (parser.parseOptionalLBrace()) {
14498+
if (!parser.parseOptionalRBrace()) {
14499+
do {
14500+
const dim = parser.parseOperand();
14501+
parser.resolveOperand(dim, null, result.operands);
14502+
} while (parser.parseOptionalComma());
14503+
parser.parseRBrace();
1446814504
}
1446914505
}
14470-
const types = parser.parseOptionalColonTypeList();
14471-
parser.resolveOperands(unresolvedOperands, types, result.operands);
14472-
// For tensor.load, there's a -> result type
14473-
// For tensor.store, the -> is followed by the output tensor type (not a result)
1447414506
if (parser.parseOptionalArrow() || parser.parseOptionalKeyword('to')) {
1447514507
const resultType = parser.parseType();
14476-
if (op === 'flow.dispatch.tensor.load' && resultType) {
14508+
if (isLoad && resultType) {
1447714509
result.addTypes([resultType]);
1447814510
}
1447914511
}
@@ -15631,7 +15663,7 @@ _.LinalgDialect = class extends _.Dialect {
1563115663
const lastOperand = op.operands[op.operands.length - 1];
1563215664
if (lastOperand && lastOperand.type) {
1563315665
const elemType = lastOperand.type.elementType || lastOperand.type;
15634-
payloadState.types = [elemType];
15666+
payloadState.addTypes([elemType]);
1563515667
}
1563615668
}
1563715669
for (const [name, value] of payloadOpAttrs) {

0 commit comments

Comments
 (0)