diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-10 09:01:59 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-10 09:01:59 -0800 |
| commit | 6e7b424953ae6732d4863e887e7e452396095d71 (patch) | |
| tree | 883ca4d168ba679bd7ffe197fa765c8b42b19c6b | |
| parent | df02f3f50f977112ca1fbb148cd48ee41d560f41 (diff) | |
Fix checking of `[BackwardDerivativeOf]` attribute. (#2640)
* Fix checking of `[BackwardDerivativeOf]` attribute.
* Fix crash in `canInstHaveSideEffectAtAddress`.
* Fix.
* Revert fix.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 123 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 3 |
4 files changed, 101 insertions, 29 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 9bda6c3e7..7e8e94d95 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -4771,6 +4771,9 @@ namespace Slang List<Expr*> getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* bwdDiffFunc, SourceLoc loc) { + // Note: it isn't always possible to construct original arguments from + // backward propagation arguments because backward propagation function + // may drop certain parameters. List<Expr*> imaginaryArguments; for (auto param : bwdDiffFunc->getParameters()) { @@ -4789,6 +4792,12 @@ namespace Slang } imaginaryArguments.add(arg); } + // Assume the last parameter is `dOut`. + // This is not true if the function returns a non-differentiable value. + // However in that uncommon case we just fail the overload resolution + // and require the user to provide disambiguate themselves. + if (imaginaryArguments.getCount()) + imaginaryArguments.fastRemoveAt(imaginaryArguments.getCount() - 1); return imaginaryArguments; } @@ -4808,41 +4817,99 @@ namespace Slang DeclAssociationKind assocKind, const List<Expr*>& imaginaryArgsToOriginal) { - auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal); - auto resolved = visitor->ResolveInvoke(invokeExpr); - if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + DeclRef<Decl> calleeDeclRef; + auto calleeDeclRefExpr = as<DeclRefExpr>(derivativeOfAttr->funcExpr); + if (!calleeDeclRefExpr) { - if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) + auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal); + auto resolved = visitor->ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + { + calleeDeclRefExpr = as<DeclRefExpr>(resolvedInvoke->functionExpr); + } + } + if (!calleeDeclRefExpr) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + calleeDeclRef = calleeDeclRefExpr->declRef; + if (auto calleeGenDecl = as<GenericDecl>(calleeDeclRef.getDecl())) + { + auto parentGenericDecl = as<GenericDecl>(funcDecl->parentDecl); + if (!parentGenericDecl) { - auto calleeDecl = calleeDeclRef->declRef.getDecl(); - if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeDecl)) + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; + } + FunctionDeclBase* funcReturnVal = nullptr; + List<Val*> args; + for (auto mm : parentGenericDecl->members) + { + if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm)) { - // The primal function already has a `[*Derivative]` attribute, this is invalid. - visitor->getSink()->diagnose( - derivativeOfAttr, - Diagnostics::declAlreadyHasAttribute, - calleeDeclRef->declRef, - getDerivativeAttrName<TDerivativeAttr>()); - visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl()); + args.add(DeclRefType::create(visitor->getASTBuilder(), DeclRef<Decl>(genericTypeParamDecl, nullptr))); } - derivativeOfAttr->funcExpr = calleeDeclRef; - auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>(); - derivativeAttr->loc = derivativeOfAttr->loc; - auto outterGeneric = visitor->GetOuterGeneric(funcDecl); - auto declRef = - DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); - auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); - declRefExpr->type.type = nullptr; - derivativeAttr->args.add(declRefExpr); - derivativeAttr->funcExpr = declRefExpr; - checkDerivativeAttribute(visitor, as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), derivativeAttr); - derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr; - derivativeAttr->funcExpr = nullptr; - visitor->getShared()->registerAssociatedDecl(calleeDeclRef->declRef.getDecl(), assocKind, funcDecl); + else if (auto genericValueParamDecl = as<GenericValueParamDecl>(mm)) + { + args.add(visitor->getASTBuilder()->getOrCreate<GenericParamIntVal>( + genericValueParamDecl->getType(), + genericValueParamDecl, nullptr)); + } + } + auto funcs = calleeGenDecl->getMembersOfType<FunctionDeclBase>(); + if (funcs.isEmpty()) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); return; } + funcReturnVal = funcs.getFirst(); + if (funcReturnVal) + { + auto subst = visitor->getASTBuilder()->getOrCreateGenericSubstitution(calleeGenDecl, args, nullptr); + calleeDeclRef.decl = funcReturnVal; + calleeDeclRef.substitutions = subst; + calleeDeclRefExpr = as<DeclRefExpr>(visitor->ConstructDeclRefExpr( + calleeDeclRef, nullptr, derivativeOfAttr->loc, nullptr)); + } + else + { + calleeDeclRef = DeclRef<Decl>(); + calleeDeclRefExpr = nullptr; + } + } + + auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl()); + if (!calleeFunc) + { + visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); + return; } - visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::invalidCustomDerivative); + + if (auto existingModifier = _findModifier<TDerivativeAttr>(calleeFunc)) + { + // The primal function already has a `[*Derivative]` attribute, this is invalid. + visitor->getSink()->diagnose( + derivativeOfAttr, + Diagnostics::declAlreadyHasAttribute, + calleeDeclRef, + getDerivativeAttrName<TDerivativeAttr>()); + visitor->getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef.getDecl()); + } + derivativeOfAttr->funcExpr = calleeDeclRefExpr; + auto derivativeAttr = visitor->getASTBuilder()->create<TDerivativeAttr>(); + derivativeAttr->loc = derivativeOfAttr->loc; + auto outterGeneric = visitor->GetOuterGeneric(funcDecl); + auto declRef = + DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); + auto declRefExpr = visitor->ConstructDeclRefExpr(declRef, nullptr, derivativeOfAttr->loc, nullptr); + declRefExpr->type.type = nullptr; + derivativeAttr->args.add(declRefExpr); + derivativeAttr->funcExpr = declRefExpr; + checkDerivativeAttribute(visitor, calleeFunc, derivativeAttr); + derivativeOfAttr->backDeclRef = derivativeAttr->funcExpr; + derivativeAttr->funcExpr = nullptr; + visitor->getShared()->registerAssociatedDecl(calleeDeclRef.getDecl(), assocKind, funcDecl); } static void checkDerivativeAttribute(SemanticsVisitor* visitor, FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 0d4088d75..d3731756a 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -349,6 +349,7 @@ DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original defi DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative attribute.") DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.") +DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.") // Enums diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 3e7e346d2..25f6c3964 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -256,7 +256,8 @@ struct ExtractPrimalFuncContext auto structKey = genTypeBuilder.createStructKey(); genTypeBuilder.setInsertInto(structType); - if (isChildInstOf(fieldType->getParent(), structType->getParent())) + if (fieldType->getParent() != structType->getParent() && + isChildInstOf(fieldType->getParent(), structType->getParent())) { IRCloneEnv cloneEnv; fieldType = cloneInst(&cloneEnv, &genTypeBuilder, fieldType); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index af6fd8ac4..942c8f2f8 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -364,6 +364,8 @@ bool canAddressesPotentiallyAlias(IRGlobalValueWithCode* func, IRInst* addr1, IR bool isPtrLikeOrHandleType(IRInst* type) { + if (!type) + return false; switch (type->getOp()) { case kIROp_ComPtrType: @@ -413,6 +415,7 @@ bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, I // If any pointer typed argument of the call inst may overlap addr, return true. for (UInt i = 0; i < call->getArgCount(); i++) { + SLANG_RELEASE_ASSERT(call->getArg(i)->getDataType()); if (isPtrLikeOrHandleType(call->getArg(i)->getDataType())) { if (canAddressesPotentiallyAlias(func, call->getArg(i), addr)) |
