diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-10 12:42:55 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-10 12:42:55 -0800 |
| commit | 2f422087ed04940f6b6b351605e61d48ce1989ce (patch) | |
| tree | 522f8027173732d903a906081238b12863d73fb8 /source | |
| parent | eb813fbd8750ed1ab66d73f5fa29ae8f2407e8af (diff) | |
Nested bwd-diff func call context save/restore. (#2584)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 48 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 177 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.h | 33 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 28 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 171 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 31 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 34 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-types.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 55 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 11 |
19 files changed, 407 insertions, 239 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7c8e320c4..b8732a67f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -5672,7 +5672,8 @@ namespace Slang { // Requirement for backward derivative. auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); - auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef))); + auto originalFuncType = getFuncType(m_astBuilder, declRef); + auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(originalFuncType)); { auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>(); cloneModifiers(reqDecl, decl); @@ -5704,8 +5705,8 @@ namespace Slang auto reqDecl = m_astBuilder->create<BackwardDerivativePrimalRequirementDecl>(); cloneModifiers(reqDecl, decl); FuncType* primalFuncType = m_astBuilder->create<FuncType>(); - primalFuncType->resultType = diffFuncType->resultType; - primalFuncType->paramTypes.addRange(diffFuncType->paramTypes); + primalFuncType->resultType = originalFuncType->resultType; + primalFuncType->paramTypes.addRange(originalFuncType->paramTypes); auto outType = m_astBuilder->getOutType(intermediateType); primalFuncType->paramTypes.add(outType); setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 00db6bd96..d50cc45a3 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -387,6 +387,8 @@ Result linkAndOptimizeIR( finalizeAutoDiffPass(irModule); + finalizeSpecialization(irModule); + // If we have a target that is GPU like we use the string hashing mechanism // but for that to work we need to inline such that calls (or returns) of strings // boil down into getStringHash(stringLiteral) diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 19678f402..e37415446 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -864,7 +864,7 @@ InstPair ForwardDiffTranscriber::transcribeSingleOperandInst(IRBuilder* builder, IRInst* diffResult = nullptr; - if (auto diffType = differentiateType(builder, primalType)) + if (auto diffType = differentiateType(builder, origInst->getDataType())) { if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) { @@ -930,7 +930,33 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu { if (auto bwdDecor = origFunc->findDecoration<IRForwardDerivativeDecoration>()) return InstPair(origFunc, bwdDecor->getForwardDerivativeFunc()); + + auto diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc); + + if (auto outerGen = findOuterGeneric(diffFunc)) + { + IRBuilder subBuilder = *inBuilder; + subBuilder.setInsertBefore(origFunc); + auto specialized = + specializeWithGeneric(subBuilder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc))); + subBuilder.addForwardDerivativeDecoration(origFunc, specialized); + } + else + { + inBuilder->addForwardDerivativeDecoration(origFunc, diffFunc); + } + + FuncBodyTranscriptionTask task; + task.type = FuncBodyTranscriptionTaskType::Forward; + task.originalFunc = origFunc; + task.resultFunc = diffFunc; + autoDiffSharedContext->followUpFunctionsToTranscribe.add(task); + + return InstPair(origFunc, diffFunc); +} +IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc) +{ IRBuilder builder = *inBuilder; IRFunc* primalFunc = origFunc; @@ -955,17 +981,6 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu newNameSb << "s_fwd_" << originalName; builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); } - - if (auto outerGen = findOuterGeneric(diffFunc)) - { - auto specialized = - specializeWithGeneric(builder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc))); - builder.addForwardDerivativeDecoration(origFunc, specialized); - } - else - { - builder.addForwardDerivativeDecoration(origFunc, diffFunc); - } // Mark the generated derivative function itself as differentiable. builder.addForwardDifferentiableDecoration(diffFunc); @@ -975,14 +990,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu { cloneDecoration(dictDecor, diffFunc); } - - FuncBodyTranscriptionTask task; - task.type = FuncBodyTranscriptionTaskType::Forward; - task.originalFunc = primalFunc; - task.resultFunc = diffFunc; - autoDiffSharedContext->followUpFunctionsToTranscribe.add(task); - - return InstPair(primalFunc, diffFunc); + return diffFunc; } // Transcribe a function definition. diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 869b25ffd..828916c01 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -81,6 +81,9 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase // Transcribe a generic definition InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric); + // Transcribe a function without marking the result as a decoration. + IRFunc* transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc); + // Create an empty func to represent the transcribed func of `origFunc`. virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index b6704011c..817534065 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -60,7 +60,15 @@ namespace Slang { auto intermediateType = builder->getBackwardDiffIntermediateContextType(func); auto outType = builder->getOutType(intermediateType); - return differentiateFunctionTypeImpl(builder, funcType, outType); + List<IRType*> paramTypes; + for (UInt i = 0; i < funcType->getParamCount(); i++) + { + paramTypes.add(funcType->getParamType(i)); + } + paramTypes.add(outType); + IRFuncType* primalFuncType = builder->getFuncType( + paramTypes, funcType->getResultType()); + return primalFuncType; } InstPair BackwardDiffPrimalTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) @@ -210,8 +218,6 @@ namespace Slang differentiableTypeConformanceContext.setFunc(origFunc); - primalFunc = origFunc; - auto diffFunc = builder.createFunc(); SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); @@ -278,27 +284,65 @@ namespace Slang builder.setInsertInto(header.differential); builder.emitBlock(); auto funcType = as<IRFuncType>(header.differential->getDataType()); - List<IRInst*> args; + List<IRInst*> primalArgs, propagateArgs; + List<IRType*> primalTypes, propagateTypes; for (UInt i = 0; i < funcType->getParamCount(); i++) { auto paramType = funcType->getParamType(i); - args.add(builder.emitParam(paramType)); + auto param = builder.emitParam(paramType); + if (i != funcType->getParamCount() - 1) + { + primalArgs.add(param); + } + propagateArgs.add(param); + propagateTypes.add(paramType); + } + + // Fetch primal values to use as arguments in primal func call. + for (auto& arg : primalArgs) + { + IRInst* valueType = arg->getDataType(); + auto inoutType = as<IRPtrTypeBase>(arg->getDataType()); + if (inoutType) + { + valueType = inoutType->getValueType(); + arg = builder.emitLoad(arg); + } + auto diffPairType = as<IRDifferentialPairType>(valueType); + if (!diffPairType) continue; + arg = builder.emitDifferentialPairGetPrimal(arg); } + + for (auto& arg : primalArgs) + { + primalTypes.add(arg->getFullType()); + } + auto outerGeneric = findOuterGeneric(origFunc); IRInst* specializedOriginalFunc = origFunc; if (outerGeneric) { specializedOriginalFunc = maybeSpecializeWithGeneric(builder, outerGeneric, findOuterGeneric(header.differential)); } + auto intermediateType = builder.getBackwardDiffIntermediateContextType(specializedOriginalFunc); auto intermediateVar = builder.emitVar(intermediateType); - auto primalFunc = builder.emitBackwardDifferentiatePrimalInst(builder.getTypeKind(), specializedOriginalFunc); - auto propagateFunc = builder.emitBackwardDifferentiatePropagateInst(builder.getTypeKind(), specializedOriginalFunc); - args.add(intermediateVar); - builder.emitCallInst(builder.getVoidType(), primalFunc, args); - args.removeLast(); - args.add(builder.emitLoad(intermediateVar)); - builder.emitCallInst(builder.getVoidType(), propagateFunc, args); + + auto origFuncType = as<IRFuncType>(origFunc->getDataType()); + auto primalFuncType = builder.getFuncType( + primalTypes, + origFuncType->getResultType()); + primalArgs.add(intermediateVar); + primalTypes.add(builder.getOutType(intermediateType)); + auto primalFunc = builder.emitBackwardDifferentiatePrimalInst(primalFuncType, specializedOriginalFunc); + builder.emitCallInst(origFuncType->getResultType(), primalFunc, primalArgs); + + propagateTypes.add(intermediateType); + propagateArgs.add(builder.emitLoad(intermediateVar)); + auto propagateFuncType = builder.getFuncType(propagateTypes, builder.getVoidType()); + auto propagateFunc = builder.emitBackwardDifferentiatePropagateInst(propagateFuncType, specializedOriginalFunc); + builder.emitCallInst(builder.getVoidType(), propagateFunc, propagateArgs); + builder.emitReturn(); return header; } @@ -339,98 +383,6 @@ namespace Slang builder.emitBranch(firstBlock); } - void BackwardDiffTranscriberBase::cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType) - { - IRStructType* structType = as<IRStructType>(intermediateType); - if (!structType) - { - auto genType = as<IRGeneric>(intermediateType); - structType = as<IRStructType>(findGenericReturnVal(genType)); - SLANG_RELEASE_ASSERT(structType); - } - - // Collect fields that are never fetched by reverse func. - OrderedHashSet<IRStructKey*> fieldsToCleanup; - for (auto children : structType->getChildren()) - { - if (auto field = as<IRStructField>(children)) - { - auto structKey = field->getKey(); - bool usedByRevFunc = false; - for (auto use = structKey->firstUse; use; use = use->nextUse) - { - if (isChildInstOf(use->getUser(), func)) - { - usedByRevFunc = true; - break; - } - } - if (!usedByRevFunc) - { - List<IRInst*> users; - for (auto use = structKey->firstUse; use; use = use->nextUse) - { - users.add(use->getUser()); - } - for (auto user : users) - { - if (!isChildInstOf(user, primalFunc)) - continue; - if (auto addr = as<IRFieldAddress>(user)) - { - if (addr->hasMoreThanOneUse()) - continue; - if (addr->firstUse) - { - if (addr->firstUse->getUser()->getOp() == kIROp_Store) - { - addr->firstUse->getUser()->removeAndDeallocate(); - } - addr->removeAndDeallocate(); - } - } - } - - bool hasNonTrivialUse = false; - for (auto use = structKey->firstUse; use; use = use->nextUse) - { - switch (use->getUser()->getOp()) - { - case kIROp_PrimalValueStructKeyDecoration: - case kIROp_StructField: - continue; - default: - hasNonTrivialUse = true; - break; - } - } - if (!hasNonTrivialUse) - { - fieldsToCleanup.Add(structKey); - } - } - } - } - - // Actually remove fields from struct. - for (auto children : structType->getChildren()) - { - if (auto field = as<IRStructField>(children)) - { - if (fieldsToCleanup.Contains(field->getKey())) - { - auto key = field->getKey(); - List<IRInst*> keyUsers; - for (auto use = key->firstUse; use; use = use->nextUse) - keyUsers.add(use->getUser()); - for (auto keyUser : keyUsers) - keyUser->removeAndDeallocate(); - key->removeAndDeallocate(); - } - } - } - } - // Transcribe a function definition. void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc) { @@ -442,11 +394,11 @@ namespace Slang // Generate a temporary forward derivative function as an intermediate step. IRBuilder tempBuilder = *builder; tempBuilder.setInsertBefore(diffPropagateFunc); - IRFunc* fwdDiffFunc = as<IRFunc>( - fwdDiffTranscriber->transcribeFuncHeader(&tempBuilder, primalFunc).differential); + ForwardDiffTranscriber* fwdTranscriber = static_cast<ForwardDiffTranscriber*>(autoDiffSharedContext->transcriberSet.forwardTranscriber); + IRFunc* fwdDiffFunc = as<IRFunc>(fwdTranscriber->transcribeFuncHeaderImpl(&tempBuilder, primalFunc)); SLANG_ASSERT(fwdDiffFunc); - fwdDiffTranscriber->transcribeFunc(&tempBuilder, primalFunc, fwdDiffFunc); + fwdTranscriber->transcribeFunc(&tempBuilder, primalFunc, fwdDiffFunc); // Split first block into a paramter block. this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc)); @@ -473,7 +425,8 @@ namespace Slang // for that. // builder->setInsertInto(diffPropagateFunc->getParent()); - auto tempDiffFunc = as<IRFunc>(cloneInst(&cloneEnv, builder, unzippedFwdDiffFunc)); + IRCloneEnv subCloneEnv; + auto tempDiffFunc = as<IRFunc>(cloneInst(&subCloneEnv, builder, unzippedFwdDiffFunc)); // Move blocks to the diffFunc shell. { @@ -496,18 +449,18 @@ namespace Slang DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr}; diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info); + eliminateDeadCode(diffPropagateFunc); + // Extracts the primal computations into its own func, and replace the primal insts // with the intermediate results computed from the extracted func. IRInst* intermediateType = nullptr; - auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffPropagateFunc, unzippedFwdDiffFunc, intermediateType); + auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc( + diffPropagateFunc, primalFunc, intermediateType); // Clean up by deallocating intermediate versions. tempDiffFunc->removeAndDeallocate(); unzippedFwdDiffFunc->removeAndDeallocate(); fwdDiffFunc->removeAndDeallocate(); - - eliminateDeadCode(diffPropagateFunc); - cleanUpUnusedPrimalIntermediate(diffPropagateFunc, extractedPrimalFunc, intermediateType); // If primal function is nested in a generic, we want to create separate generics for all the associated things // we have just created. diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index 378300789..decbdf150 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -30,7 +30,6 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase Dictionary<IRInst*, IRInst*> orginalToTranscribed; // References to other passes that for reverse-mode transcription. - ForwardDiffTranscriber* fwdDiffTranscriber; DiffTransposePass* diffTransposePass; DiffPropagationPass* diffPropagationPass; DiffUnzipPass* diffUnzipPass; @@ -40,7 +39,11 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase DiffPropagationPass diffPropagationPassStorage; DiffUnzipPass diffUnzipPassStorage; - BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType taskType, AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) + BackwardDiffTranscriberBase( + FuncBodyTranscriptionTaskType taskType, + AutoDiffSharedContext* shared, + SharedIRBuilder* inSharedBuilder, + DiagnosticSink* inSink) : AutoDiffTranscriberBase(shared, inSharedBuilder, inSink) , diffTaskType(taskType) , diffTransposePassStorage(shared) @@ -49,7 +52,7 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase , diffTransposePass(&diffTransposePassStorage) , diffPropagationPass(&diffPropagationPassStorage) , diffUnzipPass(&diffUnzipPassStorage) - { } + {} // Returns "dp<var-name>" to use as a name hint for parameters. // If no primal name is available, returns a blank string. @@ -63,8 +66,6 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase // Puts parameters into their own block. void makeParameterBlock(IRBuilder* inBuilder, IRFunc* func); - void cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType); - // Transcribe a function definition. virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) = 0; @@ -103,8 +104,12 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase struct BackwardDiffPrimalTranscriber : BackwardDiffTranscriberBase { - BackwardDiffPrimalTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) - : BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType::BackwardPrimal, shared, inSharedBuilder, inSink) + BackwardDiffPrimalTranscriber( + AutoDiffSharedContext* shared, + SharedIRBuilder* inSharedBuilder, + DiagnosticSink* inSink) + : BackwardDiffTranscriberBase( + FuncBodyTranscriptionTaskType::BackwardPrimal, shared, inSharedBuilder, inSink) { } virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; @@ -125,8 +130,15 @@ struct BackwardDiffPrimalTranscriber : BackwardDiffTranscriberBase struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase { - BackwardDiffPropagateTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) - : BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType::BackwardPropagate, shared, inSharedBuilder, inSink) + BackwardDiffPropagateTranscriber( + AutoDiffSharedContext* shared, + SharedIRBuilder* inSharedBuilder, + DiagnosticSink* inSink) + : BackwardDiffTranscriberBase( + FuncBodyTranscriptionTaskType::BackwardPropagate, + shared, + inSharedBuilder, + inSink) { } virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; @@ -153,7 +165,8 @@ struct BackwardDiffTranscriber : BackwardDiffTranscriberBase AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) - : BackwardDiffTranscriberBase(FuncBodyTranscriptionTaskType::Backward, shared, inSharedBuilder, inSink) + : BackwardDiffTranscriberBase( + FuncBodyTranscriptionTaskType::Backward, shared, inSharedBuilder, inSink) { } virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 4aab0f835..c0404e036 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -390,7 +390,7 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* if (lookupKeyPath.getCount()) { // `interfaceType` does conform to `IDifferentiable`. - outWitnessTable = builder->emitExtractExistentialWitnessTable(origType->getOperand(0)); + outWitnessTable = builder->emitExtractExistentialWitnessTable(lookupPrimalInstIfExists(origType->getOperand(0))); for (auto node : lookupKeyPath) { outWitnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), outWitnessTable, node->getRequirementKey()); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index 4c3bbe05f..a6b832856 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -67,6 +67,8 @@ struct AutoDiffTranscriberBase IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst); + IRInst* lookupPrimalInstIfExists(IRInst* origInst) { return lookupPrimalInst(origInst, origInst); } + bool hasPrimalInst(IRInst* origInst); IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 436a17a7f..fa9f4ffb2 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -155,12 +155,15 @@ struct DiffTransposePass // firstRevDiffBlockMap[revDiffFunc] = revBlockMap[workList[0]]; + IRInst* retVal = nullptr; + for (auto block : workList) { // Set dOutParameter as the transpose gradient for the return inst, if any. if (auto returnInst = as<IRReturn>(block->getTerminator())) { this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr)); + retVal = returnInst->getVal(); } IRBlock* revBlock = revBlockMap[block]; @@ -187,7 +190,18 @@ struct DiffTransposePass // There should be no parameters in the first reverse-mode block. SLANG_ASSERT(terminalRevBlock->getFirstParam() == nullptr); - subBuilder.emitBranch(terminalRevBlock); + auto branch = subBuilder.emitBranch(terminalRevBlock); + + if (!retVal) + { + retVal = subBuilder.getVoidValue(); + } + else + { + auto makePair = cast<IRMakeDifferentialPair>(retVal); + retVal = makePair->getPrimalValue(); + } + subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal); } // Remove fwd-mode blocks. @@ -498,6 +512,10 @@ struct DiffTransposePass } } + // The call must have been decorated with the continuation context after splitting. + auto primalContextDecor = fwdCall->findDecoration<IRBackwardDerivativePrimalContextDecoration>(); + SLANG_RELEASE_ASSERT(primalContextDecor); + auto baseFn = fwdDiffCallee->getBaseFn(); List<IRInst*> args; @@ -543,8 +561,14 @@ struct DiffTransposePass args.add(revValue); argTypes.add(revValue->getDataType()); + args.add(primalContextDecor->getBackwardDerivativePrimalContextVar()); + argTypes.add(builder->getOutType( + as<IRPtrTypeBase>( + primalContextDecor->getBackwardDerivativePrimalContextVar()->getDataType()) + ->getValueType())); + auto revFnType = builder->getFuncType(argTypes, builder->getVoidType()); - auto revCallee = builder->emitBackwardDifferentiateInst( + auto revCallee = builder->emitBackwardDifferentiatePropagateInst( revFnType, baseFn); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 1496ae60f..43b48aa13 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -7,10 +7,12 @@ namespace Slang struct ExtractPrimalFuncContext { SharedIRBuilder* sharedBuilder; + AutoDiffTranscriberBase* backwardPrimalTranscriber; - void init(SharedIRBuilder* inSharedBuilder) + void init(SharedIRBuilder* inSharedBuilder, AutoDiffTranscriberBase* transcriber) { sharedBuilder = inSharedBuilder; + backwardPrimalTranscriber = transcriber; } IRInst* cloneGenericHeader(IRBuilder& builder, IRCloneEnv& cloneEnv, IRGeneric* gen) @@ -65,14 +67,14 @@ struct ExtractPrimalFuncContext } IRInst* generatePrimalFuncType( - IRGlobalValueWithCode* destFunc, IRGlobalValueWithCode* fwdFunc, IRInst*& outIntermediateType) + IRGlobalValueWithCode* destFunc, IRGlobalValueWithCode* originalFunc, IRInst*& outIntermediateType) { IRBuilder builder(sharedBuilder); builder.setInsertBefore(destFunc); IRFuncType* originalFuncType = nullptr; outIntermediateType = createIntermediateType(destFunc); - originalFuncType = as<IRFuncType>(fwdFunc->getDataType()); + originalFuncType = as<IRFuncType>(originalFunc->getDataType()); SLANG_RELEASE_ASSERT(originalFuncType); List<IRType*> paramTypes; @@ -231,56 +233,46 @@ struct ExtractPrimalFuncContext return true; } - void storeInst( - IRBuilder& builder, - IRInst* inst, - IRInst* intermediateOutput) + IRStructField* addIntermediateContextField(IRInst* type, IRInst* intermediateOutput) { IRBuilder genTypeBuilder(sharedBuilder); - auto ptrStructType = as<IRPtrTypeBase>(intermediateOutput->getDataType() ); + auto ptrStructType = as<IRPtrTypeBase>(intermediateOutput->getDataType()); SLANG_RELEASE_ASSERT(ptrStructType); auto structType = as<IRStructType>(ptrStructType->getValueType()); genTypeBuilder.setInsertBefore(structType); - auto fieldType = inst->getDataType(); + auto fieldType = type; SLANG_RELEASE_ASSERT(structType); auto structKey = genTypeBuilder.createStructKey(); - if (auto nameHint = inst->findDecoration<IRNameHintDecoration>()) - cloneDecoration(nameHint, structKey); genTypeBuilder.setInsertInto(structType); - genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType); - builder.addPrimalValueStructKeyDecoration(inst, structKey); + return genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType); + } + + void storeInst( + IRBuilder& builder, + IRInst* inst, + IRInst* intermediateOutput) + { + auto field = addIntermediateContextField(inst->getDataType(), intermediateOutput); + auto key = field->getKey(); + if (auto nameHint = inst->findDecoration<IRNameHintDecoration>()) + cloneDecoration(nameHint, key); + builder.addPrimalValueStructKeyDecoration(inst, key); builder.emitStore( builder.emitFieldAddress( - builder.getPtrType(inst->getFullType()), intermediateOutput, structKey), + builder.getPtrType(inst->getFullType()), intermediateOutput, key), inst); } - IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* fwdFunc, IRInst*& outIntermediateType) + IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, IRInst*& outIntermediateType) { - // Note: this transformation assumes the original func has only one return. - IRBuilder builder(sharedBuilder); IRFunc* func = unzippedFunc; IRInst* intermediateType = nullptr; - auto newFuncType = generatePrimalFuncType(unzippedFunc, fwdFunc, intermediateType); + auto newFuncType = generatePrimalFuncType(unzippedFunc, originalFunc, intermediateType); outIntermediateType = intermediateType; func->setFullType((IRType*)newFuncType); - // Go through all the insts and preserve the primal blocks. - // Create a return block to replace all branches into a non-primal block. - builder.setInsertInto(func); - auto returnBlock = builder.emitBlock(); - for (auto block : func->getBlocks()) - { - auto term = block->getTerminator(); - if (auto ret = as<IRReturn>(term)) - { - insertIntoReturnBlock(builder, ret); - break; - } - } - auto paramBlock = func->getFirstBlock(); builder.setInsertInto(paramBlock); auto oldIntermediateParam = func->getLastParam(); @@ -317,53 +309,76 @@ struct ExtractPrimalFuncContext builder.setInsertAfter(inst); storeInst(builder, inst, outIntermediary); } + else if (inst->getOp() == kIROp_Var) + { + // Always store intermediate context var. + if (inst->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) + { + auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary); + builder.setInsertBefore(inst); + auto fieldAddr = builder.emitFieldAddress( + inst->getFullType(), outIntermediary, field->getKey()); + inst->replaceUsesWith(fieldAddr); + builder.addPrimalValueStructKeyDecoration(inst, field->getKey()); + } + } } } - // Go over differential blocks and complete - for (auto block : diffBlocksList) + for (auto block : primalBlocksList) { - - if (block->getFirstParam() == nullptr) - { - // If the block does not have any PHI nodes, just remove it and - // replace all its uses with returnBlock. - - // TODO: This invalides the next block in the chain. Make a list first. - block->replaceUsesWith(returnBlock); - block->removeAndDeallocate(); - } - else + auto term = block->getTerminator(); + builder.setInsertBefore(term); + if (auto decor = term->findDecoration<IRBackwardDerivativePrimalReturnDecoration>()) { - // If the block has Phi nodes, we can't directly replace it with - // `returnBlock`, but we can turn the block into a trivial branch - // into `returnBlock` to safely preserve the invariants of Phi nodes. - auto inst = block->getLastParam()->getNextInst(); - for (; inst;) - { - auto nextInst = inst->getNextInst(); - inst->removeAndDeallocate(); - inst = nextInst; - } - - builder.setInsertInto(block); - builder.emitBranch(returnBlock); + builder.emitReturn(decor->getBackwardDerivativePrimalReturnValue()); + term->removeAndDeallocate(); } } - + List<IRBlock*> unusedBlocks; for (auto block : func->getBlocks()) { - if (!block->hasUses() && isDiffInst(block)) + if (isDiffInst(block)) unusedBlocks.add(block); } - for (auto block : unusedBlocks) block->removeAndDeallocate(); builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); auto defVal = builder.emitDefaultConstructRaw((IRType*)intermediateType); builder.emitStore(outIntermediary, defVal); + + // The primal func will not have the result derivative param (second to last param), so we remove it. + auto resultDerivativeParam = func->getLastParam()->getPrevParam(); + SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses()); + resultDerivativeParam->removeAndDeallocate(); + + // Finally, go through parameters and turn DifferentiablePair<T> back to T. + for (auto param : func->getParams()) + { + IRInst* valueType = param->getDataType(); + auto inoutType = as<IRPtrTypeBase>(param->getDataType()); + if (inoutType) valueType = inoutType->getValueType(); + auto diffPairType = as<IRDifferentialPairType>(valueType); + if (!diffPairType) continue; + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + + auto originalValueType = diffPairType->getValueType(); + + // Create a local var to act as the old param. + auto tempVar = builder.emitVar(diffPairType); + param->replaceUsesWith(tempVar); + auto pairValue = builder.emitMakeDifferentialPair( + diffPairType, + param, + backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType)); + builder.emitStore(tempVar, pairValue); + + // Change the param type to original type. + param->setFullType(originalValueType); + } + return unzippedFunc; } }; @@ -386,7 +401,7 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE } IRFunc* DiffUnzipPass::extractPrimalFunc( - IRFunc* func, IRFunc* fwdFunc, IRInst*& intermediateType) + IRFunc* func, IRFunc* originalFunc, IRInst*& intermediateType) { IRBuilder builder(this->autodiffContext->sharedBuilder); builder.setInsertBefore(func); @@ -397,15 +412,19 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func)); ExtractPrimalFuncContext context; - context.init(autodiffContext->sharedBuilder); + context.init(autodiffContext->sharedBuilder, autodiffContext->transcriberSet.primalTranscriber); intermediateType = nullptr; - auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, fwdFunc, intermediateType); + auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, intermediateType); if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>()) { - auto primalName = String(nameHint->getName()) + "_primal"; - nameHint->setOperand(0, builder.getStringValue(primalName.getUnownedSlice())); + nameHint->removeAndDeallocate(); + } + if (auto originalNameHint = originalFunc->findDecoration<IRNameHintDecoration>()) + { + auto primalName = String("s_bwd_primal_") + UnownedStringSlice(originalNameHint->getName()); + builder.addNameHintDecoration(primalFunc, builder.getStringValue(primalName.getUnownedSlice())); } // Copy PrimalValueStructKey decorations from primal func. @@ -429,10 +448,26 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( builder.getPtrType(inst->getDataType()), intermediateVar, structKeyDecor->getStructKey()); - auto val = builder.emitLoad(addr); - inst->replaceUsesWith(val); + if (inst->getOp() == kIROp_Var) + { + // This is a var for intermediate context. + inst->replaceUsesWith(addr); + } + else + { + // Orindary value. + auto val = builder.emitLoad(addr); + inst->replaceUsesWith(val); + } instsToRemove.add(inst); } + else if (auto primalCtx = inst->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) + { + if (inst->getOp() == kIROp_Call) + { + builder.addSimpleDecoration<IRNoSideEffectDecoration>(inst); + } + } } } diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index f2ce3dc62..ba1e425db 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -8,6 +8,7 @@ #include "slang-ir-autodiff.h" #include "slang-ir-autodiff-fwd.h" #include "slang-ir-autodiff-propagate.h" +#include "slang-ir-autodiff-transcriber-base.h" namespace Slang { @@ -31,10 +32,10 @@ struct DiffUnzipPass // might run into an issue here? IRBlock* firstDiffBlock; - // Dictionary<IRBlock*, List<IRBlock*>> - - DiffUnzipPass(AutoDiffSharedContext* autodiffContext) : - autodiffContext(autodiffContext), diffTypeContext(autodiffContext) + DiffUnzipPass( + AutoDiffSharedContext* autodiffContext) + : autodiffContext(autodiffContext) + , diffTypeContext(autodiffContext) { } IRInst* lookupPrimalInst(IRInst* inst) @@ -71,9 +72,6 @@ struct DiffUnzipPass SLANG_ASSERT(unzippedFunc->getFirstBlock() != nullptr); SLANG_ASSERT(unzippedFunc->getFirstBlock()->getNextBlock() != nullptr); - // Ignore the first block (this is reserved for parameters), start - // at the second block. (For now, we work with only a single block of insts) - // TODO: expand to handle multi-block functions later. IRBlock* firstBlock = unzippedFunc->getFirstBlock()->getNextBlock(); List<IRBlock*> mixedBlocks; @@ -132,7 +130,7 @@ struct DiffUnzipPass return unzippedFunc; } - IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* fwdFunc, IRInst*& intermediateType); + IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* originalFunc, IRInst*& intermediateType); bool isRelevantDifferentialPair(IRType* type) { @@ -160,6 +158,14 @@ struct DiffUnzipPass auto fwdCalleeType = as<IRFuncType>(fwdCallee->getDataType()); auto baseFn = fwdCallee->getBaseFn(); + auto primalFuncType = autodiffContext->transcriberSet.primalTranscriber->differentiateFunctionType( + primalBuilder, baseFn, as<IRFuncType>(baseFn->getDataType())); + + auto intermediateVar = primalBuilder->emitVar(primalBuilder->getBackwardDiffIntermediateContextType(baseFn)); + primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar); + + auto primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn); + List<IRInst*> primalArgs; for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++) { @@ -174,6 +180,7 @@ struct DiffUnzipPass primalArgs.add(arg); } } + primalArgs.add(intermediateVar); auto mixedDecoration = mixedCall->findDecoration<IRMixedDifferentialInstDecoration>(); SLANG_ASSERT(mixedDecoration); @@ -184,8 +191,9 @@ struct DiffUnzipPass auto primalType = fwdPairResultType->getValueType(); auto diffType = (IRType*) diffTypeContext.getDifferentialForType(&globalBuilder, primalType); - auto primalVal = primalBuilder->emitCallInst(primalType, baseFn, primalArgs); - + auto primalVal = primalBuilder->emitCallInst(primalType, primalFn, primalArgs); + primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar); + List<IRInst*> diffArgs; for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++) { @@ -215,6 +223,7 @@ struct DiffUnzipPass } auto newFwdCallee = diffBuilder->emitForwardDifferentiateInst(fwdCalleeType, baseFn); + diffBuilder->markInstAsDifferential(newFwdCallee); auto diffPairVal = diffBuilder->emitCallInst( @@ -222,6 +231,7 @@ struct DiffUnzipPass newFwdCallee, diffArgs); diffBuilder->markInstAsDifferential(diffPairVal, primalType); + diffBuilder->addBackwardDerivativePrimalContextDecoration(diffPairVal, intermediateVar); auto diffVal = diffBuilder->emitDifferentialPairGetDifferential(diffType, diffPairVal); diffBuilder->markInstAsDifferential(diffVal, primalType); @@ -272,7 +282,6 @@ struct DiffUnzipPass // Check that we have an unambiguous 'first' differential block. SLANG_ASSERT(firstDiffBlock); auto primalBranch = primalBuilder->emitBranch(firstDiffBlock); - auto pairVal = diffBuilder->emitMakeDifferentialPair( pairType, lookupPrimalInst(mixedReturn->getVal()), diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index d23271704..53e2ed0be 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -2,6 +2,7 @@ #include "slang-ir-autodiff-rev.h" #include "slang-ir-autodiff-fwd.h" #include "slang-ir-autodiff-pairs.h" +#include "slang-ir-validate.h" namespace Slang { @@ -405,6 +406,8 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_BackwardDerivativeIntermediateTypeDecoration: case kIROp_BackwardDerivativePropagateDecoration: case kIROp_BackwardDerivativePrimalDecoration: + case kIROp_BackwardDerivativePrimalContextDecoration: + case kIROp_BackwardDerivativePrimalReturnDecoration: decor->removeAndDeallocate(); break; default: @@ -716,6 +719,10 @@ struct AutoDiffPass : public InstPassBase autodiffCleanupList.clear(); +#if _DEBUG + validateIRModule(module, sink); +#endif + if (!changed) break; hasChanges |= changed; @@ -780,11 +787,14 @@ struct AutoDiffPass : public InstPassBase forwardTranscriber.pairBuilder = &pairBuilderStorage; backwardPrimalTranscriber.pairBuilder = &pairBuilderStorage; - backwardPrimalTranscriber.fwdDiffTranscriber = &forwardTranscriber; backwardPropagateTranscriber.pairBuilder = &pairBuilderStorage; - backwardPropagateTranscriber.fwdDiffTranscriber = &forwardTranscriber; backwardTranscriber.pairBuilder = &pairBuilderStorage; - backwardTranscriber.fwdDiffTranscriber = &forwardTranscriber; + + // Make the transcribers available to all sub passes via shared context. + context->transcriberSet.primalTranscriber = &backwardPrimalTranscriber; + context->transcriberSet.propagateTranscriber = &backwardPropagateTranscriber; + context->transcriberSet.forwardTranscriber = &forwardTranscriber; + context->transcriberSet.backwardTranscriber = &backwardTranscriber; } protected: diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 1415618e1..b4b97751f 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -47,6 +47,16 @@ struct FuncBodyTranscriptionTask IRFunc* resultFunc; }; +struct AutoDiffTranscriberBase; + +struct DiffTranscriberSet +{ + AutoDiffTranscriberBase* forwardTranscriber = nullptr; + AutoDiffTranscriberBase* primalTranscriber = nullptr; + AutoDiffTranscriberBase* propagateTranscriber = nullptr; + AutoDiffTranscriberBase* backwardTranscriber = nullptr; +}; + struct AutoDiffSharedContext { IRModuleInst* moduleInst = nullptr; @@ -93,6 +103,8 @@ struct AutoDiffSharedContext List<FuncBodyTranscriptionTask> followUpFunctionsToTranscribe; + DiffTranscriberSet transcriberSet; + AutoDiffSharedContext(IRModuleInst* inModuleInst); private: diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index b721f4225..06f8b0e5d 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -645,6 +645,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// A `[keepAlive]` decoration marks an instruction that should not be eliminated. INST(KeepAliveDecoration, keepAlive, 0, 0) + /// A `[NoSideEffect]` decoration marks a callee to be side-effect free. + INST(NoSideEffectDecoration, noSideEffect, 0, 0) + INST(BindExistentialSlotsDecoration, bindExistentialSlots, 0, 0) /// A `[format(f)]` decoration specifies that the format of an image should be `f` @@ -737,6 +740,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(BackwardDerivativeIntermediateTypeDecoration, backwardDiffIntermediateTypeReference, 1, 0) INST(BackwardDerivativeDecoration, backwardDiffReference, 1, 0) + INST(BackwardDerivativePrimalContextDecoration, BackwardDerivativePrimalContextDecoration, 1, 0) + INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0) + /// Used by the auto-diff pass to mark insts that compute /// a differential value. INST(DifferentialInstDecoration, diffInstDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index d2a4c7ae3..1ff61a774 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -308,6 +308,7 @@ struct IRRequireGLSLExtensionDecoration : IRDecoration }; IR_SIMPLE_DECORATION(ReadNoneDecoration) +IR_SIMPLE_DECORATION(NoSideEffectDecoration) IR_SIMPLE_DECORATION(EarlyDepthStencilDecoration) IR_SIMPLE_DECORATION(GloballyCoherentDecoration) IR_SIMPLE_DECORATION(PreciseDecoration) @@ -607,6 +608,29 @@ struct IRBackwardDerivativePrimalDecoration : IRDecoration IRInst* getBackwardDerivativePrimalFunc() { return getOperand(0); } }; +// Used to associate the restore context var to use in a call to splitted backward propgate function. +struct IRBackwardDerivativePrimalContextDecoration : IRDecoration +{ + enum + { + kOp = kIROp_BackwardDerivativePrimalContextDecoration + }; + IR_LEAF_ISA(BackwardDerivativePrimalContextDecoration) + + IRInst* getBackwardDerivativePrimalContextVar() { return getOperand(0); } +}; + +struct IRBackwardDerivativePrimalReturnDecoration : IRDecoration +{ + enum + { + kOp = kIROp_BackwardDerivativePrimalReturnDecoration + }; + IR_LEAF_ISA(BackwardDerivativePrimalReturnDecoration) + + IRInst* getBackwardDerivativePrimalReturnValue() { return getOperand(0); } +}; + struct IRBackwardDerivativePropagateDecoration : IRDecoration { enum @@ -3478,6 +3502,11 @@ public: addDecoration(value, kIROp_BackwardDerivativePrimalDecoration, jvpFn); } + void addBackwardDerivativePrimalReturnDecoration(IRInst* value, IRInst* retVal) + { + addDecoration(value, kIROp_BackwardDerivativePrimalReturnDecoration, retVal); + } + void addBackwardDerivativePropagateDecoration(IRInst* value, IRInst* jvpFn) { addDecoration(value, kIROp_BackwardDerivativePropagateDecoration, jvpFn); @@ -3493,6 +3522,11 @@ public: addDecoration(value, kIROp_BackwardDerivativeIntermediateTypeDecoration, jvpFn); } + void addBackwardDerivativePrimalContextDecoration(IRInst* value, IRInst* ctx) + { + addDecoration(value, kIROp_BackwardDerivativePrimalContextDecoration, ctx); + } + void markInstAsDifferential(IRInst* value) { addDecoration(value, kIROp_DifferentialInstDecoration, nullptr); diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 5b0afdd12..38503155d 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -1511,6 +1511,10 @@ static LegalVal legalizeMakeStruct( List<IRInst*> args; for(UInt aa = 0; aa < argCount; ++aa) { + // Ignore none values. + if (legalArgs[aa].flavor == LegalVal::Flavor::none) + continue; + // Note: we assume that all the arguments // must be simple here, because otherwise // the `struct` type with them as fields @@ -1521,7 +1525,7 @@ static LegalVal legalizeMakeStruct( return LegalVal::simple( builder->emitMakeStruct( legalType.getSimple(), - argCount, + args.getCount(), args.getBuffer())); } diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index de970fbca..cbb5ccf09 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -720,14 +720,27 @@ struct SpecializationContext if (!item) continue; IRSimpleSpecializationKey key; bool shouldSkip = false; - for (UInt i = 1; i < item->getOperandCount(); i++) + for (UInt i = 0; i < item->getOperandCount(); i++) { if (item->getOperand(i) == nullptr) { shouldSkip = true; break; } - key.vals.add(item->getOperand(i)); + if (item->getOperand(i)->getParent() == nullptr) + { + shouldSkip = true; + break; + } + if (item->getOperand(i)->getOp() == kIROp_undefined) + { + shouldSkip = true; + break; + } + if (i > 0) + { + key.vals.add(item->getOperand(i)); + } } if (shouldSkip) continue; @@ -768,10 +781,19 @@ struct SpecializationContext builder.setInsertInto(dictInst); for (auto kv : dict) { - List<IRInst*> args; - args.add(kv.Value); - args.addRange(kv.Key.vals); - builder.emitIntrinsicInst(nullptr, kIROp_SpecializationDictionaryItem, args.getCount(), args.getBuffer()); + if (!kv.Value->parent) + continue; + for (auto keyVal : kv.Key.vals) + { + if (!keyVal->parent) goto next; + } + { + List<IRInst*> args; + args.add(kv.Value); + args.addRange(kv.Key.vals); + builder.emitIntrinsicInst(nullptr, kIROp_SpecializationDictionaryItem, args.getCount(), args.getBuffer()); + } + next:; } } void writeSpecializationDictionaries() @@ -2312,6 +2334,27 @@ bool specializeModule( return context.changed; } +void finalizeSpecialization(IRModule* module) +{ + for (auto inst : module->getModuleInst()->getChildren()) + { + for (auto decor = inst->getFirstDecoration(); decor; ) + { + auto next = decor->getNextDecoration(); + switch (decor->getOp()) + { + case kIROp_ExistentialFuncSpecializationDictionary: + case kIROp_ExistentialTypeSpecializationDictionary: + case kIROp_GenericSpecializationDictionary: + decor->removeAndDeallocate(); + break; + default: + break; + } + decor = next; + } + } +} IRInst* specializeGenericImpl( IRGeneric* genericVal, diff --git a/source/slang/slang-ir-specialize.h b/source/slang/slang-ir-specialize.h index 1503c238e..20d65cb67 100644 --- a/source/slang/slang-ir-specialize.h +++ b/source/slang/slang-ir-specialize.h @@ -9,4 +9,6 @@ struct IRModule; bool specializeModule( IRModule* module); +void finalizeSpecialization(IRModule* module); + } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 9e0e328bd..f37a7a1a0 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6526,7 +6526,7 @@ namespace Slang // By default, assume that we might have side effects, // to safely cover all the instructions we haven't had time to think about. default: - return true; + break; case kIROp_Call: { @@ -6553,7 +6553,7 @@ namespace Slang return false; } } - return true; + break; // All of the cases for "global values" are side-effect-free. case kIROp_StructType: @@ -6665,6 +6665,13 @@ namespace Slang case kIROp_BackwardDifferentiate: return false; } + + // Check if the calle has been marked with a catch-all no-side-effect decoration. + if (findDecoration<IRNoSideEffectDecoration>()) + { + return false; + } + return true; } IRModule* IRInst::getModule() |
