From 004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 9 Nov 2022 19:19:17 -0800 Subject: Add `[ForwardDerivativeOf]` attribute. (#2501) * Add [ForwardDerivativeOf] attribute. * Fix handling around phi nodes. * Fixes. * Remove IR opcode for ForwardDerivativeOfDecoration. Co-authored-by: Yong He --- source/slang/slang-check-modifier.cpp | 100 +++++++--------------------------- 1 file changed, 19 insertions(+), 81 deletions(-) (limited to 'source/slang/slang-check-modifier.cpp') diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index d8b05198c..b8ac21e2d 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -617,92 +617,30 @@ namespace Slang getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as(attrTarget)); return false; } - - // Either diffExpr has a function type, or it is a reference to a generic. - if (!as(diffExpr->type) && - !(as(diffExpr) && - as(diffExpr)->declRef.as().getDecl() != nullptr)) - { - return false; - } - - auto diffDeclRef = as(diffExpr)->declRef; - - UCount genericLevels = 0; - // If we've grabbed the outer generic for some reason, - // recursively construct GenericAppExpr<...>(generic) - // and check that to get a specialized func. - // - while (diffDeclRef.as().getDecl() != nullptr) - { - // Forward to the inner decl - diffDeclRef = makeDeclRef(diffDeclRef.as().getDecl()->inner); - - // Increment counter. - genericLevels += 1; - } - - auto targetGeneric = as(as(attrTarget)->parentDecl); - auto diffGeneric = as(diffDeclRef.getDecl()->parentDecl); - Expr* currentDiffExpr = diffExpr; - - // Go back through each level, and use generic declarations in the - // target's generic scope as arguments for the diff function's generic. + // We store the partially checked funcExpr in the attribute, and + // rely on `ResolveInvoke` to resolve it to the actual function decl. + // The call to `ResolveInvoke` is deferred until we are checking the + // body of the function. // - for (UIndex ii = 0; ii < genericLevels; ii++) - { - // Nest our expression inside a GenericAppExpr - auto genericAppExpr = getASTBuilder()->create(); - genericAppExpr->functionExpr = currentDiffExpr; - - // Construct references to the generic args in the current scope. - // TODO: Probably an easier way to do this. - for (auto member : targetGeneric->members) - { - if (auto typeParamDecl = as(member)) - { - genericAppExpr->arguments.add( - ConstructDeclRefExpr(makeDeclRef(typeParamDecl), nullptr, typeParamDecl->loc, nullptr)); - } - else if (auto valueParamDecl = as(member)) - { - genericAppExpr->arguments.add( - ConstructDeclRefExpr(makeDeclRef(valueParamDecl), nullptr, valueParamDecl->loc, nullptr)); - } - } - - // Set our generic-app-expr as the new expr. - currentDiffExpr = genericAppExpr; - - // Peel the generic layer. - diffGeneric = as(diffGeneric->parentDecl); - targetGeneric = as(targetGeneric->parentDecl); - } - - if ((diffGeneric == nullptr && targetGeneric != nullptr) || - (targetGeneric == nullptr && diffGeneric != nullptr)) - { - //getSink()->diagnose(diffDeclRef, Slang::Diagnostics::customDerivativeGenericSignatureMismatch, diffDeclRef, attrTarget); - SLANG_UNEXPECTED(""); - } - - // If we had to change currentDiffExpr, then re-check the expr. - if (!currentDiffExpr->type) - { - currentDiffExpr = CheckTerm(currentDiffExpr); - } + // Set type to null to indicate that this needs expr needs to be further resolved. + diffExpr->type.type = nullptr; + forwardDerivativeAttr->funcExpr = diffExpr; + } + else if (auto forwardDerivativeOfAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + SLANG_ASSERT(as(attrTarget)); // Ensure that the argument is a reference to a function definition or declaration. - auto currentDiffDeclRefExpr = as(currentDiffExpr); - auto currentDiffDeclRef = currentDiffDeclRefExpr->declRef; - - if (!as(GetTypeForDeclRef(currentDiffDeclRef, currentDiffDeclRef.getLoc()))) + auto primalFunc = CheckTerm(attr->args[0]); + if (primalFunc->type == getASTBuilder()->getErrorType()) { - getSink()->diagnose(currentDiffDeclRef, Slang::Diagnostics::customDerivativeNotAFunction, currentDiffDeclRef); + // Could not resolve the term. + getSink()->diagnose(primalFunc, Slang::Diagnostics::invalidCustomDerivative, as(attrTarget)); + return false; } - - // TODO: Can possibly just store a DeclRef (no need for DeclRefExpr) - forwardDerivativeAttr->funcDeclRef = as(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr)); + + forwardDerivativeOfAttr->funcExpr = primalFunc; } else if (auto comInterfaceAttr = as(attr)) { -- cgit v1.2.3