diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-01-10 03:16:24 +0530 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-09 13:46:24 -0800 |
| commit | 87f00a36a123e36b415eeea82e02a8366cc5b881 (patch) | |
| tree | 719270397242dd0ea2cccf36f586118ac30a6ff1 /source/slang/slang-ir-autodiff-fwd.cpp | |
| parent | 6706c1a7764ae03d810e35ce766ba153ebf7ee03 (diff) | |
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866)
* Overhauled the auto-diff system for dynamic dispatch
* More fixes
* remove intermediate dumps
* Update slang-ast-type.h
* More fixes + add a workaround for existential no-diff
* Update reverse-control-flow-3.slang
* remove dumps
* remove more dumps
* Delete working-reverse-control-flow-3.hlsl
* Cleanup comments + unused variables
* More comment cleanup
* Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)`
* Fix array of issues in Falcor tests.
* Update slang-ir-autodiff-pairs.cpp
* More fixes for Falcor image tests
* Small fixups.
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-fwd.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 143 |
1 files changed, 119 insertions, 24 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 9f26f9d55..30c14f706 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -160,6 +160,40 @@ InstPair ForwardDiffTranscriber::transcribeReinterpret(IRBuilder* builder, IRIns return InstPair(primalVal, diffVal); } +InstPair ForwardDiffTranscriber::transcribeDifferentiableTypeAnnotation( + IRBuilder* builder, + IRInst* origInst) +{ + auto primalAnnotation = + as<IRDifferentiableTypeAnnotation>(maybeCloneForPrimalInst(builder, origInst)); + + IRDifferentiableTypeAnnotation* annotation = as<IRDifferentiableTypeAnnotation>(origInst); + + differentiableTypeConformanceContext.addTypeToDictionary( + (IRType*)primalAnnotation->getBaseType(), + primalAnnotation->getWitness()); + + auto diffType = differentiateType(builder, (IRType*)annotation->getBaseType()); + if (!diffType) + return InstPair(primalAnnotation, nullptr); + + auto diffTypeDiffWitness = + tryGetDifferentiableWitness(builder, diffType, DiffConformanceKind::Any); + + IRInst* args[] = {diffType, diffTypeDiffWitness}; + + auto diffAnnotation = builder->emitIntrinsicInst( + builder->getVoidType(), + kIROp_DifferentiableTypeAnnotation, + 2, + args); + + builder->markInstAsPrimal(diffAnnotation); + builder->markInstAsPrimal(primalAnnotation); + + return InstPair(primalAnnotation, diffAnnotation); +} + InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar) { if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType())) @@ -752,9 +786,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig auto pairValType = as<IRDifferentialPairTypeBase>( pairPtrType ? pairPtrType->getValueType() : pairType); - auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType( - &argBuilder, - pairValType); + auto diffType = differentiateType(&argBuilder, primalType); if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType)) { // Create temp var to pass in/out arguments. @@ -795,7 +827,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (diffArg) { auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential( - (IRType*)diffType, + (IRType*)as<IRPtrTypeBase>(diffType)->getValueType(), newVal); markDiffTypeInst( &afterBuilder, @@ -827,17 +859,72 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig } } } + + { + // --WORKAROUND-- + // This is a temporary workaround for a very specific case.. + // + // If all the following are true: + // 1. the parameter type expects a differential pair, + // 2. the argument is derived from a no_diff type, and + // 3. the argument type is a run-time type (i.e. extract_existential_type), + // then we need to generate a differential 0, but the IR has no + // information on the diff witness. + // + // We will bypass the conformance system & brute-force the lookup for the interface + // keys, but the proper fix is to lower this key mapping during `no_diff` lowering. + // + + // Condition 1 + if (differentiableTypeConformanceContext.isDifferentiableType((originalParamType))) + { + // Condition 3 + if (auto extractExistentialType = as<IRExtractExistentialType>(primalType)) + { + // Condition 2 + if (isNoDiffType(extractExistentialType->getOperand(0)->getDataType())) + { + // Force-differentiate the type (this will perform a search for the witness + // without going through the diff-type annotation list) + // + IRInst* witnessTable = nullptr; + auto diffType = differentiateExtractExistentialType( + &argBuilder, + extractExistentialType, + witnessTable); + + auto pairType = + getOrCreateDiffPairType(&argBuilder, primalType, witnessTable); + auto zeroMethod = argBuilder.emitLookupInterfaceMethodInst( + differentiableTypeConformanceContext.sharedContext->zeroMethodType, + witnessTable, + differentiableTypeConformanceContext.sharedContext + ->zeroMethodStructKey); + auto diffZero = argBuilder.emitCallInst(diffType, zeroMethod, 0, nullptr); + auto diffPair = + argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffZero); + + args.add(diffPair); + continue; + } + } + } + } + // Argument is not differentiable. // Add original/primal argument. args.add(primalArg); } IRType* diffReturnType = nullptr; - diffReturnType = tryGetDiffPairType(&argBuilder, origCall->getFullType()); + auto primalReturnType = + (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType()); + + diffReturnType = tryGetDiffPairType(&argBuilder, primalReturnType); if (!diffReturnType) { - diffReturnType = (IRType*)findOrTranscribePrimalInst(&argBuilder, origCall->getFullType()); + diffReturnType = primalReturnType; } auto callInst = argBuilder.emitCallInst(diffReturnType, diffCallee, args); @@ -1035,6 +1122,7 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize( IRInst* diffBase = nullptr; if (instMapD.tryGetValue(origSpecialize->getBase(), diffBase)) { + auto diffType = differentiateType(builder, origSpecialize->getFullType()); if (diffBase) { List<IRInst*> args; @@ -1042,11 +1130,8 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize( { args.add(primalSpecialize->getArg(i)); } - auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), - diffBase, - args.getCount(), - args.getBuffer()); + auto diffSpecialize = + builder->emitSpecializeInst(diffType, diffBase, args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } else @@ -1572,7 +1657,24 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu return InstPair(origFunc, fwdDecor->getForwardDerivativeFunc()); } - auto diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc); + IRFunc* diffFunc = nullptr; + + // If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the + // insert location unchanged). If we're transcribing it as a declaration, we should + // insert into the module. + // + auto origOuterGen = as<IRGeneric>(findOuterGeneric(origFunc)); + if (!origOuterGen || findInnerMostGenericReturnVal(origOuterGen) != origFunc) + { + // Dealing with a declaration.. insert into module scope. + IRBuilder subBuilder = *inBuilder; + subBuilder.setInsertInto(inBuilder->getModule()); + diffFunc = transcribeFuncHeaderImpl(&subBuilder, origFunc); + } + else + { + diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc); + } if (auto outerGen = findOuterGeneric(diffFunc)) { @@ -1605,7 +1707,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I IRBuilder builder = *inBuilder; maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc); - differentiableTypeConformanceContext.setFunc(origFunc); auto diffFunc = builder.createFunc(); @@ -1632,12 +1733,6 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I // Transfer checkpoint hint decorations copyCheckpointHints(&builder, origFunc, diffFunc); - - // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. - if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) - { - cloneDecoration(&cloneEnv, dictDecor, diffFunc, diffFunc->getModule()); - } return diffFunc; } @@ -2012,6 +2107,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_Reinterpret: return transcribeReinterpret(builder, origInst); + case kIROp_DifferentiableTypeAnnotation: + return transcribeDifferentiableTypeAnnotation(builder, origInst); + // Differentiable insts that should have been lowered in a previous pass. case kIROp_SwizzledStore: { @@ -2138,13 +2236,10 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam( if (as<IRDifferentialPairType>(diffPairType) || as<IRDifferentialPtrPairType>(diffPairType)) { + auto diffType = differentiateType(builder, (IRType*)origParam->getFullType()); return InstPair( builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType( - builder, - as<IRDifferentialPairTypeBase>(diffPairType)), - diffPairParam)); + builder->emitDifferentialPairGetDifferential(diffType, diffPairParam)); } else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType)) { |
