summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-check-decl.cpp18
-rw-r--r--source/slang/slang-parser.cpp21
-rw-r--r--tests/autodiff/no-diff-static.slang31
3 files changed, 57 insertions, 13 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 4bfbde584..867c1daad 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -10241,6 +10241,17 @@ void SemanticsDeclHeaderVisitor::checkInterfaceRequirement(Decl* decl)
}
}
+bool doesTypeHaveNoDiffModifier(Type* type)
+{
+ if (auto modifiedType = as<ModifiedType>(type))
+ {
+ if (modifiedType->findModifier<NoDiffModifierVal>() != nullptr)
+ return true;
+ return doesTypeHaveNoDiffModifier(modifiedType->getBase());
+ }
+ return false;
+}
+
void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
{
for (auto paramDecl : decl->getParameters())
@@ -10259,6 +10270,13 @@ void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
}
decl->errorType = errorType;
+ if (doesTypeHaveNoDiffModifier(decl->returnType.type))
+ {
+ auto noDiffMod = m_astBuilder->create<NoDiffModifier>();
+ noDiffMod->loc = decl->loc;
+ addModifier(decl, noDiffMod);
+ }
+
checkDifferentiableCallableCommon(decl);
// If this method is intended to be a CUDA kernel, verify that the return type is void.
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 3519b6d43..7208ab67e 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -2708,14 +2708,14 @@ static Expr* _applyModifiersToTypeExpr(Parser* parser, Expr* typeExpr, Modifiers
}
}
-/// Apply any type modifier in `ioBaseModifiers` to the given `typeExpr`.
+/// Move any type modifier in `ioBaseModifiers` to the given `typeExpr`.
///
/// If any type modifiers were present, `ioBaseModifiers` will be updated
/// to only include those modifiers that were not type modifiers (if any).
///
/// If no type modifiers were present, `ioBaseModifiers` will remain unchanged.
///
-static Expr* _applyTypeModifiersToTypeExpr(
+static Expr* _moveTypeModifiersToTypeExpr(
Parser* parser,
Expr* typeExpr,
Modifiers& ioBaseModifiers)
@@ -2763,7 +2763,7 @@ static Expr* _applyTypeModifiersToTypeExpr(
// a pointer to the type modifier into the "link" for
// the type modifier list, and updating the link to point
// to the `next` field of the current modifier (since that
- // fill be the location any further type modifiers need
+ // will be the location any further type modifiers need
// to be linked).
//
*typeModifierLink = typeModifier;
@@ -2792,10 +2792,7 @@ static Expr* _applyTypeModifiersToTypeExpr(
return _applyModifiersToTypeExpr(parser, typeExpr, typeModifiers);
}
-static TypeSpec _applyModifiersToTypeSpec(
- Parser* parser,
- TypeSpec typeSpec,
- Modifiers const& inModifiers)
+static TypeSpec _applyModifiersToTypeSpec(Parser* parser, TypeSpec typeSpec, Modifiers& modifiers)
{
// It is possible that the form of the type specifier will have
// included a declaration directly (e.g., using `struct { ... }`
@@ -2809,8 +2806,7 @@ static TypeSpec _applyModifiersToTypeSpec(
// and any modifiers that logically belong to the declaration to
// the declaration.
//
- Modifiers modifiers = inModifiers;
- typeSpec.expr = _applyTypeModifiersToTypeExpr(parser, typeSpec.expr, modifiers);
+ typeSpec.expr = _moveTypeModifiersToTypeExpr(parser, typeSpec.expr, modifiers);
// Any remaining modifiers should instead be applied to the declaration.
_addModifiers(decl, modifiers);
@@ -2821,7 +2817,7 @@ static TypeSpec _applyModifiersToTypeSpec(
// This may result in modifiers being applied that do not belong on a type;
// in that case we rely on downstream semantic checking to diagnose any error.
//
- typeSpec.expr = _applyModifiersToTypeExpr(parser, typeSpec.expr, inModifiers);
+ typeSpec.expr = _applyModifiersToTypeExpr(parser, typeSpec.expr, modifiers);
}
return typeSpec;
@@ -2962,7 +2958,7 @@ static TypeSpec _parseTypeSpec(Parser* parser, Modifiers& ioModifiers)
// or which of them might be type modifiers, so we will delegate
// figuring that out to a subroutine.
//
- typeSpec.expr = _applyTypeModifiersToTypeExpr(parser, typeSpec.expr, ioModifiers);
+ typeSpec.expr = _moveTypeModifiersToTypeExpr(parser, typeSpec.expr, ioModifiers);
return typeSpec;
}
@@ -2992,11 +2988,10 @@ static TypeSpec _parseTypeSpec(Parser* parser)
static DeclBase* ParseDeclaratorDecl(
Parser* parser,
ContainerDecl* containerDecl,
- Modifiers const& inModifiers)
+ Modifiers& modifiers)
{
SourceLoc startPosition = parser->tokenReader.peekLoc();
- Modifiers modifiers = inModifiers;
auto typeSpec = _parseTypeSpec(parser, modifiers);
if (typeSpec.expr == nullptr && typeSpec.decl == nullptr)
diff --git a/tests/autodiff/no-diff-static.slang b/tests/autodiff/no-diff-static.slang
new file mode 100644
index 000000000..15d316a90
--- /dev/null
+++ b/tests/autodiff/no-diff-static.slang
@@ -0,0 +1,31 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-slang -compute -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+// BUF: 1.000000
+// BUF: 1.000000
+
+struct TestStruct : IDifferentiable
+{
+ static const no_diff float foo = 1.0f;
+
+ int x;
+
+ void assignOut()
+ {
+ outputBuffer[0] = foo;
+ }
+}
+no_diff static const float foo = 1.0f;
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain(uint thread_idx: SV_DispatchThreadID)
+{
+ if (thread_idx == 0) {
+ TestStruct t;
+ t.x = 10;
+ t.assignOut();
+ }
+ outputBuffer[1] = foo;
+}