summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2023-05-13 06:33:26 +0800
committerGitHub <noreply@github.com>2023-05-12 15:33:26 -0700
commit65103bc9a0c72117d3c9410e361947cdd568ae55 (patch)
tree9dcb3dea5082d12366e078b3c7b62faa89ef5c73 /source
parent332f60c19336252d907b83882aa70665ca93a9d2 (diff)
Fusion pass for saturated_cooperation (#2874)
* Fusion pass for saturated_cooperation * simplify assert * regenerate vs projects * missing test output files * rename shadowing variable to appease msvc * Fuse calls to sat_coop with differing inputs * formatting * add cpu test for hof simple * Make higher-order functions into compute comparison tests * comment tests * remove redundant test * Add test to confirm inlining in sat_coop fuse * Add clarifying comment for sat coop fusing * Add KnownBuiltin decoration * s/CanUseFuncSignature/TypesFullyResolved for higher order function checking * Add TODO * spelling * Correct detection of sat_coop calls * Disable tests which are unsupported on testing infra
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/hlsl.meta.slang114
-rw-r--r--source/slang/slang-ast-modifier.h12
-rw-r--r--source/slang/slang-check-modifier.cpp12
-rw-r--r--source/slang/slang-check-overload.cpp3
-rw-r--r--source/slang/slang-emit.cpp19
-rw-r--r--source/slang/slang-ir-fuse-satcoop.cpp540
-rw-r--r--source/slang/slang-ir-fuse-satcoop.h11
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-insts.h29
-rw-r--r--source/slang/slang-ir-lower-witness-lookup.cpp6
-rw-r--r--source/slang/slang-ir-specialize-function-call.cpp2
-rw-r--r--source/slang/slang-ir.cpp2
-rw-r--r--source/slang/slang-ir.h1
-rw-r--r--source/slang/slang-lower-to-ir.cpp4
15 files changed, 747 insertions, 15 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index b992def6e..730c3fcc8 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -3160,3 +3160,6 @@ attribute_syntax [PreferRecompute] : PreferRecomputeAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [PreferCheckpoint] : PreferCheckpointAttribute;
+
+__attributeTarget(DeclBase)
+attribute_syntax [KnownBuiltin(name : String)] : KnownBuiltinAttribute;
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 03cdc9ee2..40a372a32 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -6948,3 +6948,117 @@ uint3 cudaBlockIdx();
__target_intrinsic(cuda, "(blockDim)")
[__readNone]
uint3 cudaBlockDim();
+
+//
+// Workgroup cooperation
+//
+
+//
+// `saturated_cooperation(c, f, s, u)` will call `f(s, u)` if not all lanes in the
+// workgroup are currently executing. however if all lanes are saturated, then
+// for each unique `s` across all the active lanes `c(s, u)` is called. The
+// return value is the one corresponding to the input `s` from this lane.
+//
+// Adjacent calls to saturated_cooperation are subject to fusion, i.e.
+// saturated_cooperation(c1, f1, s, u1);
+// saturated_cooperation(c2, f2, s, u2);
+// will be transformed to:
+// saturated_cooperation(c1c2, f1f2, s, u1u2);
+// where
+// c1c2 is a function which calls c1(s, u1) and then c2(s, u2);
+// f1f2 is a function which calls f1(s, u1) and then f2(s, u2);
+//
+// When the input differs, calls are fused
+// saturated_cooperation(c1, f1, s1, u1);
+// saturated_cooperation(c2, f2, s2, u2);
+// will be transformed to:
+// saturated_cooperation(c1c2, f1f2, s1s2, u1u2);
+// where
+// s1s2 is a tuple of s1 and s2
+// c1c2 is a function which calls c1(s1, u1) and then c2(s2, u2);
+// f1f2 is a function which calls f1(s1, u1) and then f2(s2, u2);
+// Note that in this case, we will make a call to c1c2 for every unique pair
+// s1s2 across all lanes
+//
+// (This fusion takes place in the fuse-satcoop pass, and as such any changes to
+// the signature or behavior of this function should be adjusted for there).
+//
+[KnownBuiltin("saturated_cooperation")]
+func saturated_cooperation<A : __BuiltinType, B, C>(
+ cooperate : functype (A, B) -> C,
+ fallback : functype (A, B) -> C,
+ A input,
+ B otherArg)
+ -> C
+{
+ return saturated_cooperation_using(cooperate, fallback, __WaveMatchBuitin<A>, __WaveReadLaneAtBuiltin<A>, input, otherArg);
+}
+
+// These two functions are a temporary (circa May 2023) workaround to the fact
+// that we can't deduce which overload to pass to saturated_cooperation_using
+// in the call above
+[__unsafeForceInlineEarly]
+func __WaveMatchBuitin<T : __BuiltinType>(T t) -> uint4
+{
+ return WaveMatch(t);
+}
+[__unsafeForceInlineEarly]
+func __WaveReadLaneAtBuiltin<T : __BuiltinType>(T t, int i) -> T
+{
+ return WaveReadLaneAt(t, i);
+}
+
+//
+// saturated_cooperation, but you're able to specify manually the functions:
+//
+// waveMatch: a function to return a mask of lanes with the same input as this one
+// broadcast: a function which returns the value passed into it on the specified lane
+//
+[KnownBuiltin("saturated_cooperation_using")]
+func saturated_cooperation_using<A, B, C>(
+ cooperate : functype (A, B) -> C,
+ fallback : functype (A, B) -> C,
+ waveMatch : functype (A) -> uint4,
+ broadcast : functype (A, int) -> A,
+ A input,
+ B otherArg)
+ -> C
+{
+ const bool isWaveSaturated = WaveActiveCountBits(true) == WaveGetLaneCount();
+ if(isWaveSaturated)
+ {
+ let lanesWithSameInput = waveMatch(input).x;
+ // Keep least significant lane in our set
+ let ourRepresentative = lanesWithSameInput & -lanesWithSameInput;
+ // The representative lanes for all lanes
+ var allRepresentatives = WaveActiveBitOr(ourRepresentative);
+
+ C ret;
+
+ // Iterate over set bits in mask from low to high.
+ // In each iteration the lowest bit is cleared.
+ while(bool(allRepresentatives))
+ {
+ // Broadcast input across warp.
+ let laneIdx = firstbitlow(allRepresentatives);
+ let uniformInput = broadcast(input, int(laneIdx));
+
+ // All lanes perform some cooperative computation with dynamic
+ // uniform input
+ C c = cooperate(uniformInput, otherArg);
+
+ // Update our return value until it
+ if(bool(allRepresentatives & ourRepresentative))
+ ret = c;
+
+ // Clear the lowest bit
+ allRepresentatives &= allRepresentatives - 1;
+ }
+
+ return ret;
+ }
+ else
+ {
+ return fallback(input, otherArg);
+ }
+}
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 25761b11c..59ac26833 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1303,6 +1303,18 @@ class NoSideEffectAttribute : public Attribute
{
SLANG_AST_CLASS(NoSideEffectAttribute)
};
+
+ /// A `[KnownBuiltin("name")]` attribute allows the compiler to
+ /// identify this declaration during compilation, despite obfuscation or
+ /// linkage removing optimizations
+ ///
+class KnownBuiltinAttribute : public Attribute
+{
+ SLANG_AST_CLASS(KnownBuiltinAttribute)
+
+ String name;
+};
+
/// A modifier that applies to types rather than declarations.
///
/// In most cases, the Slang compiler assumes that a modifier should
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 2d5f6aad7..b1f36ca2b 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -725,6 +725,18 @@ namespace Slang
deprecatedAttr->message = message;
}
+ else if (auto knownBuiltinAttr = as<KnownBuiltinAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
+
+ String name;
+ if(!checkLiteralStringVal(attr->args[0], &name))
+ {
+ return false;
+ }
+
+ knownBuiltinAttr->name = name;
+ }
else
{
if(attr->args.getCount() == 0)
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 5160f3c6f..4bd8506ed 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1454,7 +1454,8 @@ namespace Slang
// We could probably be broader than just parameters here
// eventually.
// Limit it for now though to make the specialization easier
- ensureDecl(localDeclRef, DeclCheckState::CanUseFuncSignature);
+ // TODO: why can't this use DeclCheckState::CanUseFuncSignature
+ ensureDecl(localDeclRef, DeclCheckState::TypesFullyResolved);
const auto type = localDeclRef.getDecl()->getType();
// We can only add overload candidates if this is known to be a function
if(const auto funType = as<FuncType>(type))
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 8a04e5cc8..071ea9639 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -21,6 +21,7 @@
#include "slang-ir-entry-point-raw-ptr-params.h"
#include "slang-ir-explicit-global-context.h"
#include "slang-ir-explicit-global-init.h"
+#include "slang-ir-fuse-satcoop.h"
#include "slang-ir-glsl-legalize.h"
#include "slang-ir-insts.h"
#include "slang-ir-inline.h"
@@ -350,6 +351,11 @@ Result linkAndOptimizeIR(
#endif
validateIRModuleIfEnabled(codeGenContext, irModule);
+ // It's important that this takes place before defunctionalization as we
+ // want to be able to easily discover the cooperate and fallback funcitons
+ // being passed to saturated_cooperation
+ fuseCallsToSaturatedCooperation(irModule);
+
// Next, we need to ensure that the code we emit for
// the target doesn't contain any operations that would
// be illegal on the target platform. For example,
@@ -397,6 +403,12 @@ Result linkAndOptimizeIR(
return SLANG_FAIL;
}
+ // Few of our targets support higher order functions, and
+ // we don't have the backend code to emit higher order functions for those
+ // which do.
+ // Specialize away these parameters
+ // TODO: We should implement a proper defunctionalization pass
+ changed |= specializeHigherOrderParameters(codeGenContext, irModule);
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
enableIRValidationAtInsert();
@@ -617,13 +629,6 @@ Result linkAndOptimizeIR(
break;
}
- // Few of our targets support higher order functions, and
- // we don't have the backend code to emit higher order functions for those
- // which do.
- // Specialize away these parameters
- // TODO: We should implement a proper defunctionalization pass
- specializeHigherOrderParameters(codeGenContext, irModule);
-
// For all targets, we translate load/store operations
// of aggregate types from/to byte-address buffers into
// stores of individual scalar or vector values.
diff --git a/source/slang/slang-ir-fuse-satcoop.cpp b/source/slang/slang-ir-fuse-satcoop.cpp
new file mode 100644
index 000000000..3c827ef25
--- /dev/null
+++ b/source/slang/slang-ir-fuse-satcoop.cpp
@@ -0,0 +1,540 @@
+#include "slang-ir-fuse-satcoop.h"
+
+#include "slang-ir-inline.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-specialize-function-call.h"
+#include "slang-ir-ssa-simplification.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+
+//
+// Some helpers
+//
+
+// Run an operation over every block in a module
+template<typename F>
+static void overAllBlocks(IRModule* module, F f)
+{
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRGlobalValueWithCode>(globalInst))
+ {
+ for (auto block : func->getBlocks())
+ {
+ f(block);
+ }
+ }
+ }
+}
+
+static bool uses(IRInst* used, IRInst* user)
+{
+ for(auto use = used->firstUse; use; use = use->nextUse)
+ {
+ if(use->getUser() == user)
+ return true;
+ }
+ return false;
+};
+
+// given: `f; x; g`
+// reorder instructions such that f and g are adjacent, in the form:
+// `p; f; g; q`,
+//
+// p is the set of instructions upon which g depends and q is the
+// set of instructions which depend on f. If these sets are not disjoint then
+// we can't float f and g together. Instructions not used by g and which don't
+// use f can go in either p or q.
+//
+// Returns g on success
+static IRInst* floatTogether(IRInst* f, IRInst* g)
+{
+ List<IRInst*> ps, qs;
+
+ auto usesF = [&](IRInst* i){
+ if(uses(f, i))
+ return true;
+ for(auto q : qs)
+ if(uses(q, i))
+ return true;
+ return false;
+ };
+ auto usedByG = [&](IRInst* i){
+ if(uses(i, g))
+ return true;
+ for(auto p : ps)
+ if(uses(i, p))
+ return true;
+ return false;
+ };
+
+ // Scan backwards to find which instructions g depends on, known as p
+ auto i = g->prev;
+ while(i != f)
+ {
+ SLANG_ASSERT(i);
+
+ // If any instruction in x has side effects, we can't reorder things
+ if(i->mightHaveSideEffects())
+ return nullptr;
+
+ if(usedByG(i))
+ ps.add(i);
+ i = i->prev;
+ }
+
+ // Scan forwards to compute instructions which depend on f, the instructions in q
+ i = f->next;
+ while(i != g)
+ {
+ if(usesF(i))
+ {
+ // If this happens then ps and qs are not disjoint, and we will not
+ // be able to float f and g together
+ if(ps.contains(i))
+ return nullptr;
+ qs.add(i);
+ }
+
+ i = i->next;
+ }
+
+ // Now we can safely reorder things by moving p;f;g before everything else
+ // Remember, we constructed ps in reverse, so we must insert these
+ // backwards too
+ for(Index j = ps.getCount()-1; j >= 0; --j)
+ {
+ auto p = ps[j];
+ p->removeFromParent();
+ p->insertBefore(f);
+ }
+ g->removeFromParent();
+ g->insertAfter(f);
+ return g;
+}
+
+// bifanout(f, g)((x, y), (a, b)) = (f(x, a), g(y, b))
+//
+// Make a function `bifanout` which applies two functions to their respective
+// elements in two pairs. Optionally the first and second inputs can be shared
+// instead of split in a tuple.
+//
+// The outputs are returned in a 2-tuple
+static IRFunc* makeBiFanout(IRBuilder& builder, IRFunc* f, IRFunc* g, bool shareFirst, bool shareSecond)
+{
+ SLANG_ASSERT(f->getParamCount() == 2);
+ SLANG_ASSERT(g->getParamCount() == f->getParamCount());
+ SLANG_ASSERT(!shareFirst || f->getParamType(0) == g->getParamType(0));
+ SLANG_ASSERT(!shareSecond || f->getParamType(1) == g->getParamType(1));
+ IRBuilderInsertLocScope insertLocScope(&builder);
+
+ // Create (using shareFirst = false, shareSecond = true as an example)
+ // func myFunc(s : S, u : (U1,U2)) -> (R1, R2)
+ // {
+ // let fRes = f(s, u.fst);
+ // let gRes = g(s, u.snd);
+ // return (fRes, gRes);
+ // }
+
+ // The return type is the tuple of f and g's return types
+ auto resType = builder.getTupleType(f->getResultType(), g->getResultType());
+ auto firstInputType = shareFirst
+ ? f->getParamType(0)
+ : builder.getTupleType(f->getParamType(0), g->getParamType(0));
+ auto secondInputType = shareSecond
+ ? f->getParamType(1)
+ : builder.getTupleType(f->getParamType(1), g->getParamType(1));
+
+ // Set up our function
+ // func myFunc(s : S, u : (U1,U2)) -> (R1, R2)
+ auto func = builder.createFunc();
+ builder.addDecoration(func, kIROp_ForceInlineDecoration);
+ builder.setDataType(func, builder.getFuncType({firstInputType, secondInputType}, resType));
+ builder.setInsertInto(func);
+ auto b = builder.emitBlock();
+ builder.setInsertInto(b);
+
+ auto s = builder.emitParam(firstInputType);
+ auto s1 = shareFirst ? s : builder.emitGetTupleElement(f->getParamType(0), s, 0);
+ auto s2 = shareFirst ? s : builder.emitGetTupleElement(g->getParamType(0), s, 1);
+
+ auto u = builder.emitParam(secondInputType);
+ auto u1 = shareSecond ? u : builder.emitGetTupleElement(f->getParamType(1), u, 0);
+ auto u2 = shareSecond ? u : builder.emitGetTupleElement(g->getParamType(1), u, 1);
+
+ // let fRes = f(s, u.fst);
+ auto fRes = builder.emitCallInst(f->getResultType(), f, {s1, u1});
+ // let gRes = g(s, u.snd);
+ auto gRes = builder.emitCallInst(g->getResultType(), g, {s2, u2});
+ // return (fRes, gRes);
+ builder.emitReturn(builder.emitMakeTuple(fRes, gRes));
+ return func;
+}
+
+// Given f : a -> uint4, g : b -> uint4, return z : (a, b) -> uint4 using
+// bitwise and to combine the outputs
+static IRFunc* makeWaveMatchBoth(IRBuilder& builder, IRType* inputTypeF, IRType* inputTypeG, IRInst* f, IRInst* g)
+{
+ // SLANG_ASSERT(f->getParamCount() == 1);
+ // SLANG_ASSERT(g->getParamCount() == f->getParamCount());
+ auto uint4Type = builder.getVectorType(builder.getUIntType(), 4);
+ // SLANG_ASSERT(f->getResultType() == uint4Type);
+ // SLANG_ASSERT(g->getResultType() == f->getResultType());
+ IRBuilderInsertLocScope insertLocScope(&builder);
+
+ // Create (using shareFirst = false, shareSecond = true as an example)
+ // func myFunc(x : (A,B)) -> uint4
+ // {
+ // let fRes = f(x.fst);
+ // let gRes = g(x.snd);
+ // return fRes & gRes;
+ // }
+
+ auto inputTypeFG = builder.getTupleType(inputTypeF, inputTypeG);
+ auto resType = uint4Type;
+
+ auto func = builder.createFunc();
+ builder.addDecoration(func, kIROp_ForceInlineDecoration);
+ builder.setDataType(func, builder.getFuncType({inputTypeFG}, resType));
+ builder.setInsertInto(func);
+ auto b = builder.emitBlock();
+ builder.setInsertInto(b);
+
+ auto x = builder.emitParam(inputTypeFG);
+ auto x1 = builder.emitGetTupleElement(inputTypeF, x, 0);
+ auto x2 = builder.emitGetTupleElement(inputTypeG, x, 1);
+
+ auto b1 = builder.emitCallInst(uint4Type, f, {x1});
+ auto b2 = builder.emitCallInst(uint4Type, g, {x2});
+ auto r = builder.emitBitAnd(uint4Type, b1, b2);
+
+ builder.emitReturn(r);
+ return func;
+}
+
+// Similar to above
+static IRFunc* makeBroadcastBoth(IRBuilder& builder, IRType* inputTypeF, IRType* inputTypeG, IRInst* f, IRInst* g)
+{
+ // SLANG_ASSERT(f->getParamCount() == 2);
+ // SLANG_ASSERT(g->getParamCount() == f->getParamCount());
+ auto intType = builder.getIntType();
+ // SLANG_ASSERT(f->getParamType(1) == intType);
+ // SLANG_ASSERT(g->getParamType(1) == f->getParamType(1));
+ IRBuilderInsertLocScope insertLocScope(&builder);
+
+ // Create (using shareFirst = false, shareSecond = true as an example)
+ // func myFunc(x : (A,B), i : int) -> (A, B)
+ // {
+ // let fRes = f(x.fst, i);
+ // let gRes = g(x.snd, i);
+ // return (fRes, gRes);
+ // }
+
+ auto inputTypeFG = builder.getTupleType(inputTypeF, inputTypeG);
+ auto resType = inputTypeFG;
+
+ auto func = builder.createFunc();
+ builder.addDecoration(func, kIROp_ForceInlineDecoration);
+ builder.setDataType(func, builder.getFuncType({inputTypeFG, intType}, resType));
+ builder.setInsertInto(func);
+ auto b = builder.emitBlock();
+ builder.setInsertInto(b);
+
+ auto x = builder.emitParam(inputTypeFG);
+ auto i = builder.emitParam(intType);
+ auto x1 = builder.emitGetTupleElement(inputTypeF, x, 0);
+ auto x2 = builder.emitGetTupleElement(inputTypeG, x, 1);
+
+ auto b1 = builder.emitCallInst(inputTypeF, f, {x1, i});
+ auto b2 = builder.emitCallInst(inputTypeG, g, {x2, i});
+ auto r = builder.emitMakeTuple(b1, b2);
+
+ builder.emitReturn(r);
+ return func;
+}
+
+// All the information on a call to saturated_cooperation_using
+struct SatCoopCall
+{
+ // The definition in hlsl.slang
+ IRGeneric* generic;
+
+ // The specialization of that call
+ IRSpecialize* specialize;
+
+ // Called 'A' in the definition
+ IRType* sharedInputType;
+ // Called 'B' in the definition
+ IRType* distinctInputType;
+ // Called 'C' in the definition
+ IRType* retType;
+
+ // The function arguments to the call
+ IRFunc* cooperate;
+ IRFunc* fallback;
+
+ // The inter-lane communication functions
+ // TODO: call specializeGeneric on these and extract the IRFunc
+ IRInst* waveMatch;
+ IRInst* broadcast;
+
+ // The values to pass to these functions
+ IRInst* sharedInput;
+ IRInst* distinctInput;
+};
+
+static SatCoopCall getSatCoopCall(IRCall* f)
+{
+ SatCoopCall ret;
+ ret.specialize = as<IRSpecialize>(f->getCallee());
+
+ // Since this is a call to saturated_cooperation, it must have at least
+ // three specialization arguments for the type parameters A, B, C. We allow
+ // more here for any dictionaries or witnesses.
+ SLANG_ASSERT(ret.specialize && ret.specialize->getArgCount() >= 3);
+ ret.generic = as<IRGeneric>(ret.specialize->getBase());
+ SLANG_ASSERT(ret.generic);
+ ret.sharedInputType = as<IRType>(ret.specialize->getArg(0));
+ ret.distinctInputType = as<IRType>(ret.specialize->getArg(1));
+ ret.retType = as<IRType>(ret.specialize->getArg(2));
+ SLANG_ASSERT(ret.sharedInputType);
+ SLANG_ASSERT(ret.distinctInputType);
+ SLANG_ASSERT(ret.retType);
+
+ SLANG_ASSERT(f->getArgCount() == 6);
+ ret.cooperate = as<IRFunc>(f->getArg(0));
+ ret.fallback = as<IRFunc>(f->getArg(1));
+ SLANG_ASSERT(ret.cooperate);
+ SLANG_ASSERT(ret.fallback);
+
+ ret.waveMatch = f->getArg(2);
+ ret.broadcast = f->getArg(3);
+ SLANG_ASSERT(ret.waveMatch);
+ SLANG_ASSERT(ret.broadcast);
+
+ ret.sharedInput = f->getArg(4);
+ ret.distinctInput = f->getArg(5);
+ SLANG_ASSERT(ret.sharedInput->getDataType() == ret.sharedInputType);
+ SLANG_ASSERT(ret.distinctInput->getDataType() == ret.distinctInputType);
+ return ret;
+}
+
+// transform:
+// a = sat_coop(c1, f1, s1, u1); // f
+// p;
+// q;
+// b = sat_coop(c2, f2, s2, u2); // g
+// to:
+// p;
+// (a,b) = sat_coop(c1 &&& c2, f1 &&& f2, (s1, s2), (u1, u2));
+// q;
+//
+// Removes the first two calls, and returns the second one if creation was
+// successful.
+//
+// This can fail if:
+//
+// p has side effects which c1 or f1 may depend on
+// q has side effects which c2 or f2 may depend on
+// p depends on a
+// the second call to sat_coop depends on a
+// the second call to sat_coop depends on q
+static IRCall* tryFuseCalls(IRBuilder& builder, IRCall* f, IRCall* g)
+{
+ // TODO: Make sure that the types in here are concrete, use
+ // `isGenericParam`
+
+ IRBuilderInsertLocScope insertLocScope(&builder);
+
+ SatCoopCall callF = getSatCoopCall(f);
+ SatCoopCall callG = getSatCoopCall(g);
+ // If these aren't referencing the same generic, then something has gone
+ // wrong in our assumptions.
+ SLANG_ASSERT(callF.generic == callG.generic);
+
+ // If g uses the result of f, we can't fuse them with this logic (we could
+ // however with a replacement for 'fanout')
+ if(uses(f, g))
+ return nullptr;
+
+ // If there is no safe way to float these together, then fail
+ const auto q = floatTogether(f, g);
+ if(!q)
+ return nullptr;
+ builder.setInsertBefore(q);
+
+ // As a slight neatening, we'll avoid wrapping and upwrapping a tuple (u,u)
+ // if both f and g use the same distinct input..
+ bool usesSameDistinctInput = callF.distinctInput == callG.distinctInput;
+ SLANG_ASSERT(!usesSameDistinctInput || callF.distinctInputType == callG.distinctInputType);
+
+ // Similarly for the shared input: if these use the same shared input then
+ // the fusing is simpler (no need to make a product of s1 and s2)
+ // TODO: if there is an injection from s1 to s2, then we can avoid the WaveMatch on s2
+ const bool usesSameSharedInput =
+ callF.sharedInput == callG.sharedInput &&
+ callF.waveMatch == callG.waveMatch &&
+ callF.broadcast == callG.broadcast;
+ SLANG_ASSERT(!usesSameSharedInput || callF.sharedInputType == callG.sharedInputType);
+
+ // Generate a new specialization of our saturated_cooperation_using function,
+ // reflecting the new input and output types.
+ const auto newRetType = builder.getTupleType(callF.retType, callG.retType);
+ const auto sharedInputType = usesSameSharedInput
+ ? callF.sharedInputType
+ : builder.getTupleType(callF.sharedInputType, callG.sharedInputType);
+ const auto distinctInputType = usesSameDistinctInput
+ ? callF.distinctInputType
+ : builder.getTupleType(callF.distinctInputType, callG.distinctInputType);
+
+ // Make sure there are no other generic parameters which are are failing to
+ // take care of here.
+ SLANG_ASSERT(callF.specialize->getArgCount() == 3);
+ SLANG_ASSERT(callG.specialize->getArgCount() == 3);
+
+ // Specialize our new call
+ const auto newSpec = builder.emitSpecializeInst(
+ builder.getTypeKind(),
+ callF.generic,
+ {sharedInputType, distinctInputType, newRetType});
+
+ // Make our new functions, and joined inputs
+ const auto newCooperate = makeBiFanout(builder, callF.cooperate, callG.cooperate, usesSameSharedInput, usesSameDistinctInput);
+ const auto newFallback = makeBiFanout(builder, callF.fallback, callG.fallback, usesSameSharedInput, usesSameDistinctInput);
+ const auto newWaveMatch = usesSameSharedInput
+ ? callF.waveMatch
+ : makeWaveMatchBoth(builder, callF.sharedInputType, callG.sharedInputType, callF.waveMatch, callG.waveMatch);
+ const auto newBroadcast = usesSameSharedInput
+ ? callF.broadcast
+ : makeBroadcastBoth(builder, callF.sharedInputType, callG.sharedInputType, callF.broadcast, callG.broadcast);
+ const auto newSharedInput = usesSameSharedInput
+ ? callF.sharedInput
+ : builder.emitMakeTuple(callF.sharedInput, callG.sharedInput);
+ const auto newDistinctInput = usesSameDistinctInput
+ ? callF.distinctInput
+ : builder.emitMakeTuple(callF.distinctInput, callG.distinctInput);
+
+ // Call it and extract the results from f and g
+ const auto res = builder.emitCallInst(
+ newRetType,
+ newSpec,
+ {newCooperate, newFallback, newWaveMatch, newBroadcast, newSharedInput, newDistinctInput});
+ const auto resF = builder.emitGetTupleElement(callF.retType, res, 0);
+ const auto resG = builder.emitGetTupleElement(callG.retType, res, 1);
+ f->replaceUsesWith(resF);
+ g->replaceUsesWith(resG);
+ f->removeAndDeallocate();
+ g->removeAndDeallocate();
+
+ return res;
+}
+
+//
+// Identify calls which we can fuse
+//
+IRCall* isKnownFunction(const char* n, IRInst* i)
+{
+ auto call = as<IRCall>(i);
+ if(!call)
+ return nullptr;
+ // saturated_cooperation is a generic function, so look for specializations thereof
+ auto spec = as<IRSpecialize>(call->getCallee());
+ if(!spec)
+ return nullptr;
+ auto generic = findSpecializedGeneric(spec);
+ if(!generic)
+ return nullptr;
+
+ auto h = generic->findDecoration<IRKnownBuiltinDecoration>();
+ if(!h || h->getName() != n)
+ return nullptr;
+ return call;
+}
+
+//
+// We perform a left fold over calls to saturated_cooperation
+//
+// sc(ca, fa)
+// sc(cb, fb)
+// sc(cc, fc)
+//
+// to
+//
+// sc(cacbcc, fafbfc)
+//
+// where cacbcc (and fafbfc) look like
+//
+// cacbcc(){
+// cacb();
+// cc();
+// }
+//
+// cacb(){
+// ca();
+// cb();
+// }
+//
+// These helper functions are inlined shortly after and the generated code is
+// exactly what you'd expect: it's the body of sat_coop except that the
+// original call to cooperate is replaced by three calls to ca, cb, cc.
+//
+// We use a fold here rather than accumulating everything at once as it's
+// easier to implement fusing for 2 functions than n
+static void fuseCallsInBlock(IRBuilder& builder, IRBlock* block)
+{
+ // first, inline calls to saturated_cooperation to expose
+ // saturated_cooperation_using which is simpler to fuse.
+ // It is simpler to fuse because it makes explicit the inter-lane
+ // communication functions, which we can use as buiding blocks in our
+ // composition.
+
+ List<IRCall*> toInline;
+ for (auto inst : block->getChildren())
+ {
+ if(auto sat_coop = isKnownFunction("saturated_cooperation", inst))
+ toInline.add(sat_coop);
+ }
+ for(auto c : toInline)
+ inlineCall(c);
+
+ // Walk over the instructions in this block
+ // If we see a call to sat_coop then remember where it is and keep
+ // walking, if we reach another call without first encountering any
+ // instructions with which our first call can't be safely reordered
+ // then we remove the first call and replace the second with a fused
+ // call.
+ IRCall* lastCall = nullptr;
+ for(auto inst = block->getFirstInst(); inst != block->getTerminator(); inst = inst->getNextInst())
+ {
+ if(auto call = isKnownFunction("saturated_cooperation_using", inst))
+ {
+ if(lastCall)
+ {
+ auto fused = tryFuseCalls(builder, lastCall, call);
+ if(fused)
+ {
+ inst = fused;
+ lastCall = fused;
+ }
+ else
+ {
+ lastCall = call;
+ }
+ }
+ else
+ {
+ lastCall = call;
+ }
+ }
+ }
+}
+
+void fuseCallsToSaturatedCooperation(IRModule* module)
+{
+ IRBuilder builder(module);
+ overAllBlocks(module, [&](auto b){fuseCallsInBlock(builder, b);});
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-fuse-satcoop.h b/source/slang/slang-ir-fuse-satcoop.h
new file mode 100644
index 000000000..8761c13da
--- /dev/null
+++ b/source/slang/slang-ir-fuse-satcoop.h
@@ -0,0 +1,11 @@
+#pragma once
+
+namespace Slang
+{
+ struct CodeGenContext;
+ struct IRModule;
+ struct IRType;
+
+ /// Fuse adjacent calls to saturated_cooperation
+ void fuseCallsToSaturatedCooperation(IRModule* module);
+}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index d9036f8bc..ea7547171 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -738,6 +738,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// Marks an interface as a COM interface declaration.
INST(ComInterfaceDecoration, COMInterface, 0, 0)
+ /// Attaches a name to this instruction so that it can be identified
+ /// later in the compiler reliably
+ INST(KnownBuiltinDecoration, KnownBuiltinDecoration, 1, 0)
+
/* Decorations for RTTI objects */
INST(RTTITypeSizeDecoration, RTTI_typeSize, 1, 0)
INST(AnyValueSizeDecoration, AnyValueSize, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 3f49be801..0b4ddf1a6 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -541,6 +541,18 @@ struct IRTorchEntryPointDecoration : IRDecoration
UnownedStringSlice getFunctionName() { return getFunctionNameOperand()->getStringSlice(); }
};
+struct IRKnownBuiltinDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_KnownBuiltinDecoration
+ };
+ IR_LEAF_ISA(KnownBuiltinDecoration)
+
+ IRStringLit* getNameOperand() { return cast<IRStringLit>(getOperand(0)); }
+ UnownedStringSlice getName() { return getNameOperand()->getStringSlice(); }
+};
+
struct IRFormatDecoration : IRDecoration
{
enum { kOp = kIROp_FormatDecoration };
@@ -3016,6 +3028,14 @@ public:
UInt argCount,
IRInst* const* args);
+ IRInst* emitSpecializeInst(
+ IRType* type,
+ IRInst* genericVal,
+ const List<IRInst*>& args)
+ {
+ return emitSpecializeInst(type, genericVal, args.getCount(), args.begin());
+ }
+
IRInst* emitLookupInterfaceMethodInst(
IRType* type,
IRInst* witnessTableVal,
@@ -3029,13 +3049,13 @@ public:
IRInst* emitUnpackAnyValue(IRType* type, IRInst* value);
- IRInst* emitCallInst(
+ IRCall* emitCallInst(
IRType* type,
IRInst* func,
UInt argCount,
IRInst* const* args);
- IRInst* emitCallInst(
+ IRCall* emitCallInst(
IRType* type,
IRInst* func,
List<IRInst*> const& args)
@@ -4051,6 +4071,11 @@ public:
// TODO: Ellie, correct int type here?
addDecoration(value, d, maxCount);
}
+
+ void addKnownBuiltinDecoration(IRInst* value, UnownedStringSlice const& name)
+ {
+ addDecoration(value, kIROp_KnownBuiltinDecoration, getStringValue(name));
+ }
};
// Helper to establish the source location that will be used
diff --git a/source/slang/slang-ir-lower-witness-lookup.cpp b/source/slang/slang-ir-lower-witness-lookup.cpp
index f87633656..841617ac8 100644
--- a/source/slang/slang-ir-lower-witness-lookup.cpp
+++ b/source/slang/slang-ir-lower-witness-lookup.cpp
@@ -284,7 +284,7 @@ struct WitnessLookupLoweringContext
callReturnType = dispatchFuncType->getResultType();
}
- auto call = builder.emitCallInst(
+ IRInst* ret = builder.emitCallInst(
callReturnType,
entry,
(UInt)args.getCount(),
@@ -292,9 +292,9 @@ struct WitnessLookupLoweringContext
// If result type is an associated type, we need to pack it into an anyValue.
if (as<IRAssociatedType>(dispatchFuncType->getResultType()))
{
- call = builder.emitPackAnyValue(dispatchFuncType->getResultType(), call);
+ ret = builder.emitPackAnyValue(dispatchFuncType->getResultType(), ret);
}
- builder.emitReturn(call);
+ builder.emitReturn(ret);
}
builder.setInsertInto(firstBlock);
if (witnessTables.getCount() == 1)
diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp
index bc238e6ec..e0a757e11 100644
--- a/source/slang/slang-ir-specialize-function-call.cpp
+++ b/source/slang/slang-ir-specialize-function-call.cpp
@@ -32,7 +32,7 @@ bool FunctionCallSpecializeCondition::isParamSuitableForSpecialization(IRParam*
if (as<IRGlobalParam>(arg)) return true;
// Similarly for these global values
- if( as<IRGlobalValueWithCode>(arg) ) return true;
+ if (as<IRGlobalValueWithCode>(arg)) return true;
// As we will see later, we can also
// specialize a call when the argument
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 6adf8ee1c..72b65ea3e 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3395,7 +3395,7 @@ namespace Slang
return inst;
}
- IRInst* IRBuilder::emitCallInst(
+ IRCall* IRBuilder::emitCallInst(
IRType* type,
IRInst* pFunc,
UInt argCount,
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 67f812134..66295b0fb 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -677,6 +677,7 @@ struct IRInst
//
IRInstListBase m_decorationsAndChildren;
+
IRInst* getFirstDecorationOrChild() { return m_decorationsAndChildren.first; }
IRInst* getLastDecorationOrChild() { return m_decorationsAndChildren.last; }
IRInstListBase getDecorationsAndChildren() { return m_decorationsAndChildren; }
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index b1726487d..486c152d5 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1271,6 +1271,10 @@ static void addLinkageDecoration(
builder->addPublicDecoration(inst);
builder->addExternCppDecoration(inst, decl->getName()->text.getUnownedSlice());
}
+ else if (as<KnownBuiltinAttribute>(modifier))
+ {
+ builder->addKnownBuiltinDecoration(inst, decl->getName()->text.getUnownedSlice());
+ }
}
if (as<InterfaceDecl>(decl->parentDecl) &&
decl->parentDecl->hasModifier<ComInterfaceAttribute>() &&