summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-06-16 11:03:59 -0400
committerGitHub <noreply@github.com>2022-06-16 11:03:59 -0400
commitd2a467c7a941c4453b3d825c9d5bb4d72230c8ba (patch)
tree79a72dcdb48436b86b87216a36e4963dcaa36235 /source
parent37c43e20fc5ab42d26695c990edf2835952087c8 (diff)
Added a decorator to mark functions for forward-mode differentiation (#2283)
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang5
-rw-r--r--source/slang/slang-ast-modifier.h1
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h5
-rw-r--r--source/slang/slang-lower-to-ir.cpp4
5 files changed, 18 insertions, 0 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 41cfea6af..721529726 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -71,6 +71,11 @@ syntax snorm : SNormModifier;
///
syntax __extern_cpp : ExternCppModifier;
+/// Modifer to mark a function for forward-mode differentiation.
+/// i.e. the compiler will automatically generate a new function
+/// that computes the jacobian-vector product of the original.
+syntax __differentiate_jvp : JVPDerivativeModifier;
+
/// A type that can be used as an operand for builtins
[sealed]
[builtin]
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 012c74377..d0f215cbe 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -30,6 +30,7 @@ class ExportedModifier : public Modifier { SLANG_AST_CLASS(ExportedModifier)};
class ConstExprModifier : public Modifier { SLANG_AST_CLASS(ConstExprModifier)};
class GloballyCoherentModifier : public Modifier { SLANG_AST_CLASS(GloballyCoherentModifier)};
class ExternCppModifier : public Modifier { SLANG_AST_CLASS(ExternCppModifier)};
+class JVPDerivativeModifier : public Modifier { SLANG_AST_CLASS(JVPDerivativeModifier)};
// An 'ActualGlobal' is a global that is output as a normal global in CPU code.
// Globals in HLSL/Slang are constant state passed into kernel execution
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 6547d949e..4d927bdaf 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -661,6 +661,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(SPIRVOpDecoration, spirvOpDecoration, 1, 0)
+ /// Decorated function is marked for the forward-mode differentiation pass.
+ INST(JVPDerivativeDecoration, differentiateJvp, 0, 0)
+
/// Marks a struct type as being used as a structured buffer block.
/// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration.
INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 5e8e11f84..80438504c 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2967,6 +2967,11 @@ public:
addDecoration(value, kIROp_ExternCppDecoration, getStringValue(mangledName));
}
+ void addJVPDerivativeDecoration(IRInst* value, UnownedStringSlice const& mangledName)
+ {
+ addDecoration(value, kIROp_JVPDerivativeDecoration, getStringValue(mangledName));
+ }
+
void addDllImportDecoration(IRInst* value, UnownedStringSlice const& libraryName, UnownedStringSlice const& functionName)
{
addDecoration(value, kIROp_DllImportDecoration, getStringValue(libraryName), getStringValue(functionName));
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 86edf9282..791180890 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1149,6 +1149,10 @@ static void addLinkageDecoration(
{
builder->addExternCppDecoration(inst, mangledName);
}
+ if (decl->findModifier<JVPDerivativeModifier>())
+ {
+ builder->addJVPDerivativeDecoration(inst, mangledName);
+ }
if (as<InterfaceDecl>(decl->parentDecl) &&
decl->parentDecl->hasModifier<ComInterfaceAttribute>())
{