summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-02-22 19:33:42 -0500
committerGitHub <noreply@github.com>2023-02-22 16:33:42 -0800
commit6eb0b4dea4da1fc21767c86cc0837d0c8b68063b (patch)
tree8ad8fe77e57db437be5f7403fd324e218db9c578 /source
parent0ef7aa85d3a6b2ff1d6b25576b4d9eff188c1a6a (diff)
Reverse-mode AD fixes for loops with non-trivial break region (#2671)
* Fix crash when applying autodiff to functions with no arguments * Fixes for loops where the break region is non-trivial * Minor fix * Implement array legalization correctly. * Fix array legalization. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp8
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h65
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h59
-rw-r--r--source/slang/slang-ir-legalize-types.cpp174
5 files changed, 267 insertions, 41 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 55c0ee46d..640f516ed 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -568,11 +568,9 @@ namespace Slang
{
DifferentiableTypeConformanceContext* diffTypeContext;
- virtual bool shouldConvertAddrInst(IRInst* addrInst) override
+ virtual bool shouldConvertAddrInst(IRInst*) override
{
- if (isDifferentiableType(*diffTypeContext, addrInst->getDataType()))
- return true;
- return false;
+ return true;
}
};
@@ -598,7 +596,9 @@ namespace Slang
if (SLANG_SUCCEEDED(result))
{
+ disableIRValidationAtInsert();
simplifyFunc(func);
+ enableIRValidationAtInsert();
}
return result;
}
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 70018b476..95ad58586 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -328,18 +328,6 @@ struct DiffTransposePass
getPhiGrads(trueBlock).getCount(),
getPhiGrads(trueBlock).getBuffer());
- // Old false-side starting block becomes end block
- // for the new pre-cond region (which could be empty)
- //
- IRBlock* revPreCondEndBlock = revBlockMap[falseBlock];
- if (!falseRegionInfo.isTrivial)
- {
- builder.setInsertInto(revPreCondEndBlock);
- builder.emitBranch(
- revCondBlock,
- getPhiGrads(falseBlock).getCount(),
- getPhiGrads(falseBlock).getBuffer());
- }
IRBlock* revBreakRegionExitBlock = revBlockMap[firstLoopBlock];
if (!preCondRegionInfo.isTrivial)
@@ -366,17 +354,42 @@ struct DiffTransposePass
ifElse->getCondition(),
revTrueBlock,
revFalseBlock,
- revLoopEndBlock);
+ revTrueBlock);
- // Emit loop into rev-version of the break block.
- auto revLoopBlock = revBlockMap[breakBlock];
- builder.setInsertInto(revLoopBlock);
- builder.emitLoop(
- revPreCondBlock,
- revBreakBlock,
- revLoopEndBlock,
- getPhiGrads(breakBlock).getCount(),
- getPhiGrads(breakBlock).getBuffer());
+ // Old false-side starting block becomes end block
+ // for the new pre-cond region (which could be empty)
+ //
+
+ if (!falseRegionInfo.isTrivial)
+ {
+ IRBlock* revPreCondEndBlock = revBlockMap[falseBlock];
+ builder.setInsertInto(revPreCondEndBlock);
+ builder.emitLoop(
+ revCondBlock,
+ revBreakBlock,
+ revLoopEndBlock,
+ getPhiGrads(falseBlock).getCount(),
+ getPhiGrads(falseBlock).getBuffer());
+
+ auto revLoopStartBlock = revBlockMap[breakBlock];
+ builder.setInsertInto(revLoopStartBlock);
+ builder.emitBranch(
+ revPreCondBlock,
+ getPhiGrads(breakBlock).getCount(),
+ getPhiGrads(breakBlock).getBuffer());
+ }
+ else
+ {
+ // Emit loop into rev-version of the break block.
+ auto revLoopBlock = revBlockMap[breakBlock];
+ builder.setInsertInto(revLoopBlock);
+ builder.emitLoop(
+ revPreCondBlock,
+ revBreakBlock,
+ revLoopEndBlock,
+ getPhiGrads(breakBlock).getCount(),
+ getPhiGrads(breakBlock).getBuffer());
+ }
currentBlock = breakBlock;
break;
@@ -1436,9 +1449,13 @@ struct DiffTransposePass
argRequiresLoad.add(false);
}
- args.add(builder->emitLoad(primalContextDecor->getBackwardDerivativePrimalContextVar()));
+ // Ensure availability of the primal context var
+ auto primalContextVar = hoistPrimalInst(builder, primalContextDecor->getBackwardDerivativePrimalContextVar());
+ SLANG_RELEASE_ASSERT(primalContextVar);
+
+ args.add(builder->emitLoad(primalContextVar));
argTypes.add(as<IRPtrTypeBase>(
- primalContextDecor->getBackwardDerivativePrimalContextVar()->getDataType())
+ primalContextVar->getDataType())
->getValueType());
argRequiresLoad.add(false);
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 4e7539b48..50c5c4ea6 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -83,7 +83,7 @@ struct ExtractPrimalFuncContext
SLANG_RELEASE_ASSERT(originalFuncType);
List<IRType*> paramTypes;
- for (UInt i = 0; i < originalFuncType->getParamCount() - 1; i++)
+ for (Index i = 0; i < ((Count) originalFuncType->getParamCount()) - 1; i++)
paramTypes.add((IRType*)migrationContext.cloneInst(&builder, originalFuncType->getParamType(i)));
paramTypes.add(builder.getInOutType((IRType*)outIntermediateType));
auto resultType = (IRType*)migrationContext.cloneInst(&builder, originalFuncType->getResultType());
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 2ccb8d8e2..e2c84ce8b 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -525,7 +525,32 @@ struct DiffUnzipPass
List<IRInst*> primalInsts;
for (auto child = primalBlock->getFirstChild(); child; child = child->getNextInst())
+ {
+ // TODO: This might be a decent place to enforce that each load has a single
+ // corresponding store (i.e. that everything is SSAd properly)?
+
+ // We're only interested in insts that generate values.
+ if (child->getDataType() == nullptr ||
+ as<IRVoidType>(child->getDataType()) ||
+ as<IRFuncType>(child->getDataType()) ||
+ as<IRTypeKind>(child->getDataType()))
+ continue;
+
+ // We also don't care about pointer types (only Loads)
+ if (auto ptrType = as<IRPtrTypeBase>(child->getDataType()))
+ {
+ // There's an exception to this, if the var is an intermediate context type
+ // variable since there won't be a load from this yet (the load will
+ // be inserted later during the transposition process)
+ //
+ if (as<IRBackwardDiffIntermediateContextType>(ptrType->getValueType()))
+ primalInsts.add(child);
+
+ continue;
+ }
+
primalInsts.add(child);
+ }
IRBuilder builder(autodiffContext->moduleInst->getModule());
@@ -545,7 +570,7 @@ struct DiffUnzipPass
bool shouldStore = false;
for (auto use = inst->firstUse; use; use = use->nextUse)
{
- IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent());
+ IRBlock* useBlock = getBlock(use->getUser());
if (isDifferentialInst(useBlock))
{
@@ -561,7 +586,14 @@ struct DiffUnzipPass
builder.setInsertBefore(firstPrimalBlock->getTerminator());
IRType* arrayType = inst->getDataType();
- SLANG_ASSERT(!as<IRPtrTypeBase>(arrayType)); // can't store pointers.
+ bool isPtrType = false;
+
+ if (auto ptrType = as<IRPtrTypeBase>(arrayType))
+ {
+ SLANG_RELEASE_ASSERT(as<IRBackwardDiffIntermediateContextType>(ptrType->getValueType()));
+ arrayType = ptrType->getValueType();
+ isPtrType = true;
+ }
for (auto region : regions)
{
@@ -582,11 +614,6 @@ struct DiffUnzipPass
auto storageVar = builder.emitVar(arrayType);
- // TODO(sai) STOPPED HERE: For some reason, we still have a direct param access
- // when trying to cover up the access to last value of loop counter.
- // Maybe we need a different way to access this? (use a var)
- // Special case?
-
// 3. Store current value into the array and replace uses with a load.
// TODO: If an index is missing, use the 'last' value of the primal index.
{
@@ -616,7 +643,8 @@ struct DiffUnzipPass
{
if (as<IRDecoration>(use->getUser()))
{
- if (!as<IRLoopExitPrimalValueDecoration>(use->getUser()))
+ if (!as<IRLoopExitPrimalValueDecoration>(use->getUser()) &&
+ !as<IRBackwardDerivativePrimalContextDecoration>(use->getUser()))
continue;
}
@@ -683,10 +711,17 @@ struct DiffUnzipPass
instsToTag.add(loadAddr);
}
- auto loadedValue = builder.emitLoad(loadAddr);
- instsToTag.add(loadedValue);
+ if (!isPtrType)
+ {
+ auto loadedValue = builder.emitLoad(loadAddr);
+ instsToTag.add(loadedValue);
- use->set(loadedValue);
+ use->set(loadedValue);
+ }
+ else
+ {
+ use->set(loadAddr);
+ }
}
}
@@ -744,6 +779,8 @@ struct DiffUnzipPass
}
auto intermediateVar = primalBuilder->emitVar((IRType*)intermediateType);
+ primalBuilder->markInstAsPrimal(intermediateVar);
+
primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar);
auto primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn);
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 7660c9526..d7ed1f63f 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -752,7 +752,7 @@ static LegalVal legalizeUnconditionalBranch(
SLANG_UNIMPLEMENTED_X("Unknown legalized val flavor.");
}
}
- context->builder->emitBranch(branchInst->getTargetBlock(), newArgs.getCount() - 1, newArgs.getBuffer() + 1);
+ context->builder->emitIntrinsicInst(nullptr, branchInst->getOp(), newArgs.getCount(), newArgs.getBuffer());
return LegalVal();
}
@@ -1665,6 +1665,169 @@ static LegalVal legalizeMakeStruct(
}
}
+static LegalVal legalizeMakeArray(
+ IRTypeLegalizationContext* context,
+ LegalType legalType,
+ LegalVal const* legalArgs,
+ UInt argCount,
+ IROp constructOp)
+{
+ auto builder = context->builder;
+
+ switch (legalType.flavor)
+ {
+ case LegalType::Flavor::none:
+ return LegalVal();
+
+ case LegalType::Flavor::simple:
+ {
+ List<IRInst*> args;
+ // We need a valid default val for elements that are legalized to `none`.
+ // We grab the first non-none value from the legalized args and use it.
+ // If all args are none (althoguh this shouldn't happen, since the entire array
+ // would have been legalized to none in this case.), we use defaultConstruct op.
+ // Use of defaultConstruct may lead to invalid HLSL/GLSL code, so we want to
+ // avoid that if possible.
+ IRInst* defaultVal = nullptr;
+ for (UInt aa = 0; aa < argCount; ++aa)
+ {
+ if (legalArgs[aa].flavor == LegalVal::Flavor::simple)
+ {
+ defaultVal = legalArgs[aa].getSimple();
+ break;
+ }
+ }
+ if (!defaultVal)
+ {
+ defaultVal = builder->emitDefaultConstruct(as<IRArrayTypeBase>(legalType.getSimple())->getElementType());
+ }
+ for (UInt aa = 0; aa < argCount; ++aa)
+ {
+ if (legalArgs[aa].flavor == LegalVal::Flavor::none)
+ args.add(defaultVal);
+ else
+ args.add(legalArgs[aa].getSimple());
+ }
+ return LegalVal::simple(
+ builder->emitIntrinsicInst(
+ legalType.getSimple(),
+ constructOp,
+ args.getCount(),
+ args.getBuffer()));
+ }
+
+ case LegalType::Flavor::pair:
+ {
+ // There are two sides, the ordinary and the special,
+ // and we basically just dispatch to both of them.
+ auto pairType = legalType.getPair();
+ auto pairInfo = pairType->pairInfo;
+ LegalType ordinaryType = pairType->ordinaryType;
+ LegalType specialType = pairType->specialType;
+
+ List<LegalVal> ordinaryArgs;
+ List<LegalVal> specialArgs;
+ bool hasValidOrdinaryArgs = false;
+ bool hasValidSpecialArgs = false;
+ for (UInt argIndex = 0; argIndex < argCount; argIndex++)
+ {
+ LegalVal arg = legalArgs[argIndex];
+
+ // The argument must be a pair.
+ if (arg.flavor == LegalVal::Flavor::pair)
+ {
+ auto argPair = arg.getPair();
+ ordinaryArgs.add(argPair->ordinaryVal);
+ specialArgs.add(argPair->specialVal);
+ hasValidOrdinaryArgs = true;
+ hasValidSpecialArgs = true;
+ }
+ else if (arg.flavor == LegalVal::Flavor::simple)
+ {
+ if (arg.getSimple()->getFullType() == ordinaryType.irType)
+ {
+ ordinaryArgs.add(arg);
+ specialArgs.add(LegalVal());
+ hasValidOrdinaryArgs = true;
+ }
+ else
+ {
+ ordinaryArgs.add(LegalVal());
+ specialArgs.add(arg);
+ hasValidSpecialArgs = true;
+ }
+ }
+ else if (arg.flavor == LegalVal::Flavor::none)
+ {
+ ordinaryArgs.add(arg);
+ specialArgs.add(arg);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("unhandled");
+ }
+ }
+
+ LegalVal ordinaryVal = LegalVal();
+ if (hasValidOrdinaryArgs)
+ ordinaryVal = legalizeMakeArray(
+ context,
+ ordinaryType,
+ ordinaryArgs.getBuffer(),
+ ordinaryArgs.getCount(),
+ constructOp);
+
+ LegalVal specialVal = LegalVal();
+ if (hasValidSpecialArgs)
+ specialVal = legalizeMakeArray(
+ context, specialType, specialArgs.getBuffer(), specialArgs.getCount(), constructOp);
+
+ return LegalVal::pair(ordinaryVal, specialVal, pairInfo);
+ }
+ break;
+
+ case LegalType::Flavor::tuple:
+ {
+ // For array types that are legalized as tuples,
+ // we expect each element of the array to be legalized as the same tuples.
+ // We want to return a tuple, where i-th element is an array containing
+ // the i-th tuple-element of each legalized array-element.
+
+ auto tupleType = legalType.getTuple();
+
+ RefPtr<TuplePseudoVal> resTupleInfo = new TuplePseudoVal();
+ UInt elementCounter = 0;
+ for (auto typeElem : tupleType->elements)
+ {
+ auto elemKey = typeElem.key;
+ UInt elementIndex = elementCounter++;
+ List<LegalVal> subArray;
+ for (UInt i = 0; i < argCount; i++)
+ {
+ LegalVal argVal = legalArgs[i];
+ SLANG_RELEASE_ASSERT(argVal.flavor == LegalVal::Flavor::tuple);
+ auto argTuple = argVal.getTuple();
+ SLANG_RELEASE_ASSERT(
+ argTuple->elements.getCount() == tupleType->elements.getCount());
+ subArray.add(argTuple->elements[elementIndex].val);
+ }
+
+ auto legalSubArray = legalizeMakeArray(context, typeElem.type, subArray.getBuffer(), subArray.getCount(), constructOp);
+
+ TuplePseudoVal::Element resElem;
+ resElem.key = elemKey;
+ resElem.val = legalSubArray;
+ resTupleInfo->elements.add(resElem);
+ }
+ return LegalVal::tuple(resTupleInfo);
+ }
+
+ default:
+ SLANG_UNEXPECTED("unhandled");
+ UNREACHABLE_RETURN(LegalVal());
+ }
+}
+
static LegalVal legalizeDefaultConstruct(
IRTypeLegalizationContext* context,
LegalType legalType)
@@ -1762,11 +1925,20 @@ static LegalVal legalizeInst(
type,
args.getBuffer(),
inst->getOperandCount());
+ case kIROp_MakeArray:
+ case kIROp_MakeArrayFromElement:
+ return legalizeMakeArray(
+ context,
+ type,
+ args.getBuffer(),
+ inst->getOperandCount(),
+ inst->getOp());
case kIROp_DefaultConstruct:
return legalizeDefaultConstruct(
context,
type);
case kIROp_unconditionalBranch:
+ case kIROp_loop:
return legalizeUnconditionalBranch(context, args, (IRUnconditionalBranch*)inst);
case kIROp_undefined:
return LegalVal();