summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-09-23 12:11:45 -0400
committerGitHub <noreply@github.com>2023-09-23 12:11:45 -0400
commitab04bd0dd7dd6a818bbac8c5fef9372c4f597352 (patch)
treed37f49273bc48c55ea3e16a243817907af0ebcbc /source
parent263f807285c93272abb0c0352be6f8553f01a373 (diff)
More `slangpy` features + polishing (#3233)
* Update user-guide with new slangpy features * More polishing of new slangpy docs * Update a1-02-slangpy.md * Only require contiguity for vector element types * Added `loadOnce/storeOnce` and subscript operations * Added docs, `DiffTensorView.dims()` & `DiffTensorView.stride(uint)` * Add constructors, remove storeOnce/loadOnce test * Adjusted intrinsic definitions
Diffstat (limited to 'source')
-rw-r--r--source/slang/diff.meta.slang383
-rw-r--r--source/slang/slang-check-overload.cpp12
-rw-r--r--source/slang/slang-emit-torch.cpp7
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp1
4 files changed, 268 insertions, 135 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 75c57018c..5fe1440e6 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -88,6 +88,11 @@ struct TensorView
[__NoSideEffect]
T load(uint i0, uint i1, uint i2, uint i3, uint i4);
+ [__NoSideEffect]
+ __generic<let N : int>
+ __target_intrinsic(cuda, "$0.load<$TR>($1)")
+ T load(vector<uint, N> index);
+
__target_intrinsic(cuda, "$0.store<$G0>($1, $2)")
void store(uint x, T val);
__target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3)")
@@ -99,6 +104,11 @@ struct TensorView
__target_intrinsic(cuda, "$0.store<$G0>($1, $2, $3, $4, $5, $6)")
void store(uint i0, uint i1, uint i2, uint i3, uint i4, T val);
+ [__NoSideEffect]
+ __generic<let N : int>
+ __target_intrinsic(cuda, "$0.store<$T1>($1)")
+ void store(vector<uint, N> index, T val);
+
__target_intrinsic(cuda, "*($3) = atomicAdd($0.data_ptr_at<$T2>($1), $2)")
void InterlockedAdd(uint index, T val, out T oldVal);
@@ -266,165 +276,173 @@ extension TensorView<float>
interface IDiffTensorWrapper
{
- __generic<T : __BuiltinFloatingPointType>
- T load_forward(uint offset);
+ // Derivatives for universal load/store operations.
__generic<T : __BuiltinFloatingPointType>
- T load_forward_2(uint2 offset);
+ T load_forward(uint i);
- __generic<T : __BuiltinFloatingPointType>
- T load_forward_3(uint3 offset);
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ T load_forward(vector<uint, N> i);
__generic<T : __BuiltinFloatingPointType>
- T load_forward_4(uint4 offset);
+ void load_backward(uint i, T dOut);
- __generic<T : __BuiltinFloatingPointType>
- void load_backward(uint offset, T dOut);
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ void load_backward(vector<uint, N> i, T dOut);
__generic<T : __BuiltinFloatingPointType>
- void load_backward_2(uint2 offset, T dOut);
+ void store_forward(uint i, T dx);
- __generic<T : __BuiltinFloatingPointType>
- void load_backward_3(uint3 offset, T dOut);
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ void store_forward(vector<uint, N> i, T dx);
__generic<T : __BuiltinFloatingPointType>
- void load_backward_4(uint4 offset, T dOut);
+ T store_backward(uint i);
- __generic<T : __BuiltinFloatingPointType>
- void store_forward(uint offset, T dx);
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ T store_backward(vector<uint, N> i);
- __generic<T : __BuiltinFloatingPointType>
- void store_forward_2(uint2 offset, T dx);
+ // Derivatives for loadOnce/storeOnce operations. These operations
+ // are designed to only run once per-address and don't need atomic
+ // gradient handling.
+ //
__generic<T : __BuiltinFloatingPointType>
- void store_forward_3(uint3 offset, T dx);
+ T loadOnce_forward(uint i);
- __generic<T : __BuiltinFloatingPointType>
- void store_forward_4(uint4 offset, T dx);
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ T loadOnce_forward(vector<uint, N> i);
__generic<T : __BuiltinFloatingPointType>
- T store_backward(uint offset);
+ void loadOnce_backward(uint i, T dOut);
- __generic<T : __BuiltinFloatingPointType>
- T store_backward_2(uint2 offset);
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ void loadOnce_backward(vector<uint, N> i, T dOut);
__generic<T : __BuiltinFloatingPointType>
- T store_backward_3(uint3 offset);
+ void storeOnce_forward(uint i, T dx);
+
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ void storeOnce_forward(vector<uint, N> i, T dx);
__generic<T : __BuiltinFloatingPointType>
- T store_backward_4(uint4 offset);
+ T storeOnce_backward(uint i);
+
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ T storeOnce_backward(vector<uint, N> i);
};
struct AtomicAdd : IDiffTensorWrapper
{
TensorView<float> diff;
+ // Derivatives for universal load/store operations.
+
__generic<T : __BuiltinFloatingPointType>
T load_forward(uint i)
{
return __realCast<T, float>(diff.load(i));
}
- __generic<T : __BuiltinFloatingPointType>
- T load_forward_2(uint2 i)
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ T load_forward(vector<uint, N> i)
{
- return __realCast<T, float>(diff.load(i.x, i.y));
+ return __realCast<T, float>(diff.load(i));
}
__generic<T : __BuiltinFloatingPointType>
- T load_forward_3(uint3 i)
+ void load_backward(uint i, T dOut)
{
- return __realCast<T, float>(diff.load(i.x, i.y, i.z));
+ float oldVal;
+ diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
}
- __generic<T : __BuiltinFloatingPointType>
- T load_forward_4(uint4 i)
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ void load_backward(vector<uint, N> i, T dOut)
{
- return __realCast<T, float>(diff.load(i.x, i.y, i.z, i.w));
+ float oldVal;
+ diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
}
__generic<T : __BuiltinFloatingPointType>
- void load_backward(uint i, T dOut)
+ void store_forward(uint i, T dx)
{
- float oldVal;
- diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
+ diff.store(i, __realCast<float, T>(dx));
}
- __generic<T : __BuiltinFloatingPointType>
- void load_backward_2(uint2 i, T dOut)
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ void store_forward(vector<uint, N> i, T dx)
{
- float oldVal;
- diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
+ diff.store(i, __realCast<float, T>(dx));
}
__generic<T : __BuiltinFloatingPointType>
- void load_backward_3(uint3 i, T dOut)
+ T store_backward(uint i)
{
float oldVal;
- diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
+ diff.InterlockedExchange(i, (float)0, oldVal);
+ return __realCast<T, float>(oldVal);
}
- __generic<T : __BuiltinFloatingPointType>
- void load_backward_4(uint4 i, T dOut)
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ T store_backward(vector<uint, N> i)
{
float oldVal;
- diff.InterlockedAdd(i, __realCast<float, T>(dOut), oldVal);
+ diff.InterlockedExchange(i, (float)0, oldVal);
+ return __realCast<T, float>(oldVal);
}
+ // Derivatives for loadOnce/storeOnce operations. These operations
+ // are designed to only run once per-address and don't need atomic
+ // gradient handling.
+ //
+
__generic<T : __BuiltinFloatingPointType>
- void store_forward(uint i, T dx)
+ T loadOnce_forward(uint i)
{
- diff.store(i, __realCast<float, T>(dx));
+ return __realCast<T, float>(diff.load(i));
}
- __generic<T : __BuiltinFloatingPointType>
- void store_forward_2(uint2 i, T dx)
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ T loadOnce_forward(vector<uint, N> i)
{
- diff.store(i.x, i.y, __realCast<float, T>(dx));
+ return __realCast<T, float>(diff.load(i));
}
__generic<T : __BuiltinFloatingPointType>
- void store_forward_3(uint3 i, T dx)
+ void loadOnce_backward(uint i, T dOut)
{
- diff.store(i.x, i.y, i.z, __realCast<float, T>(dx));
+ diff.store(i, __realCast<float, T>(dOut));
}
- __generic<T : __BuiltinFloatingPointType>
- void store_forward_4(uint4 i, T dx)
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ void loadOnce_backward(vector<uint, N> i, T dOut)
{
- diff.store(i.x, i.y, i.z, i.w, __realCast<float, T>(dx));
+ diff.store(i, __realCast<float, T>(dOut));
}
__generic<T : __BuiltinFloatingPointType>
- T store_backward(uint i)
+ void storeOnce_forward(uint i, T dx)
{
- float oldVal;
- diff.InterlockedExchange(i, (float)0, oldVal);
- return __realCast<T, float>(oldVal);
+ diff.store(i, __realCast<float, T>(dx));
}
- __generic<T : __BuiltinFloatingPointType>
- T store_backward_2(uint2 i)
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ void storeOnce_forward(vector<uint, N> i, T dx)
{
- float oldVal;
- diff.InterlockedExchange(i, (float)0, oldVal);
- return __realCast<T, float>(oldVal);
+ diff.store(i, __realCast<float, T>(dx));
}
__generic<T : __BuiltinFloatingPointType>
- T store_backward_3(uint3 i)
+ T storeOnce_backward(uint i)
{
- float oldVal;
- diff.InterlockedExchange(i, (float)0, oldVal);
- return __realCast<T, float>(oldVal);
+ return __realCast<T, float>(diff.load(i));
}
- __generic<T : __BuiltinFloatingPointType>
- T store_backward_4(uint4 i)
+ __generic<T : __BuiltinFloatingPointType, let N : int>
+ T storeOnce_backward(vector<uint, N> i)
{
- float oldVal;
- diff.InterlockedExchange(i, (float)0, oldVal);
- return __realCast<T, float>(oldVal);
+ return __realCast<T, float>(diff.load(i));
}
};
@@ -439,120 +457,223 @@ struct DiffTensorView
return primal.size(i);
}
- [BackwardDerivative(load_backward)]
- [ForwardDerivative(load_forward)]
- T load(uint i) { return primal.load(i); }
+ uint dims()
+ {
+ return primal.dims();
+ }
- [BackwardDerivative(load_backward)]
- [ForwardDerivative(load_forward)]
- T load(uint2 i) { return primal.load(i.x, i.y); }
+ uint stride(uint i)
+ {
+ return primal.stride(i);
+ }
- [BackwardDerivative(load_backward)]
- [ForwardDerivative(load_forward)]
- T load(uint3 i) { return primal.load(i.x, i.y, i.z); }
+ // Constructors
+ __init(TensorView<T> primal, A diff)
+ {
+ this.primal = primal;
+ this.diff = diff;
+ }
- [BackwardDerivative(load_backward)]
- [ForwardDerivative(load_forward)]
- T load(uint4 i) { return primal.load(i.x, i.y, i.z, i.w); }
+ __init(TensorView<T> primal)
+ {
+ this.primal = primal;
+ }
+
+ // Universal load/store operations.
+
+ [BackwardDerivative(__load_backward)]
+ [ForwardDerivative(__load_forward)]
+ T load(uint i) { return primal.load(i); }
+
+ [BackwardDerivative(__load_backward)]
+ [ForwardDerivative(__load_forward)]
+ __generic<let N : int>
+ T load(vector<uint, N> i) { return primal.load(i); }
- DifferentialPair<T> load_forward(uint x)
+ DifferentialPair<T> __load_forward(uint x)
{
return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T>(x)));
}
- DifferentialPair<T> load_forward(uint2 x)
+ __generic<let N : int>
+ DifferentialPair<T> __load_forward(vector<uint, N> x)
{
- return diffPair(primal.load(x.x, x.y), reinterpret<T.Differential, T>(diff.load_forward_2<T>(x)));
+ return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T, N>(x)));
}
- DifferentialPair<T> load_forward(uint3 x)
+ void __load_backward(uint x, T.Differential dOut)
{
- return diffPair(primal.load(x.x, x.y, x.z), reinterpret<T.Differential, T>(diff.load_forward_3<T>(x)));
+ diff.load_backward<T>(x, reinterpret<T, T.Differential>(dOut));
}
- DifferentialPair<T> load_forward(uint4 x)
+ __generic<let N : int>
+ void __load_backward(vector<uint, N> x, T.Differential dOut)
{
- return diffPair(primal.load(x.x, x.y, x.z, x.w), reinterpret<T.Differential, T>(diff.load_forward_4<T>(x)));
+ diff.load_backward<T, N>(x, reinterpret<T, T.Differential>(dOut));
}
- void load_backward(uint x, T.Differential dOut)
+ [BackwardDerivative(__store_backward)]
+ [ForwardDerivative(__store_forward)]
+ void store(uint x, T val) { primal.store(x, val); }
+
+ [BackwardDerivative(__store_backward)]
+ [ForwardDerivative(__store_forward)]
+ __generic<let N : int>
+ void store(vector<uint, N> x, T val) { primal.store(x, val); }
+
+ void __store_forward(uint x, DifferentialPair<T> dpval)
{
- diff.load_backward<T>(x, reinterpret<T, T.Differential>(dOut));
+ primal.store(x, dpval.p);
+ diff.store_forward<T>(x, reinterpret<T, T.Differential>(dpval.d));
}
- void load_backward(uint2 x, T.Differential dOut)
+ __generic<let N : int>
+ void __store_forward(vector<uint, N> x, DifferentialPair<T> dpval)
{
- diff.load_backward_2<T>(x, reinterpret<T, T.Differential>(dOut));
+ primal.store(x, dpval.p);
+ diff.store_forward<T, N>(x, reinterpret<T, T.Differential>(dpval.d));
}
- void load_backward(uint3 x, T.Differential dOut)
+ void __store_backward(uint x, inout DifferentialPair<T> dpval)
{
- diff.load_backward_3<T>(x, reinterpret<T, T.Differential>(dOut));
+ dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward<T>(x)));
}
- void load_backward(uint4 x, T.Differential dOut)
+ __generic<let N : int>
+ void __store_backward(vector<uint, N> x, inout DifferentialPair<T> dpval)
{
- diff.load_backward_4<T>(x, reinterpret<T, T.Differential>(dOut));
+ dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward<T, N>(x)));
}
- [BackwardDerivative(store_backward)]
- [ForwardDerivative(store_forward)]
- void store(uint x, T val) { primal.store(x, val); }
+ __subscript(uint index)->T
+ {
+ [__unsafeForceInlineEarly] [Differentiable] [__NoSideEffect] get { return load(index); }
+ [__unsafeForceInlineEarly] [Differentiable] set { store(index, newValue); }
- [BackwardDerivative(store_backward)]
- [ForwardDerivative(store_forward)]
- void store(uint2 x, T val) { primal.store(x.x, x.y, val); }
+ [__NoSideEffect]
+ ref;
+ }
- [BackwardDerivative(store_backward)]
- [ForwardDerivative(store_forward)]
- void store(uint3 x, T val) { primal.store(x.x, x.y, x.z, val); }
+ __subscript(uint2 index)->T
+ {
+ [__unsafeForceInlineEarly] [Differentiable] [__NoSideEffect] get { return load(index); }
+ [__unsafeForceInlineEarly] [Differentiable] set { store(index, newValue); }
- [BackwardDerivative(store_backward)]
- [ForwardDerivative(store_forward)]
- void store(uint4 x, T val) { primal.store(x.x, x.y, x.z, x.w, val); }
+ [__NoSideEffect]
+ ref;
+ }
- void store_forward(uint x, DifferentialPair<T> dpval)
+ __subscript(uint x, uint y)->T
{
- primal.store(x, dpval.p);
- diff.store_forward<T>(x, reinterpret<T, T.Differential>(dpval.d));
+ [__unsafeForceInlineEarly] [Differentiable] [__NoSideEffect] get { return load(uint2(x, y)); }
+ [__unsafeForceInlineEarly] [Differentiable] set { store(uint2(x, y), newValue); }
+
+ [__NoSideEffect]
+ ref;
}
- void store_forward(uint2 x, DifferentialPair<T> dpval)
+ __subscript(uint3 index)->T
{
- primal.store(x.x, x.y, dpval.p);
- diff.store_forward_2<T>(x, reinterpret<T, T.Differential>(dpval.d));
+ [__unsafeForceInlineEarly] [Differentiable] [__NoSideEffect] get { return load(index); }
+ [__unsafeForceInlineEarly] [Differentiable] set { store(index, newValue); }
+
+ [__NoSideEffect]
+ ref;
}
- void store_forward(uint3 x, DifferentialPair<T> dpval)
+ __subscript(uint x, uint y, uint z)->T
{
- primal.store(x.x, x.y, x.z, dpval.p);
- diff.store_forward_3<T>(x, reinterpret<T, T.Differential>(dpval.d));
+ [__unsafeForceInlineEarly] [Differentiable] [__NoSideEffect] get { return load(uint3(x, y, z)); }
+ [__unsafeForceInlineEarly] [Differentiable] set { store(uint3(x, y, z), newValue); }
+
+ [__NoSideEffect]
+ ref;
}
- void store_forward(uint4 x, DifferentialPair<T> dpval)
+ __subscript(uint4 index)->T
{
- primal.store(x.x, x.y, x.z, x.w, dpval.p);
- diff.store_forward_4<T>(x, reinterpret<T, T.Differential>(dpval.d));
+ [__unsafeForceInlineEarly] [Differentiable] [__NoSideEffect] get { return load(index); }
+ [__unsafeForceInlineEarly] [Differentiable] set { store(index, newValue); }
+
+ [__NoSideEffect]
+ ref;
}
- void store_backward(uint x, inout DifferentialPair<T> dpval)
+ __subscript(uint x, uint y, uint z, uint w)->T
{
- dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward<T>(x)));
+ [__unsafeForceInlineEarly] [Differentiable] [__NoSideEffect] get { return load(uint4(x, y, z, w)); }
+ [__unsafeForceInlineEarly] [Differentiable] set { store(uint4(x, y, z, w), newValue); }
+
+ [__NoSideEffect]
+ ref;
}
- void store_backward(uint2 x, inout DifferentialPair<T> dpval)
+ // loadOnce/storeOnce operations. These operations are designed to only run once per-address and
+ // don't need atomic gradient handling.
+ //
+
+ [BackwardDerivative(__loadOnce_backward)]
+ [ForwardDerivative(__loadOnce_forward)]
+ T loadOnce(uint i) { return primal.load(i); }
+
+ [BackwardDerivative(__loadOnce_backward)]
+ [ForwardDerivative(__loadOnce_forward)]
+ __generic<let N : int>
+ T loadOnce(vector<uint, N> i) { return primal.load(i); }
+
+ DifferentialPair<T> __loadOnce_forward(uint x)
{
- dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_2<T>(x)));
+ return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.loadOnce_forward<T>(x)));
}
- void store_backward(uint3 x, inout DifferentialPair<T> dpval)
+ __generic<let N : int>
+ DifferentialPair<T> __loadOnce_forward(vector<uint, N> x)
{
- dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_3<T>(x)));
+ return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.loadOnce_forward<T, N>(x)));
}
- void store_backward(uint4 x, inout DifferentialPair<T> dpval)
+ void __loadOnce_backward(uint x, T.Differential dOut)
+ {
+ diff.loadOnce_backward<T>(x, reinterpret<T, T.Differential>(dOut));
+ }
+
+ __generic<let N : int>
+ void __loadOnce_backward(vector<uint, N> x, T.Differential dOut)
+ {
+ diff.loadOnce_backward<T, N>(x, reinterpret<T, T.Differential>(dOut));
+ }
+
+ [BackwardDerivative(__storeOnce_backward)]
+ [ForwardDerivative(__storeOnce_forward)]
+ void storeOnce(uint x, T val) { primal.store(x, val); }
+
+ [BackwardDerivative(__storeOnce_backward)]
+ [ForwardDerivative(__storeOnce_forward)]
+ __generic<let N : int>
+ void storeOnce(vector<uint, N> x, T val) { primal.store(x, val); }
+
+ void __storeOnce_forward(uint x, DifferentialPair<T> dpval)
+ {
+ primal.store(x, dpval.p);
+ diff.storeOnce_forward<T>(x, reinterpret<T, T.Differential>(dpval.d));
+ }
+
+ __generic<let N : int>
+ void __storeOnce_forward(vector<uint, N> x, DifferentialPair<T> dpval)
+ {
+ primal.store(x, dpval.p);
+ diff.storeOnce_forward<T, N>(x, reinterpret<T, T.Differential>(dpval.d));
+ }
+
+ void __storeOnce_backward(uint x, inout DifferentialPair<T> dpval)
+ {
+ dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.storeOnce_backward<T>(x)));
+ }
+
+ __generic<let N : int>
+ void __storeOnce_backward(vector<uint, N> x, inout DifferentialPair<T> dpval)
{
- dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward_4<T>(x)));
+ dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.storeOnce_backward<T, N>(x)));
}
};
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index d7ed5975d..c668155df 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -964,10 +964,14 @@ namespace Slang
{
auto leftType = DeclRefType::create(m_astBuilder, left.declRef.getParent());
auto rightType = DeclRefType::create(m_astBuilder, right.declRef.getParent());
- if (isSubtype(leftType, rightType))
- return -1;
- if (isSubtype(rightType, leftType))
- return 1;
+
+ if (!leftType->equals(rightType))
+ {
+ if (isSubtype(leftType, rightType))
+ return -1;
+ if (isSubtype(rightType, leftType))
+ return 1;
+ }
}
// TODO: We should generalize above rules such that in a tie a declaration
diff --git a/source/slang/slang-emit-torch.cpp b/source/slang/slang-emit-torch.cpp
index 7cd793ec1..54408aa80 100644
--- a/source/slang/slang-emit-torch.cpp
+++ b/source/slang/slang-emit-torch.cpp
@@ -118,6 +118,13 @@ bool TorchCppSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo&
emitStringLiteral(getUnmangledName(inst->getOperand(0)));
m_writer->emit(", ");
emitTorchScalarTypeName(m_writer, inst->getOperand(0)->getDataType());
+ m_writer->emit(", ");
+
+ if (as<IRVectorType>(inst->getOperand(0)->getDataType()))
+ m_writer->emit("true");
+ else
+ m_writer->emit("false");
+
m_writer->emit(")");
return true;
}
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index a9db3aecc..e7c269756 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -619,6 +619,7 @@ struct CFGNormalizationPass
// SLANG_UNEXPECTED("Switch-case normalization not implemented yet.");
BreakableRegionInfo info;
info.breakBlock = as<IRSwitch>(branchInst)->getBreakLabel();
+ info.headerBlock = as<IRBlock>(branchInst->getParent());
// Emit var into parent block.
builder.setInsertBefore(as<IRBlock>(branchInst->getParent())->getTerminator());