diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-06-16 11:03:59 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-06-16 11:03:59 -0400 |
| commit | d2a467c7a941c4453b3d825c9d5bb4d72230c8ba (patch) | |
| tree | 79a72dcdb48436b86b87216a36e4963dcaa36235 | |
| parent | 37c43e20fc5ab42d26695c990edf2835952087c8 (diff) | |
Added a decorator to mark functions for forward-mode differentiation (#2283)
| -rw-r--r-- | source/slang/core.meta.slang | 5 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 4 |
5 files changed, 18 insertions, 0 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 41cfea6af..721529726 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -71,6 +71,11 @@ syntax snorm : SNormModifier; /// syntax __extern_cpp : ExternCppModifier; +/// Modifer to mark a function for forward-mode differentiation. +/// i.e. the compiler will automatically generate a new function +/// that computes the jacobian-vector product of the original. +syntax __differentiate_jvp : JVPDerivativeModifier; + /// A type that can be used as an operand for builtins [sealed] [builtin] diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 012c74377..d0f215cbe 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -30,6 +30,7 @@ class ExportedModifier : public Modifier { SLANG_AST_CLASS(ExportedModifier)}; class ConstExprModifier : public Modifier { SLANG_AST_CLASS(ConstExprModifier)}; class GloballyCoherentModifier : public Modifier { SLANG_AST_CLASS(GloballyCoherentModifier)}; class ExternCppModifier : public Modifier { SLANG_AST_CLASS(ExternCppModifier)}; +class JVPDerivativeModifier : public Modifier { SLANG_AST_CLASS(JVPDerivativeModifier)}; // An 'ActualGlobal' is a global that is output as a normal global in CPU code. // Globals in HLSL/Slang are constant state passed into kernel execution diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 6547d949e..4d927bdaf 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -661,6 +661,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(SPIRVOpDecoration, spirvOpDecoration, 1, 0) + /// Decorated function is marked for the forward-mode differentiation pass. + INST(JVPDerivativeDecoration, differentiateJvp, 0, 0) + /// Marks a struct type as being used as a structured buffer block. /// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration. INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 5e8e11f84..80438504c 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2967,6 +2967,11 @@ public: addDecoration(value, kIROp_ExternCppDecoration, getStringValue(mangledName)); } + void addJVPDerivativeDecoration(IRInst* value, UnownedStringSlice const& mangledName) + { + addDecoration(value, kIROp_JVPDerivativeDecoration, getStringValue(mangledName)); + } + void addDllImportDecoration(IRInst* value, UnownedStringSlice const& libraryName, UnownedStringSlice const& functionName) { addDecoration(value, kIROp_DllImportDecoration, getStringValue(libraryName), getStringValue(functionName)); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 86edf9282..791180890 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1149,6 +1149,10 @@ static void addLinkageDecoration( { builder->addExternCppDecoration(inst, mangledName); } + if (decl->findModifier<JVPDerivativeModifier>()) + { + builder->addJVPDerivativeDecoration(inst, mangledName); + } if (as<InterfaceDecl>(decl->parentDecl) && decl->parentDecl->hasModifier<ComInterfaceAttribute>()) { |
