diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-01 08:46:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-01 08:46:57 -0700 |
| commit | cbc1eff56057f199183bb7c17d8a360326512367 (patch) | |
| tree | 487865e928cd2ceecbb509f0bfd06aa8d9584411 /source/slang/slang-ir-diff-jvp.cpp | |
| parent | b707a07b1de3535cb0a8ccb6fe2ed4afa4a016d1 (diff) | |
Make `DifferentialPair` able to nest. (#2477)
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 72 |
1 files changed, 5 insertions, 67 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 73818dbb1..3d02d4fc0 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -205,7 +205,7 @@ struct DifferentiableTypeConformanceContext { if (as<IRModuleInst>(inst) && differentiableInterfaceType) { - // Assume for now that IDifferentiable has exactly three fields. + // Assume for now that IDifferentiable has exactly four fields. SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4); if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) return as<IRStructKey>(entry->getRequirementKey()); @@ -462,45 +462,6 @@ struct DifferentialPairTypeBuilder } } - void _createGenericDiffPairType(IRBuilder* builder) - { - // Insert directly at top level (skip any generic scopes etc.) - auto insertLoc = builder->getInsertLoc(); - builder->setInsertInto(builder->getModule()->getModuleInst()); - - // Make a generic version of the pair struct. - auto irGeneric = builder->emitGeneric(); - irGeneric->setFullType(builder->getTypeKind()); - builder->setInsertInto(irGeneric); - - generatedTypeList.add(irGeneric); - - auto irBlock = builder->emitBlock(); - builder->setInsertInto(irBlock); - - auto pTypeParam = builder->emitParam(builder->getTypeType()); - builder->addNameHintDecoration(pTypeParam, UnownedTerminatedStringSlice("pT")); - - auto dTypeParam = builder->emitParam(builder->getTypeType()); - builder->addNameHintDecoration(dTypeParam, UnownedTerminatedStringSlice("dT")); - - auto irStructType = builder->createStructType(); - builder->emitReturn(irStructType); - - auto primalKey = _getOrCreatePrimalStructKey(builder); - builder->addNameHintDecoration(primalKey, UnownedTerminatedStringSlice("primal")); - builder->createStructField(irStructType, primalKey, (IRType*) pTypeParam); - - auto diffKey = _getOrCreateDiffStructKey(builder); - builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential")); - builder->createStructField(irStructType, diffKey, (IRType*) dTypeParam); - - // Reset cursor when done. - builder->setInsertLoc(insertLoc); - - this->genericDiffPairType = irGeneric; - } - IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder) { if (!this->globalDiffKey) @@ -535,17 +496,6 @@ struct DifferentialPairTypeBuilder return this->globalPrimalKey; } - IRInst* _getOrCreateGenericDiffPairType(IRBuilder* builder) - { - if (!this->genericDiffPairType) - { - _createGenericDiffPairType(builder); - } - - SLANG_ASSERT(this->genericDiffPairType); - return this->genericDiffPairType; - } - IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType) { if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType)) @@ -1383,22 +1333,10 @@ struct JVPTranscriber } else { - // We special case a few non-differentiable types that sometimes appear in places - // where we're forced to provide a differential zero value. For instance, - // float3(float, float, int) is accepted by the compiler, but is tricky in the context - // of differentiation since int is non-differentiable, and should be cast to float first. - // In the absence of such casts, this piece of code generates appropriate zero values. - // - switch (primalType->getOp()) - { - case kIROp_IntType: - return builder->getIntValue(primalType, 0); - default: - getSink()->diagnose(primalType->sourceLoc, - Diagnostics::internalCompilerError, - "could not generate zero value for given type"); - return nullptr; - } + getSink()->diagnose(primalType->sourceLoc, + Diagnostics::internalCompilerError, + "could not generate zero value for given type"); + return nullptr; } } |
