summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-22 21:16:35 -0700
committerGitHub <noreply@github.com>2023-03-22 21:16:35 -0700
commit259a015feb9d4ab65e8fbba32f6c777e92780cc7 (patch)
tree45bd4cb9217325c67f5a27d8562b0e7e6b79bb77
parentd4f99c8bac8b28f18c864a717d8833db6a1c872d (diff)
Type legalization and autodiff bug fixes. (#2722)
* Bug fixes. * Fix. * Only perform autodiff for functions whose derivative is actually used. * Fix loop optimize bug. * Fix high order diff. * Fix trivial diff func generation. * Fixes. * Cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/core/slang-dictionary.h8
-rw-r--r--source/slang/diff.meta.slang28
-rw-r--r--source/slang/slang-emit.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp13
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp6
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp22
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h5
-rw-r--r--source/slang/slang-ir-autodiff.cpp17
-rw-r--r--source/slang/slang-ir-clone.cpp7
-rw-r--r--source/slang/slang-ir-inst-pass-base.h35
-rw-r--r--source/slang/slang-ir-insts.h4
-rw-r--r--source/slang/slang-ir-legalize-types.cpp33
-rw-r--r--source/slang/slang-ir-loop-unroll.cpp10
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp15
-rw-r--r--source/slang/slang-ir-specialize.cpp29
-rw-r--r--source/slang/slang-ir-ssa-simplification.cpp3
-rw-r--r--source/slang/slang-ir-util.h8
-rw-r--r--source/slang/slang-legalize-types.cpp1
-rw-r--r--source/slang/slang-legalize-types.h2
-rw-r--r--tests/autodiff-dstdlib/determinant.slang24
-rw-r--r--tests/autodiff-dstdlib/determinant.slang.expected.txt5
-rw-r--r--tests/bugs/loop-optimize.slang28
-rw-r--r--tests/bugs/loop-optimize.slang.expected.txt2
24 files changed, 238 insertions, 71 deletions
diff --git a/source/core/slang-dictionary.h b/source/core/slang-dictionary.h
index e923832e5..fffec9640 100644
--- a/source/core/slang-dictionary.h
+++ b/source/core/slang-dictionary.h
@@ -648,6 +648,14 @@ namespace Slang
{
return dict.AddIfNotExists(_Move(obj), _DummyClass());
}
+ bool add(const T& obj)
+ {
+ return dict.AddIfNotExists(obj, _DummyClass());
+ }
+ bool add(T&& obj)
+ {
+ return dict.AddIfNotExists(_Move(obj), _DummyClass());
+ }
void Remove(const T & obj)
{
dict.Remove(obj);
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index a9b8209f3..26a673512 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -1022,31 +1022,35 @@ __generic<T : __BuiltinFloatingPointType, let N : int>
[__readNone]
T __determinant_impl(matrix<T,N,N> m)
{
- if (N == 1)
- return m[0][0];
- else if (N == 2)
- return m[0][0] * m[1][1] - m[0][1] * m[1][0];
- else if (N == 3)
+ T result = T(0);
+ switch (N)
{
- return m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2])
+ case 1:
+ result = m[0][0];
+ break;
+ case 2:
+ result = m[0][0] * m[1][1] - m[0][1] * m[1][0];
+ break;
+ case 3:
+ result = m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2])
- m[1][0] * (m[0][1] * m[2][2] - m[2][1] * m[0][2])
+ m[2][0] * (m[0][1] * m[1][2] - m[1][1] * m[0][2]);
- }
- else if (N == 4)
- {
- T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3];
+ break;
+ case 4:
+ T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3];
T s01 = m[2][1] * m[3][3] - m[3][1] * m[2][3];
T s02 = m[2][1] * m[3][2] - m[3][1] * m[2][2];
T s03 = m[2][0] * m[3][3] - m[3][0] * m[2][3];
T s04 = m[2][0] * m[3][2] - m[3][0] * m[2][2];
T s05 = m[2][0] * m[3][1] - m[3][0] * m[2][1];
- return m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02)
+ result = m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02)
- m[0][1] * (m[1][0] * s00 - m[1][2] * s03 + m[1][3] * s04)
+ m[0][2] * (m[1][0] * s01 - m[1][1] * s03 + m[1][3] * s05)
- m[0][3] * (m[1][0] * s02 - m[1][1] * s04 + m[1][2] * s05);
+ break;
}
- return T(0.0);
+ return result;
}
__generic<T : __BuiltinFloatingPointType, let N : int>
[ForwardDerivativeOf(determinant)]
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index bc62e488f..1b4eed8fd 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -372,6 +372,8 @@ Result linkAndOptimizeIR(
changed |= specializeModule(irModule);
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE");
+ eliminateDeadCode(irModule);
+
validateIRModuleIfEnabled(codeGenContext, irModule);
// Inline calls to any functions marked with [__unsafeInlineEarly] again,
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index f3c739894..fe3c70bde 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -199,7 +199,7 @@ struct CFGNormalizationPass
bool currBreakRegion = false;
bool currBaseRegion = true;
- // Detect the trivial case. The current block is alredy
+ // Detect the trivial case. The current block is already
// in the next region => this region is empty.
//
if (afterBlocks.contains(currentBlock))
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index bc7e03ad3..3f31f1463 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -65,7 +65,7 @@ void ForwardDiffTranscriber::generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFu
auto primal = builder.emitDefaultConstruct(pairType->getValueType());
builder.markInstAsPrimal(primal);
auto diff = getDifferentialZeroOfType(&builder, pairType->getValueType());
- builder.markInstAsDifferential(primal);
+ builder.markInstAsDifferential(diff, primal->getDataType());
auto val = builder.emitMakeDifferentialPair(pairType, primal, diff);
builder.markInstAsMixedDifferential(val);
@@ -178,7 +178,8 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns
diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType());
auto resultType = primalArith->getDataType();
- auto diffType = (IRType*) differentiableTypeConformanceContext.getDifferentialForType(builder, resultType);
+ auto origResultType = origArith->getDataType();
+ auto diffType = (IRType*)differentiateType(builder, origResultType);
switch(origArith->getOp())
{
@@ -263,15 +264,13 @@ InstPair ForwardDiffTranscriber::transcribeSelect(IRBuilder* builder, IRInst* or
auto primalSelect = maybeCloneForPrimalInst(builder, origSelect);
- auto resultType = primalCondition->getDataType();
-
// If both sides have no differential, skip
if (diffLeft || diffRight)
{
diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType());
diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType());
- auto diffType = (IRType*) differentiableTypeConformanceContext.getDifferentialForType(builder, resultType);
+ auto diffType = differentiateType(builder, origSelect->getDataType());
return InstPair(
primalSelect,
@@ -1831,7 +1830,9 @@ String ForwardDiffTranscriber::makeDiffPairName(IRInst* origVar)
InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType)
{
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalType))
+ SLANG_UNUSED(primalType);
+
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)origParam->getFullType()))
{
IRInst* diffPairParam = builder->emitParam(diffPairType);
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index e01d65f4f..f23e45be0 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -391,6 +391,8 @@ namespace Slang
auto origFuncType = as<IRFuncType>(origFunc->getFullType());
List<IRInst*> primalArgs, propagateArgs;
List<IRType*> primalTypes, propagateTypes;
+ IRType* primalResultType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getResultType());
+
for (UInt i = 0; i < origFuncType->getParamCount(); i++)
{
auto primalParamType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getParamType(i));
@@ -465,11 +467,11 @@ namespace Slang
auto primalFuncType = builder.getFuncType(
primalTypes,
- origFuncType->getResultType());
+ primalResultType);
primalArgs.add(intermediateVar);
primalTypes.add(builder.getOutType(intermediateType));
auto primalFunc = builder.emitBackwardDifferentiatePrimalInst(primalFuncType, specializedOriginalFunc);
- builder.emitCallInst(origFuncType->getResultType(), primalFunc, primalArgs);
+ builder.emitCallInst(primalResultType, primalFunc, primalArgs);
propagateTypes.add(intermediateType);
propagateArgs.add(builder.emitLoad(intermediateVar));
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 6784391a0..2bc67e561 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -474,7 +474,9 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
}
default:
- return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
+ return (IRType*)maybeCloneForPrimalInst(
+ builder,
+ differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)origType));
}
}
@@ -674,7 +676,7 @@ InstPair AutoDiffTranscriberBase::transcribeParam(IRBuilder* builder, IRParam* o
{
auto primal = cloneInst(&cloneEnv, builder, origParam);
IRInst* diff = nullptr;
- if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType))
+ if (IRType* diffType = differentiateType(builder, (IRType*)origParam->getDataType()))
{
diff = builder->emitParam(diffType);
}
@@ -749,11 +751,11 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui
// result, it's useful to have a method to generate zero literals of any (arithmetic) type.
// The current implementation requires that types are defined linearly.
//
-IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
+IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, IRType* originalType)
{
- primalType = (IRType*)unwrapAttributedType(primalType);
-
- if (auto diffType = differentiateType(builder, primalType))
+ originalType = (IRType*)unwrapAttributedType(originalType);
+ auto primalType = (IRType*)lookupPrimalInst(builder, originalType);
+ if (auto diffType = differentiateType(builder, originalType))
{
switch (diffType->getOp())
{
@@ -777,7 +779,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
}
}
- if (auto arrayType = as<IRArrayType>(primalType))
+ if (auto arrayType = as<IRArrayType>(originalType))
{
auto diffElementType =
(IRType*)differentiableTypeConformanceContext.getDifferentialForType(
@@ -803,7 +805,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I
}
else
{
- zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
+ zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, originalType);
}
SLANG_RELEASE_ASSERT(zeroMethod);
@@ -1108,7 +1110,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
if (!pair.differential->findDecoration<IRAutodiffInstDecoration>()
&& !as<IRConstant>(pair.differential))
{
- auto primalType = as<IRType>(pair.primal->getDataType());
+ auto primalType = (IRType*)(pair.primal->getDataType());
builder->markInstAsDifferential(pair.differential, primalType);
}
}
@@ -1117,7 +1119,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
if (!pair.primal->findDecoration<IRAutodiffInstDecoration>()
&& !as<IRConstant>(pair.differential))
{
- auto mixedType = as<IRType>(pair.primal->getDataType());
+ auto mixedType = (IRType*)(pair.primal->getDataType());
builder->markInstAsMixedDifferential(pair.primal, mixedType);
}
}
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index e59f27881..653589933 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1708,7 +1708,9 @@ struct DiffTransposePass
// Look for gradient entries for this inst.
List<RevGradient> gradients;
if (hasRevGradients(inst))
+ {
gradients = popRevGradients(inst);
+ }
IRType* primalType = tryGetPrimalTypeFromDiffInst(inst);
@@ -2746,7 +2748,8 @@ struct DiffTransposePass
}
default:
- SLANG_ASSERT_FAILURE("Unhandled target type for promotion");
+ // Default is not to promote.
+ return inst;
}
}
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 224cca9e0..f173aaa8b 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1143,9 +1143,8 @@ struct AutoDiffPass : public InstPassBase
{
bool changed = false;
List<IRInst*> autoDiffWorkList;
- // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the module.
- autoDiffWorkList.clear();
- processAllInsts([&](IRInst* inst)
+ // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the call graph.
+ processAllReachableInsts([&](IRInst* inst)
{
switch (inst->getOp())
{
@@ -1164,11 +1163,15 @@ struct AutoDiffPass : public InstPassBase
{
// Skip functions whose body still has a differentiate inst (higher order func).
if (!isFullyDifferentiated(innerFunc))
+ {
+ addToWorkList(inst->getOperand(0));
return;
+ }
}
autoDiffWorkList.add(inst);
break;
default:
+ autoDiffWorkList.add(inst->getOperand(0));
break;
}
break;
@@ -1176,6 +1179,11 @@ struct AutoDiffPass : public InstPassBase
// Explicit primal subst operator is not yet supported.
SLANG_UNIMPLEMENTED_X("explicit primal_subst operator.");
default:
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto operand = inst->getOperand(i);
+ addToWorkList(operand);
+ }
break;
}
});
@@ -1199,7 +1207,7 @@ struct AutoDiffPass : public InstPassBase
}
break;
case kIROp_BackwardDifferentiatePrimal:
- {
+ {
auto baseFunc = differentiateInst->getOperand(0);
diffFunc = backwardPrimalTranscriber.transcribe(&subBuilder, baseFunc);
}
@@ -1300,7 +1308,6 @@ struct AutoDiffPass : public InstPassBase
hasChanges |= changed;
}
-
return hasChanges;
}
diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp
index 050d1e392..6ac9442ee 100644
--- a/source/slang/slang-ir-clone.cpp
+++ b/source/slang/slang-ir-clone.cpp
@@ -142,12 +142,9 @@ static void _cloneInstDecorationsAndChildren(
// If `newInst` already has non-decoration children, we want to
// insert the new children between the existing decoration and non-decoration children
// so that we maintain the invariant that all decorations are defined before non-decorations.
- if (auto lastDecor = newInst->getLastDecoration())
+ if (auto firstChild = newInst->getFirstChild())
{
- if (auto nextInstBeforeLastDecor = lastDecor->getNextInst())
- {
- builder->setInsertBefore(nextInstBeforeLastDecor);
- }
+ builder->setInsertBefore(firstChild);
}
// When applying the first phase of cloning to
diff --git a/source/slang/slang-ir-inst-pass-base.h b/source/slang/slang-ir-inst-pass-base.h
index 3b2331963..2db8a725f 100644
--- a/source/slang/slang-ir-inst-pass-base.h
+++ b/source/slang/slang-ir-inst-pass-base.h
@@ -3,6 +3,7 @@
#include "slang-ir.h"
#include "slang-ir-insts.h"
+#include "slang-ir-dce.h"
namespace Slang
{
@@ -23,14 +24,15 @@ namespace Slang
workListSet.Add(inst);
}
- IRInst* pop()
+ IRInst* pop(bool removeFromSet = true)
{
if (workList.getCount() == 0)
return nullptr;
IRInst* inst = workList.getLast();
workList.removeLast();
- workListSet.Remove(inst);
+ if (removeFromSet)
+ workListSet.Remove(inst);
return inst;
}
@@ -113,6 +115,35 @@ namespace Slang
processChildInsts(module->getModuleInst(), f);
}
+ template <typename Func>
+ void processAllReachableInsts(const Func& f)
+ {
+ workList.clear();
+ workListSet.Clear();
+
+ addToWorkList(module->getModuleInst());
+ while (workList.getCount() != 0)
+ {
+ IRInst* inst = pop(false);
+ f(inst);
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ if (as<IRDecoration>(child))
+ break;
+ switch (child->getOp())
+ {
+ case kIROp_GenericSpecializationDictionary:
+ case kIROp_ExistentialFuncSpecializationDictionary:
+ case kIROp_ExistentialTypeSpecializationDictionary:
+ continue;
+ default:
+ break;
+ }
+ if (shouldInstBeLiveIfParentIsLive(child, IRDeadCodeEliminationOptions()))
+ addToWorkList(child);
+ }
+ }
+ }
};
}
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index cf66f1f6b..88a5f6075 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -728,8 +728,8 @@ struct IRDifferentialInstDecoration : IRAutodiffInstDecoration
IRUse primalType;
IR_LEAF_ISA(DifferentialInstDecoration)
- IRType* getPrimalType() { return as<IRType>(getOperand(0)); }
- IRInst* getPrimalInst() { return as<IRInst>(getOperand(1)); }
+ IRType* getPrimalType() { return (IRType*)(getOperand(0)); }
+ IRInst* getPrimalInst() { return getOperand(1); }
};
struct IRPrimalInstDecoration : IRAutodiffInstDecoration
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index d7ed1f63f..ef7b74906 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -2169,7 +2169,8 @@ static LegalVal legalizeInst(
context->replacedInstructions.add(inst);
return LegalVal::simple(newInst);
}
- inst->setFullType(legalType.getSimple());
+ if (inst->getFullType() != legalType.getSimple())
+ inst->setFullType(legalType.getSimple());
return LegalVal::simple(inst);
}
@@ -3473,8 +3474,15 @@ struct IRTypeLegalizationPass
// instructions have ever been added to the work list.
List<IRInst*> workList;
+ HashSet<IRInst*> hasBeenAddedOrProcessedSet;
HashSet<IRInst*> addedToWorkListSet;
+ bool hasBeenAddedToWorkListOrProcessed(IRInst* inst)
+ {
+ if (hasBeenAddedToWorkList(inst)) return true;
+ return hasBeenAddedOrProcessedSet.Contains(inst);
+ }
+
// We will add a simple query to check whether an instruciton
// has been put on the work list before (or if it should be
// treated *as if* it has been placed on the work list).
@@ -3523,9 +3531,9 @@ struct IRTypeLegalizationPass
//
if(addedToWorkListSet.Contains(inst))
return;
-
workList.add(inst);
addedToWorkListSet.Add(inst);
+ hasBeenAddedOrProcessedSet.Add(inst);
}
void processModule(IRModule* module)
@@ -3549,6 +3557,7 @@ struct IRTypeLegalizationPass
//
List<IRInst*> workListCopy;
Swap(workListCopy, workList);
+ addedToWorkListSet.Clear();
// Now we simply process each instruction on the copy of
// the work list, knowing that `processInst` may add additional
@@ -3567,6 +3576,20 @@ struct IRTypeLegalizationPass
//
for (auto& lv : context->replacedInstructions)
{
+#if _DEBUG
+ for (auto use = lv->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (user->getModule() == nullptr)
+ continue;
+ if (as<IRType>(user))
+ continue;
+ if (!context->replacedInstructions.Contains(user))
+ SLANG_UNEXPECTED("replaced inst still has use.");
+ if (lv->getParent())
+ SLANG_UNEXPECTED("replaced inst still in a parent.");
+ }
+#endif
lv->removeAndDeallocate();
}
}
@@ -3635,19 +3658,19 @@ struct IRTypeLegalizationPass
// Next, we don't want to add something if its parent
// hasn't been added already.
//
- if(!hasBeenAddedToWorkList(inst->getParent()))
+ if(!hasBeenAddedToWorkListOrProcessed(inst->getParent()))
return;
// Finally, we don't want to add something if its
// type and/or operands haven't all been added.
//
- if(!hasBeenAddedToWorkList(inst->getFullType()))
+ if(!hasBeenAddedToWorkListOrProcessed(inst->getFullType()))
return;
Index operandCount = (Index) inst->getOperandCount();
for( Index i = 0; i < operandCount; ++i )
{
auto operand = inst->getOperand(i);
- if(!hasBeenAddedToWorkList(operand))
+ if(!hasBeenAddedToWorkListOrProcessed(operand))
return;
}
diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp
index 4f9b8d272..121665c85 100644
--- a/source/slang/slang-ir-loop-unroll.cpp
+++ b/source/slang/slang-ir-loop-unroll.cpp
@@ -67,8 +67,14 @@ List<IRBlock*> _collectBlocksInLoop(IRDominatorTree* dom, IRLoop* loopInst)
{
if (succ == breakBlock)
continue;
- if (dom->dominates(firstBlock, succ) && !dom->dominates(breakBlock, succ))
- addBlock(succ);
+ if (!dom->dominates(firstBlock, succ))
+ continue;
+ if (!as<IRUnreachable>(breakBlock->getTerminator()))
+ {
+ if (dom->dominates(breakBlock, succ))
+ continue;
+ }
+ addBlock(succ);
}
}
return loopBlocks;
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index b814442fa..e98d14a0c 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -105,7 +105,7 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst)
loopBlocks.Add(b);
auto addressHasOutOfLoopUses = [&](IRInst* addr)
{
- // The entire access chain of `addr` must have no uses out side the loop.
+ // The entire access chain of `addr` must have no uses outside the loop.
// The root variable must be a local var.
for (auto chainNode = addr; chainNode;)
{
@@ -123,6 +123,11 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst)
chainNode = chainNode->getOperand(0);
continue;
case kIROp_Var:
+ if (auto rate = chainNode->getFullType()->getRate())
+ {
+ if (!as<IRConstExprRate>(rate))
+ return true;
+ }
break;
default:
return true;
@@ -143,10 +148,6 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst)
return true;
}
- // The inst can't possibly have side effect? Skip it.
- if (!inst->mightHaveSideEffects())
- continue;
-
// This inst might have side effect, try to prove that the
// side effect does not leak beyond the scope of the loop.
if (auto call = as<IRCall>(inst))
@@ -187,6 +188,10 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst)
}
else
{
+ // The inst can't possibly have side effect? Skip it.
+ if (!inst->mightHaveSideEffects())
+ continue;
+
// For all other insts, we assume it has a global side effect.
return true;
}
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 050c9bfc7..05c28d131 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -62,7 +62,7 @@ struct SpecializationContext
// if it is in our set.
//
bool isInstFullySpecialized(
- IRInst* inst)
+ IRInst* inst)
{
// A small wrinkle is that a null instruction pointer
// sometimes appears a a type, and so should be treated
@@ -70,7 +70,7 @@ struct SpecializationContext
//
// TODO: It would be nice to remove this wrinkle.
//
- if(!inst) return true;
+ if (!inst) return true;
// An interface requirement entry should always be considered
// to be fully specialized, even if it hasn't been visited.
@@ -79,7 +79,7 @@ struct SpecializationContext
// can't mark an interface as used until its requirements are
// used, etc.
//
- if(inst->getOp() == kIROp_InterfaceRequirementEntry)
+ if (inst->getOp() == kIROp_InterfaceRequirementEntry)
return true;
// A generic parameter is never specialized.
@@ -93,6 +93,20 @@ struct SpecializationContext
inst->getParent()->getParent()->getOp() == kIROp_Generic)
return false;
}
+
+ // A global value is always specialized.
+ if (inst->getParent() == module->getModuleInst())
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_LookupWitness:
+ case kIROp_Specialize:
+ return false;
+ default:
+ return true;
+ }
+ }
+
return fullySpecializedInsts.Contains(inst);
}
@@ -504,12 +518,9 @@ struct SpecializationContext
// be considered as fully specialized as soon
// as all of its operands are.
//
- // TODO: We realistically need a more refined
- // check here that uses an allow-list of instructions
- // that can represent values suitable for use
- // as generic arguments.
- //
- if(areAllOperandsFullySpecialized(inst))
+ // Anything defined in global scope can be viewed as fully specialized.
+ if (inst->getParent() == module->getModuleInst() ||
+ areAllOperandsFullySpecialized(inst))
{
markInstAsFullySpecialized(inst);
}
diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp
index beaaae065..a0acf5082 100644
--- a/source/slang/slang-ir-ssa-simplification.cpp
+++ b/source/slang/slang-ir-ssa-simplification.cpp
@@ -30,13 +30,14 @@ namespace Slang
changed |= peepholeOptimize(module);
changed |= removeRedundancy(module);
changed |= simplifyCFG(module);
- changed |= propagateFuncProperties(module);
// Note: we disregard the `changed` state from dead code elimination pass since
// SCCP pass could be generating temporarily evaluated constant values and never actually use them.
// DCE will always remove those nearly generated consts and always returns true here.
eliminateDeadCode(module);
+ changed |= propagateFuncProperties(module);
+
changed |= constructSSA(module);
changed |= removeUnusedGenericParam(module);
iterationCounter++;
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index b2f49c24b..4f1c15459 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -49,9 +49,13 @@ struct DeduplicateContext
return *newValue;
for (UInt i = 0; i < value->getOperandCount(); i++)
{
- value->unsafeSetOperand(i, deduplicate(value->getOperand(i), shouldDeduplicate));
+ auto deduplicatedOperand = deduplicate(value->getOperand(i), shouldDeduplicate);
+ if (deduplicatedOperand != value->getOperand(i))
+ value->unsafeSetOperand(i, deduplicatedOperand);
}
- value->setFullType((IRType*)deduplicate(value->getFullType(), shouldDeduplicate));
+ auto deduplicatedType = (IRType*)deduplicate(value->getFullType(), shouldDeduplicate);
+ if (deduplicatedType != value->getFullType())
+ value->setFullType(deduplicatedType);
if (auto newValue = deduplicateMap.TryGetValue(key))
return *newValue;
deduplicateMap[key] = value;
diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp
index 1b72b5c00..8d8056e30 100644
--- a/source/slang/slang-legalize-types.cpp
+++ b/source/slang/slang-legalize-types.cpp
@@ -464,6 +464,7 @@ struct TupleTypeBuilder
// collide.
//
// (Also, the original type wasn't legal - that was the whole point...)
+ originalStructType->removeFromParent();
context->replacedInstructions.add(originalStructType);
for(auto ee : ordinaryElements)
diff --git a/source/slang/slang-legalize-types.h b/source/slang/slang-legalize-types.h
index 0b4acf0fe..693a154f6 100644
--- a/source/slang/slang-legalize-types.h
+++ b/source/slang/slang-legalize-types.h
@@ -620,7 +620,7 @@ struct IRTypeLegalizationContext
// store instructions that have been replaced here, so we can free them
// when legalization has done
- List<IRInst*> replacedInstructions;
+ OrderedHashSet<IRInst*> replacedInstructions;
Dictionary<IRType*, LegalType> mapTypeToLegalType;
diff --git a/tests/autodiff-dstdlib/determinant.slang b/tests/autodiff-dstdlib/determinant.slang
new file mode 100644
index 000000000..d2e699551
--- /dev/null
+++ b/tests/autodiff-dstdlib/determinant.slang
@@ -0,0 +1,24 @@
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[BackwardDifferentiable]
+float diffDeterminant(float2x2 x)
+{
+ return determinant(x);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ {
+ var dpx = diffPair(float2x2(1.0, 2.0, 3.0, 4.0));
+ __bwd_diff(diffDeterminant)(dpx, 1.0);
+ outputBuffer[0] = dpx.d[0][0];
+ outputBuffer[1] = dpx.d[0][1];
+ outputBuffer[2] = dpx.d[1][0];
+ outputBuffer[3] = dpx.d[1][1];
+ }
+}
diff --git a/tests/autodiff-dstdlib/determinant.slang.expected.txt b/tests/autodiff-dstdlib/determinant.slang.expected.txt
new file mode 100644
index 000000000..b2c2f7e17
--- /dev/null
+++ b/tests/autodiff-dstdlib/determinant.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+4.000000
+-3.000000
+-2.000000
+1.000000
diff --git a/tests/bugs/loop-optimize.slang b/tests/bugs/loop-optimize.slang
new file mode 100644
index 000000000..85231dbea
--- /dev/null
+++ b/tests/bugs/loop-optimize.slang
@@ -0,0 +1,28 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-cpu -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+void test(inout float v, int i)
+{
+ while (true)
+ {
+ i--;
+ if (i < 2)
+ {
+ v += 1.0;
+ return;
+ }
+ }
+}
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ float v = 2.0;
+ test(v, 3);
+ outputBuffer[0] = v; // Expect 3.0
+} \ No newline at end of file
diff --git a/tests/bugs/loop-optimize.slang.expected.txt b/tests/bugs/loop-optimize.slang.expected.txt
new file mode 100644
index 000000000..f38cc1080
--- /dev/null
+++ b/tests/bugs/loop-optimize.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+3.0