diff options
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 75 |
1 files changed, 62 insertions, 13 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 011ea6bc7..e82fc03fd 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -592,11 +592,21 @@ struct IRGenContext // The element index if we are inside an `expand` expression. IRInst* expandIndex = nullptr; + // Callback function to call when after lowering a type. + std::function<IRType*(IRGenContext* context, Type* type, IRType* irType)> lowerTypeCallback = + nullptr; + explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder) : shared(inShared), astBuilder(inAstBuilder), env(&inShared->globalEnv), irBuilder(nullptr) { } + void registerTypeCallback( + std::function<IRType*(IRGenContext* context, Type* type, IRType* irType)> callback) + { + lowerTypeCallback = callback; + } + void setGlobalValue(Decl* decl, LoweredValInfo value) { shared->setGlobalValue(decl, value); } void setValue(Decl* decl, LoweredValInfo value) { env->mapDeclToValue[decl] = value; } @@ -2202,7 +2212,12 @@ IRType* lowerType(IRGenContext* context, Type* type) { ValLoweringVisitor visitor; visitor.context = context; - return (IRType*)getSimpleVal(context, visitor.dispatchType(type)); + IRType* loweredType = (IRType*)getSimpleVal(context, visitor.dispatchType(type)); + + if (context->lowerTypeCallback && loweredType) + context->lowerTypeCallback(context, type, loweredType); + + return loweredType; } void addVarDecorations(IRGenContext* context, IRInst* inst, Decl* decl) @@ -8105,6 +8120,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> subContextStorage.thisTypeWitness = outerContext->thisTypeWitness; subContextStorage.returnDestination = LoweredValInfo(); + subContextStorage.lowerTypeCallback = nullptr; } IRBuilder* getBuilder() { return &subBuilderStorage; } @@ -8629,7 +8645,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto finalVal = finishOuterGenerics(subBuilder, irInterface, outerGeneric); // Add `irInterface` to decl mapping now to prevent cyclic lowering. - context->setValue(decl, LoweredValInfo::simple(finalVal)); + context->setGlobalValue(decl, LoweredValInfo::simple(finalVal)); subBuilder->setInsertBefore(irInterface); @@ -8783,7 +8799,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } - addNameHint(context, irInterface, decl); addLinkageDecoration(context, irInterface, decl); if (auto anyValueSizeAttr = decl->findModifier<AnyValueSizeAttribute>()) @@ -9910,6 +9925,48 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> else outerGeneric = emitOuterGenerics(subContext, decl, decl); + // If our function is differentiable, register a callback so the derivative + // annotations for types can be lowered. + // + if (auto diffAttr = decl->findModifier<DifferentiableAttribute>()) + { + auto diffTypeWitnessMap = diffAttr->getMapTypeToIDifferentiableWitness(); + OrderedDictionary<DeclRefBase*, SubtypeWitness*> resolveddiffTypeWitnessMap; + + // Go through each entry in the map and resolve the key. + for (auto& entry : diffTypeWitnessMap) + { + auto resolvedKey = as<DeclRefBase>(entry.key->resolve()); + resolveddiffTypeWitnessMap[resolvedKey] = + as<SubtypeWitness>(as<Val>(entry.value)->resolve()); + } + + subContext->registerTypeCallback( + [=](IRGenContext* context, Type* type, IRType* irType) + { + if (!as<DeclRefType>(type)) + return irType; + + DeclRefBase* declRefBase = as<DeclRefType>(type)->getDeclRefBase(); + if (resolveddiffTypeWitnessMap.containsKey(declRefBase)) + { + auto irWitness = + lowerVal(subContext, resolveddiffTypeWitnessMap[declRefBase]).val; + if (irWitness) + { + IRInst* args[] = {irType, irWitness}; + context->irBuilder->emitIntrinsicInst( + context->irBuilder->getVoidType(), + kIROp_DifferentiableTypeAnnotation, + 2, + args); + } + } + + return irType; + }); + } + FuncDeclBaseTypeInfo info; _lowerFuncDeclBaseTypeInfo( subContext, @@ -10220,6 +10277,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } + subContext->registerTypeCallback(nullptr); + getBuilder()->addHighLevelDeclDecoration(irFunc, decl); addSpecializedForTargetDecorations(irFunc, decl); @@ -10467,16 +10526,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } - if (auto diffAttr = decl->findModifier<DifferentiableAttribute>()) - { - if (decl->body) - { - subContext->irBuilder->setInsertInto(irFunc->getParent()); - lowerDifferentiableAttribute(subContext, irFunc, diffAttr); - subContext->irBuilder->setInsertInto(irFunc); - } - } - // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list |
