diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-06 13:39:06 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-06 13:39:06 -0800 |
| commit | 33fb95980b0120cdd4d4f2d51f5f116e808dd4aa (patch) | |
| tree | 318b1669a0e52aabd11f8694de1278ef7dbc0e3b /source | |
| parent | e70cbe76ce74769069b7384f5f05c62da1ca45ed (diff) | |
Split bwd_diff op into separate ops for primal and propagate func. (#2582)
* Split bwd_diff op into separate ops for primal and propagate func.
* Fix.
* Download swiftshader with github actions instead of curl on linux.
* Fix github action.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
22 files changed, 929 insertions, 306 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index ccbac0286..81a6e3f7d 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -546,6 +546,21 @@ class BackwardDerivativeRequirementDecl : public DerivativeRequirementDecl SLANG_AST_CLASS(BackwardDerivativeRequirementDecl) }; +class BackwardDerivativePrimalRequirementDecl : public DerivativeRequirementDecl +{ + SLANG_AST_CLASS(BackwardDerivativePrimalRequirementDecl) +}; + +class BackwardDerivativePropagateRequirementDecl : public DerivativeRequirementDecl +{ + SLANG_AST_CLASS(BackwardDerivativePropagateRequirementDecl) +}; + +class BackwardDerivativeIntermediateTypeRequirementDecl : public DerivativeRequirementDecl +{ + SLANG_AST_CLASS(BackwardDerivativeIntermediateTypeRequirementDecl) +}; + bool isInterfaceRequirement(Decl* decl); InterfaceDecl* findParentInterfaceDecl(Decl* decl); diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 503d63a76..8e5192536 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -521,4 +521,19 @@ class BackwardDifferentiateVal : public DifferentiateVal SLANG_AST_CLASS(BackwardDifferentiateVal) }; +class BackwardDifferentiateIntermediateTypeVal : public DifferentiateVal +{ + SLANG_AST_CLASS(BackwardDifferentiateIntermediateTypeVal) +}; + +class BackwardDifferentiatePrimalVal : public DifferentiateVal +{ + SLANG_AST_CLASS(BackwardDifferentiatePrimalVal) +}; + +class BackwardDifferentiatePropagateVal : public DifferentiateVal +{ + SLANG_AST_CLASS(BackwardDifferentiatePropagateVal) +}; + } // namespace Slang diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 80bf74e53..7c8e320c4 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2665,10 +2665,28 @@ namespace Slang } else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) { - BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); + DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); val->func = satisfyingMemberDeclRef; witnessTable->add(bwdReq, RequirementWitness(val)); } + else if (auto primalReq = as<BackwardDerivativePrimalRequirementDecl>(reqRefDecl->referencedDecl)) + { + DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePrimalVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(primalReq, RequirementWitness(val)); + } + else if (auto propReq = as<BackwardDerivativePropagateRequirementDecl>(reqRefDecl->referencedDecl)) + { + DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePropagateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(propReq, RequirementWitness(val)); + } + else if (auto itypeReq = as<BackwardDerivativeIntermediateTypeRequirementDecl>(reqRefDecl->referencedDecl)) + { + DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateIntermediateTypeVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(itypeReq, RequirementWitness(val)); + } } witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef)); } @@ -5652,18 +5670,70 @@ namespace Slang } if (decl->hasModifier<BackwardDifferentiableAttribute>()) { - auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>(); - cloneModifiers(reqDecl, decl); + // Requirement for backward derivative. auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); - auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)); - setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; - - auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); - reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); + auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef))); + { + auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>(); + cloneModifiers(reqDecl, decl); + setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType); + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); + } + // Requirement for backward derivative intermediate type. + auto intermediateTypeReqDecl = m_astBuilder->create<BackwardDerivativeIntermediateTypeRequirementDecl>(); + auto intermediateType = m_astBuilder->getOrCreateDeclRefType( + intermediateTypeReqDecl, createDefaultSubstitutions(m_astBuilder, this, decl)); + { + cloneModifiers(intermediateTypeReqDecl, decl); + interfaceDecl->members.add(intermediateTypeReqDecl); + intermediateTypeReqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = intermediateTypeReqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); + } + // Requirement for backward derivative primal func. + { + auto reqDecl = m_astBuilder->create<BackwardDerivativePrimalRequirementDecl>(); + cloneModifiers(reqDecl, decl); + FuncType* primalFuncType = m_astBuilder->create<FuncType>(); + primalFuncType->resultType = diffFuncType->resultType; + primalFuncType->paramTypes.addRange(diffFuncType->paramTypes); + auto outType = m_astBuilder->getOutType(intermediateType); + primalFuncType->paramTypes.add(outType); + setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType); + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); + } + // Requirement for backward derivative propagate func. + { + auto reqDecl = m_astBuilder->create<BackwardDerivativePropagateRequirementDecl>(); + cloneModifiers(reqDecl, decl); + interfaceDecl->members.add(reqDecl); + reqDecl->parentDecl = interfaceDecl; + FuncType* propagateFuncType = m_astBuilder->create<FuncType>(); + propagateFuncType->resultType = diffFuncType->resultType; + propagateFuncType->paramTypes.addRange(diffFuncType->paramTypes); + propagateFuncType->paramTypes.add(intermediateType); + setFuncTypeIntoRequirementDecl(reqDecl, propagateFuncType); + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); + reqRef->referencedDecl = reqDecl; + reqRef->parentDecl = decl; + decl->members.add(reqRef); + } + isDiffFunc = true; } if (isDiffFunc) diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index c245701df..19678f402 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -11,8 +11,10 @@ namespace Slang { -IRFuncType* ForwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) +IRFuncType* ForwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) { + SLANG_UNUSED(func); + List<IRType*> newParameterTypes; IRType* diffReturnType; @@ -330,7 +332,8 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass // to generate the implementation. diffCallee = builder->emitForwardDifferentiateInst( - differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), + differentiateFunctionType( + builder, primalCallee, as<IRFuncType>(primalCallee->getFullType())), primalCallee); } @@ -615,8 +618,16 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec { args.add(primalSpecialize->getArg(i)); } + + // A `ForwardDerivative` decoration on an inner func of a generic should always be a `specialize`. + auto diffBaseSpecialize = as<IRSpecialize>(diffBase); + SLANG_RELEASE_ASSERT(diffBaseSpecialize); + + // Note: this assumes that the generic arguments to specialize the derivative is the same as the + // generic args to specialize the primal function. This is true for all of our stdlib functions, + // but we may need to rely on more general substitution logic here. auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + builder->getTypeKind(), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } else if (auto diffDecor = genericInnerVal->findDecoration<IRForwardDifferentiableDecoration>()) @@ -933,6 +944,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); IRType* diffFuncType = this->differentiateFunctionType( &builder, + origFunc, as<IRFuncType>(origFunc->getFullType())); diffFunc->setFullType(diffFuncType); @@ -943,7 +955,17 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu newNameSb << "s_fwd_" << originalName; builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); } - builder.addForwardDerivativeDecoration(origFunc, diffFunc); + + if (auto outerGen = findOuterGeneric(diffFunc)) + { + auto specialized = + specializeWithGeneric(builder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc))); + builder.addForwardDerivativeDecoration(origFunc, specialized); + } + else + { + builder.addForwardDerivativeDecoration(origFunc, diffFunc); + } // Mark the generated derivative function itself as differentiable. builder.addForwardDifferentiableDecoration(diffFunc); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 22ebf9d95..869b25ffd 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -73,7 +73,7 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeWrapExistential(IRBuilder* builder, IRInst* origInst); - virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) override; + virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; // Transcribe a function definition. InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index ae9b69f61..b6704011c 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -11,7 +11,7 @@ namespace Slang { - IRFuncType* BackwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) + IRFuncType* BackwardDiffTranscriberBase::differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermeidateType) { List<IRType*> newParameterTypes; IRType* diffReturnType; @@ -46,12 +46,53 @@ namespace Slang newParameterTypes.add(differentiateType(builder, funcType->getResultType())); + if (intermeidateType) + { + newParameterTypes.add((IRType*)intermeidateType); + } + diffReturnType = builder->getVoidType(); return builder->getFuncType(newParameterTypes, diffReturnType); } + + IRFuncType* BackwardDiffPrimalTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) + { + auto intermediateType = builder->getBackwardDiffIntermediateContextType(func); + auto outType = builder->getOutType(intermediateType); + return differentiateFunctionTypeImpl(builder, funcType, outType); + } + + InstPair BackwardDiffPrimalTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) + { + SLANG_UNUSED(builder); + SLANG_UNUSED(diffFunc); + auto intermediateTypeDecor = primalFunc->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>(); + SLANG_RELEASE_ASSERT(intermediateTypeDecor); + auto primalDecor = primalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>(); + return InstPair(primalFunc, primalDecor->getBackwardDerivativePrimalFunc()); + } + + IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) + { + auto intermediateType = builder->getBackwardDiffIntermediateContextType(func); + return differentiateFunctionTypeImpl(builder, funcType, intermediateType); + } + + IRFuncType* BackwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) + { + SLANG_UNUSED(func); + return differentiateFunctionTypeImpl(builder, funcType, nullptr); + } + + InstPair BackwardDiffPropagateTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) + { + IRGlobalValueWithCode* diffPrimalFunc = nullptr; + transcribeFuncImpl(builder, primalFunc, diffFunc, diffPrimalFunc); + return InstPair(primalFunc, diffFunc); + } - InstPair BackwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* origInst) + InstPair BackwardDiffTranscriberBase::transcribeInstImpl(IRBuilder* builder, IRInst* origInst) { switch (origInst->getOp()) { @@ -90,7 +131,7 @@ namespace Slang // Returns "dp<var-name>" to use as a name hint for parameters. // If no primal name is available, returns a blank string. // - String BackwardDiffTranscriber::makeDiffPairName(IRInst* origVar) + String BackwardDiffTranscriberBase::makeDiffPairName(IRInst* origVar) { if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) { @@ -100,47 +141,7 @@ namespace Slang return String(""); } - - // In differential computation, the 'default' differential value is always zero. - // This is a consequence of differential computing being inherently linear. As a - // result, it's useful to have a method to generate zero literals of any (arithmetic) type. - // The current implementation requires that types are defined linearly. - // - IRInst* BackwardDiffTranscriber::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) - { - if (auto diffType = differentiateType(builder, primalType)) - { - switch (diffType->getOp()) - { - case kIROp_DifferentialPairType: - return builder->emitMakeDifferentialPair( - diffType, - getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()), - getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType())); - } - // Since primalType has a corresponding differential type, we can lookup the - // definition for zero(). - auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); - SLANG_ASSERT(zeroMethod); - - auto emptyArgList = List<IRInst*>(); - return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); - } - else - { - if (isScalarIntegerType(primalType)) - { - return builder->getIntValue(primalType, 0); - } - - getSink()->diagnose(primalType->sourceLoc, - Diagnostics::internalCompilerError, - "could not generate zero value for given type"); - return nullptr; - } - } - - InstPair BackwardDiffTranscriber::transposeBlock(IRBuilder* builder, IRBlock* origBlock) + InstPair BackwardDiffTranscriberBase::transposeBlock(IRBuilder* builder, IRBlock* origBlock) { IRBuilder subBuilder(builder->getSharedBuilder()); subBuilder.setInsertLoc(builder->getInsertLoc()); @@ -194,10 +195,10 @@ namespace Slang } // Create an empty func to represent the transcribed func of `origFunc`. - InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) + InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc) { - if (auto bwdDecor = origFunc->findDecoration<IRBackwardDerivativeDecoration>()) - return InstPair(origFunc, bwdDecor->getBackwardDerivativeFunc()); + if (auto bwdDiffFunc = findExistingDiffFunc(origFunc)) + return InstPair(origFunc, bwdDiffFunc); if (!isMarkedForBackwardDifferentiation(origFunc)) return InstPair(nullptr, nullptr); @@ -216,6 +217,7 @@ namespace Slang SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); IRType* diffFuncType = this->differentiateFunctionType( &builder, + origFunc, as<IRFuncType>(origFunc->getFullType())); diffFunc->setFullType(diffFuncType); @@ -226,7 +228,18 @@ namespace Slang newNameSb << "s_bwd_" << originalName; builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); } - builder.addBackwardDerivativeDecoration(origFunc, diffFunc); + + if (auto outerGen = findOuterGeneric(diffFunc)) + { + builder.setInsertBefore(origFunc); + auto specialized = + specializeWithGeneric(builder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc))); + addExistingDiffFuncDecor(&builder, origFunc, specialized); + } + else + { + addExistingDiffFuncDecor(&builder, origFunc, diffFunc); + } // Mark the generated derivative function itself as differentiable. builder.addBackwardDifferentiableDecoration(diffFunc); @@ -237,17 +250,61 @@ namespace Slang cloneDecoration(dictDecor, diffFunc); } + return InstPair(primalFunc, diffFunc); + } + + InstPair BackwardDiffTranscriberBase::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) + { + auto result = transcribeFuncHeaderImpl(inBuilder, origFunc); + FuncBodyTranscriptionTask task; - task.originalFunc = primalFunc; - task.resultFunc = diffFunc; - task.type = FuncBodyTranscriptionTaskType::Backward; - autoDiffSharedContext->followUpFunctionsToTranscribe.add(task); + task.originalFunc = as<IRFunc>(result.primal); + task.resultFunc = as<IRFunc>(result.differential); + task.type = diffTaskType; + if (task.resultFunc) + { + autoDiffSharedContext->followUpFunctionsToTranscribe.add(task); + } + return result; + } - return InstPair(primalFunc, diffFunc); + InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) + { + auto header = transcribeFuncHeaderImpl(inBuilder, origFunc); + if (!header.differential) + return header; + + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertInto(header.differential); + builder.emitBlock(); + auto funcType = as<IRFuncType>(header.differential->getDataType()); + List<IRInst*> args; + for (UInt i = 0; i < funcType->getParamCount(); i++) + { + auto paramType = funcType->getParamType(i); + args.add(builder.emitParam(paramType)); + } + auto outerGeneric = findOuterGeneric(origFunc); + IRInst* specializedOriginalFunc = origFunc; + if (outerGeneric) + { + specializedOriginalFunc = maybeSpecializeWithGeneric(builder, outerGeneric, findOuterGeneric(header.differential)); + } + auto intermediateType = builder.getBackwardDiffIntermediateContextType(specializedOriginalFunc); + auto intermediateVar = builder.emitVar(intermediateType); + auto primalFunc = builder.emitBackwardDifferentiatePrimalInst(builder.getTypeKind(), specializedOriginalFunc); + auto propagateFunc = builder.emitBackwardDifferentiatePropagateInst(builder.getTypeKind(), specializedOriginalFunc); + args.add(intermediateVar); + builder.emitCallInst(builder.getVoidType(), primalFunc, args); + args.removeLast(); + args.add(builder.emitLoad(intermediateVar)); + builder.emitCallInst(builder.getVoidType(), propagateFunc, args); + builder.emitReturn(); + return header; } // Puts parameters into their own block. - void BackwardDiffTranscriber::makeParameterBlock(IRBuilder* inBuilder, IRFunc* func) + void BackwardDiffTranscriberBase::makeParameterBlock(IRBuilder* inBuilder, IRFunc* func) { IRBuilder builder(inBuilder->getSharedBuilder()); @@ -282,7 +339,7 @@ namespace Slang builder.emitBranch(firstBlock); } - void BackwardDiffTranscriber::cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType) + void BackwardDiffTranscriberBase::cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType) { IRStructType* structType = as<IRStructType>(intermediateType); if (!structType) @@ -375,22 +432,21 @@ namespace Slang } // Transcribe a function definition. - InstPair BackwardDiffTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) + void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc) { SLANG_ASSERT(primalFunc); - SLANG_ASSERT(diffFunc); + SLANG_ASSERT(diffPropagateFunc); // Reverse-mode transcription uses 4 separate steps: // TODO(sai): Fill in documentation. // Generate a temporary forward derivative function as an intermediate step. IRBuilder tempBuilder = *builder; - tempBuilder.setInsertBefore(diffFunc); - IRFunc* fwdDiffFunc = as<IRFunc>(fwdDiffTranscriber->transcribeFuncHeader(&tempBuilder, (IRFunc*)primalFunc).differential); + tempBuilder.setInsertBefore(diffPropagateFunc); + IRFunc* fwdDiffFunc = as<IRFunc>( + fwdDiffTranscriber->transcribeFuncHeader(&tempBuilder, primalFunc).differential); SLANG_ASSERT(fwdDiffFunc); - // Transcribe the body of the primal function into it's linear (fwd-diff) form. - // TODO(sai): Handle the case when we already have a user-defined fwd-derivative function. - fwdDiffTranscriber->transcribeFunc(&tempBuilder, primalFunc, as<IRFunc>(fwdDiffFunc)); + fwdDiffTranscriber->transcribeFunc(&tempBuilder, primalFunc, fwdDiffFunc); // Split first block into a paramter block. this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc)); @@ -416,7 +472,7 @@ namespace Slang // only blocks, and right now there's no provision in slang-ir-clone.h // for that. // - builder->setInsertInto(diffFunc->getParent()); + builder->setInsertInto(diffPropagateFunc->getParent()); auto tempDiffFunc = as<IRFunc>(cloneInst(&cloneEnv, builder, unzippedFwdDiffFunc)); // Move blocks to the diffFunc shell. @@ -426,37 +482,63 @@ namespace Slang workList.add(block); for (auto block : workList) - block->insertAtEnd(diffFunc); + block->insertAtEnd(diffPropagateFunc); } // Transpose the first block (parameter block) - transposeParameterBlock(builder, diffFunc); + transposeParameterBlock(builder, diffPropagateFunc); - builder->setInsertInto(diffFunc); + builder->setInsertInto(diffPropagateFunc); - auto dOutParameter = diffFunc->getLastParam(); + auto dOutParameter = diffPropagateFunc->getLastParam()->getPrevParam(); // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr}; - diffTransposePass->transposeDiffBlocksInFunc(diffFunc, info); + diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info); // Extracts the primal computations into its own func, and replace the primal insts // with the intermediate results computed from the extracted func. IRInst* intermediateType = nullptr; - auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffFunc, unzippedFwdDiffFunc, intermediateType); + auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffPropagateFunc, unzippedFwdDiffFunc, intermediateType); // Clean up by deallocating intermediate versions. tempDiffFunc->removeAndDeallocate(); unzippedFwdDiffFunc->removeAndDeallocate(); fwdDiffFunc->removeAndDeallocate(); - eliminateDeadCode(diffFunc); - cleanUpUnusedPrimalIntermediate(diffFunc, extractedPrimalFunc, intermediateType); - - return InstPair(primalFunc, diffFunc); + eliminateDeadCode(diffPropagateFunc); + cleanUpUnusedPrimalIntermediate(diffPropagateFunc, extractedPrimalFunc, intermediateType); + + // If primal function is nested in a generic, we want to create separate generics for all the associated things + // we have just created. + auto primalOuterGeneric = findOuterGeneric(primalFunc); + IRInst* specializedFunc = nullptr; + auto intermediateTypeGeneric = hoistValueFromGeneric(*builder, intermediateType, specializedFunc); + auto specializedIntermeidateType = maybeSpecializeWithGeneric(*builder, intermediateTypeGeneric, primalOuterGeneric); + builder->addBackwardDerivativeIntermediateTypeDecoration(primalFunc, specializedIntermeidateType); + + auto primalFuncGeneric = hoistValueFromGeneric(*builder, extractedPrimalFunc, specializedFunc); + builder->setInsertBefore(primalFunc); + + if (auto existingDecor = primalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>()) + { + // If we already created a header for primal func, move the body into the existing primal func header. + auto existingPrimalHeader = existingDecor->getBackwardDerivativePrimalFunc(); + if (auto spec = as<IRSpecialize>(existingPrimalHeader)) + existingPrimalHeader = spec->getBase(); + moveInstChildren(existingPrimalHeader, primalFuncGeneric); + primalFuncGeneric->replaceUsesWith(existingPrimalHeader); + primalFuncGeneric->removeAndDeallocate(); + } + else + { + auto specializedBackwardPrimalFunc = maybeSpecializeWithGeneric(*builder, primalFuncGeneric, primalOuterGeneric); + builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc); + } + diffPrimalFunc = as<IRGlobalValueWithCode>(primalOuterGeneric); } - void BackwardDiffTranscriber::transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc) + void BackwardDiffTranscriberBase::transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc) { IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock(); @@ -499,16 +581,19 @@ namespace Slang auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount(); // 2. Add a parameter for 'derivative of the output' (d_out). - // The type is the last parameter type of the function. + // The type is the second last parameter type of the function. // - auto dOutParamType = as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1); + auto dOutParamType = as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 2); SLANG_ASSERT(dOutParamType); builder->emitParam(dOutParamType); + + // Add a parameter for intermediate val. + builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1)); } - IRInst* BackwardDiffTranscriber::copyParam(IRBuilder* builder, IRParam* origParam) + IRInst* BackwardDiffTranscriberBase::copyParam(IRBuilder* builder, IRParam* origParam) { auto primalDataType = origParam->getDataType(); @@ -533,7 +618,7 @@ namespace Slang return cloneInst(&cloneEnv, builder, origParam); } - InstPair BackwardDiffTranscriber::copyBinaryArith(IRBuilder* builder, IRInst* origArith) + InstPair BackwardDiffTranscriberBase::copyBinaryArith(IRBuilder* builder, IRInst* origArith) { SLANG_ASSERT(origArith->getOperandCount() == 2); @@ -577,7 +662,7 @@ namespace Slang return InstPair(newInst, nullptr); } - IRInst* BackwardDiffTranscriber::transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad) + IRInst* BackwardDiffTranscriberBase::transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad) { SLANG_ASSERT(origArith->getOperandCount() == 2); @@ -645,7 +730,7 @@ namespace Slang return nullptr; } - InstPair BackwardDiffTranscriber::copyInst(IRBuilder* builder, IRInst* origInst) + InstPair BackwardDiffTranscriberBase::copyInst(IRBuilder* builder, IRInst* origInst) { // Handle common SSA-style operations switch (origInst->getOp()) @@ -670,7 +755,7 @@ namespace Slang return InstPair(nullptr, nullptr); } - IRInst* BackwardDiffTranscriber::transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad) + IRInst* BackwardDiffTranscriberBase::transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad) { IRInOutType* inoutParam = as<IRInOutType>(param->getDataType()); auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType()); @@ -687,7 +772,7 @@ namespace Slang return store; } - IRInst* BackwardDiffTranscriber::transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad) + IRInst* BackwardDiffTranscriberBase::transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad) { // Handle common SSA-style operations switch (origInst->getOp()) @@ -727,7 +812,7 @@ namespace Slang return nullptr; } - InstPair BackwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) + InstPair BackwardDiffTranscriberBase::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) { auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase()); List<IRInst*> primalArgs; @@ -739,8 +824,7 @@ namespace Slang auto primalSpecialize = (IRSpecialize*)builder->emitSpecializeInst( (IRType*)primalType, primalBase, primalArgs.getCount(), primalArgs.getBuffer()); - IRInst* diffBase = nullptr; - if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase)) + if (auto diffBase = instMapD.TryGetValue(origSpecialize->getBase())) { List<IRInst*> args; for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) @@ -748,7 +832,7 @@ namespace Slang args.add(primalSpecialize->getArg(i)); } auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + builder->getTypeKind(), *diffBase, args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } @@ -757,25 +841,31 @@ namespace Slang // (Normally, this would be on the inner IRFunc, but in this case only the JVP func // can be specialized, so we put a decoration on the IRSpecialize) // - if (auto backDecor = origSpecialize->findDecoration<IRBackwardDerivativeDecoration>()) + if (auto derivativeFunc = findExistingDiffFunc(origSpecialize)) { - auto derivativeFunc = backDecor->getBackwardDerivativeFunc(); - // Make sure this isn't itself a specialize . SLANG_RELEASE_ASSERT(!as<IRSpecialize>(derivativeFunc)); return InstPair(primalSpecialize, derivativeFunc); } - else if (auto derivativeDecoration = genericInnerVal->findDecoration<IRBackwardDerivativeDecoration>()) + else if (auto diffBase = findExistingDiffFunc(genericInnerVal)) { - diffBase = derivativeDecoration->getBackwardDerivativeFunc(); List<IRInst*> args; for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) { args.add(primalSpecialize->getArg(i)); } + + // A `BackwardDerivative` decoration on an inner func of a generic should always be a `specialize`. + auto diffBaseSpecialize = as<IRSpecialize>(diffBase); + SLANG_RELEASE_ASSERT(diffBaseSpecialize); + + // Note: this assumes that the generic arguments to specialize the derivative is the same as the + // generic args to specialize the primal function. This is true for all of our stdlib functions, + // but we may need to rely on more general substitution logic here. auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + builder->getTypeKind(), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer()); + return InstPair(primalSpecialize, diffSpecialize); } else if (auto diffDecor = genericInnerVal->findDecoration<IRBackwardDifferentiableDecoration>()) @@ -785,9 +875,9 @@ namespace Slang { args.add(primalSpecialize->getArg(i)); } - diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase()); + auto diffCallee = findOrTranscribeDiffInst(builder, origSpecialize->getBase()); auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + builder->getTypeKind(), diffCallee, args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } else diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index f9ca6110c..378300789 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -20,8 +20,10 @@ struct IRReverseDerivativePassOptions // Nothing for now.. }; -struct BackwardDiffTranscriber : AutoDiffTranscriberBase +struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase { + FuncBodyTranscriptionTaskType diffTaskType; + // Map that stores the upper gradient given an IRInst* Dictionary<IRInst*, List<IRInst*>> upperGradients; Dictionary<IRInst*, IRInst*> primalToDiffPair; @@ -38,8 +40,9 @@ struct BackwardDiffTranscriber : AutoDiffTranscriberBase DiffPropagationPass diffPropagationPassStorage; DiffUnzipPass diffUnzipPassStorage; - BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) + BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType taskType, AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) : AutoDiffTranscriberBase(shared, inSharedBuilder, inSink) + , diffTaskType(taskType) , diffTransposePassStorage(shared) , diffPropagationPassStorage(shared) , diffUnzipPassStorage(shared) @@ -52,13 +55,8 @@ struct BackwardDiffTranscriber : AutoDiffTranscriberBase // If no primal name is available, returns a blank string. // String makeDiffPairName(IRInst* origVar); - - // In differential computation, the 'default' differential value is always zero. - // This is a consequence of differential computing being inherently linear. As a - // result, it's useful to have a method to generate zero literals of any (arithmetic) type. - // The current implementation requires that types are defined linearly. - // - IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType); + + IRFuncType* differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermediateType); InstPair transposeBlock(IRBuilder* builder, IRBlock* origBlock); @@ -68,7 +66,7 @@ struct BackwardDiffTranscriber : AutoDiffTranscriberBase void cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType); // Transcribe a function definition. - InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc); + virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) = 0; void transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc); @@ -86,18 +84,98 @@ struct BackwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize); - // Create an empty func to represent the transcribed func of `origFunc`. - virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override; + void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc); - virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) override; + InstPair transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc); + + virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override; virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) override; + virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) = 0; + virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) = 0; + virtual IROp getDifferentiableMethodDictionaryItemOp() override { - return kIROp_ForwardDifferentiableMethodRequirementDictionaryItem; + return kIROp_BackwardDifferentiableMethodRequirementDictionaryItem; + } +}; + +struct BackwardDiffPrimalTranscriber : BackwardDiffTranscriberBase +{ + BackwardDiffPrimalTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) + : BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType::BackwardPrimal, shared, inSharedBuilder, inSink) + { } + + virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; + virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) override; + virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) override + { + if (auto backDecor = originalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>()) + { + return backDecor->getBackwardDerivativePrimalFunc(); + } + return nullptr; + } + virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) override + { + builder->addBackwardDerivativePrimalDecoration(inst, diffFunc); + } +}; + +struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase +{ + BackwardDiffPropagateTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) + : BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType::BackwardPropagate, shared, inSharedBuilder, inSink) + { } + + virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; + virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) override; + virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) override + { + if (auto backDecor = originalFunc->findDecoration<IRBackwardDerivativePropagateDecoration>()) + { + return backDecor->getBackwardDerivativePropagateFunc(); + } + return nullptr; } + virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) override + { + builder->addBackwardDerivativePropagateDecoration(inst, diffFunc); + } +}; + +// A backward derivative function combines both primal + propagate functions and accepts no +// intermediate value input. +struct BackwardDiffTranscriber : BackwardDiffTranscriberBase +{ + BackwardDiffTranscriber( + AutoDiffSharedContext* shared, + SharedIRBuilder* inSharedBuilder, + DiagnosticSink* inSink) + : BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType::Backward, shared, inSharedBuilder, inSink) + { } + virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; + virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override; + virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) override + { + SLANG_UNUSED(builder); + // Don't need to do anything here, the body is generated in transcribeFuncHeader. + return InstPair(primalFunc, diffFunc); + } + virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) override + { + if (auto backDecor = originalFunc->findDecoration<IRBackwardDerivativeDecoration>()) + { + return backDecor->getBackwardDerivativeFunc(); + } + return nullptr; + } + virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) override + { + builder->addBackwardDerivativeDecoration(inst, diffFunc); + } }; } diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 69cef941c..4aab0f835 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -259,7 +259,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy } case kIROp_FuncType: - return differentiateFunctionType(builder, as<IRFuncType>(primalType)); + return differentiateFunctionType(builder, nullptr, as<IRFuncType>(primalType)); case kIROp_OutType: if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType())) @@ -436,7 +436,7 @@ InstPair AutoDiffTranscriberBase::transcribeParam(IRBuilder* builder, IRParam* o { auto primalDataType = findOrTranscribePrimalInst(builder, origParam->getDataType()); // Do not differentiate generic type (and witness table) parameters - if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType)) + if (isGenericParam(origParam)) { return InstPair( cloneInst(&cloneEnv, builder, origParam), diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index 8e4b7a901..4c3bbe05f 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -116,7 +116,7 @@ struct AutoDiffTranscriberBase IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType); - virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) = 0; + virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) = 0; // Create an empty func to represent the transcribed func of `origFunc`. virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) = 0; diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 2fd53dbd0..1496ae60f 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -72,23 +72,11 @@ struct ExtractPrimalFuncContext IRFuncType* originalFuncType = nullptr; outIntermediateType = createIntermediateType(destFunc); - if (auto gen = as<IRGeneric>(destFunc)) - { - auto func = findGenericReturnVal(gen); - builder.setInsertBefore(func); - outIntermediateType = - specializeWithGeneric(builder, outIntermediateType, gen); - SLANG_RELEASE_ASSERT(func); - originalFuncType = as<IRFuncType>(as<IRGeneric>(fwdFunc)->getDataType()); - } - else - { - originalFuncType = as<IRFuncType>(fwdFunc->getDataType()); - } + originalFuncType = as<IRFuncType>(fwdFunc->getDataType()); SLANG_RELEASE_ASSERT(originalFuncType); List<IRType*> paramTypes; - for (UInt i = 0; i < originalFuncType->getParamCount(); i++) + for (UInt i = 0; i < originalFuncType->getParamCount() - 1; i++) paramTypes.add(originalFuncType->getParamType(i)); paramTypes.add(builder.getInOutType((IRType*)outIntermediateType)); auto newFuncType = builder.getFuncType(paramTypes, builder.getVoidType()); @@ -243,75 +231,9 @@ struct ExtractPrimalFuncContext return true; } - // Given a `genericA<Param1, Param1,...> { instX(Param1, Param2) }`, - // and a clone of it `genericB<ParamB_1, ParamB_2,...> { }`. - // `GenericChildrenMigrationContext(genericA, genericB)::getCorrespondingInst(instX)` - // returns a clone of `instX` in `genericB` that references the new generic params - // as `instX_clone` in `genericB<ParamB_1, ParamB_2,...> { instX_clone(ParamB_1, ParamB_2) }`. - struct GenericChildrenMigrationContext - { - IRCloneEnv cloneEnv; - IRGeneric* oldGeneric = nullptr; - IRGeneric* newGeneric = nullptr; - IRInst* newGenericRetVal = nullptr; - - void init(IRGeneric* oldGen, IRGeneric* newGen) - { - oldGeneric = oldGen; - newGeneric = newGen; - newGenericRetVal = findGenericReturnVal(newGen); - - IRInst* oldParam = oldGen->getFirstParam(); - IRInst* newParam = newGen->getFirstParam(); - while (oldParam) - { - oldParam = as<IRParam>(oldParam->getNextInst()); - newParam = as<IRParam>(newParam->getNextInst()); - if (!oldParam) - { - SLANG_RELEASE_ASSERT(!newParam); - break; - } - SLANG_RELEASE_ASSERT(newParam); - cloneEnv.mapOldValToNew[oldParam] = newParam; - } - } - IRInst* getCorrespondingInst(IRBuilder& builder, IRInst* oldChild) - { - if (!oldGeneric) - return oldChild; - auto parent = oldChild->getParent(); - bool found = false; - while (parent) - { - if (parent == oldGeneric) - { - found = true; - break; - } - parent = parent->getParent(); - } - if (!found) - return oldChild; - for (UInt i = 0; i < oldChild->getOperandCount(); i++) - { - auto operand = oldChild->getOperand(i); - if (cloneEnv.mapOldValToNew.ContainsKey(operand)) - {} - else - { - getCorrespondingInst(builder, operand); - } - } - auto cloned = cloneInst(&cloneEnv, &builder, oldChild); - return cloned; - } - }; - void storeInst( IRBuilder& builder, IRInst* inst, - GenericChildrenMigrationContext& genericContext, IRInst* intermediateOutput) { IRBuilder genTypeBuilder(sharedBuilder); @@ -319,7 +241,7 @@ struct ExtractPrimalFuncContext SLANG_RELEASE_ASSERT(ptrStructType); auto structType = as<IRStructType>(ptrStructType->getValueType()); genTypeBuilder.setInsertBefore(structType); - auto fieldType = genericContext.getCorrespondingInst(genTypeBuilder, inst->getDataType()); + auto fieldType = inst->getDataType(); SLANG_RELEASE_ASSERT(structType); auto structKey = genTypeBuilder.createStructKey(); if (auto nameHint = inst->findDecoration<IRNameHintDecoration>()) @@ -333,30 +255,16 @@ struct ExtractPrimalFuncContext inst); } - IRGlobalValueWithCode* turnUnzippedFuncIntoPrimalFunc(IRGlobalValueWithCode* unzippedFunc, IRGlobalValueWithCode* fwdFunc, IRInst*& outIntermediateType) + IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* fwdFunc, IRInst*& outIntermediateType) { // Note: this transformation assumes the original func has only one return. IRBuilder builder(sharedBuilder); - IRFunc* func = nullptr; + IRFunc* func = unzippedFunc; IRInst* intermediateType = nullptr; auto newFuncType = generatePrimalFuncType(unzippedFunc, fwdFunc, intermediateType); - if (auto gen = as<IRGeneric>(unzippedFunc)) - { - func = as<IRFunc>(findGenericReturnVal(gen)); - SLANG_RELEASE_ASSERT(func); - builder.setInsertBefore(func); - auto spec = as<IRSpecialize>(intermediateType); - SLANG_RELEASE_ASSERT(spec); - outIntermediateType = spec->getBase(); - } - else - { - func = as<IRFunc>(unzippedFunc); - SLANG_RELEASE_ASSERT(func); - outIntermediateType = intermediateType; - } + outIntermediateType = intermediateType; func->setFullType((IRType*)newFuncType); // Go through all the insts and preserve the primal blocks. @@ -375,19 +283,14 @@ struct ExtractPrimalFuncContext auto paramBlock = func->getFirstBlock(); builder.setInsertInto(paramBlock); + auto oldIntermediateParam = func->getLastParam(); auto outIntermediary = builder.emitParam(builder.getInOutType((IRType*)intermediateType)); + oldIntermediateParam->replaceUsesWith(outIntermediary); + oldIntermediateParam->removeAndDeallocate(); auto firstBlock = *(paramBlock->getSuccessors().begin()); - GenericChildrenMigrationContext genericMigrationContext; - if (auto gen = as<IRGeneric>(unzippedFunc)) - { - auto spec = as<IRSpecialize>(intermediateType); - SLANG_RELEASE_ASSERT(spec); - genericMigrationContext.init(gen, as<IRGeneric>(spec->getBase())); - } - List<IRBlock*> diffBlocksList; List<IRBlock*> primalBlocksList; @@ -412,7 +315,7 @@ struct ExtractPrimalFuncContext if (shouldStoreInst(inst)) { builder.setInsertAfter(inst); - storeInst(builder, inst, genericMigrationContext, outIntermediary); + storeInst(builder, inst, outIntermediary); } } } @@ -482,8 +385,8 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE } } -IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc( - IRGlobalValueWithCode* func, IRGlobalValueWithCode* fwdFunc, IRInst*& intermediateType) +IRFunc* DiffUnzipPass::extractPrimalFunc( + IRFunc* func, IRFunc* fwdFunc, IRInst*& intermediateType) { IRBuilder builder(this->autodiffContext->sharedBuilder); builder.setInsertBefore(func); @@ -491,46 +394,31 @@ IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc( IRCloneEnv subEnv; subEnv.squashChildrenMapping = true; subEnv.parent = &cloneEnv; - auto clonedFunc = as<IRGlobalValueWithCode>(cloneInst(&subEnv, &builder, func)); + auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func)); ExtractPrimalFuncContext context; context.init(autodiffContext->sharedBuilder); intermediateType = nullptr; auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, fwdFunc, intermediateType); - IRInst* specializedPrimalFunc = primalFunc; - - // Copy PrimalValueStructKey decorations from primal func. - copyPrimalValueStructKeyDecorations(func, subEnv); - - IRInst* specializedIntermediateType = intermediateType; - auto innerFunc = as<IRFunc>(func); - if (auto genFunc = as<IRGeneric>(func)) + if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>()) { - innerFunc = as<IRFunc>(findGenericReturnVal(genFunc)); - builder.setInsertBefore(innerFunc); - specializedIntermediateType = specializeWithGeneric(builder, intermediateType, genFunc); - specializedPrimalFunc = specializeWithGeneric(builder, primalFunc, genFunc); + auto primalName = String(nameHint->getName()) + "_primal"; + nameHint->setOperand(0, builder.getStringValue(primalName.getUnownedSlice())); } - SLANG_RELEASE_ASSERT(innerFunc); - // Insert a call to primal func at start of the function. - auto paramBlock = innerFunc->getFirstBlock(); + // Copy PrimalValueStructKey decorations from primal func. + copyPrimalValueStructKeyDecorations(func, subEnv); + + auto paramBlock = func->getFirstBlock(); auto firstBlock = *(paramBlock->getSuccessors().begin()); builder.setInsertBefore(firstBlock->getFirstInst()); - auto intermediateVar = builder.emitVar((IRType*)specializedIntermediateType); - List<IRInst*> args; - for (auto param : paramBlock->getParams()) - { - args.add(param); - } - args.add(intermediateVar); - builder.emitCallInst(innerFunc->getResultType(), specializedPrimalFunc, args); + auto intermediateVar = func->getLastParam(); // Replace all insts that has intermediate results with a load of the intermediate. List<IRInst*> instsToRemove; - for (auto block : innerFunc->getBlocks()) + for (auto block : func->getBlocks()) { for (auto inst : block->getOrdinaryInsts()) { @@ -554,8 +442,8 @@ IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc( } // Run simplification to DCE unnecessary insts. - eliminateDeadCode(innerFunc); - eliminateDeadCode(specializedPrimalFunc); + eliminateDeadCode(func); + eliminateDeadCode(primalFunc); return primalFunc; } diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 2c55b390b..f2ce3dc62 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -132,7 +132,7 @@ struct DiffUnzipPass return unzippedFunc; } - IRGlobalValueWithCode* extractPrimalFunc(IRGlobalValueWithCode* func, IRGlobalValueWithCode* fwdFunc, IRInst*& intermediateType); + IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* fwdFunc, IRInst*& intermediateType); bool isRelevantDifferentialPair(IRType* type) { diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 40c24d11d..d23271704 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -401,6 +401,10 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_DifferentiableTypeDictionaryDecoration: case kIROp_DifferentialInstDecoration: case kIROp_MixedDifferentialInstDecoration: + case kIROp_BackwardDerivativeDecoration: + case kIROp_BackwardDerivativeIntermediateTypeDecoration: + case kIROp_BackwardDerivativePropagateDecoration: + case kIROp_BackwardDerivativePrimalDecoration: decor->removeAndDeallocate(); break; default: @@ -489,7 +493,7 @@ struct AutoDiffPass : public InstPassBase // TODO(sai): Move this call. forwardTranscriber.differentiableTypeConformanceContext.buildGlobalWitnessDictionary(); - IRBuilder builderStorage(this->autodiffContext->sharedBuilder); + IRBuilder builderStorage(&sharedBuilderStorage); IRBuilder* builder = &builderStorage; // Process all ForwardDifferentiate and BackwardDifferentiate instructions by @@ -500,6 +504,81 @@ struct AutoDiffPass : public InstPassBase return modified; } + IRInst* processIntermediateContextTypeBase(IRBuilder* builder, IRInst* base) + { + if (auto spec = as<IRSpecialize>(base)) + { + List<IRInst*> args; + auto subBase = processIntermediateContextTypeBase(builder, spec->getBase()); + for (UInt a = 0; a < spec->getArgCount(); a++) + args.add(spec->getArg(a)); + auto actualType = builder->emitSpecializeInst( + builder->getTypeKind(), + subBase, + args.getCount(), + args.getBuffer()); + return actualType; + } + else if (auto baseGeneric = as<IRGeneric>(base)) + { + auto inner = findGenericReturnVal(baseGeneric); + if (auto typeDecor = inner->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>()) + { + auto typeSpec = cast<IRSpecialize>(typeDecor->getBackwardDerivativeIntermediateType()); + auto typeSpecBase = typeSpec->getBase(); + return typeSpecBase; + } + } + else if (auto func = as<IRFunc>(base)) + { + if (auto typeDecor = func->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>()) + { + return typeDecor->getBackwardDerivativeIntermediateType(); + } + } + else if (auto lookup = as<IRLookupWitnessMethod>(base)) + { + auto key = lookup->getRequirementKey(); + if (auto typeDecor = key->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>()) + { + auto typeKey = typeDecor->getBackwardDerivativeIntermediateType(); + auto typeLookup = builder->emitLookupInterfaceMethodInst(builder->getTypeKind(), lookup->getWitnessTable(), typeKey); + return typeLookup; + } + } + return nullptr; + } + + bool lowerIntermediateContextType(IRBuilder* builder) + { + bool changed = false; + processAllInsts([&](IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_BackwardDiffIntermediateContextType: + { + auto differentiateInst = as<IRBackwardDiffIntermediateContextType>(inst); + + auto baseFunc = differentiateInst->getOperand(0); + IRBuilder subBuilder = *builder; + subBuilder.setInsertBefore(inst); + auto type = processIntermediateContextTypeBase(&subBuilder, baseFunc); + if (type) + { + inst->replaceUsesWith(type); + inst->removeAndDeallocate(); + changed = true; + } + } + break; + default: + break; + } + }); + return changed; + } + // Process all differentiate calls, and recursively generate code for forward and backward // derivative functions. // @@ -518,6 +597,9 @@ struct AutoDiffPass : public InstPassBase { case kIROp_ForwardDifferentiate: case kIROp_BackwardDifferentiate: + case kIROp_BackwardDifferentiatePrimal: + case kIROp_BackwardDifferentiatePropagate: + case kIROp_BackwardDiffIntermediateContextType: // Only process now if the operand is a materialized function. switch (inst->getOperand(0)->getOp()) { @@ -538,29 +620,49 @@ struct AutoDiffPass : public InstPassBase // Process collected differentiate insts and replace them with placeholders for // differentiated functions. - for (auto differentiateInst : autoDiffWorkList) + for (Index i = 0; i < autoDiffWorkList.getCount(); i++) { - if (auto diffInst = as<IRForwardDifferentiate>(differentiateInst)) + auto differentiateInst = autoDiffWorkList[i]; + + IRInst* diffFunc = nullptr; + IRBuilder subBuilder(*builder); + subBuilder.setInsertBefore(differentiateInst); + switch (differentiateInst->getOp()) { - IRBuilder subBuilder(*builder); - subBuilder.setInsertBefore(differentiateInst); - if (auto diffFunc = forwardTranscriber.transcribe(&subBuilder, diffInst->getBaseFn())) + case kIROp_ForwardDifferentiate: { - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - changed = true; + auto baseFunc = as<IRForwardDifferentiate>(differentiateInst)->getBaseFn(); + diffFunc = forwardTranscriber.transcribe(&subBuilder, baseFunc); } - } - else if (auto backDiffInst = as<IRBackwardDifferentiate>(differentiateInst)) - { - auto baseInst = backDiffInst->getBaseFn(); - if (auto diffFunc = backwardTranscriber.transcribe(builder, (IRFunc*)baseInst)) + break; + case kIROp_BackwardDifferentiatePrimal: + { + auto baseFunc = differentiateInst->getOperand(0); + diffFunc = backwardPrimalTranscriber.transcribe(&subBuilder, baseFunc); + } + break; + case kIROp_BackwardDifferentiatePropagate: { - SLANG_ASSERT(diffFunc); - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - changed = true; + auto baseFunc = differentiateInst->getOperand(0); + diffFunc = backwardPropagateTranscriber.transcribe(&subBuilder, baseFunc); } + break; + case kIROp_BackwardDifferentiate: + { + auto baseFunc = differentiateInst->getOperand(0); + diffFunc = backwardTranscriber.transcribe(&subBuilder, baseFunc); + } + break; + default: + break; + } + + if (diffFunc) + { + SLANG_ASSERT(diffFunc); + differentiateInst->replaceUsesWith(diffFunc); + differentiateInst->removeAndDeallocate(); + changed = true; } } @@ -591,8 +693,11 @@ struct AutoDiffPass : public InstPassBase case FuncBodyTranscriptionTaskType::Forward: forwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc); break; - case FuncBodyTranscriptionTaskType::Backward: - backwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc); + case FuncBodyTranscriptionTaskType::BackwardPrimal: + // Don't need to do anything, they will be filled by `backwardPropagateTranscriber`. + break; + case FuncBodyTranscriptionTaskType::BackwardPropagate: + backwardPropagateTranscriber.transcribeFunc(builder, primalFunc, diffFunc); break; default: break; @@ -616,6 +721,11 @@ struct AutoDiffPass : public InstPassBase hasChanges |= changed; } + if (lowerIntermediateContextType(builder)) + { + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + hasChanges = true; + } return hasChanges; } @@ -651,12 +761,28 @@ struct AutoDiffPass : public InstPassBase AutoDiffPass(AutoDiffSharedContext* context, DiagnosticSink* sink) : InstPassBase(context->moduleInst->getModule()), sink(sink), - forwardTranscriber(context, context->sharedBuilder, sink), - backwardTranscriber(context, context->sharedBuilder, sink), + forwardTranscriber(context, &sharedBuilderStorage, sink), + backwardPrimalTranscriber(context, &sharedBuilderStorage, sink), + backwardPropagateTranscriber(context, &sharedBuilderStorage, sink), + backwardTranscriber(context, &sharedBuilderStorage, sink), pairBuilderStorage(context), autodiffContext(context) { + + // We start by initializing our shared IR building state, + // since we will re-use that state for any code we + // generate along the way. + // + sharedBuilderStorage.init(module); + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + + context->sharedBuilder = &sharedBuilderStorage; + forwardTranscriber.pairBuilder = &pairBuilderStorage; + backwardPrimalTranscriber.pairBuilder = &pairBuilderStorage; + backwardPrimalTranscriber.fwdDiffTranscriber = &forwardTranscriber; + backwardPropagateTranscriber.pairBuilder = &pairBuilderStorage; + backwardPropagateTranscriber.fwdDiffTranscriber = &forwardTranscriber; backwardTranscriber.pairBuilder = &pairBuilderStorage; backwardTranscriber.fwdDiffTranscriber = &forwardTranscriber; } @@ -667,8 +793,13 @@ protected: // ForwardDiffTranscriber forwardTranscriber; + BackwardDiffPrimalTranscriber backwardPrimalTranscriber; + + BackwardDiffPropagateTranscriber backwardPropagateTranscriber; + BackwardDiffTranscriber backwardTranscriber; + // Diagnostic object from the compile request for // error messages. DiagnosticSink* sink; @@ -691,16 +822,6 @@ bool processAutodiffCalls( // Create shared context for all auto-diff related passes AutoDiffSharedContext autodiffContext(module->getModuleInst()); - // We start by initializing our shared IR building state, - // since we will re-use that state for any code we - // generate along the way. - // - SharedIRBuilder sharedBuilder; - sharedBuilder.init(module); - sharedBuilder.deduplicateAndRebuildGlobalNumberingMap(); - - autodiffContext.sharedBuilder = &sharedBuilder; - AutoDiffPass pass(&autodiffContext, sink); modified |= pass.processModule(); diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index e0508cef7..1415618e1 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -37,7 +37,7 @@ typedef DiffInstPair<IRInst*, IRInst*> InstPair; enum class FuncBodyTranscriptionTaskType { - Forward, Backward, Primal + Forward, BackwardPrimal, BackwardPropagate, Backward }; struct FuncBodyTranscriptionTask diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 8440f4181..b721f4225 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -60,6 +60,7 @@ INST(Nop, nop, 0, 0) INST(OptionalType, Optional, 1, 0) INST(DifferentialPairType, DiffPair, 1, 0) + INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, 0) /* BindExistentialsTypeBase */ @@ -731,6 +732,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(BackwardDifferentiableDecoration, backwardDifferentiable, 1, 0) /// Decorated function is marked for the reverse-mode differentiation pass. + INST(BackwardDerivativePrimalDecoration, backwardDiffPrimalReference, 1, 0) + INST(BackwardDerivativePropagateDecoration, backwardDiffPropagateReference, 1, 0) + INST(BackwardDerivativeIntermediateTypeDecoration, backwardDiffIntermediateTypeReference, 1, 0) INST(BackwardDerivativeDecoration, backwardDiffReference, 1, 0) /// Used by the auto-diff pass to mark insts that compute @@ -815,8 +819,18 @@ INST(CastToVoid, castToVoid, 1, 0) INST(IsType, IsType, 3, 0) INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0) -INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0) -INST(DifferentialEqualityTypeCast, DifferentialEqualityTypeCast, 1, 0) + +// Produces the primal computation of backward derivatives, will return an intermediate context for +// backward derivative func. +INST(BackwardDifferentiatePrimal, BackwardDifferentiatePrimal, 1, 0) + +// Produces the actual backward derivative propagate function, using the intermediate context returned by the +// primal func produced from `BackwardDifferentiatePrimal`. +INST(BackwardDifferentiatePropagate, BackwardDifferentiatePropagate, 1, 0) + +// Represents the conceptual backward derivative function. Only produced by lower-to-ir and will be +// replaced with `BackwardDifferentiatePrimal` and `BackwardDifferentiatePropagate`. +INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0) // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) @@ -875,6 +889,11 @@ INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0) /* DifferentiableMethodRequirementDictionaryItem */ INST(ForwardDifferentiableMethodRequirementDictionaryItem, DifferentiableMethodRequirementDictionaryItem, 0, 0) INST(BackwardDifferentiableMethodRequirementDictionaryItem, DifferentiableMethodRequirementDictionaryItem, 0, 0) + INST(BackwardDifferentiablePrimalMethodRequirementDictionaryItem, DifferentiablePrimalMethodRequirementDictionaryItem, 0, 0) + INST(BackwardDifferentiablePropagateMethodRequirementDictionaryItem, DifferentiablePropagateMethodRequirementDictionaryItem, 0, 0) + INST(BackwardDifferentiableIntermediateTypeRequirementDictionaryItem, DifferentiableIntermediateTypeRequirementDictionaryItem, 0, 0) + + INST_RANGE(DifferentiableMethodRequirementDictionaryItem, ForwardDifferentiableMethodRequirementDictionaryItem, BackwardDifferentiableMethodRequirementDictionaryItem) #undef PARENT diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 03a3fb063..d2a4c7ae3 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -585,6 +585,38 @@ struct IRForwardDerivativeDecoration : IRDecoration IRInst* getForwardDerivativeFunc() { return getOperand(0); } }; +struct IRBackwardDerivativeIntermediateTypeDecoration : IRDecoration +{ + enum + { + kOp = kIROp_BackwardDerivativeIntermediateTypeDecoration + }; + IR_LEAF_ISA(BackwardDerivativeIntermediateTypeDecoration) + + IRInst* getBackwardDerivativeIntermediateType() { return getOperand(0); } +}; + +struct IRBackwardDerivativePrimalDecoration : IRDecoration +{ + enum + { + kOp = kIROp_BackwardDerivativePrimalDecoration + }; + IR_LEAF_ISA(BackwardDerivativePrimalDecoration) + + IRInst* getBackwardDerivativePrimalFunc() { return getOperand(0); } +}; + +struct IRBackwardDerivativePropagateDecoration : IRDecoration +{ + enum + { + kOp = kIROp_BackwardDerivativePropagateDecoration + }; + IR_LEAF_ISA(BackwardDerivativePropagateDecoration) + + IRInst* getBackwardDerivativePropagateFunc() { return getOperand(0); } +}; struct IRBackwardDerivativeDecoration : IRDecoration { @@ -681,7 +713,45 @@ struct IRForwardDifferentiate : IRInst }; // An instruction that replaces the function symbol -// with it's derivative function. +// with its backward derivative primal function. +// A backward derivative primal function is the first pass +// of backward derivative computation. It performs the primal +// computations and returns the intermediates that will be used +// by the actual backward derivative function. +struct IRBackwardDifferentiatePrimal : IRInst +{ + enum + { + kOp = kIROp_BackwardDifferentiatePrimal + }; + // The base function for the call. + IRUse base; + IRInst* getBaseFn() { return getOperand(0); } + + IR_LEAF_ISA(BackwardDifferentiatePrimal) +}; + +// An instruction that replaces the function symbol with its backward derivative propagate function. +// A backward derivative propagate function is the second pass of backward derivative computation. It uses the +// intermediates computed in the bacward derivative primal function to perform the actual backward +// derivative propagation. +struct IRBackwardDifferentiatePropagate : IRInst +{ + enum + { + kOp = kIROp_BackwardDifferentiatePropagate + }; + // The base function for the call. + IRUse base; + IRInst* getBaseFn() { return getOperand(0); } + + IR_LEAF_ISA(BackwardDifferentiatePropagate) +}; + +// An instruction that replaces the function symbol with its backward derivative function. +// A backward derivative function is a concept that combines both passes of backward derivative +// computation. This inst should only be produced by lower-to-ir, and will be replaced with calls to +// the primal function followed by the propagate function in the auto-diff pass. struct IRBackwardDifferentiate : IRInst { enum @@ -2556,6 +2626,8 @@ public: IRType* valueType, IRInst* witnessTable); + IRBackwardDiffIntermediateContextType* getBackwardDiffIntermediateContextType(IRInst* func); + IRFuncType* getFuncType( UInt paramCount, IRType* const* paramTypes, @@ -2664,6 +2736,8 @@ public: IRInst* emitForwardDifferentiateInst(IRType* type, IRInst* baseFn); IRInst* emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn); + IRInst* emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn); + IRInst* emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn); IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); @@ -3399,11 +3473,26 @@ public: addDecoration(value, kIROp_ForwardDerivativeDecoration, fwdFunc); } + void addBackwardDerivativePrimalDecoration(IRInst* value, IRInst* jvpFn) + { + addDecoration(value, kIROp_BackwardDerivativePrimalDecoration, jvpFn); + } + + void addBackwardDerivativePropagateDecoration(IRInst* value, IRInst* jvpFn) + { + addDecoration(value, kIROp_BackwardDerivativePropagateDecoration, jvpFn); + } + void addBackwardDerivativeDecoration(IRInst* value, IRInst* jvpFn) { addDecoration(value, kIROp_BackwardDerivativeDecoration, jvpFn); } + void addBackwardDerivativeIntermediateTypeDecoration(IRInst* value, IRInst* jvpFn) + { + addDecoration(value, kIROp_BackwardDerivativeIntermediateTypeDecoration, jvpFn); + } + void markInstAsDifferential(IRInst* value) { addDecoration(value, kIROp_DifferentialInstDecoration, nullptr); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 46c7b3363..de970fbca 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -403,7 +403,7 @@ struct SpecializationContext // If the base is specialized, the JVP version must be also be a specialized // generic. // - SLANG_ASSERT(specDiffFunc); + SLANG_RELEASE_ASSERT(specDiffFunc); // Build specialization arguments from specInst. // Note that if we've reached this point, we can safely assume diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 81b5d636a..8e3e879ad 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1,5 +1,6 @@ #include "slang-ir-util.h" #include "slang-ir-insts.h" +#include "slang-ir-clone.h" namespace Slang { @@ -143,4 +144,77 @@ IRInst* specializeWithGeneric(IRBuilder& builder, IRInst* genericToSpecialize, I genArgs.getBuffer()); } +IRInst* maybeSpecializeWithGeneric(IRBuilder& builder, IRInst* genericToSpecailize, IRInst* userGeneric) +{ + if (auto gen = as<IRGeneric>(userGeneric)) + { + if (auto toSpecialize = as<IRGeneric>(genericToSpecailize)) + { + return specializeWithGeneric(builder, toSpecialize, gen); + } + } + return genericToSpecailize; +} + +IRInst* hoistValueFromGeneric(IRBuilder& builder, IRInst* value, IRInst*& outSpecializedVal, bool replaceExistingValue) +{ + auto outerGeneric = as<IRGeneric>(findOuterGeneric(value)); + if (!outerGeneric) return value; + + builder.setInsertBefore(outerGeneric); + auto newGeneric = builder.emitGeneric(); + builder.setInsertInto(newGeneric); + builder.emitBlock(); + IRInst* newResultVal = nullptr; + + // Clone insts in outerGeneric up until `value`. + IRCloneEnv cloneEnv; + for (auto inst : outerGeneric->getFirstBlock()->getChildren()) + { + auto newInst = cloneInst(&cloneEnv, &builder, inst); + if (inst == value) + { + builder.emitReturn(newInst); + newResultVal = newInst; + break; + } + } + SLANG_RELEASE_ASSERT(newResultVal); + if (newResultVal->getOp() == kIROp_Func) + { + IRBuilder subBuilder = builder; + IRInst* subOutSpecialized = nullptr; + auto genericFuncType = hoistValueFromGeneric(subBuilder, newResultVal->getFullType(), subOutSpecialized, false); + newGeneric->setFullType((IRType*)genericFuncType); + } + else + { + newGeneric->setFullType(builder.getTypeKind()); + } + if (replaceExistingValue) + { + builder.setInsertBefore(value); + outSpecializedVal = specializeWithGeneric(builder, newGeneric, outerGeneric); + value->replaceUsesWith(outSpecializedVal); + value->removeAndDeallocate(); + } + return newGeneric; +} + +void moveInstChildren(IRInst* dest, IRInst* src) +{ + for (auto child = dest->getFirstDecorationOrChild(); child; ) + { + auto next = child->getNextInst(); + child->removeAndDeallocate(); + child = next; + } + for (auto child = src->getFirstDecorationOrChild(); child; ) + { + auto next = child->getNextInst(); + child->insertAtEnd(dest); + child = next; + } +} + } diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 2087ee4a7..49f46d0e3 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -61,6 +61,40 @@ inline bool isChildInstOf(IRInst* inst, IRInst* parent) IRInst* specializeWithGeneric( IRBuilder& builder, IRInst* genericToSpecialize, IRGeneric* userGeneric); +IRInst* maybeSpecializeWithGeneric(IRBuilder& builder, IRInst* genericToSpecailize, IRInst* userGeneric); + + // For a value inside a generic, create a standalone generic wrapping just the value, and replace the use of + // the original value with a specialization of the new generic using the current generic arguments if + // `replaceExistingValue` is true. + // For example, if we have + // ``` + // generic G { param T; v = x(T); f = y(v); return f; } + // ``` + // hoistValueFromGeneric(G, v) turns the code into: + // ``` + // generic G1 { param T1; v1 = x(T); return v1; } + // generic G { param T; v = specialize(G1, T); f = y(v); return f; } + // ``` + // This function returns newly created generic inst. + // if `value` is not inside any generic, this function makes no change to IR, and returns `value`. +IRInst* hoistValueFromGeneric( + IRBuilder& builder, + IRInst* value, + IRInst*& outSpecializedVal, + bool replaceExistingValue = false); + +// Clear dest and move all chidlren from src to dest. +void moveInstChildren(IRInst* dest, IRInst* src); + +inline bool isGenericParam(IRInst* param) +{ + auto parent = param->getParent(); + if (auto block = as<IRBlock>(parent)) + parent = block->getParent(); + if (as<IRGeneric>(parent)) + return true; + return false; +} inline IRInst* unwrapAttributedType(IRInst* type) { diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index d8a8fb7c4..9e0e328bd 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -300,6 +300,11 @@ namespace Slang return as<IRParam>(getNextInst()); } + IRParam* IRParam::getPrevParam() + { + return as<IRParam>(getPrevInst()); + } + // IRArrayTypeBase IRInst* IRArrayTypeBase::getElementCount() @@ -2802,6 +2807,15 @@ namespace Slang operands); } + IRBackwardDiffIntermediateContextType* IRBuilder::getBackwardDiffIntermediateContextType( + IRInst* func) + { + return (IRBackwardDiffIntermediateContextType*)getType( + kIROp_BackwardDiffIntermediateContextType, + 1, + &func); + } + IRFuncType* IRBuilder::getFuncType( UInt paramCount, IRType* const* paramTypes, @@ -3129,6 +3143,28 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitBackwardDifferentiatePrimalInst(IRType* type, IRInst* baseFn) + { + auto inst = createInst<IRBackwardDifferentiatePrimal>( + this, + kIROp_BackwardDifferentiatePrimal, + type, + baseFn); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn) + { + auto inst = createInst<IRBackwardDifferentiatePropagate>( + this, + kIROp_BackwardDifferentiatePropagate, + type, + baseFn); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential) { SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type)); @@ -6622,6 +6658,7 @@ namespace Slang case kIROp_UnpackAnyValue: case kIROp_Reinterpret: case kIROp_GetNativePtr: + case kIROp_BackwardDiffIntermediateContextType: return false; case kIROp_ForwardDifferentiate: @@ -6904,6 +6941,16 @@ namespace Slang } return nullptr; } + + IRInst* getGenericReturnVal(IRInst* inst) + { + if (auto gen = as<IRGeneric>(inst)) + { + return findGenericReturnVal(gen); + } + return inst; + } + } // namespace Slang #if SLANG_VC @@ -6917,4 +6964,3 @@ SLANG_API const int SlangDebug__IROpStringLit = Slang::kIROp_StringLit; SLANG_API const int SlangDebug__IROpIntLit = Slang::kIROp_IntLit; #endif #endif - diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 56a33c02b..b4a657545 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1347,6 +1347,12 @@ struct IRDifferentialPairType : IRType IR_LEAF_ISA(DifferentialPairType) }; +struct IRBackwardDiffIntermediateContextType : IRType +{ + IRInst* getFunc() { return getOperand(0); } + IR_LEAF_ISA(BackwardDiffIntermediateContextType) +}; + struct IRVectorType : IRType { IRType* getElementType() { return (IRType*)getOperand(0); } @@ -1743,6 +1749,9 @@ IRInst* findGenericReturnVal(IRGeneric* generic); // Recursively find the inner most generic return value. IRInst* findInnerMostGenericReturnVal(IRGeneric* generic); +// Returns the generic return val if `inst` is a generic, otherwise returns `inst`. +IRInst* getGenericReturnVal(IRInst* inst); + // Find the generic container, if any, that this inst is contained in // Returns nullptr if there is no outer container. IRInst* findOuterGeneric(IRInst* inst); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index a84cf9b8d..6803e1cb4 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1407,6 +1407,33 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(diff); } + LoweredValInfo visitBackwardDifferentiatePropagateVal(BackwardDifferentiatePropagateVal* val) + { + auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind()); + SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); + + auto diff = getBuilder()->emitBackwardDifferentiatePropagateInst(getBuilder()->getTypeKind(), funcVal.val); + return LoweredValInfo::simple(diff); + } + + LoweredValInfo visitBackwardDifferentiatePrimalVal(BackwardDifferentiatePrimalVal* val) + { + auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind()); + SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); + + auto diff = getBuilder()->emitBackwardDifferentiatePrimalInst(getBuilder()->getTypeKind(), funcVal.val); + return LoweredValInfo::simple(diff); + } + + LoweredValInfo visitBackwardDifferentiateIntermediateTypeVal(BackwardDifferentiateIntermediateTypeVal* val) + { + auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind()); + SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); + + auto diff = getBuilder()->getBackwardDiffIntermediateContextType(funcVal.val); + return LoweredValInfo::simple(diff); + } + LoweredValInfo visitDifferentialBottomSubtypeWitness(DifferentialBottomSubtypeWitness*) { return LoweredValInfo(); @@ -6816,9 +6843,23 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> context->irBuilder->addDecoration( interfaceType, kIROp_DifferentiableMethodRequirementDictionaryDecoration); } - auto op = as<ForwardDerivativeRequirementDecl>(requirementDecl) - ? kIROp_ForwardDifferentiableMethodRequirementDictionaryItem - : kIROp_BackwardDifferentiableMethodRequirementDictionaryItem; + IROp op = kIROp_ForwardDifferentiableMethodRequirementDictionaryItem; + if (as<BackwardDerivativeRequirementDecl>(requirementDecl)) + { + op = kIROp_BackwardDifferentiableMethodRequirementDictionaryItem; + } + else if (as<BackwardDerivativePropagateRequirementDecl>(requirementDecl)) + { + op = kIROp_BackwardDifferentiablePropagateMethodRequirementDictionaryItem; + } + else if (as<BackwardDerivativePrimalRequirementDecl>(requirementDecl)) + { + op = kIROp_BackwardDifferentiablePrimalMethodRequirementDictionaryItem; + } + else if (as<BackwardDerivativeIntermediateTypeRequirementDecl>(requirementDecl)) + { + op = kIROp_BackwardDifferentiableIntermediateTypeRequirementDictionaryItem; + } IRInst* args[] = {originalKey, associatedKey}; auto assoc = context->irBuilder->emitIntrinsicInst(nullptr, op, 2, args); assoc->insertAtEnd(decor); @@ -8405,6 +8446,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } + LoweredValInfo visitBackwardDerivativeIntermediateTypeRequirementDecl(BackwardDerivativeIntermediateTypeRequirementDecl* decl) + { + SLANG_UNUSED(decl); + return LoweredValInfo(getBuilder()->getTypeKind()); + } + LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl) { // A function declaration may have multiple, target-specific diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index 6ea9ea01e..58d6aaae3 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -521,6 +521,12 @@ namespace Slang emitRaw(context, "FwdReq_"); else if (as<BackwardDerivativeRequirementDecl>(decl)) emitRaw(context, "BwdReq_"); + else if (as<BackwardDerivativePropagateRequirementDecl>(decl)) + emitRaw(context, "BwdReq_Prop_"); + else if (as<BackwardDerivativePrimalRequirementDecl>(decl)) + emitRaw(context, "BwdReq_Primal_"); + else if (as<BackwardDerivativeIntermediateTypeRequirementDecl>(decl)) + emitRaw(context, "BwdReq_CtxType_"); else { // TODO: handle other cases |
