diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 188 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 8 |
7 files changed, 199 insertions, 17 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index d99114e4f..b89eb85c4 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2314,6 +2314,7 @@ namespace Slang { resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type); + return; } resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType); if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr))) diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 58c8aae93..abe3f718c 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1198,6 +1198,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_VectorReshape: case kIROp_IntCast: case kIROp_FloatCast: + case kIROp_MakeVectorFromScalar: case kIROp_MakeStruct: case kIROp_MakeArray: case kIROp_MakeArrayFromElement: @@ -1212,7 +1213,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_swizzle: return transcribeSwizzle(builder, as<IRSwizzle>(origInst)); - case kIROp_MakeVectorFromScalar: case kIROp_MakeTuple: return transcribeByPassthrough(builder, origInst); case kIROp_UpdateElement: diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 6f18a3d8a..8f218293d 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -220,7 +220,6 @@ namespace Slang case kIROp_Specialize: return transcribeSpecialize(builder, as<IRSpecialize>(origInst)); - case kIROp_MakeVectorFromScalar: case kIROp_MakeTuple: case kIROp_FloatLit: case kIROp_IntLit: diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 901649f3c..f87aa7751 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1195,6 +1195,14 @@ struct DiffTransposePass case kIROp_MakeVector: return transposeMakeVector(builder, fwdInst, revValue); + case kIROp_MakeVectorFromScalar: + return transposeMakeVectorFromScalar(builder, fwdInst, revValue); + case kIROp_MakeMatrixFromScalar: + return transposeMakeMatrixFromScalar(builder, fwdInst, revValue); + case kIROp_MakeMatrix: + return transposeMakeMatrix(builder, fwdInst, revValue); + case kIROp_MatrixReshape: + return transposeMatrixReshape(builder, fwdInst, revValue); case kIROp_MakeStruct: return transposeMakeStruct(builder, fwdInst, revValue); case kIROp_MakeArray: @@ -1348,25 +1356,183 @@ struct DiffTransposePass fwdGetDiff))); } - TranspositionResult transposeMakeVector(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue) + TranspositionResult transposeMakeVectorFromScalar(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue) { - // For now, we support only vector types. Extend this to other built-in types if necessary. - SLANG_ASSERT(fwdMakeVector->getOp() == kIROp_MakeVector); + auto vectorType = as<IRVectorType>(revValue->getDataType()); + SLANG_RELEASE_ASSERT(vectorType); + auto vectorSize = as<IRIntLit>(vectorType->getElementCount()); + SLANG_RELEASE_ASSERT(vectorSize); List<RevGradient> gradients; - for (UIndex ii = 0; ii < fwdMakeVector->getOperandCount(); ii++) + for (UIndex ii = 0; ii < (UIndex)vectorSize->getValue(); ii++) { - auto gradAtIndex = builder->emitElementExtract( - fwdMakeVector->getOperand(ii)->getDataType(), - revValue, - builder->getIntValue(builder->getIntType(), ii)); - + auto revComp = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), ii)); gradients.add(RevGradient( RevGradient::Flavor::Simple, - fwdMakeVector->getOperand(ii), - gradAtIndex, + fwdMakeVector->getOperand(0), + revComp, fwdMakeVector)); } + return TranspositionResult(gradients); + } + + TranspositionResult transposeMakeMatrixFromScalar(IRBuilder* builder, IRInst* fwdMakeMatrix, IRInst* revValue) + { + auto matrixType = as<IRMatrixType>(revValue->getDataType()); + SLANG_RELEASE_ASSERT(matrixType); + auto row = as<IRIntLit>(matrixType->getRowCount()); + auto col = as<IRIntLit>(matrixType->getColumnCount()); + SLANG_RELEASE_ASSERT(row && col); + + List<RevGradient> gradients; + for (UIndex r = 0; r < (UIndex)row->getValue(); r++) + { + for (UIndex c = 0; c < (UIndex)col->getValue(); c++) + { + auto revRow = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), r)); + auto revCol = builder->emitElementExtract(revRow, builder->getIntValue(builder->getIntType(), c)); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakeMatrix->getOperand(0), + revCol, + fwdMakeMatrix)); + } + } + return TranspositionResult(gradients); + } + + TranspositionResult transposeMakeMatrix(IRBuilder* builder, IRInst* fwdMakeMatrix, IRInst* revValue) + { + List<RevGradient> gradients; + auto matrixType = as<IRMatrixType>(fwdMakeMatrix->getDataType()); + auto row = as<IRIntLit>(matrixType->getRowCount()); + auto colCount = matrixType->getColumnCount(); + IRType* rowVectorType = nullptr; + for (UIndex ii = 0; ii < fwdMakeMatrix->getOperandCount(); ii++) + { + auto argOperand = fwdMakeMatrix->getOperand(ii); + IRInst* gradAtIndex = nullptr; + if (auto vecType = as<IRVectorType>(argOperand->getDataType())) + { + gradAtIndex = builder->emitElementExtract( + argOperand->getDataType(), + revValue, + builder->getIntValue(builder->getIntType(), ii)); + } + else + { + SLANG_RELEASE_ASSERT(row); + UInt rowIndex = ii / (UInt)row->getValue(); + UInt colIndex = ii % (UInt)row->getValue(); + if (!rowVectorType) + rowVectorType = builder->getVectorType(matrixType->getElementType(), colCount); + auto revRow = builder->emitElementExtract( + rowVectorType, + revValue, + builder->getIntValue(builder->getIntType(), rowIndex)); + gradAtIndex = builder->emitElementExtract( + matrixType->getElementType(), + revRow, + builder->getIntValue(builder->getIntType(), colIndex)); + } + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakeMatrix->getOperand(ii), + gradAtIndex, + fwdMakeMatrix)); + } + return TranspositionResult(gradients); + } + + TranspositionResult transposeMatrixReshape(IRBuilder* builder, IRInst* fwdMatrixReshape, IRInst* revValue) + { + List<RevGradient> gradients; + auto operandMatrixType = as<IRMatrixType>(fwdMatrixReshape->getOperand(0)->getDataType()); + SLANG_RELEASE_ASSERT(operandMatrixType); + + auto operandRow = as<IRIntLit>(operandMatrixType->getRowCount()); + auto operandCol = as<IRIntLit>(operandMatrixType->getColumnCount()); + SLANG_RELEASE_ASSERT(operandRow && operandCol); + + auto revMatrixType = as<IRMatrixType>(revValue->getDataType()); + SLANG_RELEASE_ASSERT(revMatrixType); + auto revRow = as<IRIntLit>(revMatrixType->getRowCount()); + auto revCol = as<IRIntLit>(revMatrixType->getColumnCount()); + SLANG_RELEASE_ASSERT(revRow && revCol); + + IRInst* dzero = nullptr; + List<IRInst*> elements; + for (IRIntegerValue r = 0; r < operandRow->getValue(); r++) + { + IRInst* dstRow = nullptr; + if (r < revRow->getValue()) + dstRow = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), r)); + for (IRIntegerValue c = 0; c < operandCol->getValue(); c++) + { + IRInst* element = nullptr; + if (r < revRow->getValue() && c < revCol->getValue()) + { + element = builder->emitElementExtract(dstRow, builder->getIntValue(builder->getIntType(), c)); + } + else + { + if (!dzero) + { + dzero = builder->getFloatValue(operandMatrixType->getElementType(), 0.0f); + } + element = dzero; + } + elements.add(element); + } + } + auto gradToProp = builder->emitMakeMatrix(operandMatrixType, (UInt)elements.getCount(), elements.getBuffer()); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMatrixReshape->getOperand(0), + gradToProp, + fwdMatrixReshape)); + return TranspositionResult(gradients); + } + + TranspositionResult transposeMakeVector(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue) + { + List<RevGradient> gradients; + for (UIndex ii = 0; ii < fwdMakeVector->getOperandCount(); ii++) + { + auto argOperand = fwdMakeVector->getOperand(ii); + UInt componentCount = 1; + if (auto vecType = as<IRVectorType>(argOperand->getDataType())) + { + auto intConstant = as<IRIntLit>(vecType->getElementCount()); + SLANG_RELEASE_ASSERT(intConstant); + componentCount = (UInt)intConstant->getValue(); + } + IRInst* gradAtIndex = nullptr; + if (componentCount == 1) + { + gradAtIndex = builder->emitElementExtract( + argOperand->getDataType(), + revValue, + builder->getIntValue(builder->getIntType(), ii)); + } + else + { + ShortList<UInt> componentIndices; + for (UInt index = ii; index < ii + componentCount; index++) + componentIndices.add(index); + gradAtIndex = builder->emitSwizzle( + argOperand->getDataType(), + revValue, + componentCount, + componentIndices.getArrayView().getBuffer()); + } + + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakeVector->getOperand(ii), + gradAtIndex, + fwdMakeVector)); + } // (A = float3(X, Y, Z)) -> [(dX += dA), (dY += dA), (dZ += dA)] return TranspositionResult(gradients); diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 5669a12d7..aca832c0c 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2926,6 +2926,9 @@ public: IRType* type, UInt argCount, IRInst* const* args); + IRInst* emitMakeVectorFromScalar( + IRType* type, + IRInst* scalarValue); IRInst* emitMakeVector( IRType* type, @@ -2933,6 +2936,7 @@ public: { return emitMakeVector(type, args.getCount(), args.getBuffer()); } + IRInst* emitMatrixReshape(IRType* type, IRInst* inst); IRInst* emitMakeMatrix( IRType* type, diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index e72ba8c9f..2a4ae59a7 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3723,6 +3723,18 @@ namespace Slang &defaultValue); } + IRInst* IRBuilder::emitMakeVectorFromScalar( + IRType* type, + IRInst* scalarValue) + { + return emitIntrinsicInst(type, kIROp_MakeVectorFromScalar, 1, &scalarValue); + } + + IRInst* IRBuilder::emitMatrixReshape(IRType* type, IRInst* inst) + { + return emitIntrinsicInst(type, kIROp_MatrixReshape, 1, &inst); + } + IRInst* IRBuilder::emitMakeVector( IRType* type, UInt argCount, diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index fd0810214..1bddfb9cf 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -5613,7 +5613,7 @@ namespace Slang if(unknownCount) { parser->sink->diagnose(token, Diagnostics::invalidIntegerLiteralSuffix, suffix); - suffixBaseType = BaseType::Void; + suffixBaseType = BaseType::Int; } // `u` or `ul` suffix -> `uint` else if(uCount == 1 && (lCount <= 1) && zCount == 0) @@ -5647,7 +5647,7 @@ namespace Slang else { parser->sink->diagnose(token, Diagnostics::invalidIntegerLiteralSuffix, suffix); - suffixBaseType = BaseType::Void; + suffixBaseType = BaseType::Int; } } @@ -5711,7 +5711,7 @@ namespace Slang if (unknownCount) { parser->sink->diagnose(token, Diagnostics::invalidFloatingPointLiteralSuffix, suffix); - suffixBaseType = BaseType::Void; + suffixBaseType = BaseType::Float; } // `f` suffix -> `float` if(fCount == 1 && !lCount && !hCount) @@ -5732,7 +5732,7 @@ namespace Slang else { parser->sink->diagnose(token, Diagnostics::invalidFloatingPointLiteralSuffix, suffix); - suffixBaseType = BaseType::Void; + suffixBaseType = BaseType::Float; } } |
