summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-30 20:03:46 -0800
committerGitHub <noreply@github.com>2023-01-30 20:03:46 -0800
commit77cdbb2101f4e27bf1800d4bc1077c0510668c25 (patch)
tree418ea6776aaa6a65364ba6de9ec3e6c63d1c4c5a /source
parent499b0253c224e68ceed6e5b6b1ee9cd7d65aad0f (diff)
Add transposition logic for constructor opcodes. (#2618)
* Add transposition logic for constructor opcodes. * Fix. * Add language server regression test. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-expr.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h188
-rw-r--r--source/slang/slang-ir-insts.h4
-rw-r--r--source/slang/slang-ir.cpp12
-rw-r--r--source/slang/slang-parser.cpp8
7 files changed, 199 insertions, 17 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index d99114e4f..b89eb85c4 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2314,6 +2314,7 @@ namespace Slang
{
resultDiffExpr->type = semantics->getASTBuilder()->getErrorType();
semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type);
+ return;
}
resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType);
if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr)))
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 58c8aae93..abe3f718c 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1198,6 +1198,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_VectorReshape:
case kIROp_IntCast:
case kIROp_FloatCast:
+ case kIROp_MakeVectorFromScalar:
case kIROp_MakeStruct:
case kIROp_MakeArray:
case kIROp_MakeArrayFromElement:
@@ -1212,7 +1213,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_swizzle:
return transcribeSwizzle(builder, as<IRSwizzle>(origInst));
- case kIROp_MakeVectorFromScalar:
case kIROp_MakeTuple:
return transcribeByPassthrough(builder, origInst);
case kIROp_UpdateElement:
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 6f18a3d8a..8f218293d 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -220,7 +220,6 @@ namespace Slang
case kIROp_Specialize:
return transcribeSpecialize(builder, as<IRSpecialize>(origInst));
- case kIROp_MakeVectorFromScalar:
case kIROp_MakeTuple:
case kIROp_FloatLit:
case kIROp_IntLit:
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 901649f3c..f87aa7751 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1195,6 +1195,14 @@ struct DiffTransposePass
case kIROp_MakeVector:
return transposeMakeVector(builder, fwdInst, revValue);
+ case kIROp_MakeVectorFromScalar:
+ return transposeMakeVectorFromScalar(builder, fwdInst, revValue);
+ case kIROp_MakeMatrixFromScalar:
+ return transposeMakeMatrixFromScalar(builder, fwdInst, revValue);
+ case kIROp_MakeMatrix:
+ return transposeMakeMatrix(builder, fwdInst, revValue);
+ case kIROp_MatrixReshape:
+ return transposeMatrixReshape(builder, fwdInst, revValue);
case kIROp_MakeStruct:
return transposeMakeStruct(builder, fwdInst, revValue);
case kIROp_MakeArray:
@@ -1348,25 +1356,183 @@ struct DiffTransposePass
fwdGetDiff)));
}
- TranspositionResult transposeMakeVector(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue)
+ TranspositionResult transposeMakeVectorFromScalar(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue)
{
- // For now, we support only vector types. Extend this to other built-in types if necessary.
- SLANG_ASSERT(fwdMakeVector->getOp() == kIROp_MakeVector);
+ auto vectorType = as<IRVectorType>(revValue->getDataType());
+ SLANG_RELEASE_ASSERT(vectorType);
+ auto vectorSize = as<IRIntLit>(vectorType->getElementCount());
+ SLANG_RELEASE_ASSERT(vectorSize);
List<RevGradient> gradients;
- for (UIndex ii = 0; ii < fwdMakeVector->getOperandCount(); ii++)
+ for (UIndex ii = 0; ii < (UIndex)vectorSize->getValue(); ii++)
{
- auto gradAtIndex = builder->emitElementExtract(
- fwdMakeVector->getOperand(ii)->getDataType(),
- revValue,
- builder->getIntValue(builder->getIntType(), ii));
-
+ auto revComp = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), ii));
gradients.add(RevGradient(
RevGradient::Flavor::Simple,
- fwdMakeVector->getOperand(ii),
- gradAtIndex,
+ fwdMakeVector->getOperand(0),
+ revComp,
fwdMakeVector));
}
+ return TranspositionResult(gradients);
+ }
+
+ TranspositionResult transposeMakeMatrixFromScalar(IRBuilder* builder, IRInst* fwdMakeMatrix, IRInst* revValue)
+ {
+ auto matrixType = as<IRMatrixType>(revValue->getDataType());
+ SLANG_RELEASE_ASSERT(matrixType);
+ auto row = as<IRIntLit>(matrixType->getRowCount());
+ auto col = as<IRIntLit>(matrixType->getColumnCount());
+ SLANG_RELEASE_ASSERT(row && col);
+
+ List<RevGradient> gradients;
+ for (UIndex r = 0; r < (UIndex)row->getValue(); r++)
+ {
+ for (UIndex c = 0; c < (UIndex)col->getValue(); c++)
+ {
+ auto revRow = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), r));
+ auto revCol = builder->emitElementExtract(revRow, builder->getIntValue(builder->getIntType(), c));
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeMatrix->getOperand(0),
+ revCol,
+ fwdMakeMatrix));
+ }
+ }
+ return TranspositionResult(gradients);
+ }
+
+ TranspositionResult transposeMakeMatrix(IRBuilder* builder, IRInst* fwdMakeMatrix, IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+ auto matrixType = as<IRMatrixType>(fwdMakeMatrix->getDataType());
+ auto row = as<IRIntLit>(matrixType->getRowCount());
+ auto colCount = matrixType->getColumnCount();
+ IRType* rowVectorType = nullptr;
+ for (UIndex ii = 0; ii < fwdMakeMatrix->getOperandCount(); ii++)
+ {
+ auto argOperand = fwdMakeMatrix->getOperand(ii);
+ IRInst* gradAtIndex = nullptr;
+ if (auto vecType = as<IRVectorType>(argOperand->getDataType()))
+ {
+ gradAtIndex = builder->emitElementExtract(
+ argOperand->getDataType(),
+ revValue,
+ builder->getIntValue(builder->getIntType(), ii));
+ }
+ else
+ {
+ SLANG_RELEASE_ASSERT(row);
+ UInt rowIndex = ii / (UInt)row->getValue();
+ UInt colIndex = ii % (UInt)row->getValue();
+ if (!rowVectorType)
+ rowVectorType = builder->getVectorType(matrixType->getElementType(), colCount);
+ auto revRow = builder->emitElementExtract(
+ rowVectorType,
+ revValue,
+ builder->getIntValue(builder->getIntType(), rowIndex));
+ gradAtIndex = builder->emitElementExtract(
+ matrixType->getElementType(),
+ revRow,
+ builder->getIntValue(builder->getIntType(), colIndex));
+ }
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeMatrix->getOperand(ii),
+ gradAtIndex,
+ fwdMakeMatrix));
+ }
+ return TranspositionResult(gradients);
+ }
+
+ TranspositionResult transposeMatrixReshape(IRBuilder* builder, IRInst* fwdMatrixReshape, IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+ auto operandMatrixType = as<IRMatrixType>(fwdMatrixReshape->getOperand(0)->getDataType());
+ SLANG_RELEASE_ASSERT(operandMatrixType);
+
+ auto operandRow = as<IRIntLit>(operandMatrixType->getRowCount());
+ auto operandCol = as<IRIntLit>(operandMatrixType->getColumnCount());
+ SLANG_RELEASE_ASSERT(operandRow && operandCol);
+
+ auto revMatrixType = as<IRMatrixType>(revValue->getDataType());
+ SLANG_RELEASE_ASSERT(revMatrixType);
+ auto revRow = as<IRIntLit>(revMatrixType->getRowCount());
+ auto revCol = as<IRIntLit>(revMatrixType->getColumnCount());
+ SLANG_RELEASE_ASSERT(revRow && revCol);
+
+ IRInst* dzero = nullptr;
+ List<IRInst*> elements;
+ for (IRIntegerValue r = 0; r < operandRow->getValue(); r++)
+ {
+ IRInst* dstRow = nullptr;
+ if (r < revRow->getValue())
+ dstRow = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), r));
+ for (IRIntegerValue c = 0; c < operandCol->getValue(); c++)
+ {
+ IRInst* element = nullptr;
+ if (r < revRow->getValue() && c < revCol->getValue())
+ {
+ element = builder->emitElementExtract(dstRow, builder->getIntValue(builder->getIntType(), c));
+ }
+ else
+ {
+ if (!dzero)
+ {
+ dzero = builder->getFloatValue(operandMatrixType->getElementType(), 0.0f);
+ }
+ element = dzero;
+ }
+ elements.add(element);
+ }
+ }
+ auto gradToProp = builder->emitMakeMatrix(operandMatrixType, (UInt)elements.getCount(), elements.getBuffer());
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMatrixReshape->getOperand(0),
+ gradToProp,
+ fwdMatrixReshape));
+ return TranspositionResult(gradients);
+ }
+
+ TranspositionResult transposeMakeVector(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+ for (UIndex ii = 0; ii < fwdMakeVector->getOperandCount(); ii++)
+ {
+ auto argOperand = fwdMakeVector->getOperand(ii);
+ UInt componentCount = 1;
+ if (auto vecType = as<IRVectorType>(argOperand->getDataType()))
+ {
+ auto intConstant = as<IRIntLit>(vecType->getElementCount());
+ SLANG_RELEASE_ASSERT(intConstant);
+ componentCount = (UInt)intConstant->getValue();
+ }
+ IRInst* gradAtIndex = nullptr;
+ if (componentCount == 1)
+ {
+ gradAtIndex = builder->emitElementExtract(
+ argOperand->getDataType(),
+ revValue,
+ builder->getIntValue(builder->getIntType(), ii));
+ }
+ else
+ {
+ ShortList<UInt> componentIndices;
+ for (UInt index = ii; index < ii + componentCount; index++)
+ componentIndices.add(index);
+ gradAtIndex = builder->emitSwizzle(
+ argOperand->getDataType(),
+ revValue,
+ componentCount,
+ componentIndices.getArrayView().getBuffer());
+ }
+
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeVector->getOperand(ii),
+ gradAtIndex,
+ fwdMakeVector));
+ }
// (A = float3(X, Y, Z)) -> [(dX += dA), (dY += dA), (dZ += dA)]
return TranspositionResult(gradients);
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 5669a12d7..aca832c0c 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2926,6 +2926,9 @@ public:
IRType* type,
UInt argCount,
IRInst* const* args);
+ IRInst* emitMakeVectorFromScalar(
+ IRType* type,
+ IRInst* scalarValue);
IRInst* emitMakeVector(
IRType* type,
@@ -2933,6 +2936,7 @@ public:
{
return emitMakeVector(type, args.getCount(), args.getBuffer());
}
+ IRInst* emitMatrixReshape(IRType* type, IRInst* inst);
IRInst* emitMakeMatrix(
IRType* type,
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index e72ba8c9f..2a4ae59a7 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3723,6 +3723,18 @@ namespace Slang
&defaultValue);
}
+ IRInst* IRBuilder::emitMakeVectorFromScalar(
+ IRType* type,
+ IRInst* scalarValue)
+ {
+ return emitIntrinsicInst(type, kIROp_MakeVectorFromScalar, 1, &scalarValue);
+ }
+
+ IRInst* IRBuilder::emitMatrixReshape(IRType* type, IRInst* inst)
+ {
+ return emitIntrinsicInst(type, kIROp_MatrixReshape, 1, &inst);
+ }
+
IRInst* IRBuilder::emitMakeVector(
IRType* type,
UInt argCount,
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index fd0810214..1bddfb9cf 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -5613,7 +5613,7 @@ namespace Slang
if(unknownCount)
{
parser->sink->diagnose(token, Diagnostics::invalidIntegerLiteralSuffix, suffix);
- suffixBaseType = BaseType::Void;
+ suffixBaseType = BaseType::Int;
}
// `u` or `ul` suffix -> `uint`
else if(uCount == 1 && (lCount <= 1) && zCount == 0)
@@ -5647,7 +5647,7 @@ namespace Slang
else
{
parser->sink->diagnose(token, Diagnostics::invalidIntegerLiteralSuffix, suffix);
- suffixBaseType = BaseType::Void;
+ suffixBaseType = BaseType::Int;
}
}
@@ -5711,7 +5711,7 @@ namespace Slang
if (unknownCount)
{
parser->sink->diagnose(token, Diagnostics::invalidFloatingPointLiteralSuffix, suffix);
- suffixBaseType = BaseType::Void;
+ suffixBaseType = BaseType::Float;
}
// `f` suffix -> `float`
if(fCount == 1 && !lCount && !hCount)
@@ -5732,7 +5732,7 @@ namespace Slang
else
{
parser->sink->diagnose(token, Diagnostics::invalidFloatingPointLiteralSuffix, suffix);
- suffixBaseType = BaseType::Void;
+ suffixBaseType = BaseType::Float;
}
}