diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-07 12:26:29 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-07 12:26:29 -0800 |
| commit | ca882a1ef46a5a8bbff50e3a1a6f973e16358634 (patch) | |
| tree | 1a4d37ad67d3844b6c69ebec68d3858f0c318747 /source/slang | |
| parent | ea99c274dea12fffdc89a8d4eeefcbb670232ba8 (diff) | |
Small cleanups on forward differentiation. (#2498)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-check-conformance.cpp | 39 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 135 |
2 files changed, 68 insertions, 106 deletions
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index eb072e9dd..d2335efbf 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -378,45 +378,6 @@ namespace Slang } } } - - // If a generic type parameter does not declare itself to conform to `IDifferentiable`, - // we treat it as a subtype of `DifferentialBottom` to make it conform to `IDifferentiable`. - // Note: we only consider this option for `originalSubType` so a type that implements `IDifferential` but - // inherits from some other non differentiable types don't get to inherit `DifferentialBottom`. - if (m_astBuilder->isDifferentiableInterfaceAvailable() && - subType == originalSubType && - superTypeDeclRef.getDecl() == m_astBuilder->getDifferentiableInterface()) - { - if (as<GenericTypeParamDecl>(declRefType->declRef.getDecl()) || - as<AssocTypeDecl>(declRefType->declRef.getDecl())) - { - auto sup = DeclRefType::create(m_astBuilder, superTypeDeclRef); - auto differentialBottomType = as<DeclRefType>(m_astBuilder->getDifferentialBottomType()); - auto container = differentialBottomType->declRef.as<ContainerDecl>().getDecl(); - SLANG_RELEASE_ASSERT(container); - auto inheritanceDecl = container->getMembersOfType<InheritanceDecl>().getFirst(); - auto witnessDifferentialBottomIsIDifferentiable = - m_astBuilder->getOrCreate<DeclaredSubtypeWitness>( - m_astBuilder->getDifferentialBottomType(), - sup, - inheritanceDecl, - nullptr); - - auto witnessSubIsDifferentialBottom = - m_astBuilder->getOrCreate<DifferentialBottomSubtypeWitness>( - subType, differentialBottomType); - - TransitiveSubtypeWitness* transitiveWitness = - m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>( - witnessSubIsDifferentialBottom, witnessDifferentialBottomIsIDifferentiable); - transitiveWitness->sub = subType; - transitiveWitness->sup = sup; - transitiveWitness->midToSup = witnessDifferentialBottomIsIDifferentiable; - transitiveWitness->subToMid = witnessSubIsDifferentialBottom; - *outWitness = transitiveWitness; - return true; - } - } } else if (auto extractExistentialType = as<ExtractExistentialType>(subType)) { diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 8a4fe23d0..3135f300d 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -42,6 +42,8 @@ struct AutoDiffSharedContext { IRModuleInst* moduleInst = nullptr; + SharedIRBuilder* sharedBuilder = nullptr; + // A reference to the builtin IDifferentiable interface type. // We use this to look up all the other types (and type exprs) // that conform to a base type. @@ -422,49 +424,48 @@ struct DifferentialPairTypeBuilder return emitFieldAccessor(builder, baseInst, this->globalDiffKey); } - IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder) + IRStructKey* _getOrCreateDiffStructKey() { if (!this->globalDiffKey) { + IRBuilder builder(sharedContext->sharedBuilder); // Insert directly at top level (skip any generic scopes etc.) - auto insertLoc = builder->getInsertLoc(); - builder->setInsertInto(builder->getModule()->getModuleInst()); - - this->globalDiffKey = builder->createStructKey(); - builder->addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential")); + builder.setInsertInto(sharedContext->moduleInst); - builder->setInsertLoc(insertLoc); + this->globalDiffKey = builder.createStructKey(); + builder.addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential")); } return this->globalDiffKey; } - IRStructKey* _getOrCreatePrimalStructKey(IRBuilder* builder) + IRStructKey* _getOrCreatePrimalStructKey() { if (!this->globalPrimalKey) { // Insert directly at top level (skip any generic scopes etc.) - auto insertLoc = builder->getInsertLoc(); - builder->setInsertInto(builder->getModule()->getModuleInst()); + IRBuilder builder(sharedContext->sharedBuilder); + builder.setInsertInto(sharedContext->moduleInst); - this->globalPrimalKey = builder->createStructKey(); - builder->addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal")); - - builder->setInsertLoc(insertLoc); + this->globalPrimalKey = builder.createStructKey(); + builder.addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal")); } return this->globalPrimalKey; } - IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType) + IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType) { SLANG_ASSERT(!as<IRParam>(origBaseType)); SLANG_ASSERT(diffType); if (diffType->getOp() != kIROp_DifferentialBottomType) { - auto pairStructType = builder->createStructType(); - builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType); - builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType); + IRBuilder builder(sharedContext->sharedBuilder); + builder.setInsertBefore(diffType); + + auto pairStructType = builder.createStructType(); + builder.createStructField(pairStructType, _getOrCreatePrimalStructKey(), origBaseType); + builder.createStructField(pairStructType, _getOrCreateDiffStructKey(), (IRType*)diffType); return pairStructType; } return origBaseType; @@ -510,7 +511,7 @@ struct DifferentialPairTypeBuilder } auto diffType = getDiffTypeFromPairType(builder, pairType); - result.loweredType = _createDiffPairType(builder, pairType->getValueType(), (IRType*)diffType); + result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); result.isTrivial = (diffType->getOp() == kIROp_DifferentialBottomType); pairTypeCache.Add(originalPairType, result); @@ -1476,9 +1477,10 @@ struct JVPTranscriber InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock) { - auto oldLoc = builder->getInsertLoc(); + IRBuilder subBuilder(builder->getSharedBuilder()); + subBuilder.setInsertLoc(builder->getInsertLoc()); - IRInst* diffBlock = builder->emitBlock(); + IRInst* diffBlock = subBuilder.emitBlock(); // Note: for blocks, we setup the mapping _before_ // processing the children since we could encounter @@ -1487,19 +1489,17 @@ struct JVPTranscriber mapPrimalInst(origBlock, diffBlock); mapDifferentialInst(origBlock, diffBlock); - builder->setInsertInto(diffBlock); + subBuilder.setInsertInto(diffBlock); // First transcribe every parameter in the block. for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) - this->transcribe(builder, param); + this->transcribe(&subBuilder, param); // Then, run through every instruction and use the transcriber to generate the appropriate // derivative code. // for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) - this->transcribe(builder, child); - - builder->setInsertLoc(oldLoc); + this->transcribe(&subBuilder, child); return InstPair(diffBlock, diffBlock); } @@ -1709,22 +1709,22 @@ struct JVPTranscriber } // Create an empty func to represent the transcribed func of `origFunc`. - InstPair transcribeFuncHeader(IRBuilder* builder, IRFunc* origFunc) + InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) { - auto oldLoc = builder->getInsertLoc(); + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertBefore(origFunc); IRFunc* primalFunc = origFunc; differentiableTypeConformanceContext.setFunc(origFunc); - builder->setInsertBefore(origFunc); primalFunc = origFunc; - auto diffFunc = builder->createFunc(); + auto diffFunc = builder.createFunc(); SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); IRType* diffFuncType = this->differentiateFunctionType( - builder, + &builder, as<IRFuncType>(origFunc->getFullType())); diffFunc->setFullType(diffFuncType); @@ -1732,13 +1732,13 @@ struct JVPTranscriber { auto originalName = nameHint->getName(); StringBuilder newNameSb; - newNameSb << "s_jvp_" << originalName; - builder->addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); + newNameSb << "s_fwd_" << originalName; + builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); } - builder->addForwardDerivativeDecoration(origFunc, diffFunc); + builder.addForwardDerivativeDecoration(origFunc, diffFunc); // Mark the generated derivative function itself as differentiable. - builder->addForwardDifferentiableDecoration(diffFunc); + builder.addForwardDifferentiableDecoration(diffFunc); // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) @@ -1746,32 +1746,27 @@ struct JVPTranscriber cloneDecoration(dictDecor, diffFunc); } - // Reset builder position - builder->setInsertLoc(oldLoc); auto result = InstPair(primalFunc, diffFunc); followUpFunctionsToTranscribe.add(result); return result; } // Transcribe a function definition. - InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) + InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) { - auto oldLoc = builder->getInsertLoc(); + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertInto(diffFunc); differentiableTypeConformanceContext.setFunc(primalFunc); // Transcribe children from origFunc into diffFunc - builder->setInsertInto(diffFunc); for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) - this->transcribe(builder, block); - - // Reset builder position - builder->setInsertLoc(oldLoc); + this->transcribe(&builder, block); return InstPair(primalFunc, diffFunc); } // Transcribe a generic definition - InstPair transcribeGeneric(IRBuilder* builder, IRGeneric* origGeneric) + InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric) { auto innerVal = findInnerMostGenericReturnVal(origGeneric); if (auto innerFunc = as<IRFunc>(innerVal)) @@ -1789,10 +1784,10 @@ struct JVPTranscriber IRGeneric* primalGeneric = origGeneric; - auto oldLoc = builder->getInsertLoc(); - builder->setInsertBefore(origGeneric); + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertBefore(origGeneric); - auto diffGeneric = builder->emitGeneric(); + auto diffGeneric = builder.emitGeneric(); // Process type of generic. If the generic is a function, then it's type will also be a // generic and this logic will transcribe that generic first before continuing with the @@ -1803,7 +1798,7 @@ struct JVPTranscriber IRType* diffType = nullptr; if (primalType) { - diffType = (IRType*) findOrTranscribeDiffInst(builder, primalType); + diffType = (IRType*) findOrTranscribeDiffInst(&builder, primalType); } diffGeneric->setFullType(diffType); @@ -1813,12 +1808,9 @@ struct JVPTranscriber // builder->addNameHintDecoration(diffFunc, jvpName); // Transcribe children from origFunc into diffFunc. - builder->setInsertInto(diffGeneric); + builder.setInsertInto(diffGeneric); for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock()) - this->transcribe(builder, block); - - // Reset builder position. - builder->setInsertLoc(oldLoc); + this->transcribe(&builder, block); return InstPair(primalGeneric, diffGeneric); } @@ -1846,12 +1838,23 @@ struct JVPTranscriber mapDifferentialInst(origInst, pair.differential); if (pair.differential) { - // Generate name hint for the inst. - if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>()) + switch (pair.differential->getOp()) { - StringBuilder sb; - sb << "s_diff_" << primalNameHint->getName(); - builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); + case kIROp_Func: + case kIROp_Generic: + case kIROp_Block: + // Don't generate again for these. + // Functions already have their names generated in `transcribeFuncHeader`. + break; + default: + // Generate name hint for the inst. + if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>()) + { + StringBuilder sb; + sb << "s_diff_" << primalNameHint->getName(); + builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); + } + break; } } return pair.differential; @@ -2408,24 +2411,21 @@ struct JVPDerivativeContext : public InstPassBase return false; } - IRStringLit* getForwardDerivativeFuncName(IRBuilder* builder, - IRInst* func) + IRStringLit* getForwardDerivativeFuncName(IRInst* func) { - auto oldLoc = builder->getInsertLoc(); - builder->setInsertBefore(func); + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(func); IRStringLit* name = nullptr; if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>()) { - name = builder->getStringValue((String(linkageDecoration->getMangledName()) + "_fwd_diff").getUnownedSlice()); + name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_fwd_diff").getUnownedSlice()); } else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>()) { - name = builder->getStringValue((String(namehintDecoration->getName()) + "_fwd_diff").getUnownedSlice()); + name = builder.getStringValue((String(namehintDecoration->getName()) + "_fwd_diff").getUnownedSlice()); } - builder->setInsertLoc(oldLoc); - return name; } @@ -2435,6 +2435,7 @@ struct JVPDerivativeContext : public InstPassBase autoDiffSharedContextStorage(module->getModuleInst()), transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage) { + autoDiffSharedContextStorage.sharedBuilder = &sharedBuilderStorage; pairBuilderStorage.sharedContext = &autoDiffSharedContextStorage; transcriberStorage.sink = sink; transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage); |
