diff options
| author | Yong He <yonghe@outlook.com> | 2023-11-10 13:55:14 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-11-10 13:55:14 -0800 |
| commit | 011d4281647e3a2a3cf0dbdda1fa65cc1b8ed881 (patch) | |
| tree | 70f91655e86d30529eda0a683e15f378eeae2cb5 | |
| parent | bfd3f39d04047d7a46e75206cd125ed87b3f3f99 (diff) | |
Cleanup builtin arithmetic interfaces. (#3317)
* wip: clean up IArithmetic
* wip.
* Cleanup builtin arithmetic interfaces.
* Fix.
* Fixes.
* Fix.
* Fix.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/core.meta.slang | 579 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-constexpr.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 19 | ||||
| -rw-r--r-- | tests/autodiff/generic-constructor.slang | 39 | ||||
| -rw-r--r-- | tests/ir/loop-inversion.slang | 8 | ||||
| -rw-r--r-- | tests/language-feature/generics/iarray.slang | 12 |
12 files changed, 487 insertions, 233 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 0f600d7c6..4dff9888d 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -74,6 +74,91 @@ syntax snorm : SNormModifier; /// syntax __extern_cpp : ExternCppModifier; +interface IComparable +{ + bool equals(This other); + bool lessThan(This other); + bool lessThanOrEquals(This other); +} + +interface IRangedValue +{ + static const This maxValue; + static const This minValue; +} + +__attributeTarget(DeclBase) +attribute_syntax [TreatAsDifferentiable] : TreatAsDifferentiableAttribute; + +interface IArithmetic : IComparable +{ + This add(This other); + This sub(This other); + This mul(This other); + This div(This other); + This mod(This other); + This neg(); + + __init(int val); + + /// Initialize from the same type. + __init(This value); +} + +interface ILogical : IComparable +{ + This shl(int value); + This shr(int value); + This bitAnd(This other); + This bitOr(This other); + This bitXor(This other); + This bitNot(); + This and(This other); + This or(This other); + This not(); + __init(int val); +} + +interface IInteger : IArithmetic, ILogical +{ + int toInt(); + int64_t toInt64(); + uint toUInt(); + uint64_t toUInt64(); +} + +interface IFloat : IArithmetic, IDifferentiable +{ + [TreatAsDifferentiable] + __init(float value); + + [TreatAsDifferentiable] + float toFloat(); + + [TreatAsDifferentiable] + This add(This other); + + [TreatAsDifferentiable] + This sub(This other); + + [TreatAsDifferentiable] + This mul(This other); + + [TreatAsDifferentiable] + This div(This other); + + [TreatAsDifferentiable] + This mod(This other); + + [TreatAsDifferentiable] + This neg(); + + [TreatAsDifferentiable] + __init(This value); + + [TreatAsDifferentiable] + This scale<T:__BuiltinFloatingPointType>(T scale); +} /// A type that can be used as an operand for builtins [sealed] @@ -83,22 +168,15 @@ interface __BuiltinType {} /// A type that can be used for arithmetic operations [sealed] [builtin] -interface __BuiltinArithmeticType : __BuiltinType +interface __BuiltinArithmeticType : __BuiltinType, IArithmetic { - /// Initialize from a 32-bit signed integer value. - __init(int value); - - /// Initialize from the same type. - __init(This value); } /// A type that can be used for logical/bitwise operations [sealed] [builtin] -interface __BuiltinLogicalType : __BuiltinType +interface __BuiltinLogicalType : __BuiltinType, ILogical { - /// Initialize from a 32-bit signed integer value. - __init(int value); } /// A type that logically has a sign (positive/negative/zero) @@ -243,12 +321,10 @@ struct DifferentialPair : IDifferentiable [sealed] [builtin] [TreatAsDifferentiable] -interface __BuiltinFloatingPointType : __BuiltinRealType, IDifferentiable +interface __BuiltinFloatingPointType : __BuiltinRealType, IFloat { - /// Initialize from a 32-bit floating-point value. - __init(float value); - /// Get the value of the mathematical constant pi in this type. + [Differentiable] static This getPi(); } @@ -365,7 +441,8 @@ __generic<T, let N : int> __intrinsic_op(select) vector<T,N> select(vector<bool, // Allow real-number types to be cast into each other __intrinsic_op($(kIROp_FloatCast)) T __realCast<T : __BuiltinRealType, U : __BuiltinRealType>(U val); - +__intrinsic_op($(kIROp_IntCast)) + T __intCast<T : __BuiltinType, U : __BuiltinType>(U val); ${{{{ // We are going to use code generation to produce the // declarations for all of our base types. @@ -446,8 +523,12 @@ ${{{{ if (kBaseTypes[tt].tag == BaseType::Double && kBaseTypes[ss].tag == BaseType::Float) builtinConversionKind = kBuiltinConversion_FloatToDouble; -}}}} + const char* attrib = ""; + if ((kBaseTypes[tt].flags & kBaseTypes[ss].flags & FLOAT_MASK) != 0) + attrib = "[TreatAsDifferentiable]"; +}}}} + $(attrib) __intrinsic_op($(intrinsicOpCode)) __implicit_conversion($(conversionCost), $(builtinConversionKind)) __init($(kBaseTypes[ss].name) value); @@ -455,23 +536,71 @@ ${{{{ ${{{{ } - // If this is a basic integer type, then define explicit - // initializers that take a value of an `enum` type. - // - // TODO: This should actually be restricted, so that this - // only applies `where T.__Tag == Self`, but we don't have - // the needed features in our type system to implement - // that constraint right now. - // + // Integer type implementations. switch (kBaseTypes[tt].tag) { - // TODO: should this cover the full gamut of integer types? - case BaseType::Int: + case BaseType::Bool: +}}}} + [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Eql)) bool equals(This other); + [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Less)) bool lessThan(This other); + [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Leq)) bool lessThanOrEquals(This other); + [__unsafeForceInlineEarly] This shl(int other) { return __intCast<This>(__shl(__intCast<int>(this), other)); } + [__unsafeForceInlineEarly] This shr(int other) { return __intCast<This>(__shr(__intCast<int>(this), other)); } + [__unsafeForceInlineEarly] This bitAnd(This other) { return __intCast<This>(__and(__intCast<int>(this), __intCast<int>(other))); } + [__unsafeForceInlineEarly] This bitOr(This other) { return __intCast<This>(__or(__intCast<int>(this), __intCast<int>(other))); } + [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_And)) This and(This other); + [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Or)) This or(This other); + [__unsafeForceInlineEarly] This bitXor(This other) { return __intCast<This>(__xor(__intCast<int>(this), __intCast<int>(other))); } + [__unsafeForceInlineEarly] This bitNot() { return __intCast<This>(__not(__intCast<int>(this))); } + [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Not)) This not(); +${{{{ + break; + case BaseType::UInt8: + case BaseType::UInt16: case BaseType::UInt: + case BaseType::UInt64: + case BaseType::Int8: + case BaseType::Int16: + case BaseType::Int: + case BaseType::Int64: + case BaseType::IntPtr: + case BaseType::UIntPtr: }}}} + // If this is a basic integer type, then define explicit + // initializers that take a value of an `enum` type. + // + // TODO: This should actually be restricted, so that this + // only applies `where T.__Tag == Self`, but we don't have + // the needed features in our type system to implement + // that constraint right now. + // __generic<T:__EnumType> __intrinsic_op($(kIROp_IntCast)) __init(T value); + + // Implementation of the `IInteger` interface. + __intrinsic_op($(kIROp_Less)) bool lessThan(This other); + __intrinsic_op($(kIROp_Leq)) bool lessThanOrEquals(This other); + __intrinsic_op($(kIROp_Eql)) bool equals(This other); + __intrinsic_op($(kIROp_Add)) This add(This other); + __intrinsic_op($(kIROp_Sub)) This sub(This other); + __intrinsic_op($(kIROp_Mul)) This mul(This other); + __intrinsic_op($(kIROp_Div)) This div(This other); + __intrinsic_op($(kIROp_FRem)) This mod(This other); + __intrinsic_op($(kIROp_Neg)) This neg(); + __intrinsic_op($(kIROp_Lsh)) This shl(int other); + __intrinsic_op($(kIROp_Rsh)) This shr(int other); + __intrinsic_op($(kIROp_BitAnd)) This bitAnd(This other); + __intrinsic_op($(kIROp_BitOr)) This bitOr(This other); + [__unsafeForceInlineEarly] This and(This other) {return __intCast<This>(__intCast<bool>(this) && __intCast<bool>(other)); } + [__unsafeForceInlineEarly] This or(This other) {return __intCast<This>(__intCast<bool>(this) || __intCast<bool>(other)); } + __intrinsic_op($(kIROp_BitXor)) This bitXor(This other); + __intrinsic_op($(kIROp_BitNot)) This bitNot(); + [__unsafeForceInlineEarly] This not() {return __intCast<This>(!__intCast<bool>(this)); } + __intrinsic_op($(kIROp_IntCast)) int toInt(); + __intrinsic_op($(kIROp_IntCast)) int64_t toInt64(); + __intrinsic_op($(kIROp_IntCast)) uint toUInt(); + __intrinsic_op($(kIROp_IntCast)) uint64_t toUInt64(); ${{{{ break; @@ -480,8 +609,8 @@ ${{{{ } // If this is a floating-point type, then we need to - // define the basic `getPi()` function that is used - // to implement generic versions of `degrees()` and + // implement the `IFloat` interface, which defines the basic `getPi()` + // function that is used to implement generic versions of `degrees()` and // `radians()`. // switch (kBaseTypes[tt].tag) @@ -492,8 +621,20 @@ ${{{{ case BaseType::Float: case BaseType::Double: }}}} + [TreatAsDifferentiable] static $(kBaseTypes[tt].name) getPi() { return $(kBaseTypes[tt].name)(3.14159265358979323846264338328); } + __intrinsic_op($(kIROp_Less)) bool lessThan(This other); + __intrinsic_op($(kIROp_Leq)) bool lessThanOrEquals(This other); + __intrinsic_op($(kIROp_Eql)) bool equals(This other); + __intrinsic_op($(kIROp_Add)) This add(This other); + __intrinsic_op($(kIROp_Sub)) This sub(This other); + __intrinsic_op($(kIROp_Mul)) This mul(This other); + __intrinsic_op($(kIROp_Div)) This div(This other); + __intrinsic_op($(kIROp_FRem)) This mod(This other); + __intrinsic_op($(kIROp_Neg)) This neg(); + __intrinsic_op($(kIROp_FloatCast)) float toFloat(); + __intrinsic_op($(kIROp_Mul)) This scale<T:__BuiltinFloatingPointType>(T s) { return __mul(this, __realCast<This>(s)); } typedef $(kBaseTypes[tt].name) Differential; [__unsafeForceInlineEarly] @@ -628,7 +769,7 @@ __generic<T> __intrinsic_op($(kIROp_Eql)) bool operator==(Ptr<T> p1, Ptr<T> p2); -extension bool +extension bool : IRangedValue { __generic<T> __implicit_conversion($(kConversionCost_PtrToBool)) @@ -639,7 +780,7 @@ extension bool static const bool minValue = false; } -extension uint64_t +extension uint64_t : IRangedValue { __generic<T> __intrinsic_op($(kIROp_CastPtrToInt)) @@ -649,7 +790,7 @@ extension uint64_t static const uint64_t minValue = 0; } -extension int64_t +extension int64_t : IRangedValue { __generic<T> __intrinsic_op($(kIROp_CastPtrToInt)) @@ -659,7 +800,7 @@ extension int64_t static const int64_t minValue = -0x8000000000000000LL; } -extension intptr_t +extension intptr_t : IRangedValue { __generic<T> __intrinsic_op($(kIROp_CastPtrToInt)) @@ -669,7 +810,7 @@ extension intptr_t static const int size = $(SLANG_PROCESSOR_X86_64?"8":"4"); } -extension uintptr_t +extension uintptr_t : IRangedValue { __generic<T> __intrinsic_op($(kIROp_CastPtrToInt)) @@ -844,56 +985,56 @@ __intrinsic_type($(kIROp_DynamicType)) struct __Dynamic {}; -extension half +extension half : IRangedValue { static const half maxValue = half(65504); static const half minValue = half(-65504); } -extension float +extension float : IRangedValue { static const float maxValue = 340282346638528859811704183484516925440.0f; static const float minValue = -340282346638528859811704183484516925440.0f; } -extension double +extension double : IRangedValue { static const double maxValue = 179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368.0; static const double minValue = -179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368.0; } -extension int +extension int : IRangedValue { static const int maxValue = 2147483647; static const int minValue = -2147483648; } -extension uint +extension uint : IRangedValue { static const uint maxValue = 4294967295; static const uint minValue = 0; } -extension int8_t +extension int8_t : IRangedValue { static const int8_t maxValue = 127; static const int8_t minValue = -128; } -extension uint8_t +extension uint8_t : IRangedValue { static const uint8_t maxValue = 255; static const uint8_t minValue = 0; } -extension uint16_t +extension uint16_t : IRangedValue { static const uint16_t maxValue = 65535; static const uint16_t minValue = 0; } -extension int16_t +extension int16_t : IRangedValue { static const int16_t maxValue = 32767; static const int16_t minValue = -32768; @@ -909,13 +1050,15 @@ struct vector : IArray<T> /// Initialize a vector where all elements have the same scalar `value`. + [TreatAsDifferentiable] __implicit_conversion($(kConversionCost_ScalarToVector)) __intrinsic_op($(kIROp_MakeVectorFromScalar)) __init(T value); - /// Initialize a vector from a value of the same type + /// Initialize a vector from a value of the same type // TODO: we should revise semantic checking so this kind of "identity" conversion is not required __intrinsic_op(0) + [TreatAsDifferentiable] __init(vector<T,N> value); [ForceInline] @@ -933,14 +1076,115 @@ __magic_type(MatrixExpressionType) struct matrix : IArray<vector<T,C>> { __intrinsic_op($(kIROp_MakeMatrixFromScalar)) + [TreatAsDifferentiable] __init(T val); + /// Initialize a vector from a value of the same type + // TODO: we should revise semantic checking so this kind of "identity" conversion is not required + __intrinsic_op(0) + [TreatAsDifferentiable] + __init(This value); + [ForceInline] int getCount() { return R; } __subscript(int index) -> vector<T,C> { __intrinsic_op($(kIROp_GetElement)) get; } } +__generic<T:__BuiltinFloatingPointType, let N : int> +extension vector<T,N> : IFloat +{ + [__unsafeForceInlineEarly] bool lessThan(This other) { return this < other; } + [__unsafeForceInlineEarly] bool lessThanOrEquals(This other) { return this <= other; } + [__unsafeForceInlineEarly] bool equals(This other) { return all(this == other); } + __intrinsic_op($(kIROp_Add)) This add(This other); + __intrinsic_op($(kIROp_Sub)) This sub(This other); + __intrinsic_op($(kIROp_Mul)) This mul(This other); + __intrinsic_op($(kIROp_Div)) This div(This other); + __intrinsic_op($(kIROp_FRem)) This mod(This other); + __intrinsic_op($(kIROp_Neg)) This neg(); + __intrinsic_op($(kIROp_Mul)) This scale<T1:__BuiltinFloatingPointType>(T1 s); + [__unsafeForceInlineEarly] float toFloat() { return __realCast<float>(this[0]); } + + [OverloadRank(-1)] + [__unsafeForceInlineEarly] __init(int v) { return vector<T,N>(T(v)); } + [OverloadRank(-1)] + [__unsafeForceInlineEarly] __init(float v) { return vector<T,N>(T(v)); } + + // IDifferentiable + + typedef vector<T, N> Differential; + + [__unsafeForceInlineEarly] + [BackwardDifferentiable] + static Differential dzero() + { + return Differential(__slang_noop_cast<T>(T.dzero())); + } + + [__unsafeForceInlineEarly] + [BackwardDifferentiable] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + __generic<U : __BuiltinRealType> + [__unsafeForceInlineEarly] + [BackwardDifferentiable] + static Differential dmul(U a, Differential b) + { + return __realCast<T, U>(a) * b; + } +} + +__generic<T:__BuiltinFloatingPointType, let N : int, let M : int, let L : int> +extension matrix<T,N,M,L> : IFloat +{ + [TreatAsDifferentiable][__unsafeForceInlineEarly] bool lessThan(This other) { return this < other; } + [TreatAsDifferentiable][__unsafeForceInlineEarly] bool lessThanOrEquals(This other) { return this <= other; } + [TreatAsDifferentiable][__unsafeForceInlineEarly] bool equals(This other) { return all(this == other); } + [TreatAsDifferentiable] __intrinsic_op($(kIROp_Add)) This add(This other); + [TreatAsDifferentiable] __intrinsic_op($(kIROp_Sub)) This sub(This other); + [TreatAsDifferentiable] __intrinsic_op($(kIROp_Mul))This mul(This other); + [TreatAsDifferentiable] __intrinsic_op($(kIROp_Div)) This div(This other); + [TreatAsDifferentiable] __intrinsic_op($(kIROp_FRem)) This mod(This other); + [TreatAsDifferentiable] __intrinsic_op($(kIROp_Neg)) This neg(); + [TreatAsDifferentiable] __intrinsic_op($(kIROp_Mul)) This scale<T1:__BuiltinFloatingPointType>(T1 s); + [TreatAsDifferentiable][__unsafeForceInlineEarly] This scale<T1:__BuiltinFloatingPointType>(T1 s); + [TreatAsDifferentiable][__unsafeForceInlineEarly] float toFloat() { return __realCast<float>(this[0][0]); } + + [OverloadRank(-1)] + [TreatAsDifferentiable][__unsafeForceInlineEarly] __init(int v) { return matrix<T,N,M>(T(v)); } + [OverloadRank(-1)] + [TreatAsDifferentiable][__unsafeForceInlineEarly] __init(float v) { return matrix<T,N,M>(T(v)); } + + // IDifferentiable. + typedef matrix<T, N,M,L> Differential; + + [__unsafeForceInlineEarly] + [BackwardDifferentiable] + static Differential dzero() + { + return matrix<T, N,M,L>(__slang_noop_cast<T>(T.dzero())); + } + + [__unsafeForceInlineEarly] + [BackwardDifferentiable] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + __generic<U : __BuiltinRealType> + [__unsafeForceInlineEarly] + [BackwardDifferentiable] + static Differential dmul(U a, Differential b) + { + return __realCast<T, U>(a) * b; + } +} + ${{{{ static const struct { char const* name; @@ -1184,7 +1428,7 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) auto toType = kBaseTypes[tt].name; }}}} -__generic<let R : int, let C : int> extension matrix<$(toType),R,C> +__generic<let R : int, let C : int, let L : int> extension matrix<$(toType),R,C,L> { ${{{{ for (int ff = 0; ff < kBaseTypeCount; ++ff) @@ -1202,7 +1446,7 @@ ${{{{ }}}} __implicit_conversion($(cost)) __intrinsic_op($(op)) - __init(matrix<$(fromType),R,C> value); + __init(matrix<$(fromType),R,C,L> value); ${{{{ } }}}} @@ -1215,61 +1459,6 @@ __generic<T, U> __intrinsic_op(0) T __slang_noop_cast(U u); -__generic<T:__BuiltinFloatingPointType, let N: int> -extension vector<T, N> : IDifferentiable -{ - typedef vector<T, N> Differential; - - [__unsafeForceInlineEarly] - [BackwardDifferentiable] - static Differential dzero() - { - return Differential(__slang_noop_cast<T>(T.dzero())); - } - - [__unsafeForceInlineEarly] - [BackwardDifferentiable] - static Differential dadd(Differential a, Differential b) - { - return a + b; - } - - __generic<U : __BuiltinRealType> - [__unsafeForceInlineEarly] - [BackwardDifferentiable] - static Differential dmul(U a, Differential b) - { - return __realCast<T, U>(a) * b; - } -} - -__generic<T:__BuiltinFloatingPointType, let R: int, let C: int, let L : int> -extension matrix<T, R, C, L> : IDifferentiable -{ - typedef matrix<T, R, C, L> Differential; - - [__unsafeForceInlineEarly] - [BackwardDifferentiable] - static Differential dzero() - { - return matrix<T, R, C, L>(__slang_noop_cast<T>(T.dzero())); - } - - [__unsafeForceInlineEarly] - [BackwardDifferentiable] - static Differential dadd(Differential a, Differential b) - { - return a + b; - } - - __generic<U : __BuiltinRealType> - [__unsafeForceInlineEarly] - [BackwardDifferentiable] - static Differential dmul(U a, Differential b) - { - return __realCast<T, U>(a) * b; - } -} //@ public: @@ -1333,17 +1522,20 @@ for (auto op : intrinsicUnaryOps) { char const* resultType = "T"; if (op.flags & BOOL_RESULT) resultType = "bool"; - + // scalar version sb << "__generic<T : " << op.interface << ">\n"; + sb << "[OverloadRank(10)]"; sb << "__prefix __intrinsic_op(" << int(op.opCode) << ") " << resultType << " operator" << op.opName << "(" << "T value);\n"; // vector version sb << "__generic<T : " << op.interface << ", let N : int> "; + sb << "[OverloadRank(10)]"; sb << "__prefix __intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(vector<T,N> value);\n"; // matrix version sb << "__generic<T : " << op.interface << ", let N : int, let M : int> "; + sb << "[OverloadRank(10)]"; sb << "__prefix __intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(matrix<T,N,M> value);\n"; } } @@ -1369,7 +1561,7 @@ Ptr<T> operator-(Ptr<T> value, int64_t offset) return __getElementPtr(value, -offset); } -__generic<T : __BuiltinArithmeticType> +__generic<T : IArithmetic> [__unsafeForceInlineEarly] __prefix T operator+(T value) { return value; } @@ -1532,28 +1724,35 @@ for (auto op : intrinsicBinaryOps) // scalar version sb << "__generic<T : " << op.interface << ">\n"; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") " << resultType << " operator" << op.opName << "(" << leftType << " left, " << rightType << " right);\n"; // vector version sb << "__generic<T : " << op.interface << ", let N : int> "; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(vector<" << leftType << ",N> left, vector<" << rightType << ",N> right);\n"; // matrix version sb << "__generic<T : " << op.interface << ", let N : int, let M : int> "; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(matrix<" << leftType << ",N,M> left, matrix<" << rightType << ",N,M> right);\n"; // scalar-vector and scalar-matrix sb << "__generic<T : " << op.interface << ", let N : int> "; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(" << leftType << " left, vector<" << rightType << ",N> right);\n"; sb << "__generic<T : " << op.interface << ", let N : int, let M : int> "; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(" << leftType << " left, matrix<" << rightType << ",N,M> right);\n"; // vector-scalar and matrix-scalar sb << "__generic<T : " << op.interface << ", let N : int> "; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(vector<" << leftType << ",N> left, " << rightType << " right);\n"; sb << "__generic<T : " << op.interface << ", let N : int, let M : int> "; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(matrix<" << leftType << ",N,M> left, " << rightType << " right);\n"; } } @@ -1848,152 +2047,164 @@ bool operator!=(E left, E right); // public interfaces for generic arithmetic types. -interface IComparable -{ - bool equals(This other); - bool lessThan(This other); - bool lessThanOrEquals(This other); -} - -__attributeTarget(DeclBase) -attribute_syntax [TreatAsDifferentiable] : TreatAsDifferentiableAttribute; - -[TreatAsDifferentiable] -interface IArithmetic : IComparable -{ - This add(This other); - This sub(This other); - This mul(This other); - This div(This other); - This mod(This other); - This neg(); - __init(int val); - static const This maxValue; - static const This minValue; -} - -interface IInteger : IArithmetic -{ - This shl(int value); - This shr(int value); - This bitAnd(This other); - This bitOr(This other); - This bitXor(This other); - This bitNot(); - int toInt(); - int64_t toInt64(); - uint toUInt(); - uint64_t toUInt64(); -} - -interface IFloat : IArithmetic -{ - __init(float value); - float toFloat(); -} - __generic<T : IComparable> [__unsafeForceInlineEarly] +[OverloadRank(-10)] bool operator<(T v0, T v1) { return v0.lessThan(v1); } __generic<T : IComparable> [__unsafeForceInlineEarly] +[OverloadRank(-10)] bool operator>(T v0, T v1) { return v1.lessThan(v0); } __generic<T : IComparable> [__unsafeForceInlineEarly] +[OverloadRank(-10)] bool operator ==(T v0, T v1) { return v0.equals(v1); } __generic<T : IComparable> [__unsafeForceInlineEarly] +[OverloadRank(-10)] bool operator >=(T v0, T v1) { return v1.lessThan(v1); } __generic<T : IComparable> [__unsafeForceInlineEarly] +[OverloadRank(-10)] bool operator <=(T v0, T v1) { return v0.lessThanOrEquals(v1); } __generic<T : IComparable> [__unsafeForceInlineEarly] +[OverloadRank(-10)] bool operator !=(T v0, T v1) { return !v0.equals(v1); } -__generic<T : IArithmetic> +${{{{ +const char* arithmeticInterfaces[] = {"IArithmetic", "IFloat"}; +const char* attribs[] = {"", "[TreatAsDifferentiable]"}; +for (Index i = 0; i < 2; i++) { + const auto interfaceName = arithmeticInterfaces[i]; + const auto attrib = attribs[i]; + Index overloadRank = i - 3; +}}}} +$(attrib) +__generic<T : $(interfaceName)> [__unsafeForceInlineEarly] +[OverloadRank($(overloadRank))] T operator +(T v0, T v1) { return v0.add(v1); } -__generic<T : IArithmetic> +$(attrib) +__generic<T : $(interfaceName)> [__unsafeForceInlineEarly] +[OverloadRank($(overloadRank))] T operator -(T v0, T v1) { return v0.sub(v1); } -__generic<T : IArithmetic> +$(attrib) +__generic<T : $(interfaceName)> [__unsafeForceInlineEarly] +[OverloadRank($(overloadRank))] T operator *(T v0, T v1) { return v0.mul(v1); } -__generic<T : IArithmetic> +$(attrib) +__generic<T : $(interfaceName)> [__unsafeForceInlineEarly] +[OverloadRank($(overloadRank))] T operator /(T v0, T v1) { return v0.div(v1); } -__generic<T : IArithmetic> +$(attrib) +__generic<T : $(interfaceName)> [__unsafeForceInlineEarly] +[OverloadRank($(overloadRank))] T operator %(T v0, T v1) { return v0.mod(v1); } -__generic<T : IArithmetic> +$(attrib) +__generic<T : $(interfaceName)> [__unsafeForceInlineEarly] +[OverloadRank($(overloadRank))] __prefix T operator -(T v0) { return v0.neg(); } -__generic<T : IInteger> + +${{{{ + } // foreach ["IArithmetic", "IFloat"] +}}}} + +__generic<T : ILogical> [__unsafeForceInlineEarly] +[OverloadRank(-10)] T operator &(T v0, T v1) { return v0.bitAnd(v1); } -__generic<T : IInteger> +__generic<T : ILogical> [__unsafeForceInlineEarly] +[OverloadRank(-10)] +T operator &&(T v0, T v1) +{ + return v0.and(v1); +} +__generic<T : ILogical> +[__unsafeForceInlineEarly] +[OverloadRank(-10)] T operator |(T v0, T v1) { return v0.bitOr(v1); } -__generic<T : IInteger> +__generic<T : ILogical> +[__unsafeForceInlineEarly] +[OverloadRank(-10)] +T operator ||(T v0, T v1) +{ + return v0.or(v1); +} +__generic<T : ILogical> [__unsafeForceInlineEarly] +[OverloadRank(-10)] T operator ^(T v0, T v1) { return v0.bitXor(v1); } -__generic<T : IInteger> +__generic<T : ILogical> [__unsafeForceInlineEarly] +[OverloadRank(-10)] __prefix T operator ~(T v0) { return v0.bitNot(); } +__generic<T : ILogical> +[__unsafeForceInlineEarly] +[OverloadRank(-10)] +__prefix T operator !(T v0) +{ + return v0.not(); +} // IR level type traits. @@ -2089,61 +2300,6 @@ bool __isVector() return __isVector_impl(__declVal<T>()); } -// Provide implementations to public generic arithmetic interfaces for builtin types. - -${{{{ -// Code gen integer type implementations. - -for (int tt = 0; tt < kBaseTypeCount; ++tt) -{ - if (kBaseTypes[tt].flags & (SINT_MASK | UINT_MASK)) - { -}}}} -extension $(kBaseTypes[tt].name) : IInteger -{ - [__unsafeForceInlineEarly] bool equals(This other){return this==other;} - [__unsafeForceInlineEarly] bool lessThan(This other){return this<other;} - [__unsafeForceInlineEarly] bool lessThanOrEquals(This other){return this<=other;} - [__unsafeForceInlineEarly] This add(This other) { return __add(this, other); } - [__unsafeForceInlineEarly] This sub(This other) { return __sub(this, other); } - [__unsafeForceInlineEarly] This mul(This other) { return __mul(this, other); } - [__unsafeForceInlineEarly] This div(This other) { return __div(this, other); } - [__unsafeForceInlineEarly] This mod(This other) { return __irem(this, other); } - [__unsafeForceInlineEarly] This neg() { return __neg(this); } - [__unsafeForceInlineEarly] This shl(int other) { return __shl(this, other); } - [__unsafeForceInlineEarly] This shr(int other) { return __shr(this, other); } - [__unsafeForceInlineEarly] This bitAnd(This other) { return __add(this, other); } - [__unsafeForceInlineEarly] This bitOr(This other) { return __or(this, other); } - [__unsafeForceInlineEarly] This bitXor(This other) { return __xor(this, other); } - [__unsafeForceInlineEarly] This bitNot() { return __not(this); } - [__unsafeForceInlineEarly] int toInt() { return int(this); } - [__unsafeForceInlineEarly] int64_t toInt64() { return int64_t(this); } - [__unsafeForceInlineEarly] uint toUInt() { return uint(this); } - [__unsafeForceInlineEarly] uint64_t toUInt64() { return uint64_t(this); } -} -${{{{ - } - else if (kBaseTypes[tt].flags & FLOAT_MASK) - { -}}}} - -extension $(kBaseTypes[tt].name) : IFloat -{ - [__unsafeForceInlineEarly] bool lessThan(This other) { return this < other; } - [__unsafeForceInlineEarly] bool lessThanOrEquals(This other) { return this <= other; } - [__unsafeForceInlineEarly] bool equals(This other) { return this == other; } - [__unsafeForceInlineEarly] This add(This other) { return __add(this, other); } - [__unsafeForceInlineEarly] This sub(This other) { return __sub(this, other); } - [__unsafeForceInlineEarly] This mul(This other) { return __mul(this, other); } - [__unsafeForceInlineEarly] This div(This other) { return __div(this, other); } - [__unsafeForceInlineEarly] This mod(This other) { return __frem(this, other); } - [__unsafeForceInlineEarly] This neg() { return __neg(this); } - [__unsafeForceInlineEarly] float toFloat() { return float(this); } -} -${{{{ - } -} -}}}} // Binding Attributes @@ -2309,6 +2465,9 @@ attribute_syntax [__unsafeForceInlineEarly] : UnsafeForceInlineEarlyAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [ForceInline] : ForceInlineAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [OverloadRank] : OverloadRankAttribute; + __attributeTarget(FuncDecl) attribute_syntax [DllImport(modulePath: String)] : DllImportAttribute; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index f430c5f7a..4288017fa 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1075,6 +1075,18 @@ class AnyValueSizeAttribute : public Attribute int32_t size; }; + /// This is a stop-gap solution to break overload ambiguity in stdlib. + /// When there is a function overload ambiguity, the compiler will pick the one with higher rank + /// specified by this attribute. An overload without this attribute will have a rank of 0. + /// In the future, we should enhance our type system to take into account the "specialized"-ness + /// of an overload, such that `T overload1<T:IDerived>()` is more specialized than `T overload2<T:IBase>()` + /// and preferred during overload resolution. +class OverloadRankAttribute : public Attribute +{ + SLANG_AST_CLASS(OverloadRankAttribute) + int32_t rank; +}; + /// An attribute that marks an interface for specialization use only. Any operation that triggers dynamic /// dispatch through the interface is a compile-time error. class SpecializeAttribute : public Attribute diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 43b73892c..93c53a975 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1482,6 +1482,9 @@ namespace Slang // Cached dictionary for looking up satisfying values. SLANG_UNREFLECTED RequirementDictionary m_requirementDictionary; + + RefPtr<WitnessTable> specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); + }; struct SpecializationParam diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 5ca0af1c9..2f8b28afc 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2225,11 +2225,15 @@ namespace Slang { // check the base expression first expr->functionExpr = CheckTerm(expr->functionExpr); + + auto treatAsDifferentiableExpr = m_treatAsDifferentiableExpr; + m_treatAsDifferentiableExpr = nullptr; // Next check the argument expressions for (auto & arg : expr->arguments) { arg = CheckTerm(arg); } + m_treatAsDifferentiableExpr = treatAsDifferentiableExpr; // If we are in a differentiable function, register differential witness tables involved in // this call. diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index e8ae28c04..569804ff4 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -349,6 +349,19 @@ namespace Slang anyValueSizeAttr->size = int32_t(value->getValue()); } + else if (auto overloadRankAttr = as<OverloadRankAttribute>(attr)) + { + if (attr->args.getCount() != 1) + { + return false; + } + auto rank = checkConstantIntVal(attr->args[0]); + if (rank == nullptr) + { + return false; + } + overloadRankAttr->rank = int32_t(rank->getValue()); + } else if (auto bindingAttr = as<GLSLBindingAttribute>(attr)) { // This must be vk::binding or gl::binding (as specified in core.meta.slang under vk_binding/gl_binding) diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index c668155df..27062fc0c 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1070,6 +1070,15 @@ namespace Slang return 0; } + int getOverloadRank(DeclRef<Decl> declRef) + { + if (!declRef.getDecl()) + return 0; + if (auto attr = declRef.getDecl()->findModifier<OverloadRankAttribute>()) + return attr->rank; + return 0; + } + int SemanticsVisitor::CompareOverloadCandidates( OverloadCandidate* left, OverloadCandidate* right) @@ -1142,6 +1151,11 @@ namespace Slang auto specificityDiff = compareOverloadCandidateSpecificity(left->item, right->item); if(specificityDiff) return specificityDiff; + + // If we reach here, we will attempt to use overload rank to break the ties. + auto overloadRankDiff = getOverloadRank(right->item.declRef) - getOverloadRank(left->item.declRef); + if (overloadRankDiff) + return overloadRankDiff; } return 0; diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp index 34b56bfef..63ca32650 100644 --- a/source/slang/slang-ir-constexpr.cpp +++ b/source/slang/slang-ir-constexpr.cpp @@ -56,6 +56,9 @@ bool isConstExpr(IRInst* value) case kIROp_FloatLit: case kIROp_BoolLit: case kIROp_Func: + case kIROp_StructKey: + case kIROp_WitnessTable: + case kIROp_Generic: return true; default: @@ -136,6 +139,8 @@ bool opCanBeConstExpr(IROp op) case kIROp_GetOptionalValue: case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: + case kIROp_LookupWitness: + case kIROp_Specialize: // TODO: more cases return true; @@ -146,10 +151,8 @@ bool opCanBeConstExpr(IROp op) bool opCanBeConstExprByForwardPass(IRInst* value) { - // TODO: realistically need to special-case `call` - // operations here, so that we check whether the - // callee function is fixed/known, and if it is - // whether it has been declared as constant-foldable + // TODO: handle call inst here. + if (value->getOp() == kIROp_Param) return false; return opCanBeConstExpr(value->getOp()); diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index b6b7823c9..d7713618e 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -317,7 +317,7 @@ struct PeepholeContext : InstPassBase } else { - changed = tryFoldElementExtractFromUpdateInst(inst); + changed |= tryFoldElementExtractFromUpdateInst(inst); } break; case kIROp_GetElement: @@ -382,7 +382,7 @@ struct PeepholeContext : InstPassBase } else { - changed = tryFoldElementExtractFromUpdateInst(inst); + changed |= tryFoldElementExtractFromUpdateInst(inst); } break; case kIROp_UpdateElement: @@ -806,7 +806,7 @@ struct PeepholeContext : InstPassBase case kIROp_Div: case kIROp_And: case kIROp_Or: - changed = tryOptimizeArithmeticInst(inst); + changed |= tryOptimizeArithmeticInst(inst); break; case kIROp_Param: { diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index d24fd239d..ed2ce048b 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -245,6 +245,22 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return m_obj.as<WitnessTable>(); } + RefPtr<WitnessTable> WitnessTable::specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst) + { + auto newBaseType = baseType->substitute(astBuilder, subst); + auto newWitnessedType = witnessedType->substitute(astBuilder, subst); + if (newBaseType == baseType && newWitnessedType == witnessedType) + return this; + RefPtr<WitnessTable> result = new WitnessTable(); + result->baseType = as<Type>(newBaseType); + result->witnessedType = as<Type>(newWitnessedType); + for (auto requirement : m_requirements) + { + auto newRequirement = requirement.value.specialize(astBuilder, subst); + result->add(requirement.key, newRequirement); + } + return result; + } RequirementWitness RequirementWitness::specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst) { @@ -256,8 +272,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return RequirementWitness(); case RequirementWitness::Flavor::witnessTable: - SLANG_ASSERT(!subst); - return *this; + return RequirementWitness(this->getWitnessTable()->specialize(astBuilder, subst)); case RequirementWitness::Flavor::declRef: { diff --git a/tests/autodiff/generic-constructor.slang b/tests/autodiff/generic-constructor.slang new file mode 100644 index 000000000..aad9824ec --- /dev/null +++ b/tests/autodiff/generic-constructor.slang @@ -0,0 +1,39 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +interface IFoo : IDifferentiable +{ + [Differentiable] + __init(Differential v); +} + +struct Impl : IFoo +{ + float x; + + [Differentiable] + __init(Differential v) + { + x = v.x; + } +} + +[Differentiable] +float test(float x) +{ + Impl.Differential v0 = { x }; + var v1 = Impl(v0); + return v1.x * v1.x; +} + +[numthreads(1,1,1)] +void computeMain(uint tid : SV_DispatchThreadID) +{ + var p = diffPair(3.0, 0.0); + bwd_diff(test)(p, 1.0); + outputBuffer[tid] = p.d; + // CHECK: 6.0 +} diff --git a/tests/ir/loop-inversion.slang b/tests/ir/loop-inversion.slang index 03bdcc340..7e218a62a 100644 --- a/tests/ir/loop-inversion.slang +++ b/tests/ir/loop-inversion.slang @@ -19,7 +19,7 @@ RWStructuredBuffer<int> outputBuffer; // A standard loop // CHECK-LABEL: int a_{{.*}}() // CHECK-NOT: break; -// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]] +// CHECK: int {{.*}} = j_{{.*}} + [[i:i_[0-9]+]] // CHECK: [[i]] + int(1); // CHECK: if( // CHECK: break; @@ -35,7 +35,7 @@ int a() // A vanilla while loop // CHECK-LABEL: int b_{{.*}}() // CHECK-NOT: break; -// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]] +// CHECK: int {{.*}} = j_{{.*}} + [[i:i_[0-9]+]] // CHECK: [[i]] + int(1); // CHECK: if( // CHECK: break; @@ -55,7 +55,7 @@ int b() // A while loop with a break on the false branch // CHECK-LABEL: int c_{{.*}}() // CHECK-NOT: break; -// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]] +// CHECK: int {{.*}} = j_{{.*}} + [[i:i_[0-9]+]] // CHECK: [[i]] + int(1); // CHECK: if( // CHECK: break; @@ -79,7 +79,7 @@ int c() // A while loop with a break on the true branch // CHECK-LABEL: int d_{{.*}}() // CHECK-NOT: break; -// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]] +// CHECK: int {{.*}} = j_{{.*}} + [[i:i_[0-9]+]] // CHECK: [[i]] + int(1); // CHECK: if( // CHECK: break; diff --git a/tests/language-feature/generics/iarray.slang b/tests/language-feature/generics/iarray.slang index c2314f106..b66c3ab27 100644 --- a/tests/language-feature/generics/iarray.slang +++ b/tests/language-feature/generics/iarray.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type -T sum<T:__BuiltinArithmeticType>(IArray<T> array) +T sum<T:IFloat>(IArray<T> array) { T result = T(0); for (int i = 0; i < array.getCount(); i++) @@ -10,15 +10,7 @@ T sum<T:__BuiltinArithmeticType>(IArray<T> array) } return result; } -vector<T,N> sum<T:__BuiltinArithmeticType, let N:int>(IArray<vector<T,N>> array) -{ - vector<T,N> result = vector<T,N>(T(0)); - for (int i = 0; i < array.getCount(); i++) - { - result = result + array[i]; - } - return result; -} + //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; |
