diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-01-10 03:16:24 +0530 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-09 13:46:24 -0800 |
| commit | 87f00a36a123e36b415eeea82e02a8366cc5b881 (patch) | |
| tree | 719270397242dd0ea2cccf36f586118ac30a6ff1 /source/slang/slang-lower-to-ir.cpp | |
| parent | 6706c1a7764ae03d810e35ce766ba153ebf7ee03 (diff) | |
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866)
* Overhauled the auto-diff system for dynamic dispatch
* More fixes
* remove intermediate dumps
* Update slang-ast-type.h
* More fixes + add a workaround for existential no-diff
* Update reverse-control-flow-3.slang
* remove dumps
* remove more dumps
* Delete working-reverse-control-flow-3.hlsl
* Cleanup comments + unused variables
* More comment cleanup
* Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)`
* Fix array of issues in Falcor tests.
* Update slang-ir-autodiff-pairs.cpp
* More fixes for Falcor image tests
* Small fixups.
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 75 |
1 files changed, 62 insertions, 13 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 011ea6bc7..e82fc03fd 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -592,11 +592,21 @@ struct IRGenContext // The element index if we are inside an `expand` expression. IRInst* expandIndex = nullptr; + // Callback function to call when after lowering a type. + std::function<IRType*(IRGenContext* context, Type* type, IRType* irType)> lowerTypeCallback = + nullptr; + explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder) : shared(inShared), astBuilder(inAstBuilder), env(&inShared->globalEnv), irBuilder(nullptr) { } + void registerTypeCallback( + std::function<IRType*(IRGenContext* context, Type* type, IRType* irType)> callback) + { + lowerTypeCallback = callback; + } + void setGlobalValue(Decl* decl, LoweredValInfo value) { shared->setGlobalValue(decl, value); } void setValue(Decl* decl, LoweredValInfo value) { env->mapDeclToValue[decl] = value; } @@ -2202,7 +2212,12 @@ IRType* lowerType(IRGenContext* context, Type* type) { ValLoweringVisitor visitor; visitor.context = context; - return (IRType*)getSimpleVal(context, visitor.dispatchType(type)); + IRType* loweredType = (IRType*)getSimpleVal(context, visitor.dispatchType(type)); + + if (context->lowerTypeCallback && loweredType) + context->lowerTypeCallback(context, type, loweredType); + + return loweredType; } void addVarDecorations(IRGenContext* context, IRInst* inst, Decl* decl) @@ -8105,6 +8120,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> subContextStorage.thisTypeWitness = outerContext->thisTypeWitness; subContextStorage.returnDestination = LoweredValInfo(); + subContextStorage.lowerTypeCallback = nullptr; } IRBuilder* getBuilder() { return &subBuilderStorage; } @@ -8629,7 +8645,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto finalVal = finishOuterGenerics(subBuilder, irInterface, outerGeneric); // Add `irInterface` to decl mapping now to prevent cyclic lowering. - context->setValue(decl, LoweredValInfo::simple(finalVal)); + context->setGlobalValue(decl, LoweredValInfo::simple(finalVal)); subBuilder->setInsertBefore(irInterface); @@ -8783,7 +8799,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } - addNameHint(context, irInterface, decl); addLinkageDecoration(context, irInterface, decl); if (auto anyValueSizeAttr = decl->findModifier<AnyValueSizeAttribute>()) @@ -9910,6 +9925,48 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> else outerGeneric = emitOuterGenerics(subContext, decl, decl); + // If our function is differentiable, register a callback so the derivative + // annotations for types can be lowered. + // + if (auto diffAttr = decl->findModifier<DifferentiableAttribute>()) + { + auto diffTypeWitnessMap = diffAttr->getMapTypeToIDifferentiableWitness(); + OrderedDictionary<DeclRefBase*, SubtypeWitness*> resolveddiffTypeWitnessMap; + + // Go through each entry in the map and resolve the key. + for (auto& entry : diffTypeWitnessMap) + { + auto resolvedKey = as<DeclRefBase>(entry.key->resolve()); + resolveddiffTypeWitnessMap[resolvedKey] = + as<SubtypeWitness>(as<Val>(entry.value)->resolve()); + } + + subContext->registerTypeCallback( + [=](IRGenContext* context, Type* type, IRType* irType) + { + if (!as<DeclRefType>(type)) + return irType; + + DeclRefBase* declRefBase = as<DeclRefType>(type)->getDeclRefBase(); + if (resolveddiffTypeWitnessMap.containsKey(declRefBase)) + { + auto irWitness = + lowerVal(subContext, resolveddiffTypeWitnessMap[declRefBase]).val; + if (irWitness) + { + IRInst* args[] = {irType, irWitness}; + context->irBuilder->emitIntrinsicInst( + context->irBuilder->getVoidType(), + kIROp_DifferentiableTypeAnnotation, + 2, + args); + } + } + + return irType; + }); + } + FuncDeclBaseTypeInfo info; _lowerFuncDeclBaseTypeInfo( subContext, @@ -10220,6 +10277,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } + subContext->registerTypeCallback(nullptr); + getBuilder()->addHighLevelDeclDecoration(irFunc, decl); addSpecializedForTargetDecorations(irFunc, decl); @@ -10467,16 +10526,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } - if (auto diffAttr = decl->findModifier<DifferentiableAttribute>()) - { - if (decl->body) - { - subContext->irBuilder->setInsertInto(irFunc->getParent()); - lowerDifferentiableAttribute(subContext, irFunc, diffAttr); - subContext->irBuilder->setInsertInto(irFunc); - } - } - // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list |
