diff options
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 184 |
1 files changed, 161 insertions, 23 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 2a42a7b6e..00210daaa 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -13,11 +13,20 @@ struct JVPTranscriber // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent // their differential values. - Dictionary<IRInst*, IRInst*> instMapD; + Dictionary<IRInst*, IRInst*> instMapD; // Cloning environment to hold mapping from old to new copies for the primal // instructions. - IRCloneEnv cloneEnv; + IRCloneEnv cloneEnv; + + // Diagnostic sink for error messages. + DiagnosticSink* sink; + + DiagnosticSink* getSink() + { + SLANG_ASSERT(sink); + return sink; + } void mapDifferentialInst(IRInst* instP, IRInst* instD) { @@ -156,7 +165,9 @@ struct JVPTranscriber rightP->getDataType(), rightP, rightP )); default: - SLANG_UNEXPECTED("Attempting to differentiate unsupported arithmetic"); + getSink()->diagnose(arith->sourceLoc, + Diagnostics::unimplemented, + "this arithmetic instruction cannot be differentiated"); } } @@ -179,8 +190,11 @@ struct JVPTranscriber } return nullptr; } - - SLANG_UNEXPECTED("Attempting to differentiate an unsupported load instruction"); + else + getSink()->diagnose(loadP->sourceLoc, + Diagnostics::unimplemented, + "this load instruction cannot be differentiated"); + return nullptr; } IRInst* differentiateStore(IRBuilder* builder, IRStore* storeP) @@ -204,8 +218,11 @@ struct JVPTranscriber } return nullptr; } - - SLANG_UNEXPECTED("Attempting to differentiate an unsupported store instruction"); + else + getSink()->diagnose(storeP->sourceLoc, + Diagnostics::unimplemented, + "this store instruction cannot be differentiated"); + return nullptr; } IRInst* differentiateReturn(IRBuilder* builder, IRReturn* returnP) @@ -226,12 +243,13 @@ struct JVPTranscriber // IRInst* differentiateConstruct(IRBuilder*, IRInst* consP) { - if (as<IRConstant>(consP->getOperand(0)) && consP->getOperandCount() == 1) - { return nullptr; - } - SLANG_UNEXPECTED("Attempting to differentiate unsupported constructor"); + else + getSink()->diagnose(consP->sourceLoc, + Diagnostics::unimplemented, + "this construct instruction cannot be differentiated"); + return nullptr; } // Logic for whether a primal instruction needs to be replicated @@ -242,13 +260,9 @@ struct JVPTranscriber bool requiresPrimalClone(IRBuilder*, IRInst* instP) { if (as<IRReturn>(instP)) - { return false; - } else - { return true; - } } IRInst* transcribe(IRBuilder* builder, IRInst* oldInstP) @@ -295,11 +309,52 @@ struct JVPTranscriber return differentiateConstruct(builder, instP); default: - SLANG_UNEXPECTED("Attempting to differentiate unrecognized instruction"); + getSink()->diagnose(instP->sourceLoc, + Diagnostics::unimplemented, + "this instruction cannot be differentiated"); + return nullptr; } } }; +struct IRWorkQueue +{ + // Work list to hold the active set of insts whose children + // need to be looked at. + // + List<IRInst*> workList; + HashSet<IRInst*> workListSet; + + void push(IRInst* inst) + { + if(!inst) return; + if(workListSet.Contains(inst)) return; + + workList.add(inst); + workListSet.Add(inst); + } + + IRInst* pop() + { + if (workList.getCount() != 0) + { + IRInst* topItem = workList.getFirst(); + // TODO(Sai): Repeatedly calling removeAt() can be really slow. + // Consider a specialized data structure or using removeLast() + // + workList.removeAt(0); + workListSet.Remove(topItem); + return topItem; + } + return nullptr; + } + + IRInst* peek() + { + return workList.getFirst(); + } +}; + struct JVPDerivativeContext { // This type passes over the module and generates @@ -315,6 +370,19 @@ struct JVPDerivativeContext // processing instructions while maintaining state. // JVPTranscriber transcriberStorage; + + // Diagnostic object from the compile request for + // error messages. + DiagnosticSink* sink; + + // Work queue to hold a stream of instructions that need + // to be checked for references to derivative functions. + IRWorkQueue workQueueStorage; + + DiagnosticSink* getSink() + { + return sink; + } bool processModule() { @@ -324,14 +392,81 @@ struct JVPDerivativeContext // SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->init(module); - - // Run through all the global-level instructions, - // looking for callables. - // Note: We're only processing global callables (IRGlobalValueWithCode) - // for now. - // + IRBuilder builderStorage(sharedBuilderStorage); IRBuilder* builder = &builderStorage; + + // processMarkedGlobalFunctions(builder); + return processReferencedFunctions(builder); + } + + IRInst* lookupJVPReference(IRInst* primalFunction) + { + if(auto jvpDefinition = primalFunction->findDecoration<IRJVPDerivativeReferenceDecoration>()) + return jvpDefinition->getJVPFunc(); + + return nullptr; + } + + // Recursively process instructions looking for JVP calls (kIROp_JVPDifferentiate), + // then check that the referenced function is marked correctly for differentiation. + // + bool processReferencedFunctions(IRBuilder* builder) + { + IRWorkQueue* workQueue = &(workQueueStorage); + + // Put the top-level inst into the queue. + workQueue->push(module->getModuleInst()); + + // Keep processing items until the queue is complete. + while (IRInst* workItem = workQueue->pop()) + { + for(auto child = workItem->getFirstChild(); child; child = child->getNextInst()) + { + // Either the child instruction has more children (func/block etc..) + // and we add it to the work list for further processing, or + // it's an ordinary inst in which case we check if it's a JVPDifferentiate + // instruction. + // + if (child->getFirstChild() != nullptr) + workQueue->push(child); + + if (auto jvpDiffInst = as<IRJVPDifferentiate>(child)) + { + auto baseFunction = jvpDiffInst->getBaseFn(); + // If the JVP Reference already exists, no need to + // differentiate again. + // + if(lookupJVPReference(baseFunction)) continue; + + if (isFunctionMarkedForJVP(as<IRGlobalValueWithCode>(baseFunction))) + { + IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(baseFunction)); + builder->addJVPDerivativeReferenceDecoration(baseFunction, jvpFunction); + workQueue->push(jvpFunction); + } + else + { + // TODO(Sai): This would probably be better with a more specific + // error code. + getSink()->diagnose(jvpDiffInst->sourceLoc, + Diagnostics::internalCompilerError, + "Cannot differentiate functions not marked for differentiation"); + } + } + } + } + + return true; + } + + // Run through all the global-level instructions, + // looking for callables. + // Note: We're only processing global callables (IRGlobalValueWithCode) + // for now. + // + bool processMarkedGlobalFunctions(IRBuilder* builder) + { for (auto inst : module->getGlobalInsts()) { // If the instr is a callable, get all the basic blocks @@ -340,7 +475,8 @@ struct JVPDerivativeContext if (isFunctionMarkedForJVP(callable)) { SLANG_ASSERT(as<IRFunc>(callable)); - IRFunc* jvpFunction = emitJVPFunction(&builderStorage, as<IRFunc>(callable)); + + IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(callable)); builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction); unmarkForJVP(callable); @@ -485,10 +621,12 @@ struct JVPDerivativeContext // bool processJVPDerivativeMarkers( IRModule* module, + DiagnosticSink* sink, IRJVPDerivativePassOptions const&) { JVPDerivativeContext context; context.module = module; + context.sink = sink; return context.processModule(); } |
