summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-check-decl.cpp2
-rw-r--r--source/slang/slang-check-modifier.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp96
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp156
-rw-r--r--source/slang/slang-ir-autodiff-rev.h11
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp12
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h36
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp16
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h36
-rw-r--r--source/slang/slang-ir-autodiff.cpp60
-rw-r--r--source/slang/slang-ir-autodiff.h1
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp15
-rw-r--r--source/slang/slang-ir-hoist-local-types.cpp56
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h16
-rw-r--r--source/slang/slang-ir-util.cpp2
-rw-r--r--source/slang/slang-ir.cpp9
-rw-r--r--source/slang/slang-ir.h3
-rw-r--r--tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang61
-rw-r--r--tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang.expected.txt6
-rw-r--r--tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang53
-rw-r--r--tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang.expected.txt6
-rw-r--r--tests/language-server/robustness-6.slang10
-rw-r--r--tests/language-server/robustness-6.slang.expected.txt13
24 files changed, 512 insertions, 171 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index f016ae3d8..a535ba104 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -6918,6 +6918,8 @@ namespace Slang
// has an associated derivative function.
if (func->findModifier<BackwardDifferentiableAttribute>())
return true;
+ if (func->findModifier<BackwardDerivativeAttribute>())
+ return true;
for (auto assocDecl : getAssociatedDeclsForDecl(func))
{
switch (assocDecl.kind)
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index f505b1321..f3623f19f 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -613,7 +613,7 @@ namespace Slang
hitObjectAttributesAttr->location = (int32_t)val->value;
}
- else if (auto forwardDerivativeAttr = as<ForwardDerivativeAttribute>(attr))
+ else if (auto derivativeAttr = as<UserDefinedDerivativeAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
SLANG_ASSERT(as<Decl>(attrTarget));
@@ -633,7 +633,7 @@ namespace Slang
//
// Set type to null to indicate that this needs expr needs to be further resolved.
diffExpr->type.type = nullptr;
- forwardDerivativeAttr->funcExpr = diffExpr;
+ derivativeAttr->funcExpr = diffExpr;
}
else if (auto derivativeOfAttr = as<DerivativeOfAttribute>(attr))
{
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 54d32ae3e..68a86bc00 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -287,6 +287,33 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst*
}
}
+static bool _isDifferentiableFunc(IRInst* func)
+{
+ for (auto decor = func->getFirstDecoration(); decor; decor = decor->getNextDecoration())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_ForwardDifferentiableDecoration:
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_BackwardDifferentiableDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ return true;
+ }
+ }
+ return false;
+}
+
+static IRFuncType* _getCalleeActualFuncType(IRInst* callee)
+{
+ auto type = callee->getFullType();
+ if (auto funcType = as<IRFuncType>(type))
+ return funcType;
+ if (auto specialize = as<IRSpecialize>(callee))
+ return as<IRFuncType>(findGenericReturnVal(as<IRGeneric>(specialize->getBase()))->getFullType());
+ return nullptr;
+}
+
// Differentiating a call instruction here is primarily about generating
// an appropriate call list based on whichever parameters have differentials
// in the current transcription context.
@@ -310,10 +337,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
return InstPair(nullptr, nullptr);
}
- // Since concrete functions are globals, the primal callee is the same
- // as the original callee.
- //
- auto primalCallee = origCallee;
+ auto primalCallee = lookupPrimalInst(builder, origCallee, origCallee);
IRInst* diffCallee = nullptr;
@@ -325,8 +349,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
// If the user has already provided an differentiated implementation, use that.
diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc();
}
- else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>() ||
- primalCallee->findDecoration<IRBackwardDifferentiableDecoration>())
+ else if (_isDifferentiableFunc(primalCallee))
{
// If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
// to generate the implementation.
@@ -343,7 +366,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
return InstPair(primalCall, nullptr);
}
- auto calleeType = as<IRFuncType>(diffCallee->getDataType());
+ auto calleeType = _getCalleeActualFuncType(diffCallee);
SLANG_ASSERT(calleeType);
SLANG_RELEASE_ASSERT(calleeType->getParamCount() == origCall->getArgCount());
@@ -399,6 +422,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
diffCallee,
args);
builder->markInstAsMixedDifferential(callInst, diffReturnType);
+ builder->addAutoDiffOriginalValueDecoration(callInst, primalCallee);
if (diffReturnType->getOp() != kIROp_VoidType)
{
@@ -629,7 +653,7 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec
builder->getTypeKind(), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer());
return InstPair(primalSpecialize, diffSpecialize);
}
- else if (auto diffDecor = genericInnerVal->findDecoration<IRForwardDifferentiableDecoration>())
+ else if (_isDifferentiableFunc(genericInnerVal))
{
List<IRInst*> args;
for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
@@ -927,8 +951,15 @@ InstPair ForwardDiffTranscriber::transcribeWrapExistential(IRBuilder* builder, I
// Create an empty func to represent the transcribed func of `origFunc`.
InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
{
- if (auto bwdDecor = origFunc->findDecoration<IRForwardDerivativeDecoration>())
- return InstPair(origFunc, bwdDecor->getForwardDerivativeFunc());
+ if (auto fwdDecor = origFunc->findDecoration<IRForwardDerivativeDecoration>())
+ {
+ // If we reach here, the function must have been used directly in a `call` inst, and therefore
+ // can't be a generic.
+ // Generic function are always referenced with `specialize` inst and the handling logic for
+ // custom derivatives is implemented in `transcribeSpecialize`.
+ SLANG_RELEASE_ASSERT(fwdDecor->getForwardDerivativeFunc()->getOp() == kIROp_Func);
+ return InstPair(origFunc, fwdDecor->getForwardDerivativeFunc());
+ }
auto diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc);
@@ -1012,51 +1043,6 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr
return InstPair(primalFunc, diffFunc);
}
-// Transcribe a generic definition
-InstPair ForwardDiffTranscriber::transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric)
-{
- auto innerVal = findInnerMostGenericReturnVal(origGeneric);
- if (auto innerFunc = as<IRFunc>(innerVal))
- {
- differentiableTypeConformanceContext.setFunc(innerFunc);
- }
- else if (auto funcType = as<IRFuncType>(innerVal))
- {
- }
- else
- {
- return InstPair(origGeneric, nullptr);
- }
-
- IRGeneric* primalGeneric = origGeneric;
-
- IRBuilder builder(inBuilder->getSharedBuilder());
- builder.setInsertBefore(origGeneric);
-
- auto diffGeneric = builder.emitGeneric();
-
- // Process type of generic. If the generic is a function, then it's type will also be a
- // generic and this logic will transcribe that generic first before continuing with the
- // function itself.
- //
- auto primalType = primalGeneric->getFullType();
-
- IRType* diffType = nullptr;
- if (primalType)
- {
- diffType = (IRType*) findOrTranscribeDiffInst(&builder, primalType);
- }
-
- diffGeneric->setFullType(diffType);
-
- // Transcribe children from origFunc into diffFunc.
- builder.setInsertInto(diffGeneric);
- for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock())
- this->transcribe(&builder, block);
-
- return InstPair(primalGeneric, diffGeneric);
-}
-
InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* origInst)
{
// Handle common SSA-style operations
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index af408a5b3..8d6419cf2 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -46,7 +46,8 @@ namespace Slang
}
}
- newParameterTypes.add(differentiateType(builder, funcType->getResultType()));
+ if (auto diffResultType = differentiateType(builder, funcType->getResultType()))
+ newParameterTypes.add(diffResultType);
if (intermeidateType)
{
@@ -58,20 +59,14 @@ namespace Slang
return builder->getFuncType(newParameterTypes, diffReturnType);
}
- static IRInst* getOriginalFuncRef(IRBuilder& builder, IRInst* func, IRInst* useSite)
- {
- if (!func) return nullptr;
- auto userGeneric = findOuterGeneric(useSite);
- if (!userGeneric) return func;
- auto funcGen = findOuterGeneric(func);
- SLANG_RELEASE_ASSERT(funcGen);
- return maybeSpecializeWithGeneric(builder, funcGen, userGeneric);
- }
-
IRFuncType* BackwardDiffPrimalTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType)
{
- auto funcRef = getOriginalFuncRef(*builder, func, builder->getInsertLoc().getParent());
- auto intermediateType = builder->getBackwardDiffIntermediateContextType(funcRef);
+ IRType* intermediateType = builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func));
+ if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent()))
+ {
+ intermediateType = (IRType*)specializeWithGeneric(*builder, intermediateType, as<IRGeneric>(outerGeneric));
+ }
+
auto outType = builder->getOutType(intermediateType);
List<IRType*> paramTypes;
for (UInt i = 0; i < funcType->getParamCount(); i++)
@@ -91,13 +86,98 @@ namespace Slang
// Don't need to do anything other than add a decoration in the original func to point to the primal func.
// The body of the primal func will be generated by propagateTranscriber together with propagate func.
addTranscribedFuncDecoration(*builder, primalFunc, diffFunc);
- return InstPair(primalFunc, primalFunc);
+ return InstPair(primalFunc, diffFunc);
+ }
+
+ static List<IRInst*> _defineFuncParams(IRBuilder* builder, IRFunc* func)
+ {
+ auto propFuncType = cast<IRFuncType>(func->getFullType());
+ List<IRInst*> params;
+ for (UInt i = 0; i < propFuncType->getParamCount(); i++)
+ {
+ auto paramType = propFuncType->getParamType(i);
+ auto param = builder->emitParam(paramType);
+ params.add(param);
+ }
+ return params;
+ }
+
+ void BackwardDiffPropagateTranscriber::generateTrivialDiffFuncFromUserDefinedDerivative(
+ IRBuilder* builder,
+ IRFunc* originalFunc,
+ IRFunc* diffPropFunc,
+ IRUserDefinedBackwardDerivativeDecoration* udfDecor)
+ {
+ // Create an empty struct type to use as the intermediate context type.
+ auto originalGeneric = findOuterGeneric(originalFunc);
+ builder->setInsertBefore(originalFunc);
+ IRInst* emptyStruct = builder->createStructType();
+ IRInst* emptyStructType = nullptr;
+ auto emptyStructGeneric = hoistValueFromGeneric(*builder, emptyStruct, emptyStructType, false);
+ builder->addBackwardDerivativeIntermediateTypeDecoration(originalFunc, emptyStructGeneric);
+
+ IRInst* udf = udfDecor->getBackwardDerivativeFunc();
+ builder->setInsertInto(diffPropFunc);
+ builder->emitBlock();
+ List<IRInst*> params = _defineFuncParams(builder, diffPropFunc);
+ params.removeLast();
+ IRInst* udfRefFromPropFunc = udf;
+ if (auto specialize = as<IRSpecialize>(udf))
+ {
+ udf = specialize->getBase();
+ auto propGeneric = findOuterGeneric(diffPropFunc);
+ SLANG_RELEASE_ASSERT(propGeneric);
+ udfRefFromPropFunc = maybeSpecializeWithGeneric(*builder, udf, propGeneric);
+ }
+ builder->emitCallInst(builder->getVoidType(), udfRefFromPropFunc, params);
+ builder->emitReturn();
+
+ // Now create the trivial primal function.
+ auto existingDecor = originalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>();
+ if (!existingDecor)
+ {
+ // We haven't created a header for primal func yet, create it now.
+ if (originalGeneric)
+ builder->setInsertBefore(originalGeneric);
+ else
+ builder->setInsertBefore(originalFunc);
+
+ autoDiffSharedContext->transcriberSet.primalTranscriber->transcribe(builder, originalGeneric ? originalGeneric : originalFunc);
+ existingDecor = originalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>();
+ }
+ SLANG_RELEASE_ASSERT(existingDecor);
+
+ // Fill the primal func header with trivial call to original func.
+ IRInst* existingPrimalFunc = existingDecor->getBackwardDerivativePrimalFunc();
+ IRGeneric* existingPriamlFuncGeneric = nullptr;
+ if (auto specialize = as<IRSpecialize>(existingPrimalFunc))
+ {
+ existingPriamlFuncGeneric = as<IRGeneric>(specialize->getBase());
+ existingPrimalFunc = findGenericReturnVal(existingPriamlFuncGeneric);
+ }
+ builder->setInsertBefore(existingPrimalFunc);
+
+ builder->setInsertInto(existingPrimalFunc);
+ builder->emitBlock();
+ params = _defineFuncParams(builder, as<IRFunc>(existingPrimalFunc));
+ params.removeLast();
+ IRInst* originalFuncRefFromPrimalFunc = originalFunc;
+ if (originalGeneric)
+ originalFuncRefFromPrimalFunc = maybeSpecializeWithGeneric(*builder, originalGeneric, existingPriamlFuncGeneric);
+ auto result = builder->emitCallInst(
+ cast<IRFuncType>(existingPrimalFunc->getFullType())->getResultType(),
+ originalFuncRefFromPrimalFunc,
+ params);
+ builder->emitReturn(result);
}
IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType)
{
- auto funcRef = getOriginalFuncRef(*builder, func, builder->getInsertLoc().getParent());
- auto intermediateType = builder->getBackwardDiffIntermediateContextType(funcRef);
+ IRType* intermediateType = builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func));
+ if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent()))
+ {
+ intermediateType = (IRType*)specializeWithGeneric(*builder, intermediateType, as<IRGeneric>(outerGeneric));
+ }
return differentiateFunctionTypeImpl(builder, funcType, intermediateType);
}
@@ -109,9 +189,15 @@ namespace Slang
InstPair BackwardDiffPropagateTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
{
- IRGlobalValueWithCode* diffPrimalFunc = nullptr;
addTranscribedFuncDecoration(*builder, primalFunc, diffFunc);
- transcribeFuncImpl(builder, primalFunc, diffFunc, diffPrimalFunc);
+ if (auto udf = primalFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
+ {
+ generateTrivialDiffFuncFromUserDefinedDerivative(builder, primalFunc, diffFunc, udf);
+ }
+ else
+ {
+ transcribeFuncImpl(builder, primalFunc, diffFunc);
+ }
return InstPair(primalFunc, diffFunc);
}
@@ -212,18 +298,13 @@ namespace Slang
return InstPair(diffBlock, diffBlock);
}
- static bool isMarkedForBackwardDifferentiation(IRInst* callable)
- {
- return callable->findDecoration<IRBackwardDifferentiableDecoration>() != nullptr;
- }
-
// Create an empty func to represent the transcribed func of `origFunc`.
InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc)
{
if (auto bwdDiffFunc = findExistingDiffFunc(origFunc))
return InstPair(origFunc, bwdDiffFunc);
- if (!isMarkedForBackwardDifferentiation(origFunc))
+ if (!isBackwardDifferentiableFunc(origFunc))
return InstPair(nullptr, nullptr);
IRBuilder builder = *inBuilder;
@@ -253,7 +334,6 @@ namespace Slang
// Mark the generated derivative function itself as differentiable.
builder.addBackwardDifferentiableDecoration(diffFunc);
-
// Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
{
@@ -339,13 +419,14 @@ namespace Slang
}
auto outerGeneric = findOuterGeneric(origFunc);
+ IRType* intermediateType = builder.getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(origFunc));
IRInst* specializedOriginalFunc = origFunc;
if (outerGeneric)
{
specializedOriginalFunc = maybeSpecializeWithGeneric(builder, outerGeneric, findOuterGeneric(header.differential));
+ intermediateType = (IRType*)specializeWithGeneric(builder, intermediateType, as<IRGeneric>(findOuterGeneric(header.differential)));
}
- auto intermediateType = builder.getBackwardDiffIntermediateContextType(specializedOriginalFunc);
auto intermediateVar = builder.emitVar(intermediateType);
auto origFuncType = as<IRFuncType>(origFunc->getDataType());
@@ -420,11 +501,19 @@ namespace Slang
eliminateDeadCode(primalOuterParent);
// Forward transcribe the clone of the original func.
- ForwardDiffTranscriber fwdTranscriber(autoDiffSharedContext, builder->getSharedBuilder(), sink);
- fwdTranscriber.pairBuilder = pairBuilder;
+ ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>(
+ autoDiffSharedContext->transcriberSet.forwardTranscriber);
+ auto oldCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount();
IRFunc* fwdDiffFunc = as<IRFunc>(getGenericReturnVal(fwdTranscriber.transcribe(builder, primalOuterParent)));
SLANG_ASSERT(fwdDiffFunc);
- fwdTranscriber.transcribeFunc(builder, primalFunc, fwdDiffFunc);
+ auto newCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount();
+ for (auto i = oldCount; i < newCount; i++)
+ {
+ auto pendingTask = autoDiffSharedContext->followUpFunctionsToTranscribe.getLast();
+ autoDiffSharedContext->followUpFunctionsToTranscribe.removeLast();
+ SLANG_RELEASE_ASSERT(pendingTask.type == FuncBodyTranscriptionTaskType::Forward);
+ fwdTranscriber.transcribeFunc(builder, pendingTask.originalFunc, pendingTask.resultFunc);
+ }
// Remove the clone of original func.
primalOuterParent->removeAndDeallocate();
@@ -453,12 +542,11 @@ namespace Slang
}
fwdParentGeneric->removeAndDeallocate();
}
-
return fwdDiffFunc;
}
// Transcribe a function definition.
- void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc)
+ void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc)
{
SLANG_ASSERT(primalFunc);
SLANG_ASSERT(diffPropagateFunc);
@@ -546,8 +634,7 @@ namespace Slang
IRInst* specializedFunc = nullptr;
auto intermediateTypeGeneric = hoistValueFromGeneric(*builder, intermediateType, specializedFunc, true);
builder->setInsertBefore(primalFunc);
- auto specializedIntermeidateType = maybeSpecializeWithGeneric(*builder, intermediateTypeGeneric, primalOuterGeneric);
- builder->addBackwardDerivativeIntermediateTypeDecoration(primalFunc, specializedIntermeidateType);
+ builder->addBackwardDerivativeIntermediateTypeDecoration(primalFunc, intermediateTypeGeneric);
auto primalFuncGeneric = hoistValueFromGeneric(*builder, extractedPrimalFunc, specializedFunc, true);
builder->setInsertBefore(primalFunc);
@@ -567,7 +654,6 @@ namespace Slang
auto specializedBackwardPrimalFunc = maybeSpecializeWithGeneric(*builder, primalFuncGeneric, primalOuterGeneric);
builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc);
}
- diffPrimalFunc = as<IRGlobalValueWithCode>(primalOuterGeneric);
}
void BackwardDiffTranscriberBase::transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc)
@@ -900,7 +986,7 @@ namespace Slang
return InstPair(primalSpecialize, diffSpecialize);
}
- else if (auto diffDecor = genericInnerVal->findDecoration<IRBackwardDifferentiableDecoration>())
+ else if (isBackwardDifferentiableFunc(genericInnerVal))
{
List<IRInst*> args;
for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index 02a100c80..228bcf588 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -87,7 +87,7 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
IRFunc* generateNewForwardDerivativeForFunc(IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc);
- void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc);
+ void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc);
InstPair transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc);
@@ -144,6 +144,11 @@ struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase
inSharedBuilder,
inSink)
{ }
+ void generateTrivialDiffFuncFromUserDefinedDerivative(
+ IRBuilder* builder,
+ IRFunc* primalFunc,
+ IRFunc* diffPropFunc,
+ IRUserDefinedBackwardDerivativeDecoration* udfDecor);
virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override;
virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) override;
@@ -189,6 +194,10 @@ struct BackwardDiffTranscriber : BackwardDiffTranscriberBase
{
return backDecor->getBackwardDerivativeFunc();
}
+ if (auto backDecor = originalFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
+ {
+ return backDecor->getBackwardDerivativeFunc();
+ }
return nullptr;
}
virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) override
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index deb1b2da9..275b40b46 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -143,6 +143,9 @@ IRInst* AutoDiffTranscriberBase::findOrTranscribePrimalInst(IRBuilder* builder,
IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRInst* inst)
{
+ if (!inst)
+ return nullptr;
+
IRInst* primal = lookupPrimalInst(builder, inst, nullptr);
if (!primal)
{
@@ -234,6 +237,13 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRInst* primalType)
IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType)
{
+ auto primalType = lookupPrimalInst(builder, origType, origType);
+ if (primalType->getOp() == kIROp_Param &&
+ primalType->getParent() && primalType->getParent()->getParent() &&
+ primalType->getParent()->getParent()->getOp() == kIROp_Generic)
+ {
+ return (IRType*)differentiableTypeConformanceContext.getDifferentialForType(builder, origType);
+ }
return (IRType*)transcribe(builder, origType);
}
@@ -725,6 +735,8 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene
builder.setInsertBefore(origGeneric);
auto diffGeneric = builder.emitGeneric();
+
+ mapDifferentialInst(origGeneric, diffGeneric);
// Process type of generic. If the generic is a function, then it's type will also be a
// generic and this logic will transcribe that generic first before continuing with the
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index fa9f4ffb2..78b8c5098 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -413,6 +413,14 @@ struct DiffTransposePass
void transposeInst(IRBuilder* builder, IRInst* inst)
{
+ switch (inst->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ return;
+ default:
+ break;
+ }
+
// Look for gradient entries for this inst.
List<RevGradient> gradients;
if (hasRevGradients(inst))
@@ -520,14 +528,21 @@ struct DiffTransposePass
List<IRInst*> args;
List<IRType*> argTypes;
- List<bool> isArgPtrTyped;
+ List<bool> argRequiresLoad;
+
+ auto getDiffPairType = [](IRType* type)
+ {
+ if (auto ptrType = as<IRPtrTypeBase>(type))
+ type = ptrType->getValueType();
+ return as<IRDifferentialPairType>(type);
+ };
for (UIndex ii = 0; ii < fwdCall->getArgCount(); ii++)
{
auto arg = fwdCall->getArg(ii);
// If this isn't a ptr-type, make a var.
- if (!as<IRPtrTypeBase>(arg->getDataType()) && as<IRDifferentialPairType>(arg->getDataType()))
+ if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType()))
{
auto pairType = as<IRDifferentialPairType>(arg->getDataType());
@@ -548,24 +563,26 @@ struct DiffTransposePass
args.add(var);
argTypes.add(builder->getInOutType(pairType));
- isArgPtrTyped.add(false);
+ argRequiresLoad.add(true);
}
else
{
args.add(arg);
argTypes.add(arg->getDataType());
- isArgPtrTyped.add(true);
+ argRequiresLoad.add(false);
}
}
args.add(revValue);
argTypes.add(revValue->getDataType());
+ argRequiresLoad.add(false);
args.add(primalContextDecor->getBackwardDerivativePrimalContextVar());
argTypes.add(builder->getOutType(
as<IRPtrTypeBase>(
primalContextDecor->getBackwardDerivativePrimalContextVar()->getDataType())
->getValueType()));
+ argRequiresLoad.add(false);
auto revFnType = builder->getFuncType(argTypes, builder->getVoidType());
auto revCallee = builder->emitBackwardDifferentiatePropagateInst(
@@ -578,17 +595,16 @@ struct DiffTransposePass
for (UIndex ii = 0; ii < fwdCall->getArgCount(); ii++)
{
// Is this arg relevant to auto-diff?
- if (as<IRDifferentialPairType>(as<IRPtrTypeBase>(args[ii]->getDataType())->getValueType()))
+ if (auto diffPairType = getDiffPairType(args[ii]->getDataType()))
{
// If this is ptr typed, ignore (the gradient will be accumulated on the pointer)
// automatically.
//
- if (!isArgPtrTyped[ii])
+ if (argRequiresLoad[ii])
{
auto diffArgType = (IRType*)diffTypeContext.getDifferentialForType(
builder,
- as<IRDifferentialPairType>(
- as<IRPtrTypeBase>(argTypes[ii])->getValueType())->getValueType());
+ diffPairType->getValueType());
auto diffArgPtrType = builder->getPtrType(kIROp_PtrType, diffArgType);
gradients.add(RevGradient(
@@ -889,7 +905,6 @@ struct DiffTransposePass
TranspositionResult transposeInst(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
{
-
// Dispatch logic.
switch(fwdInst->getOp())
{
@@ -924,7 +939,8 @@ struct DiffTransposePass
case kIROp_MakeVector:
return transposeMakeVector(builder, fwdInst, revValue);
-
+
+ case kIROp_Specialize:
case kIROp_unconditionalBranch:
case kIROp_conditionalBranch:
case kIROp_ifElse:
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index b8a4c4f08..a95fd7b9b 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -252,6 +252,12 @@ struct ExtractPrimalFuncContext
SLANG_RELEASE_ASSERT(structType);
auto structKey = genTypeBuilder.createStructKey();
genTypeBuilder.setInsertInto(structType);
+
+ if (isChildInstOf(fieldType->getParent(), structType->getParent()))
+ {
+ IRCloneEnv cloneEnv;
+ fieldType = cloneInst(&cloneEnv, &genTypeBuilder, fieldType);
+ }
return genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType);
}
@@ -452,19 +458,21 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
if (auto structKeyDecor = inst->findDecoration<IRPrimalValueStructKeyDecoration>())
{
builder.setInsertBefore(inst);
- auto addr = builder.emitFieldAddress(
- builder.getPtrType(inst->getDataType()),
+ auto val = builder.emitFieldExtract(
+ inst->getDataType(),
intermediateVar,
structKeyDecor->getStructKey());
if (inst->getOp() == kIROp_Var)
{
// This is a var for intermediate context.
- inst->replaceUsesWith(addr);
+ auto tempVar =
+ builder.emitVar(cast<IRPtrTypeBase>(inst->getFullType())->getValueType());
+ builder.emitStore(tempVar, val);
+ inst->replaceUsesWith(tempVar);
}
else
{
// Orindary value.
- auto val = builder.emitLoad(addr);
inst->replaceUsesWith(val);
}
instsToRemove.add(inst);
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 612212dd9..b06ed29bf 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -184,19 +184,43 @@ struct DiffUnzipPass
return false;
}
+ static IRInst* _getOriginalFunc(IRInst* call)
+ {
+ if (auto decor = call->findDecoration<IRAutoDiffOriginalValueDecoration>())
+ return decor->getOriginalValue();
+ return nullptr;
+ }
+
InstPair splitCall(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRCall* mixedCall)
{
IRBuilder globalBuilder;
globalBuilder.init(autodiffContext->sharedBuilder);
- auto fwdCallee = as<IRForwardDifferentiate>(mixedCall->getCallee());
- auto fwdCalleeType = as<IRFuncType>(fwdCallee->getDataType());
- auto baseFn = fwdCallee->getBaseFn();
+ auto fwdCalleeType = mixedCall->getCallee()->getDataType();
+ auto baseFn = _getOriginalFunc(mixedCall);
+ SLANG_RELEASE_ASSERT(baseFn);
auto primalFuncType = autodiffContext->transcriberSet.primalTranscriber->differentiateFunctionType(
primalBuilder, baseFn, as<IRFuncType>(baseFn->getDataType()));
- auto intermediateVar = primalBuilder->emitVar(primalBuilder->getBackwardDiffIntermediateContextType(baseFn));
+ IRInst* intermediateType = nullptr;
+
+ if (auto specialize = as<IRSpecialize>(baseFn))
+ {
+ auto func = findSpecializeReturnVal(specialize);
+ auto outerGen = findOuterGeneric(func);
+ intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(outerGen);
+ intermediateType = specializeWithGeneric(
+ *primalBuilder,
+ intermediateType,
+ as<IRGeneric>(findOuterGeneric(primalBuilder->getInsertLoc().getParent())));
+ }
+ else
+ {
+ intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(baseFn);
+ }
+
+ auto intermediateVar = primalBuilder->emitVar((IRType*)intermediateType);
primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar);
auto primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn);
@@ -204,7 +228,7 @@ struct DiffUnzipPass
List<IRInst*> primalArgs;
for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++)
{
- auto arg = mixedCall->getArg(0);
+ auto arg = mixedCall->getArg(ii);
if (isRelevantDifferentialPair(arg->getDataType()))
{
@@ -232,7 +256,7 @@ struct DiffUnzipPass
List<IRInst*> diffArgs;
for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++)
{
- auto arg = mixedCall->getArg(0);
+ auto arg = mixedCall->getArg(ii);
if (isRelevantDifferentialPair(arg->getDataType()))
{
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 74afa4002..7182375de 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -6,6 +6,21 @@
namespace Slang
{
+
+bool isBackwardDifferentiableFunc(IRInst* func)
+{
+ for (auto decorations : func->getDecorations())
+ {
+ switch (decorations->getOp())
+ {
+ case kIROp_BackwardDifferentiableDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ return true;
+ }
+ }
+ return false;
+}
+
static IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey)
{
if (auto witnessTable = as<IRWitnessTable>(witness))
@@ -388,7 +403,7 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
{
if (auto pairType = as<IRDifferentialPairType>(globalInst))
{
- differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness());
+ differentiableWitnessDictionary.AddIfNotExists(pairType->getValueType(), pairType->getWitness());
}
}
}
@@ -406,6 +421,8 @@ void stripDerivativeDecorations(IRInst* inst)
case kIROp_BackwardDerivativeIntermediateTypeDecoration:
case kIROp_BackwardDerivativePropagateDecoration:
case kIROp_BackwardDerivativePrimalDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_AutoDiffOriginalValueDecoration:
decor->removeAndDeallocate();
break;
default:
@@ -435,6 +452,8 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
case kIROp_BackwardDerivativePrimalDecoration:
case kIROp_BackwardDerivativePrimalContextDecoration:
case kIROp_BackwardDerivativePrimalReturnDecoration:
+ case kIROp_AutoDiffOriginalValueDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
decor->removeAndDeallocate();
break;
default:
@@ -456,27 +475,26 @@ void stripAutoDiffDecorations(IRModule* module)
}
-void stripBlockTypeDecorations(IRFunc* func)
+void stripTempDecorations(IRInst* inst)
{
- for (auto child : func->getChildren())
+ for (auto decor = inst->getFirstDecoration(); decor; )
{
- if (auto block = as<IRBlock>(child))
+ auto next = decor->getNextDecoration();
+ switch (decor->getOp())
{
- for (auto decor = block->getFirstDecoration(); decor; )
- {
- auto next = decor->getNextDecoration();
- switch (decor->getOp())
- {
- case kIROp_DifferentialInstDecoration:
- case kIROp_MixedDifferentialInstDecoration:
- decor->removeAndDeallocate();
- break;
- default:
- break;
- }
- decor = next;
- }
+ case kIROp_DifferentialInstDecoration:
+ case kIROp_MixedDifferentialInstDecoration:
+ case kIROp_AutoDiffOriginalValueDecoration:
+ decor->removeAndDeallocate();
+ break;
+ default:
+ break;
}
+ decor = next;
+ }
+ for (auto child : inst->getChildren())
+ {
+ stripTempDecorations(child);
}
}
@@ -554,9 +572,7 @@ struct AutoDiffPass : public InstPassBase
auto inner = findGenericReturnVal(baseGeneric);
if (auto typeDecor = inner->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>())
{
- auto typeSpec = cast<IRSpecialize>(typeDecor->getBackwardDerivativeIntermediateType());
- auto typeSpecBase = typeSpec->getBase();
- return typeSpecBase;
+ return typeDecor->getBackwardDerivativeIntermediateType();
}
}
else if (auto func = as<IRFunc>(base))
@@ -742,7 +758,7 @@ struct AutoDiffPass : public InstPassBase
// passes since they don't expect decorations in blocks.
//
for (auto diffFunc : autodiffCleanupList)
- stripBlockTypeDecorations(diffFunc);
+ stripTempDecorations(diffFunc);
autodiffCleanupList.clear();
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index f468b1fca..7479e4eee 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -259,4 +259,5 @@ bool finalizeAutoDiffPass(IRModule* module);
void stripDerivativeDecorations(IRInst* inst);
+bool isBackwardDifferentiableFunc(IRInst* func);
};
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 83351d07b..8413e7e79 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -71,6 +71,7 @@ public:
{
case kIROp_ForwardDerivativeDecoration:
case kIROp_ForwardDifferentiableDecoration:
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
case kIROp_BackwardDerivativeDecoration:
case kIROp_BackwardDifferentiableDecoration:
return true;
@@ -140,20 +141,6 @@ public:
return false;
}
- bool isBackwardDifferentiableFunc(IRInst* func)
- {
- for (auto decorations : func->getDecorations())
- {
- switch (decorations->getOp())
- {
- case kIROp_BackwardDerivativeDecoration:
- case kIROp_BackwardDifferentiableDecoration:
- return true;
- }
- }
- return false;
- }
-
bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst)
{
HashSet<IRInst*> processedSet;
diff --git a/source/slang/slang-ir-hoist-local-types.cpp b/source/slang/slang-ir-hoist-local-types.cpp
index cf091f701..d8b0eab22 100644
--- a/source/slang/slang-ir-hoist-local-types.cpp
+++ b/source/slang/slang-ir-hoist-local-types.cpp
@@ -16,12 +16,6 @@ struct HoistLocalTypesContext
void addToWorkList(IRInst* inst)
{
- for (auto ii = inst->getParent(); ii; ii = ii->getParent())
- {
- if (as<IRGeneric>(ii))
- return;
- }
-
if (workListSet.Contains(inst))
return;
@@ -29,19 +23,28 @@ struct HoistLocalTypesContext
workListSet.Add(inst);
}
- void processInst(IRInst* inst)
+ bool processInst(IRInst* inst)
{
auto sharedBuilder = &sharedBuilderStorage;
if (!as<IRType>(inst))
- return;
+ return false;
if (inst->getParent() == module->getModuleInst())
- return;
+ return false;
+ switch (inst->getOp())
+ {
+ case kIROp_InterfaceType:
+ case kIROp_StructType:
+ case kIROp_ClassType:
+ return false;
+ default:
+ break;
+ }
IRInstKey key = {inst};
if (auto value = sharedBuilder->getGlobalValueNumberingMap().TryGetValue(key))
{
inst->replaceUsesWith(*value);
inst->removeAndDeallocate();
- return;
+ return true;
}
IRBuilder builder(sharedBuilder);
builder.setInsertInto(module->getModuleInst());
@@ -67,7 +70,9 @@ struct HoistLocalTypesContext
inst->transferDecorationsTo(newType);
inst->replaceUsesWith(newType);
inst->removeAndDeallocate();
+ return true;
}
+ return false;
}
void processModule()
@@ -75,24 +80,31 @@ struct HoistLocalTypesContext
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->init(module);
- // Deduplicate equivalent types and build numbering map for global types.
- sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
+ for (;;)
+ {
+ bool changed = false;
+ // Deduplicate equivalent types and build numbering map for global types.
+ sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
- addToWorkList(module->getModuleInst());
+ addToWorkList(module->getModuleInst());
- while (workList.getCount() != 0)
- {
- IRInst* inst = workList.getLast();
+ while (workList.getCount() != 0)
+ {
+ IRInst* inst = workList.getLast();
- workList.removeLast();
- workListSet.Remove(inst);
+ workList.removeLast();
+ workListSet.Remove(inst);
- processInst(inst);
+ changed |= processInst(inst);
- for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
- {
- addToWorkList(child);
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ addToWorkList(child);
+ }
}
+
+ if (!changed)
+ break;
}
}
};
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index ab7453b41..68afbbb95 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -726,6 +726,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// Decorated function is marked for the forward-mode differentiation pass.
INST(ForwardDifferentiableDecoration, forwardDifferentiable, 0, 0)
+ /// Decorates a auto-diff transcribed value with the original value that the inst is transcribed from.
+ INST(AutoDiffOriginalValueDecoration, AutoDiffOriginalValueDecoration, 1, 0)
+
/// Used by the auto-diff pass to hold a reference to the
/// generated derivative function.
INST(ForwardDerivativeDecoration, fwdDerivative, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index b30d489dc..22da763b3 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -566,6 +566,16 @@ struct IRSequentialIDDecoration : IRDecoration
IRIntegerValue getSequentialID() { return getSequentialIDOperand()->getValue(); }
};
+struct IRAutoDiffOriginalValueDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_AutoDiffOriginalValueDecoration
+ };
+ IR_LEAF_ISA(AutoDiffOriginalValueDecoration)
+ IRInst* getOriginalValue() { return getOperand(0); }
+};
+
struct IRForwardDifferentiableDecoration : IRDecoration
{
enum
@@ -708,6 +718,7 @@ struct IRUserDefinedBackwardDerivativeDecoration : IRDecoration
kOp = kIROp_UserDefinedBackwardDerivativeDecoration
};
IR_LEAF_ISA(UserDefinedBackwardDerivativeDecoration)
+ IRInst* getBackwardDerivativeFunc() { return getOperand(0); }
};
struct IRTreatAsDifferentiableDecoration : IRDecoration
@@ -3491,6 +3502,11 @@ public:
addDecoration(value, kIROp_ForceInlineDecoration);
}
+ void addAutoDiffOriginalValueDecoration(IRInst* value, IRInst* originalVal)
+ {
+ addDecoration(value, kIROp_AutoDiffOriginalValueDecoration, originalVal);
+ }
+
void addForwardDifferentiableDecoration(IRInst* value)
{
addDecoration(value, kIROp_ForwardDifferentiableDecoration);
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 73d8865ed..fb465f638 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -1,6 +1,7 @@
#include "slang-ir-util.h"
#include "slang-ir-insts.h"
#include "slang-ir-clone.h"
+#include "slang-ir-dce.h"
namespace Slang
{
@@ -198,6 +199,7 @@ IRInst* hoistValueFromGeneric(IRBuilder& inBuilder, IRInst* value, IRInst*& outS
value->replaceUsesWith(outSpecializedVal);
value->removeAndDeallocate();
}
+ eliminateDeadCode(newGeneric);
return newGeneric;
}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index b36a2ebec..e400d0a17 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -6671,6 +6671,8 @@ namespace Slang
case kIROp_ForwardDifferentiate:
case kIROp_BackwardDifferentiate:
+ case kIROp_BackwardDifferentiatePrimal:
+ case kIROp_BackwardDifferentiatePropagate:
return false;
}
@@ -6815,6 +6817,13 @@ namespace Slang
return nullptr;
}
+ IRInst* maybeFindOuterGeneric(IRInst* inst)
+ {
+ auto outerGeneric = findOuterGeneric(inst);
+ if (!outerGeneric) return inst;
+ return outerGeneric;
+ }
+
IRInst* findOuterMostGeneric(IRInst* inst)
{
IRInst* currInst = inst;
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index b4a657545..e22e41f0c 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1758,6 +1758,9 @@ IRInst* findOuterGeneric(IRInst* inst);
// Recursively find the outer most generic container.
IRInst* findOuterMostGeneric(IRInst* inst);
+// Returns `inst` if it is not a generic, otherwise its outer generic.
+IRInst* maybeFindOuterGeneric(IRInst* inst);
+
struct IRSpecialize;
IRGeneric* findSpecializedGeneric(IRSpecialize* specialize);
IRInst* findSpecializeReturnVal(IRSpecialize* specialize);
diff --git a/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang b/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang
new file mode 100644
index 000000000..bd0780174
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang
@@ -0,0 +1,61 @@
+// Test calling differentiable function through dynamic dispatch.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[anyValueSize(16)]
+interface IInterface
+{
+ static float calc(float x);
+}
+
+struct A : IInterface
+{
+ static float calc(float x) { return 1.0; }
+};
+
+struct B : IInterface
+{
+ static float calc(float x) { return 2.0; }
+};
+
+void dsqr<T:IInterface>(T obj, inout DifferentialPair<float> x, float dOut)
+{
+ float diff = 2.0 * x.p * dOut;
+ updateDiff(x, diff);
+}
+
+[BackwardDerivative(dsqr)]
+float sqr<T:IInterface>(T obj, float x)
+{
+ return no_diff(obj.calc(x)) + x * x;
+}
+
+// Use automatically differentiated outer function to triger the primal/propagate func generation logic
+// on a function that has user provided backward derivative.
+[BackwardDifferentiable]
+float sqr_outter<T:IInterface>(T obj, float x)
+{
+ return sqr(obj, x);
+}
+
+//TEST_INPUT: type_conformance A:IInterface = 0
+//TEST_INPUT: type_conformance B:IInterface = 1
+
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A
+ var p = DifferentialPair<float>(2.0, 1.0);
+ __bwd_diff(sqr_outter)(obj, p, 1.0); // A.calc, expect 4
+ outputBuffer[0] = p.d;
+
+ obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B
+ p = DifferentialPair<float>(1.5, 1.0);
+ __bwd_diff(sqr)(obj, p, 1.0); // A.calc, expect 4
+ outputBuffer[1] = p.d; // B.calc, expect 3
+}
diff --git a/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang.expected.txt b/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang.expected.txt
new file mode 100644
index 000000000..780ba6ed4
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+4.000000
+3.000000
+0.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang b/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang
new file mode 100644
index 000000000..930c1c82b
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang
@@ -0,0 +1,53 @@
+// Test calling differentiable function through dynamic dispatch.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[anyValueSize(16)]
+interface IInterface
+{
+ static float calc(float x);
+}
+
+struct A : IInterface
+{
+ static float calc(float x) { return 1.0; }
+};
+
+struct B : IInterface
+{
+ static float calc(float x) { return 2.0; }
+};
+
+DifferentialPair<float> dsqr<T:IInterface>(T obj, DifferentialPair<float> x)
+{
+ float primal = obj.calc(x.p) + x.p * x.p;
+ float diff = 2.0 * x.p * x.d;
+ return diffPair(primal, diff);
+}
+
+[ForwardDerivative(dsqr)]
+float sqr<T:IInterface>(T obj, float x)
+{
+ return no_diff(obj.calc(x)) + x * x;
+}
+
+//TEST_INPUT: type_conformance A:IInterface = 0
+//TEST_INPUT: type_conformance B:IInterface = 1
+
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A
+ var p = DifferentialPair<float>(2.0, 1.0);
+
+ outputBuffer[0] = __fwd_diff(sqr)(obj, p).d; // A.calc, expect 4
+
+ obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B
+ p = DifferentialPair<float>(1.5, 1.0);
+ outputBuffer[1] = __fwd_diff(sqr)(obj, p).d; // B.calc, expect 3
+}
diff --git a/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang.expected.txt b/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang.expected.txt
new file mode 100644
index 000000000..780ba6ed4
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+4.000000
+3.000000
+0.000000
+0.000000
+0.000000
diff --git a/tests/language-server/robustness-6.slang b/tests/language-server/robustness-6.slang
new file mode 100644
index 000000000..ef5924cf3
--- /dev/null
+++ b/tests/language-server/robustness-6.slang
@@ -0,0 +1,10 @@
+//TEST:LANG_SERVER:
+//HOVER:4,8
+
+float dsqr<T:II
+
+[ForwardDerivative(dsqr)]
+float sqr<T:IInterface>(T obj, float x)
+{
+ return no_diff(obj.calc(x)) + x * x;
+}
diff --git a/tests/language-server/robustness-6.slang.expected.txt b/tests/language-server/robustness-6.slang.expected.txt
new file mode 100644
index 000000000..d5aa6c8c9
--- /dev/null
+++ b/tests/language-server/robustness-6.slang.expected.txt
@@ -0,0 +1,13 @@
+--------
+range: 3,6 - 3,10
+content:
+```
+func dsqr<T>(T obj, float x) -> float
+```
+
+TEST:LANG_SERVER:
+HOVER:4,8
+
+{REDACTED}.slang(4)
+
+