diff options
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index df657476a..f3f32add2 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1362,9 +1362,10 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod( IRBuilder* builder, IRType* origType, IRStructKey* key, - IRType* resultType) + IRType* resultType, + DiffConformanceKind kind) { - if (auto conformance = tryGetDifferentiableWitness(builder, origType, DiffConformanceKind::Any)) + if (auto conformance = tryGetDifferentiableWitness(builder, origType, kind)) return _lookupWitness(builder, conformance, key, resultType); return nullptr; } @@ -2097,8 +2098,6 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( IRWitnessTable* table = nullptr; if (target == DiffConformanceKind::Value) { - SLANG_ASSERT(isDifferentiableValueType((IRType*)inTupleType)); - auto addMethod = builder->createFunc(); auto zeroMethod = builder->createFunc(); @@ -2138,6 +2137,8 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( &b, (IRType*)elementType, DiffConformanceKind::Value); + + SLANG_ASSERT(isDifferentiableValueType((IRType*)elementType)); IRInst* elementResult = nullptr; if (!innerWitness) { @@ -2171,9 +2172,9 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( { // Zero method. IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); + b.setInsertInto(zeroMethod); + b.addBackwardDifferentiableDecoration(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); b.emitBlock(); List<IRInst*> results; for (UInt i = 0; i < inTupleType->getOperandCount(); i++) @@ -2214,7 +2215,6 @@ IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( else if (target == DiffConformanceKind::Ptr) { SLANG_ASSERT(isDifferentiablePtrType((IRType*)inTupleType)); - table = builder->createWitnessTable( sharedContext->differentiablePtrInterfaceType, (IRType*)inTupleType); |
