summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-06-23 16:02:05 -0400
committerGitHub <noreply@github.com>2022-06-23 16:02:05 -0400
commit6cf3d496005c5635b273d9ce6c110f14541a9492 (patch)
tree91d6480cec20ca88e85fb9b4d437e4661fba229c /source/slang/slang-ir-diff-jvp.cpp
parent4aa6344f772d31c1f7b0676cbaf315104c4b30a2 (diff)
Added basic syntax to mark and request function derivatives, as well as the framework for passes to process them. (#2297)
* Added a decorator to mark functions for forward-mode differentiation * Fill out support for calls to non-decl values The existing compiler logic has a few places (semantic checking plus AST-to-IR lowering) where it assumes that function calls (`InvokeExpr`) are only ever made to expressions that resolve to a specific `Decl` (`DeclRefExpr`). This assumption allows semantic checking and lowering code to inspect things like the parameter list of an actual declaration, rather than just the type signature of the callee, and that infrastructure is used to support various features (e.g., default argument values on parameters). The AST and IR representations themselves have no matching requirement, and the places where the more general case of call expressions would need to be supported were relatively clear in the code. This change attempts to add suitable logic into each of those places. Note that this change does *not* surface any valid way to form input code that would cause these new code paths to be executed, so it is entirely possible that there are bugs in the logic as written here. The primary goal of this change is simply to get a sketch of the correct code checked in so that we have something to build on once we have language features that will require this support. * fixup: warnings-as-errors * Added parser logic for '__jvp(<fn-name>)' operator * Fixed issue with missing overload candidate item and added basic parsing test for the __jvp syntax * Added a blank JVP Auto-diff pass and a pass that replaces 'JVPDerivativeOf' calls with the differentiated function * Added a couple comments * Added parameter handling for the JVP pass Co-authored-by: Theresa Foley <tfoley@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp180
1 files changed, 180 insertions, 0 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
new file mode 100644
index 000000000..431c8e5b2
--- /dev/null
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -0,0 +1,180 @@
+// slang-ir-diff-jvp.cpp
+#include "slang-ir-diff-jvp.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+
+struct JVPDerivativeContext
+{
+ // This type passes over the module and generates
+ // forward-mode derivative versions of functions
+ // that are explicitly marked for it.
+ //
+ IRModule* module;
+
+ // Shared builder state for our derivative passes.
+ SharedIRBuilder sharedBuilderStorage;
+
+ bool processModule()
+ {
+ // We start by initializing our shared IR building state,
+ // since we will re-use that state for any code we
+ // generate along the way.
+ //
+ SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
+ sharedBuilder->init(module);
+
+ // Run through all the global-level instructions,
+ // looking for callables.
+ // Note: We're only processing global callables (IRGlobalValueWithCode)
+ // for now.
+ IRBuilder builderStorage(sharedBuilderStorage);
+ IRBuilder* builder = &builderStorage;
+ for (auto inst : module->getGlobalInsts())
+ {
+ // If the instr is a callable, get all the basic blocks
+ if (auto callable = as<IRGlobalValueWithCode>(inst))
+ {
+ if (isFunctionMarkedForJVP(callable))
+ {
+ SLANG_ASSERT(as<IRFunc>(callable));
+ IRFunc* jvpFunction = emitJVPFunction(&builderStorage, as<IRFunc>(callable));
+ builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction);
+ }
+ }
+ }
+ return true;
+ }
+
+ // Checks decorators to see if the function should
+ // be differentiated (kIROp_JVPDerivativeMarkerDecoration)
+ //
+ bool isFunctionMarkedForJVP(IRGlobalValueWithCode* callable)
+ {
+ for(auto decoration = callable->getFirstDecoration();
+ decoration;
+ decoration = decoration->getNextDecoration())
+ {
+ if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration)
+ {
+ return true;
+ }
+ // TODO: Need to remove this decoration or check for
+ // JVPDerivativeReferenceDecoration to avoid re-generating code.
+ }
+ return false;
+ }
+
+ List<IRParam*> emitFuncParameters(IRBuilder* builder, IRFuncType* dataType)
+ {
+ List<IRParam*> params;
+ for(UIndex i = 0; i < dataType->getParamCount(); i++)
+ {
+ params.add(
+ builder->emitParam(dataType->getParamType(i)));
+ }
+ return params;
+ }
+
+ // Perform forward-mode automatic differentiation on
+ // the intstructions.
+ IRFunc* emitJVPFunction(IRBuilder* builder,
+ IRFunc* primalFn)
+ {
+ // Note (sai): Is this safe? Should we use setInsertInto?
+ builder->setInsertBefore(primalFn->getNextInst());
+
+ auto jvpFn = builder->createFunc();
+ IRType* jvpFuncType = primalTypeToJVPType(primalFn->getFullType());
+ jvpFn->setFullType(jvpFuncType);
+ if (auto jvpName = getJVPFuncName(builder, primalFn))
+ builder->addNameHintDecoration(jvpFn, jvpName);
+
+ builder->setInsertInto(jvpFn);
+
+ // Start with _extremely_ basic functions
+ SLANG_ASSERT(primalFn->getFirstBlock() == primalFn->getLastBlock());
+
+ for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock())
+ {
+ IRBlock* newJVPBlock = nullptr;
+ if (block == primalFn->getFirstBlock())
+ {
+ newJVPBlock = builder->emitBlock();
+ emitFuncParameters(builder, as<IRFuncType>(jvpFuncType));
+ }
+ newJVPBlock = emitJVPBlock(builder, primalFn->getFirstBlock(), newJVPBlock);
+ }
+
+ return jvpFn;
+ }
+
+ IRStringLit* getJVPFuncName(IRBuilder* builder,
+ IRFunc* func)
+ {
+ auto oldLoc = builder->getInsertLoc();
+ builder->setInsertBefore(func);
+
+ IRStringLit* name = nullptr;
+ if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
+ {
+ name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_jvp").getUnownedSlice());
+ }
+ else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
+ {
+ name = builder->getStringValue((String(namehintDecoration->getName()) + "_jvp").getUnownedSlice());
+ }
+
+ builder->setInsertLoc(oldLoc);
+
+ return name;
+ }
+
+
+ IRBlock* emitJVPBlock(IRBuilder* builder,
+ IRBlock* primalBlock,
+ IRBlock* jvpBlock = nullptr)
+ {
+ // Create if not already provided, and insert into new block.
+ if (!jvpBlock)
+ jvpBlock = builder->emitBlock();
+ else
+ builder->setInsertInto(jvpBlock);
+
+ // Temporarily, we're going to just emit a single return 0 instruction.
+ for(auto child = primalBlock->getFirstInst(); child; child = child->getNextInst())
+ {
+ if (auto returnOp = as<IRReturn>(child))
+ {
+ auto zeroVal = builder->getFloatValue(returnOp->getVal()->getDataType(), 0.0);
+ builder->emitReturn(zeroVal);
+ }
+ }
+
+ return jvpBlock;
+ }
+
+ IRType* primalTypeToJVPType(IRType* primalType)
+ {
+ // Temporarily, we're going to implement the identity transform.
+ // The return type is the same as the primal type.
+ return primalType;
+ }
+};
+
+// Set up context and call main process method.
+//
+bool processJVPDerivativeMarkers(
+ IRModule* module,
+ IRJVPDerivativePassOptions const&)
+{
+ JVPDerivativeContext context;
+ context.module = module;
+
+ return context.processModule();
+}
+
+}