diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-14 09:31:51 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-14 09:31:51 -0700 |
| commit | e291f60c6b083eaa74aed5307a6e9461274c1642 (patch) | |
| tree | bde9b45a9e09ebbe173fae1821237b258a9ff800 /source/slang/slang-ir.cpp | |
| parent | a911ca6e06ce41e403b80fe6054162393491c8ac (diff) | |
Support `fwd_diff(bwd_diff(f))`. (#2697)
* Support `fwd_diff(bwd_diff(f))`.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir.cpp')
| -rw-r--r-- | source/slang/slang-ir.cpp | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 08c066f5d..f61e5a10e 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3509,6 +3509,104 @@ namespace Slang return nullptr; } + IRInst* IRBuilder::emitStructuralAddRaw(IRInst* val0, IRInst* val1) + { + IRInst* args[2] = { val0, val1 }; + return emitIntrinsicInst(val0->getFullType(), kIROp_StructuralAdd, 2, args); + } + + IRInst* IRBuilder::emitStructuralAdd(IRInst* val0, IRInst* val1, bool fallback) + { + auto type = val0->getFullType(); + SLANG_RELEASE_ASSERT(val0->getFullType() == val1->getFullType()); + IRType* actualType = val0->getFullType(); + for (;;) + { + if (auto attr = as<IRAttributedType>(actualType)) + actualType = attr->getBaseType(); + else if (auto rateQualified = as<IRRateQualifiedType>(actualType)) + actualType = rateQualified->getValueType(); + else + break; + } + if (as<IRBasicType>(actualType)) + return emitAdd(type, val0, val1); + + switch (actualType->getOp()) + { + case kIROp_PtrType: + case kIROp_VectorType: + case kIROp_MatrixType: + return emitAdd(type, val0, val1); + case kIROp_TupleType: + { + List<IRInst*> elements; + auto tupleType = as<IRTupleType>(actualType); + for (UInt i = 0; i < tupleType->getOperandCount(); i++) + { + auto operand = tupleType->getOperand(i); + if (as<IRAttr>(operand)) + break; + auto inner = emitStructuralAdd( + emitGetTupleElement((IRType*)operand, val0, i), + emitGetTupleElement((IRType*)operand, val1, i), + fallback); + if (!inner) + return nullptr; + elements.add(inner); + } + return emitMakeTuple(tupleType, elements); + } + case kIROp_StructType: + { + List<IRInst*> elements; + auto structType = as<IRStructType>(actualType); + for (auto field : structType->getFields()) + { + auto fieldType = field->getFieldType(); + auto inner = emitStructuralAdd( + emitFieldExtract(fieldType, val0, field->getKey()), + emitFieldExtract(fieldType, val1, field->getKey()), + fallback); + if (!inner) + return nullptr; + elements.add(inner); + } + return emitMakeStruct(type, elements); + } + case kIROp_ArrayType: + { + auto arrayType = as<IRArrayType>(actualType); + if (auto count = as<IRIntLit>(arrayType->getElementCount())) + { + auto elementType = arrayType->getElementType(); + List<IRInst*> elements; + constexpr int maxCount = 4096; + if (count->getValue() > maxCount) + break; + for (IRIntegerValue i = 0; i < count->getValue(); i++) + { + auto index = getIntValue(getIntType(), i); + auto element = emitStructuralAdd( + emitElementExtract(elementType, val0, index), + emitElementExtract(elementType, val1, index), + fallback); + elements.add(element); + } + return emitMakeArray(type, elements.getCount(), elements.getBuffer()); + } + break; + } + default: + break; + } + if (fallback) + { + return emitStructuralAddRaw(val0, val1); + } + return nullptr; + } + static int _getTypeStyleId(IRType* type) { if (auto vectorType = as<IRVectorType>(type)) @@ -3928,6 +4026,11 @@ namespace Slang return addDecoration(target, kIROp_PrimalElementTypeDecoration, type); } + IRInst* IRBuilder::addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* type) + { + return addDecoration(target, kIROp_IntermediateContextFieldDifferentialTypeDecoration, type); + } + RefPtr<IRModule> IRModule::create(Session* session) { RefPtr<IRModule> module = new IRModule(session); @@ -7028,6 +7131,7 @@ namespace Slang case kIROp_Nop: case kIROp_undefined: case kIROp_DefaultConstruct: + case kIROp_StructuralAdd: case kIROp_Specialize: case kIROp_LookupWitness: case kIROp_GetSequentialID: |
