diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-10-26 22:21:29 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-26 19:21:29 -0700 |
| commit | f7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 (patch) | |
| tree | 574dff2bcb8c5a3de9e74d18346a424c82d62a7a /source/slang/slang-check-modifier.cpp | |
| parent | 939be44ca23476e622dfb24a592383fe2a1da61f (diff) | |
Adding a differentiable standard library (#2465)
Diffstat (limited to 'source/slang/slang-check-modifier.cpp')
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 97 |
1 files changed, 93 insertions, 4 deletions
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 7e11ee3ca..20e5d5378 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -628,13 +628,102 @@ namespace Slang else if (auto customJVPAttr = as<CustomJVPAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); - + SLANG_ASSERT(as<Decl>(attrTarget)); + // Ensure that the argument is a reference to a function definition or declaration. - auto funcExpr = as<DeclRefExpr>(CheckTerm(attr->args[0])); - if (!as<FuncType>(funcExpr->type)) + auto diffExpr = CheckTerm(attr->args[0]); + if (diffExpr->type == getASTBuilder()->getErrorType()) + { + // Could not resolve the term. + getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget)); + return false; + } + + // Either diffExpr has a function type, or it is a reference to a generic. + if (!as<FuncType>(diffExpr->type) && + !(as<DeclRefExpr>(diffExpr) && + as<DeclRefExpr>(diffExpr)->declRef.as<GenericDecl>().getDecl() != nullptr)) + { return false; + } - customJVPAttr->funcDeclRef = funcExpr; + auto diffDeclRef = as<DeclRefExpr>(diffExpr)->declRef; + + UCount genericLevels = 0; + // If we've grabbed the outer generic for some reason, + // recursively construct GenericAppExpr<...>(generic) + // and check that to get a specialized func. + // + while (diffDeclRef.as<GenericDecl>().getDecl() != nullptr) + { + // Forward to the inner decl + diffDeclRef = makeDeclRef(diffDeclRef.as<GenericDecl>().getDecl()->inner); + + // Increment counter. + genericLevels += 1; + } + + auto targetGeneric = as<GenericDecl>(as<Decl>(attrTarget)->parentDecl); + auto diffGeneric = as<GenericDecl>(diffDeclRef.getDecl()->parentDecl); + Expr* currentDiffExpr = diffExpr; + + // Go back through each level, and use generic declarations in the + // target's generic scope as arguments for the diff function's generic. + // + for (UIndex ii = 0; ii < genericLevels; ii++) + { + // Nest our expression inside a GenericAppExpr + auto genericAppExpr = getASTBuilder()->create<GenericAppExpr>(); + genericAppExpr->functionExpr = currentDiffExpr; + + // Construct references to the generic args in the current scope. + // TODO: Probably an easier way to do this. + for (auto member : targetGeneric->members) + { + if (auto typeParamDecl = as<GenericTypeParamDecl>(member)) + { + genericAppExpr->arguments.add( + ConstructDeclRefExpr(makeDeclRef(typeParamDecl), nullptr, typeParamDecl->loc, nullptr)); + } + else if (auto valueParamDecl = as<GenericValueParamDecl>(member)) + { + genericAppExpr->arguments.add( + ConstructDeclRefExpr(makeDeclRef(valueParamDecl), nullptr, valueParamDecl->loc, nullptr)); + } + } + + // Set our generic-app-expr as the new expr. + currentDiffExpr = genericAppExpr; + + // Peel the generic layer. + diffGeneric = as<GenericDecl>(diffGeneric->parentDecl); + targetGeneric = as<GenericDecl>(targetGeneric->parentDecl); + } + + if ((diffGeneric == nullptr && targetGeneric != nullptr) || + (targetGeneric == nullptr && diffGeneric != nullptr)) + { + //getSink()->diagnose(diffDeclRef, Slang::Diagnostics::customDerivativeGenericSignatureMismatch, diffDeclRef, attrTarget); + SLANG_UNEXPECTED(""); + } + + // If we had to change currentDiffExpr, then re-check the expr. + if (!currentDiffExpr->type) + { + currentDiffExpr = CheckTerm(currentDiffExpr); + } + + // Ensure that the argument is a reference to a function definition or declaration. + auto currentDiffDeclRefExpr = as<DeclRefExpr>(currentDiffExpr); + auto currentDiffDeclRef = currentDiffDeclRefExpr->declRef; + + if (!as<FuncType>(GetTypeForDeclRef(currentDiffDeclRef, currentDiffDeclRef.getLoc()))) + { + getSink()->diagnose(currentDiffDeclRef, Slang::Diagnostics::customDerivativeNotAFunction, currentDiffDeclRef); + } + + // TODO: Can possibly just store a DeclRef (no need for DeclRefExpr) + customJVPAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr)); } else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr)) { |
