summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-10 09:01:59 -0800
committerGitHub <noreply@github.com>2023-02-10 09:01:59 -0800
commit6e7b424953ae6732d4863e887e7e452396095d71 (patch)
tree883ca4d168ba679bd7ffe197fa765c8b42b19c6b
parentdf02f3f50f977112ca1fbb148cd48ee41d560f41 (diff)
Fix checking of `[BackwardDerivativeOf]` attribute. (#2640)
* Fix checking of `[BackwardDerivativeOf]` attribute. * Fix crash in `canInstHaveSideEffectAtAddress`. * Fix. * Revert fix. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-check-decl.cpp123
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp3
-rw-r--r--source/slang/slang-ir-util.cpp3
4 files changed, 101 insertions, 29 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 9bda6c3e7..7e8e94d95 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -4771,6 +4771,9 @@ namespace Slang
List<Expr*> getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* bwdDiffFunc, SourceLoc loc)
{
+ // Note: it isn't always possible to construct original arguments from
+ // backward propagation arguments because backward propagation function
+ // may drop certain parameters.
List<Expr*> imaginaryArguments;
for (auto param : bwdDiffFunc->getParameters())
{
@@ -4789,6 +4792,12 @@ namespace Slang
}
imaginaryArguments.add(arg);
}
+ // Assume the last parameter is `dOut`.
+ // This is not true if the function returns a non-differentiable value.
+ // However in that uncommon case we just fail the overload resolution
+ // and require the user to provide disambiguate themselves.
+ if (imaginaryArguments.getCount())
+ imaginaryArguments.fastRemoveAt(imaginaryArguments.getCount() - 1);
return imaginaryArguments;
}
@@ -4808,41 +4817,99 @@ namespace Slang
DeclAssociationKind assocKind,
const List<Expr*>& imaginaryArgsToOriginal)
{
- auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal);
- auto resolved = visitor->ResolveInvoke(invokeExpr);
- if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ DeclRef<Decl> calleeDeclRef;
+ auto calleeDeclRefExpr = as<DeclRefExpr>(derivativeOfAttr->funcExpr);
+ if (!calleeDeclRefExpr)
{
- if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
+ auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal);
+ auto resolved = visitor->ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ calleeDeclRefExpr = as<DeclRefExpr>(resolvedInvoke->functionExpr);
+ }
+ }
+ if (!calleeDeclRefExpr)
+ {
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
+ return;
+ }
+ calleeDeclRef = calleeDeclRefExpr->declRef;
+ if (auto calleeGenDecl = as<GenericDecl>(calleeDeclRef.getDecl()))
+ {
+ auto parentGenericDecl = as<GenericDecl>(funcDecl->parentDecl);
+ if (!parentGenericDecl)
{
- auto calleeDecl = calleeDeclRef->declRef.getDecl();
- if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeDecl))
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
+ return;
+ }
+ FunctionDeclBase* funcReturnVal = nullptr;
+ List<Val*> args;
+ for (auto mm : parentGenericDecl->members)
+ {
+ if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm))
{
- // The primal function already has a `[*Derivative]` attribute, this is invalid.
- visitor->getSink()->diagnose(
- derivativeOfAttr,
- Diagnostics::declAlreadyHasAttribute,
- calleeDeclRef->declRef,
- getDerivativeAttrName<TDerivativeAttr>());
- visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl());
+ args.add(DeclRefType::create(visitor->getASTBuilder(), DeclRef<Decl>(genericTypeParamDecl, nullptr)));
}
- derivativeOfAttr->funcExpr = calleeDeclRef;
- auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>();
- derivativeAttr->loc = derivativeOfAttr->loc;
- auto outterGeneric = visitor->GetOuterGeneric(funcDecl);
- auto declRef =
- DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr);
- auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr);
- declRefExpr->type.type = nullptr;
- derivativeAttr->args.add(declRefExpr);
- derivativeAttr->funcExpr = declRefExpr;
- checkDerivativeAttribute(visitor, as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), derivativeAttr);
- derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr;
- derivativeAttr->funcExpr = nullptr;
- visitor->getShared()->registerAssociatedDecl(calleeDeclRef->declRef.getDecl(), assocKind, funcDecl);
+ else if (auto genericValueParamDecl = as<GenericValueParamDecl>(mm))
+ {
+ args.add(visitor->getASTBuilder()->getOrCreate<GenericParamIntVal>(
+ genericValueParamDecl->getType(),
+ genericValueParamDecl, nullptr));
+ }
+ }
+ auto funcs = calleeGenDecl->getMembersOfType<FunctionDeclBase>();
+ if (funcs.isEmpty())
+ {
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
return;
}
+ funcReturnVal = funcs.getFirst();
+ if (funcReturnVal)
+ {
+ auto subst = visitor->getASTBuilder()->getOrCreateGenericSubstitution(calleeGenDecl, args, nullptr);
+ calleeDeclRef.decl = funcReturnVal;
+ calleeDeclRef.substitutions = subst;
+ calleeDeclRefExpr = as<DeclRefExpr>(visitor->ConstructDeclRefExpr(
+ calleeDeclRef, nullptr, derivativeOfAttr->loc, nullptr));
+ }
+ else
+ {
+ calleeDeclRef = DeclRef<Decl>();
+ calleeDeclRefExpr = nullptr;
+ }
+ }
+
+ auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl());
+ if (!calleeFunc)
+ {
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
+ return;
}
- visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::invalidCustomDerivative);
+
+ if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc))
+ {
+ // The primal function already has a `[*Derivative]` attribute, this is invalid.
+ visitor->getSink()->diagnose(
+ derivativeOfAttr,
+ Diagnostics::declAlreadyHasAttribute,
+ calleeDeclRef,
+ getDerivativeAttrName<TDerivativeAttr>());
+ visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl());
+ }
+ derivativeOfAttr->funcExpr = calleeDeclRefExpr;
+ auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>();
+ derivativeAttr->loc = derivativeOfAttr->loc;
+ auto outterGeneric = visitor->GetOuterGeneric(funcDecl);
+ auto declRef =
+ DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr);
+ auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr);
+ declRefExpr->type.type = nullptr;
+ derivativeAttr->args.add(declRefExpr);
+ derivativeAttr->funcExpr = declRefExpr;
+ checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr);
+ derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr;
+ derivativeAttr->funcExpr = nullptr;
+ visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl);
}
static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr)
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 0d4088d75..d3731756a 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -349,6 +349,7 @@ DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original defi
DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative attribute.")
DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.")
+DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.")
// Enums
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 3e7e346d2..25f6c3964 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -256,7 +256,8 @@ struct ExtractPrimalFuncContext
auto structKey = genTypeBuilder.createStructKey();
genTypeBuilder.setInsertInto(structType);
- if (isChildInstOf(fieldType->getParent(), structType->getParent()))
+ if (fieldType->getParent() != structType->getParent() &&
+ isChildInstOf(fieldType->getParent(), structType->getParent()))
{
IRCloneEnv cloneEnv;
fieldType = cloneInst(&cloneEnv, &genTypeBuilder, fieldType);
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index af6fd8ac4..942c8f2f8 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -364,6 +364,8 @@ bool canAddressesPotentiallyAlias(IRGlobalValueWithCode* func, IRInst* addr1, IR
bool isPtrLikeOrHandleType(IRInst* type)
{
+ if (!type)
+ return false;
switch (type->getOp())
{
case kIROp_ComPtrType:
@@ -413,6 +415,7 @@ bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, I
// If any pointer typed argument of the call inst may overlap addr, return true.
for (UInt i = 0; i < call->getArgCount(); i++)
{
+ SLANG_RELEASE_ASSERT(call->getArg(i)->getDataType());
if (isPtrLikeOrHandleType(call->getArg(i)->getDataType()))
{
if (canAddressesPotentiallyAlias(func, call->getArg(i), addr))