summaryrefslogtreecommitdiff
path: root/source/slang/slang-emit-cuda.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-emit-cuda.cpp')
-rw-r--r--source/slang/slang-emit-cuda.cpp149
1 files changed, 149 insertions, 0 deletions
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index 702543fc8..ac1e1ea63 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -349,6 +349,138 @@ void CUDASourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* d
}
}
+static bool _areEquivalent(IRType* a, IRType* b)
+{
+ if (a == b)
+ {
+ return true;
+ }
+ if (a->op != b->op)
+ {
+ return false;
+ }
+
+ switch (a->op)
+ {
+ case kIROp_VectorType:
+ {
+ IRVectorType* vecA = static_cast<IRVectorType*>(a);
+ IRVectorType* vecB = static_cast<IRVectorType*>(b);
+ return GetIntVal(vecA->getElementCount()) == GetIntVal(vecB->getElementCount()) &&
+ _areEquivalent(vecA->getElementType(), vecB->getElementType());
+ }
+ case kIROp_MatrixType:
+ {
+ IRMatrixType* matA = static_cast<IRMatrixType*>(a);
+ IRMatrixType* matB = static_cast<IRMatrixType*>(b);
+ return GetIntVal(matA->getColumnCount()) == GetIntVal(matB->getColumnCount()) &&
+ GetIntVal(matA->getRowCount()) == GetIntVal(matB->getRowCount()) &&
+ _areEquivalent(matA->getElementType(), matB->getElementType());
+ }
+ default:
+ {
+ return as<IRBasicType>(a) != nullptr;
+ }
+ }
+}
+
+void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value)
+{
+ // When constructing a matrix or vector from a single value this is handled by the default path
+
+ switch (value->op)
+ {
+ case kIROp_Construct:
+ case kIROp_MakeMatrix:
+ case kIROp_makeVector:
+ {
+ IRType* type = value->getDataType();
+
+ // If the types are the same, we can can just break down and use
+ if (_areEquivalent(dstType, type))
+ {
+ if (auto vecType = as<IRVectorType>(type))
+ {
+ if (UInt(GetIntVal(vecType->getElementCount())) == value->getOperandCount())
+ {
+ _emitInitializerList(vecType->getElementType(), value->getOperands(), value->getOperandCount());
+ return;
+ }
+ }
+ else if (auto matType = as<IRMatrixType>(type))
+ {
+ const Index colCount = Index(GetIntVal(matType->getColumnCount()));
+ const Index rowCount = Index(GetIntVal(matType->getRowCount()));
+
+ // TODO(JS): If num cols = 1, then it *doesn't* actually return a vector.
+ // That could be argued is an error because we want swizzling or [] to work.
+ IRType* rowType = m_typeSet.addVectorType(matType->getElementType(), int(colCount));
+ IRVectorType* rowVectorType = as<IRVectorType>(rowType);
+ const Index operandCount = Index(value->getOperandCount());
+
+ // Can init, with vectors.
+ // For now special case if the rowVectorType is not actually a vector (when elementSize == 1)
+ if (operandCount == rowCount || rowVectorType == nullptr)
+ {
+ // We have to output vectors
+
+ // Emit the braces for the Matrix struct, contains an row array.
+ m_writer->emit("{\n");
+ m_writer->indent();
+ _emitInitializerList(rowType, value->getOperands(), rowCount);
+ m_writer->dedent();
+ m_writer->emit("\n}");
+ return;
+ }
+ else if (operandCount == rowCount * colCount)
+ {
+ // Handle if all are explicitly defined
+ IRType* elementType = matType->getElementType();
+ IRUse* operands = value->getOperands();
+
+ // Emit the braces for the Matrix struct, and the array of rows
+ m_writer->emit("{\n");
+ m_writer->indent();
+ m_writer->emit("{\n");
+ m_writer->indent();
+ for (Index i = 0; i < rowCount; ++i)
+ {
+ if (i != 0) m_writer->emit(", ");
+ _emitInitializerList(elementType, operands, colCount);
+ operands += colCount;
+ }
+ m_writer->dedent();
+ m_writer->emit("\n}");
+ m_writer->dedent();
+ m_writer->emit("\n}");
+ return;
+ }
+ }
+ }
+
+ break;
+ }
+ }
+
+ // All other cases we just use the default emitting - might not work on arrays defined in global scope on CUDA though
+ emitOperand(value, getInfo(EmitOp::General));
+}
+
+void CUDASourceEmitter::_emitInitializerList(IRType* elementType, IRUse* operands, Index operandCount)
+{
+ m_writer->emit("{\n");
+ m_writer->indent();
+
+ for (Index i = 0; i < operandCount; ++i)
+ {
+ if (i != 0) m_writer->emit(", ");
+ _emitInitializerListValue(elementType, operands[i].get());
+ }
+
+ m_writer->dedent();
+ m_writer->emit("\n}");
+}
+
bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
switch(inst->op)
@@ -369,6 +501,23 @@ bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
}
break;
}
+ case kIROp_makeArray:
+ {
+ IRType* dataType = inst->getDataType();
+ IRArrayType* arrayType = as<IRArrayType>(dataType);
+
+ IRType* elementType = arrayType->getElementType();
+
+ // Emit braces for the FixedArray struct.
+ m_writer->emit("{\n");
+ m_writer->indent();
+
+ _emitInitializerList(elementType, inst->getOperands(), Index(inst->getOperandCount()));
+
+ m_writer->dedent();
+ m_writer->emit("\n}");
+ return true;
+ }
default: break;
}