From 86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 8 Mar 2023 21:52:34 -0800 Subject: Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691) * Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. * Fix * Fix. * Cleanup. --------- Co-authored-by: Yong He --- source/slang/slang-lower-to-ir.cpp | 357 ++++++++++++++++++++----------------- 1 file changed, 190 insertions(+), 167 deletions(-) (limited to 'source/slang/slang-lower-to-ir.cpp') diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 261e08168..d8912cbd4 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3165,6 +3165,17 @@ struct ExprLoweringVisitorBase : ExprVisitor baseVal.val)); } + LoweredValInfo visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr) + { + auto baseVal = lowerSubExpr(expr->baseFunction); + SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); + + return LoweredValInfo::simple( + getBuilder()->emitPrimalSubstituteInst( + lowerType(context, expr->type), + baseVal.val)); + } + LoweredValInfo visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) { auto baseVal = lowerSubExpr(expr->innerExpr); @@ -7970,14 +7981,6 @@ struct DeclLoweringVisitor : DeclVisitor addNameHint(subContext, irFunc, decl); addLinkageDecoration(subContext, irFunc, decl); - if (decl->findModifier()) - { - getBuilder()->addForwardDifferentiableDecoration(irFunc); - } - if (decl->findModifier()) - { - getBuilder()->addBackwardDifferentiableDecoration(irFunc); - } if (auto differentialAttr = decl->findModifier()) { lowerDifferentiableAttribute(subContext, irFunc, differentialAttr); @@ -8291,156 +8294,156 @@ struct DeclLoweringVisitor : DeclVisitor getBuilder()->addRequireCUDASMVersionDecoration(irFunc, versionMod->version); } - if (decl->findModifier()) - { - getBuilder()->addSimpleDecoration(irFunc); - } - - if (decl->findModifier()) - { - getBuilder()->addSimpleDecoration(irFunc); - } - - if (decl->findModifier()) - { - getBuilder()->addSimpleDecoration(irFunc); - } - - if (auto attr = decl->findModifier()) - { - IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit); - } - - if (auto attr = decl->findModifier()) - { - IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit); - } - - if (auto attr = decl->findModifier()) - { - auto builder = getBuilder(); - IRType* intType = builder->getIntType(); - - IRInst* operands[3] = { - builder->getIntValue(intType, attr->x), - builder->getIntValue(intType, attr->y), - builder->getIntValue(intType, attr->z) - }; - - builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3); - } - - if (decl->findModifier()) - { - getBuilder()->addSimpleDecoration(irFunc); - } - - if (decl->findModifier()) - { - getBuilder()->addSimpleDecoration(irFunc); - } - - if (auto attr = decl->findModifier()) - { - IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit); - } + // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute + setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); - if (auto attr = decl->findModifier()) + for (auto modifier : decl->modifiers) { - IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit); - } + if (as(modifier)) + { + getBuilder()->addSimpleDecoration(irFunc); + } + else if (as(modifier)) + { + getBuilder()->addSimpleDecoration(irFunc); + } + else if (as(modifier)) + { + getBuilder()->addSimpleDecoration(irFunc); + } + else if (auto instanceAttr = as(modifier)) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), instanceAttr); + getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit); + } + else if (auto maxVertCountAttr = as(modifier)) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), maxVertCountAttr); + getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit); + } + else if (auto numThreadsAttr = as(modifier)) + { + auto builder = getBuilder(); + IRType* intType = builder->getIntType(); - if (auto attr = decl->findModifier()) - { - IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit); - } + IRInst* operands[3] = { + builder->getIntValue(intType, numThreadsAttr->x), + builder->getIntValue(intType, numThreadsAttr->y), + builder->getIntValue(intType, numThreadsAttr->z) + }; - if (auto attr = decl->findModifier()) - { - IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); - getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit); - } + builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3); + } + else if (as(modifier)) + { + getBuilder()->addSimpleDecoration(irFunc); + } + else if (as(modifier)) + { + getBuilder()->addSimpleDecoration(irFunc); + } + else if (auto domainAttr = as(modifier)) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), domainAttr); + getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit); + } + else if (auto partitionAttr = as(modifier)) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), partitionAttr); + getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit); + } + else if (auto outputTopAttr = as(modifier)) + { + IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), outputTopAttr); + getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit); + } + else if (auto outputCtrlPtAttr = as(modifier)) + { + IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), outputCtrlPtAttr); + getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit); + } + else if (auto spvInstOpAttr = as(modifier)) + { + auto builder = getBuilder(); + IRIntLit* intLit = _getIntLitFromAttribute(builder, spvInstOpAttr, 0); - if (auto attr = decl->findModifier()) - { - auto builder = getBuilder(); - IRIntLit* intLit = _getIntLitFromAttribute(builder, attr, 0); + IRStringLit* setStringLit = nullptr; + if (spvInstOpAttr->args.getCount() > 1) + { + IRStringLit* checkSetStringLit = _getStringLitFromAttribute(builder, spvInstOpAttr, 1); + if (checkSetStringLit && checkSetStringLit->getStringSlice().getLength() > 0) + { + setStringLit = checkSetStringLit; + } + } - IRStringLit* setStringLit = nullptr; - if (attr->args.getCount() > 1) - { - IRStringLit* checkSetStringLit = _getStringLitFromAttribute(builder, attr, 1); - if (checkSetStringLit && checkSetStringLit->getStringSlice().getLength() > 0) + // If it has a `set` defined, set it on the decoration + if (setStringLit) { - setStringLit = checkSetStringLit; + builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit, setStringLit); + } + else + { + builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit); } } - - // If it has a `set` defined, set it on the decoration - if (setStringLit) + else if (as(modifier)) { - builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit, setStringLit); + getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration); } - else + else if (as(modifier)) { - builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit); + getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); } - } - - if (decl->findModifier()) - { - getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration); - } - - if (decl->findModifier()) - { - getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); - } - - if (decl->findModifier()) - { - getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); - } - - if (auto intrinsicOp = decl->findModifier()) - { - auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op); - getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op); - } - - // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute - setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); - - if (auto attr = decl->findModifier()) - { - // We need to lower the decl ref to the custom derivative function to IR. - // The IR insts correspond to the decl ref is not part of the function we - // are processing. If we emit it directly to within the function, it could - // mess up the assumption on the form of the IR (e.g. having non decoration insts - // appearing in the middle of decoration insts). so we emit the decl ref to the - // function's parent for now. - - subContext->irBuilder->setInsertInto(irFunc->getParent()); - - auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr); - - SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); - IRInst* derivativeFunc = loweredVal.val; - - if (as(attr)) - getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc); - else - getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc); + else if (as(modifier)) + { + getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); + } + else if (auto intrinsicOp = as(modifier)) + { + auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op); + getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op); + } + else if (as(modifier) || as(modifier)) + { + // We need to lower the decl ref to the custom derivative function to IR. + // The IR insts correspond to the decl ref is not part of the function we + // are processing. If we emit it directly to within the function, it could + // mess up the assumption on the form of the IR (e.g. having non decoration insts + // appearing in the middle of decoration insts). so we emit the decl ref to the + // function's parent for now. + + subContext->irBuilder->setInsertInto(irFunc->getParent()); + Expr* funcExpr = nullptr; + if (auto udAttr = as(modifier)) + funcExpr = udAttr->funcExpr; + else if (auto primalAttr = as(modifier)) + funcExpr = primalAttr->funcExpr; + + auto loweredVal = lowerRValueExpr(subContext, funcExpr); + + SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); + IRInst* derivativeFunc = loweredVal.val; + + if (as(modifier)) + getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc); + else if (as(modifier)) + getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc); + else + getBuilder()->addPrimalSubstituteDecoration(irFunc, derivativeFunc); - // Reset cursor. - subContext->irBuilder->setInsertInto(irFunc); + // Reset cursor. + subContext->irBuilder->setInsertInto(irFunc); + } + else if (as(modifier)) + { + getBuilder()->addForwardDifferentiableDecoration(irFunc); + } + else if (as(modifier)) + { + getBuilder()->addBackwardDifferentiableDecoration(irFunc); + } } - // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list @@ -8451,39 +8454,59 @@ struct DeclLoweringVisitor : DeclVisitor // the interface's type definition. auto finalVal = finishOuterGenerics(subBuilder, irFunc, outerGeneric); - if (auto attr = decl->findModifier()) + for (auto modifier : decl->modifiers) { - if (auto originalDeclRefExpr = as(attr->funcExpr)) + if (as(modifier) || as(modifier)) { - NestedContext originalContextFunc(this); - auto originalSubBuilder = originalContextFunc.getBuilder(); - auto originalSubContext = originalContextFunc.getContext(); - if (auto outterGeneric = getOuterGeneric(irFunc)) - originalSubBuilder->setInsertBefore(outterGeneric); - else - originalSubBuilder->setInsertBefore(irFunc); - auto originalFuncDecl = as(originalDeclRefExpr->declRef.getDecl()); - SLANG_RELEASE_ASSERT(originalFuncDecl); - - auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val; - if (auto originalFuncGeneric = as(originalFuncVal)) + Expr* funcExpr = nullptr; + Expr* backDeclRef = nullptr; + if (auto attr = as(modifier)) { - originalFuncVal = findGenericReturnVal(originalFuncGeneric); + funcExpr = attr->funcExpr; + backDeclRef = attr->backDeclRef; } - originalSubBuilder->setInsertBefore(originalFuncVal); - auto derivativeFuncVal = lowerRValueExpr(originalSubContext, attr->backDeclRef); - if (as(attr)) + else if (auto primalAttr = as(modifier)) { - originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); - getBuilder()->addForwardDifferentiableDecoration(irFunc); + funcExpr = primalAttr->funcExpr; + backDeclRef = primalAttr->backDeclRef; } - else + + if (auto originalDeclRefExpr = as(funcExpr)) { - originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + NestedContext originalContextFunc(this); + auto originalSubBuilder = originalContextFunc.getBuilder(); + auto originalSubContext = originalContextFunc.getContext(); + if (auto outterGeneric = getOuterGeneric(irFunc)) + originalSubBuilder->setInsertBefore(outterGeneric); + else + originalSubBuilder->setInsertBefore(irFunc); + auto originalFuncDecl = as(originalDeclRefExpr->declRef.getDecl()); + SLANG_RELEASE_ASSERT(originalFuncDecl); + + auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val; + if (auto originalFuncGeneric = as(originalFuncVal)) + { + originalFuncVal = findGenericReturnVal(originalFuncGeneric); + } + originalSubBuilder->setInsertBefore(originalFuncVal); + auto derivativeFuncVal = lowerRValueExpr(originalSubContext, backDeclRef); + if (as(modifier)) + { + originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + getBuilder()->addForwardDifferentiableDecoration(irFunc); + } + else if (as(modifier)) + { + originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + } + else + { + originalSubBuilder->addPrimalSubstituteDecoration(originalFuncVal, derivativeFuncVal.val); + } } + subContext->irBuilder->setInsertInto(irFunc); + finalVal->moveToEnd(); } - subContext->irBuilder->setInsertInto(irFunc); - finalVal->moveToEnd(); } return LoweredValInfo::simple(finalVal); } -- cgit v1.2.3