From 6e7b424953ae6732d4863e887e7e452396095d71 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 10 Feb 2023 09:01:59 -0800 Subject: Fix checking of `[BackwardDerivativeOf]` attribute. (#2640) * Fix checking of `[BackwardDerivativeOf]` attribute. * Fix crash in `canInstHaveSideEffectAtAddress`. * Fix. * Revert fix. * Fix. --------- Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 123 ++++++++++++++++++++++++------- source/slang/slang-diagnostic-defs.h | 1 + source/slang/slang-ir-autodiff-unzip.cpp | 3 +- source/slang/slang-ir-util.cpp | 3 + 4 files changed, 101 insertions(+), 29 deletions(-) (limited to 'source') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 9bda6c3e7..7e8e94d95 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -4771,6 +4771,9 @@ namespace Slang List getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* bwdDiffFunc, SourceLoc loc) { + // Note: it isn't always possible to construct original arguments from + // backward propagation arguments because backward propagation function + // may drop certain parameters. List imaginaryArguments; for (auto param : bwdDiffFunc->getParameters()) { @@ -4789,6 +4792,12 @@ namespace Slang } imaginaryArguments.add(arg); } + // Assume the last parameter is `dOut`. + // This is not true if the function returns a non-differentiable value. + // However in that uncommon case we just fail the overload resolution + // and require the user to provide disambiguate themselves. + if (imaginaryArguments.getCount()) + imaginaryArguments.fastRemoveAt(imaginaryArguments.getCount() - 1); return imaginaryArguments; } @@ -4808,41 +4817,99 @@ namespace Slang DeclAssociationKind assocKind, const List& imaginaryArgsToOriginal) { - auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal); - auto resolved = visitor->ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as(resolved)) + DeclRef calleeDeclRef; + auto calleeDeclRefExpr = as(derivativeOfAttr->funcExpr); + if (!calleeDeclRefExpr) { - if (auto calleeDeclRef = as(resolvedInvoke->functionExpr)) + auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal); + auto resolved = visitor->ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as(resolved)) + { + calleeDeclRefExpr = as(resolvedInvoke->functionExpr); + } + } + if (!calleeDeclRefExpr) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + calleeDeclRef = calleeDeclRefExpr->declRef; + if (auto calleeGenDecl = as(calleeDeclRef.getDecl())) + { + auto parentGenericDecl = as(funcDecl->parentDecl); + if (!parentGenericDecl) { - auto calleeDecl = calleeDeclRef->declRef.getDecl(); - if (auto existingModifier = _findModifier(calleeDecl)) + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + FunctionDeclBase* funcReturnVal = nullptr; + List args; + for (auto mm : parentGenericDecl->members) + { + if (auto genericTypeParamDecl = as(mm)) { - // The primal function already has a `[*Derivative]` attribute, this is invalid. - visitor->getSink()->diagnose( - derivativeOfAttr, - Diagnostics::declAlreadyHasAttribute, - calleeDeclRef->declRef, - getDerivativeAttrName()); - visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl()); + args.add(DeclRefType::create(visitor->getASTBuilder(), DeclRef(genericTypeParamDecl, nullptr))); } - derivativeOfAttr->funcExpr = calleeDeclRef; - auto derivativeAttr = visitor->getASTBuilder()->create(); - derivativeAttr->loc = derivativeOfAttr->loc; - auto outterGeneric = visitor->GetOuterGeneric(funcDecl); - auto declRef = - DeclRef((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); - auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); - declRefExpr->type.type = nullptr; - derivativeAttr->args.add(declRefExpr); - derivativeAttr->funcExpr = declRefExpr; - checkDerivativeAttribute(visitor, as(calleeDeclRef->declRef.getDecl()), derivativeAttr); - derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr; - derivativeAttr->funcExpr = nullptr; - visitor->getShared()->registerAssociatedDecl(calleeDeclRef->declRef.getDecl(), assocKind, funcDecl); + else if (auto genericValueParamDecl = as(mm)) + { + args.add(visitor->getASTBuilder()->getOrCreate( + genericValueParamDecl->getType(), + genericValueParamDecl, nullptr)); + } + } + auto funcs = calleeGenDecl->getMembersOfType(); + if (funcs.isEmpty()) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } + funcReturnVal = funcs.getFirst(); + if (funcReturnVal) + { + auto subst = visitor->getASTBuilder()->getOrCreateGenericSubstitution(calleeGenDecl, args, nullptr); + calleeDeclRef.decl = funcReturnVal; + calleeDeclRef.substitutions = subst; + calleeDeclRefExpr = as(visitor->ConstructDeclRefExpr( + calleeDeclRef, nullptr, derivativeOfAttr->loc, nullptr)); + } + else + { + calleeDeclRef = DeclRef(); + calleeDeclRefExpr = nullptr; + } + } + + auto calleeFunc = as(calleeDeclRef.getDecl()); + if (!calleeFunc) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; } - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::invalidCustomDerivative); + + if (auto existingModifier = _findModifier(calleeFunc)) + { + // The primal function already has a `[*Derivative]` attribute, this is invalid. + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::declAlreadyHasAttribute, + calleeDeclRef, + getDerivativeAttrName()); + visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl()); + } + derivativeOfAttr->funcExpr = calleeDeclRefExpr; + auto derivativeAttr = visitor->getASTBuilder()->create(); + derivativeAttr->loc = derivativeOfAttr->loc; + auto outterGeneric = visitor->GetOuterGeneric(funcDecl); + auto declRef = + DeclRef((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); + auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); + declRefExpr->type.type = nullptr; + derivativeAttr->args.add(declRefExpr); + derivativeAttr->funcExpr = declRefExpr; + checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr); + derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr; + derivativeAttr->funcExpr = nullptr; + visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl); } static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 0d4088d75..d3731756a 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -349,6 +349,7 @@ DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original defi DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative attribute.") DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.") +DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.") // Enums diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 3e7e346d2..25f6c3964 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -256,7 +256,8 @@ struct ExtractPrimalFuncContext auto structKey = genTypeBuilder.createStructKey(); genTypeBuilder.setInsertInto(structType); - if (isChildInstOf(fieldType->getParent(), structType->getParent())) + if (fieldType->getParent() != structType->getParent() && + isChildInstOf(fieldType->getParent(), structType->getParent())) { IRCloneEnv cloneEnv; fieldType = cloneInst(&cloneEnv, &genTypeBuilder, fieldType); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index af6fd8ac4..942c8f2f8 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -364,6 +364,8 @@ bool canAddressesPotentiallyAlias(IRGlobalValueWithCode* func, IRInst* addr1, IR bool isPtrLikeOrHandleType(IRInst* type) { + if (!type) + return false; switch (type->getOp()) { case kIROp_ComPtrType: @@ -413,6 +415,7 @@ bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, I // If any pointer typed argument of the call inst may overlap addr, return true. for (UInt i = 0; i < call->getArgCount(); i++) { + SLANG_RELEASE_ASSERT(call->getArg(i)->getDataType()); if (isPtrLikeOrHandleType(call->getArg(i)->getDataType())) { if (canAddressesPotentiallyAlias(func, call->getArg(i), addr)) -- cgit v1.2.3