summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-07 12:26:29 -0800
committerGitHub <noreply@github.com>2022-11-07 12:26:29 -0800
commitca882a1ef46a5a8bbff50e3a1a6f973e16358634 (patch)
tree1a4d37ad67d3844b6c69ebec68d3858f0c318747 /source
parentea99c274dea12fffdc89a8d4eeefcbb670232ba8 (diff)
Small cleanups on forward differentiation. (#2498)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-conformance.cpp39
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp135
2 files changed, 68 insertions, 106 deletions
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp
index eb072e9dd..d2335efbf 100644
--- a/source/slang/slang-check-conformance.cpp
+++ b/source/slang/slang-check-conformance.cpp
@@ -378,45 +378,6 @@ namespace Slang
}
}
}
-
- // If a generic type parameter does not declare itself to conform to `IDifferentiable`,
- // we treat it as a subtype of `DifferentialBottom` to make it conform to `IDifferentiable`.
- // Note: we only consider this option for `originalSubType` so a type that implements `IDifferential` but
- // inherits from some other non differentiable types don't get to inherit `DifferentialBottom`.
- if (m_astBuilder->isDifferentiableInterfaceAvailable() &&
- subType == originalSubType &&
- superTypeDeclRef.getDecl() == m_astBuilder->getDifferentiableInterface())
- {
- if (as<GenericTypeParamDecl>(declRefType->declRef.getDecl()) ||
- as<AssocTypeDecl>(declRefType->declRef.getDecl()))
- {
- auto sup = DeclRefType::create(m_astBuilder, superTypeDeclRef);
- auto differentialBottomType = as<DeclRefType>(m_astBuilder->getDifferentialBottomType());
- auto container = differentialBottomType->declRef.as<ContainerDecl>().getDecl();
- SLANG_RELEASE_ASSERT(container);
- auto inheritanceDecl = container->getMembersOfType<InheritanceDecl>().getFirst();
- auto witnessDifferentialBottomIsIDifferentiable =
- m_astBuilder->getOrCreate<DeclaredSubtypeWitness>(
- m_astBuilder->getDifferentialBottomType(),
- sup,
- inheritanceDecl,
- nullptr);
-
- auto witnessSubIsDifferentialBottom =
- m_astBuilder->getOrCreate<DifferentialBottomSubtypeWitness>(
- subType, differentialBottomType);
-
- TransitiveSubtypeWitness* transitiveWitness =
- m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(
- witnessSubIsDifferentialBottom, witnessDifferentialBottomIsIDifferentiable);
- transitiveWitness->sub = subType;
- transitiveWitness->sup = sup;
- transitiveWitness->midToSup = witnessDifferentialBottomIsIDifferentiable;
- transitiveWitness->subToMid = witnessSubIsDifferentialBottom;
- *outWitness = transitiveWitness;
- return true;
- }
- }
}
else if (auto extractExistentialType = as<ExtractExistentialType>(subType))
{
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 8a4fe23d0..3135f300d 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -42,6 +42,8 @@ struct AutoDiffSharedContext
{
IRModuleInst* moduleInst = nullptr;
+ SharedIRBuilder* sharedBuilder = nullptr;
+
// A reference to the builtin IDifferentiable interface type.
// We use this to look up all the other types (and type exprs)
// that conform to a base type.
@@ -422,49 +424,48 @@ struct DifferentialPairTypeBuilder
return emitFieldAccessor(builder, baseInst, this->globalDiffKey);
}
- IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder)
+ IRStructKey* _getOrCreateDiffStructKey()
{
if (!this->globalDiffKey)
{
+ IRBuilder builder(sharedContext->sharedBuilder);
// Insert directly at top level (skip any generic scopes etc.)
- auto insertLoc = builder->getInsertLoc();
- builder->setInsertInto(builder->getModule()->getModuleInst());
-
- this->globalDiffKey = builder->createStructKey();
- builder->addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential"));
+ builder.setInsertInto(sharedContext->moduleInst);
- builder->setInsertLoc(insertLoc);
+ this->globalDiffKey = builder.createStructKey();
+ builder.addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential"));
}
return this->globalDiffKey;
}
- IRStructKey* _getOrCreatePrimalStructKey(IRBuilder* builder)
+ IRStructKey* _getOrCreatePrimalStructKey()
{
if (!this->globalPrimalKey)
{
// Insert directly at top level (skip any generic scopes etc.)
- auto insertLoc = builder->getInsertLoc();
- builder->setInsertInto(builder->getModule()->getModuleInst());
+ IRBuilder builder(sharedContext->sharedBuilder);
+ builder.setInsertInto(sharedContext->moduleInst);
- this->globalPrimalKey = builder->createStructKey();
- builder->addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal"));
-
- builder->setInsertLoc(insertLoc);
+ this->globalPrimalKey = builder.createStructKey();
+ builder.addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal"));
}
return this->globalPrimalKey;
}
- IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType)
+ IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType)
{
SLANG_ASSERT(!as<IRParam>(origBaseType));
SLANG_ASSERT(diffType);
if (diffType->getOp() != kIROp_DifferentialBottomType)
{
- auto pairStructType = builder->createStructType();
- builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType);
- builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType);
+ IRBuilder builder(sharedContext->sharedBuilder);
+ builder.setInsertBefore(diffType);
+
+ auto pairStructType = builder.createStructType();
+ builder.createStructField(pairStructType, _getOrCreatePrimalStructKey(), origBaseType);
+ builder.createStructField(pairStructType, _getOrCreateDiffStructKey(), (IRType*)diffType);
return pairStructType;
}
return origBaseType;
@@ -510,7 +511,7 @@ struct DifferentialPairTypeBuilder
}
auto diffType = getDiffTypeFromPairType(builder, pairType);
- result.loweredType = _createDiffPairType(builder, pairType->getValueType(), (IRType*)diffType);
+ result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
result.isTrivial = (diffType->getOp() == kIROp_DifferentialBottomType);
pairTypeCache.Add(originalPairType, result);
@@ -1476,9 +1477,10 @@ struct JVPTranscriber
InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
{
- auto oldLoc = builder->getInsertLoc();
+ IRBuilder subBuilder(builder->getSharedBuilder());
+ subBuilder.setInsertLoc(builder->getInsertLoc());
- IRInst* diffBlock = builder->emitBlock();
+ IRInst* diffBlock = subBuilder.emitBlock();
// Note: for blocks, we setup the mapping _before_
// processing the children since we could encounter
@@ -1487,19 +1489,17 @@ struct JVPTranscriber
mapPrimalInst(origBlock, diffBlock);
mapDifferentialInst(origBlock, diffBlock);
- builder->setInsertInto(diffBlock);
+ subBuilder.setInsertInto(diffBlock);
// First transcribe every parameter in the block.
for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
- this->transcribe(builder, param);
+ this->transcribe(&subBuilder, param);
// Then, run through every instruction and use the transcriber to generate the appropriate
// derivative code.
//
for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
- this->transcribe(builder, child);
-
- builder->setInsertLoc(oldLoc);
+ this->transcribe(&subBuilder, child);
return InstPair(diffBlock, diffBlock);
}
@@ -1709,22 +1709,22 @@ struct JVPTranscriber
}
// Create an empty func to represent the transcribed func of `origFunc`.
- InstPair transcribeFuncHeader(IRBuilder* builder, IRFunc* origFunc)
+ InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
{
- auto oldLoc = builder->getInsertLoc();
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertBefore(origFunc);
IRFunc* primalFunc = origFunc;
differentiableTypeConformanceContext.setFunc(origFunc);
- builder->setInsertBefore(origFunc);
primalFunc = origFunc;
- auto diffFunc = builder->createFunc();
+ auto diffFunc = builder.createFunc();
SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
IRType* diffFuncType = this->differentiateFunctionType(
- builder,
+ &builder,
as<IRFuncType>(origFunc->getFullType()));
diffFunc->setFullType(diffFuncType);
@@ -1732,13 +1732,13 @@ struct JVPTranscriber
{
auto originalName = nameHint->getName();
StringBuilder newNameSb;
- newNameSb << "s_jvp_" << originalName;
- builder->addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
+ newNameSb << "s_fwd_" << originalName;
+ builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
}
- builder->addForwardDerivativeDecoration(origFunc, diffFunc);
+ builder.addForwardDerivativeDecoration(origFunc, diffFunc);
// Mark the generated derivative function itself as differentiable.
- builder->addForwardDifferentiableDecoration(diffFunc);
+ builder.addForwardDifferentiableDecoration(diffFunc);
// Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
@@ -1746,32 +1746,27 @@ struct JVPTranscriber
cloneDecoration(dictDecor, diffFunc);
}
- // Reset builder position
- builder->setInsertLoc(oldLoc);
auto result = InstPair(primalFunc, diffFunc);
followUpFunctionsToTranscribe.add(result);
return result;
}
// Transcribe a function definition.
- InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
+ InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
{
- auto oldLoc = builder->getInsertLoc();
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertInto(diffFunc);
differentiableTypeConformanceContext.setFunc(primalFunc);
// Transcribe children from origFunc into diffFunc
- builder->setInsertInto(diffFunc);
for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
- this->transcribe(builder, block);
-
- // Reset builder position
- builder->setInsertLoc(oldLoc);
+ this->transcribe(&builder, block);
return InstPair(primalFunc, diffFunc);
}
// Transcribe a generic definition
- InstPair transcribeGeneric(IRBuilder* builder, IRGeneric* origGeneric)
+ InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric)
{
auto innerVal = findInnerMostGenericReturnVal(origGeneric);
if (auto innerFunc = as<IRFunc>(innerVal))
@@ -1789,10 +1784,10 @@ struct JVPTranscriber
IRGeneric* primalGeneric = origGeneric;
- auto oldLoc = builder->getInsertLoc();
- builder->setInsertBefore(origGeneric);
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertBefore(origGeneric);
- auto diffGeneric = builder->emitGeneric();
+ auto diffGeneric = builder.emitGeneric();
// Process type of generic. If the generic is a function, then it's type will also be a
// generic and this logic will transcribe that generic first before continuing with the
@@ -1803,7 +1798,7 @@ struct JVPTranscriber
IRType* diffType = nullptr;
if (primalType)
{
- diffType = (IRType*) findOrTranscribeDiffInst(builder, primalType);
+ diffType = (IRType*) findOrTranscribeDiffInst(&builder, primalType);
}
diffGeneric->setFullType(diffType);
@@ -1813,12 +1808,9 @@ struct JVPTranscriber
// builder->addNameHintDecoration(diffFunc, jvpName);
// Transcribe children from origFunc into diffFunc.
- builder->setInsertInto(diffGeneric);
+ builder.setInsertInto(diffGeneric);
for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock())
- this->transcribe(builder, block);
-
- // Reset builder position.
- builder->setInsertLoc(oldLoc);
+ this->transcribe(&builder, block);
return InstPair(primalGeneric, diffGeneric);
}
@@ -1846,12 +1838,23 @@ struct JVPTranscriber
mapDifferentialInst(origInst, pair.differential);
if (pair.differential)
{
- // Generate name hint for the inst.
- if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>())
+ switch (pair.differential->getOp())
{
- StringBuilder sb;
- sb << "s_diff_" << primalNameHint->getName();
- builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
+ case kIROp_Func:
+ case kIROp_Generic:
+ case kIROp_Block:
+ // Don't generate again for these.
+ // Functions already have their names generated in `transcribeFuncHeader`.
+ break;
+ default:
+ // Generate name hint for the inst.
+ if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder sb;
+ sb << "s_diff_" << primalNameHint->getName();
+ builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
+ }
+ break;
}
}
return pair.differential;
@@ -2408,24 +2411,21 @@ struct JVPDerivativeContext : public InstPassBase
return false;
}
- IRStringLit* getForwardDerivativeFuncName(IRBuilder* builder,
- IRInst* func)
+ IRStringLit* getForwardDerivativeFuncName(IRInst* func)
{
- auto oldLoc = builder->getInsertLoc();
- builder->setInsertBefore(func);
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(func);
IRStringLit* name = nullptr;
if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
{
- name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_fwd_diff").getUnownedSlice());
+ name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_fwd_diff").getUnownedSlice());
}
else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
{
- name = builder->getStringValue((String(namehintDecoration->getName()) + "_fwd_diff").getUnownedSlice());
+ name = builder.getStringValue((String(namehintDecoration->getName()) + "_fwd_diff").getUnownedSlice());
}
- builder->setInsertLoc(oldLoc);
-
return name;
}
@@ -2435,6 +2435,7 @@ struct JVPDerivativeContext : public InstPassBase
autoDiffSharedContextStorage(module->getModuleInst()),
transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage)
{
+ autoDiffSharedContextStorage.sharedBuilder = &sharedBuilderStorage;
pairBuilderStorage.sharedContext = &autoDiffSharedContextStorage;
transcriberStorage.sink = sink;
transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage);