diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-08-27 21:13:00 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-27 18:13:00 -0700 |
| commit | 6bb32aa976494466bd6303f8ae6e348b297edb44 (patch) | |
| tree | 54765fd5168e1d6590f403c6df04b30404ee6346 /source | |
| parent | a9882c648c58e6f2821df11c7ee6ac77d9f09473 (diff) | |
Adds a warning for using `[PreferRecompute]` on methods that may contain side effects (#4707)
* Adds a warning for using prefer-recompute on methods that contain side effects
* Rename `SideEffects` -> `SideEffectBehavior`
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 10 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 60 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 9 |
9 files changed, 108 insertions, 7 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 0b57993ef..c1eb2597a 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2884,8 +2884,16 @@ attribute_syntax [payload] : PayloadAttribute; __attributeTarget(DeclBase) attribute_syntax [deprecated(message: String)] : DeprecatedAttribute; +enum SideEffectBehavior +{ + /// Causes a warning if the method is detected to have side-effects + Warn = 0, + + /// Suppresses the warning + Allow = 1 +}; __attributeTarget(FunctionDeclBase) -attribute_syntax [PreferRecompute] : PreferRecomputeAttribute; +attribute_syntax[PreferRecompute(behavior: SideEffectBehavior = SideEffectBehavior.Warn)] : PreferRecomputeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [PreferCheckpoint] : PreferCheckpointAttribute; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index c890b874c..14e945e25 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1275,6 +1275,14 @@ class PyExportAttribute : public Attribute class PreferRecomputeAttribute : public Attribute { SLANG_AST_CLASS(PreferRecomputeAttribute) + + enum SideEffectBehavior + { + Warn = 0, + Allow = 1 + }; + + SideEffectBehavior sideEffectBehavior; }; class PreferCheckpointAttribute : public Attribute diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 483ff0e18..705d0bb3b 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -842,6 +842,16 @@ namespace Slang else if (auto primalOfAttr = as<PrimalSubstituteOfAttribute>(attr)) primalOfAttr->funcExpr = attr->args[0]; } + else if (auto preferRecomputeAttr = as<PreferRecomputeAttribute>(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + SLANG_ASSERT(as<Decl>(attrTarget)); + + auto val = checkConstantIntVal(attr->args[0]); + if (!val) return false; + + preferRecomputeAttr->sideEffectBehavior = (PreferRecomputeAttribute::SideEffectBehavior) val->getValue(); + } else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 7288befe8..e058fcd91 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -785,6 +785,8 @@ DIAGNOSTIC(41904, Error, unableToAlignOf, "alignof could not be performed for ty DIAGNOSTIC(42001, Error, invalidUseOfTorchTensorTypeInDeviceFunc, "invalid use of TorchTensor type in device/kernel functions. use `TensorView` instead.") +DIAGNOSTIC(42050, Warning, potentialIssuesWithPreferRecomputeOnSideEffectMethod, "$0 has [PreferRecompute] and may have side effects. side effects may execute multiple times. use [PreferRecompute(SideEffectBehavior.Allow)], or mark function with [NoSideEffect]") + DIAGNOSTIC(45001, Error, unresolvedSymbol, "unresolved external symbol '$0'.") DIAGNOSTIC(41201, Warning, expectDynamicUniformArgument, "argument for '$0' might not be a dynamic uniform, use `asDynamicUniform()` to silence this warning.") diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index e91209108..103cd15ab 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -637,6 +637,13 @@ Result linkAndOptimizeIR( default: break; } + + if (requiredLoweringPassSet.autodiff) + { + // Generate warnings for potentially incorrect or badly-performing autodiff patterns. + checkAutodiffPatterns(targetProgram, irModule, sink); + } + // Next, we need to ensure that the code we emit for // the target doesn't contain any operations that would // be illegal on the target platform. For example, diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 9a69d9aa8..35a197f29 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -147,7 +147,7 @@ namespace Slang if (!checkpointHint) checkpointHint = originalFunc->findDecoration<IRCheckpointHintDecoration>(); if (checkpointHint) - builder->addDecoration(existingPrimalFunc, checkpointHint->getOp()); + cloneCheckpointHint(builder, checkpointHint, cast<IRGlobalValueWithCode>(existingPrimalFunc)); builder->emitBlock(); params = _defineFuncParams(builder, as<IRFunc>(existingPrimalFunc)); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index b7c2037e5..bf83d8d7f 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1123,19 +1123,33 @@ IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness( return nullptr; } - void copyCheckpointHints(IRBuilder* builder, IRGlobalValueWithCode* oldInst, IRGlobalValueWithCode* newInst) { for (auto decor = oldInst->getFirstDecoration(); decor; decor = decor->getNextDecoration()) { if (auto chkHint = as<IRCheckpointHintDecoration>(decor)) { - SLANG_ASSERT(chkHint->getOperandCount() == 0); - builder->addDecoration(newInst, chkHint->getOp()); + cloneCheckpointHint(builder, chkHint, newInst); } } } +void cloneCheckpointHint(IRBuilder* builder, IRCheckpointHintDecoration* chkHint, IRGlobalValueWithCode* target) +{ + // Grab all the operands + List<IRInst*> operands; + for (UCount operand = 0; operand < chkHint->getOperandCount(); operand++) + { + operands.add(chkHint->getOperand(operand)); + } + + builder->addDecoration( + target, + chkHint->getOp(), + operands.getBuffer(), + operands.getCount()); +} + void stripDerivativeDecorations(IRInst* inst) { for (auto decor = inst->getFirstDecoration(); decor; ) @@ -2096,6 +2110,46 @@ protected: }; +void checkAutodiffPatterns( + TargetProgram* target, + IRModule* module, + DiagnosticSink* sink) +{ + SLANG_UNUSED(target); + + enum SideEffectBehavior + { + Warn = 0, + Allow = 1, + }; + + // For now, we have only 1 check to see if methods that have side-effects + // are marked with prefer-recompute + // + for (auto inst : module->getGlobalInsts()) + { + if (auto func = as<IRFunc>(inst)) + { + if (func->sourceLoc.isValid() && // Don't diagnose for synthesized functions + func->findDecoration<IRPreferRecomputeDecoration>() && + !func->findDecoration<IRNoSideEffectDecoration>()) + { + auto preferRecomputeDecor = func->findDecoration<IRPreferRecomputeDecoration>(); + auto sideEffectBehavior = as<IRIntLit>(preferRecomputeDecor->getOperand(0))->getValue(); + + if (sideEffectBehavior == SideEffectBehavior::Allow) + continue; + + // Find function name. (don't diagnose on nameless functions) + if (auto nameHint = func->findDecoration<IRNameHintDecoration>()) + { + sink->diagnose(func, Diagnostics::potentialIssuesWithPreferRecomputeOnSideEffectMethod, nameHint->getName()); + } + } + } + } +} + bool processAutodiffCalls( TargetProgram* target, IRModule* module, diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 23ae717be..812471fe3 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -363,6 +363,11 @@ struct IRAutodiffPassOptions // Nothing for now... }; +void checkAutodiffPatterns( + TargetProgram* target, + IRModule* module, + DiagnosticSink* sink); + bool processAutodiffCalls( TargetProgram* target, IRModule* module, @@ -375,6 +380,8 @@ bool finalizeAutoDiffPass(TargetProgram* target, IRModule* module); void copyCheckpointHints(IRBuilder* builder, IRGlobalValueWithCode* oldInst, IRGlobalValueWithCode* newInst); +void cloneCheckpointHint(IRBuilder* builder, IRCheckpointHintDecoration* oldInst, IRGlobalValueWithCode* code); + void stripDerivativeDecorations(IRInst* inst); bool isBackwardDifferentiableFunc(IRInst* func); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index b9d7a898f..31427e616 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -10203,9 +10203,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { getBuilder()->addDecoration(irFunc, kIROp_PreferCheckpointDecoration); } - else if (as<PreferRecomputeAttribute>(modifier)) + else if (auto attr = as<PreferRecomputeAttribute>(modifier)) { - getBuilder()->addDecoration(irFunc, kIROp_PreferRecomputeDecoration); + getBuilder()->addDecoration( + irFunc, + kIROp_PreferRecomputeDecoration, + getBuilder()->getIntValue( + getBuilder()->getIntType(), + attr->sideEffectBehavior)); } else if (auto extensionMod = as<RequiredGLSLExtensionModifier>(modifier)) getBuilder()->addRequireGLSLExtensionDecoration(irFunc, extensionMod->extensionNameToken.getContent()); |
