summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--prelude/slang-cuda-prelude.h88
-rw-r--r--source/slang/slang-emit-cpp.cpp100
-rw-r--r--source/slang/slang-emit-cpp.h9
-rw-r--r--source/slang/slang-emit-cuda.cpp21
-rw-r--r--source/slang/slang-emit-cuda.h2
-rw-r--r--tests/compute/half-calc.slang4
-rw-r--r--tests/compute/half-calc.slang.expected.txt8
-rw-r--r--tests/compute/half-structured-buffer.slang2
-rw-r--r--tests/compute/half-vector-calc.slang4
-rw-r--r--tests/compute/half-vector-calc.slang.expected.txt8
-rw-r--r--tests/compute/half-vector-compare.slang98
-rw-r--r--tests/compute/half-vector-compare.slang.expected.txt5
12 files changed, 279 insertions, 70 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 05b978cf6..a627cc652 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -5,6 +5,9 @@
// are passed down.
#ifdef SLANG_CUDA_ENABLE_HALF
+// We don't want half2 operators, because it will implement comparison operators that return a bool(!). We want to generate
+// those functions. Doing so means that we will have to define all the other half2 operators.
+# define __CUDA_NO_HALF2_OPERATORS__
# include <cuda_fp16.h>
#endif
@@ -155,6 +158,7 @@ union Union64
struct __half3 { __half2 xy; __half z; };
struct __half4 { __half2 xy; __half2 zw; };
+// *** convert ***
// half -> other
@@ -196,7 +200,43 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 convert___half2(const double2& v) { r
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 convert___half3(const double3& v) { return __half3{ __float22half2_rn(float2{v.x, v.y}), __float2half_rn(v.z) }; }
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 convert___half4(const double4& v) { return __half4{ __float22half2_rn(float2{v.x, v.y}), __float22half2_rn(float2{v.z, v.w}) }; }
-// half2
+// *** make ***
+
+// Mechanism to make half vectors
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 make___half2(__half x, __half y) { return __halves2half2(x, y); }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 make___half3(__half x, __half y, __half z) { return __half3{ __halves2half2(x, y), z }; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 make___half4(__half x, __half y, __half z, __half w) { return __half4{ __halves2half2(x, y), __halves2half2(z, w)}; }
+
+// *** constructFromScalar ***
+
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 constructFromScalar___half2(half x) { return __half2half2(x); }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 constructFromScalar___half3(half x) { return __half3{__half2half2(x), x}; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 constructFromScalar___half4(half x) { const __half2 v = __half2half2(x); return __half4{v, v}; }
+
+// *** half2 ***
+
+// half2 maths ops
+
+// NOTE! That by default these are in cuda_fp16.hpp, but we disable them, because we need to define the comparison operators
+// as we need versions that will return vector<bool>
+
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2& lh, const __half2& rh) { return __hadd2(lh, rh); }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(const __half2& lh, const __half2& rh) { return __hsub2(lh, rh); }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator*(const __half2& lh, const __half2& rh) { return __hmul2(lh, rh); }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator/(const __half2& lh, const __half2& rh) { return __h2div(lh, rh); }
+
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator+=(__half2& lh, const __half2& rh) { lh = __hadd2(lh, rh); return lh; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator-=(__half2& lh, const __half2& rh) { lh = __hsub2(lh, rh); return lh; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator*=(__half2& lh, const __half2& rh) { lh = __hmul2(lh, rh); return lh; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator/=(__half2& lh, const __half2& rh) { lh = __h2div(lh, rh); return lh; }
+
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 &operator++(__half2 &h) { __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hadd2(h, one); return h; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 &operator--(__half2 &h) { __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hsub2(h, one); return h; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator++(__half2 &h, int) { __half2 ret = h; __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hadd2(h, one); return ret; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator--(__half2 &h, int) { __half2 ret = h; __half2_raw one; one.x = 0x3C00; one.y = 0x3C00; h = __hsub2(h, one); return ret; }
+
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2 &h) { return h; }
+SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(const __half2 &h) { return __hneg2(h); }
// vec op scalar
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2& lh, __half rh) { return __hadd2(lh, __half2half2(rh)); }
@@ -210,16 +250,7 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(__half lh, const __half2& r
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator*(__half lh, const __half2& rh) { return __hmul2(__half2half2(lh), rh); }
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator/(__half lh, const __half2& rh) { return __h2div(__half2half2(lh), rh); }
-// Mechanism to make half vectors
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 make___half2(__half x, __half y) { return __halves2half2(x, y); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 make___half3(__half x, __half y, __half z) { return __half3{ __halves2half2(x, y), z }; }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 make___half4(__half x, __half y, __half z, __half w) { return __half4{ __halves2half2(x, y), __halves2half2(z, w)}; }
-
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 constructFromScalar___half2(half x) { return __half2half2(x); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 constructFromScalar___half3(half x) { return __half3{__half2half2(x), x}; }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 constructFromScalar___half4(half x) { const __half2 v = __half2half2(x); return __half4{v, v}; }
-
-// Half3 maths ops
+// *** half3 ***
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator+(const __half3& lh, const __half3& rh) { return __half3{__hadd2(lh.xy, rh.xy), __hadd(lh.z, rh.z)}; }
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator-(const __half3& lh, const __half3& rh) { return __half3{__hsub2(lh.xy, rh.xy), __hsub(lh.z, rh.z)}; }
@@ -241,18 +272,7 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator-(__half lh, const __half3& r
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator*(__half lh, const __half3& rh) { return __half3{__hmul2(__half2half2(lh), rh.xy), __hmul(lh, rh.z)}; }
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 operator/(__half lh, const __half3& rh) { return __half3{__h2div(__half2half2(lh), rh.xy), __hdiv(lh, rh.z)}; }
-
-#if 0
-// We need to return the vector<bool> type
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator==(const __half3& lh, const __half3& rh) { return __hbeq2(lh.xy, rh.xy) && __heq(lh.z, rh.z); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator!=(const __half3& lh, const __half3& rh) { return __hbneu2(lh.xy, rh.xy) && __hneu(lh.z, rh.z); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator>(const __half3& lh, const __half3& rh) { return __hbgt2(lh.xy, rh.xy) && __hgt(lh.z, rh.z); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator<(const __half3& lh, const __half3& rh) { return __hblt2(lh.xy, rh.xy) && __hlt(lh.z, rh.z); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator>=(const __half3& lh, const __half3& rh) { return __hbge2(lh.xy, rh.xy) && __hge(lh.z, rh.z); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator<=(const __half3& lh, const __half3& rh) { return __hble2(lh.xy, rh.xy) && __hle(lh.z, rh.z); }
-#endif
-
-// Half4 maths ops
+// *** half4 ***
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator+(const __half4& lh, const __half4& rh) { return __half4{__hadd2(lh.xy, rh.xy), __hadd2(lh.zw, rh.zw)}; }
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator-(const __half4& lh, const __half4& rh) { return __half4{__hsub2(lh.xy, rh.xy), __hsub2(lh.zw, rh.zw)}; }
@@ -274,28 +294,6 @@ SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator/(__half lh, const __half4& r
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator-(const __half4& h) { return __half4{__hneg2(h.xy), __hneg2(h.zw)}; }
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 operator+(const __half4& h) { return h; }
-#if 0
-// We need to return vector<bool> type
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator==(const __half4& lh, const __half4& rh) { return __hbeq2(lh.xy, rh.xy) && __hbeq2(lh.zw, rh.zw); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator!=(const __half4& lh, const __half4& rh) { return __hbneu2(lh.xy, rh.xy) && __hbneu2(lh.zw, rh.zw); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator>(const __half4& lh, const __half4& rh) { return __hbgt2(lh.xy, rh.xy) && __hbgt2(lh.zw, rh.zw); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator<(const __half4& lh, const __half4& rh) { return __hblt2(lh.xy, rh.xy) && __hblt2(lh.zw, rh.zw); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator>=(const __half4& lh, const __half4& rh) { return __hbge2(lh.xy, rh.xy) && __hbge2(lh.zw, rh.zw); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL bool operator<=(const __half4& lh, const __half4& rh) { return __hble2(lh.xy, rh.xy) && __hble2(lh.zw, rh.zw); }
-#endif
-
-// Use the round nearest as the default - it is the only one defined
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __float22half2(const float2 a) { return __float22half2_rn(a); }
-
-// Implement the vector versions
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __float2half(float2 a) { return __float22half2(a); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 __float2half(float3 a) { __half3 o; o.xy = __float22half2(make_float2(a.x, a.y)); o.z = __float2half(a.z); return o; }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 __float2half(float4 a) { __half4 o; o.xy = __float22half2(make_float2(a.x, a.y)); o.zw = __float22half2(make_float2(a.z, a.w)); return o; }
-
-SLANG_FORCE_INLINE SLANG_CUDA_CALL float2 __half2float(__half2 a) { return __half22float2(a); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL float3 __half2float(__half3 a) { float2 xy = __half22float2(a.xy); float z = __half2float(a.z); return make_float3(xy.x, xy.y, z); }
-SLANG_FORCE_INLINE SLANG_CUDA_CALL float4 __half2float(__half4 a) { float2 xy = __half22float2(a.xy); float2 zw = __half22float2(a.zw); return make_float4(xy.x, xy.y, zw.x, zw.y); }
-
#endif
// ----------------------------- F32 -----------------------------------------
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index bbb974fd4..3c43485cc 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -61,7 +61,7 @@ When called we can have a structure that holds the thread local variables, and t
namespace Slang {
-static const char s_elemNames[] = "xyzw";
+static const char s_xyzwNames[] = "xyzw";
static UnownedStringSlice _getTypePrefix(IROp op)
{
@@ -219,6 +219,9 @@ void CPPSourceEmitter::emitTypeDefinition(IRType* inType)
case kIROp_VectorType:
{
auto vecType = static_cast<IRVectorType*>(type);
+
+ const UnownedStringSlice* elemNames = getVectorElementNames(vecType);
+
int count = int(getIntVal(vecType->getElementCount()));
SLANG_ASSERT(count > 0 && count < 4);
@@ -239,7 +242,7 @@ void CPPSourceEmitter::emitTypeDefinition(IRType* inType)
{
writer->emit(", ");
}
- writer->emitChar(s_elemNames[i]);
+ writer->emit(elemNames[i]);
}
writer->emit(";\n");
@@ -648,22 +651,43 @@ static IRBasicType* _getElementType(IRType* type)
case kIROp_VectorType:
{
auto vecType = static_cast<IRVectorType*>(type);
+
+ IRBasicType* elemBasicType = as<IRBasicType>(vecType->getElementType());
+ const BaseType baseType = elemBasicType->getBaseType();
+
const int elemCount = int(getIntVal(vecType->getElementCount()));
- return (!vecSwap) ? TypeDimension{1, elemCount} : TypeDimension{ elemCount, 1};
+ return (!vecSwap) ? TypeDimension{baseType, 1, elemCount} : TypeDimension{ baseType, elemCount, 1};
}
case kIROp_MatrixType:
{
auto matType = static_cast<IRMatrixType*>(type);
const int colCount = int(getIntVal(matType->getColumnCount()));
const int rowCount = int(getIntVal(matType->getRowCount()));
- return TypeDimension{rowCount, colCount};
+
+ IRBasicType* elemBasicType = as<IRBasicType>(matType->getElementType());
+ const BaseType baseType = elemBasicType->getBaseType();
+
+ return TypeDimension{baseType, rowCount, colCount};
+ }
+ default:
+ {
+ // Assume we don't know the type
+ BaseType baseType = BaseType::Void;
+
+ IRBasicType* basicType = as<IRBasicType>(type);
+ if (basicType)
+ {
+ baseType = basicType->getBaseType();
+ }
+
+ return TypeDimension{baseType, 1, 1};
}
- default: return TypeDimension{1, 1};
}
}
-/* static */void CPPSourceEmitter::_emitAccess(const UnownedStringSlice& name, const TypeDimension& dimension, int row, int col, SourceWriter* writer)
+void CPPSourceEmitter::_emitAccess(const UnownedStringSlice& name, const TypeDimension& dimension, int row, int col, SourceWriter* writer)
{
+
writer->emit(name);
const int comb = (dimension.colCount > 1 ? 2 : 0) | (dimension.rowCount > 1 ? 1 : 0);
switch (comb)
@@ -673,21 +697,32 @@ static IRBasicType* _getElementType(IRType* type)
break;
}
case 1:
+ {
+ // Vector, row count is biggest
+ const UnownedStringSlice* elemNames = getVectorElementNames(dimension.elemType, dimension.rowCount);
+ writer->emit(".");
+ const int index = (row > col) ? row : col;
+ writer->emit(elemNames[index]);
+ break;
+ }
case 2:
{
- // Vector
- int index = (row > col) ? row : col;
+ // Vector cols biggest dimension
+ const UnownedStringSlice* elemNames = getVectorElementNames(dimension.elemType, dimension.colCount);
writer->emit(".");
- writer->emitChar(s_elemNames[index]);
+ const int index = (row > col) ? row : col;
+ writer->emit(elemNames[index]);
break;
}
case 3:
- {
+ {
// Matrix
+ const UnownedStringSlice* elemNames = getVectorElementNames(dimension.elemType, dimension.colCount);
+
writer->emit(".rows[");
writer->emit(row);
writer->emit("].");
- writer->emitChar(s_elemNames[col]);
+ writer->emit(elemNames[col]);
break;
}
}
@@ -1158,9 +1193,11 @@ void CPPSourceEmitter::_emitInitDefinition(const UnownedStringSlice& funcName, c
{
Index paramElementCount = Index(getIntVal(paramVecType->getElementCount()));
+ const UnownedStringSlice* elemNames = getVectorElementNames(paramVecType);
+
writer->emitChar('a' + char(paramIndex));
writer->emit(".");
- writer->emitChar(s_elemNames[paramSubIndex]);
+ writer->emit(elemNames[paramSubIndex]);
paramSubIndex ++;
@@ -1348,6 +1385,11 @@ void CPPSourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const
auto swizzleInst = static_cast<IRSwizzle*>(inst);
const Index elementCount = Index(swizzleInst->getElementCount());
+ IRType* srcType = swizzleInst->getBase()->getDataType();
+ IRVectorType* srcVecType = as<IRVectorType>(srcType);
+
+ const UnownedStringSlice* elemNames = getVectorElementNames(srcVecType);
+
// TODO(JS): Not 100% sure this is correct on the parens handling front
IRType* retType = specOp->returnType;
emitType(retType);
@@ -1373,7 +1415,7 @@ void CPPSourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const
UInt elementIndex = (UInt)irConst->value.intVal;
SLANG_RELEASE_ASSERT(elementIndex < 4);
- writer->emitChar(s_elemNames[elementIndex]);
+ writer->emit(elemNames[elementIndex]);
}
writer->emit("}");
@@ -2119,6 +2161,32 @@ void CPPSourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* de
}
}
+const UnownedStringSlice* CPPSourceEmitter::getVectorElementNames(BaseType baseType, Index elemCount)
+{
+ SLANG_UNUSED(baseType);
+ SLANG_UNUSED(elemCount);
+
+ static const UnownedStringSlice elemNames[] =
+ {
+ UnownedStringSlice::fromLiteral("x"),
+ UnownedStringSlice::fromLiteral("y"),
+ UnownedStringSlice::fromLiteral("z"),
+ UnownedStringSlice::fromLiteral("w"),
+ };
+
+ return elemNames;
+}
+
+const UnownedStringSlice* CPPSourceEmitter::getVectorElementNames(IRVectorType* vectorType)
+{
+ Index elemCount = Index(getIntVal(vectorType->getElementCount()));
+
+ IRType* type = vectorType->getElementType()->getCanonicalType();
+ IRBasicType* basicType = as<IRBasicType>(type);
+ SLANG_ASSERT(basicType);
+ return getVectorElementNames(basicType->getBaseType(), elemCount);
+}
+
bool CPPSourceEmitter::_tryEmitInstExprAsIntrinsic(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
HLSLIntrinsic* specOp = m_intrinsicSet.add(inst);
@@ -2444,7 +2512,7 @@ void CPPSourceEmitter::_emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroup
{
const auto& axis = axes[i];
builder.Clear();
- const char elem[2] = { s_elemNames[axis.axis], 0 };
+ const char elem[2] = { s_xyzwNames[axis.axis], 0 };
builder << "for (uint32_t " << elem << " = 0; " << elem << " < " << axis.size << "; ++" << elem << ")\n{\n";
m_writer->emit(builder);
m_writer->indent();
@@ -2478,7 +2546,7 @@ void CPPSourceEmitter::_emitEntryPointGroupRange(const Int sizeAlongAxis[kThread
{
const auto& axis = axes[i];
builder.Clear();
- const char elem[2] = { s_elemNames[axis.axis], 0 };
+ const char elem[2] = { s_xyzwNames[axis.axis], 0 };
builder << "for (uint32_t " << elem << " = vi.startGroupID." << elem << "; " << elem << " < vi.endGroupID." << elem << "; ++" << elem << ")\n{\n";
m_writer->emit(builder);
@@ -2511,7 +2579,7 @@ void CPPSourceEmitter::_emitInitAxisValues(const Int sizeAlongAxis[kThreadGroupA
for (int i = 0; i < kThreadGroupAxisCount; ++i)
{
builder.Clear();
- const char elem[2] = { s_elemNames[i], 0 };
+ const char elem[2] = { s_xyzwNames[i], 0 };
builder << mulName << "." << elem << " * " << sizeAlongAxis[i];
if (addName.getLength() > 0)
{
diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h
index 2ff922421..5ba509e7d 100644
--- a/source/slang/slang-emit-cpp.h
+++ b/source/slang/slang-emit-cpp.h
@@ -33,6 +33,7 @@ public:
{
bool isScalar() const { return rowCount <= 1 && colCount <= 1; }
+ BaseType elemType;
int rowCount;
int colCount;
};
@@ -74,11 +75,15 @@ protected:
virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE;
+ virtual const UnownedStringSlice* getVectorElementNames(BaseType elemType, Index elemCount);
+
// Replaceable for classes derived from CPPSourceEmitter
virtual SlangResult calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out);
virtual SlangResult calcFuncName(const HLSLIntrinsic* specOp, StringBuilder& out);
virtual SlangResult calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder);
+ const UnownedStringSlice* getVectorElementNames(IRVectorType* vectorType);
+
void _maybeEmitSpecializedOperationDefinition(const HLSLIntrinsic* specOp);
void _emitForwardDeclarations(const List<EmitAction>& actions);
@@ -96,10 +101,12 @@ protected:
void _emitInOutParamType(IRType* type, String const& name, IRType* valueType);
+
UnownedStringSlice _getAndEmitSpecializedOperationDefinition(HLSLIntrinsic::Op op, IRType*const* argTypes, Int argCount, IRType* retType);
static TypeDimension _getTypeDimension(IRType* type, bool vecSwap);
- static void _emitAccess(const UnownedStringSlice& name, const TypeDimension& dimension, int row, int col, SourceWriter* writer);
+
+ void _emitAccess(const UnownedStringSlice& name, const TypeDimension& dimension, int row, int col, SourceWriter* writer);
UnownedStringSlice _getScalarFuncName(HLSLIntrinsic::Op operation, IRBasicType* scalarType);
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index dbe089723..09ea7ef9e 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -207,9 +207,11 @@ void CUDASourceEmitter::emitSpecializedOperationDefinition(const HLSLIntrinsic*
switch (specOp->op)
{
case Op::Init:
+
case Op::Add:
case Op::Mul:
case Op::Div:
+ case Op::Sub:
case Op::Neg:
@@ -331,6 +333,25 @@ SlangResult CUDASourceEmitter::calcTypeName(IRType* type, CodeGenTarget target,
return Super::calcTypeName(type, target, out);
}
+const UnownedStringSlice* CUDASourceEmitter::getVectorElementNames(BaseType baseType, Index elemCount)
+{
+ static const UnownedStringSlice normal[] = { UnownedStringSlice::fromLiteral("x"), UnownedStringSlice::fromLiteral("y"), UnownedStringSlice::fromLiteral("z"), UnownedStringSlice::fromLiteral("w") };
+ static const UnownedStringSlice half3[] = { UnownedStringSlice::fromLiteral("xy.x"), UnownedStringSlice::fromLiteral("xy.y"), UnownedStringSlice::fromLiteral("z") };
+ static const UnownedStringSlice half4[] = { UnownedStringSlice::fromLiteral("xy.x"), UnownedStringSlice::fromLiteral("xy.y"), UnownedStringSlice::fromLiteral("zw.x"), UnownedStringSlice::fromLiteral("zw.y")};
+
+ if (baseType == BaseType::Half)
+ {
+ switch (elemCount)
+ {
+ default: break;
+ case 3: return half3;
+ case 4: return half4;
+ }
+ }
+
+ return normal;
+}
+
void CUDASourceEmitter::emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling)
{
Super::emitLayoutSemanticsImpl(inst, uniformSemanticSpelling);
diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h
index b73948525..d91ea504c 100644
--- a/source/slang/slang-emit-cuda.h
+++ b/source/slang/slang-emit-cuda.h
@@ -77,6 +77,8 @@ protected:
virtual void emitFunctionPreambleImpl(IRInst* inst) SLANG_OVERRIDE;
virtual String generateEntryPointNameImpl(IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE;
+ virtual const UnownedStringSlice* getVectorElementNames(BaseType baseType, Index elemCount) SLANG_OVERRIDE;
+
virtual void emitGlobalRTTISymbolPrefix() SLANG_OVERRIDE;
virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE;
diff --git a/tests/compute/half-calc.slang b/tests/compute/half-calc.slang
index e0dd01315..0f321ef98 100644
--- a/tests/compute/half-calc.slang
+++ b/tests/compute/half-calc.slang
@@ -29,5 +29,9 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
v += half(1.0f);
v += offset;
+ v++;
+ --v;
+ v--;
+
outputBuffer[tid] = v;
} \ No newline at end of file
diff --git a/tests/compute/half-calc.slang.expected.txt b/tests/compute/half-calc.slang.expected.txt
index 2915a0dbc..389e44adf 100644
--- a/tests/compute/half-calc.slang.expected.txt
+++ b/tests/compute/half-calc.slang.expected.txt
@@ -1,4 +1,4 @@
-3F800000
-40800000
-40E00000
-41200000
+0
+40400000
+40C00000
+41100000
diff --git a/tests/compute/half-structured-buffer.slang b/tests/compute/half-structured-buffer.slang
index db0837d53..e701bb0fa 100644
--- a/tests/compute/half-structured-buffer.slang
+++ b/tests/compute/half-structured-buffer.slang
@@ -1,4 +1,6 @@
//TEST(compute):COMPARE_COMPUTE:-vk -compute -profile cs_6_2 -render-features half -shaderobj
+//TEST(compute):COMPARE_COMPUTE:-cuda -compute -render-features half -shaderobj
+
//Disable on Dx12 for now - because writing to structured buffer produces unexpected results
//TEST_DISABLED(compute):COMPARE_COMPUTE:-dx12 -compute -use-dxil -profile cs_6_2 -render-features half -shaderobj
diff --git a/tests/compute/half-vector-calc.slang b/tests/compute/half-vector-calc.slang
index 3ae204796..b145e27ec 100644
--- a/tests/compute/half-vector-calc.slang
+++ b/tests/compute/half-vector-calc.slang
@@ -23,6 +23,10 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
v1 += v2.wzy;
v2 += v0.xyxy;
+ v1 ++;
+ --v2;
+ v3++;
+
// Unary
v2 = +v2.yxwz;
v2 = -v2.zwxy;
diff --git a/tests/compute/half-vector-calc.slang.expected.txt b/tests/compute/half-vector-calc.slang.expected.txt
index 49c339529..2e80e4e2a 100644
--- a/tests/compute/half-vector-calc.slang.expected.txt
+++ b/tests/compute/half-vector-calc.slang.expected.txt
@@ -1,5 +1,5 @@
type: float
-30.000000
-161.500000
-492.000000
-1021.500000
+73.000000
+206.500000
+539.000000
+1070.000000
diff --git a/tests/compute/half-vector-compare.slang b/tests/compute/half-vector-compare.slang
new file mode 100644
index 000000000..5f4670456
--- /dev/null
+++ b/tests/compute/half-vector-compare.slang
@@ -0,0 +1,98 @@
+//DISABLE_TEST(compute):COMPARE_COMPUTE:-dx12 -compute -output-using-type -use-dxil -profile cs_6_2 -render-features half -shaderobj
+//TEST(compute):COMPARE_COMPUTE:-vk -compute -output-using-type -profile cs_6_2 -render-features half -shaderobj
+//TEST(compute):COMPARE_COMPUTE:-cuda -compute -output-using-type -render-features half -shaderobj
+
+// Test for doing a calculation using half
+
+//TEST_INPUT:ubuffer(data=[0.2 10.0 12.0 16.0], stride=4):name=inputBuffer
+RWStructuredBuffer<int> inputBuffer;
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+struct Values
+{
+ __init(int index)
+ {
+ m_index = index;
+ }
+
+ [mutating] half next()
+ {
+ float v = inputBuffer[m_index & 3];
+ m_index++;
+ return half(v);
+ }
+
+ int m_index = 0;
+};
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+
+ Values values = Values(int(tid));
+
+ int r = 0;
+
+ half s0 = values.next();
+ half s1 = values.next();
+
+ if (s0 < s1)
+ {
+ r += 0x1;
+ }
+
+ half2 h2_0 = half2(values.next(), values.next());
+ half2 h2_1 = half2(values.next(), values.next());
+
+ if (any(h2_0 < h2_1))
+ {
+ r += 0x2;
+ }
+
+ if (all(h2_0 < h2_1))
+ {
+ r += 0x4;
+ }
+
+ half3 h3_0 = half3(values.next(), values.next(), values.next());
+ half3 h3_1 = half3(values.next(), values.next(), values.next());
+
+ if (any(h3_0 > h3_1))
+ {
+ r += 0x8;
+ }
+
+ if (all(h3_0 <= h3_1))
+ {
+ r += 0x10;
+ }
+
+ half4 h4_0 = half4(values.next(), values.next(), values.next(), values.next());
+ half4 h4_1 = half4(values.next(), values.next(), values.next(), values.next());
+
+
+ if (any(h4_0 > h4_1))
+ {
+ r += 0x8;
+ }
+
+ if (all(h4_0 <= h4_1))
+ {
+ r += 0x10;
+ }
+
+ if (any(!(h4_0 == h4_1)))
+ {
+ r += 0x20;
+ }
+
+ if (all(h4_0 != h4_1))
+ {
+ r += 0x40;
+ }
+
+ outputBuffer[tid] = r;
+}
diff --git a/tests/compute/half-vector-compare.slang.expected.txt b/tests/compute/half-vector-compare.slang.expected.txt
new file mode 100644
index 000000000..51c83b301
--- /dev/null
+++ b/tests/compute/half-vector-compare.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+32.000000
+32.000000
+32.000000
+32.000000