summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-11-10 13:55:14 -0800
committerGitHub <noreply@github.com>2023-11-10 13:55:14 -0800
commit011d4281647e3a2a3cf0dbdda1fa65cc1b8ed881 (patch)
tree70f91655e86d30529eda0a683e15f378eeae2cb5
parentbfd3f39d04047d7a46e75206cd125ed87b3f3f99 (diff)
Cleanup builtin arithmetic interfaces. (#3317)
* wip: clean up IArithmetic * wip. * Cleanup builtin arithmetic interfaces. * Fix. * Fixes. * Fix. * Fix. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/core.meta.slang579
-rw-r--r--source/slang/slang-ast-modifier.h12
-rw-r--r--source/slang/slang-ast-support-types.h3
-rw-r--r--source/slang/slang-check-expr.cpp4
-rw-r--r--source/slang/slang-check-modifier.cpp13
-rw-r--r--source/slang/slang-check-overload.cpp14
-rw-r--r--source/slang/slang-ir-constexpr.cpp11
-rw-r--r--source/slang/slang-ir-peephole.cpp6
-rw-r--r--source/slang/slang-syntax.cpp19
-rw-r--r--tests/autodiff/generic-constructor.slang39
-rw-r--r--tests/ir/loop-inversion.slang8
-rw-r--r--tests/language-feature/generics/iarray.slang12
12 files changed, 487 insertions, 233 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 0f600d7c6..4dff9888d 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -74,6 +74,91 @@ syntax snorm : SNormModifier;
///
syntax __extern_cpp : ExternCppModifier;
+interface IComparable
+{
+ bool equals(This other);
+ bool lessThan(This other);
+ bool lessThanOrEquals(This other);
+}
+
+interface IRangedValue
+{
+ static const This maxValue;
+ static const This minValue;
+}
+
+__attributeTarget(DeclBase)
+attribute_syntax [TreatAsDifferentiable] : TreatAsDifferentiableAttribute;
+
+interface IArithmetic : IComparable
+{
+ This add(This other);
+ This sub(This other);
+ This mul(This other);
+ This div(This other);
+ This mod(This other);
+ This neg();
+
+ __init(int val);
+
+ /// Initialize from the same type.
+ __init(This value);
+}
+
+interface ILogical : IComparable
+{
+ This shl(int value);
+ This shr(int value);
+ This bitAnd(This other);
+ This bitOr(This other);
+ This bitXor(This other);
+ This bitNot();
+ This and(This other);
+ This or(This other);
+ This not();
+ __init(int val);
+}
+
+interface IInteger : IArithmetic, ILogical
+{
+ int toInt();
+ int64_t toInt64();
+ uint toUInt();
+ uint64_t toUInt64();
+}
+
+interface IFloat : IArithmetic, IDifferentiable
+{
+ [TreatAsDifferentiable]
+ __init(float value);
+
+ [TreatAsDifferentiable]
+ float toFloat();
+
+ [TreatAsDifferentiable]
+ This add(This other);
+
+ [TreatAsDifferentiable]
+ This sub(This other);
+
+ [TreatAsDifferentiable]
+ This mul(This other);
+
+ [TreatAsDifferentiable]
+ This div(This other);
+
+ [TreatAsDifferentiable]
+ This mod(This other);
+
+ [TreatAsDifferentiable]
+ This neg();
+
+ [TreatAsDifferentiable]
+ __init(This value);
+
+ [TreatAsDifferentiable]
+ This scale<T:__BuiltinFloatingPointType>(T scale);
+}
/// A type that can be used as an operand for builtins
[sealed]
@@ -83,22 +168,15 @@ interface __BuiltinType {}
/// A type that can be used for arithmetic operations
[sealed]
[builtin]
-interface __BuiltinArithmeticType : __BuiltinType
+interface __BuiltinArithmeticType : __BuiltinType, IArithmetic
{
- /// Initialize from a 32-bit signed integer value.
- __init(int value);
-
- /// Initialize from the same type.
- __init(This value);
}
/// A type that can be used for logical/bitwise operations
[sealed]
[builtin]
-interface __BuiltinLogicalType : __BuiltinType
+interface __BuiltinLogicalType : __BuiltinType, ILogical
{
- /// Initialize from a 32-bit signed integer value.
- __init(int value);
}
/// A type that logically has a sign (positive/negative/zero)
@@ -243,12 +321,10 @@ struct DifferentialPair : IDifferentiable
[sealed]
[builtin]
[TreatAsDifferentiable]
-interface __BuiltinFloatingPointType : __BuiltinRealType, IDifferentiable
+interface __BuiltinFloatingPointType : __BuiltinRealType, IFloat
{
- /// Initialize from a 32-bit floating-point value.
- __init(float value);
-
/// Get the value of the mathematical constant pi in this type.
+ [Differentiable]
static This getPi();
}
@@ -365,7 +441,8 @@ __generic<T, let N : int> __intrinsic_op(select) vector<T,N> select(vector<bool,
// Allow real-number types to be cast into each other
__intrinsic_op($(kIROp_FloatCast))
T __realCast<T : __BuiltinRealType, U : __BuiltinRealType>(U val);
-
+__intrinsic_op($(kIROp_IntCast))
+ T __intCast<T : __BuiltinType, U : __BuiltinType>(U val);
${{{{
// We are going to use code generation to produce the
// declarations for all of our base types.
@@ -446,8 +523,12 @@ ${{{{
if (kBaseTypes[tt].tag == BaseType::Double &&
kBaseTypes[ss].tag == BaseType::Float)
builtinConversionKind = kBuiltinConversion_FloatToDouble;
-}}}}
+ const char* attrib = "";
+ if ((kBaseTypes[tt].flags & kBaseTypes[ss].flags & FLOAT_MASK) != 0)
+ attrib = "[TreatAsDifferentiable]";
+}}}}
+ $(attrib)
__intrinsic_op($(intrinsicOpCode))
__implicit_conversion($(conversionCost), $(builtinConversionKind))
__init($(kBaseTypes[ss].name) value);
@@ -455,23 +536,71 @@ ${{{{
${{{{
}
- // If this is a basic integer type, then define explicit
- // initializers that take a value of an `enum` type.
- //
- // TODO: This should actually be restricted, so that this
- // only applies `where T.__Tag == Self`, but we don't have
- // the needed features in our type system to implement
- // that constraint right now.
- //
+ // Integer type implementations.
switch (kBaseTypes[tt].tag)
{
- // TODO: should this cover the full gamut of integer types?
- case BaseType::Int:
+ case BaseType::Bool:
+}}}}
+ [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Eql)) bool equals(This other);
+ [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Less)) bool lessThan(This other);
+ [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Leq)) bool lessThanOrEquals(This other);
+ [__unsafeForceInlineEarly] This shl(int other) { return __intCast<This>(__shl(__intCast<int>(this), other)); }
+ [__unsafeForceInlineEarly] This shr(int other) { return __intCast<This>(__shr(__intCast<int>(this), other)); }
+ [__unsafeForceInlineEarly] This bitAnd(This other) { return __intCast<This>(__and(__intCast<int>(this), __intCast<int>(other))); }
+ [__unsafeForceInlineEarly] This bitOr(This other) { return __intCast<This>(__or(__intCast<int>(this), __intCast<int>(other))); }
+ [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_And)) This and(This other);
+ [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Or)) This or(This other);
+ [__unsafeForceInlineEarly] This bitXor(This other) { return __intCast<This>(__xor(__intCast<int>(this), __intCast<int>(other))); }
+ [__unsafeForceInlineEarly] This bitNot() { return __intCast<This>(__not(__intCast<int>(this))); }
+ [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Not)) This not();
+${{{{
+ break;
+ case BaseType::UInt8:
+ case BaseType::UInt16:
case BaseType::UInt:
+ case BaseType::UInt64:
+ case BaseType::Int8:
+ case BaseType::Int16:
+ case BaseType::Int:
+ case BaseType::Int64:
+ case BaseType::IntPtr:
+ case BaseType::UIntPtr:
}}}}
+ // If this is a basic integer type, then define explicit
+ // initializers that take a value of an `enum` type.
+ //
+ // TODO: This should actually be restricted, so that this
+ // only applies `where T.__Tag == Self`, but we don't have
+ // the needed features in our type system to implement
+ // that constraint right now.
+ //
__generic<T:__EnumType>
__intrinsic_op($(kIROp_IntCast))
__init(T value);
+
+ // Implementation of the `IInteger` interface.
+ __intrinsic_op($(kIROp_Less)) bool lessThan(This other);
+ __intrinsic_op($(kIROp_Leq)) bool lessThanOrEquals(This other);
+ __intrinsic_op($(kIROp_Eql)) bool equals(This other);
+ __intrinsic_op($(kIROp_Add)) This add(This other);
+ __intrinsic_op($(kIROp_Sub)) This sub(This other);
+ __intrinsic_op($(kIROp_Mul)) This mul(This other);
+ __intrinsic_op($(kIROp_Div)) This div(This other);
+ __intrinsic_op($(kIROp_FRem)) This mod(This other);
+ __intrinsic_op($(kIROp_Neg)) This neg();
+ __intrinsic_op($(kIROp_Lsh)) This shl(int other);
+ __intrinsic_op($(kIROp_Rsh)) This shr(int other);
+ __intrinsic_op($(kIROp_BitAnd)) This bitAnd(This other);
+ __intrinsic_op($(kIROp_BitOr)) This bitOr(This other);
+ [__unsafeForceInlineEarly] This and(This other) {return __intCast<This>(__intCast<bool>(this) && __intCast<bool>(other)); }
+ [__unsafeForceInlineEarly] This or(This other) {return __intCast<This>(__intCast<bool>(this) || __intCast<bool>(other)); }
+ __intrinsic_op($(kIROp_BitXor)) This bitXor(This other);
+ __intrinsic_op($(kIROp_BitNot)) This bitNot();
+ [__unsafeForceInlineEarly] This not() {return __intCast<This>(!__intCast<bool>(this)); }
+ __intrinsic_op($(kIROp_IntCast)) int toInt();
+ __intrinsic_op($(kIROp_IntCast)) int64_t toInt64();
+ __intrinsic_op($(kIROp_IntCast)) uint toUInt();
+ __intrinsic_op($(kIROp_IntCast)) uint64_t toUInt64();
${{{{
break;
@@ -480,8 +609,8 @@ ${{{{
}
// If this is a floating-point type, then we need to
- // define the basic `getPi()` function that is used
- // to implement generic versions of `degrees()` and
+ // implement the `IFloat` interface, which defines the basic `getPi()`
+ // function that is used to implement generic versions of `degrees()` and
// `radians()`.
//
switch (kBaseTypes[tt].tag)
@@ -492,8 +621,20 @@ ${{{{
case BaseType::Float:
case BaseType::Double:
}}}}
+ [TreatAsDifferentiable]
static $(kBaseTypes[tt].name) getPi() { return $(kBaseTypes[tt].name)(3.14159265358979323846264338328); }
+ __intrinsic_op($(kIROp_Less)) bool lessThan(This other);
+ __intrinsic_op($(kIROp_Leq)) bool lessThanOrEquals(This other);
+ __intrinsic_op($(kIROp_Eql)) bool equals(This other);
+ __intrinsic_op($(kIROp_Add)) This add(This other);
+ __intrinsic_op($(kIROp_Sub)) This sub(This other);
+ __intrinsic_op($(kIROp_Mul)) This mul(This other);
+ __intrinsic_op($(kIROp_Div)) This div(This other);
+ __intrinsic_op($(kIROp_FRem)) This mod(This other);
+ __intrinsic_op($(kIROp_Neg)) This neg();
+ __intrinsic_op($(kIROp_FloatCast)) float toFloat();
+ __intrinsic_op($(kIROp_Mul)) This scale<T:__BuiltinFloatingPointType>(T s) { return __mul(this, __realCast<This>(s)); }
typedef $(kBaseTypes[tt].name) Differential;
[__unsafeForceInlineEarly]
@@ -628,7 +769,7 @@ __generic<T>
__intrinsic_op($(kIROp_Eql))
bool operator==(Ptr<T> p1, Ptr<T> p2);
-extension bool
+extension bool : IRangedValue
{
__generic<T>
__implicit_conversion($(kConversionCost_PtrToBool))
@@ -639,7 +780,7 @@ extension bool
static const bool minValue = false;
}
-extension uint64_t
+extension uint64_t : IRangedValue
{
__generic<T>
__intrinsic_op($(kIROp_CastPtrToInt))
@@ -649,7 +790,7 @@ extension uint64_t
static const uint64_t minValue = 0;
}
-extension int64_t
+extension int64_t : IRangedValue
{
__generic<T>
__intrinsic_op($(kIROp_CastPtrToInt))
@@ -659,7 +800,7 @@ extension int64_t
static const int64_t minValue = -0x8000000000000000LL;
}
-extension intptr_t
+extension intptr_t : IRangedValue
{
__generic<T>
__intrinsic_op($(kIROp_CastPtrToInt))
@@ -669,7 +810,7 @@ extension intptr_t
static const int size = $(SLANG_PROCESSOR_X86_64?"8":"4");
}
-extension uintptr_t
+extension uintptr_t : IRangedValue
{
__generic<T>
__intrinsic_op($(kIROp_CastPtrToInt))
@@ -844,56 +985,56 @@ __intrinsic_type($(kIROp_DynamicType))
struct __Dynamic
{};
-extension half
+extension half : IRangedValue
{
static const half maxValue = half(65504);
static const half minValue = half(-65504);
}
-extension float
+extension float : IRangedValue
{
static const float maxValue = 340282346638528859811704183484516925440.0f;
static const float minValue = -340282346638528859811704183484516925440.0f;
}
-extension double
+extension double : IRangedValue
{
static const double maxValue = 179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368.0;
static const double minValue = -179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368.0;
}
-extension int
+extension int : IRangedValue
{
static const int maxValue = 2147483647;
static const int minValue = -2147483648;
}
-extension uint
+extension uint : IRangedValue
{
static const uint maxValue = 4294967295;
static const uint minValue = 0;
}
-extension int8_t
+extension int8_t : IRangedValue
{
static const int8_t maxValue = 127;
static const int8_t minValue = -128;
}
-extension uint8_t
+extension uint8_t : IRangedValue
{
static const uint8_t maxValue = 255;
static const uint8_t minValue = 0;
}
-extension uint16_t
+extension uint16_t : IRangedValue
{
static const uint16_t maxValue = 65535;
static const uint16_t minValue = 0;
}
-extension int16_t
+extension int16_t : IRangedValue
{
static const int16_t maxValue = 32767;
static const int16_t minValue = -32768;
@@ -909,13 +1050,15 @@ struct vector : IArray<T>
/// Initialize a vector where all elements have the same scalar `value`.
+ [TreatAsDifferentiable]
__implicit_conversion($(kConversionCost_ScalarToVector))
__intrinsic_op($(kIROp_MakeVectorFromScalar))
__init(T value);
- /// Initialize a vector from a value of the same type
+ /// Initialize a vector from a value of the same type
// TODO: we should revise semantic checking so this kind of "identity" conversion is not required
__intrinsic_op(0)
+ [TreatAsDifferentiable]
__init(vector<T,N> value);
[ForceInline]
@@ -933,14 +1076,115 @@ __magic_type(MatrixExpressionType)
struct matrix : IArray<vector<T,C>>
{
__intrinsic_op($(kIROp_MakeMatrixFromScalar))
+ [TreatAsDifferentiable]
__init(T val);
+ /// Initialize a vector from a value of the same type
+ // TODO: we should revise semantic checking so this kind of "identity" conversion is not required
+ __intrinsic_op(0)
+ [TreatAsDifferentiable]
+ __init(This value);
+
[ForceInline]
int getCount() { return R; }
__subscript(int index) -> vector<T,C> { __intrinsic_op($(kIROp_GetElement)) get; }
}
+__generic<T:__BuiltinFloatingPointType, let N : int>
+extension vector<T,N> : IFloat
+{
+ [__unsafeForceInlineEarly] bool lessThan(This other) { return this < other; }
+ [__unsafeForceInlineEarly] bool lessThanOrEquals(This other) { return this <= other; }
+ [__unsafeForceInlineEarly] bool equals(This other) { return all(this == other); }
+ __intrinsic_op($(kIROp_Add)) This add(This other);
+ __intrinsic_op($(kIROp_Sub)) This sub(This other);
+ __intrinsic_op($(kIROp_Mul)) This mul(This other);
+ __intrinsic_op($(kIROp_Div)) This div(This other);
+ __intrinsic_op($(kIROp_FRem)) This mod(This other);
+ __intrinsic_op($(kIROp_Neg)) This neg();
+ __intrinsic_op($(kIROp_Mul)) This scale<T1:__BuiltinFloatingPointType>(T1 s);
+ [__unsafeForceInlineEarly] float toFloat() { return __realCast<float>(this[0]); }
+
+ [OverloadRank(-1)]
+ [__unsafeForceInlineEarly] __init(int v) { return vector<T,N>(T(v)); }
+ [OverloadRank(-1)]
+ [__unsafeForceInlineEarly] __init(float v) { return vector<T,N>(T(v)); }
+
+ // IDifferentiable
+
+ typedef vector<T, N> Differential;
+
+ [__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
+ static Differential dzero()
+ {
+ return Differential(__slang_noop_cast<T>(T.dzero()));
+ }
+
+ [__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ __generic<U : __BuiltinRealType>
+ [__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
+ static Differential dmul(U a, Differential b)
+ {
+ return __realCast<T, U>(a) * b;
+ }
+}
+
+__generic<T:__BuiltinFloatingPointType, let N : int, let M : int, let L : int>
+extension matrix<T,N,M,L> : IFloat
+{
+ [TreatAsDifferentiable][__unsafeForceInlineEarly] bool lessThan(This other) { return this < other; }
+ [TreatAsDifferentiable][__unsafeForceInlineEarly] bool lessThanOrEquals(This other) { return this <= other; }
+ [TreatAsDifferentiable][__unsafeForceInlineEarly] bool equals(This other) { return all(this == other); }
+ [TreatAsDifferentiable] __intrinsic_op($(kIROp_Add)) This add(This other);
+ [TreatAsDifferentiable] __intrinsic_op($(kIROp_Sub)) This sub(This other);
+ [TreatAsDifferentiable] __intrinsic_op($(kIROp_Mul))This mul(This other);
+ [TreatAsDifferentiable] __intrinsic_op($(kIROp_Div)) This div(This other);
+ [TreatAsDifferentiable] __intrinsic_op($(kIROp_FRem)) This mod(This other);
+ [TreatAsDifferentiable] __intrinsic_op($(kIROp_Neg)) This neg();
+ [TreatAsDifferentiable] __intrinsic_op($(kIROp_Mul)) This scale<T1:__BuiltinFloatingPointType>(T1 s);
+ [TreatAsDifferentiable][__unsafeForceInlineEarly] This scale<T1:__BuiltinFloatingPointType>(T1 s);
+ [TreatAsDifferentiable][__unsafeForceInlineEarly] float toFloat() { return __realCast<float>(this[0][0]); }
+
+ [OverloadRank(-1)]
+ [TreatAsDifferentiable][__unsafeForceInlineEarly] __init(int v) { return matrix<T,N,M>(T(v)); }
+ [OverloadRank(-1)]
+ [TreatAsDifferentiable][__unsafeForceInlineEarly] __init(float v) { return matrix<T,N,M>(T(v)); }
+
+ // IDifferentiable.
+ typedef matrix<T, N,M,L> Differential;
+
+ [__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
+ static Differential dzero()
+ {
+ return matrix<T, N,M,L>(__slang_noop_cast<T>(T.dzero()));
+ }
+
+ [__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return a + b;
+ }
+
+ __generic<U : __BuiltinRealType>
+ [__unsafeForceInlineEarly]
+ [BackwardDifferentiable]
+ static Differential dmul(U a, Differential b)
+ {
+ return __realCast<T, U>(a) * b;
+ }
+}
+
${{{{
static const struct {
char const* name;
@@ -1184,7 +1428,7 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt)
auto toType = kBaseTypes[tt].name;
}}}}
-__generic<let R : int, let C : int> extension matrix<$(toType),R,C>
+__generic<let R : int, let C : int, let L : int> extension matrix<$(toType),R,C,L>
{
${{{{
for (int ff = 0; ff < kBaseTypeCount; ++ff)
@@ -1202,7 +1446,7 @@ ${{{{
}}}}
__implicit_conversion($(cost))
__intrinsic_op($(op))
- __init(matrix<$(fromType),R,C> value);
+ __init(matrix<$(fromType),R,C,L> value);
${{{{
}
}}}}
@@ -1215,61 +1459,6 @@ __generic<T, U>
__intrinsic_op(0)
T __slang_noop_cast(U u);
-__generic<T:__BuiltinFloatingPointType, let N: int>
-extension vector<T, N> : IDifferentiable
-{
- typedef vector<T, N> Differential;
-
- [__unsafeForceInlineEarly]
- [BackwardDifferentiable]
- static Differential dzero()
- {
- return Differential(__slang_noop_cast<T>(T.dzero()));
- }
-
- [__unsafeForceInlineEarly]
- [BackwardDifferentiable]
- static Differential dadd(Differential a, Differential b)
- {
- return a + b;
- }
-
- __generic<U : __BuiltinRealType>
- [__unsafeForceInlineEarly]
- [BackwardDifferentiable]
- static Differential dmul(U a, Differential b)
- {
- return __realCast<T, U>(a) * b;
- }
-}
-
-__generic<T:__BuiltinFloatingPointType, let R: int, let C: int, let L : int>
-extension matrix<T, R, C, L> : IDifferentiable
-{
- typedef matrix<T, R, C, L> Differential;
-
- [__unsafeForceInlineEarly]
- [BackwardDifferentiable]
- static Differential dzero()
- {
- return matrix<T, R, C, L>(__slang_noop_cast<T>(T.dzero()));
- }
-
- [__unsafeForceInlineEarly]
- [BackwardDifferentiable]
- static Differential dadd(Differential a, Differential b)
- {
- return a + b;
- }
-
- __generic<U : __BuiltinRealType>
- [__unsafeForceInlineEarly]
- [BackwardDifferentiable]
- static Differential dmul(U a, Differential b)
- {
- return __realCast<T, U>(a) * b;
- }
-}
//@ public:
@@ -1333,17 +1522,20 @@ for (auto op : intrinsicUnaryOps)
{
char const* resultType = "T";
if (op.flags & BOOL_RESULT) resultType = "bool";
-
+
// scalar version
sb << "__generic<T : " << op.interface << ">\n";
+ sb << "[OverloadRank(10)]";
sb << "__prefix __intrinsic_op(" << int(op.opCode) << ") " << resultType << " operator" << op.opName << "(" << "T value);\n";
// vector version
sb << "__generic<T : " << op.interface << ", let N : int> ";
+ sb << "[OverloadRank(10)]";
sb << "__prefix __intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(vector<T,N> value);\n";
// matrix version
sb << "__generic<T : " << op.interface << ", let N : int, let M : int> ";
+ sb << "[OverloadRank(10)]";
sb << "__prefix __intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(matrix<T,N,M> value);\n";
}
}
@@ -1369,7 +1561,7 @@ Ptr<T> operator-(Ptr<T> value, int64_t offset)
return __getElementPtr(value, -offset);
}
-__generic<T : __BuiltinArithmeticType>
+__generic<T : IArithmetic>
[__unsafeForceInlineEarly]
__prefix T operator+(T value)
{ return value; }
@@ -1532,28 +1724,35 @@ for (auto op : intrinsicBinaryOps)
// scalar version
sb << "__generic<T : " << op.interface << ">\n";
+ sb << "[OverloadRank(10)]";
sb << "__intrinsic_op(" << int(op.opCode) << ") " << resultType << " operator" << op.opName << "(" << leftType << " left, " << rightType << " right);\n";
// vector version
sb << "__generic<T : " << op.interface << ", let N : int> ";
+ sb << "[OverloadRank(10)]";
sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(vector<" << leftType << ",N> left, vector<" << rightType << ",N> right);\n";
// matrix version
sb << "__generic<T : " << op.interface << ", let N : int, let M : int> ";
+ sb << "[OverloadRank(10)]";
sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(matrix<" << leftType << ",N,M> left, matrix<" << rightType << ",N,M> right);\n";
// scalar-vector and scalar-matrix
sb << "__generic<T : " << op.interface << ", let N : int> ";
+ sb << "[OverloadRank(10)]";
sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(" << leftType << " left, vector<" << rightType << ",N> right);\n";
sb << "__generic<T : " << op.interface << ", let N : int, let M : int> ";
+ sb << "[OverloadRank(10)]";
sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(" << leftType << " left, matrix<" << rightType << ",N,M> right);\n";
// vector-scalar and matrix-scalar
sb << "__generic<T : " << op.interface << ", let N : int> ";
+ sb << "[OverloadRank(10)]";
sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(vector<" << leftType << ",N> left, " << rightType << " right);\n";
sb << "__generic<T : " << op.interface << ", let N : int, let M : int> ";
+ sb << "[OverloadRank(10)]";
sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(matrix<" << leftType << ",N,M> left, " << rightType << " right);\n";
}
}
@@ -1848,152 +2047,164 @@ bool operator!=(E left, E right);
// public interfaces for generic arithmetic types.
-interface IComparable
-{
- bool equals(This other);
- bool lessThan(This other);
- bool lessThanOrEquals(This other);
-}
-
-__attributeTarget(DeclBase)
-attribute_syntax [TreatAsDifferentiable] : TreatAsDifferentiableAttribute;
-
-[TreatAsDifferentiable]
-interface IArithmetic : IComparable
-{
- This add(This other);
- This sub(This other);
- This mul(This other);
- This div(This other);
- This mod(This other);
- This neg();
- __init(int val);
- static const This maxValue;
- static const This minValue;
-}
-
-interface IInteger : IArithmetic
-{
- This shl(int value);
- This shr(int value);
- This bitAnd(This other);
- This bitOr(This other);
- This bitXor(This other);
- This bitNot();
- int toInt();
- int64_t toInt64();
- uint toUInt();
- uint64_t toUInt64();
-}
-
-interface IFloat : IArithmetic
-{
- __init(float value);
- float toFloat();
-}
-
__generic<T : IComparable>
[__unsafeForceInlineEarly]
+[OverloadRank(-10)]
bool operator<(T v0, T v1)
{
return v0.lessThan(v1);
}
__generic<T : IComparable>
[__unsafeForceInlineEarly]
+[OverloadRank(-10)]
bool operator>(T v0, T v1)
{
return v1.lessThan(v0);
}
__generic<T : IComparable>
[__unsafeForceInlineEarly]
+[OverloadRank(-10)]
bool operator ==(T v0, T v1)
{
return v0.equals(v1);
}
__generic<T : IComparable>
[__unsafeForceInlineEarly]
+[OverloadRank(-10)]
bool operator >=(T v0, T v1)
{
return v1.lessThan(v1);
}
__generic<T : IComparable>
[__unsafeForceInlineEarly]
+[OverloadRank(-10)]
bool operator <=(T v0, T v1)
{
return v0.lessThanOrEquals(v1);
}
__generic<T : IComparable>
[__unsafeForceInlineEarly]
+[OverloadRank(-10)]
bool operator !=(T v0, T v1)
{
return !v0.equals(v1);
}
-__generic<T : IArithmetic>
+${{{{
+const char* arithmeticInterfaces[] = {"IArithmetic", "IFloat"};
+const char* attribs[] = {"", "[TreatAsDifferentiable]"};
+for (Index i = 0; i < 2; i++) {
+ const auto interfaceName = arithmeticInterfaces[i];
+ const auto attrib = attribs[i];
+ Index overloadRank = i - 3;
+}}}}
+$(attrib)
+__generic<T : $(interfaceName)>
[__unsafeForceInlineEarly]
+[OverloadRank($(overloadRank))]
T operator +(T v0, T v1)
{
return v0.add(v1);
}
-__generic<T : IArithmetic>
+$(attrib)
+__generic<T : $(interfaceName)>
[__unsafeForceInlineEarly]
+[OverloadRank($(overloadRank))]
T operator -(T v0, T v1)
{
return v0.sub(v1);
}
-__generic<T : IArithmetic>
+$(attrib)
+__generic<T : $(interfaceName)>
[__unsafeForceInlineEarly]
+[OverloadRank($(overloadRank))]
T operator *(T v0, T v1)
{
return v0.mul(v1);
}
-__generic<T : IArithmetic>
+$(attrib)
+__generic<T : $(interfaceName)>
[__unsafeForceInlineEarly]
+[OverloadRank($(overloadRank))]
T operator /(T v0, T v1)
{
return v0.div(v1);
}
-__generic<T : IArithmetic>
+$(attrib)
+__generic<T : $(interfaceName)>
[__unsafeForceInlineEarly]
+[OverloadRank($(overloadRank))]
T operator %(T v0, T v1)
{
return v0.mod(v1);
}
-__generic<T : IArithmetic>
+$(attrib)
+__generic<T : $(interfaceName)>
[__unsafeForceInlineEarly]
+[OverloadRank($(overloadRank))]
__prefix T operator -(T v0)
{
return v0.neg();
}
-__generic<T : IInteger>
+
+${{{{
+ } // foreach ["IArithmetic", "IFloat"]
+}}}}
+
+__generic<T : ILogical>
[__unsafeForceInlineEarly]
+[OverloadRank(-10)]
T operator &(T v0, T v1)
{
return v0.bitAnd(v1);
}
-__generic<T : IInteger>
+__generic<T : ILogical>
[__unsafeForceInlineEarly]
+[OverloadRank(-10)]
+T operator &&(T v0, T v1)
+{
+ return v0.and(v1);
+}
+__generic<T : ILogical>
+[__unsafeForceInlineEarly]
+[OverloadRank(-10)]
T operator |(T v0, T v1)
{
return v0.bitOr(v1);
}
-__generic<T : IInteger>
+__generic<T : ILogical>
+[__unsafeForceInlineEarly]
+[OverloadRank(-10)]
+T operator ||(T v0, T v1)
+{
+ return v0.or(v1);
+}
+__generic<T : ILogical>
[__unsafeForceInlineEarly]
+[OverloadRank(-10)]
T operator ^(T v0, T v1)
{
return v0.bitXor(v1);
}
-__generic<T : IInteger>
+__generic<T : ILogical>
[__unsafeForceInlineEarly]
+[OverloadRank(-10)]
__prefix T operator ~(T v0)
{
return v0.bitNot();
}
+__generic<T : ILogical>
+[__unsafeForceInlineEarly]
+[OverloadRank(-10)]
+__prefix T operator !(T v0)
+{
+ return v0.not();
+}
// IR level type traits.
@@ -2089,61 +2300,6 @@ bool __isVector()
return __isVector_impl(__declVal<T>());
}
-// Provide implementations to public generic arithmetic interfaces for builtin types.
-
-${{{{
-// Code gen integer type implementations.
-
-for (int tt = 0; tt < kBaseTypeCount; ++tt)
-{
- if (kBaseTypes[tt].flags & (SINT_MASK | UINT_MASK))
- {
-}}}}
-extension $(kBaseTypes[tt].name) : IInteger
-{
- [__unsafeForceInlineEarly] bool equals(This other){return this==other;}
- [__unsafeForceInlineEarly] bool lessThan(This other){return this<other;}
- [__unsafeForceInlineEarly] bool lessThanOrEquals(This other){return this<=other;}
- [__unsafeForceInlineEarly] This add(This other) { return __add(this, other); }
- [__unsafeForceInlineEarly] This sub(This other) { return __sub(this, other); }
- [__unsafeForceInlineEarly] This mul(This other) { return __mul(this, other); }
- [__unsafeForceInlineEarly] This div(This other) { return __div(this, other); }
- [__unsafeForceInlineEarly] This mod(This other) { return __irem(this, other); }
- [__unsafeForceInlineEarly] This neg() { return __neg(this); }
- [__unsafeForceInlineEarly] This shl(int other) { return __shl(this, other); }
- [__unsafeForceInlineEarly] This shr(int other) { return __shr(this, other); }
- [__unsafeForceInlineEarly] This bitAnd(This other) { return __add(this, other); }
- [__unsafeForceInlineEarly] This bitOr(This other) { return __or(this, other); }
- [__unsafeForceInlineEarly] This bitXor(This other) { return __xor(this, other); }
- [__unsafeForceInlineEarly] This bitNot() { return __not(this); }
- [__unsafeForceInlineEarly] int toInt() { return int(this); }
- [__unsafeForceInlineEarly] int64_t toInt64() { return int64_t(this); }
- [__unsafeForceInlineEarly] uint toUInt() { return uint(this); }
- [__unsafeForceInlineEarly] uint64_t toUInt64() { return uint64_t(this); }
-}
-${{{{
- }
- else if (kBaseTypes[tt].flags & FLOAT_MASK)
- {
-}}}}
-
-extension $(kBaseTypes[tt].name) : IFloat
-{
- [__unsafeForceInlineEarly] bool lessThan(This other) { return this < other; }
- [__unsafeForceInlineEarly] bool lessThanOrEquals(This other) { return this <= other; }
- [__unsafeForceInlineEarly] bool equals(This other) { return this == other; }
- [__unsafeForceInlineEarly] This add(This other) { return __add(this, other); }
- [__unsafeForceInlineEarly] This sub(This other) { return __sub(this, other); }
- [__unsafeForceInlineEarly] This mul(This other) { return __mul(this, other); }
- [__unsafeForceInlineEarly] This div(This other) { return __div(this, other); }
- [__unsafeForceInlineEarly] This mod(This other) { return __frem(this, other); }
- [__unsafeForceInlineEarly] This neg() { return __neg(this); }
- [__unsafeForceInlineEarly] float toFloat() { return float(this); }
-}
-${{{{
- }
-}
-}}}}
// Binding Attributes
@@ -2309,6 +2465,9 @@ attribute_syntax [__unsafeForceInlineEarly] : UnsafeForceInlineEarlyAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForceInline] : ForceInlineAttribute;
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [OverloadRank] : OverloadRankAttribute;
+
__attributeTarget(FuncDecl)
attribute_syntax [DllImport(modulePath: String)] : DllImportAttribute;
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index f430c5f7a..4288017fa 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1075,6 +1075,18 @@ class AnyValueSizeAttribute : public Attribute
int32_t size;
};
+ /// This is a stop-gap solution to break overload ambiguity in stdlib.
+ /// When there is a function overload ambiguity, the compiler will pick the one with higher rank
+ /// specified by this attribute. An overload without this attribute will have a rank of 0.
+ /// In the future, we should enhance our type system to take into account the "specialized"-ness
+ /// of an overload, such that `T overload1<T:IDerived>()` is more specialized than `T overload2<T:IBase>()`
+ /// and preferred during overload resolution.
+class OverloadRankAttribute : public Attribute
+{
+ SLANG_AST_CLASS(OverloadRankAttribute)
+ int32_t rank;
+};
+
/// An attribute that marks an interface for specialization use only. Any operation that triggers dynamic
/// dispatch through the interface is a compile-time error.
class SpecializeAttribute : public Attribute
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 43b73892c..93c53a975 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -1482,6 +1482,9 @@ namespace Slang
// Cached dictionary for looking up satisfying values.
SLANG_UNREFLECTED RequirementDictionary m_requirementDictionary;
+
+ RefPtr<WitnessTable> specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst);
+
};
struct SpecializationParam
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 5ca0af1c9..2f8b28afc 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2225,11 +2225,15 @@ namespace Slang
{
// check the base expression first
expr->functionExpr = CheckTerm(expr->functionExpr);
+
+ auto treatAsDifferentiableExpr = m_treatAsDifferentiableExpr;
+ m_treatAsDifferentiableExpr = nullptr;
// Next check the argument expressions
for (auto & arg : expr->arguments)
{
arg = CheckTerm(arg);
}
+ m_treatAsDifferentiableExpr = treatAsDifferentiableExpr;
// If we are in a differentiable function, register differential witness tables involved in
// this call.
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index e8ae28c04..569804ff4 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -349,6 +349,19 @@ namespace Slang
anyValueSizeAttr->size = int32_t(value->getValue());
}
+ else if (auto overloadRankAttr = as<OverloadRankAttribute>(attr))
+ {
+ if (attr->args.getCount() != 1)
+ {
+ return false;
+ }
+ auto rank = checkConstantIntVal(attr->args[0]);
+ if (rank == nullptr)
+ {
+ return false;
+ }
+ overloadRankAttr->rank = int32_t(rank->getValue());
+ }
else if (auto bindingAttr = as<GLSLBindingAttribute>(attr))
{
// This must be vk::binding or gl::binding (as specified in core.meta.slang under vk_binding/gl_binding)
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index c668155df..27062fc0c 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1070,6 +1070,15 @@ namespace Slang
return 0;
}
+ int getOverloadRank(DeclRef<Decl> declRef)
+ {
+ if (!declRef.getDecl())
+ return 0;
+ if (auto attr = declRef.getDecl()->findModifier<OverloadRankAttribute>())
+ return attr->rank;
+ return 0;
+ }
+
int SemanticsVisitor::CompareOverloadCandidates(
OverloadCandidate* left,
OverloadCandidate* right)
@@ -1142,6 +1151,11 @@ namespace Slang
auto specificityDiff = compareOverloadCandidateSpecificity(left->item, right->item);
if(specificityDiff)
return specificityDiff;
+
+ // If we reach here, we will attempt to use overload rank to break the ties.
+ auto overloadRankDiff = getOverloadRank(right->item.declRef) - getOverloadRank(left->item.declRef);
+ if (overloadRankDiff)
+ return overloadRankDiff;
}
return 0;
diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp
index 34b56bfef..63ca32650 100644
--- a/source/slang/slang-ir-constexpr.cpp
+++ b/source/slang/slang-ir-constexpr.cpp
@@ -56,6 +56,9 @@ bool isConstExpr(IRInst* value)
case kIROp_FloatLit:
case kIROp_BoolLit:
case kIROp_Func:
+ case kIROp_StructKey:
+ case kIROp_WitnessTable:
+ case kIROp_Generic:
return true;
default:
@@ -136,6 +139,8 @@ bool opCanBeConstExpr(IROp op)
case kIROp_GetOptionalValue:
case kIROp_DifferentialPairGetDifferential:
case kIROp_DifferentialPairGetPrimal:
+ case kIROp_LookupWitness:
+ case kIROp_Specialize:
// TODO: more cases
return true;
@@ -146,10 +151,8 @@ bool opCanBeConstExpr(IROp op)
bool opCanBeConstExprByForwardPass(IRInst* value)
{
- // TODO: realistically need to special-case `call`
- // operations here, so that we check whether the
- // callee function is fixed/known, and if it is
- // whether it has been declared as constant-foldable
+ // TODO: handle call inst here.
+
if (value->getOp() == kIROp_Param)
return false;
return opCanBeConstExpr(value->getOp());
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index b6b7823c9..d7713618e 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -317,7 +317,7 @@ struct PeepholeContext : InstPassBase
}
else
{
- changed = tryFoldElementExtractFromUpdateInst(inst);
+ changed |= tryFoldElementExtractFromUpdateInst(inst);
}
break;
case kIROp_GetElement:
@@ -382,7 +382,7 @@ struct PeepholeContext : InstPassBase
}
else
{
- changed = tryFoldElementExtractFromUpdateInst(inst);
+ changed |= tryFoldElementExtractFromUpdateInst(inst);
}
break;
case kIROp_UpdateElement:
@@ -806,7 +806,7 @@ struct PeepholeContext : InstPassBase
case kIROp_Div:
case kIROp_And:
case kIROp_Or:
- changed = tryOptimizeArithmeticInst(inst);
+ changed |= tryOptimizeArithmeticInst(inst);
break;
case kIROp_Param:
{
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index d24fd239d..ed2ce048b 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -245,6 +245,22 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
return m_obj.as<WitnessTable>();
}
+ RefPtr<WitnessTable> WitnessTable::specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst)
+ {
+ auto newBaseType = baseType->substitute(astBuilder, subst);
+ auto newWitnessedType = witnessedType->substitute(astBuilder, subst);
+ if (newBaseType == baseType && newWitnessedType == witnessedType)
+ return this;
+ RefPtr<WitnessTable> result = new WitnessTable();
+ result->baseType = as<Type>(newBaseType);
+ result->witnessedType = as<Type>(newWitnessedType);
+ for (auto requirement : m_requirements)
+ {
+ auto newRequirement = requirement.value.specialize(astBuilder, subst);
+ result->add(requirement.key, newRequirement);
+ }
+ return result;
+ }
RequirementWitness RequirementWitness::specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst)
{
@@ -256,8 +272,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
return RequirementWitness();
case RequirementWitness::Flavor::witnessTable:
- SLANG_ASSERT(!subst);
- return *this;
+ return RequirementWitness(this->getWitnessTable()->specialize(astBuilder, subst));
case RequirementWitness::Flavor::declRef:
{
diff --git a/tests/autodiff/generic-constructor.slang b/tests/autodiff/generic-constructor.slang
new file mode 100644
index 000000000..aad9824ec
--- /dev/null
+++ b/tests/autodiff/generic-constructor.slang
@@ -0,0 +1,39 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+interface IFoo : IDifferentiable
+{
+ [Differentiable]
+ __init(Differential v);
+}
+
+struct Impl : IFoo
+{
+ float x;
+
+ [Differentiable]
+ __init(Differential v)
+ {
+ x = v.x;
+ }
+}
+
+[Differentiable]
+float test(float x)
+{
+ Impl.Differential v0 = { x };
+ var v1 = Impl(v0);
+ return v1.x * v1.x;
+}
+
+[numthreads(1,1,1)]
+void computeMain(uint tid : SV_DispatchThreadID)
+{
+ var p = diffPair(3.0, 0.0);
+ bwd_diff(test)(p, 1.0);
+ outputBuffer[tid] = p.d;
+ // CHECK: 6.0
+}
diff --git a/tests/ir/loop-inversion.slang b/tests/ir/loop-inversion.slang
index 03bdcc340..7e218a62a 100644
--- a/tests/ir/loop-inversion.slang
+++ b/tests/ir/loop-inversion.slang
@@ -19,7 +19,7 @@ RWStructuredBuffer<int> outputBuffer;
// A standard loop
// CHECK-LABEL: int a_{{.*}}()
// CHECK-NOT: break;
-// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
+// CHECK: int {{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
// CHECK: [[i]] + int(1);
// CHECK: if(
// CHECK: break;
@@ -35,7 +35,7 @@ int a()
// A vanilla while loop
// CHECK-LABEL: int b_{{.*}}()
// CHECK-NOT: break;
-// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
+// CHECK: int {{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
// CHECK: [[i]] + int(1);
// CHECK: if(
// CHECK: break;
@@ -55,7 +55,7 @@ int b()
// A while loop with a break on the false branch
// CHECK-LABEL: int c_{{.*}}()
// CHECK-NOT: break;
-// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
+// CHECK: int {{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
// CHECK: [[i]] + int(1);
// CHECK: if(
// CHECK: break;
@@ -79,7 +79,7 @@ int c()
// A while loop with a break on the true branch
// CHECK-LABEL: int d_{{.*}}()
// CHECK-NOT: break;
-// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
+// CHECK: int {{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
// CHECK: [[i]] + int(1);
// CHECK: if(
// CHECK: break;
diff --git a/tests/language-feature/generics/iarray.slang b/tests/language-feature/generics/iarray.slang
index c2314f106..b66c3ab27 100644
--- a/tests/language-feature/generics/iarray.slang
+++ b/tests/language-feature/generics/iarray.slang
@@ -1,7 +1,7 @@
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type
-T sum<T:__BuiltinArithmeticType>(IArray<T> array)
+T sum<T:IFloat>(IArray<T> array)
{
T result = T(0);
for (int i = 0; i < array.getCount(); i++)
@@ -10,15 +10,7 @@ T sum<T:__BuiltinArithmeticType>(IArray<T> array)
}
return result;
}
-vector<T,N> sum<T:__BuiltinArithmeticType, let N:int>(IArray<vector<T,N>> array)
-{
- vector<T,N> result = vector<T,N>(T(0));
- for (int i = 0; i < array.getCount(); i++)
- {
- result = result + array[i];
- }
- return result;
-}
+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;