summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp184
1 files changed, 161 insertions, 23 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 2a42a7b6e..00210daaa 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -13,11 +13,20 @@ struct JVPTranscriber
// Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
// their differential values.
- Dictionary<IRInst*, IRInst*> instMapD;
+ Dictionary<IRInst*, IRInst*> instMapD;
// Cloning environment to hold mapping from old to new copies for the primal
// instructions.
- IRCloneEnv cloneEnv;
+ IRCloneEnv cloneEnv;
+
+ // Diagnostic sink for error messages.
+ DiagnosticSink* sink;
+
+ DiagnosticSink* getSink()
+ {
+ SLANG_ASSERT(sink);
+ return sink;
+ }
void mapDifferentialInst(IRInst* instP, IRInst* instD)
{
@@ -156,7 +165,9 @@ struct JVPTranscriber
rightP->getDataType(), rightP, rightP
));
default:
- SLANG_UNEXPECTED("Attempting to differentiate unsupported arithmetic");
+ getSink()->diagnose(arith->sourceLoc,
+ Diagnostics::unimplemented,
+ "this arithmetic instruction cannot be differentiated");
}
}
@@ -179,8 +190,11 @@ struct JVPTranscriber
}
return nullptr;
}
-
- SLANG_UNEXPECTED("Attempting to differentiate an unsupported load instruction");
+ else
+ getSink()->diagnose(loadP->sourceLoc,
+ Diagnostics::unimplemented,
+ "this load instruction cannot be differentiated");
+ return nullptr;
}
IRInst* differentiateStore(IRBuilder* builder, IRStore* storeP)
@@ -204,8 +218,11 @@ struct JVPTranscriber
}
return nullptr;
}
-
- SLANG_UNEXPECTED("Attempting to differentiate an unsupported store instruction");
+ else
+ getSink()->diagnose(storeP->sourceLoc,
+ Diagnostics::unimplemented,
+ "this store instruction cannot be differentiated");
+ return nullptr;
}
IRInst* differentiateReturn(IRBuilder* builder, IRReturn* returnP)
@@ -226,12 +243,13 @@ struct JVPTranscriber
//
IRInst* differentiateConstruct(IRBuilder*, IRInst* consP)
{
-
if (as<IRConstant>(consP->getOperand(0)) && consP->getOperandCount() == 1)
- {
return nullptr;
- }
- SLANG_UNEXPECTED("Attempting to differentiate unsupported constructor");
+ else
+ getSink()->diagnose(consP->sourceLoc,
+ Diagnostics::unimplemented,
+ "this construct instruction cannot be differentiated");
+ return nullptr;
}
// Logic for whether a primal instruction needs to be replicated
@@ -242,13 +260,9 @@ struct JVPTranscriber
bool requiresPrimalClone(IRBuilder*, IRInst* instP)
{
if (as<IRReturn>(instP))
- {
return false;
- }
else
- {
return true;
- }
}
IRInst* transcribe(IRBuilder* builder, IRInst* oldInstP)
@@ -295,11 +309,52 @@ struct JVPTranscriber
return differentiateConstruct(builder, instP);
default:
- SLANG_UNEXPECTED("Attempting to differentiate unrecognized instruction");
+ getSink()->diagnose(instP->sourceLoc,
+ Diagnostics::unimplemented,
+ "this instruction cannot be differentiated");
+ return nullptr;
}
}
};
+struct IRWorkQueue
+{
+ // Work list to hold the active set of insts whose children
+ // need to be looked at.
+ //
+ List<IRInst*> workList;
+ HashSet<IRInst*> workListSet;
+
+ void push(IRInst* inst)
+ {
+ if(!inst) return;
+ if(workListSet.Contains(inst)) return;
+
+ workList.add(inst);
+ workListSet.Add(inst);
+ }
+
+ IRInst* pop()
+ {
+ if (workList.getCount() != 0)
+ {
+ IRInst* topItem = workList.getFirst();
+ // TODO(Sai): Repeatedly calling removeAt() can be really slow.
+ // Consider a specialized data structure or using removeLast()
+ //
+ workList.removeAt(0);
+ workListSet.Remove(topItem);
+ return topItem;
+ }
+ return nullptr;
+ }
+
+ IRInst* peek()
+ {
+ return workList.getFirst();
+ }
+};
+
struct JVPDerivativeContext
{
// This type passes over the module and generates
@@ -315,6 +370,19 @@ struct JVPDerivativeContext
// processing instructions while maintaining state.
//
JVPTranscriber transcriberStorage;
+
+ // Diagnostic object from the compile request for
+ // error messages.
+ DiagnosticSink* sink;
+
+ // Work queue to hold a stream of instructions that need
+ // to be checked for references to derivative functions.
+ IRWorkQueue workQueueStorage;
+
+ DiagnosticSink* getSink()
+ {
+ return sink;
+ }
bool processModule()
{
@@ -324,14 +392,81 @@ struct JVPDerivativeContext
//
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->init(module);
-
- // Run through all the global-level instructions,
- // looking for callables.
- // Note: We're only processing global callables (IRGlobalValueWithCode)
- // for now.
- //
+
IRBuilder builderStorage(sharedBuilderStorage);
IRBuilder* builder = &builderStorage;
+
+ // processMarkedGlobalFunctions(builder);
+ return processReferencedFunctions(builder);
+ }
+
+ IRInst* lookupJVPReference(IRInst* primalFunction)
+ {
+ if(auto jvpDefinition = primalFunction->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ return jvpDefinition->getJVPFunc();
+
+ return nullptr;
+ }
+
+ // Recursively process instructions looking for JVP calls (kIROp_JVPDifferentiate),
+ // then check that the referenced function is marked correctly for differentiation.
+ //
+ bool processReferencedFunctions(IRBuilder* builder)
+ {
+ IRWorkQueue* workQueue = &(workQueueStorage);
+
+ // Put the top-level inst into the queue.
+ workQueue->push(module->getModuleInst());
+
+ // Keep processing items until the queue is complete.
+ while (IRInst* workItem = workQueue->pop())
+ {
+ for(auto child = workItem->getFirstChild(); child; child = child->getNextInst())
+ {
+ // Either the child instruction has more children (func/block etc..)
+ // and we add it to the work list for further processing, or
+ // it's an ordinary inst in which case we check if it's a JVPDifferentiate
+ // instruction.
+ //
+ if (child->getFirstChild() != nullptr)
+ workQueue->push(child);
+
+ if (auto jvpDiffInst = as<IRJVPDifferentiate>(child))
+ {
+ auto baseFunction = jvpDiffInst->getBaseFn();
+ // If the JVP Reference already exists, no need to
+ // differentiate again.
+ //
+ if(lookupJVPReference(baseFunction)) continue;
+
+ if (isFunctionMarkedForJVP(as<IRGlobalValueWithCode>(baseFunction)))
+ {
+ IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(baseFunction));
+ builder->addJVPDerivativeReferenceDecoration(baseFunction, jvpFunction);
+ workQueue->push(jvpFunction);
+ }
+ else
+ {
+ // TODO(Sai): This would probably be better with a more specific
+ // error code.
+ getSink()->diagnose(jvpDiffInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Cannot differentiate functions not marked for differentiation");
+ }
+ }
+ }
+ }
+
+ return true;
+ }
+
+ // Run through all the global-level instructions,
+ // looking for callables.
+ // Note: We're only processing global callables (IRGlobalValueWithCode)
+ // for now.
+ //
+ bool processMarkedGlobalFunctions(IRBuilder* builder)
+ {
for (auto inst : module->getGlobalInsts())
{
// If the instr is a callable, get all the basic blocks
@@ -340,7 +475,8 @@ struct JVPDerivativeContext
if (isFunctionMarkedForJVP(callable))
{
SLANG_ASSERT(as<IRFunc>(callable));
- IRFunc* jvpFunction = emitJVPFunction(&builderStorage, as<IRFunc>(callable));
+
+ IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(callable));
builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction);
unmarkForJVP(callable);
@@ -485,10 +621,12 @@ struct JVPDerivativeContext
//
bool processJVPDerivativeMarkers(
IRModule* module,
+ DiagnosticSink* sink,
IRJVPDerivativePassOptions const&)
{
JVPDerivativeContext context;
context.module = module;
+ context.sink = sink;
return context.processModule();
}