summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-30 13:24:39 -0800
committerGitHub <noreply@github.com>2022-11-30 13:24:39 -0800
commit09684224d5ab63f530d66c0be65fa50e6fc5290b (patch)
tree292d0f257b3d5a5e027892a5a1e046d60166aadd
parentf52b4de3b29ee27213b7d60fb620a0d5d50b49f9 (diff)
Support `no_diff` on existential typed params. (#2540)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-check-conformance.cpp6
-rw-r--r--source/slang/slang-check-decl.cpp3
-rw-r--r--source/slang/slang-emit.cpp41
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp8
-rw-r--r--source/slang/slang-ir-autodiff-pairs.cpp5
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp5
-rw-r--r--source/slang/slang-ir-autodiff.cpp7
-rw-r--r--source/slang/slang-ir-specialize.cpp37
-rw-r--r--source/slang/slang-ir-specialize.h2
-rw-r--r--source/slang/slang-ir.h17
-rw-r--r--tests/autodiff/no-diff-param-2.slang38
-rw-r--r--tests/autodiff/no-diff-param-2.slang.expected.txt5
12 files changed, 112 insertions, 62 deletions
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp
index d2335efbf..4d983b746 100644
--- a/source/slang/slang-check-conformance.cpp
+++ b/source/slang/slang-check-conformance.cpp
@@ -87,8 +87,10 @@ namespace Slang
// that `subType` has been proven to be *equal*
// to `superTypeDeclRef`.
//
- SLANG_UNEXPECTED("reflexive type witness");
- UNREACHABLE_RETURN(nullptr);
+ auto witness = m_astBuilder->create<TypeEqualityWitness>();
+ witness->sub = subType;
+ witness->sup = subType;
+ return witness;
}
// We might have one or more steps in the breadcrumb trail, e.g.:
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 5e6c6eedf..d36e6286d 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -4715,7 +4715,8 @@ namespace Slang
maybeRegisterDifferentiableType(m_astBuilder, decl->returnType.type);
if (as<ConstructorDecl>(decl) || !isEffectivelyStatic(decl))
{
- auto thisType = calcThisType(makeDeclRef(decl));
+ auto parentDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(decl->parentDecl));
+ auto thisType = calcThisType(parentDeclRef);
maybeRegisterDifferentiableType(m_astBuilder, thisType);
}
m_parentDifferentiableAttr = oldAttr;
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index ca55a68bc..508402736 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -358,33 +358,36 @@ Result linkAndOptimizeIR(
// perform specialization of functions based on parameter
// values that need to be compile-time constants.
//
+ // Specialization passes and auto-diff passes runs in an iterative loop
+ // since each pass can enable the other pass to progress further.
+ for (;;)
+ {
+ bool changed = false;
- dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE");
- if (!codeGenContext->isSpecializationDisabled())
- specializeModule(irModule);
- dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE");
-
- applySparseConditionalConstantPropagation(irModule);
- eliminateDeadCode(irModule);
+ dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE");
+ if (!codeGenContext->isSpecializationDisabled())
+ changed |= specializeModule(irModule);
+ dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE");
- lowerReinterpret(targetRequest, irModule, sink);
-
- validateIRModuleIfEnabled(codeGenContext, irModule);
+ validateIRModuleIfEnabled(codeGenContext, irModule);
- // Inline calls to any functions marked with [__unsafeInlineEarly] again,
- // since we may be missing out cases prevented by the functions that we just specialzied.
- performMandatoryEarlyInlining(irModule);
+ // Inline calls to any functions marked with [__unsafeInlineEarly] again,
+ // since we may be missing out cases prevented by the functions that we just specialzied.
+ performMandatoryEarlyInlining(irModule);
- dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
-
- processAutodiffCalls(irModule, sink);
+ dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
+ changed |= processAutodiffCalls(irModule, sink);
+ dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF");
- dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF");
+ if (!changed)
+ break;
+ }
+
+ lowerReinterpret(targetRequest, irModule, sink);
validateIRModuleIfEnabled(codeGenContext, irModule);
- applySparseConditionalConstantPropagation(irModule);
- eliminateDeadCode(irModule);
+ simplifyIR(irModule);
// For targets that supports dynamic dispatch, we need to lower the
// generics / interface types to ordinary functions and types using
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index c9b186c8a..d45dd0c10 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -691,7 +691,8 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall
differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
primalCallee);
}
- else
+
+ if (!diffCallee)
{
// The callee is non differentiable, just return primal value with null diff value.
IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall);
@@ -1614,8 +1615,8 @@ struct ForwardDerivativePass : public InstPassBase
//
bool processReferencedFunctions(IRBuilder* builder)
{
+ bool changed = false;
List<IRInst*> autoDiffWorkList;
-
for (;;)
{
// Collect all `ForwardDifferentiate` insts from the module.
@@ -1669,6 +1670,7 @@ struct ForwardDerivativePass : public InstPassBase
differentiateInst->replaceUsesWith(diffFunc);
differentiateInst->removeAndDeallocate();
}
+ changed = true;
}
}
// Actually synthesize the derivatives.
@@ -1689,7 +1691,7 @@ struct ForwardDerivativePass : public InstPassBase
SLANG_RELEASE_ASSERT(transcriberStorage.followUpFunctionsToTranscribe.getCount() == 0);
}
- return true;
+ return changed;
}
// Checks decorators to see if the function should
diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp
index 1dbb1bd7c..b9b4a8b66 100644
--- a/source/slang/slang-ir-autodiff-pairs.cpp
+++ b/source/slang/slang-ir-autodiff-pairs.cpp
@@ -133,12 +133,10 @@ struct DiffPairLoweringPass : InstPassBase
case kIROp_DifferentialPairGetDifferential:
case kIROp_DifferentialPairGetPrimal:
lowerPairAccess(builder, inst);
- modified = true;
break;
case kIROp_MakeDifferentialPair:
lowerMakePair(builder, inst);
- modified = true;
break;
default:
@@ -152,6 +150,7 @@ struct DiffPairLoweringPass : InstPassBase
{
inst->replaceUsesWith(loweredType);
inst->removeAndDeallocate();
+ modified = true;
}
});
return modified;
@@ -179,4 +178,4 @@ bool processPairTypes(AutoDiffSharedContext* context)
return pairLoweringPass.processModule();
}
-} \ No newline at end of file
+}
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index daf45e1ef..8ec8f581c 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -871,6 +871,8 @@ struct ReverseDerivativePass : public InstPassBase
//
bool processReferencedFunctions(IRBuilder* builder)
{
+ bool changed = false;
+
List<IRInst*> autoDiffWorkList;
for (;;)
@@ -922,6 +924,7 @@ struct ReverseDerivativePass : public InstPassBase
SLANG_ASSERT(diffFunc);
differentiateInst->replaceUsesWith(diffFunc);
differentiateInst->removeAndDeallocate();
+ changed = true;
}
else
{
@@ -950,7 +953,7 @@ struct ReverseDerivativePass : public InstPassBase
SLANG_RELEASE_ASSERT(backwardTranscriberStorage.followUpFunctionsToTranscribe.getCount() == 0);
}
- return true;
+ return changed;
}
// Checks decorators to see if the function should
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 4373cf44b..5b5832073 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -448,12 +448,6 @@ bool processAutodiffCalls(
DiagnosticSink* sink,
IRAutodiffPassOptions const&)
{
- // Simplify module to remove dead code.
- IRDeadCodeEliminationOptions dceOptions;
- dceOptions.keepExportsAlive = true;
- dceOptions.keepLayoutsAlive = true;
- eliminateDeadCode(module, dceOptions);
-
bool modified = false;
// Create shared context for all auto-diff related passes
@@ -487,7 +481,6 @@ bool processAutodiffCalls(
// Remove auto-diff related decorations.
stripAutoDiffDecorations(module);
-
return modified;
}
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 406e5157c..74caa30ae 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -44,6 +44,8 @@ struct SpecializationContext
// we are specializing.
IRModule* module;
+ bool changed = false;
+
// We know that we can only perform generic specialization when all
// of the arguments to a generic are also fully specialized.
// The "is fully specialized" condition is something we
@@ -793,8 +795,6 @@ struct SpecializationContext
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->init(module);
- bool changed = true;
-
// Read specialization dictionary from module if it is defined.
// This prevents us from generating duplicated specializations
// when this pass is invoked iteratively.
@@ -839,9 +839,9 @@ struct SpecializationContext
// We start out simple by putting the root instruction for the
// module onto our work list.
//
- while (changed)
+ for (;;)
{
- changed = false;
+ bool iterChanged = false;
addToWorkList(module->getModuleInst());
while (workList.Count() != 0)
@@ -868,7 +868,7 @@ struct SpecializationContext
// specialization opportunities (generic specialization,
// existential specialization, simplifications, etc.)
//
- changed |= maybeSpecializeInst(inst);
+ iterChanged |= maybeSpecializeInst(inst);
// Finally, we need to make our logic recurse through
// the whole IR module, so we want to add the children
@@ -896,8 +896,15 @@ struct SpecializationContext
addDirtyInstsToWorkListRec(module->getModuleInst());
}
- if (changed)
+ if (iterChanged)
+ {
simplifyIR(module);
+ this->changed = true;
+ }
+ else
+ {
+ break;
+ }
}
// Once the work list has gone dry, we should have the invariant
@@ -1776,6 +1783,11 @@ struct SpecializationContext
type = sbType->getElementType();
goto top;
}
+ else if (auto attributedType = as<IRAttributedType>(type))
+ {
+ type = attributedType->getBaseType();
+ goto top;
+ }
else if( auto structType = as<IRStructType>(type) )
{
UInt count = 0;
@@ -2070,6 +2082,11 @@ struct SpecializationContext
type = sbType->getElementType();
goto top;
}
+ else if (auto attributedType = as<IRAttributedType>(type))
+ {
+ type = attributedType->getBaseType();
+ goto top;
+ }
else if( auto structType = as<IRStructType>(type) )
{
UInt count = 0;
@@ -2114,7 +2131,8 @@ struct SpecializationContext
}
else if( as<IRPointerLikeType>(baseType) ||
as<IRHLSLStructuredBufferTypeBase>(baseType) ||
- as<IRArrayTypeBase>(baseType))
+ as<IRArrayTypeBase>(baseType) ||
+ as<IRAttributedType>(baseType) )
{
// A `BindExistentials<P<T>, ...>` can be simplified to
// `P<BindExistentials<T, ...>>` when `P` is a pointer-like
@@ -2127,6 +2145,8 @@ struct SpecializationContext
baseElementType = arrayType->getElementType();
else if (auto baseSBType = as<IRHLSLStructuredBufferTypeBase>(baseType))
baseElementType = baseSBType->getElementType();
+ else if (auto baseAttrType = as<IRAttributedType>(baseType))
+ baseElementType = baseAttrType->getBaseType();
IRInst* wrappedElementType = builder.getBindExistentialsType(
baseElementType,
@@ -2283,12 +2303,13 @@ struct SpecializationContext
}
};
-void specializeModule(
+bool specializeModule(
IRModule* module)
{
SpecializationContext context;
context.module = module;
context.processModule();
+ return context.changed;
}
diff --git a/source/slang/slang-ir-specialize.h b/source/slang/slang-ir-specialize.h
index 9c2c19785..1503c238e 100644
--- a/source/slang/slang-ir-specialize.h
+++ b/source/slang/slang-ir-specialize.h
@@ -6,7 +6,7 @@ namespace Slang
struct IRModule;
/// Specialize generic and interface-based code to use concrete types.
-void specializeModule(
+bool specializeModule(
IRModule* module);
}
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 36fab6da1..56a33c02b 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -716,28 +716,11 @@ struct IRInst
void _insertAt(IRInst* inPrev, IRInst* inNext, IRInst* inParent);
};
-inline bool isModifierInst(IROp op)
-{
- switch (op)
- {
- case kIROp_AttributedType:
- return true;
- }
- return false;
-}
-
template<typename T>
T* dynamicCast(IRInst* inst)
{
if (inst && T::isaImpl(inst->getOp()))
return static_cast<T*>(inst);
- if (inst)
- {
- if (isModifierInst(inst->getOp()))
- {
- return dynamicCast<T>(inst->getOperand(0));
- }
- }
return nullptr;
}
diff --git a/tests/autodiff/no-diff-param-2.slang b/tests/autodiff/no-diff-param-2.slang
new file mode 100644
index 000000000..d29928d69
--- /dev/null
+++ b/tests/autodiff/no-diff-param-2.slang
@@ -0,0 +1,38 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+
+interface IFoo : IDifferentiable
+{
+ [ForwardDifferentiable]
+ float getVal();
+}
+
+struct A : IFoo
+{
+ float x;
+ [ForwardDifferentiable]
+ float getVal(){return x;}
+}
+
+[ForwardDifferentiable]
+float f(float x, no_diff IFoo y)
+{
+ return x * x + y.getVal();
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ A a;
+ a.x = 2.0;
+ let rs = __fwd_diff(f)(dpfloat(1.5, 1.0), a);
+ outputBuffer[0] = rs.p; // Expect: 6.25
+ outputBuffer[1] = rs.d; // Expect: 3.0
+ }
+}
diff --git a/tests/autodiff/no-diff-param-2.slang.expected.txt b/tests/autodiff/no-diff-param-2.slang.expected.txt
new file mode 100644
index 000000000..18066089d
--- /dev/null
+++ b/tests/autodiff/no-diff-param-2.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+4.250000
+3.000000
+0.000000
+0.000000 \ No newline at end of file