diff options
Diffstat (limited to 'source/slang/slang-emit-cuda.cpp')
| -rw-r--r-- | source/slang/slang-emit-cuda.cpp | 149 |
1 files changed, 149 insertions, 0 deletions
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index 702543fc8..ac1e1ea63 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -349,6 +349,138 @@ void CUDASourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* d } } +static bool _areEquivalent(IRType* a, IRType* b) +{ + if (a == b) + { + return true; + } + if (a->op != b->op) + { + return false; + } + + switch (a->op) + { + case kIROp_VectorType: + { + IRVectorType* vecA = static_cast<IRVectorType*>(a); + IRVectorType* vecB = static_cast<IRVectorType*>(b); + return GetIntVal(vecA->getElementCount()) == GetIntVal(vecB->getElementCount()) && + _areEquivalent(vecA->getElementType(), vecB->getElementType()); + } + case kIROp_MatrixType: + { + IRMatrixType* matA = static_cast<IRMatrixType*>(a); + IRMatrixType* matB = static_cast<IRMatrixType*>(b); + return GetIntVal(matA->getColumnCount()) == GetIntVal(matB->getColumnCount()) && + GetIntVal(matA->getRowCount()) == GetIntVal(matB->getRowCount()) && + _areEquivalent(matA->getElementType(), matB->getElementType()); + } + default: + { + return as<IRBasicType>(a) != nullptr; + } + } +} + +void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value) +{ + // When constructing a matrix or vector from a single value this is handled by the default path + + switch (value->op) + { + case kIROp_Construct: + case kIROp_MakeMatrix: + case kIROp_makeVector: + { + IRType* type = value->getDataType(); + + // If the types are the same, we can can just break down and use + if (_areEquivalent(dstType, type)) + { + if (auto vecType = as<IRVectorType>(type)) + { + if (UInt(GetIntVal(vecType->getElementCount())) == value->getOperandCount()) + { + _emitInitializerList(vecType->getElementType(), value->getOperands(), value->getOperandCount()); + return; + } + } + else if (auto matType = as<IRMatrixType>(type)) + { + const Index colCount = Index(GetIntVal(matType->getColumnCount())); + const Index rowCount = Index(GetIntVal(matType->getRowCount())); + + // TODO(JS): If num cols = 1, then it *doesn't* actually return a vector. + // That could be argued is an error because we want swizzling or [] to work. + IRType* rowType = m_typeSet.addVectorType(matType->getElementType(), int(colCount)); + IRVectorType* rowVectorType = as<IRVectorType>(rowType); + const Index operandCount = Index(value->getOperandCount()); + + // Can init, with vectors. + // For now special case if the rowVectorType is not actually a vector (when elementSize == 1) + if (operandCount == rowCount || rowVectorType == nullptr) + { + // We have to output vectors + + // Emit the braces for the Matrix struct, contains an row array. + m_writer->emit("{\n"); + m_writer->indent(); + _emitInitializerList(rowType, value->getOperands(), rowCount); + m_writer->dedent(); + m_writer->emit("\n}"); + return; + } + else if (operandCount == rowCount * colCount) + { + // Handle if all are explicitly defined + IRType* elementType = matType->getElementType(); + IRUse* operands = value->getOperands(); + + // Emit the braces for the Matrix struct, and the array of rows + m_writer->emit("{\n"); + m_writer->indent(); + m_writer->emit("{\n"); + m_writer->indent(); + for (Index i = 0; i < rowCount; ++i) + { + if (i != 0) m_writer->emit(", "); + _emitInitializerList(elementType, operands, colCount); + operands += colCount; + } + m_writer->dedent(); + m_writer->emit("\n}"); + m_writer->dedent(); + m_writer->emit("\n}"); + return; + } + } + } + + break; + } + } + + // All other cases we just use the default emitting - might not work on arrays defined in global scope on CUDA though + emitOperand(value, getInfo(EmitOp::General)); +} + +void CUDASourceEmitter::_emitInitializerList(IRType* elementType, IRUse* operands, Index operandCount) +{ + m_writer->emit("{\n"); + m_writer->indent(); + + for (Index i = 0; i < operandCount; ++i) + { + if (i != 0) m_writer->emit(", "); + _emitInitializerListValue(elementType, operands[i].get()); + } + + m_writer->dedent(); + m_writer->emit("\n}"); +} + bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { switch(inst->op) @@ -369,6 +501,23 @@ bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu } break; } + case kIROp_makeArray: + { + IRType* dataType = inst->getDataType(); + IRArrayType* arrayType = as<IRArrayType>(dataType); + + IRType* elementType = arrayType->getElementType(); + + // Emit braces for the FixedArray struct. + m_writer->emit("{\n"); + m_writer->indent(); + + _emitInitializerList(elementType, inst->getOperands(), Index(inst->getOperandCount())); + + m_writer->dedent(); + m_writer->emit("\n}"); + return true; + } default: break; } |
