summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-07-18 23:32:30 -0400
committerGitHub <noreply@github.com>2022-07-18 23:32:30 -0400
commit5b4f35b8d00661852c607a49d81c590d4050a166 (patch)
tree320027c51f44c83c731d9121e41453dda67ed3ce /source
parent2e4b5770fa7e6dbf56845382706b33a22d6a025b (diff)
Added forward-mode autodiff support for more instructions (#2331)
* Merge slang-ir-diff-jvp.cpp * Added support and tests for other float vector types * Added swizzle test and code to handle it (tests failing currently) * Fixed one test, the other is still pending * Fixed instruction cloning logic to avoid modifying original function * Fixed an issue with custom 'pow_jvp' and added support for vector contructor * Minor update to comments * Fixed support for division * Fixed an issue with uninitialized diagnostic sink * Moved derivative processing to after mandatory inlining. Skip instructions that don't have side-effects and aren't used by anything. * WIP: Handling unconditional control flow and multi-block functions * Support for unconditional multi-block functions * Added a dead code elimination step to the derivative pass * Changed name of 'hasNoSideEffects()'
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-expr.cpp7
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp242
-rw-r--r--source/slang/slang-lower-to-ir.cpp18
3 files changed, 198 insertions, 69 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 67e8bf650..1895da70b 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1525,15 +1525,14 @@ namespace Slang
Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType)
{
- // Only float and float3 types can be differentiated for now.
+ // Only float and vector<float> types can be differentiated for now.
if (primalType->equals(builder->getFloatType()))
return primalType;
else if (auto primalVectorType = as<VectorExpressionType>(primalType))
{
- // TODO(sai): There's probably a more elegant way to check if a type is a float3?
- if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType()))
- return primalVectorType;
+ if (auto jvpElementType = primalToJVPParamType(builder, primalVectorType->elementType))
+ return builder->getVectorType(jvpElementType, primalVectorType->elementCount);
}
else if (auto primalOutType = as<OutType>(primalType))
{
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index f5afccd0c..fd1d0086d 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -4,6 +4,7 @@
#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-ir-clone.h"
+#include "slang-ir-dce.h"
namespace Slang
{
@@ -103,6 +104,11 @@ struct JVPTranscriber
if (IRType* typeD = differentiateType(builder, paramP->getFullType()))
{
IRParam* paramD = builder->emitParam(typeD);
+
+ auto nameHintD = getJVPVarName(paramP);
+ if (nameHintD.getLength() > 0)
+ builder->addNameHintDecoration(paramD, nameHintD.getUnownedSlice());
+
SLANG_ASSERT(paramD);
return paramD;
}
@@ -118,6 +124,7 @@ struct JVPTranscriber
{
auto newParamP = builder->emitParam(inoutTypeP->getValueType());
cloneEnv.mapOldValToNew.Add(paramP, newParamP);
+ cloneInstDecorationsAndChildren(&cloneEnv, builder->getSharedBuilder(), paramP, newParamP);
return newParamP;
}
@@ -138,7 +145,7 @@ struct JVPTranscriber
List<IRParam*> newParamListP;
for (auto paramP : paramListP)
{
- if(requiresPrimalClone(builder, paramP))
+ if(isPurelyFunctional(builder, paramP))
newParamListP.add(as<IRParam>(emitInputParam(builder, paramP)));
}
@@ -154,12 +161,30 @@ struct JVPTranscriber
return newParamListD;
}
+ // Returns "d<var-name>" to use as a name hint for variables and parameters.
+ // If no primal name is available, returns a blank string.
+ //
+ String getJVPVarName(IRInst* varP)
+ {
+ if (auto namehintDecoration = varP->findDecoration<IRNameHintDecoration>())
+ {
+ return ("d" + String(namehintDecoration->getName()));
+ }
+
+ return String("");
+ }
+
IRInst* differentiateVar(IRBuilder* builder, IRVar* varP)
{
if (IRType* typeD = differentiateType(builder, varP->getDataType()->getValueType()))
{
IRVar* varD = builder->emitVar(typeD);
SLANG_ASSERT(varD);
+
+ auto nameHintD = getJVPVarName(varP);
+ if (nameHintD.getLength() > 0)
+ builder->addNameHintDecoration(varD, nameHintD.getUnownedSlice());
+
return varD;
}
return nullptr;
@@ -270,7 +295,7 @@ struct JVPTranscriber
IRInst* differentiateReturn(IRBuilder* builder, IRReturn* returnP)
{
- IRInst* returnVal = findCloneForOperand(&cloneEnv, returnP->getVal());
+ IRInst* returnVal = returnP->getVal();
if (auto returnValD = getDifferentialInst(returnVal, nullptr))
{
IRReturn* returnD = as<IRReturn>(builder->emitReturn(returnValD));
@@ -366,6 +391,71 @@ struct JVPTranscriber
return nullptr;
}
+ IRInst* differentiateSwizzle(IRBuilder* builder, IRSwizzle* swizzleP)
+ {
+ if (auto baseD = getDifferentialInst(swizzleP->getBase(), nullptr))
+ {
+ List<IRInst*> swizzleIndices;
+ for (UIndex ii = 0; ii < swizzleP->getElementCount(); ii++)
+ swizzleIndices.add(swizzleP->getElementIndex(ii));
+
+ return builder->emitSwizzle(differentiateType(builder, swizzleP->getDataType()),
+ baseD,
+ swizzleP->getElementCount(),
+ swizzleIndices.getBuffer());
+ }
+ return nullptr;
+ }
+
+ IRInst* differentiateByPassthrough(IRBuilder* builder, IRInst* origInst)
+ {
+ UCount operandCount = origInst->getOperandCount();
+
+ List<IRInst*> diffOperands;
+ for (UIndex ii = 0; ii < operandCount; ii++)
+ {
+ // If the operand has a differential version, replace the original with the
+ // differential.
+ // Otherwise, abandon the differentiation attempt and assume that origInst
+ // cannot (or does not need to) be differentiated.
+ //
+ if (auto diffInst = getDifferentialInst(origInst->getOperand(ii), nullptr))
+ diffOperands.add(diffInst);
+ else
+ return nullptr;
+ }
+
+ return builder->emitIntrinsicInst(
+ differentiateType(builder, origInst->getDataType()),
+ origInst->getOp(),
+ operandCount,
+ diffOperands.getBuffer());
+ }
+
+ IRInst* handleControlFlow(IRBuilder* builder, IRInst* origInst)
+ {
+ switch(origInst->getOp())
+ {
+ case kIROp_unconditionalBranch:
+ auto origBranch = as<IRUnconditionalBranch>(origInst);
+
+ // Branches with extra operands not handled currently.
+ if (origBranch->getOperandCount() > 1)
+ break;
+
+ if (auto diffBlock = getDifferentialInst(origBranch->getTargetBlock(), nullptr))
+ return builder->emitBranch(as<IRBlock>(diffBlock));
+ else
+ return nullptr;
+ }
+
+ getSink()->diagnose(
+ origInst->sourceLoc,
+ Diagnostics::unimplemented,
+ "attempting to differentiate unhandled control flow");
+ return nullptr;
+ }
+
// In differential computation, the 'default' differential value is always zero.
// This is a consequence of differential computing being inherently linear. As a
// result, it's useful to have a method to generate zero literals of any (arithmetic) type.
@@ -390,11 +480,11 @@ struct JVPTranscriber
// Logic for whether a primal instruction needs to be replicated
// in the differential function. We detect and avoid replicating
- // side-effect instructions.
+ // 'side-effect' instructions.
//
- bool requiresPrimalClone(IRBuilder*, IRInst* instP)
+ bool isPurelyFunctional(IRBuilder*, IRInst* instP)
{
- if (as<IRReturn>(instP))
+ if (as<IRTerminatorInst>(instP))
return false;
else if (auto paramP = as<IRParam>(instP))
{
@@ -425,38 +515,37 @@ struct JVPTranscriber
IRInst* transcribe(IRBuilder* builder, IRInst* oldInstP)
{
- IRInst* instP = oldInstP;
- // Clone the old instruction, but only if it's safe to do so.
+ // Clone the old instruction into the new differential function.
+ //
+ IRInst* instP = cloneInst(&cloneEnv, builder, oldInstP);
+
+ SLANG_ASSERT(instP);
+
+ IRInst* instD = differentiateInst(builder, instP);
+
+ // In case it's not safe to clone the old instruction,
+ // remove it from the graph.
// For instance, instructions that handle control flow
// (return statements) shouldn't be replicated.
//
- if (requiresPrimalClone(builder, oldInstP))
- instP = cloneInst(&cloneEnv, builder, oldInstP);
+ if (isPurelyFunctional(builder, oldInstP))
+ mapDifferentialInst(instP, instD);
else
{
- // We replace the operands of the old instruction with their clones,
- // if available.
- //
- for(UInt ii = 0; ii < oldInstP->getOperandCount(); ++ii)
- {
- auto oldOperand = oldInstP->getOperand(ii);
- auto newOperand = findCloneForOperand(&cloneEnv, oldOperand);
+ // This inst should never have been used.
+ SLANG_ASSERT(instP->firstUse == nullptr);
- instP->getOperands()[ii].init(instP, newOperand);
- }
+ instP->removeAndDeallocate();
+ mapDifferentialInst(oldInstP, instD);
}
- SLANG_ASSERT(instP);
-
- IRInst* instD = differentiateInst(builder, instP);
-
- mapDifferentialInst(instP, instD);
return instD;
}
IRInst* differentiateInst(IRBuilder* builder, IRInst* instP)
{
+ // Handle common operations
switch (instP->getOp())
{
case kIROp_Var:
@@ -474,6 +563,7 @@ struct JVPTranscriber
case kIROp_Add:
case kIROp_Mul:
case kIROp_Sub:
+ case kIROp_Div:
return differentiateBinaryArith(builder, instP);
case kIROp_Construct:
@@ -481,13 +571,33 @@ struct JVPTranscriber
case kIROp_Call:
return differentiateCall(builder, as<IRCall>(instP));
+
+ case kIROp_swizzle:
+ return differentiateSwizzle(builder, as<IRSwizzle>(instP));
+
+ case kIROp_constructVectorFromScalar:
+ return differentiateByPassthrough(builder, instP);
- default:
- getSink()->diagnose(instP->sourceLoc,
+ case kIROp_unconditionalBranch:
+ case kIROp_conditionalBranch:
+ return handleControlFlow(builder, instP);
+
+ }
+
+ // If none of the cases have been hit, check if the instruction is a
+ // type.
+ // For now we don't have logic to differentiate types that appear in blocks.
+ // So, we ignore them.
+ //
+ if (as<IRType>(instP))
+ return nullptr;
+
+
+ // If we reach this statement, the instruction type is likely unhandled.
+ getSink()->diagnose(instP->sourceLoc,
Diagnostics::unimplemented,
"this instruction cannot be differentiated");
- return nullptr;
- }
+ return nullptr;
}
};
@@ -531,27 +641,6 @@ struct IRWorkQueue
struct JVPDerivativeContext
{
- // This type passes over the module and generates
- // forward-mode derivative versions of functions
- // that are explicitly marked for it.
- //
- IRModule* module;
-
- // Shared builder state for our derivative passes.
- SharedIRBuilder sharedBuilderStorage;
-
- // A transcriber object that handles the main job of
- // processing instructions while maintaining state.
- //
- JVPTranscriber transcriberStorage;
-
- // Diagnostic object from the compile request for
- // error messages.
- DiagnosticSink* sink;
-
- // Work queue to hold a stream of instructions that need
- // to be checked for references to derivative functions.
- IRWorkQueue workQueueStorage;
DiagnosticSink* getSink()
{
@@ -726,13 +815,22 @@ struct JVPDerivativeContext
builder->addNameHintDecoration(jvpFn, jvpName);
builder->setInsertInto(jvpFn);
-
- // Start with _extremely_ basic functions
- SLANG_ASSERT(primalFn->getFirstBlock() == primalFn->getLastBlock());
+ // Emit a block instruction for every block in the function, and map it as the
+ // corresponding differential.
+ //
+ for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock())
+ {
+ auto jvpBlock = builder->emitBlock();
+ transcriberStorage.mapDifferentialInst(block, jvpBlock);
+ }
+
+ // Go back over the blocks, and process the children of each block.
for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock())
{
- emitJVPBlock(builder, primalFn->getFirstBlock());
+ auto jvpBlock = as<IRBlock>(transcriberStorage.getDifferentialInst(block, block));
+ SLANG_ASSERT(jvpBlock);
+ emitJVPBlock(builder, block, jvpBlock);
}
return jvpFn;
@@ -759,7 +857,6 @@ struct JVPDerivativeContext
return name;
}
-
IRBlock* emitJVPBlock(IRBuilder* builder,
IRBlock* primalBlock,
IRBlock* jvpBlock = nullptr)
@@ -789,6 +886,35 @@ struct JVPDerivativeContext
return jvpBlock;
}
+ JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) : module(module), sink(sink)
+ {
+ transcriberStorage.sink = sink;
+ }
+
+ protected:
+
+ // This type passes over the module and generates
+ // forward-mode derivative versions of functions
+ // that are explicitly marked for it.
+ //
+ IRModule* module;
+
+ // Shared builder state for our derivative passes.
+ SharedIRBuilder sharedBuilderStorage;
+
+ // A transcriber object that handles the main job of
+ // processing instructions while maintaining state.
+ //
+ JVPTranscriber transcriberStorage;
+
+ // Diagnostic object from the compile request for
+ // error messages.
+ DiagnosticSink* sink;
+
+ // Work queue to hold a stream of instructions that need
+ // to be checked for references to derivative functions.
+ IRWorkQueue workQueueStorage;
+
};
// Set up context and call main process method.
@@ -798,9 +924,13 @@ bool processJVPDerivativeMarkers(
DiagnosticSink* sink,
IRJVPDerivativePassOptions const&)
{
- JVPDerivativeContext context;
- context.module = module;
- context.sink = sink;
+ JVPDerivativeContext context(module, sink);
+
+ // Simplify module to remove dead code.
+ IRDeadCodeEliminationOptions options;
+ options.keepExportsAlive = true;
+ options.keepLayoutsAlive = true;
+ eliminateDeadCode(module, options);
return context.processModule();
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 628f37c5b..81201f5f8 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8516,15 +8516,6 @@ RefPtr<IRModule> generateIRForTranslationUnit(
#endif
validateIRModuleIfEnabled(compileRequest, module);
-
- // Process higher-order-function calls before any optimization passes
- // to allow the optimizations to affect the generated funcitons.
- // 1. Process JVP derivative functions.
- processJVPDerivativeMarkers(module, compileRequest->getSink());
- // 2. Process VJP derivative functions.
- // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet.
- // 3. Replace JVP & VJP calls.
- processDerivativeCalls(module);
// We will perform certain "mandatory" optimization passes now.
@@ -8560,6 +8551,15 @@ RefPtr<IRModule> generateIRForTranslationUnit(
// temporaries whenever possible.
constructSSA(module);
+ // Process higher-order-function calls before any optimization passes
+ // to allow the optimizations to affect the generated funcitons.
+ // 1. Process JVP derivative functions.
+ processJVPDerivativeMarkers(module, compileRequest->getSink());
+ // 2. Process VJP derivative functions.
+ // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet.
+ // 3. Replace JVP & VJP calls.
+ processDerivativeCalls(module);
+
// Do basic constant folding and dead code elimination
// using Sparse Conditional Constant Propagation (SCCP)
//