diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-08-27 21:13:00 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-27 18:13:00 -0700 |
| commit | 6bb32aa976494466bd6303f8ae6e348b297edb44 (patch) | |
| tree | 54765fd5168e1d6590f403c6df04b30404ee6346 | |
| parent | a9882c648c58e6f2821df11c7ee6ac77d9f09473 (diff) | |
Adds a warning for using `[PreferRecompute]` on methods that may contain side effects (#4707)
* Adds a warning for using prefer-recompute on methods that contain side effects
* Rename `SideEffects` -> `SideEffectBehavior`
---------
Co-authored-by: Yong He <yonghe@outlook.com>
| -rw-r--r-- | source/slang/core.meta.slang | 10 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 60 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 9 | ||||
| -rw-r--r-- | tests/autodiff/warn-on-prefer-recompute-side-effects.slang | 47 |
10 files changed, 155 insertions, 7 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 0b57993ef..c1eb2597a 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2884,8 +2884,16 @@ attribute_syntax [payload] : PayloadAttribute; __attributeTarget(DeclBase) attribute_syntax [deprecated(message: String)] : DeprecatedAttribute; +enum SideEffectBehavior +{ + /// Causes a warning if the method is detected to have side-effects + Warn = 0, + + /// Suppresses the warning + Allow = 1 +}; __attributeTarget(FunctionDeclBase) -attribute_syntax [PreferRecompute] : PreferRecomputeAttribute; +attribute_syntax[PreferRecompute(behavior: SideEffectBehavior = SideEffectBehavior.Warn)] : PreferRecomputeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [PreferCheckpoint] : PreferCheckpointAttribute; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index c890b874c..14e945e25 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1275,6 +1275,14 @@ class PyExportAttribute : public Attribute class PreferRecomputeAttribute : public Attribute { SLANG_AST_CLASS(PreferRecomputeAttribute) + + enum SideEffectBehavior + { + Warn = 0, + Allow = 1 + }; + + SideEffectBehavior sideEffectBehavior; }; class PreferCheckpointAttribute : public Attribute diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 483ff0e18..705d0bb3b 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -842,6 +842,16 @@ namespace Slang else if (auto primalOfAttr = as<PrimalSubstituteOfAttribute>(attr)) primalOfAttr->funcExpr = attr->args[0]; } + else if (auto preferRecomputeAttr = as<PreferRecomputeAttribute>(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + SLANG_ASSERT(as<Decl>(attrTarget)); + + auto val = checkConstantIntVal(attr->args[0]); + if (!val) return false; + + preferRecomputeAttr->sideEffectBehavior = (PreferRecomputeAttribute::SideEffectBehavior) val->getValue(); + } else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 7288befe8..e058fcd91 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -785,6 +785,8 @@ DIAGNOSTIC(41904, Error, unableToAlignOf, "alignof could not be performed for ty DIAGNOSTIC(42001, Error, invalidUseOfTorchTensorTypeInDeviceFunc, "invalid use of TorchTensor type in device/kernel functions. use `TensorView` instead.") +DIAGNOSTIC(42050, Warning, potentialIssuesWithPreferRecomputeOnSideEffectMethod, "$0 has [PreferRecompute] and may have side effects. side effects may execute multiple times. use [PreferRecompute(SideEffectBehavior.Allow)], or mark function with [NoSideEffect]") + DIAGNOSTIC(45001, Error, unresolvedSymbol, "unresolved external symbol '$0'.") DIAGNOSTIC(41201, Warning, expectDynamicUniformArgument, "argument for '$0' might not be a dynamic uniform, use `asDynamicUniform()` to silence this warning.") diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index e91209108..103cd15ab 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -637,6 +637,13 @@ Result linkAndOptimizeIR( default: break; } + + if (requiredLoweringPassSet.autodiff) + { + // Generate warnings for potentially incorrect or badly-performing autodiff patterns. + checkAutodiffPatterns(targetProgram, irModule, sink); + } + // Next, we need to ensure that the code we emit for // the target doesn't contain any operations that would // be illegal on the target platform. For example, diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 9a69d9aa8..35a197f29 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -147,7 +147,7 @@ namespace Slang if (!checkpointHint) checkpointHint = originalFunc->findDecoration<IRCheckpointHintDecoration>(); if (checkpointHint) - builder->addDecoration(existingPrimalFunc, checkpointHint->getOp()); + cloneCheckpointHint(builder, checkpointHint, cast<IRGlobalValueWithCode>(existingPrimalFunc)); builder->emitBlock(); params = _defineFuncParams(builder, as<IRFunc>(existingPrimalFunc)); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index b7c2037e5..bf83d8d7f 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1123,19 +1123,33 @@ IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness( return nullptr; } - void copyCheckpointHints(IRBuilder* builder, IRGlobalValueWithCode* oldInst, IRGlobalValueWithCode* newInst) { for (auto decor = oldInst->getFirstDecoration(); decor; decor = decor->getNextDecoration()) { if (auto chkHint = as<IRCheckpointHintDecoration>(decor)) { - SLANG_ASSERT(chkHint->getOperandCount() == 0); - builder->addDecoration(newInst, chkHint->getOp()); + cloneCheckpointHint(builder, chkHint, newInst); } } } +void cloneCheckpointHint(IRBuilder* builder, IRCheckpointHintDecoration* chkHint, IRGlobalValueWithCode* target) +{ + // Grab all the operands + List<IRInst*> operands; + for (UCount operand = 0; operand < chkHint->getOperandCount(); operand++) + { + operands.add(chkHint->getOperand(operand)); + } + + builder->addDecoration( + target, + chkHint->getOp(), + operands.getBuffer(), + operands.getCount()); +} + void stripDerivativeDecorations(IRInst* inst) { for (auto decor = inst->getFirstDecoration(); decor; ) @@ -2096,6 +2110,46 @@ protected: }; +void checkAutodiffPatterns( + TargetProgram* target, + IRModule* module, + DiagnosticSink* sink) +{ + SLANG_UNUSED(target); + + enum SideEffectBehavior + { + Warn = 0, + Allow = 1, + }; + + // For now, we have only 1 check to see if methods that have side-effects + // are marked with prefer-recompute + // + for (auto inst : module->getGlobalInsts()) + { + if (auto func = as<IRFunc>(inst)) + { + if (func->sourceLoc.isValid() && // Don't diagnose for synthesized functions + func->findDecoration<IRPreferRecomputeDecoration>() && + !func->findDecoration<IRNoSideEffectDecoration>()) + { + auto preferRecomputeDecor = func->findDecoration<IRPreferRecomputeDecoration>(); + auto sideEffectBehavior = as<IRIntLit>(preferRecomputeDecor->getOperand(0))->getValue(); + + if (sideEffectBehavior == SideEffectBehavior::Allow) + continue; + + // Find function name. (don't diagnose on nameless functions) + if (auto nameHint = func->findDecoration<IRNameHintDecoration>()) + { + sink->diagnose(func, Diagnostics::potentialIssuesWithPreferRecomputeOnSideEffectMethod, nameHint->getName()); + } + } + } + } +} + bool processAutodiffCalls( TargetProgram* target, IRModule* module, diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 23ae717be..812471fe3 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -363,6 +363,11 @@ struct IRAutodiffPassOptions // Nothing for now... }; +void checkAutodiffPatterns( + TargetProgram* target, + IRModule* module, + DiagnosticSink* sink); + bool processAutodiffCalls( TargetProgram* target, IRModule* module, @@ -375,6 +380,8 @@ bool finalizeAutoDiffPass(TargetProgram* target, IRModule* module); void copyCheckpointHints(IRBuilder* builder, IRGlobalValueWithCode* oldInst, IRGlobalValueWithCode* newInst); +void cloneCheckpointHint(IRBuilder* builder, IRCheckpointHintDecoration* oldInst, IRGlobalValueWithCode* code); + void stripDerivativeDecorations(IRInst* inst); bool isBackwardDifferentiableFunc(IRInst* func); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index b9d7a898f..31427e616 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -10203,9 +10203,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { getBuilder()->addDecoration(irFunc, kIROp_PreferCheckpointDecoration); } - else if (as<PreferRecomputeAttribute>(modifier)) + else if (auto attr = as<PreferRecomputeAttribute>(modifier)) { - getBuilder()->addDecoration(irFunc, kIROp_PreferRecomputeDecoration); + getBuilder()->addDecoration( + irFunc, + kIROp_PreferRecomputeDecoration, + getBuilder()->getIntValue( + getBuilder()->getIntType(), + attr->sideEffectBehavior)); } else if (auto extensionMod = as<RequiredGLSLExtensionModifier>(modifier)) getBuilder()->addRequireGLSLExtensionDecoration(irFunc, extensionMod->extensionNameToken.getContent()); diff --git a/tests/autodiff/warn-on-prefer-recompute-side-effects.slang b/tests/autodiff/warn-on-prefer-recompute-side-effects.slang new file mode 100644 index 000000000..38543e67f --- /dev/null +++ b/tests/autodiff/warn-on-prefer-recompute-side-effects.slang @@ -0,0 +1,47 @@ +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -line-directive-mode none + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +groupshared float s_shared; + +[BackwardDifferentiable] +[PreferRecompute] +float get_thread_5_value(float v, uint group_thread_id) +{ + if(group_thread_id == 5) + { + s_shared = detach(v); + // CHECK: tests/autodiff/warn-on-prefer-recompute-side-effects.slang(10): warning 42050: get_thread_5_value has [PreferRecompute] and may have side effects. side effects may execute multiple times. use [PreferRecompute(SideEffectBehavior.Allow)], or mark function with [NoSideEffect] + // CHECK: float get_thread_5_value(float v, uint group_thread_id) + // CHECK: ^~~~~~~~~~~~~~~~~~ + } + GroupMemoryBarrierWithGroupSync(); + return s_shared; +} + +[BackwardDifferentiable] +[PreferRecompute(SideEffectBehavior.Allow)] // Suppress warning here +float get_thread_6_value(float v, uint group_thread_id) +{ + if (group_thread_id == 6) + { + s_shared = detach(v); + // CHECK-NOT: warning 42050 + + } + GroupMemoryBarrierWithGroupSync(); + return s_shared; +} + +[shader("compute")] +[numthreads(128, 1, 1)] +void computeMain(uint3 group_thread_id: SV_GroupThreadID, uint3 dispatch_thread_id: SV_DispatchThreadID) +{ + DifferentialPair<float> value = diffPair(3.f, 0.f); + + bwd_diff(get_thread_5_value)(value, group_thread_id.x, 1.0f); + bwd_diff(get_thread_6_value)(value, group_thread_id.x, 1.0f); + + outputBuffer[dispatch_thread_id.x] = value.d; +}
\ No newline at end of file |
