From 77af111867eb72f26b460c5925be47aa22c71556 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 30 Jun 2022 19:24:24 -0400 Subject: Added `[__custom_jvp(func)]` attribute, and modified the derivative pass to only process referenced functions. (#2309) * Added JVPTranscriber to handle differentiation of load, store, var, param and return instructions, as well as conversion of data and function types * Changed class names to be more in line with convention. Added correct type checking for __jvp() and verified that simple calls with only loads and stores are processed correctly * Added logic to differentiate basic arithmetic and literals inside IRConstruct and fixed the way parameters are differentiated * Replaced some SLANG_UNEXPECTED macro uses with diagnostics instead * Added work-list-based on-demand generation of derivative functions * Fixed up a couple of TODOs * Added attribute [__custom_jvp(f)] to assign a custom derivative function to a declaration * Added a test for CustomJVPAttribute on a redeclaration of an imported function * Moving arithmetic test to new folder * Moving arithmetic test to new folder (2) * Added missing test module * Fixed a minor note Co-authored-by: Yong He --- source/slang/slang-check-modifier.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) (limited to 'source/slang/slang-check-modifier.cpp') diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index c6a33930b..28164c126 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -595,6 +595,17 @@ namespace Slang callablePayloadAttr->location = (int32_t)val->value; } + else if (auto customJVPAttr = as(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + + // Ensure that the argument is a reference to a function definition or declaration. + auto funcExpr = as(CheckTerm(attr->args[0])); + if (!as(funcExpr->type)) + return false; + + customJVPAttr->funcDeclRef = funcExpr; + } else { if(attr->args.getCount() == 0) -- cgit v1.2.3