summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-14 09:31:51 -0700
committerGitHub <noreply@github.com>2023-03-14 09:31:51 -0700
commite291f60c6b083eaa74aed5307a6e9461274c1642 (patch)
treebde9b45a9e09ebbe173fae1821237b258a9ff800 /source/slang/slang-ir.cpp
parenta911ca6e06ce41e403b80fe6054162393491c8ac (diff)
Support `fwd_diff(bwd_diff(f))`. (#2697)
* Support `fwd_diff(bwd_diff(f))`. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir.cpp')
-rw-r--r--source/slang/slang-ir.cpp104
1 files changed, 104 insertions, 0 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 08c066f5d..f61e5a10e 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3509,6 +3509,104 @@ namespace Slang
return nullptr;
}
+ IRInst* IRBuilder::emitStructuralAddRaw(IRInst* val0, IRInst* val1)
+ {
+ IRInst* args[2] = { val0, val1 };
+ return emitIntrinsicInst(val0->getFullType(), kIROp_StructuralAdd, 2, args);
+ }
+
+ IRInst* IRBuilder::emitStructuralAdd(IRInst* val0, IRInst* val1, bool fallback)
+ {
+ auto type = val0->getFullType();
+ SLANG_RELEASE_ASSERT(val0->getFullType() == val1->getFullType());
+ IRType* actualType = val0->getFullType();
+ for (;;)
+ {
+ if (auto attr = as<IRAttributedType>(actualType))
+ actualType = attr->getBaseType();
+ else if (auto rateQualified = as<IRRateQualifiedType>(actualType))
+ actualType = rateQualified->getValueType();
+ else
+ break;
+ }
+ if (as<IRBasicType>(actualType))
+ return emitAdd(type, val0, val1);
+
+ switch (actualType->getOp())
+ {
+ case kIROp_PtrType:
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ return emitAdd(type, val0, val1);
+ case kIROp_TupleType:
+ {
+ List<IRInst*> elements;
+ auto tupleType = as<IRTupleType>(actualType);
+ for (UInt i = 0; i < tupleType->getOperandCount(); i++)
+ {
+ auto operand = tupleType->getOperand(i);
+ if (as<IRAttr>(operand))
+ break;
+ auto inner = emitStructuralAdd(
+ emitGetTupleElement((IRType*)operand, val0, i),
+ emitGetTupleElement((IRType*)operand, val1, i),
+ fallback);
+ if (!inner)
+ return nullptr;
+ elements.add(inner);
+ }
+ return emitMakeTuple(tupleType, elements);
+ }
+ case kIROp_StructType:
+ {
+ List<IRInst*> elements;
+ auto structType = as<IRStructType>(actualType);
+ for (auto field : structType->getFields())
+ {
+ auto fieldType = field->getFieldType();
+ auto inner = emitStructuralAdd(
+ emitFieldExtract(fieldType, val0, field->getKey()),
+ emitFieldExtract(fieldType, val1, field->getKey()),
+ fallback);
+ if (!inner)
+ return nullptr;
+ elements.add(inner);
+ }
+ return emitMakeStruct(type, elements);
+ }
+ case kIROp_ArrayType:
+ {
+ auto arrayType = as<IRArrayType>(actualType);
+ if (auto count = as<IRIntLit>(arrayType->getElementCount()))
+ {
+ auto elementType = arrayType->getElementType();
+ List<IRInst*> elements;
+ constexpr int maxCount = 4096;
+ if (count->getValue() > maxCount)
+ break;
+ for (IRIntegerValue i = 0; i < count->getValue(); i++)
+ {
+ auto index = getIntValue(getIntType(), i);
+ auto element = emitStructuralAdd(
+ emitElementExtract(elementType, val0, index),
+ emitElementExtract(elementType, val1, index),
+ fallback);
+ elements.add(element);
+ }
+ return emitMakeArray(type, elements.getCount(), elements.getBuffer());
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ if (fallback)
+ {
+ return emitStructuralAddRaw(val0, val1);
+ }
+ return nullptr;
+ }
+
static int _getTypeStyleId(IRType* type)
{
if (auto vectorType = as<IRVectorType>(type))
@@ -3928,6 +4026,11 @@ namespace Slang
return addDecoration(target, kIROp_PrimalElementTypeDecoration, type);
}
+ IRInst* IRBuilder::addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* type)
+ {
+ return addDecoration(target, kIROp_IntermediateContextFieldDifferentialTypeDecoration, type);
+ }
+
RefPtr<IRModule> IRModule::create(Session* session)
{
RefPtr<IRModule> module = new IRModule(session);
@@ -7028,6 +7131,7 @@ namespace Slang
case kIROp_Nop:
case kIROp_undefined:
case kIROp_DefaultConstruct:
+ case kIROp_StructuralAdd:
case kIROp_Specialize:
case kIROp_LookupWitness:
case kIROp_GetSequentialID: