From f7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 26 Oct 2022 22:21:29 -0400 Subject: Adding a differentiable standard library (#2465) --- source/slang/slang-check-modifier.cpp | 97 +++++++++++++++++++++++++++++++++-- 1 file changed, 93 insertions(+), 4 deletions(-) (limited to 'source/slang/slang-check-modifier.cpp') diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 7e11ee3ca..20e5d5378 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -628,13 +628,102 @@ namespace Slang else if (auto customJVPAttr = as(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); - + SLANG_ASSERT(as(attrTarget)); + // Ensure that the argument is a reference to a function definition or declaration. - auto funcExpr = as(CheckTerm(attr->args[0])); - if (!as(funcExpr->type)) + auto diffExpr = CheckTerm(attr->args[0]); + if (diffExpr->type == getASTBuilder()->getErrorType()) + { + // Could not resolve the term. + getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as(attrTarget)); + return false; + } + + // Either diffExpr has a function type, or it is a reference to a generic. + if (!as(diffExpr->type) && + !(as(diffExpr) && + as(diffExpr)->declRef.as().getDecl() != nullptr)) + { return false; + } - customJVPAttr->funcDeclRef = funcExpr; + auto diffDeclRef = as(diffExpr)->declRef; + + UCount genericLevels = 0; + // If we've grabbed the outer generic for some reason, + // recursively construct GenericAppExpr<...>(generic) + // and check that to get a specialized func. + // + while (diffDeclRef.as().getDecl() != nullptr) + { + // Forward to the inner decl + diffDeclRef = makeDeclRef(diffDeclRef.as().getDecl()->inner); + + // Increment counter. + genericLevels += 1; + } + + auto targetGeneric = as(as(attrTarget)->parentDecl); + auto diffGeneric = as(diffDeclRef.getDecl()->parentDecl); + Expr* currentDiffExpr = diffExpr; + + // Go back through each level, and use generic declarations in the + // target's generic scope as arguments for the diff function's generic. + // + for (UIndex ii = 0; ii < genericLevels; ii++) + { + // Nest our expression inside a GenericAppExpr + auto genericAppExpr = getASTBuilder()->create(); + genericAppExpr->functionExpr = currentDiffExpr; + + // Construct references to the generic args in the current scope. + // TODO: Probably an easier way to do this. + for (auto member : targetGeneric->members) + { + if (auto typeParamDecl = as(member)) + { + genericAppExpr->arguments.add( + ConstructDeclRefExpr(makeDeclRef(typeParamDecl), nullptr, typeParamDecl->loc, nullptr)); + } + else if (auto valueParamDecl = as(member)) + { + genericAppExpr->arguments.add( + ConstructDeclRefExpr(makeDeclRef(valueParamDecl), nullptr, valueParamDecl->loc, nullptr)); + } + } + + // Set our generic-app-expr as the new expr. + currentDiffExpr = genericAppExpr; + + // Peel the generic layer. + diffGeneric = as(diffGeneric->parentDecl); + targetGeneric = as(targetGeneric->parentDecl); + } + + if ((diffGeneric == nullptr && targetGeneric != nullptr) || + (targetGeneric == nullptr && diffGeneric != nullptr)) + { + //getSink()->diagnose(diffDeclRef, Slang::Diagnostics::customDerivativeGenericSignatureMismatch, diffDeclRef, attrTarget); + SLANG_UNEXPECTED(""); + } + + // If we had to change currentDiffExpr, then re-check the expr. + if (!currentDiffExpr->type) + { + currentDiffExpr = CheckTerm(currentDiffExpr); + } + + // Ensure that the argument is a reference to a function definition or declaration. + auto currentDiffDeclRefExpr = as(currentDiffExpr); + auto currentDiffDeclRef = currentDiffDeclRefExpr->declRef; + + if (!as(GetTypeForDeclRef(currentDiffDeclRef, currentDiffDeclRef.getLoc()))) + { + getSink()->diagnose(currentDiffDeclRef, Slang::Diagnostics::customDerivativeNotAFunction, currentDiffDeclRef); + } + + // TODO: Can possibly just store a DeclRef (no need for DeclRefExpr) + customJVPAttr->funcDeclRef = as(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr)); } else if (auto comInterfaceAttr = as(attr)) { -- cgit v1.2.3