From 87f00a36a123e36b415eeea82e02a8366cc5b881 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Jan 2025 03:16:24 +0530 Subject: [Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866) * Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He --- source/slang/slang-lower-to-ir.cpp | 75 +++++++++++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 13 deletions(-) (limited to 'source/slang/slang-lower-to-ir.cpp') diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 011ea6bc7..e82fc03fd 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -592,11 +592,21 @@ struct IRGenContext // The element index if we are inside an `expand` expression. IRInst* expandIndex = nullptr; + // Callback function to call when after lowering a type. + std::function lowerTypeCallback = + nullptr; + explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder) : shared(inShared), astBuilder(inAstBuilder), env(&inShared->globalEnv), irBuilder(nullptr) { } + void registerTypeCallback( + std::function callback) + { + lowerTypeCallback = callback; + } + void setGlobalValue(Decl* decl, LoweredValInfo value) { shared->setGlobalValue(decl, value); } void setValue(Decl* decl, LoweredValInfo value) { env->mapDeclToValue[decl] = value; } @@ -2202,7 +2212,12 @@ IRType* lowerType(IRGenContext* context, Type* type) { ValLoweringVisitor visitor; visitor.context = context; - return (IRType*)getSimpleVal(context, visitor.dispatchType(type)); + IRType* loweredType = (IRType*)getSimpleVal(context, visitor.dispatchType(type)); + + if (context->lowerTypeCallback && loweredType) + context->lowerTypeCallback(context, type, loweredType); + + return loweredType; } void addVarDecorations(IRGenContext* context, IRInst* inst, Decl* decl) @@ -8105,6 +8120,7 @@ struct DeclLoweringVisitor : DeclVisitor subContextStorage.thisTypeWitness = outerContext->thisTypeWitness; subContextStorage.returnDestination = LoweredValInfo(); + subContextStorage.lowerTypeCallback = nullptr; } IRBuilder* getBuilder() { return &subBuilderStorage; } @@ -8629,7 +8645,7 @@ struct DeclLoweringVisitor : DeclVisitor auto finalVal = finishOuterGenerics(subBuilder, irInterface, outerGeneric); // Add `irInterface` to decl mapping now to prevent cyclic lowering. - context->setValue(decl, LoweredValInfo::simple(finalVal)); + context->setGlobalValue(decl, LoweredValInfo::simple(finalVal)); subBuilder->setInsertBefore(irInterface); @@ -8783,7 +8799,6 @@ struct DeclLoweringVisitor : DeclVisitor } } - addNameHint(context, irInterface, decl); addLinkageDecoration(context, irInterface, decl); if (auto anyValueSizeAttr = decl->findModifier()) @@ -9910,6 +9925,48 @@ struct DeclLoweringVisitor : DeclVisitor else outerGeneric = emitOuterGenerics(subContext, decl, decl); + // If our function is differentiable, register a callback so the derivative + // annotations for types can be lowered. + // + if (auto diffAttr = decl->findModifier()) + { + auto diffTypeWitnessMap = diffAttr->getMapTypeToIDifferentiableWitness(); + OrderedDictionary resolveddiffTypeWitnessMap; + + // Go through each entry in the map and resolve the key. + for (auto& entry : diffTypeWitnessMap) + { + auto resolvedKey = as(entry.key->resolve()); + resolveddiffTypeWitnessMap[resolvedKey] = + as(as(entry.value)->resolve()); + } + + subContext->registerTypeCallback( + [=](IRGenContext* context, Type* type, IRType* irType) + { + if (!as(type)) + return irType; + + DeclRefBase* declRefBase = as(type)->getDeclRefBase(); + if (resolveddiffTypeWitnessMap.containsKey(declRefBase)) + { + auto irWitness = + lowerVal(subContext, resolveddiffTypeWitnessMap[declRefBase]).val; + if (irWitness) + { + IRInst* args[] = {irType, irWitness}; + context->irBuilder->emitIntrinsicInst( + context->irBuilder->getVoidType(), + kIROp_DifferentiableTypeAnnotation, + 2, + args); + } + } + + return irType; + }); + } + FuncDeclBaseTypeInfo info; _lowerFuncDeclBaseTypeInfo( subContext, @@ -10220,6 +10277,8 @@ struct DeclLoweringVisitor : DeclVisitor } } + subContext->registerTypeCallback(nullptr); + getBuilder()->addHighLevelDeclDecoration(irFunc, decl); addSpecializedForTargetDecorations(irFunc, decl); @@ -10467,16 +10526,6 @@ struct DeclLoweringVisitor : DeclVisitor } } - if (auto diffAttr = decl->findModifier()) - { - if (decl->body) - { - subContext->irBuilder->setInsertInto(irFunc->getParent()); - lowerDifferentiableAttribute(subContext, irFunc, diffAttr); - subContext->irBuilder->setInsertInto(irFunc); - } - } - // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list -- cgit v1.2.3