summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-modifier.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-modifier.cpp')
-rw-r--r--source/slang/slang-check-modifier.cpp30
1 files changed, 10 insertions, 20 deletions
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index e6a524645..a068f19d6 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -654,33 +654,23 @@ namespace Slang
hitObjectAttributesAttr->location = (int32_t)val->value;
}
- else if (auto derivativeAttr = as<UserDefinedDerivativeAttribute>(attr))
+ else if (as<UserDefinedDerivativeAttribute>(attr) || as<PrimalSubstituteAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
SLANG_ASSERT(as<Decl>(attrTarget));
-
- // Ensure that the argument is a reference to a function definition or declaration.
- auto diffExpr = CheckTerm(attr->args[0]);
- if (diffExpr->type == getASTBuilder()->getErrorType())
- {
- // Could not resolve the term.
- getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget));
- return false;
- }
- // We store the partially checked funcExpr in the attribute, and
- // rely on `ResolveInvoke` to resolve it to the actual function decl.
- // The call to `ResolveInvoke` is deferred until we are checking the
- // body of the function.
- //
- // Set type to null to indicate that this needs expr needs to be further resolved.
- diffExpr->type.type = nullptr;
- derivativeAttr->funcExpr = diffExpr;
+ if (auto derivativeAttr = as<UserDefinedDerivativeAttribute>(attr))
+ derivativeAttr->funcExpr = attr->args[0];
+ else if (auto primalSubstAttr = as<PrimalSubstituteAttribute>(attr))
+ primalSubstAttr->funcExpr = attr->args[0];
}
- else if (auto derivativeOfAttr = as<DerivativeOfAttribute>(attr))
+ else if (as<DerivativeOfAttribute>(attr) || as<PrimalSubstituteOfAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
SLANG_ASSERT(as<Decl>(attrTarget));
- derivativeOfAttr->funcExpr = attr->args[0];
+ if (auto derivativeOfAttr = as<DerivativeOfAttribute>(attr))
+ derivativeOfAttr->funcExpr = attr->args[0];
+ else if (auto primalOfAttr = as<PrimalSubstituteOfAttribute>(attr))
+ primalOfAttr->funcExpr = attr->args[0];
}
else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr))
{