summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-16 12:17:49 -0800
committerGitHub <noreply@github.com>2022-11-16 12:17:49 -0800
commit801aa3b44254341018a1acbe754f2ce3b0900e2a (patch)
treeb3066778522edb99bf64c0ac80c91b0b4cb788f8 /source/slang/slang-ir-diff-jvp.cpp
parent09d8e048d2264d89886cda8e87e8a452d4f913c1 (diff)
Clean up type checking of higher order expressions. (#2519)
* Clean up type checking of higher order expressions. * Replace `goto` with `break` to pacify clang. * Fix. * Fixes. * Fix more tests. * Fix lowerWitnessTable parameter error. * Exclude attributes from ast printing. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp134
1 files changed, 69 insertions, 65 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 1597c80d1..152601dbd 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -456,8 +456,15 @@ struct DifferentialPairTypeBuilder
IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType)
{
- SLANG_ASSERT(!as<IRParam>(origBaseType));
- SLANG_ASSERT(diffType);
+ switch (origBaseType->getOp())
+ {
+ case kIROp_lookup_interface_method:
+ case kIROp_Specialize:
+ case kIROp_Param:
+ return nullptr;
+ default:
+ break;
+ }
if (diffType->getOp() != kIROp_DifferentialBottomType)
{
IRBuilder builder(sharedContext->sharedBuilder);
@@ -511,6 +518,8 @@ struct DifferentialPairTypeBuilder
}
auto diffType = getDiffTypeFromPairType(builder, pairType);
+ if (!diffType)
+ return result;
result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
result.isTrivial = (diffType->getOp() == kIROp_DifferentialBottomType);
pairTypeCache.Add(originalPairType, result);
@@ -1431,10 +1440,10 @@ struct JVPTranscriber
else
{
getSink()->diagnose(origSpecialize->sourceLoc,
- Diagnostics::unexpected,
- "should not be attempting to differentiate anything specialized here.");
+ Diagnostics::unexpected,
+ "should not be attempting to differentiate anything specialized here.");
}
-
+
return InstPair(nullptr, nullptr);
}
@@ -2740,7 +2749,16 @@ struct JVPDerivativeContext : public InstPassBase
{
case kIROp_ForwardDifferentiate:
case kIROp_BackwardDifferentiate:
- autoDiffWorkList.add(inst);
+ // Only process now if the operand is a materialized function.
+ switch (inst->getOperand(0)->getOp())
+ {
+ case kIROp_Func:
+ case kIROp_Specialize:
+ autoDiffWorkList.add(inst);
+ break;
+ default:
+ break;
+ }
break;
default:
break;
@@ -2752,59 +2770,63 @@ struct JVPDerivativeContext : public InstPassBase
// Process collected `ForwardDifferentiate` insts and replace them with placeholders for
// differentiated functions.
+
transcriberStorage.followUpFunctionsToTranscribe.clear();
backwardTranscriberStorage.followUpFunctionsToTranscribe.clear();
for (auto differentiateInst : autoDiffWorkList)
{
IRInst* baseInst = differentiateInst->getOperand(0);
-
- if (auto baseFunction = as<IRGlobalValueWithCode>(baseInst))
+ if (as<IRForwardDifferentiate>(differentiateInst))
{
- if (as<IRForwardDifferentiate>(differentiateInst))
+ if (auto existingDiffFunc = lookupJVPReference(baseInst))
+ {
+ differentiateInst->replaceUsesWith(existingDiffFunc);
+ differentiateInst->removeAndDeallocate();
+ }
+ else if (isMarkedForForwardDifferentiation(baseInst))
{
- if (auto existingDiffFunc = lookupJVPReference(baseFunction))
+ if (as<IRFunc>(baseInst) || as<IRGeneric>(baseInst))
{
- differentiateInst->replaceUsesWith(existingDiffFunc);
+ IRInst* diffFunc = transcriberStorage.transcribe(builder, baseInst);
+ SLANG_ASSERT(diffFunc);
+ differentiateInst->replaceUsesWith(diffFunc);
differentiateInst->removeAndDeallocate();
}
- else if (isMarkedForForwardDifferentiation(baseFunction))
+ else
{
- if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
- {
- IRInst* diffFunc = transcriberStorage.transcribe(builder, baseFunction);
- SLANG_ASSERT(diffFunc);
- differentiateInst->replaceUsesWith(diffFunc);
- differentiateInst->removeAndDeallocate();
- }
- else
- {
- getSink()->diagnose(differentiateInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "Unexpected instruction. Expected func or generic");
- }
+ getSink()->diagnose(differentiateInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Unexpected instruction. Expected func or generic");
}
}
- else if (as<IRBackwardDifferentiate>(differentiateInst))
+ else
{
- if (isMarkedForBackwardDifferentiation(baseFunction))
+ getSink()->diagnose(differentiateInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Requested differentiation on a function that isn't marked as differentiable.");
+ }
+
+ }
+ else if (as<IRBackwardDifferentiate>(differentiateInst))
+ {
+ if (isMarkedForBackwardDifferentiation(baseInst))
+ {
+ if (as<IRFunc>(baseInst))
{
- if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
- {
- IRInst* diffFunc =
- backwardTranscriberStorage
- .transcribeFuncHeader(builder, (IRFunc*)baseFunction)
- .differential;
- SLANG_ASSERT(diffFunc);
- differentiateInst->replaceUsesWith(diffFunc);
- differentiateInst->removeAndDeallocate();
- }
- else
- {
- getSink()->diagnose(differentiateInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "Unexpected instruction. Expected func or generic");
- }
+ IRInst* diffFunc =
+ backwardTranscriberStorage
+ .transcribeFuncHeader(builder, (IRFunc*)baseInst)
+ .differential;
+ SLANG_ASSERT(diffFunc);
+ differentiateInst->replaceUsesWith(diffFunc);
+ differentiateInst->removeAndDeallocate();
+ }
+ else
+ {
+ getSink()->diagnose(differentiateInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Unexpected instruction. Expected func or generic");
}
}
}
@@ -3118,18 +3140,9 @@ struct JVPDerivativeContext : public InstPassBase
// Checks decorators to see if the function should
// be differentiated (kIROp_ForwardDifferentiableDecoration)
//
- bool isMarkedForForwardDifferentiation(IRGlobalValueWithCode* callable)
+ bool isMarkedForForwardDifferentiation(IRInst* callable)
{
- for (auto decoration = callable->getFirstDecoration();
- decoration;
- decoration = decoration->getNextDecoration())
- {
- if (decoration->getOp() == kIROp_ForwardDifferentiableDecoration)
- {
- return true;
- }
- }
- return false;
+ return callable->findDecoration<IRForwardDifferentiableDecoration>() != nullptr;
}
IRStringLit* getForwardDerivativeFuncName(IRInst* func)
@@ -3153,18 +3166,9 @@ struct JVPDerivativeContext : public InstPassBase
// Checks decorators to see if the function should
// be differentiated (kIROp_ForwardDifferentiableDecoration)
//
- bool isMarkedForBackwardDifferentiation(IRGlobalValueWithCode* callable)
+ bool isMarkedForBackwardDifferentiation(IRInst* callable)
{
- for (auto decoration = callable->getFirstDecoration();
- decoration;
- decoration = decoration->getNextDecoration())
- {
- if (decoration->getOp() == kIROp_BackwardDifferentiableDecoration)
- {
- return true;
- }
- }
- return false;
+ return callable->findDecoration<IRBackwardDifferentiableDecoration>() != nullptr;
}
IRStringLit* getBackwardDerivativeFuncName(IRInst* func)