diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-22 21:16:35 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-22 21:16:35 -0700 |
| commit | 259a015feb9d4ab65e8fbba32f6c777e92780cc7 (patch) | |
| tree | 45bd4cb9217325c67f5a27d8562b0e7e6b79bb77 | |
| parent | d4f99c8bac8b28f18c864a717d8833db6a1c872d (diff) | |
Type legalization and autodiff bug fixes. (#2722)
* Bug fixes.
* Fix.
* Only perform autodiff for functions whose derivative is actually used.
* Fix loop optimize bug.
* Fix high order diff.
* Fix trivial diff func generation.
* Fixes.
* Cleanup.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
24 files changed, 238 insertions, 71 deletions
diff --git a/source/core/slang-dictionary.h b/source/core/slang-dictionary.h index e923832e5..fffec9640 100644 --- a/source/core/slang-dictionary.h +++ b/source/core/slang-dictionary.h @@ -648,6 +648,14 @@ namespace Slang { return dict.AddIfNotExists(_Move(obj), _DummyClass()); } + bool add(const T& obj) + { + return dict.AddIfNotExists(obj, _DummyClass()); + } + bool add(T&& obj) + { + return dict.AddIfNotExists(_Move(obj), _DummyClass()); + } void Remove(const T & obj) { dict.Remove(obj); diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index a9b8209f3..26a673512 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1022,31 +1022,35 @@ __generic<T : __BuiltinFloatingPointType, let N : int> [__readNone] T __determinant_impl(matrix<T,N,N> m) { - if (N == 1) - return m[0][0]; - else if (N == 2) - return m[0][0] * m[1][1] - m[0][1] * m[1][0]; - else if (N == 3) + T result = T(0); + switch (N) { - return m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2]) + case 1: + result = m[0][0]; + break; + case 2: + result = m[0][0] * m[1][1] - m[0][1] * m[1][0]; + break; + case 3: + result = m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2]) - m[1][0] * (m[0][1] * m[2][2] - m[2][1] * m[0][2]) + m[2][0] * (m[0][1] * m[1][2] - m[1][1] * m[0][2]); - } - else if (N == 4) - { - T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; + break; + case 4: + T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3]; T s01 = m[2][1] * m[3][3] - m[3][1] * m[2][3]; T s02 = m[2][1] * m[3][2] - m[3][1] * m[2][2]; T s03 = m[2][0] * m[3][3] - m[3][0] * m[2][3]; T s04 = m[2][0] * m[3][2] - m[3][0] * m[2][2]; T s05 = m[2][0] * m[3][1] - m[3][0] * m[2][1]; - return m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02) + result = m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02) - m[0][1] * (m[1][0] * s00 - m[1][2] * s03 + m[1][3] * s04) + m[0][2] * (m[1][0] * s01 - m[1][1] * s03 + m[1][3] * s05) - m[0][3] * (m[1][0] * s02 - m[1][1] * s04 + m[1][2] * s05); + break; } - return T(0.0); + return result; } __generic<T : __BuiltinFloatingPointType, let N : int> [ForwardDerivativeOf(determinant)] diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index bc62e488f..1b4eed8fd 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -372,6 +372,8 @@ Result linkAndOptimizeIR( changed |= specializeModule(irModule); dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE"); + eliminateDeadCode(irModule); + validateIRModuleIfEnabled(codeGenContext, irModule); // Inline calls to any functions marked with [__unsafeInlineEarly] again, diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index f3c739894..fe3c70bde 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -199,7 +199,7 @@ struct CFGNormalizationPass bool currBreakRegion = false; bool currBaseRegion = true; - // Detect the trivial case. The current block is alredy + // Detect the trivial case. The current block is already // in the next region => this region is empty. // if (afterBlocks.contains(currentBlock)) diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index bc7e03ad3..3f31f1463 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -65,7 +65,7 @@ void ForwardDiffTranscriber::generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFu auto primal = builder.emitDefaultConstruct(pairType->getValueType()); builder.markInstAsPrimal(primal); auto diff = getDifferentialZeroOfType(&builder, pairType->getValueType()); - builder.markInstAsDifferential(primal); + builder.markInstAsDifferential(diff, primal->getDataType()); auto val = builder.emitMakeDifferentialPair(pairType, primal, diff); builder.markInstAsMixedDifferential(val); @@ -178,7 +178,8 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType()); auto resultType = primalArith->getDataType(); - auto diffType = (IRType*) differentiableTypeConformanceContext.getDifferentialForType(builder, resultType); + auto origResultType = origArith->getDataType(); + auto diffType = (IRType*)differentiateType(builder, origResultType); switch(origArith->getOp()) { @@ -263,15 +264,13 @@ InstPair ForwardDiffTranscriber::transcribeSelect(IRBuilder* builder, IRInst* or auto primalSelect = maybeCloneForPrimalInst(builder, origSelect); - auto resultType = primalCondition->getDataType(); - // If both sides have no differential, skip if (diffLeft || diffRight) { diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType()); diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType()); - auto diffType = (IRType*) differentiableTypeConformanceContext.getDifferentialForType(builder, resultType); + auto diffType = differentiateType(builder, origSelect->getDataType()); return InstPair( primalSelect, @@ -1831,7 +1830,9 @@ String ForwardDiffTranscriber::makeDiffPairName(IRInst* origVar) InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType) { - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalType)) + SLANG_UNUSED(primalType); + + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)origParam->getFullType())) { IRInst* diffPairParam = builder->emitParam(diffPairType); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index e01d65f4f..f23e45be0 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -391,6 +391,8 @@ namespace Slang auto origFuncType = as<IRFuncType>(origFunc->getFullType()); List<IRInst*> primalArgs, propagateArgs; List<IRType*> primalTypes, propagateTypes; + IRType* primalResultType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getResultType()); + for (UInt i = 0; i < origFuncType->getParamCount(); i++) { auto primalParamType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getParamType(i)); @@ -465,11 +467,11 @@ namespace Slang auto primalFuncType = builder.getFuncType( primalTypes, - origFuncType->getResultType()); + primalResultType); primalArgs.add(intermediateVar); primalTypes.add(builder.getOutType(intermediateType)); auto primalFunc = builder.emitBackwardDifferentiatePrimalInst(primalFuncType, specializedOriginalFunc); - builder.emitCallInst(origFuncType->getResultType(), primalFunc, primalArgs); + builder.emitCallInst(primalResultType, primalFunc, primalArgs); propagateTypes.add(intermediateType); propagateArgs.add(builder.emitLoad(intermediateVar)); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 6784391a0..2bc67e561 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -474,7 +474,9 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy } default: - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); + return (IRType*)maybeCloneForPrimalInst( + builder, + differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)origType)); } } @@ -674,7 +676,7 @@ InstPair AutoDiffTranscriberBase::transcribeParam(IRBuilder* builder, IRParam* o { auto primal = cloneInst(&cloneEnv, builder, origParam); IRInst* diff = nullptr; - if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType)) + if (IRType* diffType = differentiateType(builder, (IRType*)origParam->getDataType())) { diff = builder->emitParam(diffType); } @@ -749,11 +751,11 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui // result, it's useful to have a method to generate zero literals of any (arithmetic) type. // The current implementation requires that types are defined linearly. // -IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) +IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, IRType* originalType) { - primalType = (IRType*)unwrapAttributedType(primalType); - - if (auto diffType = differentiateType(builder, primalType)) + originalType = (IRType*)unwrapAttributedType(originalType); + auto primalType = (IRType*)lookupPrimalInst(builder, originalType); + if (auto diffType = differentiateType(builder, originalType)) { switch (diffType->getOp()) { @@ -777,7 +779,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I } } - if (auto arrayType = as<IRArrayType>(primalType)) + if (auto arrayType = as<IRArrayType>(originalType)) { auto diffElementType = (IRType*)differentiableTypeConformanceContext.getDifferentialForType( @@ -803,7 +805,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I } else { - zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); + zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, originalType); } SLANG_RELEASE_ASSERT(zeroMethod); @@ -1108,7 +1110,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst if (!pair.differential->findDecoration<IRAutodiffInstDecoration>() && !as<IRConstant>(pair.differential)) { - auto primalType = as<IRType>(pair.primal->getDataType()); + auto primalType = (IRType*)(pair.primal->getDataType()); builder->markInstAsDifferential(pair.differential, primalType); } } @@ -1117,7 +1119,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst if (!pair.primal->findDecoration<IRAutodiffInstDecoration>() && !as<IRConstant>(pair.differential)) { - auto mixedType = as<IRType>(pair.primal->getDataType()); + auto mixedType = (IRType*)(pair.primal->getDataType()); builder->markInstAsMixedDifferential(pair.primal, mixedType); } } diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index e59f27881..653589933 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1708,7 +1708,9 @@ struct DiffTransposePass // Look for gradient entries for this inst. List<RevGradient> gradients; if (hasRevGradients(inst)) + { gradients = popRevGradients(inst); + } IRType* primalType = tryGetPrimalTypeFromDiffInst(inst); @@ -2746,7 +2748,8 @@ struct DiffTransposePass } default: - SLANG_ASSERT_FAILURE("Unhandled target type for promotion"); + // Default is not to promote. + return inst; } } diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 224cca9e0..f173aaa8b 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1143,9 +1143,8 @@ struct AutoDiffPass : public InstPassBase { bool changed = false; List<IRInst*> autoDiffWorkList; - // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the module. - autoDiffWorkList.clear(); - processAllInsts([&](IRInst* inst) + // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the call graph. + processAllReachableInsts([&](IRInst* inst) { switch (inst->getOp()) { @@ -1164,11 +1163,15 @@ struct AutoDiffPass : public InstPassBase { // Skip functions whose body still has a differentiate inst (higher order func). if (!isFullyDifferentiated(innerFunc)) + { + addToWorkList(inst->getOperand(0)); return; + } } autoDiffWorkList.add(inst); break; default: + autoDiffWorkList.add(inst->getOperand(0)); break; } break; @@ -1176,6 +1179,11 @@ struct AutoDiffPass : public InstPassBase // Explicit primal subst operator is not yet supported. SLANG_UNIMPLEMENTED_X("explicit primal_subst operator."); default: + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto operand = inst->getOperand(i); + addToWorkList(operand); + } break; } }); @@ -1199,7 +1207,7 @@ struct AutoDiffPass : public InstPassBase } break; case kIROp_BackwardDifferentiatePrimal: - { + { auto baseFunc = differentiateInst->getOperand(0); diffFunc = backwardPrimalTranscriber.transcribe(&subBuilder, baseFunc); } @@ -1300,7 +1308,6 @@ struct AutoDiffPass : public InstPassBase hasChanges |= changed; } - return hasChanges; } diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp index 050d1e392..6ac9442ee 100644 --- a/source/slang/slang-ir-clone.cpp +++ b/source/slang/slang-ir-clone.cpp @@ -142,12 +142,9 @@ static void _cloneInstDecorationsAndChildren( // If `newInst` already has non-decoration children, we want to // insert the new children between the existing decoration and non-decoration children // so that we maintain the invariant that all decorations are defined before non-decorations. - if (auto lastDecor = newInst->getLastDecoration()) + if (auto firstChild = newInst->getFirstChild()) { - if (auto nextInstBeforeLastDecor = lastDecor->getNextInst()) - { - builder->setInsertBefore(nextInstBeforeLastDecor); - } + builder->setInsertBefore(firstChild); } // When applying the first phase of cloning to diff --git a/source/slang/slang-ir-inst-pass-base.h b/source/slang/slang-ir-inst-pass-base.h index 3b2331963..2db8a725f 100644 --- a/source/slang/slang-ir-inst-pass-base.h +++ b/source/slang/slang-ir-inst-pass-base.h @@ -3,6 +3,7 @@ #include "slang-ir.h" #include "slang-ir-insts.h" +#include "slang-ir-dce.h" namespace Slang { @@ -23,14 +24,15 @@ namespace Slang workListSet.Add(inst); } - IRInst* pop() + IRInst* pop(bool removeFromSet = true) { if (workList.getCount() == 0) return nullptr; IRInst* inst = workList.getLast(); workList.removeLast(); - workListSet.Remove(inst); + if (removeFromSet) + workListSet.Remove(inst); return inst; } @@ -113,6 +115,35 @@ namespace Slang processChildInsts(module->getModuleInst(), f); } + template <typename Func> + void processAllReachableInsts(const Func& f) + { + workList.clear(); + workListSet.Clear(); + + addToWorkList(module->getModuleInst()); + while (workList.getCount() != 0) + { + IRInst* inst = pop(false); + f(inst); + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + if (as<IRDecoration>(child)) + break; + switch (child->getOp()) + { + case kIROp_GenericSpecializationDictionary: + case kIROp_ExistentialFuncSpecializationDictionary: + case kIROp_ExistentialTypeSpecializationDictionary: + continue; + default: + break; + } + if (shouldInstBeLiveIfParentIsLive(child, IRDeadCodeEliminationOptions())) + addToWorkList(child); + } + } + } }; } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index cf66f1f6b..88a5f6075 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -728,8 +728,8 @@ struct IRDifferentialInstDecoration : IRAutodiffInstDecoration IRUse primalType; IR_LEAF_ISA(DifferentialInstDecoration) - IRType* getPrimalType() { return as<IRType>(getOperand(0)); } - IRInst* getPrimalInst() { return as<IRInst>(getOperand(1)); } + IRType* getPrimalType() { return (IRType*)(getOperand(0)); } + IRInst* getPrimalInst() { return getOperand(1); } }; struct IRPrimalInstDecoration : IRAutodiffInstDecoration diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index d7ed1f63f..ef7b74906 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -2169,7 +2169,8 @@ static LegalVal legalizeInst( context->replacedInstructions.add(inst); return LegalVal::simple(newInst); } - inst->setFullType(legalType.getSimple()); + if (inst->getFullType() != legalType.getSimple()) + inst->setFullType(legalType.getSimple()); return LegalVal::simple(inst); } @@ -3473,8 +3474,15 @@ struct IRTypeLegalizationPass // instructions have ever been added to the work list. List<IRInst*> workList; + HashSet<IRInst*> hasBeenAddedOrProcessedSet; HashSet<IRInst*> addedToWorkListSet; + bool hasBeenAddedToWorkListOrProcessed(IRInst* inst) + { + if (hasBeenAddedToWorkList(inst)) return true; + return hasBeenAddedOrProcessedSet.Contains(inst); + } + // We will add a simple query to check whether an instruciton // has been put on the work list before (or if it should be // treated *as if* it has been placed on the work list). @@ -3523,9 +3531,9 @@ struct IRTypeLegalizationPass // if(addedToWorkListSet.Contains(inst)) return; - workList.add(inst); addedToWorkListSet.Add(inst); + hasBeenAddedOrProcessedSet.Add(inst); } void processModule(IRModule* module) @@ -3549,6 +3557,7 @@ struct IRTypeLegalizationPass // List<IRInst*> workListCopy; Swap(workListCopy, workList); + addedToWorkListSet.Clear(); // Now we simply process each instruction on the copy of // the work list, knowing that `processInst` may add additional @@ -3567,6 +3576,20 @@ struct IRTypeLegalizationPass // for (auto& lv : context->replacedInstructions) { +#if _DEBUG + for (auto use = lv->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (user->getModule() == nullptr) + continue; + if (as<IRType>(user)) + continue; + if (!context->replacedInstructions.Contains(user)) + SLANG_UNEXPECTED("replaced inst still has use."); + if (lv->getParent()) + SLANG_UNEXPECTED("replaced inst still in a parent."); + } +#endif lv->removeAndDeallocate(); } } @@ -3635,19 +3658,19 @@ struct IRTypeLegalizationPass // Next, we don't want to add something if its parent // hasn't been added already. // - if(!hasBeenAddedToWorkList(inst->getParent())) + if(!hasBeenAddedToWorkListOrProcessed(inst->getParent())) return; // Finally, we don't want to add something if its // type and/or operands haven't all been added. // - if(!hasBeenAddedToWorkList(inst->getFullType())) + if(!hasBeenAddedToWorkListOrProcessed(inst->getFullType())) return; Index operandCount = (Index) inst->getOperandCount(); for( Index i = 0; i < operandCount; ++i ) { auto operand = inst->getOperand(i); - if(!hasBeenAddedToWorkList(operand)) + if(!hasBeenAddedToWorkListOrProcessed(operand)) return; } diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index 4f9b8d272..121665c85 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -67,8 +67,14 @@ List<IRBlock*> _collectBlocksInLoop(IRDominatorTree* dom, IRLoop* loopInst) { if (succ == breakBlock) continue; - if (dom->dominates(firstBlock, succ) && !dom->dominates(breakBlock, succ)) - addBlock(succ); + if (!dom->dominates(firstBlock, succ)) + continue; + if (!as<IRUnreachable>(breakBlock->getTerminator())) + { + if (dom->dominates(breakBlock, succ)) + continue; + } + addBlock(succ); } } return loopBlocks; diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index b814442fa..e98d14a0c 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -105,7 +105,7 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst) loopBlocks.Add(b); auto addressHasOutOfLoopUses = [&](IRInst* addr) { - // The entire access chain of `addr` must have no uses out side the loop. + // The entire access chain of `addr` must have no uses outside the loop. // The root variable must be a local var. for (auto chainNode = addr; chainNode;) { @@ -123,6 +123,11 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst) chainNode = chainNode->getOperand(0); continue; case kIROp_Var: + if (auto rate = chainNode->getFullType()->getRate()) + { + if (!as<IRConstExprRate>(rate)) + return true; + } break; default: return true; @@ -143,10 +148,6 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst) return true; } - // The inst can't possibly have side effect? Skip it. - if (!inst->mightHaveSideEffects()) - continue; - // This inst might have side effect, try to prove that the // side effect does not leak beyond the scope of the loop. if (auto call = as<IRCall>(inst)) @@ -187,6 +188,10 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst) } else { + // The inst can't possibly have side effect? Skip it. + if (!inst->mightHaveSideEffects()) + continue; + // For all other insts, we assume it has a global side effect. return true; } diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 050c9bfc7..05c28d131 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -62,7 +62,7 @@ struct SpecializationContext // if it is in our set. // bool isInstFullySpecialized( - IRInst* inst) + IRInst* inst) { // A small wrinkle is that a null instruction pointer // sometimes appears a a type, and so should be treated @@ -70,7 +70,7 @@ struct SpecializationContext // // TODO: It would be nice to remove this wrinkle. // - if(!inst) return true; + if (!inst) return true; // An interface requirement entry should always be considered // to be fully specialized, even if it hasn't been visited. @@ -79,7 +79,7 @@ struct SpecializationContext // can't mark an interface as used until its requirements are // used, etc. // - if(inst->getOp() == kIROp_InterfaceRequirementEntry) + if (inst->getOp() == kIROp_InterfaceRequirementEntry) return true; // A generic parameter is never specialized. @@ -93,6 +93,20 @@ struct SpecializationContext inst->getParent()->getParent()->getOp() == kIROp_Generic) return false; } + + // A global value is always specialized. + if (inst->getParent() == module->getModuleInst()) + { + switch (inst->getOp()) + { + case kIROp_LookupWitness: + case kIROp_Specialize: + return false; + default: + return true; + } + } + return fullySpecializedInsts.Contains(inst); } @@ -504,12 +518,9 @@ struct SpecializationContext // be considered as fully specialized as soon // as all of its operands are. // - // TODO: We realistically need a more refined - // check here that uses an allow-list of instructions - // that can represent values suitable for use - // as generic arguments. - // - if(areAllOperandsFullySpecialized(inst)) + // Anything defined in global scope can be viewed as fully specialized. + if (inst->getParent() == module->getModuleInst() || + areAllOperandsFullySpecialized(inst)) { markInstAsFullySpecialized(inst); } diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp index beaaae065..a0acf5082 100644 --- a/source/slang/slang-ir-ssa-simplification.cpp +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -30,13 +30,14 @@ namespace Slang changed |= peepholeOptimize(module); changed |= removeRedundancy(module); changed |= simplifyCFG(module); - changed |= propagateFuncProperties(module); // Note: we disregard the `changed` state from dead code elimination pass since // SCCP pass could be generating temporarily evaluated constant values and never actually use them. // DCE will always remove those nearly generated consts and always returns true here. eliminateDeadCode(module); + changed |= propagateFuncProperties(module); + changed |= constructSSA(module); changed |= removeUnusedGenericParam(module); iterationCounter++; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index b2f49c24b..4f1c15459 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -49,9 +49,13 @@ struct DeduplicateContext return *newValue; for (UInt i = 0; i < value->getOperandCount(); i++) { - value->unsafeSetOperand(i, deduplicate(value->getOperand(i), shouldDeduplicate)); + auto deduplicatedOperand = deduplicate(value->getOperand(i), shouldDeduplicate); + if (deduplicatedOperand != value->getOperand(i)) + value->unsafeSetOperand(i, deduplicatedOperand); } - value->setFullType((IRType*)deduplicate(value->getFullType(), shouldDeduplicate)); + auto deduplicatedType = (IRType*)deduplicate(value->getFullType(), shouldDeduplicate); + if (deduplicatedType != value->getFullType()) + value->setFullType(deduplicatedType); if (auto newValue = deduplicateMap.TryGetValue(key)) return *newValue; deduplicateMap[key] = value; diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp index 1b72b5c00..8d8056e30 100644 --- a/source/slang/slang-legalize-types.cpp +++ b/source/slang/slang-legalize-types.cpp @@ -464,6 +464,7 @@ struct TupleTypeBuilder // collide. // // (Also, the original type wasn't legal - that was the whole point...) + originalStructType->removeFromParent(); context->replacedInstructions.add(originalStructType); for(auto ee : ordinaryElements) diff --git a/source/slang/slang-legalize-types.h b/source/slang/slang-legalize-types.h index 0b4acf0fe..693a154f6 100644 --- a/source/slang/slang-legalize-types.h +++ b/source/slang/slang-legalize-types.h @@ -620,7 +620,7 @@ struct IRTypeLegalizationContext // store instructions that have been replaced here, so we can free them // when legalization has done - List<IRInst*> replacedInstructions; + OrderedHashSet<IRInst*> replacedInstructions; Dictionary<IRType*, LegalType> mapTypeToLegalType; diff --git a/tests/autodiff-dstdlib/determinant.slang b/tests/autodiff-dstdlib/determinant.slang new file mode 100644 index 000000000..d2e699551 --- /dev/null +++ b/tests/autodiff-dstdlib/determinant.slang @@ -0,0 +1,24 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[BackwardDifferentiable] +float diffDeterminant(float2x2 x) +{ + return determinant(x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + { + var dpx = diffPair(float2x2(1.0, 2.0, 3.0, 4.0)); + __bwd_diff(diffDeterminant)(dpx, 1.0); + outputBuffer[0] = dpx.d[0][0]; + outputBuffer[1] = dpx.d[0][1]; + outputBuffer[2] = dpx.d[1][0]; + outputBuffer[3] = dpx.d[1][1]; + } +} diff --git a/tests/autodiff-dstdlib/determinant.slang.expected.txt b/tests/autodiff-dstdlib/determinant.slang.expected.txt new file mode 100644 index 000000000..b2c2f7e17 --- /dev/null +++ b/tests/autodiff-dstdlib/determinant.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +4.000000 +-3.000000 +-2.000000 +1.000000 diff --git a/tests/bugs/loop-optimize.slang b/tests/bugs/loop-optimize.slang new file mode 100644 index 000000000..85231dbea --- /dev/null +++ b/tests/bugs/loop-optimize.slang @@ -0,0 +1,28 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-cpu -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name outputBuffer +RWStructuredBuffer<float> outputBuffer; + +void test(inout float v, int i) +{ + while (true) + { + i--; + if (i < 2) + { + v += 1.0; + return; + } + } +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + float v = 2.0; + test(v, 3); + outputBuffer[0] = v; // Expect 3.0 +}
\ No newline at end of file diff --git a/tests/bugs/loop-optimize.slang.expected.txt b/tests/bugs/loop-optimize.slang.expected.txt new file mode 100644 index 000000000..f38cc1080 --- /dev/null +++ b/tests/bugs/loop-optimize.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +3.0 |
