diff options
| -rw-r--r-- | source/slang/slang-emit.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-call.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 57 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 28 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 6 |
9 files changed, 51 insertions, 68 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 896171f32..fcdee78ea 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -377,7 +377,7 @@ Result linkAndOptimizeIR( // Process higher-order calles to auto-diff passes. // 1. Generate JVP code wherever necessary. (Linearization or "forward-mode" pass) - processJVPDerivativeMarkers(irModule, sink); + processForwardDifferentiableFuncs(irModule, sink); // 2. Transpose JVP to VJP code wherever needed. (Transposition or "reverse-mode" pass) // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet. diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp index 34e7e3de0..a574d6b7e 100644 --- a/source/slang/slang-ir-diff-call.cpp +++ b/source/slang/slang-ir-diff-call.cpp @@ -62,9 +62,9 @@ struct DerivativeCallProcessContext // if (auto origSpecialize = as<IRSpecialize>(origCallable)) { - if (auto jvpSpecRefDecorator = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>()) + if (auto jvpSpecRefDecorator = origSpecialize->findDecoration<IRForwardDerivativeDecoration>()) { - jvpCallable = jvpSpecRefDecorator->getJVPFunc(); + jvpCallable = jvpSpecRefDecorator->getForwardDerivativeFunc(); } } @@ -73,9 +73,9 @@ struct DerivativeCallProcessContext // Check for the 'JVPDerivativeReference' decorator on the // base function. // - if (auto jvpRefDecorator = origCallable->findDecoration<IRJVPDerivativeReferenceDecoration>()) + if (auto jvpRefDecorator = origCallable->findDecoration<IRForwardDerivativeDecoration>()) { - jvpCallable = jvpRefDecorator->getJVPFunc(); + jvpCallable = jvpRefDecorator->getForwardDerivativeFunc(); } SLANG_ASSERT(jvpCallable); diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 7e6fd30dd..73818dbb1 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -1116,12 +1116,12 @@ struct JVPTranscriber IRInst* diffCallee = nullptr; - if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>()) + if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>()) { // If the user has already provided an differentiated implementation, use that. - diffCallee = derivativeReferenceDecor->getJVPFunc(); + diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc(); } - else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>()) + else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>()) { // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass // to generate the implementation. @@ -1327,13 +1327,13 @@ struct JVPTranscriber auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase())); if (genericInnerVal->findDecoration<IRTargetIntrinsicDecoration>()) { - // Look for an IRJVPDerivativeReferenceDecoration on the specialize inst. + // Look for an IRForwardDerivativeDecoration on the specialize inst. // (Normally, this would be on the inner IRFunc, but in this case only the JVP func // can be specialized, so we put a decoration on the IRSpecialize) // - if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>()) + if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRForwardDerivativeDecoration>()) { - auto jvpFunc = jvpFuncDecoration->getJVPFunc(); + auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc(); // Make sure this isn't itself a specialize . SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc)); @@ -1450,7 +1450,7 @@ struct JVPTranscriber IRInst* origBase = originalInst->getOperand(0); auto primalBase = findOrTranscribePrimalInst(builder, origBase); auto field = originalInst->getOperand(1); - auto derivativeRefDecor = field->findDecoration<IRJVPDerivativeMemberReferenceDecoration>(); + auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>(); auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType()); IRInst* primalOperands[] = { primalBase, field }; @@ -1957,8 +1957,8 @@ struct JVPDerivativeContext IRInst* lookupJVPReference(IRInst* primalFunction) { - if(auto jvpDefinition = primalFunction->findDecoration<IRJVPDerivativeReferenceDecoration>()) - return jvpDefinition->getJVPFunc(); + if(auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>()) + return jvpDefinition->getForwardDerivativeFunc(); return nullptr; } @@ -2010,13 +2010,13 @@ struct JVPDerivativeContext // if (lookupJVPReference(baseFunction)) continue; - if (isMarkedForJVP(baseFunction)) + if (isMarkedForForwardDifferentiation(baseFunction)) { if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction)) { IRInst* diffFunc = (&transcriberStorage)->transcribe(builder, baseFunction); SLANG_ASSERT(diffFunc); - builder->addJVPDerivativeReferenceDecoration(baseFunction, diffFunc); + builder->addForwardDerivativeDecoration(baseFunction, diffFunc); workQueue->push(diffFunc); } else @@ -2210,15 +2210,15 @@ struct JVPDerivativeContext } // Checks decorators to see if the function should - // be differentiated (kIROp_JVPDerivativeMarkerDecoration) + // be differentiated (kIROp_ForwardDifferentiableDecoration) // - bool isMarkedForJVP(IRGlobalValueWithCode* callable) + bool isMarkedForForwardDifferentiation(IRGlobalValueWithCode* callable) { for(auto decoration = callable->getFirstDecoration(); decoration; decoration = decoration->getNextDecoration()) { - if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration) + if (decoration->getOp() == kIROp_ForwardDifferentiableDecoration) { return true; } @@ -2226,24 +2226,7 @@ struct JVPDerivativeContext return false; } - // Removes the JVPDerivativeMarkerDecoration from the provided callable, - // if it exists. - // - void unmarkForJVP(IRGlobalValueWithCode* callable) - { - for(auto decoration = callable->getFirstDecoration(); - decoration; - decoration = decoration->getNextDecoration()) - { - if (decoration->getOp() == kIROp_JVPDerivativeMarkerDecoration) - { - decoration->removeAndDeallocate(); - return; - } - } - } - - IRStringLit* getJVPFuncName(IRBuilder* builder, + IRStringLit* getForwardDerivativeFuncName(IRBuilder* builder, IRInst* func) { auto oldLoc = builder->getInsertLoc(); @@ -2252,11 +2235,11 @@ struct JVPDerivativeContext IRStringLit* name = nullptr; if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>()) { - name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_jvp").getUnownedSlice()); + name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_fwd_diff").getUnownedSlice()); } else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>()) { - name = builder->getStringValue((String(namehintDecoration->getName()) + "_jvp").getUnownedSlice()); + name = builder->getStringValue((String(namehintDecoration->getName()) + "_fwd_diff").getUnownedSlice()); } builder->setInsertLoc(oldLoc); @@ -2309,7 +2292,7 @@ struct JVPDerivativeContext // Set up context and call main process method. // -bool processJVPDerivativeMarkers( +bool processForwardDifferentiableFuncs( IRModule* module, DiagnosticSink* sink, IRJVPDerivativePassOptions const&) @@ -2335,8 +2318,8 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) auto next = decor->getNextDecoration(); switch (decor->getOp()) { - case kIROp_JVPDerivativeReferenceDecoration: - case kIROp_JVPDerivativeMemberReferenceDecoration: + case kIROp_ForwardDerivativeDecoration: + case kIROp_DerivativeMemberDecoration: decor->removeAndDeallocate(); break; default: diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h index 8ab4e0e8f..01ac15d6c 100644 --- a/source/slang/slang-ir-diff-jvp.h +++ b/source/slang/slang-ir-diff-jvp.h @@ -13,7 +13,7 @@ namespace Slang // Nothing for now.. }; - bool processJVPDerivativeMarkers( + bool processForwardDifferentiableFuncs( IRModule* module, DiagnosticSink* sink, IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions()); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index ccde80476..1d1db14f9 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -700,15 +700,15 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(SPIRVOpDecoration, spirvOpDecoration, 1, 0) /// Decorated function is marked for the forward-mode differentiation pass. - INST(JVPDerivativeMarkerDecoration, differentiateJvp, 0, 0) + INST(ForwardDifferentiableDecoration, forwardDifferentiable, 0, 0) /// Used by the auto-diff pass to hold a reference to the /// generated derivative function. - INST(JVPDerivativeReferenceDecoration, jvpFnReference, 1, 0) + INST(ForwardDerivativeDecoration, jvpFnReference, 1, 0) /// Used by the auto-diff pass to hold a reference to a /// differential member of a type in its associated differential type. - INST(JVPDerivativeMemberReferenceDecoration, derivativeMemberDecoration, 1, 0) + INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0) /// Marks a class type as a COM interface implementation, which enables /// the witness table to be easily picked up by emit. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 95202d9d0..deb81134b 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -546,33 +546,33 @@ struct IRSequentialIDDecoration : IRDecoration IRIntegerValue getSequentialID() { return getSequentialIDOperand()->getValue(); } }; -struct IRJVPDerivativeMarkerDecoration : IRDecoration +struct IRForwardDifferentiableDecoration : IRDecoration { enum { - kOp = kIROp_JVPDerivativeMarkerDecoration + kOp = kIROp_ForwardDifferentiableDecoration }; - IR_LEAF_ISA(JVPDerivativeMarkerDecoration) + IR_LEAF_ISA(ForwardDifferentiableDecoration) }; -struct IRJVPDerivativeReferenceDecoration : IRDecoration +struct IRForwardDerivativeDecoration : IRDecoration { enum { - kOp = kIROp_JVPDerivativeReferenceDecoration + kOp = kIROp_ForwardDerivativeDecoration }; - IR_LEAF_ISA(JVPDerivativeReferenceDecoration) + IR_LEAF_ISA(ForwardDerivativeDecoration) - IRInst* getJVPFunc() { return getOperand(0); } + IRInst* getForwardDerivativeFunc() { return getOperand(0); } }; -struct IRJVPDerivativeMemberReferenceDecoration : IRDecoration +struct IRDerivativeMemberDecoration : IRDecoration { enum { - kOp = kIROp_JVPDerivativeMemberReferenceDecoration + kOp = kIROp_DerivativeMemberDecoration }; - IR_LEAF_ISA(JVPDerivativeMemberReferenceDecoration) + IR_LEAF_ISA(DerivativeMemberDecoration) IRInst* getDerivativeMemberStructKey() { return getOperand(0); } }; @@ -3206,14 +3206,14 @@ public: addDecoration(value, kIROp_ForceInlineDecoration); } - void addJVPDerivativeMarkerDecoration(IRInst* value) + void addForwardDifferentiableDecoration(IRInst* value) { - addDecoration(value, kIROp_JVPDerivativeMarkerDecoration); + addDecoration(value, kIROp_ForwardDifferentiableDecoration); } - void addJVPDerivativeReferenceDecoration(IRInst* value, IRInst* jvpFn) + void addForwardDerivativeDecoration(IRInst* value, IRInst* jvpFn) { - addDecoration(value, kIROp_JVPDerivativeReferenceDecoration, jvpFn); + addDecoration(value, kIROp_ForwardDerivativeDecoration, jvpFn); } void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable) diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 56688abae..eb899b69c 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -448,7 +448,7 @@ static void cloneExtraDecorations( case kIROp_LayoutDecoration: case kIROp_PublicDecoration: case kIROp_SequentialIDDecoration: - case kIROp_JVPDerivativeReferenceDecoration: + case kIROp_ForwardDerivativeDecoration: if(!clonedInst->findDecorationImpl(decoration->getOp())) { cloneInst(context, builder, decoration); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 53ea99a0c..406e5157c 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -388,15 +388,15 @@ struct SpecializationContext auto genericReturnVal = findInnerMostGenericReturnVal(genericVal); if (genericReturnVal->findDecoration<IRTargetIntrinsicDecoration>()) { - if (auto customDiffRef = genericReturnVal->findDecoration<IRJVPDerivativeReferenceDecoration>()) + if (auto customDiffRef = genericReturnVal->findDecoration<IRForwardDerivativeDecoration>()) { // If we already have a diff func on this specialize, skip. - if (auto specDiffRef = specInst->findDecoration<IRJVPDerivativeReferenceDecoration>()) + if (auto specDiffRef = specInst->findDecoration<IRForwardDerivativeDecoration>()) { return false; } - auto specDiffFunc = as<IRSpecialize>(customDiffRef->getJVPFunc()); + auto specDiffFunc = as<IRSpecialize>(customDiffRef->getForwardDerivativeFunc()); // If the base is specialized, the JVP version must be also be a specialized // generic. @@ -436,7 +436,7 @@ struct SpecializationContext addToWorkList(newDiffFuncType); addToWorkList(newDiffFunc); - builder.addJVPDerivativeReferenceDecoration(specInst, newDiffFunc); + builder.addForwardDerivativeDecoration(specInst, newDiffFunc); return true; } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index acb7869e0..ae0590105 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7087,7 +7087,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val; SLANG_RELEASE_ASSERT(as<IRStructKey>(key)); auto builder = getBuilder(); - builder->addDecoration(inst, kIROp_JVPDerivativeMemberReferenceDecoration, key); + builder->addDecoration(inst, kIROp_DerivativeMemberDecoration, key); } LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl) @@ -7807,7 +7807,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if (decl->findModifier<ForwardDifferentiableAttribute>()) { - getBuilder()->addJVPDerivativeMarkerDecoration(irFunc); + getBuilder()->addForwardDifferentiableDecoration(irFunc); } // Always force inline diff setter accessor to prevent downstream compiler from complaining @@ -8222,7 +8222,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); IRInst* jvpFunc = loweredVal.val; - getBuilder()->addDecoration(irFunc, kIROp_JVPDerivativeReferenceDecoration, jvpFunc); + getBuilder()->addDecoration(irFunc, kIROp_ForwardDerivativeDecoration, jvpFunc); // Reset cursor. subContext->irBuilder->setInsertInto(irFunc); |
