diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-11-21 10:29:57 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-21 10:29:57 -0500 |
| commit | 545de51298ddda52ac51ded03ad489c98bdda397 (patch) | |
| tree | def78374f743d2c722fbde45eba60951a6f5c8f9 /source/slang/slang-ir-diff-jvp.cpp | |
| parent | d58e08f8237a1888ceaad53402d534679ea83b1a (diff) | |
WIP: Fixed inout struct and added testing for calls to non-differentiable functions (#2505)
* Added non-differentiable call test
* Extended testing for nondifferentiable calls
* Fixed subtle issue with extensions on generic types not applying the correct substitutions, leading to unspecialized generics at the emit stage
* More fixes. inout struct params now work fine
* Update inout-struct-parameters-jvp.slang
* Update slang-ir.cpp
* Fixed hoisting lookup_interface_method
* Fixed non-diff call return value
* Fixed issue with phi nodes
* Fixed problem with IRSpecialize preventing hoisitng of DifferentialPairType
* Fixed non-diff call test to conform to the new 'no_diff' system
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 222 |
1 files changed, 139 insertions, 83 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 4ee16aafc..c9ca687e4 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -92,16 +92,33 @@ struct DifferentialPairTypeBuilder IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key) { - auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType()); - if (baseTypeInfo.isTrivial) + IRInst* pairType = nullptr; + if (auto basePtrType = as<IRPtrTypeBase>(baseInst->getDataType())) { - if (key == globalPrimalKey) - return baseInst; - else - return builder->getDifferentialBottom(); + auto baseTypeInfo = lowerDiffPairType(builder, basePtrType->getValueType()); + + // TODO(sai): Not sure at the moment how to handle diff-bottom pointer types, + // especially since we probably don't need diff bottom anymore. + // + SLANG_ASSERT(!baseTypeInfo.isTrivial); + + pairType = builder->getPtrType(kIROp_PtrType, (IRType*)baseTypeInfo.loweredType); + } + else + { + auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType()); + if (baseTypeInfo.isTrivial) + { + if (key == globalPrimalKey) + return baseInst; + else + return builder->getDifferentialBottom(); + } + + pairType = baseTypeInfo.loweredType; } - if (auto basePairStructType = as<IRStructType>(baseTypeInfo.loweredType)) + if (auto basePairStructType = as<IRStructType>(pairType)) { return as<IRFieldExtract>(builder->emitFieldExtract( findField(basePairStructType, key)->getFieldType(), @@ -109,7 +126,7 @@ struct DifferentialPairTypeBuilder key )); } - else if (auto ptrType = as<IRPtrTypeBase>(baseTypeInfo.loweredType)) + else if (auto ptrType = as<IRPtrTypeBase>(pairType)) { if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType())) { @@ -135,7 +152,7 @@ struct DifferentialPairTypeBuilder key)); } } - else if (auto specializedType = as<IRSpecialize>(baseTypeInfo.loweredType)) + else if (auto specializedType = as<IRSpecialize>(pairType)) { // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's // type, emit the specialization type. @@ -333,7 +350,9 @@ struct JVPTranscriber JVPTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder) : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder) - {} + { + + } DiagnosticSink* getSink() { @@ -449,6 +468,17 @@ struct JVPTranscriber return builder->getFuncType(newParameterTypes, diffReturnType); } + IRWitnessTable* getDifferentialBottomWitness() + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(sharedBuilder->getModule()->getModuleInst()); + auto result = + as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( + builder.getDifferentialBottomType())); + SLANG_ASSERT(result); + return result; + } + // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) { @@ -456,23 +486,20 @@ struct JVPTranscriber builder.setInsertInto(inDiffPairType->parent); auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); SLANG_ASSERT(diffPairType); - auto diffType = differentiateType(&builder, diffPairType->getValueType()); - IRInst* tableInst = nullptr; - if (!differentiableTypeConformanceContext.differentiableWitnessDictionary.TryGetValue(diffPairType, tableInst)) - { - IRWitnessTable* table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); - // The witness that `diffType` - auto differentialType = builder.getDifferentialPairType( - diffType, - differentiableTypeConformanceContext.differentiableWitnessDictionary[diffType] - .GetValue()); - builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); - // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. - differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; - tableInst = table; - } - return as<IRWitnessTable>(tableInst); + auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); + + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = differentiateType(&builder, diffPairType); + + // And place it in the synthesized witness table. + builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType); + // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. + + // Record this in the context for future lookups + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; + + return table; } IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) @@ -490,10 +517,19 @@ struct JVPTranscriber builder.setInsertInto(primalType->parent); auto witness = as<IRWitnessTable>( differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); - if (!witness && as<IRDifferentialPairType>(primalType)) + + if (!witness) { - witness = getDifferentialPairWitness(primalType); + if (auto primalPairType = as<IRDifferentialPairType>(primalType)) + { + witness = getDifferentialPairWitness(primalPairType); + } + else + { + witness = getDifferentialBottomWitness(); + } } + return builder.getDifferentialPairType( (IRType*)primalType, witness); @@ -630,8 +666,8 @@ struct JVPTranscriber builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); SLANG_ASSERT(diffPairParam); - - if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType())) + + if (auto pairType = as<IRDifferentialPairType>(diffPairType)) { return InstPair( builder->emitDifferentialPairGetPrimal(diffPairParam), @@ -639,16 +675,23 @@ struct JVPTranscriber (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), diffPairParam)); } - // If this is an `in/inout DifferentialPair<>` parameter, we can't produce - // its primal and diff parts right now because they would represent a reference - // to a pair field, which doesn't make sense since pair types are considered mutable. - // We encode the result as if the param is non-differentiable, and handle it - // with special care at load/store. - return InstPair(diffPairParam, nullptr); + else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType)) + { + auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType()); + + return InstPair( + builder->emitDifferentialPairAddressPrimal(diffPairParam), + builder->emitDifferentialPairAddressDifferential( + builder->getPtrType( + kIROp_PtrType, + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)), + diffPairParam)); + } } + return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); + cloneInst(&cloneEnv, builder, origParam), + nullptr); } else { @@ -660,6 +703,7 @@ struct JVPTranscriber } return InstPair(primal, diff); } + } // Returns "d<var-name>" to use as a name hint for variables and parameters. @@ -784,6 +828,7 @@ struct JVPTranscriber { // Special case load from an `out` param, which will not have corresponding `diff` and // `primal` insts yet. + auto load = builder->emitLoad(primalPtr); auto primalElement = builder->emitDifferentialPairGetPrimal(load); auto diffElement = builder->emitDifferentialPairGetDifferential( @@ -1401,30 +1446,25 @@ struct JVPTranscriber InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse) { - // The loop comes with three blocks.. we just need to transcribe each one - // and assemble the new loop instruction. + // IfElse Statements come with 4 blocks. We transcribe each block into it's + // linear form, and then wire them up in the same way as the original if-else - // Transcribe the target block (this is the 'condition' part of the loop, which - // will branch into the loop body). - // Note that for the condition we use the primal inst (condition values should not have a - // differential) + // Transcribe condition block auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition()); SLANG_ASSERT(primalConditionBlock); - // Transcribe the break block (this is the block after the exiting the loop) + // Transcribe 'true' block (condition block branches into this if true) auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock()); SLANG_ASSERT(diffTrueBlock); - // Transcribe the continue block (this is the 'update' part of the loop, which will - // branch into the condition block) + // Transcribe 'false' block (condition block branches into this if true) + // TODO (sai): What happens if there's no false block? auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock()); SLANG_ASSERT(diffFalseBlock); - // Transcribe the continue block (this is the 'update' part of the loop, which will - // branch into the condition block) + // Transcribe 'after' block (true and false blocks branch into this) auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock()); SLANG_ASSERT(diffAfterBlock); - List<IRInst*> diffIfElseArgs; diffIfElseArgs.add(primalConditionBlock); @@ -2462,6 +2502,9 @@ struct JVPDerivativeContext : public InstPassBase sharedBuilder->init(module); sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); + // TODO(sai): Move this call. + transcriberStorage.differentiableTypeConformanceContext.buildGlobalWitnessDictionary(); + IRBuilder builderStorage(sharedBuilderStorage); IRBuilder* builder = &builderStorage; @@ -2477,6 +2520,9 @@ struct JVPDerivativeContext : public InstPassBase // modified |= simplifyDifferentialBottomType(builder); + // De-duplicate any remaining types. + sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); + modified |= processPairTypes(builder, module->getModuleInst()); modified |= eliminateDifferentialBottomType(builder); @@ -2665,7 +2711,13 @@ struct JVPDerivativeContext : public InstPassBase { if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst)) { - if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), nullptr)) + auto pairType = getDiffInst->getBase()->getDataType(); + if (auto pairPtrType = as<IRPtrTypeBase>(pairType)) + { + pairType = pairPtrType->getValueType(); + } + + if (lowerPairType(builder, pairType, nullptr)) { builder->setInsertBefore(getDiffInst); IRInst* diffFieldExtract = nullptr; @@ -2677,7 +2729,13 @@ struct JVPDerivativeContext : public InstPassBase } else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst)) { - if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), nullptr)) + auto pairType = getPrimalInst->getBase()->getDataType(); + if (auto pairPtrType = as<IRPtrTypeBase>(pairType)) + { + pairType = pairPtrType->getValueType(); + } + + if (lowerPairType(builder, pairType, nullptr)) { builder->setInsertBefore(getPrimalInst); @@ -2695,41 +2753,29 @@ struct JVPDerivativeContext : public InstPassBase bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren) { bool modified = false; - // Hoist and deduplicate all pair types to global scope when possible. - // This avoids emitting different struct types for equivalent pair types. + // Hoist all pair types to global scope when possible. auto moduleInst = module->getModuleInst(); - Dictionary<IRInst*, IRInst*> diffPairTypes; - for (;;) - { - bool changed = false; - sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); - processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* originalPairType) + processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType) + { + if (originalPairType->parent != moduleInst) { - IRInst* finalType = nullptr; - if (diffPairTypes.TryGetValue(originalPairType->getValueType(), finalType)) - { - if (finalType != originalPairType) - { - originalPairType->replaceUsesWith(finalType); - originalPairType->removeAndDeallocate(); - changed = true; - return; - } - } - diffPairTypes[originalPairType->getValueType()] = originalPairType; - if (originalPairType->parent != moduleInst) + originalPairType->removeFromParent(); + ShortList<IRInst*> operands; + for (UInt i = 0; i < originalPairType->getOperandCount(); i++) { - if (originalPairType->getValueType()->getParent() != originalPairType->getParent()) - { - originalPairType->insertAfter(originalPairType->getValueType()); - changed = true; - return; - } + operands.add(originalPairType->getOperand(i)); } - }); - if (!changed) - break; - } + auto newPairType = builder->findOrEmitHoistableInst( + originalPairType->getFullType(), + originalPairType->getOp(), + originalPairType->getOperandCount(), + operands.getArrayView().getBuffer()); + originalPairType->replaceUsesWith(newPairType); + originalPairType->removeAndDeallocate(); + } + }); + + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); processAllInsts([&](IRInst* inst) { @@ -3138,4 +3184,14 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* b return nullptr; } +void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() +{ + for (auto globalInst : sharedContext->moduleInst->getChildren()) + { + if (auto pairType = as<IRDifferentialPairType>(globalInst)) + { + differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness()); + } + } +} } |
