From 73f9aeb838bfaeaeae2c46d94000a4f98da47cea Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Wed, 16 Jul 2025 08:57:54 -0700 Subject: Fix duplicate DiffPair struct generation for row_major matrices in autodiff (#7728) * Initial plan * Fix duplicate DiffPair struct generation for row_major matrices in autodiff Co-authored-by: csyonghe <2652293+csyonghe@users.noreply.github.com> * Fix matrix layout conversion to use BuiltinCastExpr Address root cause in slang-check-conversion.cpp by creating proper cast expressions for matrix layout conversions instead of reusing expressions. This ensures autodiff sees proper type conversions and generates consistent DiffPair structs. Reverted the band-aid fix in autodiff system and implemented the proper front-end fix as suggested in code review. Co-authored-by: csyonghe <2652293+csyonghe@users.noreply.github.com> * Fix test to prevent dead code elimination and make it executable on CPU Co-authored-by: csyonghe <2652293+csyonghe@users.noreply.github.com> * Fix spirv emit of matrix layout cast insts. * Update test. * cleanup test. * Improve test with meaningful values that verify correct gradient computation Co-authored-by: csyonghe <2652293+csyonghe@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: csyonghe <2652293+csyonghe@users.noreply.github.com> Co-authored-by: Yong He --- source/slang/slang-emit-spirv.cpp | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'source/slang/slang-emit-spirv.cpp') diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index b9627e1ee..2b6f1c821 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -7374,6 +7374,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SLANG_ASSERT(isIntegralType(fromType)); SLANG_ASSERT(isIntegralType(toType)); + if (isTypeEqual(fromType, toType)) + { + auto inner = ensureInst(inst->getOperand(0)); + registerInst(inst, inner); + return inner; + } const auto fromInfo = getIntTypeInfo(fromType); const auto toInfo = getIntTypeInfo(toType); @@ -7491,7 +7497,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SLANG_ASSERT(isFloatingType(fromType)); SLANG_ASSERT(isFloatingType(toType)); - SLANG_ASSERT(!isTypeEqual(fromType, toType)); + if (isTypeEqual(fromType, toType)) + { + auto inner = ensureInst(inst->getOperand(0)); + registerInst(inst, inner); + return inner; + } if (isMatrixCast) { -- cgit v1.2.3