summaryrefslogtreecommitdiff
path: root/source/slang/slang-lower-to-ir.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-08 21:52:34 -0800
committerGitHub <noreply@github.com>2023-03-08 21:52:34 -0800
commit86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 (patch)
treeb4f9eb6cb1eea88145fde0bd1f670a8803120257 /source/slang/slang-lower-to-ir.cpp
parent257733f328f38a763c8b0c8830ff4c0d34ec9491 (diff)
Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691)
* Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. * Fix * Fix. * Cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-lower-to-ir.cpp')
-rw-r--r--source/slang/slang-lower-to-ir.cpp357
1 files changed, 190 insertions, 167 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 261e08168..d8912cbd4 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3165,6 +3165,17 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
baseVal.val));
}
+ LoweredValInfo visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr)
+ {
+ auto baseVal = lowerSubExpr(expr->baseFunction);
+ SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ return LoweredValInfo::simple(
+ getBuilder()->emitPrimalSubstituteInst(
+ lowerType(context, expr->type),
+ baseVal.val));
+ }
+
LoweredValInfo visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr)
{
auto baseVal = lowerSubExpr(expr->innerExpr);
@@ -7970,14 +7981,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
addNameHint(subContext, irFunc, decl);
addLinkageDecoration(subContext, irFunc, decl);
- if (decl->findModifier<ForwardDifferentiableAttribute>())
- {
- getBuilder()->addForwardDifferentiableDecoration(irFunc);
- }
- if (decl->findModifier<BackwardDifferentiableAttribute>())
- {
- getBuilder()->addBackwardDifferentiableDecoration(irFunc);
- }
if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>())
{
lowerDifferentiableAttribute(subContext, irFunc, differentialAttr);
@@ -8291,156 +8294,156 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addRequireCUDASMVersionDecoration(irFunc, versionMod->version);
}
- if (decl->findModifier<RequiresNVAPIAttribute>())
- {
- getBuilder()->addSimpleDecoration<IRRequiresNVAPIDecoration>(irFunc);
- }
-
- if (decl->findModifier<AlwaysFoldIntoUseSiteAttribute>())
- {
- getBuilder()->addSimpleDecoration<IRAlwaysFoldIntoUseSiteDecoration>(irFunc);
- }
-
- if (decl->findModifier<NoInlineAttribute>())
- {
- getBuilder()->addSimpleDecoration<IRNoInlineDecoration>(irFunc);
- }
-
- if (auto attr = decl->findModifier<InstanceAttribute>())
- {
- IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr);
- getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit);
- }
-
- if (auto attr = decl->findModifier<MaxVertexCountAttribute>())
- {
- IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr);
- getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit);
- }
-
- if (auto attr = decl->findModifier<NumThreadsAttribute>())
- {
- auto builder = getBuilder();
- IRType* intType = builder->getIntType();
-
- IRInst* operands[3] = {
- builder->getIntValue(intType, attr->x),
- builder->getIntValue(intType, attr->y),
- builder->getIntValue(intType, attr->z)
- };
-
- builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3);
- }
-
- if (decl->findModifier<ReadNoneAttribute>())
- {
- getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc);
- }
-
- if (decl->findModifier<EarlyDepthStencilAttribute>())
- {
- getBuilder()->addSimpleDecoration<IREarlyDepthStencilDecoration>(irFunc);
- }
-
- if (auto attr = decl->findModifier<DomainAttribute>())
- {
- IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr);
- getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit);
- }
+ // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute
+ setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc)));
- if (auto attr = decl->findModifier<PartitioningAttribute>())
+ for (auto modifier : decl->modifiers)
{
- IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr);
- getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit);
- }
+ if (as<RequiresNVAPIAttribute>(modifier))
+ {
+ getBuilder()->addSimpleDecoration<IRRequiresNVAPIDecoration>(irFunc);
+ }
+ else if (as<AlwaysFoldIntoUseSiteAttribute>(modifier))
+ {
+ getBuilder()->addSimpleDecoration<IRAlwaysFoldIntoUseSiteDecoration>(irFunc);
+ }
+ else if (as<NoInlineAttribute>(modifier))
+ {
+ getBuilder()->addSimpleDecoration<IRNoInlineDecoration>(irFunc);
+ }
+ else if (auto instanceAttr = as<InstanceAttribute>(modifier))
+ {
+ IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), instanceAttr);
+ getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit);
+ }
+ else if (auto maxVertCountAttr = as<MaxVertexCountAttribute>(modifier))
+ {
+ IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), maxVertCountAttr);
+ getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit);
+ }
+ else if (auto numThreadsAttr = as<NumThreadsAttribute>(modifier))
+ {
+ auto builder = getBuilder();
+ IRType* intType = builder->getIntType();
- if (auto attr = decl->findModifier<OutputTopologyAttribute>())
- {
- IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), attr);
- getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit);
- }
+ IRInst* operands[3] = {
+ builder->getIntValue(intType, numThreadsAttr->x),
+ builder->getIntValue(intType, numThreadsAttr->y),
+ builder->getIntValue(intType, numThreadsAttr->z)
+ };
- if (auto attr = decl->findModifier<OutputControlPointsAttribute>())
- {
- IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr);
- getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit);
- }
+ builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3);
+ }
+ else if (as<ReadNoneAttribute>(modifier))
+ {
+ getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc);
+ }
+ else if (as<EarlyDepthStencilAttribute>(modifier))
+ {
+ getBuilder()->addSimpleDecoration<IREarlyDepthStencilDecoration>(irFunc);
+ }
+ else if (auto domainAttr = as<DomainAttribute>(modifier))
+ {
+ IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), domainAttr);
+ getBuilder()->addDecoration(irFunc, kIROp_DomainDecoration, stringLit);
+ }
+ else if (auto partitionAttr = as<PartitioningAttribute>(modifier))
+ {
+ IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), partitionAttr);
+ getBuilder()->addDecoration(irFunc, kIROp_PartitioningDecoration, stringLit);
+ }
+ else if (auto outputTopAttr = as<OutputTopologyAttribute>(modifier))
+ {
+ IRStringLit* stringLit = _getStringLitFromAttribute(getBuilder(), outputTopAttr);
+ getBuilder()->addDecoration(irFunc, kIROp_OutputTopologyDecoration, stringLit);
+ }
+ else if (auto outputCtrlPtAttr = as<OutputControlPointsAttribute>(modifier))
+ {
+ IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), outputCtrlPtAttr);
+ getBuilder()->addDecoration(irFunc, kIROp_OutputControlPointsDecoration, intLit);
+ }
+ else if (auto spvInstOpAttr = as<SPIRVInstructionOpAttribute>(modifier))
+ {
+ auto builder = getBuilder();
+ IRIntLit* intLit = _getIntLitFromAttribute(builder, spvInstOpAttr, 0);
- if (auto attr = decl->findModifier<SPIRVInstructionOpAttribute>())
- {
- auto builder = getBuilder();
- IRIntLit* intLit = _getIntLitFromAttribute(builder, attr, 0);
+ IRStringLit* setStringLit = nullptr;
+ if (spvInstOpAttr->args.getCount() > 1)
+ {
+ IRStringLit* checkSetStringLit = _getStringLitFromAttribute(builder, spvInstOpAttr, 1);
+ if (checkSetStringLit && checkSetStringLit->getStringSlice().getLength() > 0)
+ {
+ setStringLit = checkSetStringLit;
+ }
+ }
- IRStringLit* setStringLit = nullptr;
- if (attr->args.getCount() > 1)
- {
- IRStringLit* checkSetStringLit = _getStringLitFromAttribute(builder, attr, 1);
- if (checkSetStringLit && checkSetStringLit->getStringSlice().getLength() > 0)
+ // If it has a `set` defined, set it on the decoration
+ if (setStringLit)
{
- setStringLit = checkSetStringLit;
+ builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit, setStringLit);
+ }
+ else
+ {
+ builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit);
}
}
-
- // If it has a `set` defined, set it on the decoration
- if (setStringLit)
+ else if (as<UnsafeForceInlineEarlyAttribute>(modifier))
{
- builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit, setStringLit);
+ getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration);
}
- else
+ else if (as<ForceInlineAttribute>(modifier))
{
- builder->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit);
+ getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration);
}
- }
-
- if (decl->findModifier<UnsafeForceInlineEarlyAttribute>())
- {
- getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration);
- }
-
- if (decl->findModifier<ForceInlineAttribute>())
- {
- getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration);
- }
-
- if (decl->findModifier<TreatAsDifferentiableAttribute>())
- {
- getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration);
- }
-
- if (auto intrinsicOp = decl->findModifier<IntrinsicOpModifier>())
- {
- auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op);
- getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op);
- }
-
- // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute
- setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc)));
-
- if (auto attr = decl->findModifier<UserDefinedDerivativeAttribute>())
- {
- // We need to lower the decl ref to the custom derivative function to IR.
- // The IR insts correspond to the decl ref is not part of the function we
- // are processing. If we emit it directly to within the function, it could
- // mess up the assumption on the form of the IR (e.g. having non decoration insts
- // appearing in the middle of decoration insts). so we emit the decl ref to the
- // function's parent for now.
-
- subContext->irBuilder->setInsertInto(irFunc->getParent());
-
- auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr);
-
- SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
- IRInst* derivativeFunc = loweredVal.val;
-
- if (as<ForwardDerivativeAttribute>(attr))
- getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc);
- else
- getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc);
+ else if (as<TreatAsDifferentiableAttribute>(modifier))
+ {
+ getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration);
+ }
+ else if (auto intrinsicOp = as<IntrinsicOpModifier>(modifier))
+ {
+ auto op = getBuilder()->getIntValue(getBuilder()->getIntType(), intrinsicOp->op);
+ getBuilder()->addDecoration(irFunc, kIROp_IntrinsicOpDecoration, op);
+ }
+ else if (as<UserDefinedDerivativeAttribute>(modifier) || as<PrimalSubstituteAttribute>(modifier))
+ {
+ // We need to lower the decl ref to the custom derivative function to IR.
+ // The IR insts correspond to the decl ref is not part of the function we
+ // are processing. If we emit it directly to within the function, it could
+ // mess up the assumption on the form of the IR (e.g. having non decoration insts
+ // appearing in the middle of decoration insts). so we emit the decl ref to the
+ // function's parent for now.
+
+ subContext->irBuilder->setInsertInto(irFunc->getParent());
+ Expr* funcExpr = nullptr;
+ if (auto udAttr = as<UserDefinedDerivativeAttribute>(modifier))
+ funcExpr = udAttr->funcExpr;
+ else if (auto primalAttr = as<PrimalSubstituteAttribute>(modifier))
+ funcExpr = primalAttr->funcExpr;
+
+ auto loweredVal = lowerRValueExpr(subContext, funcExpr);
+
+ SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
+ IRInst* derivativeFunc = loweredVal.val;
+
+ if (as<ForwardDerivativeAttribute>(modifier))
+ getBuilder()->addForwardDerivativeDecoration(irFunc, derivativeFunc);
+ else if (as<BackwardDerivativeAttribute>(modifier))
+ getBuilder()->addUserDefinedBackwardDerivativeDecoration(irFunc, derivativeFunc);
+ else
+ getBuilder()->addPrimalSubstituteDecoration(irFunc, derivativeFunc);
- // Reset cursor.
- subContext->irBuilder->setInsertInto(irFunc);
+ // Reset cursor.
+ subContext->irBuilder->setInsertInto(irFunc);
+ }
+ else if (as<ForwardDifferentiableAttribute>(modifier))
+ {
+ getBuilder()->addForwardDifferentiableDecoration(irFunc);
+ }
+ else if (as<BackwardDifferentiableAttribute>(modifier))
+ {
+ getBuilder()->addBackwardDifferentiableDecoration(irFunc);
+ }
}
-
// For convenience, ensure that any additional global
// values that were emitted while outputting the function
// body appear before the function itself in the list
@@ -8451,39 +8454,59 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// the interface's type definition.
auto finalVal = finishOuterGenerics(subBuilder, irFunc, outerGeneric);
- if (auto attr = decl->findModifier<DerivativeOfAttribute>())
+ for (auto modifier : decl->modifiers)
{
- if (auto originalDeclRefExpr = as<DeclRefExpr>(attr->funcExpr))
+ if (as<DerivativeOfAttribute>(modifier) || as<PrimalSubstituteOfAttribute>(modifier))
{
- NestedContext originalContextFunc(this);
- auto originalSubBuilder = originalContextFunc.getBuilder();
- auto originalSubContext = originalContextFunc.getContext();
- if (auto outterGeneric = getOuterGeneric(irFunc))
- originalSubBuilder->setInsertBefore(outterGeneric);
- else
- originalSubBuilder->setInsertBefore(irFunc);
- auto originalFuncDecl = as<FunctionDeclBase>(originalDeclRefExpr->declRef.getDecl());
- SLANG_RELEASE_ASSERT(originalFuncDecl);
-
- auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val;
- if (auto originalFuncGeneric = as<IRGeneric>(originalFuncVal))
+ Expr* funcExpr = nullptr;
+ Expr* backDeclRef = nullptr;
+ if (auto attr = as<DerivativeOfAttribute>(modifier))
{
- originalFuncVal = findGenericReturnVal(originalFuncGeneric);
+ funcExpr = attr->funcExpr;
+ backDeclRef = attr->backDeclRef;
}
- originalSubBuilder->setInsertBefore(originalFuncVal);
- auto derivativeFuncVal = lowerRValueExpr(originalSubContext, attr->backDeclRef);
- if (as<ForwardDerivativeOfAttribute>(attr))
+ else if (auto primalAttr = as<PrimalSubstituteOfAttribute>(modifier))
{
- originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
- getBuilder()->addForwardDifferentiableDecoration(irFunc);
+ funcExpr = primalAttr->funcExpr;
+ backDeclRef = primalAttr->backDeclRef;
}
- else
+
+ if (auto originalDeclRefExpr = as<DeclRefExpr>(funcExpr))
{
- originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
+ NestedContext originalContextFunc(this);
+ auto originalSubBuilder = originalContextFunc.getBuilder();
+ auto originalSubContext = originalContextFunc.getContext();
+ if (auto outterGeneric = getOuterGeneric(irFunc))
+ originalSubBuilder->setInsertBefore(outterGeneric);
+ else
+ originalSubBuilder->setInsertBefore(irFunc);
+ auto originalFuncDecl = as<FunctionDeclBase>(originalDeclRefExpr->declRef.getDecl());
+ SLANG_RELEASE_ASSERT(originalFuncDecl);
+
+ auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val;
+ if (auto originalFuncGeneric = as<IRGeneric>(originalFuncVal))
+ {
+ originalFuncVal = findGenericReturnVal(originalFuncGeneric);
+ }
+ originalSubBuilder->setInsertBefore(originalFuncVal);
+ auto derivativeFuncVal = lowerRValueExpr(originalSubContext, backDeclRef);
+ if (as<ForwardDerivativeOfAttribute>(modifier))
+ {
+ originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
+ getBuilder()->addForwardDifferentiableDecoration(irFunc);
+ }
+ else if (as<BackwardDerivativeOfAttribute>(modifier))
+ {
+ originalSubBuilder->addUserDefinedBackwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
+ }
+ else
+ {
+ originalSubBuilder->addPrimalSubstituteDecoration(originalFuncVal, derivativeFuncVal.val);
+ }
}
+ subContext->irBuilder->setInsertInto(irFunc);
+ finalVal->moveToEnd();
}
- subContext->irBuilder->setInsertInto(irFunc);
- finalVal->moveToEnd();
}
return LoweredValInfo::simple(finalVal);
}