From 5c28677ff8bb1ab498954795ae3907f3b6c3b03f Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 30 May 2023 21:15:55 -0700 Subject: Fix type checking & loop value hoisting (#2907) * Fix type checking crash in language server. * Fix loop var hoisting logic. Fixes #2903. * fix. --------- Co-authored-by: Yong He --- source/slang/slang-check-constraint.cpp | 2 + source/slang/slang-check-overload.cpp | 13 ++-- source/slang/slang-ir-autodiff-cfg-norm.cpp | 30 +++++---- source/slang/slang-ir-validate.cpp | 9 ++- source/slang/slang-stdlib.cpp | 2 +- tests/autodiff/reverse-loop-diff-only-3.slang | 74 ++++++++++++++++++++++ .../reverse-loop-diff-only-3.slang.expected.txt | 6 ++ 7 files changed, 119 insertions(+), 17 deletions(-) create mode 100644 tests/autodiff/reverse-loop-diff-only-3.slang create mode 100644 tests/autodiff/reverse-loop-diff-only-3.slang.expected.txt diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index cdffcf004..42d01a996 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -792,6 +792,8 @@ namespace Slang Type* fst, Type* snd) { + if (!fst) return false; + if (fst->equals(snd)) return true; // An error type can unify with anything, just so we avoid cascading errors. diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 4bd8506ed..0d10b05be 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -409,11 +409,12 @@ namespace Slang auto& arg = context.getArg(ii); auto argType = context.getArgType(ii); auto paramType = paramTypes[ii]; - + if (!paramType) + return false; + if (!argType) + return false; if (context.mode == OverloadResolveContext::Mode::JustTrying) { - SLANG_ASSERT(argType); - ConversionCost cost = kConversionCost_None; if( context.disallowNestedConversions ) { @@ -1656,7 +1657,11 @@ namespace Slang for( UInt aa = 0; aa < argCount; ++aa ) { if(aa != 0) argsListBuilder << ", "; - context.getArgType(aa)->toText(argsListBuilder); + auto argType = context.getArgType(aa); + if (argType) + context.getArgType(aa)->toText(argsListBuilder); + else + argsListBuilder << "error"; } argsListBuilder << ")"; return argsListBuilder.produceString(); diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index 1727f7d8f..30c8a934e 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -709,24 +709,30 @@ void normalizeCFG( IRBuilder builder(func); for (auto b : workList) { - for (auto inst : b->getChildren()) + for (auto inst : b->getModifiableChildren()) { // If inst has uses outside the loop body, we need to hoist it. IRVar* tempVar = nullptr; - for (auto use = inst->firstUse; use; use = use->nextUse) + if (auto var = as(inst)) { - auto userBlock = as(use->getUser()->getParent()); - if (userBlock && !workListSet.contains(userBlock)) + for (auto use = inst->firstUse; use; use = use->nextUse) { - // Hoist the inst. - if (auto var = as(inst)) + // If inst is an var, this is easy, we just move it to the + // loop header. + auto userBlock = as(use->getUser()->getParent()); + if (userBlock && !workListSet.contains(userBlock)) { - // If inst is an var, this is easy, we just move it to the - // loop header. var->insertBefore(insertionPoint); break; } - else + } + } + else + { + traverseUses(inst, [&](IRUse* use) + { + auto userBlock = as(use->getUser()->getParent()); + if (userBlock && !workListSet.contains(userBlock)) { // For all other insts, we need to create a local var for it. if (!tempVar) @@ -741,8 +747,7 @@ void normalizeCFG( auto load = builder.emitLoad(tempVar); builder.replaceOperand(use, load); } - break; - } + }); } } } @@ -751,6 +756,9 @@ void normalizeCFG( disableIRValidationAtInsert(); constructSSA(module, func); enableIRValidationAtInsert(); +#if _DEBUG + validateIRInst(func); +#endif } } // namespace Slang diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index 18229c9b6..bf1ce1956 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -144,7 +144,14 @@ namespace Slang auto operandParent = operandValue->getParent(); - if (auto instParentBlock = as(instParent)) + auto instParentBlock = as(instParent); + if (!instParentBlock && as(inst)) + { + instParent = instParent->getParent(); + instParentBlock = as(instParent); + } + + if (instParentBlock) { if (auto operandParentBlock = as(operandParent)) { diff --git a/source/slang/slang-stdlib.cpp b/source/slang/slang-stdlib.cpp index f80254ba6..1f8d0a97a 100644 --- a/source/slang/slang-stdlib.cpp +++ b/source/slang/slang-stdlib.cpp @@ -221,7 +221,7 @@ namespace Slang static const IntrinsicOpInfo intrinsicUnaryOps[] = { { kIROp_Neg, "neg", "-", "__BuiltinArithmeticType", ARITHMETIC_MASK }, { kIROp_Not, "logicalNot", "!", nullptr, BOOL_MASK | BOOL_RESULT }, - { kIROp_BitNot, "not", "~", "__BuiltinIntegerType", INT_MASK }, + { kIROp_BitNot, "not", "~", "__BuiltinLogicalType", INT_MASK }, }; static const IntrinsicOpInfo intrinsicBinaryOps[] = { diff --git a/tests/autodiff/reverse-loop-diff-only-3.slang b/tests/autodiff/reverse-loop-diff-only-3.slang new file mode 100644 index 000000000..e36b6c09c --- /dev/null +++ b/tests/autodiff/reverse-loop-diff-only-3.slang @@ -0,0 +1,74 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +typedef DifferentialPair dpfloat; +typedef float.Differential dfloat; + +// Test that compute does not have a context. +// CHECK-NOT: struct {{[a-zA-Z0-9_]*}}_compute_{{[a-zA-Z0-9_]*}} + +[BackwardDifferentiable] +[PreferRecompute] +float compute(float x, float y, out float k) +{ + k = y * 2; + return x * y; +} + +// Test that computeLoop compiles to just return 0. +// CHECK: float computeLoop{{[_0-9]*}}(float y{{[_0-9]*}}) +// CHECK-NOT: for{{.*}} +// CHECK: return 0 + +[BackwardDifferentiable] +[PreferRecompute] +float computeLoop(float y) +{ + float w = 0; + int i = 0; + [MaxIters(8)] + do + { + float k = float(0.f); + w += compute(i, y, k); + w += k * k; + i++; + } + while (i < 8); + + return w - detach(w); +} + +// Since computeLoop is recomputed, test_simple_loop should have nothing to store +// therefore we check that there is no intermediate context type generated for test_simple_loop. + +// CHECK-NOT: struct {{[a-zA-Z0-9_]*}}test_simple_loop{{[a-zA-Z0-9_]*}} +[BackwardDifferentiable] +float test_simple_loop(float y) +{ + float x = computeLoop(y); + return y + x; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(1.0, 0.0); + + __bwd_diff(test_simple_loop)(dpa, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 29.0 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_simple_loop)(dpa, 0.5f); + outputBuffer[1] = dpa.d; // Expect: 14.5 + } + + outputBuffer[2] = computeLoop(1.0); +} diff --git a/tests/autodiff/reverse-loop-diff-only-3.slang.expected.txt b/tests/autodiff/reverse-loop-diff-only-3.slang.expected.txt new file mode 100644 index 000000000..fedac0520 --- /dev/null +++ b/tests/autodiff/reverse-loop-diff-only-3.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +93.000000 +27.300001 +0.000000 +0.000000 +0.000000 -- cgit v1.2.3