diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-16 12:17:49 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-16 12:17:49 -0800 |
| commit | 801aa3b44254341018a1acbe754f2ce3b0900e2a (patch) | |
| tree | b3066778522edb99bf64c0ac80c91b0b4cb788f8 /source/slang/slang-ir-diff-jvp.cpp | |
| parent | 09d8e048d2264d89886cda8e87e8a452d4f913c1 (diff) | |
Clean up type checking of higher order expressions. (#2519)
* Clean up type checking of higher order expressions.
* Replace `goto` with `break` to pacify clang.
* Fix.
* Fixes.
* Fix more tests.
* Fix lowerWitnessTable parameter error.
* Exclude attributes from ast printing.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 134 |
1 files changed, 69 insertions, 65 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 1597c80d1..152601dbd 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -456,8 +456,15 @@ struct DifferentialPairTypeBuilder IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType) { - SLANG_ASSERT(!as<IRParam>(origBaseType)); - SLANG_ASSERT(diffType); + switch (origBaseType->getOp()) + { + case kIROp_lookup_interface_method: + case kIROp_Specialize: + case kIROp_Param: + return nullptr; + default: + break; + } if (diffType->getOp() != kIROp_DifferentialBottomType) { IRBuilder builder(sharedContext->sharedBuilder); @@ -511,6 +518,8 @@ struct DifferentialPairTypeBuilder } auto diffType = getDiffTypeFromPairType(builder, pairType); + if (!diffType) + return result; result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); result.isTrivial = (diffType->getOp() == kIROp_DifferentialBottomType); pairTypeCache.Add(originalPairType, result); @@ -1431,10 +1440,10 @@ struct JVPTranscriber else { getSink()->diagnose(origSpecialize->sourceLoc, - Diagnostics::unexpected, - "should not be attempting to differentiate anything specialized here."); + Diagnostics::unexpected, + "should not be attempting to differentiate anything specialized here."); } - + return InstPair(nullptr, nullptr); } @@ -2740,7 +2749,16 @@ struct JVPDerivativeContext : public InstPassBase { case kIROp_ForwardDifferentiate: case kIROp_BackwardDifferentiate: - autoDiffWorkList.add(inst); + // Only process now if the operand is a materialized function. + switch (inst->getOperand(0)->getOp()) + { + case kIROp_Func: + case kIROp_Specialize: + autoDiffWorkList.add(inst); + break; + default: + break; + } break; default: break; @@ -2752,59 +2770,63 @@ struct JVPDerivativeContext : public InstPassBase // Process collected `ForwardDifferentiate` insts and replace them with placeholders for // differentiated functions. + transcriberStorage.followUpFunctionsToTranscribe.clear(); backwardTranscriberStorage.followUpFunctionsToTranscribe.clear(); for (auto differentiateInst : autoDiffWorkList) { IRInst* baseInst = differentiateInst->getOperand(0); - - if (auto baseFunction = as<IRGlobalValueWithCode>(baseInst)) + if (as<IRForwardDifferentiate>(differentiateInst)) { - if (as<IRForwardDifferentiate>(differentiateInst)) + if (auto existingDiffFunc = lookupJVPReference(baseInst)) + { + differentiateInst->replaceUsesWith(existingDiffFunc); + differentiateInst->removeAndDeallocate(); + } + else if (isMarkedForForwardDifferentiation(baseInst)) { - if (auto existingDiffFunc = lookupJVPReference(baseFunction)) + if (as<IRFunc>(baseInst) || as<IRGeneric>(baseInst)) { - differentiateInst->replaceUsesWith(existingDiffFunc); + IRInst* diffFunc = transcriberStorage.transcribe(builder, baseInst); + SLANG_ASSERT(diffFunc); + differentiateInst->replaceUsesWith(diffFunc); differentiateInst->removeAndDeallocate(); } - else if (isMarkedForForwardDifferentiation(baseFunction)) + else { - if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction)) - { - IRInst* diffFunc = transcriberStorage.transcribe(builder, baseFunction); - SLANG_ASSERT(diffFunc); - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - } - else - { - getSink()->diagnose(differentiateInst->sourceLoc, - Diagnostics::internalCompilerError, - "Unexpected instruction. Expected func or generic"); - } + getSink()->diagnose(differentiateInst->sourceLoc, + Diagnostics::internalCompilerError, + "Unexpected instruction. Expected func or generic"); } } - else if (as<IRBackwardDifferentiate>(differentiateInst)) + else { - if (isMarkedForBackwardDifferentiation(baseFunction)) + getSink()->diagnose(differentiateInst->sourceLoc, + Diagnostics::internalCompilerError, + "Requested differentiation on a function that isn't marked as differentiable."); + } + + } + else if (as<IRBackwardDifferentiate>(differentiateInst)) + { + if (isMarkedForBackwardDifferentiation(baseInst)) + { + if (as<IRFunc>(baseInst)) { - if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction)) - { - IRInst* diffFunc = - backwardTranscriberStorage - .transcribeFuncHeader(builder, (IRFunc*)baseFunction) - .differential; - SLANG_ASSERT(diffFunc); - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - } - else - { - getSink()->diagnose(differentiateInst->sourceLoc, - Diagnostics::internalCompilerError, - "Unexpected instruction. Expected func or generic"); - } + IRInst* diffFunc = + backwardTranscriberStorage + .transcribeFuncHeader(builder, (IRFunc*)baseInst) + .differential; + SLANG_ASSERT(diffFunc); + differentiateInst->replaceUsesWith(diffFunc); + differentiateInst->removeAndDeallocate(); + } + else + { + getSink()->diagnose(differentiateInst->sourceLoc, + Diagnostics::internalCompilerError, + "Unexpected instruction. Expected func or generic"); } } } @@ -3118,18 +3140,9 @@ struct JVPDerivativeContext : public InstPassBase // Checks decorators to see if the function should // be differentiated (kIROp_ForwardDifferentiableDecoration) // - bool isMarkedForForwardDifferentiation(IRGlobalValueWithCode* callable) + bool isMarkedForForwardDifferentiation(IRInst* callable) { - for (auto decoration = callable->getFirstDecoration(); - decoration; - decoration = decoration->getNextDecoration()) - { - if (decoration->getOp() == kIROp_ForwardDifferentiableDecoration) - { - return true; - } - } - return false; + return callable->findDecoration<IRForwardDifferentiableDecoration>() != nullptr; } IRStringLit* getForwardDerivativeFuncName(IRInst* func) @@ -3153,18 +3166,9 @@ struct JVPDerivativeContext : public InstPassBase // Checks decorators to see if the function should // be differentiated (kIROp_ForwardDifferentiableDecoration) // - bool isMarkedForBackwardDifferentiation(IRGlobalValueWithCode* callable) + bool isMarkedForBackwardDifferentiation(IRInst* callable) { - for (auto decoration = callable->getFirstDecoration(); - decoration; - decoration = decoration->getNextDecoration()) - { - if (decoration->getOp() == kIROp_BackwardDifferentiableDecoration) - { - return true; - } - } - return false; + return callable->findDecoration<IRBackwardDifferentiableDecoration>() != nullptr; } IRStringLit* getBackwardDerivativeFuncName(IRInst* func) |
