diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2023-05-13 06:33:26 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-05-12 15:33:26 -0700 |
| commit | 65103bc9a0c72117d3c9410e361947cdd568ae55 (patch) | |
| tree | 9dcb3dea5082d12366e078b3c7b62faa89ef5c73 /source | |
| parent | 332f60c19336252d907b83882aa70665ca93a9d2 (diff) | |
Fusion pass for saturated_cooperation (#2874)
* Fusion pass for saturated_cooperation
* simplify assert
* regenerate vs projects
* missing test output files
* rename shadowing variable to appease msvc
* Fuse calls to sat_coop with differing inputs
* formatting
* add cpu test for hof simple
* Make higher-order functions into compute comparison tests
* comment tests
* remove redundant test
* Add test to confirm inlining in sat_coop fuse
* Add clarifying comment for sat coop fusing
* Add KnownBuiltin decoration
* s/CanUseFuncSignature/TypesFullyResolved for higher order function checking
* Add TODO
* spelling
* Correct detection of sat_coop calls
* Disable tests which are unsupported on testing infra
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 3 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 114 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 19 | ||||
| -rw-r--r-- | source/slang/slang-ir-fuse-satcoop.cpp | 540 | ||||
| -rw-r--r-- | source/slang/slang-ir-fuse-satcoop.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 29 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-witness-lookup.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-function-call.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 4 |
15 files changed, 747 insertions, 15 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index b992def6e..730c3fcc8 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -3160,3 +3160,6 @@ attribute_syntax [PreferRecompute] : PreferRecomputeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [PreferCheckpoint] : PreferCheckpointAttribute; + +__attributeTarget(DeclBase) +attribute_syntax [KnownBuiltin(name : String)] : KnownBuiltinAttribute; diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 03cdc9ee2..40a372a32 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6948,3 +6948,117 @@ uint3 cudaBlockIdx(); __target_intrinsic(cuda, "(blockDim)") [__readNone] uint3 cudaBlockDim(); + +// +// Workgroup cooperation +// + +// +// `saturated_cooperation(c, f, s, u)` will call `f(s, u)` if not all lanes in the +// workgroup are currently executing. however if all lanes are saturated, then +// for each unique `s` across all the active lanes `c(s, u)` is called. The +// return value is the one corresponding to the input `s` from this lane. +// +// Adjacent calls to saturated_cooperation are subject to fusion, i.e. +// saturated_cooperation(c1, f1, s, u1); +// saturated_cooperation(c2, f2, s, u2); +// will be transformed to: +// saturated_cooperation(c1c2, f1f2, s, u1u2); +// where +// c1c2 is a function which calls c1(s, u1) and then c2(s, u2); +// f1f2 is a function which calls f1(s, u1) and then f2(s, u2); +// +// When the input differs, calls are fused +// saturated_cooperation(c1, f1, s1, u1); +// saturated_cooperation(c2, f2, s2, u2); +// will be transformed to: +// saturated_cooperation(c1c2, f1f2, s1s2, u1u2); +// where +// s1s2 is a tuple of s1 and s2 +// c1c2 is a function which calls c1(s1, u1) and then c2(s2, u2); +// f1f2 is a function which calls f1(s1, u1) and then f2(s2, u2); +// Note that in this case, we will make a call to c1c2 for every unique pair +// s1s2 across all lanes +// +// (This fusion takes place in the fuse-satcoop pass, and as such any changes to +// the signature or behavior of this function should be adjusted for there). +// +[KnownBuiltin("saturated_cooperation")] +func saturated_cooperation<A : __BuiltinType, B, C>( + cooperate : functype (A, B) -> C, + fallback : functype (A, B) -> C, + A input, + B otherArg) + -> C +{ + return saturated_cooperation_using(cooperate, fallback, __WaveMatchBuitin<A>, __WaveReadLaneAtBuiltin<A>, input, otherArg); +} + +// These two functions are a temporary (circa May 2023) workaround to the fact +// that we can't deduce which overload to pass to saturated_cooperation_using +// in the call above +[__unsafeForceInlineEarly] +func __WaveMatchBuitin<T : __BuiltinType>(T t) -> uint4 +{ + return WaveMatch(t); +} +[__unsafeForceInlineEarly] +func __WaveReadLaneAtBuiltin<T : __BuiltinType>(T t, int i) -> T +{ + return WaveReadLaneAt(t, i); +} + +// +// saturated_cooperation, but you're able to specify manually the functions: +// +// waveMatch: a function to return a mask of lanes with the same input as this one +// broadcast: a function which returns the value passed into it on the specified lane +// +[KnownBuiltin("saturated_cooperation_using")] +func saturated_cooperation_using<A, B, C>( + cooperate : functype (A, B) -> C, + fallback : functype (A, B) -> C, + waveMatch : functype (A) -> uint4, + broadcast : functype (A, int) -> A, + A input, + B otherArg) + -> C +{ + const bool isWaveSaturated = WaveActiveCountBits(true) == WaveGetLaneCount(); + if(isWaveSaturated) + { + let lanesWithSameInput = waveMatch(input).x; + // Keep least significant lane in our set + let ourRepresentative = lanesWithSameInput & -lanesWithSameInput; + // The representative lanes for all lanes + var allRepresentatives = WaveActiveBitOr(ourRepresentative); + + C ret; + + // Iterate over set bits in mask from low to high. + // In each iteration the lowest bit is cleared. + while(bool(allRepresentatives)) + { + // Broadcast input across warp. + let laneIdx = firstbitlow(allRepresentatives); + let uniformInput = broadcast(input, int(laneIdx)); + + // All lanes perform some cooperative computation with dynamic + // uniform input + C c = cooperate(uniformInput, otherArg); + + // Update our return value until it + if(bool(allRepresentatives & ourRepresentative)) + ret = c; + + // Clear the lowest bit + allRepresentatives &= allRepresentatives - 1; + } + + return ret; + } + else + { + return fallback(input, otherArg); + } +} diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 25761b11c..59ac26833 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1303,6 +1303,18 @@ class NoSideEffectAttribute : public Attribute { SLANG_AST_CLASS(NoSideEffectAttribute) }; + + /// A `[KnownBuiltin("name")]` attribute allows the compiler to + /// identify this declaration during compilation, despite obfuscation or + /// linkage removing optimizations + /// +class KnownBuiltinAttribute : public Attribute +{ + SLANG_AST_CLASS(KnownBuiltinAttribute) + + String name; +}; + /// A modifier that applies to types rather than declarations. /// /// In most cases, the Slang compiler assumes that a modifier should diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 2d5f6aad7..b1f36ca2b 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -725,6 +725,18 @@ namespace Slang deprecatedAttr->message = message; } + else if (auto knownBuiltinAttr = as<KnownBuiltinAttribute>(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + + String name; + if(!checkLiteralStringVal(attr->args[0], &name)) + { + return false; + } + + knownBuiltinAttr->name = name; + } else { if(attr->args.getCount() == 0) diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 5160f3c6f..4bd8506ed 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1454,7 +1454,8 @@ namespace Slang // We could probably be broader than just parameters here // eventually. // Limit it for now though to make the specialization easier - ensureDecl(localDeclRef, DeclCheckState::CanUseFuncSignature); + // TODO: why can't this use DeclCheckState::CanUseFuncSignature + ensureDecl(localDeclRef, DeclCheckState::TypesFullyResolved); const auto type = localDeclRef.getDecl()->getType(); // We can only add overload candidates if this is known to be a function if(const auto funType = as<FuncType>(type)) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 8a04e5cc8..071ea9639 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -21,6 +21,7 @@ #include "slang-ir-entry-point-raw-ptr-params.h" #include "slang-ir-explicit-global-context.h" #include "slang-ir-explicit-global-init.h" +#include "slang-ir-fuse-satcoop.h" #include "slang-ir-glsl-legalize.h" #include "slang-ir-insts.h" #include "slang-ir-inline.h" @@ -350,6 +351,11 @@ Result linkAndOptimizeIR( #endif validateIRModuleIfEnabled(codeGenContext, irModule); + // It's important that this takes place before defunctionalization as we + // want to be able to easily discover the cooperate and fallback funcitons + // being passed to saturated_cooperation + fuseCallsToSaturatedCooperation(irModule); + // Next, we need to ensure that the code we emit for // the target doesn't contain any operations that would // be illegal on the target platform. For example, @@ -397,6 +403,12 @@ Result linkAndOptimizeIR( return SLANG_FAIL; } + // Few of our targets support higher order functions, and + // we don't have the backend code to emit higher order functions for those + // which do. + // Specialize away these parameters + // TODO: We should implement a proper defunctionalization pass + changed |= specializeHigherOrderParameters(codeGenContext, irModule); dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF"); enableIRValidationAtInsert(); @@ -617,13 +629,6 @@ Result linkAndOptimizeIR( break; } - // Few of our targets support higher order functions, and - // we don't have the backend code to emit higher order functions for those - // which do. - // Specialize away these parameters - // TODO: We should implement a proper defunctionalization pass - specializeHigherOrderParameters(codeGenContext, irModule); - // For all targets, we translate load/store operations // of aggregate types from/to byte-address buffers into // stores of individual scalar or vector values. diff --git a/source/slang/slang-ir-fuse-satcoop.cpp b/source/slang/slang-ir-fuse-satcoop.cpp new file mode 100644 index 000000000..3c827ef25 --- /dev/null +++ b/source/slang/slang-ir-fuse-satcoop.cpp @@ -0,0 +1,540 @@ +#include "slang-ir-fuse-satcoop.h" + +#include "slang-ir-inline.h" +#include "slang-ir-insts.h" +#include "slang-ir-specialize-function-call.h" +#include "slang-ir-ssa-simplification.h" +#include "slang-ir.h" + +namespace Slang +{ + +// +// Some helpers +// + +// Run an operation over every block in a module +template<typename F> +static void overAllBlocks(IRModule* module, F f) +{ + for (auto globalInst : module->getGlobalInsts()) + { + if (auto func = as<IRGlobalValueWithCode>(globalInst)) + { + for (auto block : func->getBlocks()) + { + f(block); + } + } + } +} + +static bool uses(IRInst* used, IRInst* user) +{ + for(auto use = used->firstUse; use; use = use->nextUse) + { + if(use->getUser() == user) + return true; + } + return false; +}; + +// given: `f; x; g` +// reorder instructions such that f and g are adjacent, in the form: +// `p; f; g; q`, +// +// p is the set of instructions upon which g depends and q is the +// set of instructions which depend on f. If these sets are not disjoint then +// we can't float f and g together. Instructions not used by g and which don't +// use f can go in either p or q. +// +// Returns g on success +static IRInst* floatTogether(IRInst* f, IRInst* g) +{ + List<IRInst*> ps, qs; + + auto usesF = [&](IRInst* i){ + if(uses(f, i)) + return true; + for(auto q : qs) + if(uses(q, i)) + return true; + return false; + }; + auto usedByG = [&](IRInst* i){ + if(uses(i, g)) + return true; + for(auto p : ps) + if(uses(i, p)) + return true; + return false; + }; + + // Scan backwards to find which instructions g depends on, known as p + auto i = g->prev; + while(i != f) + { + SLANG_ASSERT(i); + + // If any instruction in x has side effects, we can't reorder things + if(i->mightHaveSideEffects()) + return nullptr; + + if(usedByG(i)) + ps.add(i); + i = i->prev; + } + + // Scan forwards to compute instructions which depend on f, the instructions in q + i = f->next; + while(i != g) + { + if(usesF(i)) + { + // If this happens then ps and qs are not disjoint, and we will not + // be able to float f and g together + if(ps.contains(i)) + return nullptr; + qs.add(i); + } + + i = i->next; + } + + // Now we can safely reorder things by moving p;f;g before everything else + // Remember, we constructed ps in reverse, so we must insert these + // backwards too + for(Index j = ps.getCount()-1; j >= 0; --j) + { + auto p = ps[j]; + p->removeFromParent(); + p->insertBefore(f); + } + g->removeFromParent(); + g->insertAfter(f); + return g; +} + +// bifanout(f, g)((x, y), (a, b)) = (f(x, a), g(y, b)) +// +// Make a function `bifanout` which applies two functions to their respective +// elements in two pairs. Optionally the first and second inputs can be shared +// instead of split in a tuple. +// +// The outputs are returned in a 2-tuple +static IRFunc* makeBiFanout(IRBuilder& builder, IRFunc* f, IRFunc* g, bool shareFirst, bool shareSecond) +{ + SLANG_ASSERT(f->getParamCount() == 2); + SLANG_ASSERT(g->getParamCount() == f->getParamCount()); + SLANG_ASSERT(!shareFirst || f->getParamType(0) == g->getParamType(0)); + SLANG_ASSERT(!shareSecond || f->getParamType(1) == g->getParamType(1)); + IRBuilderInsertLocScope insertLocScope(&builder); + + // Create (using shareFirst = false, shareSecond = true as an example) + // func myFunc(s : S, u : (U1,U2)) -> (R1, R2) + // { + // let fRes = f(s, u.fst); + // let gRes = g(s, u.snd); + // return (fRes, gRes); + // } + + // The return type is the tuple of f and g's return types + auto resType = builder.getTupleType(f->getResultType(), g->getResultType()); + auto firstInputType = shareFirst + ? f->getParamType(0) + : builder.getTupleType(f->getParamType(0), g->getParamType(0)); + auto secondInputType = shareSecond + ? f->getParamType(1) + : builder.getTupleType(f->getParamType(1), g->getParamType(1)); + + // Set up our function + // func myFunc(s : S, u : (U1,U2)) -> (R1, R2) + auto func = builder.createFunc(); + builder.addDecoration(func, kIROp_ForceInlineDecoration); + builder.setDataType(func, builder.getFuncType({firstInputType, secondInputType}, resType)); + builder.setInsertInto(func); + auto b = builder.emitBlock(); + builder.setInsertInto(b); + + auto s = builder.emitParam(firstInputType); + auto s1 = shareFirst ? s : builder.emitGetTupleElement(f->getParamType(0), s, 0); + auto s2 = shareFirst ? s : builder.emitGetTupleElement(g->getParamType(0), s, 1); + + auto u = builder.emitParam(secondInputType); + auto u1 = shareSecond ? u : builder.emitGetTupleElement(f->getParamType(1), u, 0); + auto u2 = shareSecond ? u : builder.emitGetTupleElement(g->getParamType(1), u, 1); + + // let fRes = f(s, u.fst); + auto fRes = builder.emitCallInst(f->getResultType(), f, {s1, u1}); + // let gRes = g(s, u.snd); + auto gRes = builder.emitCallInst(g->getResultType(), g, {s2, u2}); + // return (fRes, gRes); + builder.emitReturn(builder.emitMakeTuple(fRes, gRes)); + return func; +} + +// Given f : a -> uint4, g : b -> uint4, return z : (a, b) -> uint4 using +// bitwise and to combine the outputs +static IRFunc* makeWaveMatchBoth(IRBuilder& builder, IRType* inputTypeF, IRType* inputTypeG, IRInst* f, IRInst* g) +{ + // SLANG_ASSERT(f->getParamCount() == 1); + // SLANG_ASSERT(g->getParamCount() == f->getParamCount()); + auto uint4Type = builder.getVectorType(builder.getUIntType(), 4); + // SLANG_ASSERT(f->getResultType() == uint4Type); + // SLANG_ASSERT(g->getResultType() == f->getResultType()); + IRBuilderInsertLocScope insertLocScope(&builder); + + // Create (using shareFirst = false, shareSecond = true as an example) + // func myFunc(x : (A,B)) -> uint4 + // { + // let fRes = f(x.fst); + // let gRes = g(x.snd); + // return fRes & gRes; + // } + + auto inputTypeFG = builder.getTupleType(inputTypeF, inputTypeG); + auto resType = uint4Type; + + auto func = builder.createFunc(); + builder.addDecoration(func, kIROp_ForceInlineDecoration); + builder.setDataType(func, builder.getFuncType({inputTypeFG}, resType)); + builder.setInsertInto(func); + auto b = builder.emitBlock(); + builder.setInsertInto(b); + + auto x = builder.emitParam(inputTypeFG); + auto x1 = builder.emitGetTupleElement(inputTypeF, x, 0); + auto x2 = builder.emitGetTupleElement(inputTypeG, x, 1); + + auto b1 = builder.emitCallInst(uint4Type, f, {x1}); + auto b2 = builder.emitCallInst(uint4Type, g, {x2}); + auto r = builder.emitBitAnd(uint4Type, b1, b2); + + builder.emitReturn(r); + return func; +} + +// Similar to above +static IRFunc* makeBroadcastBoth(IRBuilder& builder, IRType* inputTypeF, IRType* inputTypeG, IRInst* f, IRInst* g) +{ + // SLANG_ASSERT(f->getParamCount() == 2); + // SLANG_ASSERT(g->getParamCount() == f->getParamCount()); + auto intType = builder.getIntType(); + // SLANG_ASSERT(f->getParamType(1) == intType); + // SLANG_ASSERT(g->getParamType(1) == f->getParamType(1)); + IRBuilderInsertLocScope insertLocScope(&builder); + + // Create (using shareFirst = false, shareSecond = true as an example) + // func myFunc(x : (A,B), i : int) -> (A, B) + // { + // let fRes = f(x.fst, i); + // let gRes = g(x.snd, i); + // return (fRes, gRes); + // } + + auto inputTypeFG = builder.getTupleType(inputTypeF, inputTypeG); + auto resType = inputTypeFG; + + auto func = builder.createFunc(); + builder.addDecoration(func, kIROp_ForceInlineDecoration); + builder.setDataType(func, builder.getFuncType({inputTypeFG, intType}, resType)); + builder.setInsertInto(func); + auto b = builder.emitBlock(); + builder.setInsertInto(b); + + auto x = builder.emitParam(inputTypeFG); + auto i = builder.emitParam(intType); + auto x1 = builder.emitGetTupleElement(inputTypeF, x, 0); + auto x2 = builder.emitGetTupleElement(inputTypeG, x, 1); + + auto b1 = builder.emitCallInst(inputTypeF, f, {x1, i}); + auto b2 = builder.emitCallInst(inputTypeG, g, {x2, i}); + auto r = builder.emitMakeTuple(b1, b2); + + builder.emitReturn(r); + return func; +} + +// All the information on a call to saturated_cooperation_using +struct SatCoopCall +{ + // The definition in hlsl.slang + IRGeneric* generic; + + // The specialization of that call + IRSpecialize* specialize; + + // Called 'A' in the definition + IRType* sharedInputType; + // Called 'B' in the definition + IRType* distinctInputType; + // Called 'C' in the definition + IRType* retType; + + // The function arguments to the call + IRFunc* cooperate; + IRFunc* fallback; + + // The inter-lane communication functions + // TODO: call specializeGeneric on these and extract the IRFunc + IRInst* waveMatch; + IRInst* broadcast; + + // The values to pass to these functions + IRInst* sharedInput; + IRInst* distinctInput; +}; + +static SatCoopCall getSatCoopCall(IRCall* f) +{ + SatCoopCall ret; + ret.specialize = as<IRSpecialize>(f->getCallee()); + + // Since this is a call to saturated_cooperation, it must have at least + // three specialization arguments for the type parameters A, B, C. We allow + // more here for any dictionaries or witnesses. + SLANG_ASSERT(ret.specialize && ret.specialize->getArgCount() >= 3); + ret.generic = as<IRGeneric>(ret.specialize->getBase()); + SLANG_ASSERT(ret.generic); + ret.sharedInputType = as<IRType>(ret.specialize->getArg(0)); + ret.distinctInputType = as<IRType>(ret.specialize->getArg(1)); + ret.retType = as<IRType>(ret.specialize->getArg(2)); + SLANG_ASSERT(ret.sharedInputType); + SLANG_ASSERT(ret.distinctInputType); + SLANG_ASSERT(ret.retType); + + SLANG_ASSERT(f->getArgCount() == 6); + ret.cooperate = as<IRFunc>(f->getArg(0)); + ret.fallback = as<IRFunc>(f->getArg(1)); + SLANG_ASSERT(ret.cooperate); + SLANG_ASSERT(ret.fallback); + + ret.waveMatch = f->getArg(2); + ret.broadcast = f->getArg(3); + SLANG_ASSERT(ret.waveMatch); + SLANG_ASSERT(ret.broadcast); + + ret.sharedInput = f->getArg(4); + ret.distinctInput = f->getArg(5); + SLANG_ASSERT(ret.sharedInput->getDataType() == ret.sharedInputType); + SLANG_ASSERT(ret.distinctInput->getDataType() == ret.distinctInputType); + return ret; +} + +// transform: +// a = sat_coop(c1, f1, s1, u1); // f +// p; +// q; +// b = sat_coop(c2, f2, s2, u2); // g +// to: +// p; +// (a,b) = sat_coop(c1 &&& c2, f1 &&& f2, (s1, s2), (u1, u2)); +// q; +// +// Removes the first two calls, and returns the second one if creation was +// successful. +// +// This can fail if: +// +// p has side effects which c1 or f1 may depend on +// q has side effects which c2 or f2 may depend on +// p depends on a +// the second call to sat_coop depends on a +// the second call to sat_coop depends on q +static IRCall* tryFuseCalls(IRBuilder& builder, IRCall* f, IRCall* g) +{ + // TODO: Make sure that the types in here are concrete, use + // `isGenericParam` + + IRBuilderInsertLocScope insertLocScope(&builder); + + SatCoopCall callF = getSatCoopCall(f); + SatCoopCall callG = getSatCoopCall(g); + // If these aren't referencing the same generic, then something has gone + // wrong in our assumptions. + SLANG_ASSERT(callF.generic == callG.generic); + + // If g uses the result of f, we can't fuse them with this logic (we could + // however with a replacement for 'fanout') + if(uses(f, g)) + return nullptr; + + // If there is no safe way to float these together, then fail + const auto q = floatTogether(f, g); + if(!q) + return nullptr; + builder.setInsertBefore(q); + + // As a slight neatening, we'll avoid wrapping and upwrapping a tuple (u,u) + // if both f and g use the same distinct input.. + bool usesSameDistinctInput = callF.distinctInput == callG.distinctInput; + SLANG_ASSERT(!usesSameDistinctInput || callF.distinctInputType == callG.distinctInputType); + + // Similarly for the shared input: if these use the same shared input then + // the fusing is simpler (no need to make a product of s1 and s2) + // TODO: if there is an injection from s1 to s2, then we can avoid the WaveMatch on s2 + const bool usesSameSharedInput = + callF.sharedInput == callG.sharedInput && + callF.waveMatch == callG.waveMatch && + callF.broadcast == callG.broadcast; + SLANG_ASSERT(!usesSameSharedInput || callF.sharedInputType == callG.sharedInputType); + + // Generate a new specialization of our saturated_cooperation_using function, + // reflecting the new input and output types. + const auto newRetType = builder.getTupleType(callF.retType, callG.retType); + const auto sharedInputType = usesSameSharedInput + ? callF.sharedInputType + : builder.getTupleType(callF.sharedInputType, callG.sharedInputType); + const auto distinctInputType = usesSameDistinctInput + ? callF.distinctInputType + : builder.getTupleType(callF.distinctInputType, callG.distinctInputType); + + // Make sure there are no other generic parameters which are are failing to + // take care of here. + SLANG_ASSERT(callF.specialize->getArgCount() == 3); + SLANG_ASSERT(callG.specialize->getArgCount() == 3); + + // Specialize our new call + const auto newSpec = builder.emitSpecializeInst( + builder.getTypeKind(), + callF.generic, + {sharedInputType, distinctInputType, newRetType}); + + // Make our new functions, and joined inputs + const auto newCooperate = makeBiFanout(builder, callF.cooperate, callG.cooperate, usesSameSharedInput, usesSameDistinctInput); + const auto newFallback = makeBiFanout(builder, callF.fallback, callG.fallback, usesSameSharedInput, usesSameDistinctInput); + const auto newWaveMatch = usesSameSharedInput + ? callF.waveMatch + : makeWaveMatchBoth(builder, callF.sharedInputType, callG.sharedInputType, callF.waveMatch, callG.waveMatch); + const auto newBroadcast = usesSameSharedInput + ? callF.broadcast + : makeBroadcastBoth(builder, callF.sharedInputType, callG.sharedInputType, callF.broadcast, callG.broadcast); + const auto newSharedInput = usesSameSharedInput + ? callF.sharedInput + : builder.emitMakeTuple(callF.sharedInput, callG.sharedInput); + const auto newDistinctInput = usesSameDistinctInput + ? callF.distinctInput + : builder.emitMakeTuple(callF.distinctInput, callG.distinctInput); + + // Call it and extract the results from f and g + const auto res = builder.emitCallInst( + newRetType, + newSpec, + {newCooperate, newFallback, newWaveMatch, newBroadcast, newSharedInput, newDistinctInput}); + const auto resF = builder.emitGetTupleElement(callF.retType, res, 0); + const auto resG = builder.emitGetTupleElement(callG.retType, res, 1); + f->replaceUsesWith(resF); + g->replaceUsesWith(resG); + f->removeAndDeallocate(); + g->removeAndDeallocate(); + + return res; +} + +// +// Identify calls which we can fuse +// +IRCall* isKnownFunction(const char* n, IRInst* i) +{ + auto call = as<IRCall>(i); + if(!call) + return nullptr; + // saturated_cooperation is a generic function, so look for specializations thereof + auto spec = as<IRSpecialize>(call->getCallee()); + if(!spec) + return nullptr; + auto generic = findSpecializedGeneric(spec); + if(!generic) + return nullptr; + + auto h = generic->findDecoration<IRKnownBuiltinDecoration>(); + if(!h || h->getName() != n) + return nullptr; + return call; +} + +// +// We perform a left fold over calls to saturated_cooperation +// +// sc(ca, fa) +// sc(cb, fb) +// sc(cc, fc) +// +// to +// +// sc(cacbcc, fafbfc) +// +// where cacbcc (and fafbfc) look like +// +// cacbcc(){ +// cacb(); +// cc(); +// } +// +// cacb(){ +// ca(); +// cb(); +// } +// +// These helper functions are inlined shortly after and the generated code is +// exactly what you'd expect: it's the body of sat_coop except that the +// original call to cooperate is replaced by three calls to ca, cb, cc. +// +// We use a fold here rather than accumulating everything at once as it's +// easier to implement fusing for 2 functions than n +static void fuseCallsInBlock(IRBuilder& builder, IRBlock* block) +{ + // first, inline calls to saturated_cooperation to expose + // saturated_cooperation_using which is simpler to fuse. + // It is simpler to fuse because it makes explicit the inter-lane + // communication functions, which we can use as buiding blocks in our + // composition. + + List<IRCall*> toInline; + for (auto inst : block->getChildren()) + { + if(auto sat_coop = isKnownFunction("saturated_cooperation", inst)) + toInline.add(sat_coop); + } + for(auto c : toInline) + inlineCall(c); + + // Walk over the instructions in this block + // If we see a call to sat_coop then remember where it is and keep + // walking, if we reach another call without first encountering any + // instructions with which our first call can't be safely reordered + // then we remove the first call and replace the second with a fused + // call. + IRCall* lastCall = nullptr; + for(auto inst = block->getFirstInst(); inst != block->getTerminator(); inst = inst->getNextInst()) + { + if(auto call = isKnownFunction("saturated_cooperation_using", inst)) + { + if(lastCall) + { + auto fused = tryFuseCalls(builder, lastCall, call); + if(fused) + { + inst = fused; + lastCall = fused; + } + else + { + lastCall = call; + } + } + else + { + lastCall = call; + } + } + } +} + +void fuseCallsToSaturatedCooperation(IRModule* module) +{ + IRBuilder builder(module); + overAllBlocks(module, [&](auto b){fuseCallsInBlock(builder, b);}); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-fuse-satcoop.h b/source/slang/slang-ir-fuse-satcoop.h new file mode 100644 index 000000000..8761c13da --- /dev/null +++ b/source/slang/slang-ir-fuse-satcoop.h @@ -0,0 +1,11 @@ +#pragma once + +namespace Slang +{ + struct CodeGenContext; + struct IRModule; + struct IRType; + + /// Fuse adjacent calls to saturated_cooperation + void fuseCallsToSaturatedCooperation(IRModule* module); +} diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index d9036f8bc..ea7547171 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -738,6 +738,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// Marks an interface as a COM interface declaration. INST(ComInterfaceDecoration, COMInterface, 0, 0) + /// Attaches a name to this instruction so that it can be identified + /// later in the compiler reliably + INST(KnownBuiltinDecoration, KnownBuiltinDecoration, 1, 0) + /* Decorations for RTTI objects */ INST(RTTITypeSizeDecoration, RTTI_typeSize, 1, 0) INST(AnyValueSizeDecoration, AnyValueSize, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 3f49be801..0b4ddf1a6 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -541,6 +541,18 @@ struct IRTorchEntryPointDecoration : IRDecoration UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); } }; +struct IRKnownBuiltinDecoration : IRDecoration +{ + enum + { + kOp = kIROp_KnownBuiltinDecoration + }; + IR_LEAF_ISA(KnownBuiltinDecoration) + + IRStringLit* getNameOperand() { return cast<IRStringLit>(getOperand(0)); } + UnownedStringSlice getName() { return getNameOperand()->getStringSlice(); } +}; + struct IRFormatDecoration : IRDecoration { enum { kOp = kIROp_FormatDecoration }; @@ -3016,6 +3028,14 @@ public: UInt argCount, IRInst* const* args); + IRInst* emitSpecializeInst( + IRType* type, + IRInst* genericVal, + const List<IRInst*>& args) + { + return emitSpecializeInst(type, genericVal, args.getCount(), args.begin()); + } + IRInst* emitLookupInterfaceMethodInst( IRType* type, IRInst* witnessTableVal, @@ -3029,13 +3049,13 @@ public: IRInst* emitUnpackAnyValue(IRType* type, IRInst* value); - IRInst* emitCallInst( + IRCall* emitCallInst( IRType* type, IRInst* func, UInt argCount, IRInst* const* args); - IRInst* emitCallInst( + IRCall* emitCallInst( IRType* type, IRInst* func, List<IRInst*> const& args) @@ -4051,6 +4071,11 @@ public: // TODO: Ellie, correct int type here? addDecoration(value, d, maxCount); } + + void addKnownBuiltinDecoration(IRInst* value, UnownedStringSlice const& name) + { + addDecoration(value, kIROp_KnownBuiltinDecoration, getStringValue(name)); + } }; // Helper to establish the source location that will be used diff --git a/source/slang/slang-ir-lower-witness-lookup.cpp b/source/slang/slang-ir-lower-witness-lookup.cpp index f87633656..841617ac8 100644 --- a/source/slang/slang-ir-lower-witness-lookup.cpp +++ b/source/slang/slang-ir-lower-witness-lookup.cpp @@ -284,7 +284,7 @@ struct WitnessLookupLoweringContext callReturnType = dispatchFuncType->getResultType(); } - auto call = builder.emitCallInst( + IRInst* ret = builder.emitCallInst( callReturnType, entry, (UInt)args.getCount(), @@ -292,9 +292,9 @@ struct WitnessLookupLoweringContext // If result type is an associated type, we need to pack it into an anyValue. if (as<IRAssociatedType>(dispatchFuncType->getResultType())) { - call = builder.emitPackAnyValue(dispatchFuncType->getResultType(), call); + ret = builder.emitPackAnyValue(dispatchFuncType->getResultType(), ret); } - builder.emitReturn(call); + builder.emitReturn(ret); } builder.setInsertInto(firstBlock); if (witnessTables.getCount() == 1) diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp index bc238e6ec..e0a757e11 100644 --- a/source/slang/slang-ir-specialize-function-call.cpp +++ b/source/slang/slang-ir-specialize-function-call.cpp @@ -32,7 +32,7 @@ bool FunctionCallSpecializeCondition::isParamSuitableForSpecialization(IRParam* if (as<IRGlobalParam>(arg)) return true; // Similarly for these global values - if( as<IRGlobalValueWithCode>(arg) ) return true; + if (as<IRGlobalValueWithCode>(arg)) return true; // As we will see later, we can also // specialize a call when the argument diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6adf8ee1c..72b65ea3e 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3395,7 +3395,7 @@ namespace Slang return inst; } - IRInst* IRBuilder::emitCallInst( + IRCall* IRBuilder::emitCallInst( IRType* type, IRInst* pFunc, UInt argCount, diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 67f812134..66295b0fb 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -677,6 +677,7 @@ struct IRInst // IRInstListBase m_decorationsAndChildren; + IRInst* getFirstDecorationOrChild() { return m_decorationsAndChildren.first; } IRInst* getLastDecorationOrChild() { return m_decorationsAndChildren.last; } IRInstListBase getDecorationsAndChildren() { return m_decorationsAndChildren; } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index b1726487d..486c152d5 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1271,6 +1271,10 @@ static void addLinkageDecoration( builder->addPublicDecoration(inst); builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice()); } + else if (as<KnownBuiltinAttribute>(modifier)) + { + builder->addKnownBuiltinDecoration(inst, decl->getName()->text.getUnownedSlice()); + } } if (as<InterfaceDecl>(decl->parentDecl) && decl->parentDecl->hasModifier<ComInterfaceAttribute>() && |
