From 1c2c4908c64396de2d1bee197c8f000ae2fed0fc Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 14 Dec 2022 09:37:55 -0800 Subject: Fix code generation for matrix reshape. (#2568) Co-authored-by: Yong He --- source/slang/slang-emit-cpp.cpp | 47 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) (limited to 'source/slang/slang-emit-cpp.cpp') diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index fad0c94f9..87b620ed2 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1178,6 +1178,7 @@ void CPPSourceEmitter::_emitInitDefinition(const UnownedStringSlice& funcName, c writer->emit("{ "); const Index paramCount = Index(funcType->getParamCount()); + bool handled = false; if (IRVectorType* vecType = as(retType)) { @@ -1211,7 +1212,7 @@ void CPPSourceEmitter::_emitInitDefinition(const UnownedStringSlice& funcName, c writer->emit("."); writer->emit(elemNames[paramSubIndex]); - paramSubIndex ++; + paramSubIndex++; if (paramSubIndex >= paramElementCount) { @@ -1226,9 +1227,51 @@ void CPPSourceEmitter::_emitInitDefinition(const UnownedStringSlice& funcName, c } } } + handled = true; } - else + else if (IRMatrixType* matType = as(retType)) + { + if (paramCount != 1) + goto fallback; + + auto paramMat = as(funcType->getParamType(0)); + if (!paramMat) + goto fallback; + + // We are constructing a matrix from a differently sized matrix. + + Index rows = Index(getIntVal(matType->getRowCount())); + Index cols = Index(getIntVal(matType->getColumnCount())); + Index paramRows = Index(getIntVal(paramMat->getRowCount())); + Index paramCols = Index(getIntVal(paramMat->getColumnCount())); + char elementNames[] = { 'x', 'y', 'z', 'w' }; + + for (Index r = 0; r < rows; r++) + { + for (Index c = 0; c < cols; c++) + { + if (r != 0 || c != 0) + writer->emit(", "); + + if (r < paramRows && c < paramCols && c < 4) + { + writer->emitRawText("a.rows["); + writer->emit(r); + writer->emitRawText("]."); + writer->emitChar(elementNames[c]); + } + else + { + writer->emit("0"); + } + } + } + handled = true; + } +fallback: + if (!handled) { + // Fallback default: just use all params to construct. for (Index i = 0; i < paramCount; ++i) { if (i > 0) -- cgit v1.2.3