#include "slang-ir-fuse-satcoop.h" #include "slang-ir-inline.h" #include "slang-ir-insts.h" #include "slang-ir-specialize-function-call.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-util.h" #include "slang-ir.h" namespace Slang { // // Some helpers // static bool uses(IRInst* used, IRInst* user) { for (auto use = used->firstUse; use; use = use->nextUse) { if (use->getUser() == user) return true; } return false; }; // given: `f; x; g` // reorder instructions such that f and g are adjacent, in the form: // `p; f; g; q`, // // p is the set of instructions upon which g depends and q is the // set of instructions which depend on f. If these sets are not disjoint then // we can't float f and g together. Instructions not used by g and which don't // use f can go in either p or q. // // Returns g on success static IRInst* floatTogether(IRInst* f, IRInst* g) { List ps, qs; auto usesF = [&](IRInst* i) { if (uses(f, i)) return true; for (auto q : qs) if (uses(q, i)) return true; return false; }; auto usedByG = [&](IRInst* i) { if (uses(i, g)) return true; for (auto p : ps) if (uses(i, p)) return true; return false; }; // Scan backwards to find which instructions g depends on, known as p auto i = g->prev; while (i != f) { SLANG_ASSERT(i); // If the instruction is not movable, then obviously we can't move it. // // For a slight optimization: This is actually stricter than we need: // if `x = p;q` and f and g are movable, then we can safely move f and // g in and maintain the ordering of p and q if (!isMovableInst(i)) return nullptr; if (usedByG(i)) ps.add(i); i = i->prev; } // Scan forwards to compute instructions which depend on f, the instructions in q i = f->next; while (i != g) { if (usesF(i)) { // If this happens then ps and qs are not disjoint, and we will not // be able to float f and g together if (ps.contains(i)) return nullptr; qs.add(i); } i = i->next; } // Now we can safely reorder things by moving p;f;g before everything else // Remember, we constructed ps in reverse, so we must insert these // backwards too for (Index j = ps.getCount() - 1; j >= 0; --j) { auto p = ps[j]; p->removeFromParent(); p->insertBefore(f); } g->removeFromParent(); g->insertAfter(f); return g; } // bifanout(f, g)((x, y), (a, b)) = (f(x, a), g(y, b)) // // Make a function `bifanout` which applies two functions to their respective // elements in two pairs. Optionally the first and second inputs can be shared // instead of split in a tuple. // // The outputs are returned in a 2-tuple static IRFunc* makeBiFanout( IRBuilder& builder, IRFunc* f, IRFunc* g, bool shareFirst, bool shareSecond) { SLANG_ASSERT(f->getParamCount() == 2); SLANG_ASSERT(g->getParamCount() == f->getParamCount()); SLANG_ASSERT(!shareFirst || f->getParamType(0) == g->getParamType(0)); SLANG_ASSERT(!shareSecond || f->getParamType(1) == g->getParamType(1)); IRBuilderInsertLocScope insertLocScope(&builder); // Create (using shareFirst = false, shareSecond = true as an example) // func myFunc(s : S, u : (U1,U2)) -> (R1, R2) // { // let fRes = f(s, u.fst); // let gRes = g(s, u.snd); // return (fRes, gRes); // } // The return type is the tuple of f and g's return types auto resType = builder.getTupleType(f->getResultType(), g->getResultType()); auto firstInputType = shareFirst ? f->getParamType(0) : builder.getTupleType(f->getParamType(0), g->getParamType(0)); auto secondInputType = shareSecond ? f->getParamType(1) : builder.getTupleType(f->getParamType(1), g->getParamType(1)); // Set up our function // func myFunc(s : S, u : (U1,U2)) -> (R1, R2) auto func = builder.createFunc(); builder.addDecoration(func, kIROp_ForceInlineDecoration); builder.setDataType(func, builder.getFuncType({firstInputType, secondInputType}, resType)); builder.setInsertInto(func); auto b = builder.emitBlock(); builder.setInsertInto(b); auto s = builder.emitParam(firstInputType); auto s1 = shareFirst ? s : builder.emitGetTupleElement(f->getParamType(0), s, 0); auto s2 = shareFirst ? s : builder.emitGetTupleElement(g->getParamType(0), s, 1); auto u = builder.emitParam(secondInputType); auto u1 = shareSecond ? u : builder.emitGetTupleElement(f->getParamType(1), u, 0); auto u2 = shareSecond ? u : builder.emitGetTupleElement(g->getParamType(1), u, 1); // let fRes = f(s, u.fst); auto fRes = builder.emitCallInst(f->getResultType(), f, {s1, u1}); // let gRes = g(s, u.snd); auto gRes = builder.emitCallInst(g->getResultType(), g, {s2, u2}); // return (fRes, gRes); builder.emitReturn(builder.emitMakeTuple(fRes, gRes)); return func; } // Given f : a -> uint4, g : b -> uint4, return z : (a, b) -> uint4 using // bitwise and to combine the outputs static IRFunc* makeWaveMatchBoth( IRBuilder& builder, IRType* inputTypeF, IRType* inputTypeG, IRInst* f, IRInst* g) { // SLANG_ASSERT(f->getParamCount() == 1); // SLANG_ASSERT(g->getParamCount() == f->getParamCount()); auto uint4Type = builder.getVectorType(builder.getUIntType(), 4); // SLANG_ASSERT(f->getResultType() == uint4Type); // SLANG_ASSERT(g->getResultType() == f->getResultType()); IRBuilderInsertLocScope insertLocScope(&builder); // Create (using shareFirst = false, shareSecond = true as an example) // func myFunc(x : (A,B)) -> uint4 // { // let fRes = f(x.fst); // let gRes = g(x.snd); // return fRes & gRes; // } auto inputTypeFG = builder.getTupleType(inputTypeF, inputTypeG); auto resType = uint4Type; auto func = builder.createFunc(); builder.addDecoration(func, kIROp_ForceInlineDecoration); builder.setDataType(func, builder.getFuncType({inputTypeFG}, resType)); builder.setInsertInto(func); auto b = builder.emitBlock(); builder.setInsertInto(b); auto x = builder.emitParam(inputTypeFG); auto x1 = builder.emitGetTupleElement(inputTypeF, x, 0); auto x2 = builder.emitGetTupleElement(inputTypeG, x, 1); auto b1 = builder.emitCallInst(uint4Type, f, {x1}); auto b2 = builder.emitCallInst(uint4Type, g, {x2}); auto r = builder.emitBitAnd(uint4Type, b1, b2); builder.emitReturn(r); return func; } // Similar to above static IRFunc* makeBroadcastBoth( IRBuilder& builder, IRType* inputTypeF, IRType* inputTypeG, IRInst* f, IRInst* g) { // SLANG_ASSERT(f->getParamCount() == 2); // SLANG_ASSERT(g->getParamCount() == f->getParamCount()); auto intType = builder.getIntType(); // SLANG_ASSERT(f->getParamType(1) == intType); // SLANG_ASSERT(g->getParamType(1) == f->getParamType(1)); IRBuilderInsertLocScope insertLocScope(&builder); // Create (using shareFirst = false, shareSecond = true as an example) // func myFunc(x : (A,B), i : int) -> (A, B) // { // let fRes = f(x.fst, i); // let gRes = g(x.snd, i); // return (fRes, gRes); // } auto inputTypeFG = builder.getTupleType(inputTypeF, inputTypeG); auto resType = inputTypeFG; auto func = builder.createFunc(); builder.addDecoration(func, kIROp_ForceInlineDecoration); builder.setDataType(func, builder.getFuncType({inputTypeFG, intType}, resType)); builder.setInsertInto(func); auto b = builder.emitBlock(); builder.setInsertInto(b); auto x = builder.emitParam(inputTypeFG); auto i = builder.emitParam(intType); auto x1 = builder.emitGetTupleElement(inputTypeF, x, 0); auto x2 = builder.emitGetTupleElement(inputTypeG, x, 1); auto b1 = builder.emitCallInst(inputTypeF, f, {x1, i}); auto b2 = builder.emitCallInst(inputTypeG, g, {x2, i}); auto r = builder.emitMakeTuple(b1, b2); builder.emitReturn(r); return func; } // All the information on a call to saturated_cooperation_using struct SatCoopCall { // The definition in hlsl.slang IRGeneric* generic; // The specialization of that call IRSpecialize* specialize; // Called 'A' in the definition IRType* sharedInputType; // Called 'B' in the definition IRType* distinctInputType; // Called 'C' in the definition IRType* retType; // The function arguments to the call IRFunc* cooperate; IRFunc* fallback; // The inter-lane communication functions // TODO: call specializeGeneric on these and extract the IRFunc IRInst* waveMatch; IRInst* broadcast; // The values to pass to these functions IRInst* sharedInput; IRInst* distinctInput; }; static SatCoopCall getSatCoopCall(IRCall* f) { SatCoopCall ret; ret.specialize = as(f->getCallee()); // Since this is a call to saturated_cooperation, it must have at least // three specialization arguments for the type parameters A, B, C. We allow // more here for any dictionaries or witnesses. SLANG_ASSERT(ret.specialize && ret.specialize->getArgCount() >= 3); ret.generic = as(ret.specialize->getBase()); SLANG_ASSERT(ret.generic); ret.sharedInputType = as(ret.specialize->getArg(0)); ret.distinctInputType = as(ret.specialize->getArg(1)); ret.retType = as(ret.specialize->getArg(2)); SLANG_ASSERT(ret.sharedInputType); SLANG_ASSERT(ret.distinctInputType); SLANG_ASSERT(ret.retType); SLANG_ASSERT(f->getArgCount() == 6); ret.cooperate = as(f->getArg(0)); ret.fallback = as(f->getArg(1)); SLANG_ASSERT(ret.cooperate); SLANG_ASSERT(ret.fallback); ret.waveMatch = f->getArg(2); ret.broadcast = f->getArg(3); SLANG_ASSERT(ret.waveMatch); SLANG_ASSERT(ret.broadcast); ret.sharedInput = f->getArg(4); ret.distinctInput = f->getArg(5); SLANG_ASSERT(ret.sharedInput->getDataType() == ret.sharedInputType); SLANG_ASSERT(ret.distinctInput->getDataType() == ret.distinctInputType); return ret; } // transform: // a = sat_coop(c1, f1, s1, u1); // f // p; // q; // b = sat_coop(c2, f2, s2, u2); // g // to: // p; // (a,b) = sat_coop(c1 &&& c2, f1 &&& f2, (s1, s2), (u1, u2)); // q; // // Removes the first two calls, and returns the second one if creation was // successful. // // This can fail if: // // p has side effects which c1 or f1 may depend on // q has side effects which c2 or f2 may depend on // p depends on a // the second call to sat_coop depends on a // the second call to sat_coop depends on q static IRCall* tryFuseCalls(IRBuilder& builder, IRCall* f, IRCall* g) { // TODO: Make sure that the types in here are concrete, use // `isGenericParam` IRBuilderInsertLocScope insertLocScope(&builder); SatCoopCall callF = getSatCoopCall(f); SatCoopCall callG = getSatCoopCall(g); // If these aren't referencing the same generic, then something has gone // wrong in our assumptions. SLANG_ASSERT(callF.generic == callG.generic); // If g uses the result of f, we can't fuse them with this logic (we could // however with a replacement for 'fanout') if (uses(f, g)) return nullptr; // If there is no safe way to float these together, then fail const auto q = floatTogether(f, g); if (!q) return nullptr; builder.setInsertBefore(q); // As a slight neatening, we'll avoid wrapping and upwrapping a tuple (u,u) // if both f and g use the same distinct input.. bool usesSameDistinctInput = callF.distinctInput == callG.distinctInput; SLANG_ASSERT(!usesSameDistinctInput || callF.distinctInputType == callG.distinctInputType); // Similarly for the shared input: if these use the same shared input then // the fusing is simpler (no need to make a product of s1 and s2) // TODO: if there is an injection from s1 to s2, then we can avoid the WaveMatch on s2 const bool usesSameSharedInput = callF.sharedInput == callG.sharedInput && callF.waveMatch == callG.waveMatch && callF.broadcast == callG.broadcast; SLANG_ASSERT(!usesSameSharedInput || callF.sharedInputType == callG.sharedInputType); // Generate a new specialization of our saturated_cooperation_using function, // reflecting the new input and output types. const auto newRetType = builder.getTupleType(callF.retType, callG.retType); const auto sharedInputType = usesSameSharedInput ? callF.sharedInputType : builder.getTupleType(callF.sharedInputType, callG.sharedInputType); const auto distinctInputType = usesSameDistinctInput ? callF.distinctInputType : builder.getTupleType(callF.distinctInputType, callG.distinctInputType); // Make sure there are no other generic parameters which are are failing to // take care of here. SLANG_ASSERT(callF.specialize->getArgCount() == 3); SLANG_ASSERT(callG.specialize->getArgCount() == 3); // Specialize our new call const auto newSpec = builder.emitSpecializeInst( builder.getTypeKind(), callF.generic, {sharedInputType, distinctInputType, newRetType}); // Make our new functions, and joined inputs const auto newCooperate = makeBiFanout( builder, callF.cooperate, callG.cooperate, usesSameSharedInput, usesSameDistinctInput); const auto newFallback = makeBiFanout( builder, callF.fallback, callG.fallback, usesSameSharedInput, usesSameDistinctInput); const auto newWaveMatch = usesSameSharedInput ? callF.waveMatch : makeWaveMatchBoth( builder, callF.sharedInputType, callG.sharedInputType, callF.waveMatch, callG.waveMatch); const auto newBroadcast = usesSameSharedInput ? callF.broadcast : makeBroadcastBoth( builder, callF.sharedInputType, callG.sharedInputType, callF.broadcast, callG.broadcast); const auto newSharedInput = usesSameSharedInput ? callF.sharedInput : builder.emitMakeTuple(callF.sharedInput, callG.sharedInput); const auto newDistinctInput = usesSameDistinctInput ? callF.distinctInput : builder.emitMakeTuple(callF.distinctInput, callG.distinctInput); // Call it and extract the results from f and g const auto res = builder.emitCallInst( newRetType, newSpec, {newCooperate, newFallback, newWaveMatch, newBroadcast, newSharedInput, newDistinctInput}); const auto resF = builder.emitGetTupleElement(callF.retType, res, 0); const auto resG = builder.emitGetTupleElement(callG.retType, res, 1); f->replaceUsesWith(resF); g->replaceUsesWith(resG); f->removeAndDeallocate(); g->removeAndDeallocate(); return res; } // // Identify calls which we can fuse // IRCall* isKnownFunction(KnownBuiltinDeclName expectedNameEnum, IRInst* i) { auto call = as(i); if (!call) return nullptr; // saturated_cooperation is a generic function, so look for specializations thereof auto spec = as(call->getCallee()); if (!spec) return nullptr; auto generic = findSpecializedGeneric(spec); if (!generic) return nullptr; auto inner = findGenericReturnVal(generic); if (!inner) return nullptr; auto h = inner->findDecoration(); if (!h) return nullptr; if (h->getName() != expectedNameEnum) return nullptr; return call; } // // We perform a left fold over calls to saturated_cooperation // // sc(ca, fa) // sc(cb, fb) // sc(cc, fc) // // to // // sc(cacbcc, fafbfc) // // where cacbcc (and fafbfc) look like // // cacbcc(){ // cacb(); // cc(); // } // // cacb(){ // ca(); // cb(); // } // // These helper functions are inlined shortly after and the generated code is // exactly what you'd expect: it's the body of sat_coop except that the // original call to cooperate is replaced by three calls to ca, cb, cc. // // We use a fold here rather than accumulating everything at once as it's // easier to implement fusing for 2 functions than n static void fuseCallsInBlock(IRBuilder& builder, IRBlock* block) { // first, inline calls to saturated_cooperation to expose // saturated_cooperation_using which is simpler to fuse. // It is simpler to fuse because it makes explicit the inter-lane // communication functions, which we can use as buiding blocks in our // composition. List toInline; for (auto inst : block->getChildren()) { if (auto sat_coop = isKnownFunction(KnownBuiltinDeclName::saturated_cooperation, inst)) toInline.add(sat_coop); } for (auto c : toInline) inlineCall(c); // Walk over the instructions in this block // If we see a call to sat_coop then remember where it is and keep // walking, if we reach another call without first encountering any // instructions with which our first call can't be safely reordered // then we remove the first call and replace the second with a fused // call. IRCall* lastCall = nullptr; for (auto inst = block->getFirstInst(); inst != block->getTerminator(); inst = inst->getNextInst()) { if (auto call = isKnownFunction(KnownBuiltinDeclName::saturated_cooperation_using, inst)) { if (lastCall) { auto fused = tryFuseCalls(builder, lastCall, call); if (fused) { inst = fused; lastCall = fused; } else { lastCall = call; } } else { lastCall = call; } } } } void fuseCallsToSaturatedCooperation(IRModule* module) { IRBuilder builder(module); overAllBlocks(module, [&](auto b) { fuseCallsInBlock(builder, b); }); } } // namespace Slang