diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-08 21:52:34 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-08 21:52:34 -0800 |
| commit | 86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 (patch) | |
| tree | b4f9eb6cb1eea88145fde0bd1f670a8803120257 /source/slang/slang-lower-to-ir.cpp | |
| parent | 257733f328f38a763c8b0c8830ff4c0d34ec9491 (diff) | |
Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691)
* Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`.
* Fix
* Fix.
* Cleanup.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 357 |
1 files changed, 190 insertions, 167 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 261e08168..d8912cbd4 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3165,6 +3165,17 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> baseVal.val)); } + LoweredValInfo visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr) + { + auto baseVal = lowerSubExpr(expr->baseFunction); + SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); + + return LoweredValInfo::simple( + getBuilder()->emitPrimalSubstituteInst( + lowerType(context, expr->type), + baseVal.val)); + } + LoweredValInfo visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) { auto baseVal = lowerSubExpr(expr->innerExpr); @@ -7970,14 +7981,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> addNameHint(subContext, irFunc, decl); addLinkageDecoration(subContext, irFunc, decl); - if (decl->findModifier<ForwardDifferentiableAttribute>()) - { - getBuilder()->addForwardDifferentiableDecoration(irFunc); - } - if (decl->findModifier<BackwardDifferentiableAttribute>()) - { - getBuilder()->addBackwardDifferentiableDecoration(irFunc); - } if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>()) { lowerDifferentiableAttribute(subContext, irFunc, differentialAttr); @@ -8291,156 +8294,156 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addRequireCUDASMVersionDecoration(irFunc, versionMod->version); } - if (decl->findModifier<RequiresNVAPIAttribute>()) - { - getBuilder()->addSimpleDecoration<IRRequiresNVAPIDecoration>(irFunc); - } - - if (decl->findModifier<AlwaysFoldIntoUseSiteAttribute>()) - { - getBuilder()->addSimpleDecoration<IRAlwaysFoldIntoUseSiteDecoration>(irFunc); - } - - if (decl->findModifier<NoInlineAttribute>()) - { - getBuilder()->addSimpleDecoration<IRNoInlineDecoration>(irFunc); - } - - if (auto attr = decl->findModifier<InstanceAttribute>()) - { - IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit); - } - - if (auto attr = decl->findModifier<MaxVertexCountAttribute>()) - { - IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit); - } - - if (auto attr = decl->findModifier<NumThreadsAttribute>()) - { - auto builder = getBuilder(); - IRType* intType = builder->getIntType(); - - IRInst* operands[3] = { - builder->getIntValue(intType, attr->x), - builder->getIntValue(intType, attr->y), - builder->getIntValue(intType, attr->z) - }; - - builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3); - } - - if (decl->findModifier<ReadNoneAttribute>()) - { - getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc); - } - - if (decl->findModifier<EarlyDepthStencilAttribute>()) - { - getBuilder()->addSimpleDecoration<IREarlyDepthStencilDecoration>(irFunc); - } - - if (auto attr = decl->findModifier<DomainAttribute>()) - { - IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit); - } + // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute + setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); - if (auto attr = decl->findModifier<PartitioningAttribute>()) + for (auto modifier : decl->modifiers) { - IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit); - } + if (as<RequiresNVAPIAttribute>(modifier)) + { + getBuilder()->addSimpleDecoration<IRRequiresNVAPIDecoration>(irFunc); + } + else if (as<AlwaysFoldIntoUseSiteAttribute>(modifier)) + { + getBuilder()->addSimpleDecoration<IRAlwaysFoldIntoUseSiteDecoration>(irFunc); + } + else if (as<NoInlineAttribute>(modifier)) + { + getBuilder()->addSimpleDecoration<IRNoInlineDecoration>(irFunc); + } + else if (auto instanceAttr = as<InstanceAttribute>(modifier)) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), instanceAttr); + getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit); + } + else if (auto maxVertCountAttr = as<MaxVertexCountAttribute>(modifier)) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), maxVertCountAttr); + getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit); + } + else if (auto numThreadsAttr = as<NumThreadsAttribute>(modifier)) + { + auto builder = getBuilder(); + IRType* intType = builder->getIntType(); - if (auto attr = decl->findModifier<OutputTopologyAttribute>()) - { - IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit); - } + IRInst* operands[3] = { + builder->getIntValue(intType, numThreadsAttr->x), + builder->getIntValue(intType, numThreadsAttr->y), + builder->getIntValue(intType, numThreadsAttr->z) + }; - if (auto attr = decl->findModifier<OutputControlPointsAttribute>()) - { - IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit); - } + builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3); + } + else if (as<ReadNoneAttribute>(modifier)) + { + getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc); + } + else if (as<EarlyDepthStencilAttribute>(modifier)) + { + getBuilder()->addSimpleDecoration<IREarlyDepthStencilDecoration>(irFunc); + } + else if (auto domainAttr = as<DomainAttribute>(modifier)) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), domainAttr); + getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit); + } + else if (auto partitionAttr = as<PartitioningAttribute>(modifier)) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), partitionAttr); + getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit); + } + else if (auto outputTopAttr = as<OutputTopologyAttribute>(modifier)) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), outputTopAttr); + getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit); + } + else if (auto outputCtrlPtAttr = as<OutputControlPointsAttribute>(modifier)) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), outputCtrlPtAttr); + getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit); + } + else if (auto spvInstOpAttr = as<SPIRVInstructionOpAttribute>(modifier)) + { + auto builder = getBuilder(); + IRIntLit* intLit = _getIntLitFromAttribute(builder, spvInstOpAttr, 0); - if (auto attr = decl->findModifier<SPIRVInstructionOpAttribute>()) - { - auto builder = getBuilder(); - IRIntLit* intLit = _getIntLitFromAttribute(builder, attr, 0); + IRStringLit* setStringLit = nullptr; + if (spvInstOpAttr->args.getCount() > 1) + { + IRStringLit* checkSetStringLit = _getStringLitFromAttribute(builder, spvInstOpAttr, 1); + if (checkSetStringLit && checkSetStringLit->getStringSlice().getLength() > 0) + { + setStringLit = checkSetStringLit; + } + } - IRStringLit* setStringLit = nullptr; - if (attr->args.getCount() > 1) - { - IRStringLit* checkSetStringLit = _getStringLitFromAttribute(builder, attr, 1); - if (checkSetStringLit && checkSetStringLit->getStringSlice().getLength() > 0) + // If it has a `set` defined, set it on the decoration + if (setStringLit) { - setStringLit = checkSetStringLit; + builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit, setStringLit); + } + else + { + builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit); } } - - // If it has a `set` defined, set it on the decoration - if (setStringLit) + else if (as<UnsafeForceInlineEarlyAttribute>(modifier)) { - builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit, setStringLit); + getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration); } - else + else if (as<ForceInlineAttribute>(modifier)) { - builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit); + getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); } - } - - if (decl->findModifier<UnsafeForceInlineEarlyAttribute>()) - { - getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration); - } - - if (decl->findModifier<ForceInlineAttribute>()) - { - getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); - } - - if (decl->findModifier<TreatAsDifferentiableAttribute>()) - { - getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); - } - - if (auto intrinsicOp = decl->findModifier<IntrinsicOpModifier>()) - { - auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op); - getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op); - } - - // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute - setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); - - if (auto attr = decl->findModifier<UserDefinedDerivativeAttribute>()) - { - // We need to lower the decl ref to the custom derivative function to IR. - // The IR insts correspond to the decl ref is not part of the function we - // are processing. If we emit it directly to within the function, it could - // mess up the assumption on the form of the IR (e.g. having non decoration insts - // appearing in the middle of decoration insts). so we emit the decl ref to the - // function's parent for now. - - subContext->irBuilder->setInsertInto(irFunc->getParent()); - - auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr); - - SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); - IRInst* derivativeFunc = loweredVal.val; - - if (as<ForwardDerivativeAttribute>(attr)) - getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc); - else - getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc); + else if (as<TreatAsDifferentiableAttribute>(modifier)) + { + getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); + } + else if (auto intrinsicOp = as<IntrinsicOpModifier>(modifier)) + { + auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op); + getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op); + } + else if (as<UserDefinedDerivativeAttribute>(modifier) || as<PrimalSubstituteAttribute>(modifier)) + { + // We need to lower the decl ref to the custom derivative function to IR. + // The IR insts correspond to the decl ref is not part of the function we + // are processing. If we emit it directly to within the function, it could + // mess up the assumption on the form of the IR (e.g. having non decoration insts + // appearing in the middle of decoration insts). so we emit the decl ref to the + // function's parent for now. + + subContext->irBuilder->setInsertInto(irFunc->getParent()); + Expr* funcExpr = nullptr; + if (auto udAttr = as<UserDefinedDerivativeAttribute>(modifier)) + funcExpr = udAttr->funcExpr; + else if (auto primalAttr = as<PrimalSubstituteAttribute>(modifier)) + funcExpr = primalAttr->funcExpr; + + auto loweredVal = lowerRValueExpr(subContext, funcExpr); + + SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); + IRInst* derivativeFunc = loweredVal.val; + + if (as<ForwardDerivativeAttribute>(modifier)) + getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc); + else if (as<BackwardDerivativeAttribute>(modifier)) + getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc); + else + getBuilder()->addPrimalSubstituteDecoration(irFunc, derivativeFunc); - // Reset cursor. - subContext->irBuilder->setInsertInto(irFunc); + // Reset cursor. + subContext->irBuilder->setInsertInto(irFunc); + } + else if (as<ForwardDifferentiableAttribute>(modifier)) + { + getBuilder()->addForwardDifferentiableDecoration(irFunc); + } + else if (as<BackwardDifferentiableAttribute>(modifier)) + { + getBuilder()->addBackwardDifferentiableDecoration(irFunc); + } } - // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list @@ -8451,39 +8454,59 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // the interface's type definition. auto finalVal = finishOuterGenerics(subBuilder, irFunc, outerGeneric); - if (auto attr = decl->findModifier<DerivativeOfAttribute>()) + for (auto modifier : decl->modifiers) { - if (auto originalDeclRefExpr = as<DeclRefExpr>(attr->funcExpr)) + if (as<DerivativeOfAttribute>(modifier) || as<PrimalSubstituteOfAttribute>(modifier)) { - NestedContext originalContextFunc(this); - auto originalSubBuilder = originalContextFunc.getBuilder(); - auto originalSubContext = originalContextFunc.getContext(); - if (auto outterGeneric = getOuterGeneric(irFunc)) - originalSubBuilder->setInsertBefore(outterGeneric); - else - originalSubBuilder->setInsertBefore(irFunc); - auto originalFuncDecl = as<FunctionDeclBase>(originalDeclRefExpr->declRef.getDecl()); - SLANG_RELEASE_ASSERT(originalFuncDecl); - - auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val; - if (auto originalFuncGeneric = as<IRGeneric>(originalFuncVal)) + Expr* funcExpr = nullptr; + Expr* backDeclRef = nullptr; + if (auto attr = as<DerivativeOfAttribute>(modifier)) { - originalFuncVal = findGenericReturnVal(originalFuncGeneric); + funcExpr = attr->funcExpr; + backDeclRef = attr->backDeclRef; } - originalSubBuilder->setInsertBefore(originalFuncVal); - auto derivativeFuncVal = lowerRValueExpr(originalSubContext, attr->backDeclRef); - if (as<ForwardDerivativeOfAttribute>(attr)) + else if (auto primalAttr = as<PrimalSubstituteOfAttribute>(modifier)) { - originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); - getBuilder()->addForwardDifferentiableDecoration(irFunc); + funcExpr = primalAttr->funcExpr; + backDeclRef = primalAttr->backDeclRef; } - else + + if (auto originalDeclRefExpr = as<DeclRefExpr>(funcExpr)) { - originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + NestedContext originalContextFunc(this); + auto originalSubBuilder = originalContextFunc.getBuilder(); + auto originalSubContext = originalContextFunc.getContext(); + if (auto outterGeneric = getOuterGeneric(irFunc)) + originalSubBuilder->setInsertBefore(outterGeneric); + else + originalSubBuilder->setInsertBefore(irFunc); + auto originalFuncDecl = as<FunctionDeclBase>(originalDeclRefExpr->declRef.getDecl()); + SLANG_RELEASE_ASSERT(originalFuncDecl); + + auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val; + if (auto originalFuncGeneric = as<IRGeneric>(originalFuncVal)) + { + originalFuncVal = findGenericReturnVal(originalFuncGeneric); + } + originalSubBuilder->setInsertBefore(originalFuncVal); + auto derivativeFuncVal = lowerRValueExpr(originalSubContext, backDeclRef); + if (as<ForwardDerivativeOfAttribute>(modifier)) + { + originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + getBuilder()->addForwardDifferentiableDecoration(irFunc); + } + else if (as<BackwardDerivativeOfAttribute>(modifier)) + { + originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + } + else + { + originalSubBuilder->addPrimalSubstituteDecoration(originalFuncVal, derivativeFuncVal.val); + } } + subContext->irBuilder->setInsertInto(irFunc); + finalVal->moveToEnd(); } - subContext->irBuilder->setInsertInto(irFunc); - finalVal->moveToEnd(); } return LoweredValInfo::simple(finalVal); } |
