summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-lower-to-ir.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
-rw-r--r--source/slang/slang-lower-to-ir.cpp75
1 files changed, 62 insertions, 13 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 011ea6bc7..e82fc03fd 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -592,11 +592,21 @@ struct IRGenContext
// The element index if we are inside an `expand` expression.
IRInst* expandIndex = nullptr;
+ // Callback function to call when after lowering a type.
+ std::function<IRType*(IRGenContext* context, Type* type, IRType* irType)> lowerTypeCallback =
+ nullptr;
+
explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder)
: shared(inShared), astBuilder(inAstBuilder), env(&inShared->globalEnv), irBuilder(nullptr)
{
}
+ void registerTypeCallback(
+ std::function<IRType*(IRGenContext* context, Type* type, IRType* irType)> callback)
+ {
+ lowerTypeCallback = callback;
+ }
+
void setGlobalValue(Decl* decl, LoweredValInfo value) { shared->setGlobalValue(decl, value); }
void setValue(Decl* decl, LoweredValInfo value) { env->mapDeclToValue[decl] = value; }
@@ -2202,7 +2212,12 @@ IRType* lowerType(IRGenContext* context, Type* type)
{
ValLoweringVisitor visitor;
visitor.context = context;
- return (IRType*)getSimpleVal(context, visitor.dispatchType(type));
+ IRType* loweredType = (IRType*)getSimpleVal(context, visitor.dispatchType(type));
+
+ if (context->lowerTypeCallback && loweredType)
+ context->lowerTypeCallback(context, type, loweredType);
+
+ return loweredType;
}
void addVarDecorations(IRGenContext* context, IRInst* inst, Decl* decl)
@@ -8105,6 +8120,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
subContextStorage.thisTypeWitness = outerContext->thisTypeWitness;
subContextStorage.returnDestination = LoweredValInfo();
+ subContextStorage.lowerTypeCallback = nullptr;
}
IRBuilder* getBuilder() { return &subBuilderStorage; }
@@ -8629,7 +8645,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto finalVal = finishOuterGenerics(subBuilder, irInterface, outerGeneric);
// Add `irInterface` to decl mapping now to prevent cyclic lowering.
- context->setValue(decl, LoweredValInfo::simple(finalVal));
+ context->setGlobalValue(decl, LoweredValInfo::simple(finalVal));
subBuilder->setInsertBefore(irInterface);
@@ -8783,7 +8799,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
-
addNameHint(context, irInterface, decl);
addLinkageDecoration(context, irInterface, decl);
if (auto anyValueSizeAttr = decl->findModifier<AnyValueSizeAttribute>())
@@ -9910,6 +9925,48 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
else
outerGeneric = emitOuterGenerics(subContext, decl, decl);
+ // If our function is differentiable, register a callback so the derivative
+ // annotations for types can be lowered.
+ //
+ if (auto diffAttr = decl->findModifier<DifferentiableAttribute>())
+ {
+ auto diffTypeWitnessMap = diffAttr->getMapTypeToIDifferentiableWitness();
+ OrderedDictionary<DeclRefBase*, SubtypeWitness*> resolveddiffTypeWitnessMap;
+
+ // Go through each entry in the map and resolve the key.
+ for (auto& entry : diffTypeWitnessMap)
+ {
+ auto resolvedKey = as<DeclRefBase>(entry.key->resolve());
+ resolveddiffTypeWitnessMap[resolvedKey] =
+ as<SubtypeWitness>(as<Val>(entry.value)->resolve());
+ }
+
+ subContext->registerTypeCallback(
+ [=](IRGenContext* context, Type* type, IRType* irType)
+ {
+ if (!as<DeclRefType>(type))
+ return irType;
+
+ DeclRefBase* declRefBase = as<DeclRefType>(type)->getDeclRefBase();
+ if (resolveddiffTypeWitnessMap.containsKey(declRefBase))
+ {
+ auto irWitness =
+ lowerVal(subContext, resolveddiffTypeWitnessMap[declRefBase]).val;
+ if (irWitness)
+ {
+ IRInst* args[] = {irType, irWitness};
+ context->irBuilder->emitIntrinsicInst(
+ context->irBuilder->getVoidType(),
+ kIROp_DifferentiableTypeAnnotation,
+ 2,
+ args);
+ }
+ }
+
+ return irType;
+ });
+ }
+
FuncDeclBaseTypeInfo info;
_lowerFuncDeclBaseTypeInfo(
subContext,
@@ -10220,6 +10277,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
+ subContext->registerTypeCallback(nullptr);
+
getBuilder()->addHighLevelDeclDecoration(irFunc, decl);
addSpecializedForTargetDecorations(irFunc, decl);
@@ -10467,16 +10526,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
- if (auto diffAttr = decl->findModifier<DifferentiableAttribute>())
- {
- if (decl->body)
- {
- subContext->irBuilder->setInsertInto(irFunc->getParent());
- lowerDifferentiableAttribute(subContext, irFunc, diffAttr);
- subContext->irBuilder->setInsertInto(irFunc);
- }
- }
-
// For convenience, ensure that any additional global
// values that were emitted while outputting the function
// body appear before the function itself in the list