diff options
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 188 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 8 | ||||
| -rw-r--r-- | tests/autodiff/reverse-matrix-ops.slang | 92 | ||||
| -rw-r--r-- | tests/autodiff/reverse-matrix-ops.slang.expected.txt | 12 | ||||
| -rw-r--r-- | tests/language-server/invalid-const-suffix.slang | 12 | ||||
| -rw-r--r-- | tests/language-server/invalid-const-suffix.slang.expected.txt | 11 |
11 files changed, 326 insertions, 17 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index d99114e4f..b89eb85c4 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2314,6 +2314,7 @@ namespace Slang { resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type); + return; } resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType); if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr))) diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 58c8aae93..abe3f718c 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1198,6 +1198,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_VectorReshape: case kIROp_IntCast: case kIROp_FloatCast: + case kIROp_MakeVectorFromScalar: case kIROp_MakeStruct: case kIROp_MakeArray: case kIROp_MakeArrayFromElement: @@ -1212,7 +1213,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_swizzle: return transcribeSwizzle(builder, as<IRSwizzle>(origInst)); - case kIROp_MakeVectorFromScalar: case kIROp_MakeTuple: return transcribeByPassthrough(builder, origInst); case kIROp_UpdateElement: diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 6f18a3d8a..8f218293d 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -220,7 +220,6 @@ namespace Slang case kIROp_Specialize: return transcribeSpecialize(builder, as<IRSpecialize>(origInst)); - case kIROp_MakeVectorFromScalar: case kIROp_MakeTuple: case kIROp_FloatLit: case kIROp_IntLit: diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 901649f3c..f87aa7751 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1195,6 +1195,14 @@ struct DiffTransposePass case kIROp_MakeVector: return transposeMakeVector(builder, fwdInst, revValue); + case kIROp_MakeVectorFromScalar: + return transposeMakeVectorFromScalar(builder, fwdInst, revValue); + case kIROp_MakeMatrixFromScalar: + return transposeMakeMatrixFromScalar(builder, fwdInst, revValue); + case kIROp_MakeMatrix: + return transposeMakeMatrix(builder, fwdInst, revValue); + case kIROp_MatrixReshape: + return transposeMatrixReshape(builder, fwdInst, revValue); case kIROp_MakeStruct: return transposeMakeStruct(builder, fwdInst, revValue); case kIROp_MakeArray: @@ -1348,25 +1356,183 @@ struct DiffTransposePass fwdGetDiff))); } - TranspositionResult transposeMakeVector(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue) + TranspositionResult transposeMakeVectorFromScalar(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue) { - // For now, we support only vector types. Extend this to other built-in types if necessary. - SLANG_ASSERT(fwdMakeVector->getOp() == kIROp_MakeVector); + auto vectorType = as<IRVectorType>(revValue->getDataType()); + SLANG_RELEASE_ASSERT(vectorType); + auto vectorSize = as<IRIntLit>(vectorType->getElementCount()); + SLANG_RELEASE_ASSERT(vectorSize); List<RevGradient> gradients; - for (UIndex ii = 0; ii < fwdMakeVector->getOperandCount(); ii++) + for (UIndex ii = 0; ii < (UIndex)vectorSize->getValue(); ii++) { - auto gradAtIndex = builder->emitElementExtract( - fwdMakeVector->getOperand(ii)->getDataType(), - revValue, - builder->getIntValue(builder->getIntType(), ii)); - + auto revComp = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), ii)); gradients.add(RevGradient( RevGradient::Flavor::Simple, - fwdMakeVector->getOperand(ii), - gradAtIndex, + fwdMakeVector->getOperand(0), + revComp, fwdMakeVector)); } + return TranspositionResult(gradients); + } + + TranspositionResult transposeMakeMatrixFromScalar(IRBuilder* builder, IRInst* fwdMakeMatrix, IRInst* revValue) + { + auto matrixType = as<IRMatrixType>(revValue->getDataType()); + SLANG_RELEASE_ASSERT(matrixType); + auto row = as<IRIntLit>(matrixType->getRowCount()); + auto col = as<IRIntLit>(matrixType->getColumnCount()); + SLANG_RELEASE_ASSERT(row && col); + + List<RevGradient> gradients; + for (UIndex r = 0; r < (UIndex)row->getValue(); r++) + { + for (UIndex c = 0; c < (UIndex)col->getValue(); c++) + { + auto revRow = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), r)); + auto revCol = builder->emitElementExtract(revRow, builder->getIntValue(builder->getIntType(), c)); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakeMatrix->getOperand(0), + revCol, + fwdMakeMatrix)); + } + } + return TranspositionResult(gradients); + } + + TranspositionResult transposeMakeMatrix(IRBuilder* builder, IRInst* fwdMakeMatrix, IRInst* revValue) + { + List<RevGradient> gradients; + auto matrixType = as<IRMatrixType>(fwdMakeMatrix->getDataType()); + auto row = as<IRIntLit>(matrixType->getRowCount()); + auto colCount = matrixType->getColumnCount(); + IRType* rowVectorType = nullptr; + for (UIndex ii = 0; ii < fwdMakeMatrix->getOperandCount(); ii++) + { + auto argOperand = fwdMakeMatrix->getOperand(ii); + IRInst* gradAtIndex = nullptr; + if (auto vecType = as<IRVectorType>(argOperand->getDataType())) + { + gradAtIndex = builder->emitElementExtract( + argOperand->getDataType(), + revValue, + builder->getIntValue(builder->getIntType(), ii)); + } + else + { + SLANG_RELEASE_ASSERT(row); + UInt rowIndex = ii / (UInt)row->getValue(); + UInt colIndex = ii % (UInt)row->getValue(); + if (!rowVectorType) + rowVectorType = builder->getVectorType(matrixType->getElementType(), colCount); + auto revRow = builder->emitElementExtract( + rowVectorType, + revValue, + builder->getIntValue(builder->getIntType(), rowIndex)); + gradAtIndex = builder->emitElementExtract( + matrixType->getElementType(), + revRow, + builder->getIntValue(builder->getIntType(), colIndex)); + } + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakeMatrix->getOperand(ii), + gradAtIndex, + fwdMakeMatrix)); + } + return TranspositionResult(gradients); + } + + TranspositionResult transposeMatrixReshape(IRBuilder* builder, IRInst* fwdMatrixReshape, IRInst* revValue) + { + List<RevGradient> gradients; + auto operandMatrixType = as<IRMatrixType>(fwdMatrixReshape->getOperand(0)->getDataType()); + SLANG_RELEASE_ASSERT(operandMatrixType); + + auto operandRow = as<IRIntLit>(operandMatrixType->getRowCount()); + auto operandCol = as<IRIntLit>(operandMatrixType->getColumnCount()); + SLANG_RELEASE_ASSERT(operandRow && operandCol); + + auto revMatrixType = as<IRMatrixType>(revValue->getDataType()); + SLANG_RELEASE_ASSERT(revMatrixType); + auto revRow = as<IRIntLit>(revMatrixType->getRowCount()); + auto revCol = as<IRIntLit>(revMatrixType->getColumnCount()); + SLANG_RELEASE_ASSERT(revRow && revCol); + + IRInst* dzero = nullptr; + List<IRInst*> elements; + for (IRIntegerValue r = 0; r < operandRow->getValue(); r++) + { + IRInst* dstRow = nullptr; + if (r < revRow->getValue()) + dstRow = builder->emitElementExtract(revValue, builder->getIntValue(builder->getIntType(), r)); + for (IRIntegerValue c = 0; c < operandCol->getValue(); c++) + { + IRInst* element = nullptr; + if (r < revRow->getValue() && c < revCol->getValue()) + { + element = builder->emitElementExtract(dstRow, builder->getIntValue(builder->getIntType(), c)); + } + else + { + if (!dzero) + { + dzero = builder->getFloatValue(operandMatrixType->getElementType(), 0.0f); + } + element = dzero; + } + elements.add(element); + } + } + auto gradToProp = builder->emitMakeMatrix(operandMatrixType, (UInt)elements.getCount(), elements.getBuffer()); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMatrixReshape->getOperand(0), + gradToProp, + fwdMatrixReshape)); + return TranspositionResult(gradients); + } + + TranspositionResult transposeMakeVector(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue) + { + List<RevGradient> gradients; + for (UIndex ii = 0; ii < fwdMakeVector->getOperandCount(); ii++) + { + auto argOperand = fwdMakeVector->getOperand(ii); + UInt componentCount = 1; + if (auto vecType = as<IRVectorType>(argOperand->getDataType())) + { + auto intConstant = as<IRIntLit>(vecType->getElementCount()); + SLANG_RELEASE_ASSERT(intConstant); + componentCount = (UInt)intConstant->getValue(); + } + IRInst* gradAtIndex = nullptr; + if (componentCount == 1) + { + gradAtIndex = builder->emitElementExtract( + argOperand->getDataType(), + revValue, + builder->getIntValue(builder->getIntType(), ii)); + } + else + { + ShortList<UInt> componentIndices; + for (UInt index = ii; index < ii + componentCount; index++) + componentIndices.add(index); + gradAtIndex = builder->emitSwizzle( + argOperand->getDataType(), + revValue, + componentCount, + componentIndices.getArrayView().getBuffer()); + } + + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakeVector->getOperand(ii), + gradAtIndex, + fwdMakeVector)); + } // (A = float3(X, Y, Z)) -> [(dX += dA), (dY += dA), (dZ += dA)] return TranspositionResult(gradients); diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 5669a12d7..aca832c0c 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2926,6 +2926,9 @@ public: IRType* type, UInt argCount, IRInst* const* args); + IRInst* emitMakeVectorFromScalar( + IRType* type, + IRInst* scalarValue); IRInst* emitMakeVector( IRType* type, @@ -2933,6 +2936,7 @@ public: { return emitMakeVector(type, args.getCount(), args.getBuffer()); } + IRInst* emitMatrixReshape(IRType* type, IRInst* inst); IRInst* emitMakeMatrix( IRType* type, diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index e72ba8c9f..2a4ae59a7 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3723,6 +3723,18 @@ namespace Slang &defaultValue); } + IRInst* IRBuilder::emitMakeVectorFromScalar( + IRType* type, + IRInst* scalarValue) + { + return emitIntrinsicInst(type, kIROp_MakeVectorFromScalar, 1, &scalarValue); + } + + IRInst* IRBuilder::emitMatrixReshape(IRType* type, IRInst* inst) + { + return emitIntrinsicInst(type, kIROp_MatrixReshape, 1, &inst); + } + IRInst* IRBuilder::emitMakeVector( IRType* type, UInt argCount, diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index fd0810214..1bddfb9cf 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -5613,7 +5613,7 @@ namespace Slang if(unknownCount) { parser->sink->diagnose(token, Diagnostics::invalidIntegerLiteralSuffix, suffix); - suffixBaseType = BaseType::Void; + suffixBaseType = BaseType::Int; } // `u` or `ul` suffix -> `uint` else if(uCount == 1 && (lCount <= 1) && zCount == 0) @@ -5647,7 +5647,7 @@ namespace Slang else { parser->sink->diagnose(token, Diagnostics::invalidIntegerLiteralSuffix, suffix); - suffixBaseType = BaseType::Void; + suffixBaseType = BaseType::Int; } } @@ -5711,7 +5711,7 @@ namespace Slang if (unknownCount) { parser->sink->diagnose(token, Diagnostics::invalidFloatingPointLiteralSuffix, suffix); - suffixBaseType = BaseType::Void; + suffixBaseType = BaseType::Float; } // `f` suffix -> `float` if(fCount == 1 && !lCount && !hCount) @@ -5732,7 +5732,7 @@ namespace Slang else { parser->sink->diagnose(token, Diagnostics::invalidFloatingPointLiteralSuffix, suffix); - suffixBaseType = BaseType::Void; + suffixBaseType = BaseType::Float; } } diff --git a/tests/autodiff/reverse-matrix-ops.slang b/tests/autodiff/reverse-matrix-ops.slang new file mode 100644 index 000000000..e7be41811 --- /dev/null +++ b/tests/autodiff/reverse-matrix-ops.slang @@ -0,0 +1,92 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float3> dpfloat3; +typedef float3.Differential dfloat3; + +typedef DifferentialPair<float2> dpfloat2; +typedef float2.Differential dfloat2; + +[BackwardDifferentiable] +float2 test_reshape(float3 x, float3 y, int i, int j) +{ + float2x3 m = float2x3(x, y); + let mSmall = float2x2(m); + return mSmall[i] + mSmall[j]; +} + +[BackwardDifferentiable] +float3 test_vectorFromScalar(float x) +{ + return float3(x); +} + +[BackwardDifferentiable] +float3x3 test_matrixFromScalar(float x) +{ + return float3x3(x); +} + +[BackwardDifferentiable] +float2x2 test_matrixConstruct(float a, float b, float c, float d) +{ + return float2x2(a, b, c, d); +} + +[BackwardDifferentiable] +float3 test_makeVector(float x, float2 y) +{ + return float3(x, y); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat3 dpx = dpfloat3(float3(2.0, 3.0, 4.0), float3(0.0, 0.0, 0.0)); + dpfloat3 dpy = dpfloat3(float3(1.5, 2.5, 3.5), float3(0.0, 0.0, 0.0)); + + __bwd_diff(test_reshape)(dpx, dpy, 0, 1, dfloat2(1.0, 2.0)); + outputBuffer[0] = dpx.d.y; // Expect: 2 + outputBuffer[1] = dpy.d.y; // Expect: 2 + } + + { + DifferentialPair<float> dpx = diffPair(1.0, 0.0); + + __bwd_diff(test_vectorFromScalar)(dpx, dfloat3(2.0)); + outputBuffer[2] = dpx.d; // Expect: 6.0 + } + + { + DifferentialPair<float> dpx = diffPair(1.0, 0.0); + + __bwd_diff(test_matrixFromScalar)(dpx, float3x3(1.0)); + outputBuffer[3] = dpx.d; // Expect: 9.0 + } + { + DifferentialPair<float> dpa = diffPair(1.0, 0.0); + DifferentialPair<float> dpb = diffPair(1.0, 0.0); + DifferentialPair<float> dpc = diffPair(1.0, 0.0); + DifferentialPair<float> dpd = diffPair(1.0, 0.0); + + __bwd_diff(test_matrixConstruct)(dpa, dpb, dpc, dpd, float2x2(1.0, 2.0, 3.0, 4.0)); + outputBuffer[4] = dpa.d; // Expect: 1.0 + outputBuffer[5] = dpb.d; // Expect: 2.0 + outputBuffer[6] = dpc.d; // Expect: 3.0 + outputBuffer[7] = dpd.d; // Expect: 4.0 + } + { + DifferentialPair<float> dpx = diffPair(1.0, 0.0); + dpfloat2 dpy = dpfloat2(float2(1.5, 2.5), float2(0.0, 0.0)); + + __bwd_diff(test_makeVector)(dpx, dpy, float3(1.0, 1.5, 2.0)); + outputBuffer[8] = dpx.d; // Expect: 1.0 + outputBuffer[9] = dpy.d.x; // Expect: 1.5 + outputBuffer[10] = dpy.d.y; // Expect: 2.0 + } + +} diff --git a/tests/autodiff/reverse-matrix-ops.slang.expected.txt b/tests/autodiff/reverse-matrix-ops.slang.expected.txt new file mode 100644 index 000000000..990865983 --- /dev/null +++ b/tests/autodiff/reverse-matrix-ops.slang.expected.txt @@ -0,0 +1,12 @@ +type: float +2.0 +2.0 +6.0 +9.0 +1.0 +2.0 +3.0 +4.0 +1.0 +1.5 +2.0
\ No newline at end of file diff --git a/tests/language-server/invalid-const-suffix.slang b/tests/language-server/invalid-const-suffix.slang new file mode 100644 index 000000000..e2365a3d1 --- /dev/null +++ b/tests/language-server/invalid-const-suffix.slang @@ -0,0 +1,12 @@ +//TEST:LANG_SERVER: + +struct MyType {}; + +void m() +{ +//HOVER:8,9 + MyType b; + int t = 2tdf; + float c = 3.0ug; + reinterpret<MyType, MyType>(b); +} diff --git a/tests/language-server/invalid-const-suffix.slang.expected.txt b/tests/language-server/invalid-const-suffix.slang.expected.txt new file mode 100644 index 000000000..ec920b2eb --- /dev/null +++ b/tests/language-server/invalid-const-suffix.slang.expected.txt @@ -0,0 +1,11 @@ +-------- +range: 7,4 - 7,10 +content: +``` +struct MyType +``` + + +{REDACTED}.slang(3) + + |
