diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-06-30 19:24:24 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-06-30 19:24:24 -0400 |
| commit | 77af111867eb72f26b460c5925be47aa22c71556 (patch) | |
| tree | b516734ccec92f01eaa07a7844b3862b3c5ab628 /source | |
| parent | 2c09275388d4c88ea26bf709132b8be4a9e342bc (diff) | |
Added `[__custom_jvp(func)]` attribute, and modified the derivative pass to only process referenced functions. (#2309)
* Added JVPTranscriber to handle differentiation of load, store, var, param and return instructions, as well as conversion of data and function types
* Changed class names to be more in line with convention. Added correct type checking for __jvp() and verified that simple calls with only loads and stores are processed correctly
* Added logic to differentiate basic arithmetic and literals inside IRConstruct and fixed the way parameters are differentiated
* Replaced some SLANG_UNEXPECTED macro uses with diagnostics instead
* Added work-list-based on-demand generation of derivative functions
* Fixed up a couple of TODOs
* Added attribute [__custom_jvp(f)] to assign a custom derivative function to a declaration
* Added a test for CustomJVPAttribute on a redeclaration of an imported function
* Moving arithmetic test to new folder
* Moving arithmetic test to new folder (2)
* Added missing test module
* Fixed a minor note
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 184 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 10 |
7 files changed, 200 insertions, 24 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 721529726..e1f7503a8 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2212,3 +2212,6 @@ attribute_syntax [noinline] : NoInlineAttribute; __attributeTarget(StructDecl) attribute_syntax [payload] : PayloadAttribute; +// Custom JVP Function reference +__attributeTarget(FuncDecl) +attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index d0f215cbe..62a32045b 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -955,6 +955,15 @@ class RequiresNVAPIAttribute : public Attribute SLANG_AST_CLASS(RequiresNVAPIAttribute) }; + /// The `[__custom_jvp(function)]` attribute specifies a custom function that should + /// be used as the derivative for the decorated function. +class CustomJVPAttribute : public Attribute +{ + SLANG_AST_CLASS(CustomJVPAttribute) + + DeclRefExpr* funcDeclRef; +}; + /// Indicates that the modified declaration is one of the "magic" declarations /// that NVAPI uses to communicate extended operations. When NVAPI is being included /// via the prelude for downstream compilation, declarations with this modifier @@ -1052,4 +1061,6 @@ class SNormModifier : public ResourceElementFormatModifier SLANG_AST_CLASS(SNormModifier) }; + + } // namespace Slang diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index c6a33930b..28164c126 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -595,6 +595,17 @@ namespace Slang callablePayloadAttr->location = (int32_t)val->value; } + else if (auto customJVPAttr = as<CustomJVPAttribute>(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + + // Ensure that the argument is a reference to a function definition or declaration. + auto funcExpr = as<DeclRefExpr>(CheckTerm(attr->args[0])); + if (!as<FuncType>(funcExpr->type)) + return false; + + customJVPAttr->funcDeclRef = funcExpr; + } else { if(attr->args.getCount() == 0) diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 2a42a7b6e..00210daaa 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -13,11 +13,20 @@ struct JVPTranscriber // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent // their differential values. - Dictionary<IRInst*, IRInst*> instMapD; + Dictionary<IRInst*, IRInst*> instMapD; // Cloning environment to hold mapping from old to new copies for the primal // instructions. - IRCloneEnv cloneEnv; + IRCloneEnv cloneEnv; + + // Diagnostic sink for error messages. + DiagnosticSink* sink; + + DiagnosticSink* getSink() + { + SLANG_ASSERT(sink); + return sink; + } void mapDifferentialInst(IRInst* instP, IRInst* instD) { @@ -156,7 +165,9 @@ struct JVPTranscriber rightP->getDataType(), rightP, rightP )); default: - SLANG_UNEXPECTED("Attempting to differentiate unsupported arithmetic"); + getSink()->diagnose(arith->sourceLoc, + Diagnostics::unimplemented, + "this arithmetic instruction cannot be differentiated"); } } @@ -179,8 +190,11 @@ struct JVPTranscriber } return nullptr; } - - SLANG_UNEXPECTED("Attempting to differentiate an unsupported load instruction"); + else + getSink()->diagnose(loadP->sourceLoc, + Diagnostics::unimplemented, + "this load instruction cannot be differentiated"); + return nullptr; } IRInst* differentiateStore(IRBuilder* builder, IRStore* storeP) @@ -204,8 +218,11 @@ struct JVPTranscriber } return nullptr; } - - SLANG_UNEXPECTED("Attempting to differentiate an unsupported store instruction"); + else + getSink()->diagnose(storeP->sourceLoc, + Diagnostics::unimplemented, + "this store instruction cannot be differentiated"); + return nullptr; } IRInst* differentiateReturn(IRBuilder* builder, IRReturn* returnP) @@ -226,12 +243,13 @@ struct JVPTranscriber // IRInst* differentiateConstruct(IRBuilder*, IRInst* consP) { - if (as<IRConstant>(consP->getOperand(0)) && consP->getOperandCount() == 1) - { return nullptr; - } - SLANG_UNEXPECTED("Attempting to differentiate unsupported constructor"); + else + getSink()->diagnose(consP->sourceLoc, + Diagnostics::unimplemented, + "this construct instruction cannot be differentiated"); + return nullptr; } // Logic for whether a primal instruction needs to be replicated @@ -242,13 +260,9 @@ struct JVPTranscriber bool requiresPrimalClone(IRBuilder*, IRInst* instP) { if (as<IRReturn>(instP)) - { return false; - } else - { return true; - } } IRInst* transcribe(IRBuilder* builder, IRInst* oldInstP) @@ -295,11 +309,52 @@ struct JVPTranscriber return differentiateConstruct(builder, instP); default: - SLANG_UNEXPECTED("Attempting to differentiate unrecognized instruction"); + getSink()->diagnose(instP->sourceLoc, + Diagnostics::unimplemented, + "this instruction cannot be differentiated"); + return nullptr; } } }; +struct IRWorkQueue +{ + // Work list to hold the active set of insts whose children + // need to be looked at. + // + List<IRInst*> workList; + HashSet<IRInst*> workListSet; + + void push(IRInst* inst) + { + if(!inst) return; + if(workListSet.Contains(inst)) return; + + workList.add(inst); + workListSet.Add(inst); + } + + IRInst* pop() + { + if (workList.getCount() != 0) + { + IRInst* topItem = workList.getFirst(); + // TODO(Sai): Repeatedly calling removeAt() can be really slow. + // Consider a specialized data structure or using removeLast() + // + workList.removeAt(0); + workListSet.Remove(topItem); + return topItem; + } + return nullptr; + } + + IRInst* peek() + { + return workList.getFirst(); + } +}; + struct JVPDerivativeContext { // This type passes over the module and generates @@ -315,6 +370,19 @@ struct JVPDerivativeContext // processing instructions while maintaining state. // JVPTranscriber transcriberStorage; + + // Diagnostic object from the compile request for + // error messages. + DiagnosticSink* sink; + + // Work queue to hold a stream of instructions that need + // to be checked for references to derivative functions. + IRWorkQueue workQueueStorage; + + DiagnosticSink* getSink() + { + return sink; + } bool processModule() { @@ -324,14 +392,81 @@ struct JVPDerivativeContext // SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->init(module); - - // Run through all the global-level instructions, - // looking for callables. - // Note: We're only processing global callables (IRGlobalValueWithCode) - // for now. - // + IRBuilder builderStorage(sharedBuilderStorage); IRBuilder* builder = &builderStorage; + + // processMarkedGlobalFunctions(builder); + return processReferencedFunctions(builder); + } + + IRInst* lookupJVPReference(IRInst* primalFunction) + { + if(auto jvpDefinition = primalFunction->findDecoration<IRJVPDerivativeReferenceDecoration>()) + return jvpDefinition->getJVPFunc(); + + return nullptr; + } + + // Recursively process instructions looking for JVP calls (kIROp_JVPDifferentiate), + // then check that the referenced function is marked correctly for differentiation. + // + bool processReferencedFunctions(IRBuilder* builder) + { + IRWorkQueue* workQueue = &(workQueueStorage); + + // Put the top-level inst into the queue. + workQueue->push(module->getModuleInst()); + + // Keep processing items until the queue is complete. + while (IRInst* workItem = workQueue->pop()) + { + for(auto child = workItem->getFirstChild(); child; child = child->getNextInst()) + { + // Either the child instruction has more children (func/block etc..) + // and we add it to the work list for further processing, or + // it's an ordinary inst in which case we check if it's a JVPDifferentiate + // instruction. + // + if (child->getFirstChild() != nullptr) + workQueue->push(child); + + if (auto jvpDiffInst = as<IRJVPDifferentiate>(child)) + { + auto baseFunction = jvpDiffInst->getBaseFn(); + // If the JVP Reference already exists, no need to + // differentiate again. + // + if(lookupJVPReference(baseFunction)) continue; + + if (isFunctionMarkedForJVP(as<IRGlobalValueWithCode>(baseFunction))) + { + IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(baseFunction)); + builder->addJVPDerivativeReferenceDecoration(baseFunction, jvpFunction); + workQueue->push(jvpFunction); + } + else + { + // TODO(Sai): This would probably be better with a more specific + // error code. + getSink()->diagnose(jvpDiffInst->sourceLoc, + Diagnostics::internalCompilerError, + "Cannot differentiate functions not marked for differentiation"); + } + } + } + } + + return true; + } + + // Run through all the global-level instructions, + // looking for callables. + // Note: We're only processing global callables (IRGlobalValueWithCode) + // for now. + // + bool processMarkedGlobalFunctions(IRBuilder* builder) + { for (auto inst : module->getGlobalInsts()) { // If the instr is a callable, get all the basic blocks @@ -340,7 +475,8 @@ struct JVPDerivativeContext if (isFunctionMarkedForJVP(callable)) { SLANG_ASSERT(as<IRFunc>(callable)); - IRFunc* jvpFunction = emitJVPFunction(&builderStorage, as<IRFunc>(callable)); + + IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(callable)); builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction); unmarkForJVP(callable); @@ -485,10 +621,12 @@ struct JVPDerivativeContext // bool processJVPDerivativeMarkers( IRModule* module, + DiagnosticSink* sink, IRJVPDerivativePassOptions const&) { JVPDerivativeContext context; context.module = module; + context.sink = sink; return context.processModule(); } diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h index 9bbcad4fb..8ae6e949a 100644 --- a/source/slang/slang-ir-diff-jvp.h +++ b/source/slang/slang-ir-diff-jvp.h @@ -1,6 +1,9 @@ // slang-ir-diff-jvp.h #pragma once +#include "slang-ir.h" +#include "slang-compiler.h" + namespace Slang { struct IRModule; @@ -12,6 +15,7 @@ namespace Slang bool processJVPDerivativeMarkers( IRModule* module, + DiagnosticSink* sink, IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions()); } diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 7984c5037..14724046a 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -446,6 +446,7 @@ static void cloneExtraDecorations( case kIROp_LayoutDecoration: case kIROp_PublicDecoration: case kIROp_SequentialIDDecoration: + case kIROp_JVPDerivativeReferenceDecoration: if(!clonedInst->findDecorationImpl(decoration->getOp())) { cloneInst(context, builder, decoration); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index b7c5155b5..e65bcc8e3 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7832,6 +7832,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration); } + if (auto attr = decl->findModifier<CustomJVPAttribute>()) + { + auto loweredVal = lowerLValueExpr(this->context, attr->funcDeclRef); + SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); + IRFunc* jvpFunc = as<IRFunc>(loweredVal.val); + getBuilder()->addDecoration(irFunc, kIROp_JVPDerivativeReferenceDecoration, jvpFunc); + } + // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list @@ -8477,7 +8485,7 @@ RefPtr<IRModule> generateIRForTranslationUnit( // Process higher-order-function calls before any optimization passes // to allow the optimizations to affect the generated funcitons. // 1. Process JVP derivative functions. - processJVPDerivativeMarkers(module); + processJVPDerivativeMarkers(module, compileRequest->getSink()); // 2. Process VJP derivative functions. // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet. // 3. Replace JVP & VJP calls. |
