summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-11 15:33:28 -0800
committerGitHub <noreply@github.com>2023-01-11 15:33:28 -0800
commita3ac6e71cbc922b7c941c45f23ee18a9fc274d1f (patch)
treeacf8c18601f124e9290494f8b379d2420369fc35 /source
parent20262684bcbb707d16669b2670039df870b65ca8 (diff)
Make backward differentiation work with generics. (#2586)
* Make backward differentiation work with generics. * Fix. * Another fix. * More fix. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-expr.cpp2
-rw-r--r--source/slang/slang-emit.cpp5
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp41
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp163
-rw-r--r--source/slang/slang-ir-autodiff-rev.h8
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp62
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h26
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp14
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h38
-rw-r--r--source/slang/slang-ir-autodiff.cpp24
-rw-r--r--source/slang/slang-ir-autodiff.h2
-rw-r--r--source/slang/slang-ir-entry-point-uniforms.cpp2
-rw-r--r--source/slang/slang-ir-lower-generic-function.cpp14
-rw-r--r--source/slang/slang-ir-remove-unused-generic-param.cpp134
-rw-r--r--source/slang/slang-ir-remove-unused-generic-param.h9
-rw-r--r--source/slang/slang-ir-ssa-simplification.cpp3
-rw-r--r--source/slang/slang-ir-ssa.cpp2
-rw-r--r--source/slang/slang-ir-util.cpp4
-rw-r--r--source/slang/slang-ir-util.h1
-rw-r--r--source/slang/slang-ir-validate.cpp43
-rw-r--r--source/slang/slang-ir-validate.h4
-rw-r--r--source/slang/slang-ir.cpp8
22 files changed, 503 insertions, 106 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 1ffc45fbd..2fc18628e 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2129,7 +2129,7 @@ namespace Slang
{
derivType = outType->getValueType();
}
- else if (!as<PtrTypeBase>(derivType))
+ else if (as<DifferentialPairType>(derivType))
{
derivType = m_astBuilder->getInOutType(derivType);
}
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index d50cc45a3..00fa5d3cb 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -54,7 +54,6 @@
#include "slang-ir-liveness.h"
#include "slang-ir-glsl-liveness.h"
#include "slang-ir-string-hash.h"
-
#include "slang-legalize-types.h"
#include "slang-lower-to-ir.h"
#include "slang-mangle.h"
@@ -378,7 +377,9 @@ Result linkAndOptimizeIR(
performMandatoryEarlyInlining(irModule);
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
+ enableIRValidationAtInsert();
changed |= processAutodiffCalls(irModule, sink);
+ disableIRValidationAtInsert();
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF");
if (!changed)
@@ -1009,7 +1010,7 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr
this,
linkingAndOptimizationOptions,
linkedIR));
-
+
auto irModule = linkedIR.module;
metadata = linkedIR.metadata;
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index e37415446..54d32ae3e 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -64,17 +64,16 @@ InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVa
if (diffNameHint.getLength() > 0)
builder->addNameHintDecoration(diffVar, diffNameHint.getUnownedSlice());
- return InstPair(cloneInst(&cloneEnv, builder, origVar), diffVar);
+ return InstPair(maybeCloneForPrimalInst(builder, origVar), diffVar);
}
-
- return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr);
+ return InstPair(maybeCloneForPrimalInst(builder, origVar), nullptr);
}
InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRInst* origArith)
{
SLANG_ASSERT(origArith->getOperandCount() == 2);
- IRInst* primalArith = cloneInst(&cloneEnv, builder, origArith);
+ IRInst* primalArith = maybeCloneForPrimalInst(builder, origArith);
auto origLeft = origArith->getOperand(0);
auto origRight = origArith->getOperand(1);
@@ -160,7 +159,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRIns
// Boolean operations are not differentiable. For the linearization
// pass, we do not need to do anything but copy them over to the ne
// function.
- auto primalLogic = cloneInst(&cloneEnv, builder, origLogic);
+ auto primalLogic = maybeCloneForPrimalInst(builder, origLogic);
return InstPair(primalLogic, nullptr);
}
@@ -170,7 +169,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRIns
InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
{
auto origPtr = origLoad->getPtr();
- auto primalPtr = lookupPrimalInst(origPtr, nullptr);
+ auto primalPtr = lookupPrimalInst(builder, origPtr, nullptr);
auto primalPtrValueType = as<IRPtrTypeBase>(primalPtr->getFullType())->getValueType();
if (auto diffPairType = as<IRDifferentialPairType>(primalPtrValueType))
@@ -190,7 +189,7 @@ InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* orig
return InstPair(primalElement, diffElement);
}
- auto primalLoad = cloneInst(&cloneEnv, builder, origLoad);
+ auto primalLoad = maybeCloneForPrimalInst(builder, origLoad);
IRInst* diffLoad = nullptr;
if (auto diffPtr = lookupDiffInst(origPtr, nullptr))
{
@@ -204,9 +203,9 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or
{
IRInst* origStoreLocation = origStore->getPtr();
IRInst* origStoreVal = origStore->getVal();
- auto primalStoreLocation = lookupPrimalInst(origStoreLocation, nullptr);
+ auto primalStoreLocation = lookupPrimalInst(builder, origStoreLocation, nullptr);
auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr);
- auto primalStoreVal = lookupPrimalInst(origStoreVal, nullptr);
+ auto primalStoreVal = lookupPrimalInst(builder, origStoreVal, nullptr);
auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr);
if (!diffStoreLocation)
@@ -222,7 +221,7 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or
}
}
- auto primalStore = cloneInst(&cloneEnv, builder, origStore);
+ auto primalStore = maybeCloneForPrimalInst(builder, origStore);
IRInst* diffStore = nullptr;
@@ -248,7 +247,7 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or
//
InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* origConstruct)
{
- IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct);
+ IRInst* primalConstruct = maybeCloneForPrimalInst(builder, origConstruct);
// Check if the output type can be differentiated. If it cannot be
// differentiated, don't differentiate the inst
@@ -340,7 +339,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
if (!diffCallee)
{
// The callee is non differentiable, just return primal value with null diff value.
- IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall);
+ IRInst* primalCall = maybeCloneForPrimalInst(builder, origCall);
return InstPair(primalCall, nullptr);
}
@@ -419,7 +418,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
InstPair ForwardDiffTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
{
- IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle);
+ IRInst* primalSwizzle = maybeCloneForPrimalInst(builder, origSwizzle);
if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr))
{
@@ -441,7 +440,7 @@ InstPair ForwardDiffTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle
InstPair ForwardDiffTranscriber::transcribeByPassthrough(IRBuilder* builder, IRInst* origInst)
{
- IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst);
+ IRInst* primalInst = maybeCloneForPrimalInst(builder, origInst);
UCount operandCount = origInst->getOperandCount();
@@ -462,7 +461,7 @@ InstPair ForwardDiffTranscriber::transcribeByPassthrough(IRBuilder* builder, IRI
return InstPair(
primalInst,
builder->emitIntrinsicInst(
- differentiateType(builder, primalInst->getDataType()),
+ differentiateType(builder, origInst->getDataType()),
origInst->getOp(),
operandCount,
diffOperands.getBuffer()));
@@ -481,10 +480,10 @@ InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRIns
for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++)
{
auto origArg = origBranch->getArg(ii);
- auto primalArg = lookupPrimalInst(origArg);
+ auto primalArg = lookupPrimalInst(builder, origArg);
newArgs.add(primalArg);
- if (differentiateType(builder, primalArg->getDataType()))
+ if (differentiateType(builder, origArg->getDataType()))
{
auto diffArg = lookupDiffInst(origArg, nullptr);
if (diffArg)
@@ -672,7 +671,7 @@ InstPair ForwardDiffTranscriber::transcribeFieldExtract(IRBuilder* builder, IRIn
IRInst* diffFieldExtract = nullptr;
- if (auto diffType = differentiateType(builder, primalType))
+ if (auto diffType = differentiateType(builder, originalInst->getDataType()))
{
if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
{
@@ -706,7 +705,7 @@ InstPair ForwardDiffTranscriber::transcribeGetElement(IRBuilder* builder, IRInst
IRInst* diffGetElementPtr = nullptr;
- if (auto diffType = differentiateType(builder, primalType))
+ if (auto diffType = differentiateType(builder, origGetElementPtr->getDataType()))
{
if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
{
@@ -820,7 +819,7 @@ InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* build
auto primalPair = builder->emitMakeDifferentialPair(
tryGetDiffPairType(builder, primalVal->getDataType()), primalVal, diffPrimalVal);
auto diffPair = builder->emitMakeDifferentialPair(
- tryGetDiffPairType(builder, differentiateType(builder, primalVal->getDataType())),
+ tryGetDiffPairType(builder, differentiateType(builder, origInst->getPrimalValue()->getDataType())),
primalDiffVal,
diffDiffVal);
return InstPair(primalPair, diffPair);
@@ -897,7 +896,7 @@ InstPair ForwardDiffTranscriber::transcribeWrapExistential(IRBuilder* builder, I
IRInst* diffResult = nullptr;
- if (auto diffType = differentiateType(builder, primalType))
+ if (auto diffType = differentiateType(builder, origInst->getDataType()))
{
List<IRInst*> diffArgs;
for (UInt i = 0; i < origInst->getOperandCount(); i++)
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 817534065..af408a5b3 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -20,27 +20,29 @@ namespace Slang
{
bool noDiff = false;
auto origType = funcType->getParamType(i);
- if (auto attrType = as<IRAttributedType>(origType))
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origType);
+
+ if (auto attrType = as<IRAttributedType>(primalType))
{
if (attrType->findAttr<IRNoDiffAttr>())
{
noDiff = true;
- origType = attrType->getBaseType();
+ primalType = attrType->getBaseType();
}
}
if (noDiff)
{
- newParameterTypes.add(origType);
+ newParameterTypes.add(primalType);
}
else
{
- if (auto diffPairType = tryGetDiffPairType(builder, origType))
+ if (auto diffPairType = tryGetDiffPairType(builder, primalType))
{
auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
newParameterTypes.add(inoutDiffPairType);
}
else
- newParameterTypes.add(origType);
+ newParameterTypes.add(primalType);
}
}
@@ -55,35 +57,47 @@ namespace Slang
return builder->getFuncType(newParameterTypes, diffReturnType);
}
+
+ static IRInst* getOriginalFuncRef(IRBuilder& builder, IRInst* func, IRInst* useSite)
+ {
+ if (!func) return nullptr;
+ auto userGeneric = findOuterGeneric(useSite);
+ if (!userGeneric) return func;
+ auto funcGen = findOuterGeneric(func);
+ SLANG_RELEASE_ASSERT(funcGen);
+ return maybeSpecializeWithGeneric(builder, funcGen, userGeneric);
+ }
IRFuncType* BackwardDiffPrimalTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType)
{
- auto intermediateType = builder->getBackwardDiffIntermediateContextType(func);
+ auto funcRef = getOriginalFuncRef(*builder, func, builder->getInsertLoc().getParent());
+ auto intermediateType = builder->getBackwardDiffIntermediateContextType(funcRef);
auto outType = builder->getOutType(intermediateType);
List<IRType*> paramTypes;
for (UInt i = 0; i < funcType->getParamCount(); i++)
{
- paramTypes.add(funcType->getParamType(i));
+ auto origType = funcType->getParamType(i);
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origType);
+ paramTypes.add(primalType);
}
paramTypes.add(outType);
IRFuncType* primalFuncType = builder->getFuncType(
- paramTypes, funcType->getResultType());
+ paramTypes, (IRType*)findOrTranscribePrimalInst(builder, funcType->getResultType()));
return primalFuncType;
}
InstPair BackwardDiffPrimalTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
{
- SLANG_UNUSED(builder);
- SLANG_UNUSED(diffFunc);
- auto intermediateTypeDecor = primalFunc->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>();
- SLANG_RELEASE_ASSERT(intermediateTypeDecor);
- auto primalDecor = primalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>();
- return InstPair(primalFunc, primalDecor->getBackwardDerivativePrimalFunc());
+ // Don't need to do anything other than add a decoration in the original func to point to the primal func.
+ // The body of the primal func will be generated by propagateTranscriber together with propagate func.
+ addTranscribedFuncDecoration(*builder, primalFunc, diffFunc);
+ return InstPair(primalFunc, primalFunc);
}
IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType)
{
- auto intermediateType = builder->getBackwardDiffIntermediateContextType(func);
+ auto funcRef = getOriginalFuncRef(*builder, func, builder->getInsertLoc().getParent());
+ auto intermediateType = builder->getBackwardDiffIntermediateContextType(funcRef);
return differentiateFunctionTypeImpl(builder, funcType, intermediateType);
}
@@ -96,6 +110,7 @@ namespace Slang
InstPair BackwardDiffPropagateTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
{
IRGlobalValueWithCode* diffPrimalFunc = nullptr;
+ addTranscribedFuncDecoration(*builder, primalFunc, diffFunc);
transcribeFuncImpl(builder, primalFunc, diffFunc, diffPrimalFunc);
return InstPair(primalFunc, diffFunc);
}
@@ -211,8 +226,7 @@ namespace Slang
if (!isMarkedForBackwardDifferentiation(origFunc))
return InstPair(nullptr, nullptr);
- IRBuilder builder(inBuilder->getSharedBuilder());
- builder.setInsertBefore(origFunc);
+ IRBuilder builder = *inBuilder;
IRFunc* primalFunc = origFunc;
@@ -221,6 +235,8 @@ namespace Slang
auto diffFunc = builder.createFunc();
SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
+ builder.setInsertBefore(diffFunc);
+
IRType* diffFuncType = this->differentiateFunctionType(
&builder,
origFunc,
@@ -235,18 +251,6 @@ namespace Slang
builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
}
- if (auto outerGen = findOuterGeneric(diffFunc))
- {
- builder.setInsertBefore(origFunc);
- auto specialized =
- specializeWithGeneric(builder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc)));
- addExistingDiffFuncDecor(&builder, origFunc, specialized);
- }
- else
- {
- addExistingDiffFuncDecor(&builder, origFunc, diffFunc);
- }
-
// Mark the generated derivative function itself as differentiable.
builder.addBackwardDifferentiableDecoration(diffFunc);
@@ -259,6 +263,22 @@ namespace Slang
return InstPair(primalFunc, diffFunc);
}
+ void BackwardDiffTranscriberBase::addTranscribedFuncDecoration(IRBuilder& builder, IRFunc* origFunc, IRFunc* transcribedFunc)
+ {
+ IRBuilder subBuilder = builder;
+ if (auto outerGen = findOuterGeneric(transcribedFunc))
+ {
+ subBuilder.setInsertBefore(origFunc);
+ auto specialized =
+ specializeWithGeneric(subBuilder, outerGen, as<IRGeneric>(findOuterGeneric(origFunc)));
+ addExistingDiffFuncDecor(&subBuilder, origFunc, specialized);
+ }
+ else
+ {
+ addExistingDiffFuncDecor(&subBuilder, origFunc, transcribedFunc);
+ }
+ }
+
InstPair BackwardDiffTranscriberBase::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
{
auto result = transcribeFuncHeaderImpl(inBuilder, origFunc);
@@ -288,7 +308,7 @@ namespace Slang
List<IRType*> primalTypes, propagateTypes;
for (UInt i = 0; i < funcType->getParamCount(); i++)
{
- auto paramType = funcType->getParamType(i);
+ auto paramType = (IRType*)findOrTranscribePrimalInst(&builder, funcType->getParamType(i));
auto param = builder.emitParam(paramType);
if (i != funcType->getParamCount() - 1)
{
@@ -368,10 +388,8 @@ namespace Slang
{
IRParam* nextParam = param->getNextParam();
- // Copy inst into the new parameter block.
- auto clonedParam = cloneInst(&cloneEnv, &builder, param);
- param->replaceUsesWith(clonedParam);
- param->removeAndDeallocate();
+ // Move inst into the new parameter block.
+ param->insertAtEnd(paramBlock);
param = nextParam;
}
@@ -383,6 +401,62 @@ namespace Slang
builder.emitBranch(firstBlock);
}
+ // Create a copy of originalFunc's forward derivative in the same generic context (if any) of
+ // `diffPropagateFunc`.
+ IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc(
+ IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc)
+ {
+ auto primalOuterParent = findOuterGeneric(originalFunc);
+ if (!primalOuterParent)
+ primalOuterParent = originalFunc;
+
+ // Make a clone of original func so we won't modify the original.
+ IRCloneEnv originalCloneEnv;
+ primalOuterParent = cloneInst(&originalCloneEnv, builder, primalOuterParent);
+ auto primalFunc = as<IRFunc>(getGenericReturnVal(primalOuterParent));
+
+ // Strip any existing derivative decorations off the clone.
+ stripDerivativeDecorations(primalFunc);
+ eliminateDeadCode(primalOuterParent);
+
+ // Forward transcribe the clone of the original func.
+ ForwardDiffTranscriber fwdTranscriber(autoDiffSharedContext, builder->getSharedBuilder(), sink);
+ fwdTranscriber.pairBuilder = pairBuilder;
+ IRFunc* fwdDiffFunc = as<IRFunc>(getGenericReturnVal(fwdTranscriber.transcribe(builder, primalOuterParent)));
+ SLANG_ASSERT(fwdDiffFunc);
+ fwdTranscriber.transcribeFunc(builder, primalFunc, fwdDiffFunc);
+
+ // Remove the clone of original func.
+ primalOuterParent->removeAndDeallocate();
+
+ // Migrate the new forward derivative function into the generic parent of `diffPropagateFunc`.
+ if (auto fwdParentGeneric = as<IRGeneric>(findOuterGeneric(fwdDiffFunc)))
+ {
+ // Clone forward derivative func from its own generic into current generic parent.
+ GenericChildrenMigrationContext migrationContext;
+ auto diffOuterGeneric = as<IRGeneric>(findOuterGeneric(diffPropagateFunc));
+ SLANG_RELEASE_ASSERT(diffOuterGeneric);
+
+ migrationContext.init(fwdParentGeneric, diffOuterGeneric);
+ auto inst = fwdParentGeneric->getFirstBlock()->getFirstOrdinaryInst();
+ builder->setInsertBefore(diffPropagateFunc);
+ while (inst)
+ {
+ auto next = inst->getNextInst();
+ auto cloned = migrationContext.cloneInst(builder, inst);
+ if (inst == fwdDiffFunc)
+ {
+ fwdDiffFunc = as<IRFunc>(cloned);
+ break;
+ }
+ inst = next;
+ }
+ fwdParentGeneric->removeAndDeallocate();
+ }
+
+ return fwdDiffFunc;
+ }
+
// Transcribe a function definition.
void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc)
{
@@ -393,12 +467,16 @@ namespace Slang
// Generate a temporary forward derivative function as an intermediate step.
IRBuilder tempBuilder = *builder;
- tempBuilder.setInsertBefore(diffPropagateFunc);
- ForwardDiffTranscriber* fwdTranscriber = static_cast<ForwardDiffTranscriber*>(autoDiffSharedContext->transcriberSet.forwardTranscriber);
- IRFunc* fwdDiffFunc = as<IRFunc>(fwdTranscriber->transcribeFuncHeaderImpl(&tempBuilder, primalFunc));
- SLANG_ASSERT(fwdDiffFunc);
+ if (auto outerGeneric = findOuterGeneric(diffPropagateFunc))
+ {
+ tempBuilder.setInsertBefore(outerGeneric);
+ }
+ else
+ {
+ tempBuilder.setInsertBefore(diffPropagateFunc);
+ }
- fwdTranscriber->transcribeFunc(&tempBuilder, primalFunc, fwdDiffFunc);
+ auto fwdDiffFunc = generateNewForwardDerivativeForFunc(&tempBuilder, primalFunc, diffPropagateFunc);
// Split first block into a paramter block.
this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc));
@@ -466,11 +544,12 @@ namespace Slang
// we have just created.
auto primalOuterGeneric = findOuterGeneric(primalFunc);
IRInst* specializedFunc = nullptr;
- auto intermediateTypeGeneric = hoistValueFromGeneric(*builder, intermediateType, specializedFunc);
+ auto intermediateTypeGeneric = hoistValueFromGeneric(*builder, intermediateType, specializedFunc, true);
+ builder->setInsertBefore(primalFunc);
auto specializedIntermeidateType = maybeSpecializeWithGeneric(*builder, intermediateTypeGeneric, primalOuterGeneric);
builder->addBackwardDerivativeIntermediateTypeDecoration(primalFunc, specializedIntermeidateType);
- auto primalFuncGeneric = hoistValueFromGeneric(*builder, extractedPrimalFunc, specializedFunc);
+ auto primalFuncGeneric = hoistValueFromGeneric(*builder, extractedPrimalFunc, specializedFunc, true);
builder->setInsertBefore(primalFunc);
if (auto existingDecor = primalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>())
@@ -568,7 +647,7 @@ namespace Slang
return diffParam;
}
- return cloneInst(&cloneEnv, builder, origParam);
+ return maybeCloneForPrimalInst(builder, origParam);
}
InstPair BackwardDiffTranscriberBase::copyBinaryArith(IRBuilder* builder, IRInst* origArith)
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index decbdf150..02a100c80 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -85,10 +85,14 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize);
+ IRFunc* generateNewForwardDerivativeForFunc(IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc);
+
void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc);
InstPair transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc);
+ void addTranscribedFuncDecoration(IRBuilder& builder, IRFunc* origFunc, IRFunc* transcribedFunc);
+
virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override;
virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) override;
@@ -173,8 +177,10 @@ struct BackwardDiffTranscriber : BackwardDiffTranscriberBase
virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override;
virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) override
{
- SLANG_UNUSED(builder);
// Don't need to do anything here, the body is generated in transcribeFuncHeader.
+
+ SLANG_UNUSED(builder);
+ addTranscribedFuncDecoration(*builder, primalFunc, diffFunc);
return InstPair(primalFunc, diffFunc);
}
virtual IRInst* findExistingDiffFunc(IRInst* originalFunc) override
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index c0404e036..deb1b2da9 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -75,36 +75,38 @@ bool AutoDiffTranscriberBase::hasDifferentialInst(IRInst* origInst)
return instMapD.ContainsKey(origInst);
}
-bool AutoDiffTranscriberBase::shouldUseOriginalAsPrimal(IRInst* origInst)
+bool AutoDiffTranscriberBase::shouldUseOriginalAsPrimal(IRInst* currentParent, IRInst* origInst)
{
if (as<IRGlobalValueWithCode>(origInst))
return true;
if (origInst->parent && origInst->parent->getOp() == kIROp_Module)
return true;
+ if (isChildInstOf(currentParent, origInst->getParent()))
+ return true;
return false;
}
-IRInst* AutoDiffTranscriberBase::lookupPrimalInst(IRInst* origInst)
+IRInst* AutoDiffTranscriberBase::lookupPrimalInstImpl(IRInst* currentParent, IRInst* origInst)
{
if (!origInst)
return nullptr;
- if (shouldUseOriginalAsPrimal(origInst))
+ if (shouldUseOriginalAsPrimal(currentParent, origInst))
return origInst;
return cloneEnv.mapOldValToNew[origInst];
}
-IRInst* AutoDiffTranscriberBase::lookupPrimalInst(IRInst* origInst, IRInst* defaultInst)
+IRInst* AutoDiffTranscriberBase::lookupPrimalInst(IRInst* currentParent, IRInst* origInst, IRInst* defaultInst)
{
if (!origInst)
return nullptr;
- return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst;
+ return (hasPrimalInst(currentParent, origInst)) ? lookupPrimalInstImpl(currentParent, origInst) : defaultInst;
}
-bool AutoDiffTranscriberBase::hasPrimalInst(IRInst* origInst)
+bool AutoDiffTranscriberBase::hasPrimalInst(IRInst* currentParent, IRInst* origInst)
{
if (!origInst)
return false;
- if (shouldUseOriginalAsPrimal(origInst))
+ if (shouldUseOriginalAsPrimal(currentParent, origInst))
return true;
return cloneEnv.mapOldValToNew.ContainsKey(origInst);
}
@@ -124,26 +126,48 @@ IRInst* AutoDiffTranscriberBase::findOrTranscribePrimalInst(IRBuilder* builder,
{
if (!origInst)
return origInst;
+
+ auto currentParent = builder->getInsertLoc().getParent();
- if (shouldUseOriginalAsPrimal(origInst))
+ if (shouldUseOriginalAsPrimal(currentParent, origInst))
return origInst;
- if (!hasPrimalInst(origInst))
+ if (!hasPrimalInst(currentParent, origInst))
{
transcribe(builder, origInst);
- SLANG_ASSERT(hasPrimalInst(origInst));
+ SLANG_ASSERT(hasPrimalInst(currentParent, origInst));
}
- return lookupPrimalInst(origInst);
+ return lookupPrimalInstImpl(currentParent, origInst);
}
IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRInst* inst)
{
- IRInst* primal = lookupPrimalInst(inst, inst);
-
- if (primal == inst &&
- !isChildInstOf(builder->getInsertLoc().getParent(), inst->getParent()))
- primal = cloneInst(&cloneEnv, builder, inst);
+ IRInst* primal = lookupPrimalInst(builder, inst, nullptr);
+ if (!primal)
+ {
+ IRInst* type = inst->getFullType();
+ if (type)
+ {
+ type = maybeCloneForPrimalInst(builder, type);
+ }
+ List<IRInst*> operands;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto operand = maybeCloneForPrimalInst(builder, inst->getOperand(i));
+ operands.add(operand);
+ }
+ auto cloneResult = builder->emitIntrinsicInst(
+ (IRType*)type, inst->getOp(), operands.getCount(), operands.getBuffer());
+ IRBuilder subBuilder = *builder;
+ subBuilder.setInsertInto(cloneResult);
+ for (auto child : inst->getDecorationsAndChildren())
+ {
+ maybeCloneForPrimalInst(&subBuilder, child);
+ }
+ cloneEnv.mapOldValToNew[inst] = cloneResult;
+ return cloneResult;
+ }
return primal;
}
@@ -223,7 +247,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
// If there is an explicit primal version of this type in the local scope, load that
// otherwise use the original type.
//
- IRInst* primalType = lookupPrimalInst(origType, origType);
+ IRInst* primalType = lookupPrimalInst(builder, origType, origType);
// Special case certain compound types (PtrType, FuncType, etc..)
// otherwise try to lookup a differential definition for the given type.
@@ -390,7 +414,7 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder*
if (lookupKeyPath.getCount())
{
// `interfaceType` does conform to `IDifferentiable`.
- outWitnessTable = builder->emitExtractExistentialWitnessTable(lookupPrimalInstIfExists(origType->getOperand(0)));
+ outWitnessTable = builder->emitExtractExistentialWitnessTable(lookupPrimalInstIfExists(builder, origType->getOperand(0)));
for (auto node : lookupKeyPath)
{
outWitnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), outWitnessTable, node->getRequirementKey());
@@ -731,7 +755,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
//
if (auto diffInst = lookupDiffInst(origInst, nullptr))
{
- SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check.
+ SLANG_ASSERT(lookupPrimalInst(builder, origInst)); // Consistency check.
return diffInst;
}
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index a6b832856..2d980145e 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -41,7 +41,7 @@ struct AutoDiffTranscriberBase
, sharedBuilder(inSharedBuilder)
, sink(inSink)
{
-
+ cloneEnv.squashChildrenMapping = true;
}
DiagnosticSink* getSink();
@@ -61,15 +61,29 @@ struct AutoDiffTranscriberBase
bool hasDifferentialInst(IRInst* origInst);
- bool shouldUseOriginalAsPrimal(IRInst* origInst);
+ bool shouldUseOriginalAsPrimal(IRInst* currentParent, IRInst* origInst);
+
+ IRInst* lookupPrimalInstImpl(IRInst* currentParent, IRInst* origInst);
+
+ IRInst* lookupPrimalInst(IRInst* currentParent, IRInst* origInst, IRInst* defaultInst);
+
+ IRInst* lookupPrimalInstIfExists(IRBuilder* builder, IRInst* origInst)
+ {
+ return lookupPrimalInst(builder->getInsertLoc().getParent(), origInst, origInst);
+ }
- IRInst* lookupPrimalInst(IRInst* origInst);
+ IRInst* lookupPrimalInst(IRBuilder* builder, IRInst* origInst)
+ {
+ return lookupPrimalInstImpl(builder->getInsertLoc().getParent(), origInst);
+ }
- IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst);
+ IRInst* lookupPrimalInst(IRBuilder* builder, IRInst* origInst, IRInst* defaultInst)
+ {
+ return lookupPrimalInst(builder->getInsertLoc().getParent(), origInst, defaultInst);
+ }
- IRInst* lookupPrimalInstIfExists(IRInst* origInst) { return lookupPrimalInst(origInst, origInst); }
- bool hasPrimalInst(IRInst* origInst);
+ bool hasPrimalInst(IRInst* currentParent, IRInst* origInst);
IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst);
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 43b48aa13..b8a4c4f08 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -4,6 +4,7 @@
namespace Slang
{
+
struct ExtractPrimalFuncContext
{
SharedIRBuilder* sharedBuilder;
@@ -74,14 +75,18 @@ struct ExtractPrimalFuncContext
IRFuncType* originalFuncType = nullptr;
outIntermediateType = createIntermediateType(destFunc);
+ GenericChildrenMigrationContext migrationContext;
+ migrationContext.init(as<IRGeneric>(findOuterGeneric(originalFunc)), as<IRGeneric>(findOuterGeneric(destFunc)));
+
originalFuncType = as<IRFuncType>(originalFunc->getDataType());
SLANG_RELEASE_ASSERT(originalFuncType);
List<IRType*> paramTypes;
for (UInt i = 0; i < originalFuncType->getParamCount() - 1; i++)
- paramTypes.add(originalFuncType->getParamType(i));
+ paramTypes.add((IRType*)migrationContext.cloneInst(&builder, originalFuncType->getParamType(i)));
paramTypes.add(builder.getInOutType((IRType*)outIntermediateType));
- auto newFuncType = builder.getFuncType(paramTypes, builder.getVoidType());
+ auto resultType = (IRType*)migrationContext.cloneInst(&builder, originalFuncType->getResultType());
+ auto newFuncType = builder.getFuncType(paramTypes, resultType);
return newFuncType;
}
@@ -239,7 +244,10 @@ struct ExtractPrimalFuncContext
auto ptrStructType = as<IRPtrTypeBase>(intermediateOutput->getDataType());
SLANG_RELEASE_ASSERT(ptrStructType);
auto structType = as<IRStructType>(ptrStructType->getValueType());
- genTypeBuilder.setInsertBefore(structType);
+ if (auto outerGen = findOuterGeneric(structType))
+ genTypeBuilder.setInsertBefore(outerGen);
+ else
+ genTypeBuilder.setInsertBefore(structType);
auto fieldType = type;
SLANG_RELEASE_ASSERT(structType);
auto structKey = genTypeBuilder.createStructKey();
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index ba1e425db..612212dd9 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -9,10 +9,44 @@
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-autodiff-propagate.h"
#include "slang-ir-autodiff-transcriber-base.h"
+#include "slang-ir-validate.h"
namespace Slang
{
+struct GenericChildrenMigrationContext
+{
+ IRCloneEnv cloneEnv;
+ IRGeneric* srcGeneric;
+ void init(IRGeneric* genericSrc, IRGeneric* genericDst)
+ {
+ srcGeneric = genericSrc;
+ if (!genericSrc)
+ return;
+ auto srcParam = genericSrc->getFirstBlock()->getFirstParam();
+ auto dstParam = genericDst->getFirstBlock()->getFirstParam();
+ while (srcParam && dstParam)
+ {
+ cloneEnv.mapOldValToNew[srcParam] = dstParam;
+ srcParam = srcParam->getNextParam();
+ dstParam = dstParam->getNextParam();
+ }
+ cloneEnv.mapOldValToNew[genericSrc] = genericDst;
+ cloneEnv.mapOldValToNew[genericSrc->getFirstBlock()] = genericDst->getFirstBlock();
+ }
+
+ IRInst* cloneInst(IRBuilder* builder, IRInst* src)
+ {
+ if (!srcGeneric)
+ return src;
+ if (findOuterGeneric(src) == srcGeneric)
+ {
+ return Slang::cloneInst(&cloneEnv, builder, src);
+ }
+ return src;
+ }
+};
+
struct DiffUnzipPass
{
AutoDiffSharedContext* autodiffContext;
@@ -62,6 +96,7 @@ struct DiffUnzipPass
// TODO: Looks like we get a copy of the decorations?
IRCloneEnv subEnv;
subEnv.parent = &cloneEnv;
+ builder->setInsertBefore(func);
IRFunc* unzippedFunc = as<IRFunc>(cloneInst(&subEnv, builder, func));
builder->setInsertInto(unzippedFunc);
@@ -231,7 +266,10 @@ struct DiffUnzipPass
newFwdCallee,
diffArgs);
diffBuilder->markInstAsDifferential(diffPairVal, primalType);
+
+ disableIRValidationAtInsert();
diffBuilder->addBackwardDerivativePrimalContextDecoration(diffPairVal, intermediateVar);
+ enableIRValidationAtInsert();
auto diffVal = diffBuilder->emitDifferentialPairGetDifferential(diffType, diffPairVal);
diffBuilder->markInstAsDifferential(diffVal, primalType);
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 94417ea00..74afa4002 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -393,6 +393,28 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
}
}
+void stripDerivativeDecorations(IRInst* inst)
+{
+ for (auto decor = inst->getFirstDecoration(); decor; )
+ {
+ auto next = decor->getNextDecoration();
+ switch (decor->getOp())
+ {
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_DerivativeMemberDecoration:
+ case kIROp_BackwardDerivativeDecoration:
+ case kIROp_BackwardDerivativeIntermediateTypeDecoration:
+ case kIROp_BackwardDerivativePropagateDecoration:
+ case kIROp_BackwardDerivativePrimalDecoration:
+ decor->removeAndDeallocate();
+ break;
+ default:
+ break;
+ }
+ decor = next;
+ }
+}
+
void stripAutoDiffDecorationsFromChildren(IRInst* parent)
{
for (auto inst : parent->getChildren())
@@ -702,7 +724,7 @@ struct AutoDiffPass : public InstPassBase
forwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc);
break;
case FuncBodyTranscriptionTaskType::BackwardPrimal:
- // Don't need to do anything, they will be filled by `backwardPropagateTranscriber`.
+ backwardPrimalTranscriber.transcribeFunc(builder, primalFunc, diffFunc);
break;
case FuncBodyTranscriptionTaskType::BackwardPropagate:
backwardPropagateTranscriber.transcribeFunc(builder, primalFunc, diffFunc);
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index b4b97751f..f468b1fca 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -257,4 +257,6 @@ bool processAutodiffCalls(
bool finalizeAutoDiffPass(IRModule* module);
+void stripDerivativeDecorations(IRInst* inst);
+
};
diff --git a/source/slang/slang-ir-entry-point-uniforms.cpp b/source/slang/slang-ir-entry-point-uniforms.cpp
index d98f39515..1f0bc13b1 100644
--- a/source/slang/slang-ir-entry-point-uniforms.cpp
+++ b/source/slang/slang-ir-entry-point-uniforms.cpp
@@ -404,6 +404,8 @@ struct CollectEntryPointUniformParams : PerEntryPointPass
collectedParam = builder.createParam(paramStructType);
}
+ collectedParam->insertBefore(m_entryPoint.func);
+
// No matter what, the global shader parameter should have the layout
// information from the entry point attached to it, so that the
// contained parameters will end up in the right place(s).
diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp
index 806ea8826..6f412d579 100644
--- a/source/slang/slang-ir-lower-generic-function.cpp
+++ b/source/slang/slang-ir-lower-generic-function.cpp
@@ -48,9 +48,12 @@ namespace Slang
IRCloneEnv cloneEnv;
IRBuilder builder(sharedContext->sharedBuilderStorage);
builder.setInsertBefore(genericParent);
+ // Do not clone func type (which would break IR def-use rules if we do it here)
+ // This is OK since we will lower the type immediately after the clone.
+ cloneEnv.mapOldValToNew[func->getFullType()] = builder.getTypeKind();
auto loweredFunc = cast<IRFunc>(cloneInstAndOperands(&cloneEnv, &builder, func));
auto loweredGenericType =
- lowerGenericFuncType(&builder, cast<IRGeneric>(genericParent->getFullType()));
+ lowerGenericFuncType(&builder, genericParent, cast<IRFuncType>(func->getFullType()));
SLANG_ASSERT(loweredGenericType);
loweredFunc->setFullType(loweredGenericType);
List<IRInst*> clonedParams;
@@ -90,7 +93,7 @@ namespace Slang
return loweredFunc;
}
- IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal)
+ IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal, IRFuncType* funcType)
{
ShortList<IRInst*> genericParamTypes;
Dictionary<IRInst*, IRInst*> typeMapping;
@@ -107,7 +110,7 @@ namespace Slang
auto innerType = (IRFuncType*)lowerFuncType(
builder,
- cast<IRFuncType>(findGenericReturnVal(genericVal)),
+ funcType,
typeMapping,
genericParamTypes.getArrayView().arrayView);
@@ -182,7 +185,10 @@ namespace Slang
}
else if (auto genericFuncType = as<IRGeneric>(requirementVal))
{
- loweredVal = lowerGenericFuncType(&builder, genericFuncType);
+ loweredVal = lowerGenericFuncType(
+ &builder,
+ genericFuncType,
+ cast<IRFuncType>(findGenericReturnVal(genericFuncType)));
}
else if (requirementVal->getOp() == kIROp_AssociatedType)
{
diff --git a/source/slang/slang-ir-remove-unused-generic-param.cpp b/source/slang/slang-ir-remove-unused-generic-param.cpp
new file mode 100644
index 000000000..9337a00bb
--- /dev/null
+++ b/source/slang/slang-ir-remove-unused-generic-param.cpp
@@ -0,0 +1,134 @@
+#include "slang-ir-remove-unused-generic-param.h"
+#include "slang-ir-inst-pass-base.h"
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+struct RemoveUnusedGenericParamContext : InstPassBase
+{
+ RemoveUnusedGenericParamContext(IRModule* inModule)
+ : InstPassBase(inModule)
+ {}
+
+ bool processModule()
+ {
+ SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
+ sharedBuilder->init(module);
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+ IRBuilder builder(sharedBuilder);
+ bool changed = false;
+ for (auto inst : module->getModuleInst()->getChildren())
+ {
+ if (auto genInst = as<IRGeneric>(inst))
+ {
+ auto returnVal = findGenericReturnVal(genInst);
+ switch (returnVal->getOp())
+ {
+ case kIROp_StructType:
+ case kIROp_ClassType:
+ break;
+ case kIROp_Func:
+ case kIROp_FuncType:
+ default:
+ // Don't simplify functions since this can break signature compatiblity with the
+ // interface. For example, if we have
+ // interface IFoo { void genFunc<T>(int x); }
+ // We can't simplify this by removing `T` even when the function type here does not depend on T.
+ continue;
+ }
+ if (returnVal->findDecoration<IRTargetIntrinsicDecoration>())
+ continue;
+
+ List<UInt> paramToPreserve;
+ UInt id = 0;
+ List<IRInst*> paramsToRemove;
+ for (auto param : genInst->getParams())
+ {
+ if (param->hasUses())
+ {
+ paramToPreserve.add(id);
+ }
+ else
+ {
+ paramsToRemove.add(param);
+ }
+ id++;
+ }
+ if (paramsToRemove.getCount() == 0)
+ continue;
+ changed = true;
+ if (paramToPreserve.getCount() == 0)
+ {
+ // Special case: the generic return value is not dependent on the generic param,
+ // we can hoist to global scope safely.
+ for (auto child = genInst->getFirstBlock()->getFirstOrdinaryInst(); child; )
+ {
+ auto next = child->getNextInst();
+ if (child->getOp() == kIROp_Return)
+ {
+ break;
+ }
+ child->insertBefore(genInst);
+ child = next;
+ }
+ SLANG_ASSERT(returnVal);
+ List<IRUse*> uses;
+ for (auto use = genInst->firstUse; use; use = use->nextUse)
+ uses.add(use);
+ for (auto use : uses)
+ {
+ if (use->getUser()->getOp() == kIROp_Specialize &&
+ use == use->getUser()->getOperands())
+ {
+ use->getUser()->replaceUsesWith(returnVal);
+ }
+ }
+ genInst->replaceUsesWith(returnVal);
+ genInst->removeAndDeallocate();
+ }
+ else
+ {
+ // General case: remove unnecessary specialization arguments.
+ // Disabled this optimization for now since we still need to take care
+ // of the type of the generic, or change other passes to not
+ // use type info on a generic at all.
+ List<IRUse*> uses;
+ for (auto use = genInst->firstUse; use; use = use->nextUse)
+ uses.add(use);
+ for (auto use : uses)
+ {
+ if (use->getUser()->getOp() == kIROp_Specialize &&
+ use == use->getUser()->getOperands())
+ {
+ auto specialize = as<IRSpecialize>(use->getUser());
+ builder.setInsertBefore(specialize);
+ List<IRInst*> newArgs;
+ for (auto i : paramToPreserve)
+ newArgs.add(specialize->getArg(i));
+ auto newSpecialize = builder.emitSpecializeInst(
+ specialize->getFullType(),
+ specialize->getBase(),
+ newArgs.getCount(),
+ newArgs.getBuffer());
+ specialize->transferDecorationsTo(newSpecialize);
+ specialize->replaceUsesWith(newSpecialize);
+ specialize->removeAndDeallocate();
+ }
+ }
+ for (auto param : paramsToRemove)
+ param->removeAndDeallocate();
+ }
+ }
+ }
+ return changed;
+ }
+};
+
+bool removeUnusedGenericParam(IRModule* module)
+{
+ RemoveUnusedGenericParamContext context = RemoveUnusedGenericParamContext(module);
+ return context.processModule();
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-remove-unused-generic-param.h b/source/slang/slang-ir-remove-unused-generic-param.h
new file mode 100644
index 000000000..8f7a61945
--- /dev/null
+++ b/source/slang/slang-ir-remove-unused-generic-param.h
@@ -0,0 +1,9 @@
+// slang-ir-remove-unused-generic-param.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+
+ bool removeUnusedGenericParam(IRModule* module);
+}
diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp
index 4b604e03a..938094551 100644
--- a/source/slang/slang-ir-ssa-simplification.cpp
+++ b/source/slang/slang-ir-ssa-simplification.cpp
@@ -7,6 +7,7 @@
#include "slang-ir-simplify-cfg.h"
#include "slang-ir-peephole.h"
#include "slang-ir-hoist-constants.h"
+#include "slang-ir-remove-unused-generic-param.h"
namespace Slang
{
@@ -31,7 +32,7 @@ namespace Slang
eliminateDeadCode(module);
changed |= constructSSA(module);
-
+ changed |= removeUnusedGenericParam(module);
iterationCounter++;
}
}
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index 2dee189dc..2415f1388 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -4,6 +4,7 @@
#include "slang-ir.h"
#include "slang-ir-clone.h"
#include "slang-ir-insts.h"
+#include "slang-ir-validate.h"
namespace Slang {
@@ -1195,7 +1196,6 @@ bool constructSSA(ConstructSSAContext* context)
{
var->removeAndDeallocate();
}
-
return true;
}
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 8e3e879ad..73d8865ed 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -156,11 +156,11 @@ IRInst* maybeSpecializeWithGeneric(IRBuilder& builder, IRInst* genericToSpecaili
return genericToSpecailize;
}
-IRInst* hoistValueFromGeneric(IRBuilder& builder, IRInst* value, IRInst*& outSpecializedVal, bool replaceExistingValue)
+IRInst* hoistValueFromGeneric(IRBuilder& inBuilder, IRInst* value, IRInst*& outSpecializedVal, bool replaceExistingValue)
{
auto outerGeneric = as<IRGeneric>(findOuterGeneric(value));
if (!outerGeneric) return value;
-
+ IRBuilder builder = inBuilder;
builder.setInsertBefore(outerGeneric);
auto newGeneric = builder.emitGeneric();
builder.setInsertInto(newGeneric);
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 49f46d0e3..4885dcd96 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -102,6 +102,7 @@ inline IRInst* unwrapAttributedType(IRInst* type)
type = attrType->getBaseType();
return type;
}
+
}
#endif
diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp
index 46817e212..a49eda322 100644
--- a/source/slang/slang-ir-validate.cpp
+++ b/source/slang/slang-ir-validate.cpp
@@ -29,7 +29,14 @@ namespace Slang
{
if (!condition)
{
- context->getSink()->diagnose(inst, Diagnostics::irValidationFailed, message);
+ if (context)
+ {
+ context->getSink()->diagnose(inst, Diagnostics::irValidationFailed, message);
+ }
+ else
+ {
+ SLANG_ASSERT_FAILURE("IR validation failed");
+ }
}
}
@@ -143,7 +150,10 @@ namespace Slang
// If `operandValue` precedes `inst`, then we should
// have already seen it, because we scan parent instructions
// in order.
- validate(context, context->seenInsts.Contains(operandValue), inst, "def must come before use in same block");
+ if (context)
+ {
+ validate(context, context->seenInsts.Contains(operandValue), inst, "def must come before use in same block");
+ }
return;
}
@@ -196,6 +206,34 @@ namespace Slang
}
}
+ static thread_local bool _enableIRValidationAtInsert = false;
+ void disableIRValidationAtInsert()
+ {
+ _enableIRValidationAtInsert = false;
+ }
+ void enableIRValidationAtInsert()
+ {
+ _enableIRValidationAtInsert = true;
+ }
+ void validateIRInstOperands(IRInst* inst)
+ {
+ if (!_enableIRValidationAtInsert)
+ return;
+ switch (inst->getOp())
+ {
+ case kIROp_loop:
+ case kIROp_ifElse:
+ case kIROp_unconditionalBranch:
+ case kIROp_conditionalBranch:
+ case kIROp_Switch:
+ return;
+ default:
+ break;
+ }
+
+ validateIRInstOperands(nullptr, inst);
+ }
+
void validateCodeBody(IRValidateContext* context, IRGlobalValueWithCode* code)
{
HashSet<IRBlock*> blocks;
@@ -296,4 +334,5 @@ namespace Slang
auto sink = codeGenContext->getSink();
validateIRModule(module, sink);
}
+
}
diff --git a/source/slang/slang-ir-validate.h b/source/slang/slang-ir-validate.h
index 3e8e8dc92..a1a9eb4f4 100644
--- a/source/slang/slang-ir-validate.h
+++ b/source/slang/slang-ir-validate.h
@@ -37,4 +37,8 @@ namespace Slang
void validateIRModuleIfEnabled(
CodeGenContext* codeGenContext,
IRModule* module);
+
+ void disableIRValidationAtInsert();
+ void enableIRValidationAtInsert();
+
}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index f37a7a1a0..b36a2ebec 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2810,6 +2810,8 @@ namespace Slang
IRBackwardDiffIntermediateContextType* IRBuilder::getBackwardDiffIntermediateContextType(
IRInst* func)
{
+ if (!func)
+ func = getVoidValue();
return (IRBackwardDiffIntermediateContextType*)getType(
kIROp_BackwardDiffIntermediateContextType,
1,
@@ -6260,6 +6262,8 @@ namespace Slang
return type;
}
+ void validateIRInstOperands(IRInst*);
+
void IRInst::replaceUsesWith(IRInst* other)
{
// Safety check: don't try to replace something with itself.
@@ -6377,6 +6381,10 @@ namespace Slang
this->prev = inPrev;
this->next = inNext;
this->parent = inParent;
+
+#if _DEBUG
+ validateIRInstOperands(this);
+#endif
}
void IRInst::insertAfter(IRInst* other)