summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-decl.cpp4
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-ir-addr-inst-elimination.cpp21
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp40
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp8
-rw-r--r--source/slang/slang-ir-autodiff.cpp1
-rw-r--r--source/slang/slang-lower-to-ir.cpp11
7 files changed, 54 insertions, 33 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 6083ce9c0..6a32f59d3 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -5638,6 +5638,10 @@ namespace Slang
bool isDiffFunc = false;
if (decl->hasModifier<ForwardDifferentiableAttribute>() || decl->hasModifier<BackwardDifferentiableAttribute>())
{
+ if (GetOuterGeneric(decl))
+ {
+ getSink()->diagnose(decl, Diagnostics::differentiableGenericInterfaceMethodNotSupported);
+ }
auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>();
cloneModifiers(reqDecl, decl);
auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 2401b6e58..e3e9cfc44 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -352,6 +352,8 @@ DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative att
DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.")
DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.")
+DIAGNOSTIC(31148, Error, differentiableGenericInterfaceMethodNotSupported, "`[ForwardDifferentiable] and [BackwardDifferentiable] are not supported on generic interface requirements.")
+
// Enums
DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'")
diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp
index 6715f2c6a..16bd67f66 100644
--- a/source/slang/slang-ir-addr-inst-elimination.cpp
+++ b/source/slang/slang-ir-addr-inst-elimination.cpp
@@ -99,22 +99,11 @@ struct AddressInstEliminationContext
IRBuilder builder(module);
builder.setInsertBefore(call);
auto tempVar = builder.emitVar(cast<IRPtrTypeBase>(addr->getFullType())->getValueType());
- auto callee = getResolvedInstForDecorations(call->getCallee());
- auto funcType = as<IRFuncType>(callee->getFullType());
- SLANG_RELEASE_ASSERT(funcType);
- UInt paramIndex = (UInt)(use - call->getOperands() - 1);
- SLANG_RELEASE_ASSERT(call->getArg(paramIndex) == addr);
- if (!as<IROutType>(funcType->getParamType(paramIndex)))
- {
- builder.emitStore(tempVar, getValue(builder, addr));
- }
- else
- {
- builder.emitStore(
- tempVar,
- builder.emitDefaultConstruct(
- as<IRPtrTypeBase>(tempVar->getDataType())->getValueType()));
- }
+
+ // Store the initial value of the mutable argument into temp var.
+ // If this is an `out` var, the initial value will be undefined,
+ // which will get cleaned up later into a `defaultConstruct`.
+ builder.emitStore(tempVar, getValue(builder, addr));
builder.setInsertAfter(call);
storeValue(builder, addr, builder.emitLoad(tempVar));
use->set(tempVar);
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 3f31f1463..869f8920c 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -510,7 +510,7 @@ IRInst* tryFindPrimalSubstitute(IRBuilder* builder, IRInst* callee)
{
auto innerGen = as<IRGeneric>(specialize->getBase());
if (!innerGen)
- return nullptr;
+ return callee;
auto innerFunc = findGenericReturnVal(innerGen);
if (auto decor = innerFunc->findDecoration<IRPrimalSubstituteDecoration>())
{
@@ -553,7 +553,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
return InstPair(nullptr, nullptr);
}
- auto primalCallee = lookupPrimalInst(builder, origCallee, origCallee);
+ auto primalCallee = findOrTranscribePrimalInst(builder, origCallee);
auto substPrimalCallee = tryFindPrimalSubstitute(builder, primalCallee);
IRInst* diffCallee = nullptr;
@@ -563,7 +563,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
}
else
{
- instMapD.TryGetValue(substPrimalCallee, diffCallee);
+ diffCallee = findOrTranscribeDiffInst(builder, origCallee);
primalCallee = substPrimalCallee;
}
@@ -904,17 +904,32 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec
IRInst* diffBase = nullptr;
if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase))
{
- List<IRInst*> args;
- for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
+ if (diffBase)
{
- args.add(primalSpecialize->getArg(i));
+ List<IRInst*> args;
+ for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
+ {
+ args.add(primalSpecialize->getArg(i));
+ }
+ auto diffSpecialize = builder->emitSpecializeInst(
+ builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ return InstPair(primalSpecialize, diffSpecialize);
+ }
+ else
+ {
+ return InstPair(primalSpecialize, nullptr);
}
- auto diffSpecialize = builder->emitSpecializeInst(
- builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
- return InstPair(primalSpecialize, diffSpecialize);
}
auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase()));
+
+ // Right now we don't support transcribing a differentiable callee that is a specialize of a interface lookup
+ // (calling differentiable generic interface method). To support it, we need to recursively transcribe the
+ // specialization base here.
+
+ if (!genericInnerVal)
+ return InstPair(primalSpecialize, nullptr);
+
// Look for an IRForwardDerivativeDecoration on the specialize inst.
// (Normally, this would be on the inner IRFunc, but in this case only the JVP func
// can be specialized, so we put a decoration on the IRSpecialize)
@@ -963,10 +978,7 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec
builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
return InstPair(primalSpecialize, diffSpecialize);
}
- else
- {
- return InstPair(primalSpecialize, nullptr);
- }
+ return InstPair(primalSpecialize, nullptr);
}
InstPair ForwardDiffTranscriber::transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst)
@@ -1433,6 +1445,8 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I
IRFunc* primalFunc = origFunc;
+ maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc);
+
differentiableTypeConformanceContext.setFunc(origFunc);
primalFunc = origFunc;
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 552ac762c..9cbea7873 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -594,7 +594,11 @@ void AutoDiffTranscriberBase::maybeMigrateDifferentiableDictionaryFromDerivative
}
else
{
- cloneDecoration(udfDecor, origFunc);
+ auto udfDictDecor = derivative->findDecoration< IRDifferentiableTypeDictionaryDecoration>();
+ if (udfDictDecor)
+ {
+ cloneDecoration(udfDictDecor, origFunc);
+ }
}
}
@@ -977,6 +981,8 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene
if (auto innerFunc = as<IRFunc>(innerVal))
{
maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, innerFunc);
+ if (!innerFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
+ return InstPair(origGeneric, nullptr);
differentiableTypeConformanceContext.setFunc(innerFunc);
}
else if (auto funcType = as<IRFuncType>(innerVal))
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index f173aaa8b..1909f860c 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -368,6 +368,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
{
parentFunc = func;
+
auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
SLANG_RELEASE_ASSERT(decor);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index f84f17886..9c27beb58 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8484,10 +8484,15 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
funcExpr = udAttr->funcExpr;
else if (auto primalAttr = as<PrimalSubstituteAttribute>(modifier))
funcExpr = primalAttr->funcExpr;
+ DeclRefExpr* declRefExpr = as<DeclRefExpr>(funcExpr);
+ auto funcType = lowerType(subContext, funcExpr->type);
+ auto loweredVal = emitDeclRef(
+ subContext,
+ declRefExpr->declRef,
+ funcType);
+
+ SLANG_RELEASE_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
- auto loweredVal = lowerRValueExpr(subContext, funcExpr);
-
- SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
IRInst* derivativeFunc = loweredVal.val;
if (as<ForwardDerivativeAttribute>(modifier))