From 011d4281647e3a2a3cf0dbdda1fa65cc1b8ed881 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 10 Nov 2023 13:55:14 -0800 Subject: Cleanup builtin arithmetic interfaces. (#3317) * wip: clean up IArithmetic * wip. * Cleanup builtin arithmetic interfaces. * Fix. * Fixes. * Fix. * Fix. * Fix. --------- Co-authored-by: Yong He --- source/slang/core.meta.slang | 579 +++++++++++++++++++++------------ source/slang/slang-ast-modifier.h | 12 + source/slang/slang-ast-support-types.h | 3 + source/slang/slang-check-expr.cpp | 4 + source/slang/slang-check-modifier.cpp | 13 + source/slang/slang-check-overload.cpp | 14 + source/slang/slang-ir-constexpr.cpp | 11 +- source/slang/slang-ir-peephole.cpp | 6 +- source/slang/slang-syntax.cpp | 19 +- 9 files changed, 442 insertions(+), 219 deletions(-) (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 0f600d7c6..4dff9888d 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -74,6 +74,91 @@ syntax snorm : SNormModifier; /// syntax __extern_cpp : ExternCppModifier; +interface IComparable +{ + bool equals(This other); + bool lessThan(This other); + bool lessThanOrEquals(This other); +} + +interface IRangedValue +{ + static const This maxValue; + static const This minValue; +} + +__attributeTarget(DeclBase) +attribute_syntax [TreatAsDifferentiable] : TreatAsDifferentiableAttribute; + +interface IArithmetic : IComparable +{ + This add(This other); + This sub(This other); + This mul(This other); + This div(This other); + This mod(This other); + This neg(); + + __init(int val); + + /// Initialize from the same type. + __init(This value); +} + +interface ILogical : IComparable +{ + This shl(int value); + This shr(int value); + This bitAnd(This other); + This bitOr(This other); + This bitXor(This other); + This bitNot(); + This and(This other); + This or(This other); + This not(); + __init(int val); +} + +interface IInteger : IArithmetic, ILogical +{ + int toInt(); + int64_t toInt64(); + uint toUInt(); + uint64_t toUInt64(); +} + +interface IFloat : IArithmetic, IDifferentiable +{ + [TreatAsDifferentiable] + __init(float value); + + [TreatAsDifferentiable] + float toFloat(); + + [TreatAsDifferentiable] + This add(This other); + + [TreatAsDifferentiable] + This sub(This other); + + [TreatAsDifferentiable] + This mul(This other); + + [TreatAsDifferentiable] + This div(This other); + + [TreatAsDifferentiable] + This mod(This other); + + [TreatAsDifferentiable] + This neg(); + + [TreatAsDifferentiable] + __init(This value); + + [TreatAsDifferentiable] + This scale(T scale); +} /// A type that can be used as an operand for builtins [sealed] @@ -83,22 +168,15 @@ interface __BuiltinType {} /// A type that can be used for arithmetic operations [sealed] [builtin] -interface __BuiltinArithmeticType : __BuiltinType +interface __BuiltinArithmeticType : __BuiltinType, IArithmetic { - /// Initialize from a 32-bit signed integer value. - __init(int value); - - /// Initialize from the same type. - __init(This value); } /// A type that can be used for logical/bitwise operations [sealed] [builtin] -interface __BuiltinLogicalType : __BuiltinType +interface __BuiltinLogicalType : __BuiltinType, ILogical { - /// Initialize from a 32-bit signed integer value. - __init(int value); } /// A type that logically has a sign (positive/negative/zero) @@ -243,12 +321,10 @@ struct DifferentialPair : IDifferentiable [sealed] [builtin] [TreatAsDifferentiable] -interface __BuiltinFloatingPointType : __BuiltinRealType, IDifferentiable +interface __BuiltinFloatingPointType : __BuiltinRealType, IFloat { - /// Initialize from a 32-bit floating-point value. - __init(float value); - /// Get the value of the mathematical constant pi in this type. + [Differentiable] static This getPi(); } @@ -365,7 +441,8 @@ __generic __intrinsic_op(select) vector select(vector(U val); - +__intrinsic_op($(kIROp_IntCast)) + T __intCast(U val); ${{{{ // We are going to use code generation to produce the // declarations for all of our base types. @@ -446,8 +523,12 @@ ${{{{ if (kBaseTypes[tt].tag == BaseType::Double && kBaseTypes[ss].tag == BaseType::Float) builtinConversionKind = kBuiltinConversion_FloatToDouble; -}}}} + const char* attrib = ""; + if ((kBaseTypes[tt].flags & kBaseTypes[ss].flags & FLOAT_MASK) != 0) + attrib = "[TreatAsDifferentiable]"; +}}}} + $(attrib) __intrinsic_op($(intrinsicOpCode)) __implicit_conversion($(conversionCost), $(builtinConversionKind)) __init($(kBaseTypes[ss].name) value); @@ -455,23 +536,71 @@ ${{{{ ${{{{ } - // If this is a basic integer type, then define explicit - // initializers that take a value of an `enum` type. - // - // TODO: This should actually be restricted, so that this - // only applies `where T.__Tag == Self`, but we don't have - // the needed features in our type system to implement - // that constraint right now. - // + // Integer type implementations. switch (kBaseTypes[tt].tag) { - // TODO: should this cover the full gamut of integer types? - case BaseType::Int: + case BaseType::Bool: +}}}} + [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Eql)) bool equals(This other); + [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Less)) bool lessThan(This other); + [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Leq)) bool lessThanOrEquals(This other); + [__unsafeForceInlineEarly] This shl(int other) { return __intCast(__shl(__intCast(this), other)); } + [__unsafeForceInlineEarly] This shr(int other) { return __intCast(__shr(__intCast(this), other)); } + [__unsafeForceInlineEarly] This bitAnd(This other) { return __intCast(__and(__intCast(this), __intCast(other))); } + [__unsafeForceInlineEarly] This bitOr(This other) { return __intCast(__or(__intCast(this), __intCast(other))); } + [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_And)) This and(This other); + [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Or)) This or(This other); + [__unsafeForceInlineEarly] This bitXor(This other) { return __intCast(__xor(__intCast(this), __intCast(other))); } + [__unsafeForceInlineEarly] This bitNot() { return __intCast(__not(__intCast(this))); } + [__unsafeForceInlineEarly] __intrinsic_op($(kIROp_Not)) This not(); +${{{{ + break; + case BaseType::UInt8: + case BaseType::UInt16: case BaseType::UInt: + case BaseType::UInt64: + case BaseType::Int8: + case BaseType::Int16: + case BaseType::Int: + case BaseType::Int64: + case BaseType::IntPtr: + case BaseType::UIntPtr: }}}} + // If this is a basic integer type, then define explicit + // initializers that take a value of an `enum` type. + // + // TODO: This should actually be restricted, so that this + // only applies `where T.__Tag == Self`, but we don't have + // the needed features in our type system to implement + // that constraint right now. + // __generic __intrinsic_op($(kIROp_IntCast)) __init(T value); + + // Implementation of the `IInteger` interface. + __intrinsic_op($(kIROp_Less)) bool lessThan(This other); + __intrinsic_op($(kIROp_Leq)) bool lessThanOrEquals(This other); + __intrinsic_op($(kIROp_Eql)) bool equals(This other); + __intrinsic_op($(kIROp_Add)) This add(This other); + __intrinsic_op($(kIROp_Sub)) This sub(This other); + __intrinsic_op($(kIROp_Mul)) This mul(This other); + __intrinsic_op($(kIROp_Div)) This div(This other); + __intrinsic_op($(kIROp_FRem)) This mod(This other); + __intrinsic_op($(kIROp_Neg)) This neg(); + __intrinsic_op($(kIROp_Lsh)) This shl(int other); + __intrinsic_op($(kIROp_Rsh)) This shr(int other); + __intrinsic_op($(kIROp_BitAnd)) This bitAnd(This other); + __intrinsic_op($(kIROp_BitOr)) This bitOr(This other); + [__unsafeForceInlineEarly] This and(This other) {return __intCast(__intCast(this) && __intCast(other)); } + [__unsafeForceInlineEarly] This or(This other) {return __intCast(__intCast(this) || __intCast(other)); } + __intrinsic_op($(kIROp_BitXor)) This bitXor(This other); + __intrinsic_op($(kIROp_BitNot)) This bitNot(); + [__unsafeForceInlineEarly] This not() {return __intCast(!__intCast(this)); } + __intrinsic_op($(kIROp_IntCast)) int toInt(); + __intrinsic_op($(kIROp_IntCast)) int64_t toInt64(); + __intrinsic_op($(kIROp_IntCast)) uint toUInt(); + __intrinsic_op($(kIROp_IntCast)) uint64_t toUInt64(); ${{{{ break; @@ -480,8 +609,8 @@ ${{{{ } // If this is a floating-point type, then we need to - // define the basic `getPi()` function that is used - // to implement generic versions of `degrees()` and + // implement the `IFloat` interface, which defines the basic `getPi()` + // function that is used to implement generic versions of `degrees()` and // `radians()`. // switch (kBaseTypes[tt].tag) @@ -492,8 +621,20 @@ ${{{{ case BaseType::Float: case BaseType::Double: }}}} + [TreatAsDifferentiable] static $(kBaseTypes[tt].name) getPi() { return $(kBaseTypes[tt].name)(3.14159265358979323846264338328); } + __intrinsic_op($(kIROp_Less)) bool lessThan(This other); + __intrinsic_op($(kIROp_Leq)) bool lessThanOrEquals(This other); + __intrinsic_op($(kIROp_Eql)) bool equals(This other); + __intrinsic_op($(kIROp_Add)) This add(This other); + __intrinsic_op($(kIROp_Sub)) This sub(This other); + __intrinsic_op($(kIROp_Mul)) This mul(This other); + __intrinsic_op($(kIROp_Div)) This div(This other); + __intrinsic_op($(kIROp_FRem)) This mod(This other); + __intrinsic_op($(kIROp_Neg)) This neg(); + __intrinsic_op($(kIROp_FloatCast)) float toFloat(); + __intrinsic_op($(kIROp_Mul)) This scale(T s) { return __mul(this, __realCast(s)); } typedef $(kBaseTypes[tt].name) Differential; [__unsafeForceInlineEarly] @@ -628,7 +769,7 @@ __generic __intrinsic_op($(kIROp_Eql)) bool operator==(Ptr p1, Ptr p2); -extension bool +extension bool : IRangedValue { __generic __implicit_conversion($(kConversionCost_PtrToBool)) @@ -639,7 +780,7 @@ extension bool static const bool minValue = false; } -extension uint64_t +extension uint64_t : IRangedValue { __generic __intrinsic_op($(kIROp_CastPtrToInt)) @@ -649,7 +790,7 @@ extension uint64_t static const uint64_t minValue = 0; } -extension int64_t +extension int64_t : IRangedValue { __generic __intrinsic_op($(kIROp_CastPtrToInt)) @@ -659,7 +800,7 @@ extension int64_t static const int64_t minValue = -0x8000000000000000LL; } -extension intptr_t +extension intptr_t : IRangedValue { __generic __intrinsic_op($(kIROp_CastPtrToInt)) @@ -669,7 +810,7 @@ extension intptr_t static const int size = $(SLANG_PROCESSOR_X86_64?"8":"4"); } -extension uintptr_t +extension uintptr_t : IRangedValue { __generic __intrinsic_op($(kIROp_CastPtrToInt)) @@ -844,56 +985,56 @@ __intrinsic_type($(kIROp_DynamicType)) struct __Dynamic {}; -extension half +extension half : IRangedValue { static const half maxValue = half(65504); static const half minValue = half(-65504); } -extension float +extension float : IRangedValue { static const float maxValue = 340282346638528859811704183484516925440.0f; static const float minValue = -340282346638528859811704183484516925440.0f; } -extension double +extension double : IRangedValue { static const double maxValue = 179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368.0; static const double minValue = -179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368.0; } -extension int +extension int : IRangedValue { static const int maxValue = 2147483647; static const int minValue = -2147483648; } -extension uint +extension uint : IRangedValue { static const uint maxValue = 4294967295; static const uint minValue = 0; } -extension int8_t +extension int8_t : IRangedValue { static const int8_t maxValue = 127; static const int8_t minValue = -128; } -extension uint8_t +extension uint8_t : IRangedValue { static const uint8_t maxValue = 255; static const uint8_t minValue = 0; } -extension uint16_t +extension uint16_t : IRangedValue { static const uint16_t maxValue = 65535; static const uint16_t minValue = 0; } -extension int16_t +extension int16_t : IRangedValue { static const int16_t maxValue = 32767; static const int16_t minValue = -32768; @@ -909,13 +1050,15 @@ struct vector : IArray /// Initialize a vector where all elements have the same scalar `value`. + [TreatAsDifferentiable] __implicit_conversion($(kConversionCost_ScalarToVector)) __intrinsic_op($(kIROp_MakeVectorFromScalar)) __init(T value); - /// Initialize a vector from a value of the same type + /// Initialize a vector from a value of the same type // TODO: we should revise semantic checking so this kind of "identity" conversion is not required __intrinsic_op(0) + [TreatAsDifferentiable] __init(vector value); [ForceInline] @@ -933,14 +1076,115 @@ __magic_type(MatrixExpressionType) struct matrix : IArray> { __intrinsic_op($(kIROp_MakeMatrixFromScalar)) + [TreatAsDifferentiable] __init(T val); + /// Initialize a vector from a value of the same type + // TODO: we should revise semantic checking so this kind of "identity" conversion is not required + __intrinsic_op(0) + [TreatAsDifferentiable] + __init(This value); + [ForceInline] int getCount() { return R; } __subscript(int index) -> vector { __intrinsic_op($(kIROp_GetElement)) get; } } +__generic +extension vector : IFloat +{ + [__unsafeForceInlineEarly] bool lessThan(This other) { return this < other; } + [__unsafeForceInlineEarly] bool lessThanOrEquals(This other) { return this <= other; } + [__unsafeForceInlineEarly] bool equals(This other) { return all(this == other); } + __intrinsic_op($(kIROp_Add)) This add(This other); + __intrinsic_op($(kIROp_Sub)) This sub(This other); + __intrinsic_op($(kIROp_Mul)) This mul(This other); + __intrinsic_op($(kIROp_Div)) This div(This other); + __intrinsic_op($(kIROp_FRem)) This mod(This other); + __intrinsic_op($(kIROp_Neg)) This neg(); + __intrinsic_op($(kIROp_Mul)) This scale(T1 s); + [__unsafeForceInlineEarly] float toFloat() { return __realCast(this[0]); } + + [OverloadRank(-1)] + [__unsafeForceInlineEarly] __init(int v) { return vector(T(v)); } + [OverloadRank(-1)] + [__unsafeForceInlineEarly] __init(float v) { return vector(T(v)); } + + // IDifferentiable + + typedef vector Differential; + + [__unsafeForceInlineEarly] + [BackwardDifferentiable] + static Differential dzero() + { + return Differential(__slang_noop_cast(T.dzero())); + } + + [__unsafeForceInlineEarly] + [BackwardDifferentiable] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + __generic + [__unsafeForceInlineEarly] + [BackwardDifferentiable] + static Differential dmul(U a, Differential b) + { + return __realCast(a) * b; + } +} + +__generic +extension matrix : IFloat +{ + [TreatAsDifferentiable][__unsafeForceInlineEarly] bool lessThan(This other) { return this < other; } + [TreatAsDifferentiable][__unsafeForceInlineEarly] bool lessThanOrEquals(This other) { return this <= other; } + [TreatAsDifferentiable][__unsafeForceInlineEarly] bool equals(This other) { return all(this == other); } + [TreatAsDifferentiable] __intrinsic_op($(kIROp_Add)) This add(This other); + [TreatAsDifferentiable] __intrinsic_op($(kIROp_Sub)) This sub(This other); + [TreatAsDifferentiable] __intrinsic_op($(kIROp_Mul))This mul(This other); + [TreatAsDifferentiable] __intrinsic_op($(kIROp_Div)) This div(This other); + [TreatAsDifferentiable] __intrinsic_op($(kIROp_FRem)) This mod(This other); + [TreatAsDifferentiable] __intrinsic_op($(kIROp_Neg)) This neg(); + [TreatAsDifferentiable] __intrinsic_op($(kIROp_Mul)) This scale(T1 s); + [TreatAsDifferentiable][__unsafeForceInlineEarly] This scale(T1 s); + [TreatAsDifferentiable][__unsafeForceInlineEarly] float toFloat() { return __realCast(this[0][0]); } + + [OverloadRank(-1)] + [TreatAsDifferentiable][__unsafeForceInlineEarly] __init(int v) { return matrix(T(v)); } + [OverloadRank(-1)] + [TreatAsDifferentiable][__unsafeForceInlineEarly] __init(float v) { return matrix(T(v)); } + + // IDifferentiable. + typedef matrix Differential; + + [__unsafeForceInlineEarly] + [BackwardDifferentiable] + static Differential dzero() + { + return matrix(__slang_noop_cast(T.dzero())); + } + + [__unsafeForceInlineEarly] + [BackwardDifferentiable] + static Differential dadd(Differential a, Differential b) + { + return a + b; + } + + __generic + [__unsafeForceInlineEarly] + [BackwardDifferentiable] + static Differential dmul(U a, Differential b) + { + return __realCast(a) * b; + } +} + ${{{{ static const struct { char const* name; @@ -1184,7 +1428,7 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) auto toType = kBaseTypes[tt].name; }}}} -__generic extension matrix<$(toType),R,C> +__generic extension matrix<$(toType),R,C,L> { ${{{{ for (int ff = 0; ff < kBaseTypeCount; ++ff) @@ -1202,7 +1446,7 @@ ${{{{ }}}} __implicit_conversion($(cost)) __intrinsic_op($(op)) - __init(matrix<$(fromType),R,C> value); + __init(matrix<$(fromType),R,C,L> value); ${{{{ } }}}} @@ -1215,61 +1459,6 @@ __generic __intrinsic_op(0) T __slang_noop_cast(U u); -__generic -extension vector : IDifferentiable -{ - typedef vector Differential; - - [__unsafeForceInlineEarly] - [BackwardDifferentiable] - static Differential dzero() - { - return Differential(__slang_noop_cast(T.dzero())); - } - - [__unsafeForceInlineEarly] - [BackwardDifferentiable] - static Differential dadd(Differential a, Differential b) - { - return a + b; - } - - __generic - [__unsafeForceInlineEarly] - [BackwardDifferentiable] - static Differential dmul(U a, Differential b) - { - return __realCast(a) * b; - } -} - -__generic -extension matrix : IDifferentiable -{ - typedef matrix Differential; - - [__unsafeForceInlineEarly] - [BackwardDifferentiable] - static Differential dzero() - { - return matrix(__slang_noop_cast(T.dzero())); - } - - [__unsafeForceInlineEarly] - [BackwardDifferentiable] - static Differential dadd(Differential a, Differential b) - { - return a + b; - } - - __generic - [__unsafeForceInlineEarly] - [BackwardDifferentiable] - static Differential dmul(U a, Differential b) - { - return __realCast(a) * b; - } -} //@ public: @@ -1333,17 +1522,20 @@ for (auto op : intrinsicUnaryOps) { char const* resultType = "T"; if (op.flags & BOOL_RESULT) resultType = "bool"; - + // scalar version sb << "__generic\n"; + sb << "[OverloadRank(10)]"; sb << "__prefix __intrinsic_op(" << int(op.opCode) << ") " << resultType << " operator" << op.opName << "(" << "T value);\n"; // vector version sb << "__generic "; + sb << "[OverloadRank(10)]"; sb << "__prefix __intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(vector value);\n"; // matrix version sb << "__generic "; + sb << "[OverloadRank(10)]"; sb << "__prefix __intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(matrix value);\n"; } } @@ -1369,7 +1561,7 @@ Ptr operator-(Ptr value, int64_t offset) return __getElementPtr(value, -offset); } -__generic +__generic [__unsafeForceInlineEarly] __prefix T operator+(T value) { return value; } @@ -1532,28 +1724,35 @@ for (auto op : intrinsicBinaryOps) // scalar version sb << "__generic\n"; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") " << resultType << " operator" << op.opName << "(" << leftType << " left, " << rightType << " right);\n"; // vector version sb << "__generic "; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(vector<" << leftType << ",N> left, vector<" << rightType << ",N> right);\n"; // matrix version sb << "__generic "; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(matrix<" << leftType << ",N,M> left, matrix<" << rightType << ",N,M> right);\n"; // scalar-vector and scalar-matrix sb << "__generic "; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(" << leftType << " left, vector<" << rightType << ",N> right);\n"; sb << "__generic "; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(" << leftType << " left, matrix<" << rightType << ",N,M> right);\n"; // vector-scalar and matrix-scalar sb << "__generic "; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(vector<" << leftType << ",N> left, " << rightType << " right);\n"; sb << "__generic "; + sb << "[OverloadRank(10)]"; sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(matrix<" << leftType << ",N,M> left, " << rightType << " right);\n"; } } @@ -1848,152 +2047,164 @@ bool operator!=(E left, E right); // public interfaces for generic arithmetic types. -interface IComparable -{ - bool equals(This other); - bool lessThan(This other); - bool lessThanOrEquals(This other); -} - -__attributeTarget(DeclBase) -attribute_syntax [TreatAsDifferentiable] : TreatAsDifferentiableAttribute; - -[TreatAsDifferentiable] -interface IArithmetic : IComparable -{ - This add(This other); - This sub(This other); - This mul(This other); - This div(This other); - This mod(This other); - This neg(); - __init(int val); - static const This maxValue; - static const This minValue; -} - -interface IInteger : IArithmetic -{ - This shl(int value); - This shr(int value); - This bitAnd(This other); - This bitOr(This other); - This bitXor(This other); - This bitNot(); - int toInt(); - int64_t toInt64(); - uint toUInt(); - uint64_t toUInt64(); -} - -interface IFloat : IArithmetic -{ - __init(float value); - float toFloat(); -} - __generic [__unsafeForceInlineEarly] +[OverloadRank(-10)] bool operator<(T v0, T v1) { return v0.lessThan(v1); } __generic [__unsafeForceInlineEarly] +[OverloadRank(-10)] bool operator>(T v0, T v1) { return v1.lessThan(v0); } __generic [__unsafeForceInlineEarly] +[OverloadRank(-10)] bool operator ==(T v0, T v1) { return v0.equals(v1); } __generic [__unsafeForceInlineEarly] +[OverloadRank(-10)] bool operator >=(T v0, T v1) { return v1.lessThan(v1); } __generic [__unsafeForceInlineEarly] +[OverloadRank(-10)] bool operator <=(T v0, T v1) { return v0.lessThanOrEquals(v1); } __generic [__unsafeForceInlineEarly] +[OverloadRank(-10)] bool operator !=(T v0, T v1) { return !v0.equals(v1); } -__generic +${{{{ +const char* arithmeticInterfaces[] = {"IArithmetic", "IFloat"}; +const char* attribs[] = {"", "[TreatAsDifferentiable]"}; +for (Index i = 0; i < 2; i++) { + const auto interfaceName = arithmeticInterfaces[i]; + const auto attrib = attribs[i]; + Index overloadRank = i - 3; +}}}} +$(attrib) +__generic [__unsafeForceInlineEarly] +[OverloadRank($(overloadRank))] T operator +(T v0, T v1) { return v0.add(v1); } -__generic +$(attrib) +__generic [__unsafeForceInlineEarly] +[OverloadRank($(overloadRank))] T operator -(T v0, T v1) { return v0.sub(v1); } -__generic +$(attrib) +__generic [__unsafeForceInlineEarly] +[OverloadRank($(overloadRank))] T operator *(T v0, T v1) { return v0.mul(v1); } -__generic +$(attrib) +__generic [__unsafeForceInlineEarly] +[OverloadRank($(overloadRank))] T operator /(T v0, T v1) { return v0.div(v1); } -__generic +$(attrib) +__generic [__unsafeForceInlineEarly] +[OverloadRank($(overloadRank))] T operator %(T v0, T v1) { return v0.mod(v1); } -__generic +$(attrib) +__generic [__unsafeForceInlineEarly] +[OverloadRank($(overloadRank))] __prefix T operator -(T v0) { return v0.neg(); } -__generic + +${{{{ + } // foreach ["IArithmetic", "IFloat"] +}}}} + +__generic [__unsafeForceInlineEarly] +[OverloadRank(-10)] T operator &(T v0, T v1) { return v0.bitAnd(v1); } -__generic +__generic [__unsafeForceInlineEarly] +[OverloadRank(-10)] +T operator &&(T v0, T v1) +{ + return v0.and(v1); +} +__generic +[__unsafeForceInlineEarly] +[OverloadRank(-10)] T operator |(T v0, T v1) { return v0.bitOr(v1); } -__generic +__generic +[__unsafeForceInlineEarly] +[OverloadRank(-10)] +T operator ||(T v0, T v1) +{ + return v0.or(v1); +} +__generic [__unsafeForceInlineEarly] +[OverloadRank(-10)] T operator ^(T v0, T v1) { return v0.bitXor(v1); } -__generic +__generic [__unsafeForceInlineEarly] +[OverloadRank(-10)] __prefix T operator ~(T v0) { return v0.bitNot(); } +__generic +[__unsafeForceInlineEarly] +[OverloadRank(-10)] +__prefix T operator !(T v0) +{ + return v0.not(); +} // IR level type traits. @@ -2089,61 +2300,6 @@ bool __isVector() return __isVector_impl(__declVal()); } -// Provide implementations to public generic arithmetic interfaces for builtin types. - -${{{{ -// Code gen integer type implementations. - -for (int tt = 0; tt < kBaseTypeCount; ++tt) -{ - if (kBaseTypes[tt].flags & (SINT_MASK | UINT_MASK)) - { -}}}} -extension $(kBaseTypes[tt].name) : IInteger -{ - [__unsafeForceInlineEarly] bool equals(This other){return this==other;} - [__unsafeForceInlineEarly] bool lessThan(This other){return this()` is more specialized than `T overload2()` + /// and preferred during overload resolution. +class OverloadRankAttribute : public Attribute +{ + SLANG_AST_CLASS(OverloadRankAttribute) + int32_t rank; +}; + /// An attribute that marks an interface for specialization use only. Any operation that triggers dynamic /// dispatch through the interface is a compile-time error. class SpecializeAttribute : public Attribute diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 43b73892c..93c53a975 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1482,6 +1482,9 @@ namespace Slang // Cached dictionary for looking up satisfying values. SLANG_UNREFLECTED RequirementDictionary m_requirementDictionary; + + RefPtr specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst); + }; struct SpecializationParam diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 5ca0af1c9..2f8b28afc 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2225,11 +2225,15 @@ namespace Slang { // check the base expression first expr->functionExpr = CheckTerm(expr->functionExpr); + + auto treatAsDifferentiableExpr = m_treatAsDifferentiableExpr; + m_treatAsDifferentiableExpr = nullptr; // Next check the argument expressions for (auto & arg : expr->arguments) { arg = CheckTerm(arg); } + m_treatAsDifferentiableExpr = treatAsDifferentiableExpr; // If we are in a differentiable function, register differential witness tables involved in // this call. diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index e8ae28c04..569804ff4 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -349,6 +349,19 @@ namespace Slang anyValueSizeAttr->size = int32_t(value->getValue()); } + else if (auto overloadRankAttr = as(attr)) + { + if (attr->args.getCount() != 1) + { + return false; + } + auto rank = checkConstantIntVal(attr->args[0]); + if (rank == nullptr) + { + return false; + } + overloadRankAttr->rank = int32_t(rank->getValue()); + } else if (auto bindingAttr = as(attr)) { // This must be vk::binding or gl::binding (as specified in core.meta.slang under vk_binding/gl_binding) diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index c668155df..27062fc0c 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1070,6 +1070,15 @@ namespace Slang return 0; } + int getOverloadRank(DeclRef declRef) + { + if (!declRef.getDecl()) + return 0; + if (auto attr = declRef.getDecl()->findModifier()) + return attr->rank; + return 0; + } + int SemanticsVisitor::CompareOverloadCandidates( OverloadCandidate* left, OverloadCandidate* right) @@ -1142,6 +1151,11 @@ namespace Slang auto specificityDiff = compareOverloadCandidateSpecificity(left->item, right->item); if(specificityDiff) return specificityDiff; + + // If we reach here, we will attempt to use overload rank to break the ties. + auto overloadRankDiff = getOverloadRank(right->item.declRef) - getOverloadRank(left->item.declRef); + if (overloadRankDiff) + return overloadRankDiff; } return 0; diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp index 34b56bfef..63ca32650 100644 --- a/source/slang/slang-ir-constexpr.cpp +++ b/source/slang/slang-ir-constexpr.cpp @@ -56,6 +56,9 @@ bool isConstExpr(IRInst* value) case kIROp_FloatLit: case kIROp_BoolLit: case kIROp_Func: + case kIROp_StructKey: + case kIROp_WitnessTable: + case kIROp_Generic: return true; default: @@ -136,6 +139,8 @@ bool opCanBeConstExpr(IROp op) case kIROp_GetOptionalValue: case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: + case kIROp_LookupWitness: + case kIROp_Specialize: // TODO: more cases return true; @@ -146,10 +151,8 @@ bool opCanBeConstExpr(IROp op) bool opCanBeConstExprByForwardPass(IRInst* value) { - // TODO: realistically need to special-case `call` - // operations here, so that we check whether the - // callee function is fixed/known, and if it is - // whether it has been declared as constant-foldable + // TODO: handle call inst here. + if (value->getOp() == kIROp_Param) return false; return opCanBeConstExpr(value->getOp()); diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index b6b7823c9..d7713618e 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -317,7 +317,7 @@ struct PeepholeContext : InstPassBase } else { - changed = tryFoldElementExtractFromUpdateInst(inst); + changed |= tryFoldElementExtractFromUpdateInst(inst); } break; case kIROp_GetElement: @@ -382,7 +382,7 @@ struct PeepholeContext : InstPassBase } else { - changed = tryFoldElementExtractFromUpdateInst(inst); + changed |= tryFoldElementExtractFromUpdateInst(inst); } break; case kIROp_UpdateElement: @@ -806,7 +806,7 @@ struct PeepholeContext : InstPassBase case kIROp_Div: case kIROp_And: case kIROp_Or: - changed = tryOptimizeArithmeticInst(inst); + changed |= tryOptimizeArithmeticInst(inst); break; case kIROp_Param: { diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index d24fd239d..ed2ce048b 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -245,6 +245,22 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return m_obj.as(); } + RefPtr WitnessTable::specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst) + { + auto newBaseType = baseType->substitute(astBuilder, subst); + auto newWitnessedType = witnessedType->substitute(astBuilder, subst); + if (newBaseType == baseType && newWitnessedType == witnessedType) + return this; + RefPtr result = new WitnessTable(); + result->baseType = as(newBaseType); + result->witnessedType = as(newWitnessedType); + for (auto requirement : m_requirements) + { + auto newRequirement = requirement.value.specialize(astBuilder, subst); + result->add(requirement.key, newRequirement); + } + return result; + } RequirementWitness RequirementWitness::specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst) { @@ -256,8 +272,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return RequirementWitness(); case RequirementWitness::Flavor::witnessTable: - SLANG_ASSERT(!subst); - return *this; + return RequirementWitness(this->getWitnessTable()->specialize(astBuilder, subst)); case RequirementWitness::Flavor::declRef: { -- cgit v1.2.3