summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-cpp.cpp
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2021-05-04 16:24:51 -0400
committerGitHub <noreply@github.com>2021-05-04 13:24:51 -0700
commit731f1fc6b26659dc8f62fbc1969c076b78ada24f (patch)
tree9fcc4d1d931049edeabe3cea46d0bd6942956042 /source/slang/slang-emit-cpp.cpp
parentdc571f1291f6b82b189a0db52c0468ae2fc7af4b (diff)
CUDA half comparison support (#1834)
* #include an absolute path didn't work - because paths were taken to always be relative. * Split out StringEscapeUtil. * Added StringEscapeUtil. * Fix typo in unix quoting type. * Small comment improvements. * Try to fix linux linking issue. * Fix typo. * Attempt to fix linux link issue. * Update VS proj even though nothing really changed. * Fix another typo issue. * Fix for windows issue. Fixed bug. * Make separate Utils for escaping. * Fix typo. * Split out into StringEscapeHandler. * Windows shell does handle removing quotes (so remove code to remove them). * Handle unescaping if not initiating using the shell. * Slight improvement around shell like decoding. * Simplify command extraction. * Add shared-library category type. * Fix bug in command extraction. * Typo in transcendental category. * Enable unit-test on in smoke test category. * Make parsing failing output as a failing test. * Fixes for transcendental tests. Disable tests that do not work. * Changed category parsing. * Removed the TestResult parameter from _gatherTestsForFile. Made testsList only output. * Remove testing if all tests were disabled. * Make args of CommandLine always unescaped. * Add category. * Don't need escaping on unix/linux. * Remove some no longer used functions. * Add requireSMVersion to CUDAExtensionTracker. * half-calc.slang now works for CUDA. * bit-cast-16-bit works on CUDA. * WIP handling of CUDA vector<half> types. * Half swizzle CUDA. * Half vector test. * Fix swizzle half bug. * Fix compilation issue with narrowing to Index. * Add unary ops. * Add some vector scalar maths ops. * Add half vector conversions for CUDA. * Fix erroneous comment. * Support for half comparisons. * First pass test for half compare. * Fix bug in CUDA specialized emit control. Updated tests to have pre and post inc/dec. * Removed unneeded parts of the cuda prelude. * Half structured buffer works on CUDA. Co-authored-by: Tim Foley <tfoleyNV@users.noreply.github.com>
Diffstat (limited to 'source/slang/slang-emit-cpp.cpp')
-rw-r--r--source/slang/slang-emit-cpp.cpp100
1 files changed, 84 insertions, 16 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index bbb974fd4..3c43485cc 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -61,7 +61,7 @@ When called we can have a structure that holds the thread local variables, and t
namespace Slang {
-static const char s_elemNames[] = "xyzw";
+static const char s_xyzwNames[] = "xyzw";
static UnownedStringSlice _getTypePrefix(IROp op)
{
@@ -219,6 +219,9 @@ void CPPSourceEmitter::emitTypeDefinition(IRType* inType)
case kIROp_VectorType:
{
auto vecType = static_cast<IRVectorType*>(type);
+
+ const UnownedStringSlice* elemNames = getVectorElementNames(vecType);
+
int count = int(getIntVal(vecType->getElementCount()));
SLANG_ASSERT(count > 0 && count < 4);
@@ -239,7 +242,7 @@ void CPPSourceEmitter::emitTypeDefinition(IRType* inType)
{
writer->emit(", ");
}
- writer->emitChar(s_elemNames[i]);
+ writer->emit(elemNames[i]);
}
writer->emit(";\n");
@@ -648,22 +651,43 @@ static IRBasicType* _getElementType(IRType* type)
case kIROp_VectorType:
{
auto vecType = static_cast<IRVectorType*>(type);
+
+ IRBasicType* elemBasicType = as<IRBasicType>(vecType->getElementType());
+ const BaseType baseType = elemBasicType->getBaseType();
+
const int elemCount = int(getIntVal(vecType->getElementCount()));
- return (!vecSwap) ? TypeDimension{1, elemCount} : TypeDimension{ elemCount, 1};
+ return (!vecSwap) ? TypeDimension{baseType, 1, elemCount} : TypeDimension{ baseType, elemCount, 1};
}
case kIROp_MatrixType:
{
auto matType = static_cast<IRMatrixType*>(type);
const int colCount = int(getIntVal(matType->getColumnCount()));
const int rowCount = int(getIntVal(matType->getRowCount()));
- return TypeDimension{rowCount, colCount};
+
+ IRBasicType* elemBasicType = as<IRBasicType>(matType->getElementType());
+ const BaseType baseType = elemBasicType->getBaseType();
+
+ return TypeDimension{baseType, rowCount, colCount};
+ }
+ default:
+ {
+ // Assume we don't know the type
+ BaseType baseType = BaseType::Void;
+
+ IRBasicType* basicType = as<IRBasicType>(type);
+ if (basicType)
+ {
+ baseType = basicType->getBaseType();
+ }
+
+ return TypeDimension{baseType, 1, 1};
}
- default: return TypeDimension{1, 1};
}
}
-/* static */void CPPSourceEmitter::_emitAccess(const UnownedStringSlice& name, const TypeDimension& dimension, int row, int col, SourceWriter* writer)
+void CPPSourceEmitter::_emitAccess(const UnownedStringSlice& name, const TypeDimension& dimension, int row, int col, SourceWriter* writer)
{
+
writer->emit(name);
const int comb = (dimension.colCount > 1 ? 2 : 0) | (dimension.rowCount > 1 ? 1 : 0);
switch (comb)
@@ -673,21 +697,32 @@ static IRBasicType* _getElementType(IRType* type)
break;
}
case 1:
+ {
+ // Vector, row count is biggest
+ const UnownedStringSlice* elemNames = getVectorElementNames(dimension.elemType, dimension.rowCount);
+ writer->emit(".");
+ const int index = (row > col) ? row : col;
+ writer->emit(elemNames[index]);
+ break;
+ }
case 2:
{
- // Vector
- int index = (row > col) ? row : col;
+ // Vector cols biggest dimension
+ const UnownedStringSlice* elemNames = getVectorElementNames(dimension.elemType, dimension.colCount);
writer->emit(".");
- writer->emitChar(s_elemNames[index]);
+ const int index = (row > col) ? row : col;
+ writer->emit(elemNames[index]);
break;
}
case 3:
- {
+ {
// Matrix
+ const UnownedStringSlice* elemNames = getVectorElementNames(dimension.elemType, dimension.colCount);
+
writer->emit(".rows[");
writer->emit(row);
writer->emit("].");
- writer->emitChar(s_elemNames[col]);
+ writer->emit(elemNames[col]);
break;
}
}
@@ -1158,9 +1193,11 @@ void CPPSourceEmitter::_emitInitDefinition(const UnownedStringSlice& funcName, c
{
Index paramElementCount = Index(getIntVal(paramVecType->getElementCount()));
+ const UnownedStringSlice* elemNames = getVectorElementNames(paramVecType);
+
writer->emitChar('a' + char(paramIndex));
writer->emit(".");
- writer->emitChar(s_elemNames[paramSubIndex]);
+ writer->emit(elemNames[paramSubIndex]);
paramSubIndex ++;
@@ -1348,6 +1385,11 @@ void CPPSourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const
auto swizzleInst = static_cast<IRSwizzle*>(inst);
const Index elementCount = Index(swizzleInst->getElementCount());
+ IRType* srcType = swizzleInst->getBase()->getDataType();
+ IRVectorType* srcVecType = as<IRVectorType>(srcType);
+
+ const UnownedStringSlice* elemNames = getVectorElementNames(srcVecType);
+
// TODO(JS): Not 100% sure this is correct on the parens handling front
IRType* retType = specOp->returnType;
emitType(retType);
@@ -1373,7 +1415,7 @@ void CPPSourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const
UInt elementIndex = (UInt)irConst->value.intVal;
SLANG_RELEASE_ASSERT(elementIndex < 4);
- writer->emitChar(s_elemNames[elementIndex]);
+ writer->emit(elemNames[elementIndex]);
}
writer->emit("}");
@@ -2119,6 +2161,32 @@ void CPPSourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* de
}
}
+const UnownedStringSlice* CPPSourceEmitter::getVectorElementNames(BaseType baseType, Index elemCount)
+{
+ SLANG_UNUSED(baseType);
+ SLANG_UNUSED(elemCount);
+
+ static const UnownedStringSlice elemNames[] =
+ {
+ UnownedStringSlice::fromLiteral("x"),
+ UnownedStringSlice::fromLiteral("y"),
+ UnownedStringSlice::fromLiteral("z"),
+ UnownedStringSlice::fromLiteral("w"),
+ };
+
+ return elemNames;
+}
+
+const UnownedStringSlice* CPPSourceEmitter::getVectorElementNames(IRVectorType* vectorType)
+{
+ Index elemCount = Index(getIntVal(vectorType->getElementCount()));
+
+ IRType* type = vectorType->getElementType()->getCanonicalType();
+ IRBasicType* basicType = as<IRBasicType>(type);
+ SLANG_ASSERT(basicType);
+ return getVectorElementNames(basicType->getBaseType(), elemCount);
+}
+
bool CPPSourceEmitter::_tryEmitInstExprAsIntrinsic(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
HLSLIntrinsic* specOp = m_intrinsicSet.add(inst);
@@ -2444,7 +2512,7 @@ void CPPSourceEmitter::_emitEntryPointGroup(const Int sizeAlongAxis[kThreadGroup
{
const auto& axis = axes[i];
builder.Clear();
- const char elem[2] = { s_elemNames[axis.axis], 0 };
+ const char elem[2] = { s_xyzwNames[axis.axis], 0 };
builder << "for (uint32_t " << elem << " = 0; " << elem << " < " << axis.size << "; ++" << elem << ")\n{\n";
m_writer->emit(builder);
m_writer->indent();
@@ -2478,7 +2546,7 @@ void CPPSourceEmitter::_emitEntryPointGroupRange(const Int sizeAlongAxis[kThread
{
const auto& axis = axes[i];
builder.Clear();
- const char elem[2] = { s_elemNames[axis.axis], 0 };
+ const char elem[2] = { s_xyzwNames[axis.axis], 0 };
builder << "for (uint32_t " << elem << " = vi.startGroupID." << elem << "; " << elem << " < vi.endGroupID." << elem << "; ++" << elem << ")\n{\n";
m_writer->emit(builder);
@@ -2511,7 +2579,7 @@ void CPPSourceEmitter::_emitInitAxisValues(const Int sizeAlongAxis[kThreadGroupA
for (int i = 0; i < kThreadGroupAxisCount; ++i)
{
builder.Clear();
- const char elem[2] = { s_elemNames[i], 0 };
+ const char elem[2] = { s_xyzwNames[i], 0 };
builder << mulName << "." << elem << " * " << sizeAlongAxis[i];
if (addName.getLength() > 0)
{