summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-08-27 21:13:00 -0400
committerGitHub <noreply@github.com>2024-08-27 18:13:00 -0700
commit6bb32aa976494466bd6303f8ae6e348b297edb44 (patch)
tree54765fd5168e1d6590f403c6df04b30404ee6346 /source
parenta9882c648c58e6f2821df11c7ee6ac77d9f09473 (diff)
Adds a warning for using `[PreferRecompute]` on methods that may contain side effects (#4707)
* Adds a warning for using prefer-recompute on methods that contain side effects * Rename `SideEffects` -> `SideEffectBehavior` --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang10
-rw-r--r--source/slang/slang-ast-modifier.h8
-rw-r--r--source/slang/slang-check-modifier.cpp10
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-emit.cpp7
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp2
-rw-r--r--source/slang/slang-ir-autodiff.cpp60
-rw-r--r--source/slang/slang-ir-autodiff.h7
-rw-r--r--source/slang/slang-lower-to-ir.cpp9
9 files changed, 108 insertions, 7 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 0b57993ef..c1eb2597a 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2884,8 +2884,16 @@ attribute_syntax [payload] : PayloadAttribute;
__attributeTarget(DeclBase)
attribute_syntax [deprecated(message: String)] : DeprecatedAttribute;
+enum SideEffectBehavior
+{
+ /// Causes a warning if the method is detected to have side-effects
+ Warn = 0,
+
+ /// Suppresses the warning
+ Allow = 1
+};
__attributeTarget(FunctionDeclBase)
-attribute_syntax [PreferRecompute] : PreferRecomputeAttribute;
+attribute_syntax[PreferRecompute(behavior: SideEffectBehavior = SideEffectBehavior.Warn)] : PreferRecomputeAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [PreferCheckpoint] : PreferCheckpointAttribute;
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index c890b874c..14e945e25 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1275,6 +1275,14 @@ class PyExportAttribute : public Attribute
class PreferRecomputeAttribute : public Attribute
{
SLANG_AST_CLASS(PreferRecomputeAttribute)
+
+ enum SideEffectBehavior
+ {
+ Warn = 0,
+ Allow = 1
+ };
+
+ SideEffectBehavior sideEffectBehavior;
};
class PreferCheckpointAttribute : public Attribute
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 483ff0e18..705d0bb3b 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -842,6 +842,16 @@ namespace Slang
else if (auto primalOfAttr = as<PrimalSubstituteOfAttribute>(attr))
primalOfAttr->funcExpr = attr->args[0];
}
+ else if (auto preferRecomputeAttr = as<PreferRecomputeAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
+ SLANG_ASSERT(as<Decl>(attrTarget));
+
+ auto val = checkConstantIntVal(attr->args[0]);
+ if (!val) return false;
+
+ preferRecomputeAttr->sideEffectBehavior = (PreferRecomputeAttribute::SideEffectBehavior) val->getValue();
+ }
else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 7288befe8..e058fcd91 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -785,6 +785,8 @@ DIAGNOSTIC(41904, Error, unableToAlignOf, "alignof could not be performed for ty
DIAGNOSTIC(42001, Error, invalidUseOfTorchTensorTypeInDeviceFunc, "invalid use of TorchTensor type in device/kernel functions. use `TensorView` instead.")
+DIAGNOSTIC(42050, Warning, potentialIssuesWithPreferRecomputeOnSideEffectMethod, "$0 has [PreferRecompute] and may have side effects. side effects may execute multiple times. use [PreferRecompute(SideEffectBehavior.Allow)], or mark function with [NoSideEffect]")
+
DIAGNOSTIC(45001, Error, unresolvedSymbol, "unresolved external symbol '$0'.")
DIAGNOSTIC(41201, Warning, expectDynamicUniformArgument, "argument for '$0' might not be a dynamic uniform, use `asDynamicUniform()` to silence this warning.")
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index e91209108..103cd15ab 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -637,6 +637,13 @@ Result linkAndOptimizeIR(
default:
break;
}
+
+ if (requiredLoweringPassSet.autodiff)
+ {
+ // Generate warnings for potentially incorrect or badly-performing autodiff patterns.
+ checkAutodiffPatterns(targetProgram, irModule, sink);
+ }
+
// Next, we need to ensure that the code we emit for
// the target doesn't contain any operations that would
// be illegal on the target platform. For example,
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 9a69d9aa8..35a197f29 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -147,7 +147,7 @@ namespace Slang
if (!checkpointHint)
checkpointHint = originalFunc->findDecoration<IRCheckpointHintDecoration>();
if (checkpointHint)
- builder->addDecoration(existingPrimalFunc, checkpointHint->getOp());
+ cloneCheckpointHint(builder, checkpointHint, cast<IRGlobalValueWithCode>(existingPrimalFunc));
builder->emitBlock();
params = _defineFuncParams(builder, as<IRFunc>(existingPrimalFunc));
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index b7c2037e5..bf83d8d7f 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1123,19 +1123,33 @@ IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness(
return nullptr;
}
-
void copyCheckpointHints(IRBuilder* builder, IRGlobalValueWithCode* oldInst, IRGlobalValueWithCode* newInst)
{
for (auto decor = oldInst->getFirstDecoration(); decor; decor = decor->getNextDecoration())
{
if (auto chkHint = as<IRCheckpointHintDecoration>(decor))
{
- SLANG_ASSERT(chkHint->getOperandCount() == 0);
- builder->addDecoration(newInst, chkHint->getOp());
+ cloneCheckpointHint(builder, chkHint, newInst);
}
}
}
+void cloneCheckpointHint(IRBuilder* builder, IRCheckpointHintDecoration* chkHint, IRGlobalValueWithCode* target)
+{
+ // Grab all the operands
+ List<IRInst*> operands;
+ for (UCount operand = 0; operand < chkHint->getOperandCount(); operand++)
+ {
+ operands.add(chkHint->getOperand(operand));
+ }
+
+ builder->addDecoration(
+ target,
+ chkHint->getOp(),
+ operands.getBuffer(),
+ operands.getCount());
+}
+
void stripDerivativeDecorations(IRInst* inst)
{
for (auto decor = inst->getFirstDecoration(); decor; )
@@ -2096,6 +2110,46 @@ protected:
};
+void checkAutodiffPatterns(
+ TargetProgram* target,
+ IRModule* module,
+ DiagnosticSink* sink)
+{
+ SLANG_UNUSED(target);
+
+ enum SideEffectBehavior
+ {
+ Warn = 0,
+ Allow = 1,
+ };
+
+ // For now, we have only 1 check to see if methods that have side-effects
+ // are marked with prefer-recompute
+ //
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRFunc>(inst))
+ {
+ if (func->sourceLoc.isValid() && // Don't diagnose for synthesized functions
+ func->findDecoration<IRPreferRecomputeDecoration>() &&
+ !func->findDecoration<IRNoSideEffectDecoration>())
+ {
+ auto preferRecomputeDecor = func->findDecoration<IRPreferRecomputeDecoration>();
+ auto sideEffectBehavior = as<IRIntLit>(preferRecomputeDecor->getOperand(0))->getValue();
+
+ if (sideEffectBehavior == SideEffectBehavior::Allow)
+ continue;
+
+ // Find function name. (don't diagnose on nameless functions)
+ if (auto nameHint = func->findDecoration<IRNameHintDecoration>())
+ {
+ sink->diagnose(func, Diagnostics::potentialIssuesWithPreferRecomputeOnSideEffectMethod, nameHint->getName());
+ }
+ }
+ }
+ }
+}
+
bool processAutodiffCalls(
TargetProgram* target,
IRModule* module,
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 23ae717be..812471fe3 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -363,6 +363,11 @@ struct IRAutodiffPassOptions
// Nothing for now...
};
+void checkAutodiffPatterns(
+ TargetProgram* target,
+ IRModule* module,
+ DiagnosticSink* sink);
+
bool processAutodiffCalls(
TargetProgram* target,
IRModule* module,
@@ -375,6 +380,8 @@ bool finalizeAutoDiffPass(TargetProgram* target, IRModule* module);
void copyCheckpointHints(IRBuilder* builder, IRGlobalValueWithCode* oldInst, IRGlobalValueWithCode* newInst);
+void cloneCheckpointHint(IRBuilder* builder, IRCheckpointHintDecoration* oldInst, IRGlobalValueWithCode* code);
+
void stripDerivativeDecorations(IRInst* inst);
bool isBackwardDifferentiableFunc(IRInst* func);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index b9d7a898f..31427e616 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -10203,9 +10203,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
getBuilder()->addDecoration(irFunc, kIROp_PreferCheckpointDecoration);
}
- else if (as<PreferRecomputeAttribute>(modifier))
+ else if (auto attr = as<PreferRecomputeAttribute>(modifier))
{
- getBuilder()->addDecoration(irFunc, kIROp_PreferRecomputeDecoration);
+ getBuilder()->addDecoration(
+ irFunc,
+ kIROp_PreferRecomputeDecoration,
+ getBuilder()->getIntValue(
+ getBuilder()->getIntType(),
+ attr->sideEffectBehavior));
}
else if (auto extensionMod = as<RequiredGLSLExtensionModifier>(modifier))
getBuilder()->addRequireGLSLExtensionDecoration(irFunc, extensionMod->extensionNameToken.getContent());