summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-01 08:46:57 -0700
committerGitHub <noreply@github.com>2022-11-01 08:46:57 -0700
commitcbc1eff56057f199183bb7c17d8a360326512367 (patch)
tree487865e928cd2ceecbb509f0bfd06aa8d9584411 /source/slang/slang-ir-diff-jvp.cpp
parentb707a07b1de3535cb0a8ccb6fe2ed4afa4a016d1 (diff)
Make `DifferentialPair` able to nest. (#2477)
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp72
1 files changed, 5 insertions, 67 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 73818dbb1..3d02d4fc0 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -205,7 +205,7 @@ struct DifferentiableTypeConformanceContext
{
if (as<IRModuleInst>(inst) && differentiableInterfaceType)
{
- // Assume for now that IDifferentiable has exactly three fields.
+ // Assume for now that IDifferentiable has exactly four fields.
SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4);
if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
return as<IRStructKey>(entry->getRequirementKey());
@@ -462,45 +462,6 @@ struct DifferentialPairTypeBuilder
}
}
- void _createGenericDiffPairType(IRBuilder* builder)
- {
- // Insert directly at top level (skip any generic scopes etc.)
- auto insertLoc = builder->getInsertLoc();
- builder->setInsertInto(builder->getModule()->getModuleInst());
-
- // Make a generic version of the pair struct.
- auto irGeneric = builder->emitGeneric();
- irGeneric->setFullType(builder->getTypeKind());
- builder->setInsertInto(irGeneric);
-
- generatedTypeList.add(irGeneric);
-
- auto irBlock = builder->emitBlock();
- builder->setInsertInto(irBlock);
-
- auto pTypeParam = builder->emitParam(builder->getTypeType());
- builder->addNameHintDecoration(pTypeParam, UnownedTerminatedStringSlice("pT"));
-
- auto dTypeParam = builder->emitParam(builder->getTypeType());
- builder->addNameHintDecoration(dTypeParam, UnownedTerminatedStringSlice("dT"));
-
- auto irStructType = builder->createStructType();
- builder->emitReturn(irStructType);
-
- auto primalKey = _getOrCreatePrimalStructKey(builder);
- builder->addNameHintDecoration(primalKey, UnownedTerminatedStringSlice("primal"));
- builder->createStructField(irStructType, primalKey, (IRType*) pTypeParam);
-
- auto diffKey = _getOrCreateDiffStructKey(builder);
- builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential"));
- builder->createStructField(irStructType, diffKey, (IRType*) dTypeParam);
-
- // Reset cursor when done.
- builder->setInsertLoc(insertLoc);
-
- this->genericDiffPairType = irGeneric;
- }
-
IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder)
{
if (!this->globalDiffKey)
@@ -535,17 +496,6 @@ struct DifferentialPairTypeBuilder
return this->globalPrimalKey;
}
- IRInst* _getOrCreateGenericDiffPairType(IRBuilder* builder)
- {
- if (!this->genericDiffPairType)
- {
- _createGenericDiffPairType(builder);
- }
-
- SLANG_ASSERT(this->genericDiffPairType);
- return this->genericDiffPairType;
- }
-
IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType)
{
if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType))
@@ -1383,22 +1333,10 @@ struct JVPTranscriber
}
else
{
- // We special case a few non-differentiable types that sometimes appear in places
- // where we're forced to provide a differential zero value. For instance,
- // float3(float, float, int) is accepted by the compiler, but is tricky in the context
- // of differentiation since int is non-differentiable, and should be cast to float first.
- // In the absence of such casts, this piece of code generates appropriate zero values.
- //
- switch (primalType->getOp())
- {
- case kIROp_IntType:
- return builder->getIntValue(primalType, 0);
- default:
- getSink()->diagnose(primalType->sourceLoc,
- Diagnostics::internalCompilerError,
- "could not generate zero value for given type");
- return nullptr;
- }
+ getSink()->diagnose(primalType->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "could not generate zero value for given type");
+ return nullptr;
}
}