diff options
Diffstat (limited to 'source')
22 files changed, 503 insertions, 106 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 1ffc45fbd..2fc18628e 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2129,7 +2129,7 @@ namespace Slang { derivType = outType->getValueType(); } - else if (!as<PtrTypeBase>(derivType)) + else if (as<DifferentialPairType>(derivType)) { derivType = m_astBuilder->getInOutType(derivType); } diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index d50cc45a3..00fa5d3cb 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -54,7 +54,6 @@ #include "slang-ir-liveness.h" #include "slang-ir-glsl-liveness.h" #include "slang-ir-string-hash.h" - #include "slang-legalize-types.h" #include "slang-lower-to-ir.h" #include "slang-mangle.h" @@ -378,7 +377,9 @@ Result linkAndOptimizeIR( performMandatoryEarlyInlining(irModule); dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF"); + enableIRValidationAtInsert(); changed |= processAutodiffCalls(irModule, sink); + disableIRValidationAtInsert(); dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF"); if (!changed) @@ -1009,7 +1010,7 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr this, linkingAndOptimizationOptions, linkedIR)); - + auto irModule = linkedIR.module; metadata = linkedIR.metadata; diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index e37415446..54d32ae3e 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -64,17 +64,16 @@ InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVa if (diffNameHint.getLength() > 0) builder->addNameHintDecoration(diffVar, diffNameHint.getUnownedSlice()); - return InstPair(cloneInst(&cloneEnv, builder, origVar), diffVar); + return InstPair(maybeCloneForPrimalInst(builder, origVar), diffVar); } - - return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr); + return InstPair(maybeCloneForPrimalInst(builder, origVar), nullptr); } InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRInst* origArith) { SLANG_ASSERT(origArith->getOperandCount() == 2); - IRInst* primalArith = cloneInst(&cloneEnv, builder, origArith); + IRInst* primalArith = maybeCloneForPrimalInst(builder, origArith); auto origLeft = origArith->getOperand(0); auto origRight = origArith->getOperand(1); @@ -160,7 +159,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRIns // Boolean operations are not differentiable. For the linearization // pass, we do not need to do anything but copy them over to the ne // function. - auto primalLogic = cloneInst(&cloneEnv, builder, origLogic); + auto primalLogic = maybeCloneForPrimalInst(builder, origLogic); return InstPair(primalLogic, nullptr); } @@ -170,7 +169,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRIns InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* origLoad) { auto origPtr = origLoad->getPtr(); - auto primalPtr = lookupPrimalInst(origPtr, nullptr); + auto primalPtr = lookupPrimalInst(builder, origPtr, nullptr); auto primalPtrValueType = as<IRPtrTypeBase>(primalPtr->getFullType())->getValueType(); if (auto diffPairType = as<IRDifferentialPairType>(primalPtrValueType)) @@ -190,7 +189,7 @@ InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* orig return InstPair(primalElement, diffElement); } - auto primalLoad = cloneInst(&cloneEnv, builder, origLoad); + auto primalLoad = maybeCloneForPrimalInst(builder, origLoad); IRInst* diffLoad = nullptr; if (auto diffPtr = lookupDiffInst(origPtr, nullptr)) { @@ -204,9 +203,9 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or { IRInst* origStoreLocation = origStore->getPtr(); IRInst* origStoreVal = origStore->getVal(); - auto primalStoreLocation = lookupPrimalInst(origStoreLocation, nullptr); + auto primalStoreLocation = lookupPrimalInst(builder, origStoreLocation, nullptr); auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr); - auto primalStoreVal = lookupPrimalInst(origStoreVal, nullptr); + auto primalStoreVal = lookupPrimalInst(builder, origStoreVal, nullptr); auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr); if (!diffStoreLocation) @@ -222,7 +221,7 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or } } - auto primalStore = cloneInst(&cloneEnv, builder, origStore); + auto primalStore = maybeCloneForPrimalInst(builder, origStore); IRInst* diffStore = nullptr; @@ -248,7 +247,7 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or // InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* origConstruct) { - IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct); + IRInst* primalConstruct = maybeCloneForPrimalInst(builder, origConstruct); // Check if the output type can be differentiated. If it cannot be // differentiated, don't differentiate the inst @@ -340,7 +339,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (!diffCallee) { // The callee is non differentiable, just return primal value with null diff value. - IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall); + IRInst* primalCall = maybeCloneForPrimalInst(builder, origCall); return InstPair(primalCall, nullptr); } @@ -419,7 +418,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig InstPair ForwardDiffTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle) { - IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle); + IRInst* primalSwizzle = maybeCloneForPrimalInst(builder, origSwizzle); if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr)) { @@ -441,7 +440,7 @@ InstPair ForwardDiffTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle InstPair ForwardDiffTranscriber::transcribeByPassthrough(IRBuilder* builder, IRInst* origInst) { - IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst); + IRInst* primalInst = maybeCloneForPrimalInst(builder, origInst); UCount operandCount = origInst->getOperandCount(); @@ -462,7 +461,7 @@ InstPair ForwardDiffTranscriber::transcribeByPassthrough(IRBuilder* builder, IRI return InstPair( primalInst, builder->emitIntrinsicInst( - differentiateType(builder, primalInst->getDataType()), + differentiateType(builder, origInst->getDataType()), origInst->getOp(), operandCount, diffOperands.getBuffer())); @@ -481,10 +480,10 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++) { auto origArg = origBranch->getArg(ii); - auto primalArg = lookupPrimalInst(origArg); + auto primalArg = lookupPrimalInst(builder, origArg); newArgs.add(primalArg); - if (differentiateType(builder, primalArg->getDataType())) + if (differentiateType(builder, origArg->getDataType())) { auto diffArg = lookupDiffInst(origArg, nullptr); if (diffArg) @@ -672,7 +671,7 @@ InstPair ForwardDiffTranscriber::transcribeFieldExtract(IRBuilder* builder, IRIn IRInst* diffFieldExtract = nullptr; - if (auto diffType = differentiateType(builder, primalType)) + if (auto diffType = differentiateType(builder, originalInst->getDataType())) { if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) { @@ -706,7 +705,7 @@ InstPair ForwardDiffTranscriber::transcribeGetElement(IRBuilder* builder, IRInst IRInst* diffGetElementPtr = nullptr; - if (auto diffType = differentiateType(builder, primalType)) + if (auto diffType = differentiateType(builder, origGetElementPtr->getDataType())) { if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) { @@ -820,7 +819,7 @@ InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* build auto primalPair = builder->emitMakeDifferentialPair( tryGetDiffPairType(builder, primalVal->getDataType()), primalVal, diffPrimalVal); auto diffPair = builder->emitMakeDifferentialPair( - tryGetDiffPairType(builder, differentiateType(builder, primalVal->getDataType())), + tryGetDiffPairType(builder, differentiateType(builder, origInst->getPrimalValue()->getDataType())), primalDiffVal, diffDiffVal); return InstPair(primalPair, diffPair); @@ -897,7 +896,7 @@ InstPair ForwardDiffTranscriber::transcribeWrapExistential(IRBuilder* builder, I IRInst* diffResult = nullptr; - if (auto diffType = differentiateType(builder, primalType)) + if (auto diffType = differentiateType(builder, origInst->getDataType())) { List<IRInst*> diffArgs; for (UInt i = 0; i < origInst->getOperandCount(); i++) diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 817534065..af408a5b3 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -20,27 +20,29 @@ namespace Slang { bool noDiff = false; auto origType = funcType->getParamType(i); - if (auto attrType = as<IRAttributedType>(origType)) + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origType); + + if (auto attrType = as<IRAttributedType>(primalType)) { if (attrType->findAttr<IRNoDiffAttr>()) { noDiff = true; - origType = attrType->getBaseType(); + primalType = attrType->getBaseType(); } } if (noDiff) { - newParameterTypes.add(origType); + newParameterTypes.add(primalType); } else { - if (auto diffPairType = tryGetDiffPairType(builder, origType)) + if (auto diffPairType = tryGetDiffPairType(builder, primalType)) { auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); newParameterTypes.add(inoutDiffPairType); } else - newParameterTypes.add(origType); + newParameterTypes.add(primalType); } } @@ -55,35 +57,47 @@ namespace Slang return builder->getFuncType(newParameterTypes, diffReturnType); } + + static IRInst* getOriginalFuncRef(IRBuilder& builder, IRInst* func, IRInst* useSite) + { + if (!func) return nullptr; + auto userGeneric = findOuterGeneric(useSite); + if (!userGeneric) return func; + auto funcGen = findOuterGeneric(func); + SLANG_RELEASE_ASSERT(funcGen); + return maybeSpecializeWithGeneric(builder, funcGen, userGeneric); + } IRFuncType* BackwardDiffPrimalTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) { - auto intermediateType = builder->getBackwardDiffIntermediateContextType(func); + auto funcRef = getOriginalFuncRef(*builder, func, builder->getInsertLoc().getParent()); + auto intermediateType = builder->getBackwardDiffIntermediateContextType(funcRef); auto outType = builder->getOutType(intermediateType); List<IRType*> paramTypes; for (UInt i = 0; i < funcType->getParamCount(); i++) { - paramTypes.add(funcType->getParamType(i)); + auto origType = funcType->getParamType(i); + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origType); + paramTypes.add(primalType); } paramTypes.add(outType); IRFuncType* primalFuncType = builder->getFuncType( - paramTypes, funcType->getResultType()); + paramTypes, (IRType*)findOrTranscribePrimalInst(builder, funcType->getResultType())); return primalFuncType; } InstPair BackwardDiffPrimalTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) { - SLANG_UNUSED(builder); - SLANG_UNUSED(diffFunc); - auto intermediateTypeDecor = primalFunc->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>(); - SLANG_RELEASE_ASSERT(intermediateTypeDecor); - auto primalDecor = primalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>(); - return InstPair(primalFunc, primalDecor->getBackwardDerivativePrimalFunc()); + // Don't need to do anything other than add a decoration in the original func to point to the primal func. + // The body of the primal func will be generated by propagateTranscriber together with propagate func. + addTranscribedFuncDecoration(*builder, primalFunc, diffFunc); + return InstPair(primalFunc, primalFunc); } IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) { - auto intermediateType = builder->getBackwardDiffIntermediateContextType(func); + auto funcRef = getOriginalFuncRef(*builder, func, builder->getInsertLoc().getParent()); + auto intermediateType = builder->getBackwardDiffIntermediateContextType(funcRef); return differentiateFunctionTypeImpl(builder, funcType, intermediateType); } @@ -96,6 +110,7 @@ namespace Slang InstPair BackwardDiffPropagateTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) { IRGlobalValueWithCode* diffPrimalFunc = nullptr; + addTranscribedFuncDecoration(*builder, primalFunc, diffFunc); transcribeFuncImpl(builder, primalFunc, diffFunc, diffPrimalFunc); return InstPair(primalFunc, diffFunc); } @@ -211,8 +226,7 @@ namespace Slang if (!isMarkedForBackwardDifferentiation(origFunc)) return InstPair(nullptr, nullptr); - IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertBefore(origFunc); + IRBuilder builder = *inBuilder; IRFunc* primalFunc = origFunc; @@ -221,6 +235,8 @@ namespace Slang auto diffFunc = builder.createFunc(); SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); + builder.setInsertBefore(diffFunc); + IRType* diffFuncType = this->differentiateFunctionType( &builder, origFunc, @@ -235,18 +251,6 @@ namespace Slang builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); } - if (auto outerGen = findOuterGeneric(diffFunc)) - { - builder.setInsertBefore(origFunc); - auto specialized = - specializeWithGeneric(builder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc))); - addExistingDiffFuncDecor(&builder, origFunc, specialized); - } - else - { - addExistingDiffFuncDecor(&builder, origFunc, diffFunc); - } - // Mark the generated derivative function itself as differentiable. builder.addBackwardDifferentiableDecoration(diffFunc); @@ -259,6 +263,22 @@ namespace Slang return InstPair(primalFunc, diffFunc); } + void BackwardDiffTranscriberBase::addTranscribedFuncDecoration(IRBuilder& builder, IRFunc* origFunc, IRFunc* transcribedFunc) + { + IRBuilder subBuilder = builder; + if (auto outerGen = findOuterGeneric(transcribedFunc)) + { + subBuilder.setInsertBefore(origFunc); + auto specialized = + specializeWithGeneric(subBuilder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc))); + addExistingDiffFuncDecor(&subBuilder, origFunc, specialized); + } + else + { + addExistingDiffFuncDecor(&subBuilder, origFunc, transcribedFunc); + } + } + InstPair BackwardDiffTranscriberBase::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) { auto result = transcribeFuncHeaderImpl(inBuilder, origFunc); @@ -288,7 +308,7 @@ namespace Slang List<IRType*> primalTypes, propagateTypes; for (UInt i = 0; i < funcType->getParamCount(); i++) { - auto paramType = funcType->getParamType(i); + auto paramType = (IRType*)findOrTranscribePrimalInst(&builder, funcType->getParamType(i)); auto param = builder.emitParam(paramType); if (i != funcType->getParamCount() - 1) { @@ -368,10 +388,8 @@ namespace Slang { IRParam* nextParam = param->getNextParam(); - // Copy inst into the new parameter block. - auto clonedParam = cloneInst(&cloneEnv, &builder, param); - param->replaceUsesWith(clonedParam); - param->removeAndDeallocate(); + // Move inst into the new parameter block. + param->insertAtEnd(paramBlock); param = nextParam; } @@ -383,6 +401,62 @@ namespace Slang builder.emitBranch(firstBlock); } + // Create a copy of originalFunc's forward derivative in the same generic context (if any) of + // `diffPropagateFunc`. + IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc( + IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc) + { + auto primalOuterParent = findOuterGeneric(originalFunc); + if (!primalOuterParent) + primalOuterParent = originalFunc; + + // Make a clone of original func so we won't modify the original. + IRCloneEnv originalCloneEnv; + primalOuterParent = cloneInst(&originalCloneEnv, builder, primalOuterParent); + auto primalFunc = as<IRFunc>(getGenericReturnVal(primalOuterParent)); + + // Strip any existing derivative decorations off the clone. + stripDerivativeDecorations(primalFunc); + eliminateDeadCode(primalOuterParent); + + // Forward transcribe the clone of the original func. + ForwardDiffTranscriber fwdTranscriber(autoDiffSharedContext, builder->getSharedBuilder(), sink); + fwdTranscriber.pairBuilder = pairBuilder; + IRFunc* fwdDiffFunc = as<IRFunc>(getGenericReturnVal(fwdTranscriber.transcribe(builder, primalOuterParent))); + SLANG_ASSERT(fwdDiffFunc); + fwdTranscriber.transcribeFunc(builder, primalFunc, fwdDiffFunc); + + // Remove the clone of original func. + primalOuterParent->removeAndDeallocate(); + + // Migrate the new forward derivative function into the generic parent of `diffPropagateFunc`. + if (auto fwdParentGeneric = as<IRGeneric>(findOuterGeneric(fwdDiffFunc))) + { + // Clone forward derivative func from its own generic into current generic parent. + GenericChildrenMigrationContext migrationContext; + auto diffOuterGeneric = as<IRGeneric>(findOuterGeneric(diffPropagateFunc)); + SLANG_RELEASE_ASSERT(diffOuterGeneric); + + migrationContext.init(fwdParentGeneric, diffOuterGeneric); + auto inst = fwdParentGeneric->getFirstBlock()->getFirstOrdinaryInst(); + builder->setInsertBefore(diffPropagateFunc); + while (inst) + { + auto next = inst->getNextInst(); + auto cloned = migrationContext.cloneInst(builder, inst); + if (inst == fwdDiffFunc) + { + fwdDiffFunc = as<IRFunc>(cloned); + break; + } + inst = next; + } + fwdParentGeneric->removeAndDeallocate(); + } + + return fwdDiffFunc; + } + // Transcribe a function definition. void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc) { @@ -393,12 +467,16 @@ namespace Slang // Generate a temporary forward derivative function as an intermediate step. IRBuilder tempBuilder = *builder; - tempBuilder.setInsertBefore(diffPropagateFunc); - ForwardDiffTranscriber* fwdTranscriber = static_cast<ForwardDiffTranscriber*>(autoDiffSharedContext->transcriberSet.forwardTranscriber); - IRFunc* fwdDiffFunc = as<IRFunc>(fwdTranscriber->transcribeFuncHeaderImpl(&tempBuilder, primalFunc)); - SLANG_ASSERT(fwdDiffFunc); + if (auto outerGeneric = findOuterGeneric(diffPropagateFunc)) + { + tempBuilder.setInsertBefore(outerGeneric); + } + else + { + tempBuilder.setInsertBefore(diffPropagateFunc); + } - fwdTranscriber->transcribeFunc(&tempBuilder, primalFunc, fwdDiffFunc); + auto fwdDiffFunc = generateNewForwardDerivativeForFunc(&tempBuilder, primalFunc, diffPropagateFunc); // Split first block into a paramter block. this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc)); @@ -466,11 +544,12 @@ namespace Slang // we have just created. auto primalOuterGeneric = findOuterGeneric(primalFunc); IRInst* specializedFunc = nullptr; - auto intermediateTypeGeneric = hoistValueFromGeneric(*builder, intermediateType, specializedFunc); + auto intermediateTypeGeneric = hoistValueFromGeneric(*builder, intermediateType, specializedFunc, true); + builder->setInsertBefore(primalFunc); auto specializedIntermeidateType = maybeSpecializeWithGeneric(*builder, intermediateTypeGeneric, primalOuterGeneric); builder->addBackwardDerivativeIntermediateTypeDecoration(primalFunc, specializedIntermeidateType); - auto primalFuncGeneric = hoistValueFromGeneric(*builder, extractedPrimalFunc, specializedFunc); + auto primalFuncGeneric = hoistValueFromGeneric(*builder, extractedPrimalFunc, specializedFunc, true); builder->setInsertBefore(primalFunc); if (auto existingDecor = primalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>()) @@ -568,7 +647,7 @@ namespace Slang return diffParam; } - return cloneInst(&cloneEnv, builder, origParam); + return maybeCloneForPrimalInst(builder, origParam); } InstPair BackwardDiffTranscriberBase::copyBinaryArith(IRBuilder* builder, IRInst* origArith) diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index decbdf150..02a100c80 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -85,10 +85,14 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize); + IRFunc* generateNewForwardDerivativeForFunc(IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc); + void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc); InstPair transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc); + void addTranscribedFuncDecoration(IRBuilder& builder, IRFunc* origFunc, IRFunc* transcribedFunc); + virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override; virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) override; @@ -173,8 +177,10 @@ struct BackwardDiffTranscriber : BackwardDiffTranscriberBase virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override; virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) override { - SLANG_UNUSED(builder); // Don't need to do anything here, the body is generated in transcribeFuncHeader. + + SLANG_UNUSED(builder); + addTranscribedFuncDecoration(*builder, primalFunc, diffFunc); return InstPair(primalFunc, diffFunc); } virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) override diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index c0404e036..deb1b2da9 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -75,36 +75,38 @@ bool AutoDiffTranscriberBase::hasDifferentialInst(IRInst* origInst) return instMapD.ContainsKey(origInst); } -bool AutoDiffTranscriberBase::shouldUseOriginalAsPrimal(IRInst* origInst) +bool AutoDiffTranscriberBase::shouldUseOriginalAsPrimal(IRInst* currentParent, IRInst* origInst) { if (as<IRGlobalValueWithCode>(origInst)) return true; if (origInst->parent && origInst->parent->getOp() == kIROp_Module) return true; + if (isChildInstOf(currentParent, origInst->getParent())) + return true; return false; } -IRInst* AutoDiffTranscriberBase::lookupPrimalInst(IRInst* origInst) +IRInst* AutoDiffTranscriberBase::lookupPrimalInstImpl(IRInst* currentParent, IRInst* origInst) { if (!origInst) return nullptr; - if (shouldUseOriginalAsPrimal(origInst)) + if (shouldUseOriginalAsPrimal(currentParent, origInst)) return origInst; return cloneEnv.mapOldValToNew[origInst]; } -IRInst* AutoDiffTranscriberBase::lookupPrimalInst(IRInst* origInst, IRInst* defaultInst) +IRInst* AutoDiffTranscriberBase::lookupPrimalInst(IRInst* currentParent, IRInst* origInst, IRInst* defaultInst) { if (!origInst) return nullptr; - return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst; + return (hasPrimalInst(currentParent, origInst)) ? lookupPrimalInstImpl(currentParent, origInst) : defaultInst; } -bool AutoDiffTranscriberBase::hasPrimalInst(IRInst* origInst) +bool AutoDiffTranscriberBase::hasPrimalInst(IRInst* currentParent, IRInst* origInst) { if (!origInst) return false; - if (shouldUseOriginalAsPrimal(origInst)) + if (shouldUseOriginalAsPrimal(currentParent, origInst)) return true; return cloneEnv.mapOldValToNew.ContainsKey(origInst); } @@ -124,26 +126,48 @@ IRInst* AutoDiffTranscriberBase::findOrTranscribePrimalInst(IRBuilder* builder, { if (!origInst) return origInst; + + auto currentParent = builder->getInsertLoc().getParent(); - if (shouldUseOriginalAsPrimal(origInst)) + if (shouldUseOriginalAsPrimal(currentParent, origInst)) return origInst; - if (!hasPrimalInst(origInst)) + if (!hasPrimalInst(currentParent, origInst)) { transcribe(builder, origInst); - SLANG_ASSERT(hasPrimalInst(origInst)); + SLANG_ASSERT(hasPrimalInst(currentParent, origInst)); } - return lookupPrimalInst(origInst); + return lookupPrimalInstImpl(currentParent, origInst); } IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRInst* inst) { - IRInst* primal = lookupPrimalInst(inst, inst); - - if (primal == inst && - !isChildInstOf(builder->getInsertLoc().getParent(), inst->getParent())) - primal = cloneInst(&cloneEnv, builder, inst); + IRInst* primal = lookupPrimalInst(builder, inst, nullptr); + if (!primal) + { + IRInst* type = inst->getFullType(); + if (type) + { + type = maybeCloneForPrimalInst(builder, type); + } + List<IRInst*> operands; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto operand = maybeCloneForPrimalInst(builder, inst->getOperand(i)); + operands.add(operand); + } + auto cloneResult = builder->emitIntrinsicInst( + (IRType*)type, inst->getOp(), operands.getCount(), operands.getBuffer()); + IRBuilder subBuilder = *builder; + subBuilder.setInsertInto(cloneResult); + for (auto child : inst->getDecorationsAndChildren()) + { + maybeCloneForPrimalInst(&subBuilder, child); + } + cloneEnv.mapOldValToNew[inst] = cloneResult; + return cloneResult; + } return primal; } @@ -223,7 +247,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy // If there is an explicit primal version of this type in the local scope, load that // otherwise use the original type. // - IRInst* primalType = lookupPrimalInst(origType, origType); + IRInst* primalType = lookupPrimalInst(builder, origType, origType); // Special case certain compound types (PtrType, FuncType, etc..) // otherwise try to lookup a differential definition for the given type. @@ -390,7 +414,7 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* if (lookupKeyPath.getCount()) { // `interfaceType` does conform to `IDifferentiable`. - outWitnessTable = builder->emitExtractExistentialWitnessTable(lookupPrimalInstIfExists(origType->getOperand(0))); + outWitnessTable = builder->emitExtractExistentialWitnessTable(lookupPrimalInstIfExists(builder, origType->getOperand(0))); for (auto node : lookupKeyPath) { outWitnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), outWitnessTable, node->getRequirementKey()); @@ -731,7 +755,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst // if (auto diffInst = lookupDiffInst(origInst, nullptr)) { - SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check. + SLANG_ASSERT(lookupPrimalInst(builder, origInst)); // Consistency check. return diffInst; } diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index a6b832856..2d980145e 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -41,7 +41,7 @@ struct AutoDiffTranscriberBase , sharedBuilder(inSharedBuilder) , sink(inSink) { - + cloneEnv.squashChildrenMapping = true; } DiagnosticSink* getSink(); @@ -61,15 +61,29 @@ struct AutoDiffTranscriberBase bool hasDifferentialInst(IRInst* origInst); - bool shouldUseOriginalAsPrimal(IRInst* origInst); + bool shouldUseOriginalAsPrimal(IRInst* currentParent, IRInst* origInst); + + IRInst* lookupPrimalInstImpl(IRInst* currentParent, IRInst* origInst); + + IRInst* lookupPrimalInst(IRInst* currentParent, IRInst* origInst, IRInst* defaultInst); + + IRInst* lookupPrimalInstIfExists(IRBuilder* builder, IRInst* origInst) + { + return lookupPrimalInst(builder->getInsertLoc().getParent(), origInst, origInst); + } - IRInst* lookupPrimalInst(IRInst* origInst); + IRInst* lookupPrimalInst(IRBuilder* builder, IRInst* origInst) + { + return lookupPrimalInstImpl(builder->getInsertLoc().getParent(), origInst); + } - IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst); + IRInst* lookupPrimalInst(IRBuilder* builder, IRInst* origInst, IRInst* defaultInst) + { + return lookupPrimalInst(builder->getInsertLoc().getParent(), origInst, defaultInst); + } - IRInst* lookupPrimalInstIfExists(IRInst* origInst) { return lookupPrimalInst(origInst, origInst); } - bool hasPrimalInst(IRInst* origInst); + bool hasPrimalInst(IRInst* currentParent, IRInst* origInst); IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 43b48aa13..b8a4c4f08 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -4,6 +4,7 @@ namespace Slang { + struct ExtractPrimalFuncContext { SharedIRBuilder* sharedBuilder; @@ -74,14 +75,18 @@ struct ExtractPrimalFuncContext IRFuncType* originalFuncType = nullptr; outIntermediateType = createIntermediateType(destFunc); + GenericChildrenMigrationContext migrationContext; + migrationContext.init(as<IRGeneric>(findOuterGeneric(originalFunc)), as<IRGeneric>(findOuterGeneric(destFunc))); + originalFuncType = as<IRFuncType>(originalFunc->getDataType()); SLANG_RELEASE_ASSERT(originalFuncType); List<IRType*> paramTypes; for (UInt i = 0; i < originalFuncType->getParamCount() - 1; i++) - paramTypes.add(originalFuncType->getParamType(i)); + paramTypes.add((IRType*)migrationContext.cloneInst(&builder, originalFuncType->getParamType(i))); paramTypes.add(builder.getInOutType((IRType*)outIntermediateType)); - auto newFuncType = builder.getFuncType(paramTypes, builder.getVoidType()); + auto resultType = (IRType*)migrationContext.cloneInst(&builder, originalFuncType->getResultType()); + auto newFuncType = builder.getFuncType(paramTypes, resultType); return newFuncType; } @@ -239,7 +244,10 @@ struct ExtractPrimalFuncContext auto ptrStructType = as<IRPtrTypeBase>(intermediateOutput->getDataType()); SLANG_RELEASE_ASSERT(ptrStructType); auto structType = as<IRStructType>(ptrStructType->getValueType()); - genTypeBuilder.setInsertBefore(structType); + if (auto outerGen = findOuterGeneric(structType)) + genTypeBuilder.setInsertBefore(outerGen); + else + genTypeBuilder.setInsertBefore(structType); auto fieldType = type; SLANG_RELEASE_ASSERT(structType); auto structKey = genTypeBuilder.createStructKey(); diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index ba1e425db..612212dd9 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -9,10 +9,44 @@ #include "slang-ir-autodiff-fwd.h" #include "slang-ir-autodiff-propagate.h" #include "slang-ir-autodiff-transcriber-base.h" +#include "slang-ir-validate.h" namespace Slang { +struct GenericChildrenMigrationContext +{ + IRCloneEnv cloneEnv; + IRGeneric* srcGeneric; + void init(IRGeneric* genericSrc, IRGeneric* genericDst) + { + srcGeneric = genericSrc; + if (!genericSrc) + return; + auto srcParam = genericSrc->getFirstBlock()->getFirstParam(); + auto dstParam = genericDst->getFirstBlock()->getFirstParam(); + while (srcParam && dstParam) + { + cloneEnv.mapOldValToNew[srcParam] = dstParam; + srcParam = srcParam->getNextParam(); + dstParam = dstParam->getNextParam(); + } + cloneEnv.mapOldValToNew[genericSrc] = genericDst; + cloneEnv.mapOldValToNew[genericSrc->getFirstBlock()] = genericDst->getFirstBlock(); + } + + IRInst* cloneInst(IRBuilder* builder, IRInst* src) + { + if (!srcGeneric) + return src; + if (findOuterGeneric(src) == srcGeneric) + { + return Slang::cloneInst(&cloneEnv, builder, src); + } + return src; + } +}; + struct DiffUnzipPass { AutoDiffSharedContext* autodiffContext; @@ -62,6 +96,7 @@ struct DiffUnzipPass // TODO: Looks like we get a copy of the decorations? IRCloneEnv subEnv; subEnv.parent = &cloneEnv; + builder->setInsertBefore(func); IRFunc* unzippedFunc = as<IRFunc>(cloneInst(&subEnv, builder, func)); builder->setInsertInto(unzippedFunc); @@ -231,7 +266,10 @@ struct DiffUnzipPass newFwdCallee, diffArgs); diffBuilder->markInstAsDifferential(diffPairVal, primalType); + + disableIRValidationAtInsert(); diffBuilder->addBackwardDerivativePrimalContextDecoration(diffPairVal, intermediateVar); + enableIRValidationAtInsert(); auto diffVal = diffBuilder->emitDifferentialPairGetDifferential(diffType, diffPairVal); diffBuilder->markInstAsDifferential(diffVal, primalType); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 94417ea00..74afa4002 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -393,6 +393,28 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() } } +void stripDerivativeDecorations(IRInst* inst) +{ + for (auto decor = inst->getFirstDecoration(); decor; ) + { + auto next = decor->getNextDecoration(); + switch (decor->getOp()) + { + case kIROp_ForwardDerivativeDecoration: + case kIROp_DerivativeMemberDecoration: + case kIROp_BackwardDerivativeDecoration: + case kIROp_BackwardDerivativeIntermediateTypeDecoration: + case kIROp_BackwardDerivativePropagateDecoration: + case kIROp_BackwardDerivativePrimalDecoration: + decor->removeAndDeallocate(); + break; + default: + break; + } + decor = next; + } +} + void stripAutoDiffDecorationsFromChildren(IRInst* parent) { for (auto inst : parent->getChildren()) @@ -702,7 +724,7 @@ struct AutoDiffPass : public InstPassBase forwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc); break; case FuncBodyTranscriptionTaskType::BackwardPrimal: - // Don't need to do anything, they will be filled by `backwardPropagateTranscriber`. + backwardPrimalTranscriber.transcribeFunc(builder, primalFunc, diffFunc); break; case FuncBodyTranscriptionTaskType::BackwardPropagate: backwardPropagateTranscriber.transcribeFunc(builder, primalFunc, diffFunc); diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index b4b97751f..f468b1fca 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -257,4 +257,6 @@ bool processAutodiffCalls( bool finalizeAutoDiffPass(IRModule* module); +void stripDerivativeDecorations(IRInst* inst); + }; diff --git a/source/slang/slang-ir-entry-point-uniforms.cpp b/source/slang/slang-ir-entry-point-uniforms.cpp index d98f39515..1f0bc13b1 100644 --- a/source/slang/slang-ir-entry-point-uniforms.cpp +++ b/source/slang/slang-ir-entry-point-uniforms.cpp @@ -404,6 +404,8 @@ struct CollectEntryPointUniformParams : PerEntryPointPass collectedParam = builder.createParam(paramStructType); } + collectedParam->insertBefore(m_entryPoint.func); + // No matter what, the global shader parameter should have the layout // information from the entry point attached to it, so that the // contained parameters will end up in the right place(s). diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp index 806ea8826..6f412d579 100644 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -48,9 +48,12 @@ namespace Slang IRCloneEnv cloneEnv; IRBuilder builder(sharedContext->sharedBuilderStorage); builder.setInsertBefore(genericParent); + // Do not clone func type (which would break IR def-use rules if we do it here) + // This is OK since we will lower the type immediately after the clone. + cloneEnv.mapOldValToNew[func->getFullType()] = builder.getTypeKind(); auto loweredFunc = cast<IRFunc>(cloneInstAndOperands(&cloneEnv, &builder, func)); auto loweredGenericType = - lowerGenericFuncType(&builder, cast<IRGeneric>(genericParent->getFullType())); + lowerGenericFuncType(&builder, genericParent, cast<IRFuncType>(func->getFullType())); SLANG_ASSERT(loweredGenericType); loweredFunc->setFullType(loweredGenericType); List<IRInst*> clonedParams; @@ -90,7 +93,7 @@ namespace Slang return loweredFunc; } - IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal) + IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal, IRFuncType* funcType) { ShortList<IRInst*> genericParamTypes; Dictionary<IRInst*, IRInst*> typeMapping; @@ -107,7 +110,7 @@ namespace Slang auto innerType = (IRFuncType*)lowerFuncType( builder, - cast<IRFuncType>(findGenericReturnVal(genericVal)), + funcType, typeMapping, genericParamTypes.getArrayView().arrayView); @@ -182,7 +185,10 @@ namespace Slang } else if (auto genericFuncType = as<IRGeneric>(requirementVal)) { - loweredVal = lowerGenericFuncType(&builder, genericFuncType); + loweredVal = lowerGenericFuncType( + &builder, + genericFuncType, + cast<IRFuncType>(findGenericReturnVal(genericFuncType))); } else if (requirementVal->getOp() == kIROp_AssociatedType) { diff --git a/source/slang/slang-ir-remove-unused-generic-param.cpp b/source/slang/slang-ir-remove-unused-generic-param.cpp new file mode 100644 index 000000000..9337a00bb --- /dev/null +++ b/source/slang/slang-ir-remove-unused-generic-param.cpp @@ -0,0 +1,134 @@ +#include "slang-ir-remove-unused-generic-param.h" +#include "slang-ir-inst-pass-base.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ +struct RemoveUnusedGenericParamContext : InstPassBase +{ + RemoveUnusedGenericParamContext(IRModule* inModule) + : InstPassBase(inModule) + {} + + bool processModule() + { + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->init(module); + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + IRBuilder builder(sharedBuilder); + bool changed = false; + for (auto inst : module->getModuleInst()->getChildren()) + { + if (auto genInst = as<IRGeneric>(inst)) + { + auto returnVal = findGenericReturnVal(genInst); + switch (returnVal->getOp()) + { + case kIROp_StructType: + case kIROp_ClassType: + break; + case kIROp_Func: + case kIROp_FuncType: + default: + // Don't simplify functions since this can break signature compatiblity with the + // interface. For example, if we have + // interface IFoo { void genFunc<T>(int x); } + // We can't simplify this by removing `T` even when the function type here does not depend on T. + continue; + } + if (returnVal->findDecoration<IRTargetIntrinsicDecoration>()) + continue; + + List<UInt> paramToPreserve; + UInt id = 0; + List<IRInst*> paramsToRemove; + for (auto param : genInst->getParams()) + { + if (param->hasUses()) + { + paramToPreserve.add(id); + } + else + { + paramsToRemove.add(param); + } + id++; + } + if (paramsToRemove.getCount() == 0) + continue; + changed = true; + if (paramToPreserve.getCount() == 0) + { + // Special case: the generic return value is not dependent on the generic param, + // we can hoist to global scope safely. + for (auto child = genInst->getFirstBlock()->getFirstOrdinaryInst(); child; ) + { + auto next = child->getNextInst(); + if (child->getOp() == kIROp_Return) + { + break; + } + child->insertBefore(genInst); + child = next; + } + SLANG_ASSERT(returnVal); + List<IRUse*> uses; + for (auto use = genInst->firstUse; use; use = use->nextUse) + uses.add(use); + for (auto use : uses) + { + if (use->getUser()->getOp() == kIROp_Specialize && + use == use->getUser()->getOperands()) + { + use->getUser()->replaceUsesWith(returnVal); + } + } + genInst->replaceUsesWith(returnVal); + genInst->removeAndDeallocate(); + } + else + { + // General case: remove unnecessary specialization arguments. + // Disabled this optimization for now since we still need to take care + // of the type of the generic, or change other passes to not + // use type info on a generic at all. + List<IRUse*> uses; + for (auto use = genInst->firstUse; use; use = use->nextUse) + uses.add(use); + for (auto use : uses) + { + if (use->getUser()->getOp() == kIROp_Specialize && + use == use->getUser()->getOperands()) + { + auto specialize = as<IRSpecialize>(use->getUser()); + builder.setInsertBefore(specialize); + List<IRInst*> newArgs; + for (auto i : paramToPreserve) + newArgs.add(specialize->getArg(i)); + auto newSpecialize = builder.emitSpecializeInst( + specialize->getFullType(), + specialize->getBase(), + newArgs.getCount(), + newArgs.getBuffer()); + specialize->transferDecorationsTo(newSpecialize); + specialize->replaceUsesWith(newSpecialize); + specialize->removeAndDeallocate(); + } + } + for (auto param : paramsToRemove) + param->removeAndDeallocate(); + } + } + } + return changed; + } +}; + +bool removeUnusedGenericParam(IRModule* module) +{ + RemoveUnusedGenericParamContext context = RemoveUnusedGenericParamContext(module); + return context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-remove-unused-generic-param.h b/source/slang/slang-ir-remove-unused-generic-param.h new file mode 100644 index 000000000..8f7a61945 --- /dev/null +++ b/source/slang/slang-ir-remove-unused-generic-param.h @@ -0,0 +1,9 @@ +// slang-ir-remove-unused-generic-param.h +#pragma once + +namespace Slang +{ + struct IRModule; + + bool removeUnusedGenericParam(IRModule* module); +} diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp index 4b604e03a..938094551 100644 --- a/source/slang/slang-ir-ssa-simplification.cpp +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -7,6 +7,7 @@ #include "slang-ir-simplify-cfg.h" #include "slang-ir-peephole.h" #include "slang-ir-hoist-constants.h" +#include "slang-ir-remove-unused-generic-param.h" namespace Slang { @@ -31,7 +32,7 @@ namespace Slang eliminateDeadCode(module); changed |= constructSSA(module); - + changed |= removeUnusedGenericParam(module); iterationCounter++; } } diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index 2dee189dc..2415f1388 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -4,6 +4,7 @@ #include "slang-ir.h" #include "slang-ir-clone.h" #include "slang-ir-insts.h" +#include "slang-ir-validate.h" namespace Slang { @@ -1195,7 +1196,6 @@ bool constructSSA(ConstructSSAContext* context) { var->removeAndDeallocate(); } - return true; } diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 8e3e879ad..73d8865ed 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -156,11 +156,11 @@ IRInst* maybeSpecializeWithGeneric(IRBuilder& builder, IRInst* genericToSpecaili return genericToSpecailize; } -IRInst* hoistValueFromGeneric(IRBuilder& builder, IRInst* value, IRInst*& outSpecializedVal, bool replaceExistingValue) +IRInst* hoistValueFromGeneric(IRBuilder& inBuilder, IRInst* value, IRInst*& outSpecializedVal, bool replaceExistingValue) { auto outerGeneric = as<IRGeneric>(findOuterGeneric(value)); if (!outerGeneric) return value; - + IRBuilder builder = inBuilder; builder.setInsertBefore(outerGeneric); auto newGeneric = builder.emitGeneric(); builder.setInsertInto(newGeneric); diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 49f46d0e3..4885dcd96 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -102,6 +102,7 @@ inline IRInst* unwrapAttributedType(IRInst* type) type = attrType->getBaseType(); return type; } + } #endif diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index 46817e212..a49eda322 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -29,7 +29,14 @@ namespace Slang { if (!condition) { - context->getSink()->diagnose(inst, Diagnostics::irValidationFailed, message); + if (context) + { + context->getSink()->diagnose(inst, Diagnostics::irValidationFailed, message); + } + else + { + SLANG_ASSERT_FAILURE("IR validation failed"); + } } } @@ -143,7 +150,10 @@ namespace Slang // If `operandValue` precedes `inst`, then we should // have already seen it, because we scan parent instructions // in order. - validate(context, context->seenInsts.Contains(operandValue), inst, "def must come before use in same block"); + if (context) + { + validate(context, context->seenInsts.Contains(operandValue), inst, "def must come before use in same block"); + } return; } @@ -196,6 +206,34 @@ namespace Slang } } + static thread_local bool _enableIRValidationAtInsert = false; + void disableIRValidationAtInsert() + { + _enableIRValidationAtInsert = false; + } + void enableIRValidationAtInsert() + { + _enableIRValidationAtInsert = true; + } + void validateIRInstOperands(IRInst* inst) + { + if (!_enableIRValidationAtInsert) + return; + switch (inst->getOp()) + { + case kIROp_loop: + case kIROp_ifElse: + case kIROp_unconditionalBranch: + case kIROp_conditionalBranch: + case kIROp_Switch: + return; + default: + break; + } + + validateIRInstOperands(nullptr, inst); + } + void validateCodeBody(IRValidateContext* context, IRGlobalValueWithCode* code) { HashSet<IRBlock*> blocks; @@ -296,4 +334,5 @@ namespace Slang auto sink = codeGenContext->getSink(); validateIRModule(module, sink); } + } diff --git a/source/slang/slang-ir-validate.h b/source/slang/slang-ir-validate.h index 3e8e8dc92..a1a9eb4f4 100644 --- a/source/slang/slang-ir-validate.h +++ b/source/slang/slang-ir-validate.h @@ -37,4 +37,8 @@ namespace Slang void validateIRModuleIfEnabled( CodeGenContext* codeGenContext, IRModule* module); + + void disableIRValidationAtInsert(); + void enableIRValidationAtInsert(); + } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index f37a7a1a0..b36a2ebec 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2810,6 +2810,8 @@ namespace Slang IRBackwardDiffIntermediateContextType* IRBuilder::getBackwardDiffIntermediateContextType( IRInst* func) { + if (!func) + func = getVoidValue(); return (IRBackwardDiffIntermediateContextType*)getType( kIROp_BackwardDiffIntermediateContextType, 1, @@ -6260,6 +6262,8 @@ namespace Slang return type; } + void validateIRInstOperands(IRInst*); + void IRInst::replaceUsesWith(IRInst* other) { // Safety check: don't try to replace something with itself. @@ -6377,6 +6381,10 @@ namespace Slang this->prev = inPrev; this->next = inNext; this->parent = inParent; + +#if _DEBUG + validateIRInstOperands(this); +#endif } void IRInst::insertAfter(IRInst* other) |
