summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-modifier.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-06-30 19:24:24 -0400
committerGitHub <noreply@github.com>2022-06-30 19:24:24 -0400
commit77af111867eb72f26b460c5925be47aa22c71556 (patch)
treeb516734ccec92f01eaa07a7844b3862b3c5ab628 /source/slang/slang-check-modifier.cpp
parent2c09275388d4c88ea26bf709132b8be4a9e342bc (diff)
Added `[__custom_jvp(func)]` attribute, and modified the derivative pass to only process referenced functions. (#2309)
* Added JVPTranscriber to handle differentiation of load, store, var, param and return instructions, as well as conversion of data and function types * Changed class names to be more in line with convention. Added correct type checking for __jvp() and verified that simple calls with only loads and stores are processed correctly * Added logic to differentiate basic arithmetic and literals inside IRConstruct and fixed the way parameters are differentiated * Replaced some SLANG_UNEXPECTED macro uses with diagnostics instead * Added work-list-based on-demand generation of derivative functions * Fixed up a couple of TODOs * Added attribute [__custom_jvp(f)] to assign a custom derivative function to a declaration * Added a test for CustomJVPAttribute on a redeclaration of an imported function * Moving arithmetic test to new folder * Moving arithmetic test to new folder (2) * Added missing test module * Fixed a minor note Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-check-modifier.cpp')
-rw-r--r--source/slang/slang-check-modifier.cpp11
1 files changed, 11 insertions, 0 deletions
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index c6a33930b..28164c126 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -595,6 +595,17 @@ namespace Slang
callablePayloadAttr->location = (int32_t)val->value;
}
+ else if (auto customJVPAttr = as<CustomJVPAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
+
+ // Ensure that the argument is a reference to a function definition or declaration.
+ auto funcExpr = as<DeclRefExpr>(CheckTerm(attr->args[0]));
+ if (!as<FuncType>(funcExpr->type))
+ return false;
+
+ customJVPAttr->funcDeclRef = funcExpr;
+ }
else
{
if(attr->args.getCount() == 0)