summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-30 20:03:46 -0800
committerGitHub <noreply@github.com>2023-01-30 20:03:46 -0800
commit77cdbb2101f4e27bf1800d4bc1077c0510668c25 (patch)
tree418ea6776aaa6a65364ba6de9ec3e6c63d1c4c5a
parent499b0253c224e68ceed6e5b6b1ee9cd7d65aad0f (diff)
Add transposition logic for constructor opcodes. (#2618)
* Add transposition logic for constructor opcodes. * Fix. * Add language server regression test. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-check-expr.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h188
-rw-r--r--source/slang/slang-ir-insts.h4
-rw-r--r--source/slang/slang-ir.cpp12
-rw-r--r--source/slang/slang-parser.cpp8
-rw-r--r--tests/autodiff/reverse-matrix-ops.slang92
-rw-r--r--tests/autodiff/reverse-matrix-ops.slang.expected.txt12
-rw-r--r--tests/language-server/invalid-const-suffix.slang12
-rw-r--r--tests/language-server/invalid-const-suffix.slang.expected.txt11
11 files changed, 326 insertions, 17 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index d99114e4f..b89eb85c4 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2314,6 +2314,7 @@ namespace Slang
{
resultDiffExpr->type = semantics->getASTBuilder()->getErrorType();
semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type);
+ return;
}
resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType);
if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr)))
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 58c8aae93..abe3f718c 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1198,6 +1198,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_VectorReshape:
case kIROp_IntCast:
case kIROp_FloatCast:
+ case kIROp_MakeVectorFromScalar:
case kIROp_MakeStruct:
case kIROp_MakeArray:
case kIROp_MakeArrayFromElement:
@@ -1212,7 +1213,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_swizzle:
return transcribeSwizzle(builder, as<IRSwizzle>(origInst));
- case kIROp_MakeVectorFromScalar:
case kIROp_MakeTuple:
return transcribeByPassthrough(builder, origInst);
case kIROp_UpdateElement:
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 6f18a3d8a..8f218293d 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -220,7 +220,6 @@ namespace Slang
case kIROp_Specialize:
return transcribeSpecialize(builder, as<IRSpecialize>(origInst));
- case kIROp_MakeVectorFromScalar:
case kIROp_MakeTuple:
case kIROp_FloatLit:
case kIROp_IntLit:
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 901649f3c..f87aa7751 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1195,6 +1195,14 @@ struct DiffTransposePass
case kIROp_MakeVector:
return transposeMakeVector(builder, fwdInst, revValue);
+ case kIROp_MakeVectorFromScalar:
+ return transposeMakeVectorFromScalar(builder, fwdInst, revValue);
+ case kIROp_MakeMatrixFromScalar:
+ return transposeMakeMatrixFromScalar(builder, fwdInst, revValue);
+ case kIROp_MakeMatrix:
+ return transposeMakeMatrix(builder, fwdInst, revValue);
+ case kIROp_MatrixReshape:
+ return transposeMatrixReshape(builder, fwdInst, revValue);
case kIROp_MakeStruct:
return transposeMakeStruct(builder, fwdInst, revValue);
case kIROp_MakeArray:
@@ -1348,25 +1356,183 @@ struct DiffTransposePass
fwdGetDiff)));
}
- TranspositionResult transposeMakeVector(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue)
+ TranspositionResult transposeMakeVectorFromScalar(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue)
{
- // For now, we support only vector types. Extend this to other built-in types if necessary.
- SLANG_ASSERT(fwdMakeVector->getOp() == kIROp_MakeVector);
+ auto vectorType = as<IRVectorType>(revValue->getDataType());
+ SLANG_RELEASE_ASSERT(vectorType);
+ auto vectorSize = as<IRIntLit>(vectorType->getElementCount());
+ SLANG_RELEASE_ASSERT(vectorSize);
List<RevGradient> gradients;
- for (UIndex ii = 0; ii < fwdMakeVector->getOperandCount(); ii++)
+ for (UIndex ii = 0; ii < (UIndex)vectorSize->getValue(); ii++)
{
- auto gradAtIndex = builder->emitElementExtract(
- fwdMakeVector->getOperand(ii)->getDataType(),
- revValue,
- builder->getIntValue(builder->getIntType(), ii));
-
+ auto revComp = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), ii));
gradients.add(RevGradient(
RevGradient::Flavor::Simple,
- fwdMakeVector->getOperand(ii),
- gradAtIndex,
+ fwdMakeVector->getOperand(0),
+ revComp,
fwdMakeVector));
}
+ return TranspositionResult(gradients);
+ }
+
+ TranspositionResult transposeMakeMatrixFromScalar(IRBuilder* builder, IRInst* fwdMakeMatrix, IRInst* revValue)
+ {
+ auto matrixType = as<IRMatrixType>(revValue->getDataType());
+ SLANG_RELEASE_ASSERT(matrixType);
+ auto row = as<IRIntLit>(matrixType->getRowCount());
+ auto col = as<IRIntLit>(matrixType->getColumnCount());
+ SLANG_RELEASE_ASSERT(row && col);
+
+ List<RevGradient> gradients;
+ for (UIndex r = 0; r < (UIndex)row->getValue(); r++)
+ {
+ for (UIndex c = 0; c < (UIndex)col->getValue(); c++)
+ {
+ auto revRow = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), r));
+ auto revCol = builder->emitElementExtract(revRow, builder->getIntValue(builder->getIntType(), c));
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeMatrix->getOperand(0),
+ revCol,
+ fwdMakeMatrix));
+ }
+ }
+ return TranspositionResult(gradients);
+ }
+
+ TranspositionResult transposeMakeMatrix(IRBuilder* builder, IRInst* fwdMakeMatrix, IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+ auto matrixType = as<IRMatrixType>(fwdMakeMatrix->getDataType());
+ auto row = as<IRIntLit>(matrixType->getRowCount());
+ auto colCount = matrixType->getColumnCount();
+ IRType* rowVectorType = nullptr;
+ for (UIndex ii = 0; ii < fwdMakeMatrix->getOperandCount(); ii++)
+ {
+ auto argOperand = fwdMakeMatrix->getOperand(ii);
+ IRInst* gradAtIndex = nullptr;
+ if (auto vecType = as<IRVectorType>(argOperand->getDataType()))
+ {
+ gradAtIndex = builder->emitElementExtract(
+ argOperand->getDataType(),
+ revValue,
+ builder->getIntValue(builder->getIntType(), ii));
+ }
+ else
+ {
+ SLANG_RELEASE_ASSERT(row);
+ UInt rowIndex = ii / (UInt)row->getValue();
+ UInt colIndex = ii % (UInt)row->getValue();
+ if (!rowVectorType)
+ rowVectorType = builder->getVectorType(matrixType->getElementType(), colCount);
+ auto revRow = builder->emitElementExtract(
+ rowVectorType,
+ revValue,
+ builder->getIntValue(builder->getIntType(), rowIndex));
+ gradAtIndex = builder->emitElementExtract(
+ matrixType->getElementType(),
+ revRow,
+ builder->getIntValue(builder->getIntType(), colIndex));
+ }
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeMatrix->getOperand(ii),
+ gradAtIndex,
+ fwdMakeMatrix));
+ }
+ return TranspositionResult(gradients);
+ }
+
+ TranspositionResult transposeMatrixReshape(IRBuilder* builder, IRInst* fwdMatrixReshape, IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+ auto operandMatrixType = as<IRMatrixType>(fwdMatrixReshape->getOperand(0)->getDataType());
+ SLANG_RELEASE_ASSERT(operandMatrixType);
+
+ auto operandRow = as<IRIntLit>(operandMatrixType->getRowCount());
+ auto operandCol = as<IRIntLit>(operandMatrixType->getColumnCount());
+ SLANG_RELEASE_ASSERT(operandRow && operandCol);
+
+ auto revMatrixType = as<IRMatrixType>(revValue->getDataType());
+ SLANG_RELEASE_ASSERT(revMatrixType);
+ auto revRow = as<IRIntLit>(revMatrixType->getRowCount());
+ auto revCol = as<IRIntLit>(revMatrixType->getColumnCount());
+ SLANG_RELEASE_ASSERT(revRow && revCol);
+
+ IRInst* dzero = nullptr;
+ List<IRInst*> elements;
+ for (IRIntegerValue r = 0; r < operandRow->getValue(); r++)
+ {
+ IRInst* dstRow = nullptr;
+ if (r < revRow->getValue())
+ dstRow = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), r));
+ for (IRIntegerValue c = 0; c < operandCol->getValue(); c++)
+ {
+ IRInst* element = nullptr;
+ if (r < revRow->getValue() && c < revCol->getValue())
+ {
+ element = builder->emitElementExtract(dstRow, builder->getIntValue(builder->getIntType(), c));
+ }
+ else
+ {
+ if (!dzero)
+ {
+ dzero = builder->getFloatValue(operandMatrixType->getElementType(), 0.0f);
+ }
+ element = dzero;
+ }
+ elements.add(element);
+ }
+ }
+ auto gradToProp = builder->emitMakeMatrix(operandMatrixType, (UInt)elements.getCount(), elements.getBuffer());
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMatrixReshape->getOperand(0),
+ gradToProp,
+ fwdMatrixReshape));
+ return TranspositionResult(gradients);
+ }
+
+ TranspositionResult transposeMakeVector(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue)
+ {
+ List<RevGradient> gradients;
+ for (UIndex ii = 0; ii < fwdMakeVector->getOperandCount(); ii++)
+ {
+ auto argOperand = fwdMakeVector->getOperand(ii);
+ UInt componentCount = 1;
+ if (auto vecType = as<IRVectorType>(argOperand->getDataType()))
+ {
+ auto intConstant = as<IRIntLit>(vecType->getElementCount());
+ SLANG_RELEASE_ASSERT(intConstant);
+ componentCount = (UInt)intConstant->getValue();
+ }
+ IRInst* gradAtIndex = nullptr;
+ if (componentCount == 1)
+ {
+ gradAtIndex = builder->emitElementExtract(
+ argOperand->getDataType(),
+ revValue,
+ builder->getIntValue(builder->getIntType(), ii));
+ }
+ else
+ {
+ ShortList<UInt> componentIndices;
+ for (UInt index = ii; index < ii + componentCount; index++)
+ componentIndices.add(index);
+ gradAtIndex = builder->emitSwizzle(
+ argOperand->getDataType(),
+ revValue,
+ componentCount,
+ componentIndices.getArrayView().getBuffer());
+ }
+
+ gradients.add(RevGradient(
+ RevGradient::Flavor::Simple,
+ fwdMakeVector->getOperand(ii),
+ gradAtIndex,
+ fwdMakeVector));
+ }
// (A = float3(X, Y, Z)) -> [(dX += dA), (dY += dA), (dZ += dA)]
return TranspositionResult(gradients);
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 5669a12d7..aca832c0c 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2926,6 +2926,9 @@ public:
IRType* type,
UInt argCount,
IRInst* const* args);
+ IRInst* emitMakeVectorFromScalar(
+ IRType* type,
+ IRInst* scalarValue);
IRInst* emitMakeVector(
IRType* type,
@@ -2933,6 +2936,7 @@ public:
{
return emitMakeVector(type, args.getCount(), args.getBuffer());
}
+ IRInst* emitMatrixReshape(IRType* type, IRInst* inst);
IRInst* emitMakeMatrix(
IRType* type,
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index e72ba8c9f..2a4ae59a7 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3723,6 +3723,18 @@ namespace Slang
&defaultValue);
}
+ IRInst* IRBuilder::emitMakeVectorFromScalar(
+ IRType* type,
+ IRInst* scalarValue)
+ {
+ return emitIntrinsicInst(type, kIROp_MakeVectorFromScalar, 1, &scalarValue);
+ }
+
+ IRInst* IRBuilder::emitMatrixReshape(IRType* type, IRInst* inst)
+ {
+ return emitIntrinsicInst(type, kIROp_MatrixReshape, 1, &inst);
+ }
+
IRInst* IRBuilder::emitMakeVector(
IRType* type,
UInt argCount,
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index fd0810214..1bddfb9cf 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -5613,7 +5613,7 @@ namespace Slang
if(unknownCount)
{
parser->sink->diagnose(token, Diagnostics::invalidIntegerLiteralSuffix, suffix);
- suffixBaseType = BaseType::Void;
+ suffixBaseType = BaseType::Int;
}
// `u` or `ul` suffix -> `uint`
else if(uCount == 1 && (lCount <= 1) && zCount == 0)
@@ -5647,7 +5647,7 @@ namespace Slang
else
{
parser->sink->diagnose(token, Diagnostics::invalidIntegerLiteralSuffix, suffix);
- suffixBaseType = BaseType::Void;
+ suffixBaseType = BaseType::Int;
}
}
@@ -5711,7 +5711,7 @@ namespace Slang
if (unknownCount)
{
parser->sink->diagnose(token, Diagnostics::invalidFloatingPointLiteralSuffix, suffix);
- suffixBaseType = BaseType::Void;
+ suffixBaseType = BaseType::Float;
}
// `f` suffix -> `float`
if(fCount == 1 && !lCount && !hCount)
@@ -5732,7 +5732,7 @@ namespace Slang
else
{
parser->sink->diagnose(token, Diagnostics::invalidFloatingPointLiteralSuffix, suffix);
- suffixBaseType = BaseType::Void;
+ suffixBaseType = BaseType::Float;
}
}
diff --git a/tests/autodiff/reverse-matrix-ops.slang b/tests/autodiff/reverse-matrix-ops.slang
new file mode 100644
index 000000000..e7be41811
--- /dev/null
+++ b/tests/autodiff/reverse-matrix-ops.slang
@@ -0,0 +1,92 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float3> dpfloat3;
+typedef float3.Differential dfloat3;
+
+typedef DifferentialPair<float2> dpfloat2;
+typedef float2.Differential dfloat2;
+
+[BackwardDifferentiable]
+float2 test_reshape(float3 x, float3 y, int i, int j)
+{
+ float2x3 m = float2x3(x, y);
+ let mSmall = float2x2(m);
+ return mSmall[i] + mSmall[j];
+}
+
+[BackwardDifferentiable]
+float3 test_vectorFromScalar(float x)
+{
+ return float3(x);
+}
+
+[BackwardDifferentiable]
+float3x3 test_matrixFromScalar(float x)
+{
+ return float3x3(x);
+}
+
+[BackwardDifferentiable]
+float2x2 test_matrixConstruct(float a, float b, float c, float d)
+{
+ return float2x2(a, b, c, d);
+}
+
+[BackwardDifferentiable]
+float3 test_makeVector(float x, float2 y)
+{
+ return float3(x, y);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat3 dpx = dpfloat3(float3(2.0, 3.0, 4.0), float3(0.0, 0.0, 0.0));
+ dpfloat3 dpy = dpfloat3(float3(1.5, 2.5, 3.5), float3(0.0, 0.0, 0.0));
+
+ __bwd_diff(test_reshape)(dpx, dpy, 0, 1, dfloat2(1.0, 2.0));
+ outputBuffer[0] = dpx.d.y; // Expect: 2
+ outputBuffer[1] = dpy.d.y; // Expect: 2
+ }
+
+ {
+ DifferentialPair<float> dpx = diffPair(1.0, 0.0);
+
+ __bwd_diff(test_vectorFromScalar)(dpx, dfloat3(2.0));
+ outputBuffer[2] = dpx.d; // Expect: 6.0
+ }
+
+ {
+ DifferentialPair<float> dpx = diffPair(1.0, 0.0);
+
+ __bwd_diff(test_matrixFromScalar)(dpx, float3x3(1.0));
+ outputBuffer[3] = dpx.d; // Expect: 9.0
+ }
+ {
+ DifferentialPair<float> dpa = diffPair(1.0, 0.0);
+ DifferentialPair<float> dpb = diffPair(1.0, 0.0);
+ DifferentialPair<float> dpc = diffPair(1.0, 0.0);
+ DifferentialPair<float> dpd = diffPair(1.0, 0.0);
+
+ __bwd_diff(test_matrixConstruct)(dpa, dpb, dpc, dpd, float2x2(1.0, 2.0, 3.0, 4.0));
+ outputBuffer[4] = dpa.d; // Expect: 1.0
+ outputBuffer[5] = dpb.d; // Expect: 2.0
+ outputBuffer[6] = dpc.d; // Expect: 3.0
+ outputBuffer[7] = dpd.d; // Expect: 4.0
+ }
+ {
+ DifferentialPair<float> dpx = diffPair(1.0, 0.0);
+ dpfloat2 dpy = dpfloat2(float2(1.5, 2.5), float2(0.0, 0.0));
+
+ __bwd_diff(test_makeVector)(dpx, dpy, float3(1.0, 1.5, 2.0));
+ outputBuffer[8] = dpx.d; // Expect: 1.0
+ outputBuffer[9] = dpy.d.x; // Expect: 1.5
+ outputBuffer[10] = dpy.d.y; // Expect: 2.0
+ }
+
+}
diff --git a/tests/autodiff/reverse-matrix-ops.slang.expected.txt b/tests/autodiff/reverse-matrix-ops.slang.expected.txt
new file mode 100644
index 000000000..990865983
--- /dev/null
+++ b/tests/autodiff/reverse-matrix-ops.slang.expected.txt
@@ -0,0 +1,12 @@
+type: float
+2.0
+2.0
+6.0
+9.0
+1.0
+2.0
+3.0
+4.0
+1.0
+1.5
+2.0 \ No newline at end of file
diff --git a/tests/language-server/invalid-const-suffix.slang b/tests/language-server/invalid-const-suffix.slang
new file mode 100644
index 000000000..e2365a3d1
--- /dev/null
+++ b/tests/language-server/invalid-const-suffix.slang
@@ -0,0 +1,12 @@
+//TEST:LANG_SERVER:
+
+struct MyType {};
+
+void m()
+{
+//HOVER:8,9
+ MyType b;
+ int t = 2tdf;
+ float c = 3.0ug;
+ reinterpret<MyType, MyType>(b);
+}
diff --git a/tests/language-server/invalid-const-suffix.slang.expected.txt b/tests/language-server/invalid-const-suffix.slang.expected.txt
new file mode 100644
index 000000000..ec920b2eb
--- /dev/null
+++ b/tests/language-server/invalid-const-suffix.slang.expected.txt
@@ -0,0 +1,11 @@
+--------
+range: 7,4 - 7,10
+content:
+```
+struct MyType
+```
+
+
+{REDACTED}.slang(3)
+
+