diff options
| -rw-r--r-- | source/slang/core.meta.slang | 4 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 68 | ||||
| -rw-r--r-- | source/slang/slang-ir-inline.cpp | 23 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 19 | ||||
| -rw-r--r-- | tests/autodiff-dstdlib/dstdlib-sqrt.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt | 2 |
6 files changed, 68 insertions, 50 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index c45ad5bd6..9b2932446 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -87,6 +87,9 @@ interface __BuiltinArithmeticType : __BuiltinType { /// Initialize from a 32-bit signed integer value. __init(int value); + + /// Initialize from the same type. + __init(This value); } /// A type that can be used for logical/bitwise operations @@ -382,7 +385,6 @@ ${{{{ } }}}} { - ${{{{ // Declare initializers to convert from various other types for (int ss = 0; ss < kBaseTypeCount; ++ss) diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 2bdaccee3..cb87156f5 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -740,6 +740,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve [ForwardDerivativeOf(NAME)] \ DifferentialPair<T> __d_##NAME(DifferentialPair<T> dpx) \ { \ + typealias ReturnType = T; \ return DifferentialPair<T>(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ @@ -747,40 +748,29 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve [ForwardDerivativeOf(NAME)] \ DifferentialPair<vector<T, N>> __d_##NAME##_vector(DifferentialPair<vector<T, N>> dpx) \ { \ - vector<T, N> result; \ - vector<T, N>.Differential d_result; \ - [ForceUnroll] for (int i = 0; i < N; ++i) \ - { \ - DifferentialPair<T> dp_elem = __d_##NAME( \ - DifferentialPair<T>(dpx.p[i], __slang_noop_cast<T.Differential>(dpx.d[i]))); \ - result[i] = dp_elem.p; \ - d_result[i] = __slang_noop_cast<T>(dp_elem.d); \ - } \ - return DifferentialPair<vector<T, N>>(result, d_result); \ + typealias ReturnType = vector<T, N>; \ + return DifferentialPair<ReturnType>(NAME(dpx.p), FWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ [BackwardDifferentiable] \ [ForwardDerivativeOf(NAME)] \ - DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpx) \ + DifferentialPair<matrix<T, M, N>> __d_##NAME##_m(DifferentialPair<matrix<T, M, N>> dpm) \ { \ - matrix<T, M, N> result; \ - matrix<T, M, N>.Differential d_result; \ - [ForceUnroll] for (int i = 0; i < M; ++i) \ - [ForceUnroll] for (int j = 0; j < N; ++j) \ + typealias ReturnType = vector<T,N>; \ + matrix<T,M,N>.Differential diff; \ + [ForceUnroll] for (int i = 0; i < M; i++) \ { \ - DifferentialPair<T> dp_elem = __d_##NAME( \ - DifferentialPair<T>(dpx.p[i][j], \ - __slang_noop_cast<T.Differential>(dpx.d[i][j]))); \ - result[i][j] = dp_elem.p; \ - d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \ + var dpx = diffPair(dpm.p[i], dpm.d[i]); \ + diff[i] = FWD_DIFF_FUNC; \ } \ - return DifferentialPair<matrix<T, M, N>>(result, d_result); \ + return diffPair(NAME(dpm.p), diff); \ } \ __generic<T : __BuiltinFloatingPointType> \ [BackwardDifferentiable] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME(inout DifferentialPair<T> dpx, T.Differential dOut) \ { \ + typealias ReturnType = T; \ dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let N : int> \ @@ -789,32 +779,26 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve void __d_##NAME##_vector( \ inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) \ { \ - vector<T, N>.Differential d_result; \ - [ForceUnroll] for (int i = 0; i < N; ++i) \ - { \ - DifferentialPair<T> dp_elem = diffPair(dpx.p[i], T.dzero()); \ - __d_##NAME(dp_elem, __slang_noop_cast<T.Differential>(dOut[i])); \ - d_result[i] = __slang_noop_cast<T>(dp_elem.d); \ - } \ - dpx = diffPair(dpx.p, d_result); \ + typealias ReturnType = vector<T, N>; \ + dpx = diffPair(dpx.p, BWD_DIFF_FUNC); \ } \ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \ [BackwardDifferentiable] \ [BackwardDerivativeOf(NAME)] \ void __d_##NAME##_matrix( \ - inout DifferentialPair<matrix<T, M, N>> dpx, matrix<T, M, N>.Differential dOut) \ + inout DifferentialPair<matrix<T, M, N>> m, matrix<T, M, N>.Differential mdOut) \ { \ - matrix<T, M, N>.Differential d_result; \ - [ForceUnroll] for (int i = 0; i < M; ++i) \ - [ForceUnroll] for (int j = 0; j < N; ++j) \ + typealias ReturnType = vector<T, N>; \ + matrix<T, M, N>.Differential diff; \ + [ForceUnroll] for (int i = 0; i < M; i++) \ { \ - DifferentialPair<T> dp_elem = diffPair(dpx.p[i][j], T.dzero()); \ - __d_##NAME(dp_elem, __slang_noop_cast<T.Differential>(dOut[i][j])); \ - d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \ + var dpx = diffPair(m.p[i], m.d[i]); \ + var dOut = mdOut[i]; \ + diff[i] = BWD_DIFF_FUNC; \ } \ - dpx = diffPair(dpx.p, d_result); \ + m = diffPair(m.p, diff); \ } -#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, T.dmul(DIFF_FUNC, dpx.d), T.dmul(DIFF_FUNC, dOut)) +#define SIMPLE_UNARY_DERIVATIVE_IMPL(NAME, DIFF_FUNC) UNARY_DERIVATIVE_IMPL(NAME, ReturnType.dmul(DIFF_FUNC, dpx.d), ReturnType.dmul(DIFF_FUNC, dOut)) // Detach and set derivatives to zero __generic<T : IDifferentiable> @@ -824,9 +808,9 @@ T detach(T x); #define SLANG_SQR(x) ((x)*(x)) // Absolute value -UNARY_DERIVATIVE_IMPL(abs, (dpx.p > T(0.0) ? dpx.d : T.dmul(T(-1.0), dpx.d)), (T.dmul(__slang_noop_cast<T>(sign(dpx.p)), dOut))) +UNARY_DERIVATIVE_IMPL(abs, select(dpx.p > T(0.0), dpx.d, ReturnType.dmul(T(-1.0), dpx.d)), (ReturnType.dmul(__slang_noop_cast<ReturnType>(sign(dpx.p)), dOut))) // Saturate -UNARY_DERIVATIVE_IMPL(saturate, (dpx.p < T(0.0) || dpx.p > T(1.0) ? T.dzero() : dpx.d), (dpx.p < T(0.0) || dpx.p > T(1.0) ? T.dzero() : dOut)) +UNARY_DERIVATIVE_IMPL(saturate, select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dpx.d), select(dpx.p < T(0.0) || dpx.p > T(1.0), ReturnType.dzero(), dOut)) // frac UNARY_DERIVATIVE_IMPL(frac, dpx.d, dOut) // raidans, degrees @@ -849,9 +833,9 @@ SIMPLE_UNARY_DERIVATIVE_IMPL(log, T(1.0) / dpx.p) SIMPLE_UNARY_DERIVATIVE_IMPL(log10, T(1.0) / (dpx.p * T(52.3025850929940456840179914546844))) SIMPLE_UNARY_DERIVATIVE_IMPL(log2, T(1.0) / (dpx.p * T(50.69314718055994530941723212145818))) // Square root -SIMPLE_UNARY_DERIVATIVE_IMPL(sqrt, (dpx.p < T(1e-7) ? T(0.0) : T(0.5) / sqrt(dpx.p))) +SIMPLE_UNARY_DERIVATIVE_IMPL(sqrt, T(0.5) / sqrt(max(ReturnType(T(1e-7)), dpx.p))) // Reciprocal -SIMPLE_UNARY_DERIVATIVE_IMPL(rcp, (dpx.p < T(1e-7) ? T(0.0) : T(-1.0) / (dpx.p * dpx.p))) +SIMPLE_UNARY_DERIVATIVE_IMPL(rcp, T(-1.0) / max(ReturnType(T(1e-7)), dpx.p * dpx.p)) // rsqrt SIMPLE_UNARY_DERIVATIVE_IMPL(rsqrt, T(-0.5) / (dpx.p * sqrt(dpx.p))) // Arc-sin diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index d8486d7aa..5223e35cf 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -49,13 +49,26 @@ struct InliningPassBase changed = considerCallSite(call); } - // Note: we defensively iterate through the child instructions - // so that even if `child` gets removed (because of inlining) - // we automatically start at the next instruction after it. + // Note: we iterate until no more changes can be applied. + // This is defensive against changes made by inlining one callsite + // and make sure we get to process all callsites. // - for (auto child : inst->getModifiableChildren()) + for (;;) { - changed |= considerAllCallSitesRec(child); + bool changedInThisIteration = false; + // Note: getModifiableChildren will skip any insts that are no + // longer the chhild of `inst`. If we process one callsite, the + // remaining insts of the block will be moved into a different + // block and therefore we won't process them during this iteration. + // However, those callsites will eventually be processed + // by the outer loop. + for (auto child : inst->getModifiableChildren()) + { + changedInThisIteration = considerAllCallSitesRec(child); + changed |= changedInThisIteration; + } + if (!changedInThisIteration) + break; } return changed; } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 3affcff44..f0c30dd3c 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -6168,6 +6168,22 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> #undef IGNORED_CASE + void ensureInsertAtGlobalScope(IRBuilder* builder) + { + auto inst = builder->getInsertLoc().getInst(); + if (inst->getOp() == kIROp_Module) + return; + + while (inst && inst->getParent() && inst->getParent()->getOp() != kIROp_Module) + { + inst = inst->getParent(); + } + if (inst) + { + builder->setInsertBefore(inst); + } + } + LoweredValInfo visitTypeDefDecl(TypeDefDecl* decl) { // A type alias declaration may be generic, if it is @@ -6176,6 +6192,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> NestedContext nested(this); auto subBuilder = nested.getBuilder(); auto subContext = nested.getContext(); + + ensureInsertAtGlobalScope(nested.getBuilder()); + IRGeneric* outerGeneric = emitOuterGenerics(subContext, decl, decl); // TODO: if a type alias declaration can have linkage, diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang b/tests/autodiff-dstdlib/dstdlib-sqrt.slang index d68a2697c..ee3fb94b7 100644 --- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang +++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang @@ -41,7 +41,7 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) dpfloat dpx = dpfloat(0.0, 1.0); dpfloat res = __fwd_diff(diffSqrt)(dpx); outputBuffer[4] = res.p; // Expect: 0.000000 - outputBuffer[5] = res.d; // Expect: 0.000000 + outputBuffer[5] = res.d; // Expect: 1581.138916 } { diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt index 7e0fdf02f..7b38d1520 100644 --- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt +++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt @@ -4,7 +4,7 @@ type: float 10.000000 -0.150000 0.000000 -0.000000 +1581.138916 0.158114 0.577350 0.250000 |
