diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-09 19:19:17 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-09 19:19:17 -0800 |
| commit | 004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 (patch) | |
| tree | cbc942746bab043da0eb5298993d95f9665dfddf /source/slang/slang-check-modifier.cpp | |
| parent | cedd93690c63188cf98e452c9d104cf51aad6c4e (diff) | |
Add `[ForwardDerivativeOf]` attribute. (#2501)
* Add [ForwardDerivativeOf] attribute.
* Fix handling around phi nodes.
* Fixes.
* Remove IR opcode for ForwardDerivativeOfDecoration.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-modifier.cpp')
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 100 |
1 files changed, 19 insertions, 81 deletions
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index d8b05198c..b8ac21e2d 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -617,92 +617,30 @@ namespace Slang getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget)); return false; } - - // Either diffExpr has a function type, or it is a reference to a generic. - if (!as<FuncType>(diffExpr->type) && - !(as<DeclRefExpr>(diffExpr) && - as<DeclRefExpr>(diffExpr)->declRef.as<GenericDecl>().getDecl() != nullptr)) - { - return false; - } - - auto diffDeclRef = as<DeclRefExpr>(diffExpr)->declRef; - - UCount genericLevels = 0; - // If we've grabbed the outer generic for some reason, - // recursively construct GenericAppExpr<...>(generic) - // and check that to get a specialized func. - // - while (diffDeclRef.as<GenericDecl>().getDecl() != nullptr) - { - // Forward to the inner decl - diffDeclRef = makeDeclRef(diffDeclRef.as<GenericDecl>().getDecl()->inner); - - // Increment counter. - genericLevels += 1; - } - - auto targetGeneric = as<GenericDecl>(as<Decl>(attrTarget)->parentDecl); - auto diffGeneric = as<GenericDecl>(diffDeclRef.getDecl()->parentDecl); - Expr* currentDiffExpr = diffExpr; - - // Go back through each level, and use generic declarations in the - // target's generic scope as arguments for the diff function's generic. + // We store the partially checked funcExpr in the attribute, and + // rely on `ResolveInvoke` to resolve it to the actual function decl. + // The call to `ResolveInvoke` is deferred until we are checking the + // body of the function. // - for (UIndex ii = 0; ii < genericLevels; ii++) - { - // Nest our expression inside a GenericAppExpr - auto genericAppExpr = getASTBuilder()->create<GenericAppExpr>(); - genericAppExpr->functionExpr = currentDiffExpr; - - // Construct references to the generic args in the current scope. - // TODO: Probably an easier way to do this. - for (auto member : targetGeneric->members) - { - if (auto typeParamDecl = as<GenericTypeParamDecl>(member)) - { - genericAppExpr->arguments.add( - ConstructDeclRefExpr(makeDeclRef(typeParamDecl), nullptr, typeParamDecl->loc, nullptr)); - } - else if (auto valueParamDecl = as<GenericValueParamDecl>(member)) - { - genericAppExpr->arguments.add( - ConstructDeclRefExpr(makeDeclRef(valueParamDecl), nullptr, valueParamDecl->loc, nullptr)); - } - } - - // Set our generic-app-expr as the new expr. - currentDiffExpr = genericAppExpr; - - // Peel the generic layer. - diffGeneric = as<GenericDecl>(diffGeneric->parentDecl); - targetGeneric = as<GenericDecl>(targetGeneric->parentDecl); - } - - if ((diffGeneric == nullptr && targetGeneric != nullptr) || - (targetGeneric == nullptr && diffGeneric != nullptr)) - { - //getSink()->diagnose(diffDeclRef, Slang::Diagnostics::customDerivativeGenericSignatureMismatch, diffDeclRef, attrTarget); - SLANG_UNEXPECTED(""); - } - - // If we had to change currentDiffExpr, then re-check the expr. - if (!currentDiffExpr->type) - { - currentDiffExpr = CheckTerm(currentDiffExpr); - } + // Set type to null to indicate that this needs expr needs to be further resolved. + diffExpr->type.type = nullptr; + forwardDerivativeAttr->funcExpr = diffExpr; + } + else if (auto forwardDerivativeOfAttr = as<ForwardDerivativeOfAttribute>(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + SLANG_ASSERT(as<Decl>(attrTarget)); // Ensure that the argument is a reference to a function definition or declaration. - auto currentDiffDeclRefExpr = as<DeclRefExpr>(currentDiffExpr); - auto currentDiffDeclRef = currentDiffDeclRefExpr->declRef; - - if (!as<FuncType>(GetTypeForDeclRef(currentDiffDeclRef, currentDiffDeclRef.getLoc()))) + auto primalFunc = CheckTerm(attr->args[0]); + if (primalFunc->type == getASTBuilder()->getErrorType()) { - getSink()->diagnose(currentDiffDeclRef, Slang::Diagnostics::customDerivativeNotAFunction, currentDiffDeclRef); + // Could not resolve the term. + getSink()->diagnose(primalFunc, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget)); + return false; } - - // TODO: Can possibly just store a DeclRef (no need for DeclRefExpr) - forwardDerivativeAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr)); + + forwardDerivativeOfAttr->funcExpr = primalFunc; } else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr)) { |
