diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-24 22:16:21 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-24 22:16:21 -0800 |
| commit | 951ad25e0a9c3b0089c6b996b8e821ac93cf5766 (patch) | |
| tree | 7bed99484204611a4669d7c2c11019795e37f7cb /source | |
| parent | a3b0eff62e59f3a05461bf3edee5e100e804e4d5 (diff) | |
Reimplement address elimination. (#2605)
* Reimplement address elimination pass.
* Fix error.
* Update test references.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
26 files changed, 1158 insertions, 578 deletions
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 72d8ec50a..bc95ed63d 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -419,6 +419,40 @@ protected: }; +struct ValSet +{ + struct ValItem + { + Val* val = nullptr; + ValItem() = default; + ValItem(Val* v) : val(v) {} + + HashCode getHashCode() + { + return val ? val->getHashCode() : 0; + } + bool operator==(ValItem other) + { + if (val == other.val) + return true; + if (val) + return val->equalsVal(other.val); + else if (other.val) + return other.val->equalsVal(val); + return false; + } + }; + HashSet<ValItem> set; + bool add(Val* val) + { + return set.Add(ValItem(val)); + } + bool contains(Val* val) + { + return set.Contains(ValItem(val)); + } +}; + } // namespace Slang #endif diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp index e5a9ff75f..bab651ed9 100644 --- a/source/slang/slang-ast-synthesis.cpp +++ b/source/slang/slang-ast-synthesis.cpp @@ -29,8 +29,11 @@ Expr* ASTSynthesizer::emitPostfixExpr(UnownedStringSlice operatorToken, Expr* ba ForStmt* ASTSynthesizer::emitFor(Expr* initVal, Expr* finalVal, VarDecl* &outIndexVar) { + auto parentStmt = getCurrentScope().m_parentSeqStmt; + auto seqStmt = m_builder->create<SeqStmt>(); auto scopeDecl = pushVarScope()->containerDecl; auto stmt = m_builder->create<ForStmt>(); + stmt->statement = seqStmt; stmt->scopeDecl = (ScopeDecl*)scopeDecl; auto declStmt = emitVarDeclStmt(nullptr, m_namePool->getName("S_synth_loop_index"), initVal); stmt->initialStatement = declStmt; @@ -38,7 +41,8 @@ ForStmt* ASTSynthesizer::emitFor(Expr* initVal, Expr* finalVal, VarDecl* &outInd auto predicateExpr = emitBinaryExpr(UnownedStringSlice("<"), emitVarExpr(outIndexVar), finalVal); stmt->predicateExpression = predicateExpr; stmt->sideEffectExpression = emitPrefixExpr(UnownedStringSlice("++"), emitVarExpr(outIndexVar)); - getCurrentScope().m_parentSeqStmt->stmts.add(stmt); + parentStmt->stmts.add(stmt); + m_scopeStack.getLast().m_parentSeqStmt = seqStmt; return stmt; } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7b5f85b60..5add89312 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -3308,7 +3308,6 @@ namespace Slang synth.popScope(); if (!assignStmt) return nullptr; - forStmt->statement = assignStmt; return forStmt; } @@ -4962,6 +4961,7 @@ namespace Slang auto thisType = calcThisType(parentDeclRef); maybeRegisterDifferentiableType(m_astBuilder, thisType); } + completeDifferentiableTypeDictionary(); m_parentDifferentiableAttr = oldAttr; } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 43124b535..2853c1eb9 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -922,7 +922,6 @@ namespace Slang } } - void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type) { if (!builder->isDifferentiableInterfaceAvailable()) @@ -962,6 +961,80 @@ namespace Slang } } + void SemanticsVisitor::maybeRegisterDifferentiableTypeRecursive(ASTBuilder* builder, Type* type, ValSet& workingSet) + { + if (workingSet.contains(type)) + return; + + if (!builder->isDifferentiableInterfaceAvailable()) + { + return; + } + + if (!m_parentDifferentiableAttr) + { + return; + } + + workingSet.add(type); + + // Check for special cases such as PtrTypeBase<T> or Array<T> + // This could potentially be handled later by simply defining extensions + // for Ptr<T:IDifferentiable> etc.. + // + if (auto ptrType = as<PtrTypeBase>(type)) + { + maybeRegisterDifferentiableTypeRecursive(builder, ptrType->getValueType(), workingSet); + return; + } + + if (auto arrayType = as<ArrayExpressionType>(type)) + { + maybeRegisterDifferentiableTypeRecursive(builder, arrayType->baseType, workingSet); + return; + } + + if (auto declRefType = as<DeclRefType>(type)) + { + if (auto subtypeWitness = as<SubtypeWitness>( + tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface()))) + { + registerDifferentiableType((DeclRefType*)type, subtypeWitness); + if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>()) + { + foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member) + { + auto subType = m_astBuilder->getOrCreateDeclRefType(member.getDecl(), nullptr); + maybeRegisterDifferentiableTypeRecursive(m_astBuilder, subType, workingSet); + }); + foreachDirectOrExtensionMemberOfType<VarDeclBase>(this, aggTypeDeclRef, [&](DeclRef<VarDeclBase> member) + { + auto fieldType = getType(m_astBuilder, member); + maybeRegisterDifferentiableTypeRecursive(m_astBuilder, fieldType, workingSet); + }); + } + } + return; + } + } + + void SemanticsVisitor::completeDifferentiableTypeDictionary() + { + ValSet workingSet; + for (auto type : m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness) + { + if (auto aggTypeDeclRef = type.Key.as<AggTypeDecl>()) + { + maybeRegisterDifferentiableTypeRecursive( + m_astBuilder, + m_astBuilder->getOrCreateDeclRefType( + aggTypeDeclRef.getDecl(), aggTypeDeclRef.substitutions), + workingSet); + } + } + } + + Expr* SemanticsVisitor::CheckTerm(Expr* term) { auto checkedTerm = _CheckTerm(term); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 6099febb5..1b59094e2 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -298,7 +298,7 @@ namespace Slang bool isDifferentiableFunc(FunctionDeclBase* func); bool isBackwardDifferentiableFunc(FunctionDeclBase* func); FunctionDifferentiableLevel getFuncDifferentiableLevel(FunctionDeclBase* func); - + private: /// Mapping from type declarations to the known extensiosn that apply to them Dictionary<AggTypeDecl*, RefPtr<CandidateExtensionList>> m_mapTypeDeclToCandidateExtensions; @@ -760,6 +760,8 @@ namespace Slang /// describing the relationship. /// void registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness); + void maybeRegisterDifferentiableTypeRecursive(ASTBuilder* builder, Type* type, ValSet& workingSet); + void completeDifferentiableTypeDictionary(); // Construct the differential for 'type', if it exists. Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc); diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index fbcd0dbbd..ffb469b9d 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1087,6 +1087,8 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) // Never fold these, because their result cannot be computed // as a sub-expression (they must be emitted as a declaration // or statement). + case kIROp_UpdateField: + case kIROp_UpdateElement: case kIROp_DefaultConstruct: return false; @@ -1135,6 +1137,7 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) case kIROp_MakeStruct: case kIROp_MakeArray: case kIROp_swizzleSet: + case kIROp_MakeArrayFromElement: return false; } @@ -2190,7 +2193,24 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO m_writer->emit(" }"); } break; + case kIROp_MakeArrayFromElement: + { + // TODO: initializer-list syntax may not always + // be appropriate, depending on the context + // of the expression. + m_writer->emit("{ "); + UInt argCount = + (UInt)cast<IRIntLit>(cast<IRArrayType>(inst->getDataType())->getElementCount()) + ->getValue(); + for (UInt aa = 0; aa < argCount; ++aa) + { + if (aa != 0) m_writer->emit(", "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + } + m_writer->emit(" }"); + } + break; case kIROp_BitCast: { // Note: we are currently emitting casts as plain old @@ -2461,6 +2481,54 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst) m_writer->emit(";\n"); } break; + + case kIROp_UpdateElement: + { + auto ii = (IRUpdateElement*)inst; + auto subscriptOuter = getInfo(EmitOp::General); + auto subscriptPrec = getInfo(EmitOp::Postfix); + auto arraySize = as<IRIntLit>(as<IRArrayType>(inst->getDataType())->getElementCount()); + SLANG_RELEASE_ASSERT(arraySize); + emitInstResultDecl(inst); + m_writer->emit("{"); + for (UInt i = 0; i < (UInt)arraySize->getValue(); i++) + { + if (i > 0) + m_writer->emit(", "); + emitOperand(ii->getOldValue(), leftSide(subscriptOuter, subscriptPrec)); + m_writer->emit("["); + m_writer->emit(i); + m_writer->emit("]"); + } + + m_writer->emit("}"); + m_writer->emit(";\n"); + + emitOperand(ii, leftSide(subscriptOuter, subscriptPrec)); + m_writer->emit("["); + emitOperand(ii->getIndex(), getInfo(EmitOp::General)); + m_writer->emit("] = "); + emitOperand(ii->getElementValue(), getInfo(EmitOp::General)); + m_writer->emit(";\n"); + } + break; + case kIROp_UpdateField: + { + auto ii = (IRUpdateField*)inst; + emitInstResultDecl(inst); + emitOperand(ii->getOldValue(), getInfo(EmitOp::General)); + m_writer->emit(";\n"); + + auto subscriptOuter = getInfo(EmitOp::General); + auto subscriptPrec = getInfo(EmitOp::Postfix); + emitOperand(ii, leftSide(subscriptOuter, subscriptPrec)); + m_writer->emit("."); + m_writer->emit(getName(ii->getFieldKey())); + m_writer->emit(" = "); + emitOperand(ii->getElementValue(), getInfo(EmitOp::General)); + m_writer->emit(";\n"); + } + break; } } diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp new file mode 100644 index 000000000..877be1406 --- /dev/null +++ b/source/slang/slang-ir-addr-inst-elimination.cpp @@ -0,0 +1,216 @@ +#include "slang-ir-addr-inst-elimination.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +// Rewrites address load/store into value extract/updates to allow SSA transform to apply to struct and array elements. +// For example, +// load(elementPtr(arr, 1)) ==> elementExtract(load(arr), 1) +// store(fieldAddr(s, field_key), val) ==> store(s, updateField(load(s), fieldKey, val)) +// After this transform, all address operands of `load` and `store` insts will be either a var or a param. + +struct AddressInstEliminationContext +{ + SharedIRBuilder* sharedBuilder; + DiagnosticSink* sink; + + IRInst* getValue(IRBuilder& builder, IRInst* addr) + { + switch (addr->getOp()) + { + default: + return builder.emitLoad(addr); + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + { + IRInst* args[] = {getValue(builder, addr->getOperand(0)), addr->getOperand(1)}; + return builder.emitIntrinsicInst( + cast<IRPtrTypeBase>(addr->getFullType())->getValueType(), + (addr->getOp() == kIROp_GetElementPtr ? kIROp_GetElement : kIROp_FieldExtract), + 2, + args); + } + } + } + + void storeValue(IRBuilder& builder, IRInst* addr, IRInst* val) + { + List<IRInst*> baseAddrs; + + for (auto inst = addr; inst;) + { + switch (inst->getOp()) + { + default: + baseAddrs.add(inst); + goto endLoop; + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + baseAddrs.add(inst); + inst = inst->getOperand(0); + break; + } + } + endLoop:; + List<IRInst*> values; + values.setCount(baseAddrs.getCount()); + if (values.getCount() > 1) + { + IRInst* currentVal = builder.emitLoad(baseAddrs.getLast()); + values.getLast() = currentVal; + for (Index i = baseAddrs.getCount() - 2; i >= 1; i--) + { + auto inst = baseAddrs[i]; + switch (inst->getOp()) + { + default: + sink->diagnose(inst->sourceLoc, Diagnostics::unsupportedUseOfLValueForAutoDiff); + return; + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + { + IRInst* args[] = { currentVal, inst->getOperand(1) }; + currentVal = builder.emitIntrinsicInst( + cast<IRPtrTypeBase>(inst->getFullType())->getValueType(), + (inst->getOp() == kIROp_GetElementPtr ? kIROp_GetElement : kIROp_FieldExtract), + 2, + args); + values[i] = currentVal; + } + break; + } + } + } + values[0] = val; + for (Index i = 1; i < values.getCount(); i++) + { + auto inst = baseAddrs[i - 1]; + switch (inst->getOp()) + { + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + { + IRInst* args[] = {values[i], inst->getOperand(1), values[i - 1]}; + values[i] = builder.emitIntrinsicInst( + values[i]->getFullType(), + (inst->getOp() == kIROp_GetElementPtr ? kIROp_UpdateElement : kIROp_UpdateField), + 3, + args); + } + break; + } + } + builder.emitStore(baseAddrs.getLast(), values.getLast()); + } + + void transformLoadAddr(IRUse* use) + { + auto addr = use->get(); + auto load = as<IRLoad>(use->getUser()); + + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(use->getUser()); + auto value = getValue(builder, addr); + load->replaceUsesWith(value); + load->removeAndDeallocate(); + } + + void transformStoreAddr(IRUse* use) + { + auto addr = use->get(); + auto store = as<IRStore>(use->getUser()); + + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(use->getUser()); + storeValue(builder, addr, store->getVal()); + store->removeAndDeallocate(); + } + + void transformCallAddr(IRUse* use) + { + auto addr = use->get(); + auto call = as<IRCall>(use->getUser()); + + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(call); + auto tempVar = builder.emitVar(cast<IRPtrTypeBase>(addr->getFullType())->getValueType()); + builder.emitStore(tempVar, getValue(builder, addr)); + builder.setInsertAfter(call); + storeValue(builder, addr, builder.emitLoad(tempVar)); + use->set(tempVar); + } + + SlangResult eliminateAddressInstsImpl( + SharedIRBuilder* inSharedBuilder, + AddressConversionPolicy* policy, + IRFunc* func, + DiagnosticSink* inSink) + { + sharedBuilder = inSharedBuilder; + sink = inSink; + + IRBuilder builder(sharedBuilder); + + List<IRInst*> workList; + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (as<IRPtrTypeBase>(inst->getDataType())) + { + workList.add(inst); + } + } + } + + for (Index workListIndex = 0; workListIndex < workList.getCount(); workListIndex++) + { + auto addrInst = workList[workListIndex]; + + if (!policy->shouldConvertAddrInst(addrInst)) + continue; + + for (auto use = addrInst->firstUse; use; ) + { + if (as<IRDecoration>(use->getUser())) + continue; + + auto nextUse = use->nextUse; + + switch (use->getUser()->getOp()) + { + case kIROp_Load: + transformLoadAddr(use); + break; + case kIROp_Store: + transformStoreAddr(use); + break; + case kIROp_Call: + transformCallAddr(use); + break; + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + break; + default: + sink->diagnose(use->getUser()->sourceLoc, Diagnostics::unsupportedUseOfLValueForAutoDiff); + break; + } + use = nextUse; + } + } + + return SLANG_OK; + } +}; + +SlangResult eliminateAddressInsts( + SharedIRBuilder* sharedBuilder, + AddressConversionPolicy* policy, + IRFunc* func, + DiagnosticSink* sink) +{ + AddressInstEliminationContext ctx; + return ctx.eliminateAddressInstsImpl(sharedBuilder, policy, func, sink); +} +} // namespace Slang diff --git a/source/slang/slang-ir-addr-inst-elimination.h b/source/slang/slang-ir-addr-inst-elimination.h new file mode 100644 index 000000000..53e5628f3 --- /dev/null +++ b/source/slang/slang-ir-addr-inst-elimination.h @@ -0,0 +1,21 @@ +// slang-ir-addr-inst-elimination.h +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ +struct SharedIRBuilder; +class DiagnosticSink; + +struct AddressConversionPolicy +{ + virtual bool shouldConvertAddrInst(IRInst* addrInst) = 0; +}; +SlangResult eliminateAddressInsts( + SharedIRBuilder* sharedBuilder, + AddressConversionPolicy* policy, + IRFunc* func, + DiagnosticSink* sink); + +} diff --git a/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp b/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp deleted file mode 100644 index c60995595..000000000 --- a/source/slang/slang-ir-autodiff-addr-inst-elimination.cpp +++ /dev/null @@ -1,476 +0,0 @@ -#include "slang-ir-address-analysis.h" -#include "slang-ir-autodiff-fwd.h" -#include "slang-ir-autodiff-pairs.h" -#include "slang-ir-autodiff-rev.h" -#include "slang-ir-autodiff.h" -#include "slang-ir-single-return.h" -#include "slang-ir-ssa-simplification.h" -#include "slang-ir-validate.h" - -namespace Slang -{ -bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst); - -struct AddressInstEliminationContext -{ - OrderedDictionary<IRInst*, IRInst*> mapAddrInstToTempVar; - - IRInst* _reconstructStruct( - IRBuilder& builder, IRStructType* type, IRInst* tempVar, List<AddressInfo*>& childAddrs) - { - List<IRInst*> args; - IRInst* loadedTempVar = nullptr; - for (auto child : type->getChildren()) - { - if (auto field = as<IRStructField>(child)) - { - IRInst* childVar = nullptr; - for (auto subAddr : childAddrs) - { - auto fieldAddrInst = cast<IRFieldAddress>(subAddr->addrInst); - if (fieldAddrInst->getField() == field->getKey()) - { - mapAddrInstToTempVar.TryGetValue(subAddr->addrInst, childVar); - break; - } - } - if (childVar) - { - args.add(builder.emitLoad(childVar)); - } - else - { - if (!loadedTempVar) - loadedTempVar = builder.emitLoad(tempVar); - args.add(builder.emitFieldExtract( - field->getFieldType(), loadedTempVar, field->getKey())); - } - } - } - return builder.emitMakeStruct(type, args); - } - - IRInst* _reconstructArray( - IRBuilder& builder, - IRArrayType* type, - IRIntegerValue arraySize, - IRInst* tempVar, - List<AddressInfo*>& childAddrs) - { - IRInst* loadedTempVar = nullptr; - List<IRInst*> args; - for (IRIntegerValue index = 0; index < arraySize; index++) - { - IRInst* childVar = nullptr; - for (auto subAddr : childAddrs) - { - auto elementPtrInst = cast<IRGetElementPtr>(subAddr->addrInst); - auto elementIndex = as<IRIntLit>(elementPtrInst->getIndex()); - if (elementIndex && elementIndex->getValue() == index) - { - mapAddrInstToTempVar.TryGetValue(subAddr->addrInst, childVar); - break; - } - } - if (childVar) - { - args.add(builder.emitLoad(childVar)); - } - else - { - if (!loadedTempVar) - loadedTempVar = builder.emitLoad(tempVar); - args.add(builder.emitElementExtract( - type->getElementType(), - loadedTempVar, - builder.getIntValue(builder.getIntType(), index))); - } - } - return builder.emitMakeArray(type, args.getCount(), args.getBuffer()); - } - - void updateChildTempVarRecursive( - IRBuilder& builder, - AddressInfo* addr, - IRInst* val) - { - for (auto child : addr->children) - { - IRInst* childVar = nullptr; - if (mapAddrInstToTempVar.TryGetValue(child->addrInst, childVar)) - { - switch (child->addrInst->getOp()) - { - case kIROp_FieldAddress: - { - auto subVal = builder.emitFieldExtract( - cast<IRPtrTypeBase>(child->addrInst->getDataType())->getValueType(), - val, - child->addrInst->getOperand(1)); - builder.emitStore(childVar, subVal); - updateChildTempVarRecursive(builder, child, subVal); - } - break; - case kIROp_GetElementPtr: - { - auto subVal = builder.emitElementExtract( - cast<IRPtrTypeBase>(child->addrInst->getDataType())->getValueType(), - val, - child->addrInst->getOperand(1)); - builder.emitStore(childVar, subVal); - updateChildTempVarRecursive(builder, child, subVal); - } - break; - default: - { - } - break; - } - } - } - } - - IRInst* getLoadedValue( - IRBuilder& builder, - AddressInfo* addr, - IRInst* tempVar) - { - if (addr->children.getCount()) - { - // Reconstruct val. - auto type = - cast<IRPtrTypeBase>(unwrapAttributedType(tempVar->getFullType()))->getValueType(); - switch (type->getOp()) - { - case kIROp_StructType: - return _reconstructStruct( - builder, as<IRStructType>(type), tempVar, addr->children); - case kIROp_ArrayType: - { - auto arrayType = as<IRArrayType>(type); - auto size = as<IRIntLit>(arrayType->getElementCount()); - if (!size || size->getValue() < 0) - { - // Unsupported array type. - } - else - { - return _reconstructArray( - builder, - arrayType, - size->getValue(), - tempVar, - addr->children); - } - } - break; - default: - // Unsupported address type. - break; - } - } - return builder.emitLoad(tempVar); - }; - - void updateParentTempVarRecursive( - IRBuilder& builder, - AddressInfo* addr) - { - for (auto parent = addr->parentAddress; parent; parent = parent->parentAddress) - { - IRInst* parentVar = nullptr; - if (mapAddrInstToTempVar.TryGetValue(parent->addrInst, parentVar)) - { - auto val = getLoadedValue(builder, parent, parentVar); - builder.emitStore(parentVar, val); - } - } - } - - String getAddrName(IRInst* addrInst) - { - StringBuilder sb; - List<IRInst*> bases; - bases.add(addrInst); - for (; addrInst;) - { - if (auto fieldAddr = as<IRFieldAddress>(addrInst)) - bases.add(fieldAddr->getBase()); - else if (auto index = as<IRGetElementPtr>(addrInst)) - bases.add(index->getBase()); - else - break; - } - for (Index i = bases.getCount() - 1; i >= 0; i--) - { - if (bases[i]->getOp() == kIROp_FieldAddress) - { - sb << "."; - auto field = bases[i]->getOperand(1); - auto nameDecor = field->findDecoration<IRNameHintDecoration>(); - sb << (nameDecor ? nameDecor->getName() : UnownedStringSlice("<unknown>")); - } - else if (bases[i]->getOp() == kIROp_FieldAddress) - { - sb << "["; - auto index = bases[i]->getOperand(1); - auto nameDecor = index->findDecoration<IRNameHintDecoration>(); - if (nameDecor) - { - sb << nameDecor->getName(); - } - else if (auto intLit = as<IRIntLit>(index)) - { - sb << intLit->getValue(); - } - else - { - sb << "..."; - } - sb << "]"; - } - else - { - auto nameDecor = bases[i]->findDecoration<IRNameHintDecoration>(); - sb << (nameDecor ? nameDecor->getName() : UnownedStringSlice("<unknown>")); - } - } - return sb.ProduceString(); - } - - SlangResult eliminateAddressInstsImpl( - SharedIRBuilder* sharedBuilder, - DifferentiableTypeConformanceContext& diffContext, - IRFunc* func, - DiagnosticSink* sink) - { - bool hasError = false; - - if (!isSingleReturnFunc(func)) - { - convertFuncToSingleReturnForm(func->getModule(), func); - } - - IRBuilder builder(sharedBuilder); - - auto dom = computeDominatorTree(func); - auto addrUse = analyzeAddressUse(dom, func); - List<AddressInfo*> workList; - HashSet<AddressInfo*> workListSet; - - // Process leaf addresses first. - for (auto addr : addrUse.addressInfos) - { - if (addr.Value->children.getCount() == 0) - workList.add(addr.Value); - } - - auto createTempVarForAddr = [&](IRInst* addrInst) - { - if (as<IRParam>(addrInst)) - builder.setInsertAfter(as<IRBlock>(addrInst->getParent())->getLastParam()); - else - builder.setInsertAfter(addrInst); - auto ptrType = as<IRPtrTypeBase>(addrInst->getFullType()); - SLANG_RELEASE_ASSERT(ptrType); - auto tempVar = builder.emitVar(ptrType->getValueType()); - mapAddrInstToTempVar[addrInst] = tempVar; - }; - - // In the first pass, we create temp vars for addresses with non-trivial access pattern. - for (Index workListIndex = 0; workListIndex < workList.getCount(); workListIndex++) - { - auto addr = workList[workListIndex]; - - if (!isDifferentiableType(diffContext, addr->addrInst->getDataType())) - continue; - - List<IRUse*> readUses, writeUses, callUses, subAddrUses, unknownUses; - - for (auto node = addr; node; node = node->parentAddress) - { - auto addrInst = node->addrInst; - - for (auto use = addrInst->firstUse; use; use = use->nextUse) - { - if (as<IRDecoration>(use->getUser())) - continue; - switch (use->getUser()->getOp()) - { - case kIROp_Load: - readUses.add(use); - break; - case kIROp_Store: - writeUses.add(use); - break; - case kIROp_Call: - callUses.add(use); - break; - case kIROp_GetElementPtr: - case kIROp_FieldAddress: - if (node == addr) - subAddrUses.add(use); - break; - default: - unknownUses.add(use); - break; - } - } - } - - if (unknownUses.getCount() != 0) - { - // Diagnose about unknown use. - sink->diagnose( - unknownUses.getFirst()->getUser(), - Diagnostics::unsupportedUseOfLValueForAutoDiff); - hasError = true; - continue; - } - - if (addr->isConstant) - { - // Otherwise, the address must be a constant, and we need to create a temp var for - // it. The exception is when the variable is a temp var for a call. - if (callUses.getCount() == 1 && writeUses.getCount() <= 1 && - readUses.getCount() <= 1) - { - if (writeUses.getCount() == 0) - continue; - - // The uses must be in write->call->read order. - auto callUse = callUses.getFirst(); - auto writeUse = writeUses.getFirst(); - auto readUse = readUses.getCount() ? readUses.getFirst() : writeUse; - if (dom->dominates(writeUse->getUser(), callUse->getUser()) && - dom->dominates(callUse->getUser(), readUse->getUser())) - { - continue; - } - } - - // Create a temp var for the address and replace all uses of the address to the temp - // var. - createTempVarForAddr(addr->addrInst); - } - else - { - // This is a dynamic address. We can only allow at most one write access to it. - bool hasNonTrivialAccess = false; - if (readUses.getCount() + callUses.getCount() != 0 && - writeUses.getCount() + callUses.getCount() > 1) - hasNonTrivialAccess = true; - - if (hasNonTrivialAccess) - { - // Mixed use of a non-constant address is unsupported right now. - sink->diagnose( - addr->addrInst, - Diagnostics::cannotDifferentiateDynamicallyIndexedData, - getAddrName(addr->addrInst)); - } - } - if (addr->parentAddress && workListSet.Add(addr->parentAddress)) - workList.add(addr->parentAddress); - } - - if (hasError) - return SLANG_FAIL; - - // Actually replace addresses with temp vars. - for (auto addr : workList) - { - IRInst* tempVar = nullptr; - if (!mapAddrInstToTempVar.TryGetValue(addr->addrInst, tempVar)) - continue; - for (auto use = addr->addrInst->firstUse; use;) - { - auto nextUse = use->nextUse; - auto user = use->getUser(); - - builder.setInsertBefore(user); - switch (user->getOp()) - { - case kIROp_Load: - use->set(tempVar); - break; - case kIROp_Store: - use->set(tempVar); - updateChildTempVarRecursive( - builder, addr, as<IRStore>(user)->getVal()); - updateParentTempVarRecursive(builder, addr); - case kIROp_Call: - { - use->set(tempVar); - builder.setInsertAfter(user); - auto newVal = builder.emitLoad(tempVar); - updateChildTempVarRecursive(builder, addr, newVal); - updateParentTempVarRecursive(builder, addr); - } - break; - default: - use->set(tempVar); - break; - } - use = nextUse; - } - } - - // Assign initial values to tempVar. - for (auto tempVar : mapAddrInstToTempVar) - { - builder.setInsertAfter(tempVar.Value); - IRInst* initVal = nullptr; - if (tempVar.Key->getOp() == kIROp_Var || - tempVar.Key->getOp() == kIROp_Param && as<IROutType>(tempVar.Key->getFullType())) - { - initVal = builder.emitDefaultConstruct( - cast<IRPtrTypeBase>(tempVar.Key->getFullType())->getValueType()); - } - else - { - initVal = builder.emitLoad(tempVar.Key); - } - builder.emitStore(tempVar.Value, initVal); - } - - // Store final values to out parameters before exiting function. - IRInst* returnInst = nullptr; - for (auto block : func->getBlocks()) - { - for (auto inst : block->getChildren()) - { - if (inst->getOp() == kIROp_Return) - { - returnInst = inst; - break; - } - } - } - SLANG_RELEASE_ASSERT(returnInst); - builder.setInsertBefore(returnInst); - for (auto param : func->getParams()) - { - IRInst* tempVar = nullptr; - if (mapAddrInstToTempVar.TryGetValue(param, tempVar)) - { - auto val = builder.emitLoad(tempVar); - builder.emitStore(param, val); - } - } - if (hasError) - return SLANG_FAIL; - return SLANG_OK; - } -}; - -SlangResult eliminateAddressInsts( - SharedIRBuilder* sharedBuilder, - DifferentiableTypeConformanceContext& diffContext, - IRFunc* func, - DiagnosticSink* sink) -{ - AddressInstEliminationContext ctx; - return ctx.eliminateAddressInstsImpl(sharedBuilder, diffContext, func, sink); -} -} // namespace Slang diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 3f3618b44..f5fa17fae 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -53,6 +53,21 @@ String ForwardDiffTranscriber::getJVPVarName(IRInst* origVar) return String(""); } +InstPair ForwardDiffTranscriber::transcribeUndefined(IRBuilder* builder, IRInst* origInst) +{ + auto primalVal = maybeCloneForPrimalInst(builder, origInst); + + if (IRType* diffType = differentiateType(builder, origInst->getFullType())) + { + auto dzero = getDifferentialZeroOfType(builder, origInst->getFullType()); + if (dzero) + { + return InstPair(primalVal, dzero); + } + } + return InstPair(primalVal, nullptr); +} + InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar) { if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType())) @@ -745,6 +760,91 @@ InstPair ForwardDiffTranscriber::transcribeGetElement(IRBuilder* builder, IRInst return InstPair(primalGetElementPtr, diffGetElementPtr); } +InstPair ForwardDiffTranscriber::transcribeUpdateField(IRBuilder* builder, IRInst* originalInst) +{ + auto updateInst = as<IRUpdateField>(originalInst); + + IRInst* origBase = updateInst->getOldValue(); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto field = updateInst->getFieldKey(); + auto primalVal = findOrTranscribePrimalInst(builder, updateInst->getElementValue()); + auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>(); + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType()); + + IRInst* primalOperands[] = { primalBase, field, primalVal }; + IRInst* primalUpdateField = builder->emitIntrinsicInst( + primalType, + originalInst->getOp(), + 3, + primalOperands); + + if (!derivativeRefDecor) + { + return InstPair(primalUpdateField, nullptr); + } + + IRInst* diffUpdateField = nullptr; + + if (auto diffType = differentiateType(builder, originalInst->getDataType())) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + { + if (auto diffVal = findOrTranscribeDiffInst(builder, updateInst->getElementValue())) + { + auto primalElementType = primalVal->getDataType(); + + IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey(), diffVal, primalElementType }; + diffUpdateField = builder->emitIntrinsicInst( + diffType, + originalInst->getOp(), + 4, + diffOperands); + } + } + } + return InstPair(primalUpdateField, diffUpdateField); +} + +InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst) +{ + auto updateInst = as<IRUpdateElement>(originalInst); + + IRInst* origBase = updateInst->getOldValue(); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto primalIndex = findOrTranscribePrimalInst(builder, updateInst->getIndex()); + auto origVal = updateInst->getElementValue(); + auto primalVal = findOrTranscribePrimalInst(builder, origVal); + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType()); + + IRInst* primalOperands[] = { primalBase, primalIndex, primalVal }; + IRInst* primalUpdateField = builder->emitIntrinsicInst( + primalType, + originalInst->getOp(), + 3, + primalOperands); + + IRInst* diffUpdateElement = nullptr; + + if (auto diffType = differentiateType(builder, originalInst->getDataType())) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + { + if (auto diffVal = findOrTranscribeDiffInst(builder, origVal)) + { + auto primalElementType = primalVal->getDataType(); + + IRInst* diffOperands[] = { diffBase, primalIndex, diffVal, primalElementType }; + diffUpdateElement = builder->emitIntrinsicInst( + diffType, + originalInst->getOp(), + 4, + diffOperands); + } + } + } + return InstPair(primalUpdateField, diffUpdateElement); +} + InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* origLoop) { // The loop comes with three blocks.. we just need to transcribe each one @@ -1132,6 +1232,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_FloatCast: case kIROp_MakeStruct: case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: return transcribeConstruct(builder, origInst); case kIROp_LookupWitness: @@ -1146,7 +1247,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_MakeVectorFromScalar: case kIROp_MakeTuple: return transcribeByPassthrough(builder, origInst); - + case kIROp_UpdateElement: + return transcribeUpdateElement(builder, origInst); + case kIROp_UpdateField: + return transcribeUpdateField(builder, origInst); case kIROp_unconditionalBranch: return transcribeControlFlow(builder, origInst); @@ -1194,6 +1298,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* return transcribeExtractExistentialWitnessTable(builder, origInst); case kIROp_WrapExistential: return transcribeWrapExistential(builder, origInst); + case kIROp_undefined: + return transcribeUndefined(builder, origInst); + case kIROp_CreateExistentialObject: // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, // so we treat this inst as non differentiable. diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index b09b57974..f8186a96e 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -23,6 +23,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase // String makeDiffPairName(IRInst* origVar); + InstPair transcribeUndefined(IRBuilder* builder, IRInst* origInst); + InstPair transcribeVar(IRBuilder* builder, IRVar* origVar); InstPair transcribeBinaryArith(IRBuilder* builder, IRInst* origArith); @@ -61,6 +63,10 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr); + InstPair transcribeUpdateField(IRBuilder* builder, IRInst* originalInst); + + InstPair transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst); + InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop); InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 779a4f1a3..000921c7e 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -501,15 +501,7 @@ namespace Slang stripDerivativeDecorations(primalFunc); eliminateDeadCode(primalOuterParent); - // Perform preparation and simplification. - differentiableTypeConformanceContext.setFunc(primalFunc); - if (SLANG_FAILED(eliminateAddressInsts( - builder->getSharedBuilder(), - differentiableTypeConformanceContext, - primalFunc, - sink))) - return nullptr; - + // Perform simplification. simplifyFunc(primalFunc); // Forward transcribe the clone of the original func. diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index cfbc9638a..89adbe6a0 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -624,6 +624,20 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()), getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType())); } + + if (auto arrayType = as<IRArrayType>(primalType)) + { + auto diffElementType = + (IRType*)differentiableTypeConformanceContext.getDifferentialForType( + builder, arrayType->getElementType()); + SLANG_RELEASE_ASSERT(diffElementType); + auto diffArrayType = builder->getArrayType(diffElementType, arrayType->getElementCount()); + auto diffElementZero = getDifferentialZeroOfType(builder, arrayType->getElementType()); + auto result = builder->emitMakeArrayFromElement(diffArrayType, diffElementZero); + builder->markInstAsDifferential(result, primalType); + return result; + } + // Since primalType has a corresponding differential type, we can lookup the // definition for zero(). auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index e799456bb..51dcd9f45 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -278,21 +278,12 @@ struct DiffTransposePass auto primalType = tryGetPrimalTypeFromDiffInst(fwdInst); auto diffType = fwdInst->getDataType(); - auto zeroMethod = diffTypeContext.getZeroMethodForType( - &tempVarBuilder, - primalType); - - SLANG_ASSERT(zeroMethod); + auto zero = emitDZeroOfDiffInstType(&tempVarBuilder, primalType); // Emit a var in the top-level differential block to hold the gradient, // and initialize it. auto tempRevVar = tempVarBuilder.emitVar(diffType); - auto diffZero = tempVarBuilder.emitCallInst( - diffType, - zeroMethod, - List<IRInst*>()); - tempVarBuilder.emitStore(tempRevVar, diffZero); - + tempVarBuilder.emitStore(tempRevVar, zero); revAccumulatorVarMap[fwdInst] = tempRevVar; return tempRevVar; @@ -1044,6 +1035,9 @@ struct DiffTransposePass case kIROp_FieldExtract: return transposeFieldExtract(builder, as<IRFieldExtract>(fwdInst), revValue); + case kIROp_GetElement: + return transposeGetElement(builder, as<IRGetElement>(fwdInst), revValue); + case kIROp_Return: return transposeReturn(builder, as<IRReturn>(fwdInst), revValue); @@ -1065,6 +1059,14 @@ struct DiffTransposePass return transposeMakeStruct(builder, fwdInst, revValue); case kIROp_MakeArray: return transposeMakeArray(builder, fwdInst, revValue); + case kIROp_MakeArrayFromElement: + return transposeMakeArrayFromElement(builder, fwdInst, revValue); + + case kIROp_UpdateElement: + return transposeUpdateElement(builder, fwdInst, revValue); + + case kIROp_UpdateField: + return transposeUpdateField(builder, fwdInst, revValue); case kIROp_Specialize: case kIROp_unconditionalBranch: @@ -1170,6 +1172,17 @@ struct DiffTransposePass fwdExtract))); } + TranspositionResult transposeGetElement(IRBuilder*, IRGetElement* fwdGetElement, IRInst* revValue) + { + return TranspositionResult( + List<RevGradient>( + RevGradient( + RevGradient::Flavor::GetElement, + fwdGetElement->getBase(), + revValue, + fwdGetElement))); + } + TranspositionResult transposeMakePair(IRBuilder*, IRMakeDifferentialPair* fwdMakePair, IRInst* revValue) { // Even though makePair returns a pair of (primal, differential) @@ -1250,27 +1263,108 @@ struct DiffTransposePass { List<RevGradient> gradients; auto arrayType = cast<IRArrayType>(fwdMakeArray->getFullType()); - auto arraySize = cast<IRIntLit>(arrayType->getElementCount()); - for (UInt ii = 0; ii < (UInt)arraySize->getValue(); ii++) + for (UInt ii = 0; ii < fwdMakeArray->getOperandCount(); ii++) { auto gradAtField = builder->emitElementExtract( arrayType->getElementType(), revValue, builder->getIntValue(builder->getIntType(), ii)); - SLANG_RELEASE_ASSERT(ii < fwdMakeArray->getOperandCount()); gradients.add(RevGradient( RevGradient::Flavor::Simple, fwdMakeArray->getOperand(ii), gradAtField, fwdMakeArray)); - ii++; } // (A = MakeArray(F1, F2, F3)) -> [(dF1 += dA.F1), (dF2 += dA.F2), (dF3 += dA.F3)] return TranspositionResult(gradients); } + TranspositionResult transposeMakeArrayFromElement(IRBuilder* builder, IRInst* fwdMakeArrayFromElement, IRInst* revValue) + { + List<RevGradient> gradients; + auto arrayType = cast<IRArrayType>(fwdMakeArrayFromElement->getFullType()); + auto arraySize = cast<IRIntLit>(arrayType->getElementCount()); + SLANG_RELEASE_ASSERT(arraySize); + // TODO: if arraySize is a generic value, we can't statically expand things here. + // In that case we probably need another opcode e.g. `Sum(arrayValue)` that can be expand + // later in the pipeline when `arrayValue` becomes a known value. + for (UInt ii = 0; ii < (UInt)arraySize->getValue(); ii++) + { + auto gradAtField = builder->emitElementExtract( + arrayType->getElementType(), + revValue, + builder->getIntValue(builder->getIntType(), ii)); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakeArrayFromElement->getOperand(0), + gradAtField, + fwdMakeArrayFromElement)); + } + + // (A = MakeArrayFromElement(E)) -> [(dE += dA.F1), (dE += dA.F2), (dE += dA.F3)] + return TranspositionResult(gradients); + } + + TranspositionResult transposeUpdateElement(IRBuilder* builder, IRInst* fwdUpdate, IRInst* revValue) + { + auto updateInst = as<IRUpdateElement>(fwdUpdate); + + List<RevGradient> gradients; + auto arrayType = cast<IRArrayType>(fwdUpdate->getFullType()); + auto revElement = builder->emitElementExtract(arrayType->getElementType(), revValue, updateInst->getIndex()); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + updateInst->getElementValue(), + revElement, + fwdUpdate)); + + auto primalElementType = updateInst->getPrimalElementType(); + auto diffZero = emitDZeroOfDiffInstType(builder, (IRType*)primalElementType); + SLANG_ASSERT(diffZero); + auto revRest = builder->emitUpdateElement( + revValue, + updateInst->getIndex(), + diffZero); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + updateInst->getOldValue(), + revRest, + fwdUpdate)); + // (A = UpdateElement(arr, index, V)) -> [(dV += dA[index], d_arr += UpdateElement(revValue, index, 0)] + return TranspositionResult(gradients); + } + + TranspositionResult transposeUpdateField(IRBuilder* builder, IRInst* fwdUpdate, IRInst* revValue) + { + auto updateInst = as<IRUpdateField>(fwdUpdate); + + List<RevGradient> gradients; + IRType* fieldType = updateInst->getElementValue()->getFullType(); + auto revElement = builder->emitFieldExtract(fieldType, revValue, updateInst->getFieldKey()); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + updateInst->getElementValue(), + revElement, + fwdUpdate)); + + auto primalElementType = updateInst->getPrimalElementType(); + auto diffZero = emitDZeroOfDiffInstType(builder, (IRType*)primalElementType); + SLANG_ASSERT(diffZero); + auto revRest = builder->emitUpdateField( + revValue, + updateInst->getFieldKey(), + diffZero); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + updateInst->getOldValue(), + revRest, + fwdUpdate)); + // (A = UpdateField(s, fieldKey, V)) -> [(dV += dA.fieldKey, d_s += UpdateField(revValue, fieldKey, 0)] + return TranspositionResult(gradients); + } + // Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr. // void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad) @@ -1548,11 +1642,80 @@ struct DiffTransposePass case RevGradient::Flavor::FieldExtract: return materializeFieldExtractGradients(builder, aggPrimalType, gradients); + case RevGradient::Flavor::GetElement: + return materializeGetElementGradients(builder, aggPrimalType, gradients); + default: SLANG_ASSERT_FAILURE("Unhandled gradient flavor for materialization"); } } + RevGradient materializeGetElementGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients) + { + // Setup a temporary variable to aggregate gradients. + // TODO: We can extend this later to grab an existing ptr to allow aggregation of + // gradients across blocks without constructing new variables. + // Looking up an existing pointer could also allow chained accesses like x.a.b[1] to directly + // write into the specific sub-field that is affected without constructing intermediate vars. + // + auto revGradVar = builder->emitVar( + (IRType*)diffTypeContext.getDifferentialForType(builder, aggPrimalType)); + + // Initialize with T.dzero() + auto zeroValueInst = emitDZeroOfDiffInstType(builder, aggPrimalType); + + builder->emitStore(revGradVar, zeroValueInst); + + OrderedDictionary<IRInst*, List<RevGradient>> bucketedGradients; + for (auto gradient : gradients) + { + // Grab the element affected by this gradient. + auto getElementInst = as<IRGetElement>(gradient.fwdGradInst); + SLANG_ASSERT(getElementInst); + + auto index = getElementInst->getIndex(); + SLANG_ASSERT(index); + + if (!bucketedGradients.ContainsKey(index)) + { + bucketedGradients[index] = List<RevGradient>(); + } + + bucketedGradients[index].GetValue().add(RevGradient( + RevGradient::Flavor::Simple, + gradient.targetInst, + gradient.revGradInst, + gradient.fwdGradInst + )); + + } + + for (auto pair : bucketedGradients) + { + auto subGrads = pair.Value; + + auto primalType = tryGetPrimalTypeFromDiffInst(subGrads[0].fwdGradInst); + + SLANG_ASSERT(primalType); + + // Construct address to this field in revGradVar. + auto revGradTargetAddress = builder->emitElementAddress( + builder->getPtrType(subGrads[0].revGradInst->getDataType()), + revGradVar, + pair.Key); + + builder->emitStore(revGradTargetAddress, emitAggregateValue(builder, primalType, subGrads)); + } + + // Load the entire var and return it. + return RevGradient( + RevGradient::Flavor::Simple, + gradients[0].targetInst, + builder->emitLoad(revGradVar), + nullptr); + } + + RevGradient materializeFieldExtractGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients) { // Setup a temporary variable to aggregate gradients. @@ -1569,7 +1732,7 @@ struct DiffTransposePass builder->emitStore(revGradVar, zeroValueInst); - Dictionary<IRStructKey*, List<RevGradient>> bucketedGradients; + OrderedDictionary<IRStructKey*, List<RevGradient>> bucketedGradients; for (auto gradient : gradients) { // Grab the field affected by this gradient. @@ -1601,7 +1764,7 @@ struct DiffTransposePass SLANG_ASSERT(primalType); - // Consruct address to this field in revGradVar. + // Construct address to this field in revGradVar. auto revGradTargetAddress = builder->emitFieldAddress( builder->getPtrType(subGrads[0].revGradInst->getDataType()), revGradVar, @@ -1734,6 +1897,14 @@ struct DiffTransposePass IRInst* emitDZeroOfDiffInstType(IRBuilder* builder, IRType* primalType) { + if (auto arrayType = as<IRArrayType>(primalType)) + { + auto diffElementType = (IRType*)diffTypeContext.getDifferentialForType(builder, arrayType->getElementType()); + SLANG_RELEASE_ASSERT(diffElementType); + auto diffArrayType = builder->getArrayType(diffElementType, arrayType->getElementCount()); + auto diffElementZero = emitDZeroOfDiffInstType(builder, arrayType->getElementType()); + return builder->emitMakeArrayFromElement(diffArrayType, diffElementZero); + } auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType); // Should exist. @@ -1747,6 +1918,30 @@ struct DiffTransposePass IRInst* emitDAddOfDiffInstType(IRBuilder* builder, IRType* primalType, IRInst* op1, IRInst* op2) { + if (auto arrayType = as<IRArrayType>(primalType)) + { + auto diffElementType = (IRType*)diffTypeContext.getDifferentialForType(builder, arrayType->getElementType()); + SLANG_RELEASE_ASSERT(diffElementType); + auto arraySize = arrayType->getElementCount(); + if (auto constArraySize = as<IRIntLit>(arraySize)) + { + List<IRInst*> args; + for (IRIntegerValue i = 0; i < constArraySize->getValue(); i++) + { + auto index = builder->getIntValue(builder->getIntType(), i); + auto op1Val = builder->emitElementExtract(diffElementType, op1, index); + auto op2Val = builder->emitElementExtract(diffElementType, op2, index); + args.add(emitDAddOfDiffInstType(builder, arrayType->getElementType(), op1Val, op2Val)); + } + auto diffArrayType = builder->getArrayType(diffElementType, arrayType->getElementCount()); + return builder->emitMakeArray(diffArrayType, (UInt)args.getCount(), args.getBuffer()); + } + else + { + // TODO: insert a runtime loop here. + SLANG_UNIMPLEMENTED_X("dadd of dynamic array."); + } + } auto addMethod = diffTypeContext.getAddMethodForType(builder, primalType); // Should exist. diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 640041ecf..8d9a01b75 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -185,6 +185,7 @@ struct ExtractPrimalFuncContext case kIROp_MakeStruct: case kIROp_MakeTuple: case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: case kIROp_MakeDifferentialPair: case kIROp_MakeOptionalNone: case kIROp_MakeOptionalValue: @@ -194,6 +195,8 @@ struct ExtractPrimalFuncContext case kIROp_GetElement: case kIROp_FieldExtract: case kIROp_swizzle: + case kIROp_UpdateElement: + case kIROp_UpdateField: case kIROp_OptionalHasValue: case kIROp_GetOptionalValue: case kIROp_MatrixReshape: diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 4d33d3743..ae8359251 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -532,6 +532,46 @@ void stripNoDiffTypeAttribute(IRModule* module) pass.processModule(); } +bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst) +{ + HashSet<IRInst*> processedSet; + for (;typeInst;) + { + if (as<IRArrayTypeBase>(typeInst) || as<IRPtrTypeBase>(typeInst)) + { + typeInst = typeInst->getOperand(0); + if (!processedSet.Add(typeInst)) + return false; + } + else + { + break; + } + } + if (!typeInst) + return false; + switch (typeInst->getOp()) + { + case kIROp_FloatType: + case kIROp_DifferentialPairType: + return true; + default: + break; + } + if (context.lookUpConformanceForType(typeInst)) + return true; + // Look for equivalent types. + for (auto type : context.differentiableWitnessDictionary) + { + if (isTypeEqual(type.Key, (IRType*)typeInst)) + { + context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value; + return true; + } + } + return false; +} + struct AutoDiffPass : public InstPassBase { DiagnosticSink* getSink() diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index cb767c20a..6da4ea6a6 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -175,20 +175,40 @@ struct DifferentiableTypeConformanceContext case kIROp_DoubleType: case kIROp_VectorType: return origType; + case kIROp_ArrayType: + { + auto diffElementType = (IRType*)getDifferentialForType( + builder, as<IRArrayType>(origType)->getElementType()); + if (!diffElementType) + return nullptr; + return builder->getArrayType( + diffElementType, + as<IRArrayType>(origType)->getElementCount()); + } + default: + return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); } - return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); } IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) { - return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); + auto result = lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); + if (result && !result->findDecoration<IRNoSideEffectDecoration>()) + { + builder->addDecoration(result, kIROp_NoSideEffectDecoration); + } + return result; } IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) { - return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); + auto result = lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); + if (result && !result->findDecoration<IRNoSideEffectDecoration>()) + { + builder->addDecoration(result, kIROp_NoSideEffectDecoration); + } + return result; } - }; struct DifferentialPairTypeBuilder @@ -261,10 +281,6 @@ void stripDerivativeDecorations(IRInst* inst); bool isBackwardDifferentiableFunc(IRInst* func); -SlangResult eliminateAddressInsts( - SharedIRBuilder* sharedBuilder, - DifferentiableTypeConformanceContext& diffContext, - IRFunc* func, - DiagnosticSink* sink); +bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst); }; diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 67b7e92b0..f8d70c8ed 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -2,40 +2,32 @@ #include "slang-ir-autodiff.h" #include "slang-ir-inst-pass-base.h" +#include "slang-ir-single-return.h" +#include "slang-ir-addr-inst-elimination.h" namespace Slang { -bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst) +IRInst* getSpecializedVal(IRInst* inst) { - HashSet<IRInst*> processedSet; - while (auto ptrType = as<IRPtrTypeBase>(typeInst)) + int loopLimit = 1024; + while (inst && inst->getOp() == kIROp_Specialize) { - typeInst = ptrType->getValueType(); - if (!processedSet.Add(typeInst)) - return false; - } - if (!typeInst) - return false; - switch (typeInst->getOp()) - { - case kIROp_FloatType: - case kIROp_DifferentialPairType: - return true; - default: - break; - } - if (context.lookUpConformanceForType(typeInst)) - return true; - // Look for equivalent types. - for (auto type : context.differentiableWitnessDictionary) - { - if (isTypeEqual(type.Key, (IRType*)typeInst)) - { - context.differentiableWitnessDictionary[(IRType*)typeInst] = type.Value; - return true; - } + inst = as<IRSpecialize>(inst)->getBase(); + loopLimit--; + if (loopLimit == 0) + return inst; } - return false; + return inst; +} + +IRInst* getLeafFunc(IRInst* func) +{ + func = getSpecializedVal(func); + if (!func) + return nullptr; + if (auto genericFunc = as<IRGeneric>(func)) + return findInnerMostGenericReturnVal(genericFunc); + return func; } struct CheckDifferentiabilityPassContext : public InstPassBase @@ -55,29 +47,6 @@ public: : InstPassBase(inModule), sharedBuilder(inSharedBuilder), sink(inSink), sharedContext(inModule->getModuleInst()) {} - IRInst* getSpecializedVal(IRInst* inst) - { - int loopLimit = 1024; - while (inst && inst->getOp() == kIROp_Specialize) - { - inst = as<IRSpecialize>(inst)->getBase(); - loopLimit--; - if (loopLimit == 0) - return inst; - } - return inst; - } - - IRInst* getLeafFunc(IRInst* func) - { - func = getSpecializedVal(func); - if (!func) - return nullptr; - if (auto genericFunc = as<IRGeneric>(func)) - return findInnerMostGenericReturnVal(genericFunc); - return func; - } - bool _isFuncMarkedForAutoDiff(IRInst* func) { func = getLeafFunc(func); @@ -219,6 +188,30 @@ public: } return false; } + + struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy + { + DifferentiableTypeConformanceContext* diffTypeContext; + + virtual bool shouldConvertAddrInst(IRInst* addrInst) override + { + if (isDifferentiableType(*diffTypeContext, addrInst->getDataType())) + return true; + return false; + } + }; + + SlangResult prepareFuncForAutoDiff(DifferentiableTypeConformanceContext& diffTypeContext, IRFunc* func) + { + if (!isSingleReturnFunc(func)) + { + convertFuncToSingleReturnForm(func->getModule(), func); + } + AutoDiffAddressConversionPolicy cvtPolicty; + cvtPolicty.diffTypeContext = &diffTypeContext; + return eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink); + } + void processFunc(IRGlobalValueWithCode* funcInst) { if (!_isFuncMarkedForAutoDiff(funcInst)) @@ -232,7 +225,7 @@ public: { if (auto func = as<IRFunc>(funcInst)) { - if (SLANG_FAILED(eliminateAddressInsts(sharedBuilder, diffTypeContext, func, sink))) + if (SLANG_FAILED(prepareFuncForAutoDiff(diffTypeContext, func))) return; } } diff --git a/source/slang/slang-ir-constexpr.cpp b/source/slang/slang-ir-constexpr.cpp index 0fedf695d..a8cdb5cca 100644 --- a/source/slang/slang-ir-constexpr.cpp +++ b/source/slang/slang-ir-constexpr.cpp @@ -115,9 +115,12 @@ bool opCanBeConstExpr(IROp op) case kIROp_MakeString: case kIROp_MakeUInt64: case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: case kIROp_swizzle: case kIROp_GetElement: case kIROp_FieldExtract: + case kIROp_UpdateField: + case kIROp_UpdateElement: case kIROp_ExtractExistentialType: case kIROp_ExtractExistentialValue: case kIROp_ExtractExistentialWitnessTable: diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 134a45bf5..c2a1886fb 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -293,6 +293,7 @@ INST(MakeMatrixFromScalar, makeMatrixFromScalar, 1, 0) INST(MatrixReshape, matrixReshape, 1, 0) INST(VectorReshape, vectorReshape, 1, 0) INST(MakeArray, makeArray, 0, 0) +INST(MakeArrayFromElement, makeArrayFromElement, 1, 0) INST(MakeStruct, makeStruct, 0, 0) INST(MakeTuple, makeTuple, 0, 0) INST(GetTupleElement, getTupleElement, 2, 0) @@ -310,6 +311,9 @@ INST(Call, call, 1, 0) INST(RTTIObject, rtti_object, 0, 0) INST(Alloca, alloca, 1, 0) +INST(UpdateElement, updateElement, 3, 0) +INST(UpdateField, updateField, 3, 0) + INST(PackAnyValue, packAnyValue, 1, 0) INST(UnpackAnyValue, unpackAnyValue, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 22da763b3..4887b1c79 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2185,6 +2185,36 @@ struct IRDifferentialPairGetPrimal : IRInst IRInst* getBase() { return getOperand(0); } }; +struct IRUpdateElement : IRInst +{ + IR_LEAF_ISA(UpdateElement) + + IRInst* getOldValue() { return getOperand(0); } + IRInst* getIndex() { return getOperand(1); } + IRInst* getElementValue() { return getOperand(2); } + IRInst* getPrimalElementType() + { + if (getOperandCount() != 4) + return nullptr; + return getOperand(3); + } +}; + +struct IRUpdateField : IRInst +{ + IR_LEAF_ISA(UpdateField) + + IRInst* getOldValue() { return getOperand(0); } + IRInst* getFieldKey() { return getOperand(1); } + IRInst* getElementValue() { return getOperand(2); } + IRInst* getPrimalElementType() + { + if (getOperandCount() != 4) + return nullptr; + return getOperand(3); + } +}; + // Constructs an `Result<T,E>` value from an error code. struct IRMakeResultError : IRInst { @@ -2925,6 +2955,10 @@ public: UInt argCount, IRInst* const* args); + IRInst* emitMakeArrayFromElement( + IRType* type, + IRInst* element); + IRInst* emitMakeStruct( IRType* type, UInt argCount, @@ -3139,6 +3173,9 @@ public: IRInst* basePtr, IRInst* index); + IRInst* emitUpdateElement(IRInst* base, IRInst* index, IRInst* newElement); + IRInst* emitUpdateField(IRInst* base, IRInst* fieldKey, IRInst* newFieldVal); + IRInst* emitGetAddress( IRType* type, IRInst* value); diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index 21e17b546..16f6cd9b9 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -69,7 +69,7 @@ struct PeepholeContext : InstPassBase Index i = 0; for (auto sfield : structType->getFields()) { - if (sfield == field) + if (sfield->getKey() == field) { fieldIndex = i; break; @@ -84,6 +84,142 @@ struct PeepholeContext : InstPassBase } } } + else if (auto updateField = as<IRUpdateField>(inst->getOperand(0))) + { + if (inst->getOperand(1) == updateField->getFieldKey()) + { + inst->replaceUsesWith(updateField->getElementValue()); + inst->removeAndDeallocate(); + changed = true; + } + else + { + inst->setOperand(0, updateField->getOldValue()); + changed = true; + } + } + break; + case kIROp_GetElement: + if (inst->getOperand(0)->getOp() == kIROp_MakeArray) + { + auto index = as<IRIntLit>(as<IRGetElement>(inst)->getIndex()); + if (!index) + break; + auto opCount = inst->getOperand(0)->getOperandCount(); + if ((UInt)index->getValue() < opCount) + { + inst->replaceUsesWith(inst->getOperand(0)->getOperand((UInt)index->getValue())); + inst->removeAndDeallocate(); + changed = true; + } + } + else if (inst->getOperand(0)->getOp() == kIROp_MakeArrayFromElement) + { + inst->replaceUsesWith(inst->getOperand(0)->getOperand(0)); + inst->removeAndDeallocate(); + changed = true; + } + else if (auto updateElement = as<IRUpdateElement>(inst->getOperand(0))) + { + if (inst->getOperand(1) == updateElement->getIndex()) + { + inst->replaceUsesWith(updateElement->getElementValue()); + inst->removeAndDeallocate(); + changed = true; + } + else if (auto constIndex1 = as<IRIntLit>(inst->getOperand(1))) + { + if (auto constIndex2 = as<IRIntLit>(updateElement->getIndex())) + { + // If we can determine that the indices does not match, + // then reduce the original value operand to before the update. + if (constIndex1->getValue() != constIndex2->getValue()) + { + inst->setOperand(0, updateElement->getOldValue()); + changed = true; + } + } + } + } + break; + case kIROp_UpdateElement: + { + if (auto constIndex = as<IRIntLit>(inst->getOperand(1))) + { + auto oldVal = inst->getOperand(0); + if (oldVal->getOp() == kIROp_MakeArray || + oldVal->getOp() == kIROp_MakeArrayFromElement) + { + auto arrayType = as<IRArrayType>(inst->getDataType()); + if (!arrayType) break; + auto arraySize = as<IRIntLit>(arrayType->getElementCount()); + if (!arraySize) break; + List<IRInst*> args; + for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) + { + IRInst* arg = nullptr; + if (i < (IRIntegerValue)oldVal->getOperandCount()) + arg = oldVal->getOperand((UInt)i); + else if (oldVal->getOperandCount() != 0) + arg = oldVal->getOperand(0); + else + break; + if (i == (IRIntegerValue)constIndex->getValue()) + arg = inst->getOperand(2); + args.add(arg); + } + if (args.getCount() == arraySize->getValue()) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(inst); + auto makeArray = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer()); + inst->replaceUsesWith(makeArray); + inst->removeAndDeallocate(); + changed = true; + } + } + } + } + break; + case kIROp_UpdateField: + { + auto oldVal = inst->getOperand(0); + if (oldVal->getOp() == kIROp_MakeStruct) + { + auto structType = as<IRStructType>(inst->getDataType()); + if (!structType) break; + List<IRInst*> args; + UInt i = 0; + bool isValid = true; + for (auto field : structType->getFields()) + { + IRInst* arg = nullptr; + if (i < oldVal->getOperandCount()) + arg = oldVal->getOperand(i); + if (field->getKey() == inst->getOperand(1)) + arg = inst->getOperand(2); + if (arg) + { + args.add(arg); + } + else + { + isValid = false; + break; + } + i++; + } + if (isValid) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(inst); + auto makeStruct = builder.emitMakeStruct(structType, (UInt)args.getCount(), args.getBuffer()); + inst->replaceUsesWith(makeStruct); + inst->removeAndDeallocate(); + changed = true; + } + } + } break; case kIROp_CastPtrToBool: { @@ -246,10 +382,17 @@ struct PeepholeContext : InstPassBase SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->init(module); sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); - - changed = false; - processChildInsts(func, [this](IRInst* inst) { processInst(inst); }); - return changed; + bool result = false; + for (;;) + { + changed = false; + processChildInsts(func, [this](IRInst* inst) { processInst(inst); }); + if (changed) + result = true; + else + break; + } + return result; } bool processModule() diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index a57bfce3e..bcf0907df 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -37,6 +37,8 @@ struct RedundancyRemovalContext case kIROp_FieldAddress: case kIROp_GetElement: case kIROp_GetElementPtr: + case kIROp_UpdateElement: + case kIROp_UpdateField: case kIROp_LookupWitness: case kIROp_Specialize: case kIROp_OptionalHasValue: @@ -46,6 +48,7 @@ struct RedundancyRemovalContext case kIROp_GetTupleElement: case kIROp_MakeStruct: case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: case kIROp_MakeVector: case kIROp_MakeMatrix: case kIROp_MakeMatrixFromScalar: @@ -75,6 +78,8 @@ struct RedundancyRemovalContext case kIROp_Neq: case kIROp_Eql: return true; + case kIROp_Call: + return isPureFunctionalCall(as<IRCall>(inst)); default: return false; } diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 319a23989..3ffbb75f7 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -227,6 +227,31 @@ String dumpIRToString(IRInst* root) return sb.ToString(); } +bool isPureFunctionalCall(IRCall* call) +{ + auto callee = getResolvedInstForDecorations(call->getCallee()); + if (callee->findDecoration<IRReadNoneDecoration>()) + { + return true; + } + if (callee->findDecoration<IRNoSideEffectDecoration>()) + { + // If the function has no side effect and is not writing to any outputs, + // we can safely treat the call as a normal inst. + bool hasOutArg = false; + for (UInt i = 0; i < call->getArgCount(); i++) + { + if (as<IRPtrTypeBase>(call->getArg(i)->getDataType())) + { + hasOutArg = true; + break; + } + } + return !hasOutArg; + } + return false; +} + struct GenericChildrenMigrationContextImpl { IRCloneEnv cloneEnv; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index a250fc6a6..26ad4bc68 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -155,6 +155,9 @@ inline IRInst* unwrapAttributedType(IRInst* type) String dumpIRToString(IRInst* root); +// Returns whether a call insts can be treated as a pure functional inst +// (no writes to memory, no side effects). +bool isPureFunctionalCall(IRCall* callInst); } #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index b79221900..2960d942c 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -1810,6 +1810,24 @@ namespace Slang } template<typename T> + static T* createInst( + IRBuilder* builder, + IROp op, + IRType* type, + IRInst* arg1, + IRInst* arg2, + IRInst* arg3) + { + IRInst* args[] = { arg1, arg2, arg3 }; + return createInstImpl<T>( + builder, + op, + type, + 3, + &args[0]); + } + + template<typename T> static T* createInstWithTrailingArgs( IRBuilder* builder, IROp op, @@ -3768,6 +3786,13 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeArray, argCount, args); } + IRInst* IRBuilder::emitMakeArrayFromElement( + IRType* type, + IRInst* element) + { + return emitIntrinsicInst(type, kIROp_MakeArrayFromElement, 1, &element); + } + IRInst* IRBuilder::emitMakeStruct( IRType* type, UInt argCount, @@ -4346,6 +4371,34 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitUpdateElement(IRInst* base, IRInst* index, IRInst* newElement) + { + auto inst = createInst<IRUpdateElement>( + this, + kIROp_UpdateElement, + base->getFullType(), + base, + index, + newElement); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitUpdateField(IRInst* base, IRInst* fieldKey, IRInst* newFieldVal) + { + auto inst = createInst<IRUpdateField>( + this, + kIROp_UpdateField, + base->getFullType(), + base, + fieldKey, + newFieldVal); + + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitGetAddress( IRType* type, IRInst* value) @@ -6545,11 +6598,7 @@ namespace Slang // common subexpression elimination, etc. // auto call = cast<IRCall>(this); - auto callee = getResolvedInstForDecorations(call->getCallee()); - if(callee->findDecoration<IRReadNoneDecoration>()) - { - return false; - } + return !isPureFunctionalCall(call); } break; @@ -6592,6 +6641,7 @@ namespace Slang case kIROp_MatrixReshape: case kIROp_VectorReshape: case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: case kIROp_MakeStruct: case kIROp_MakeString: case kIROp_getNativeStr: @@ -6612,6 +6662,8 @@ namespace Slang case kIROp_FieldAddress: case kIROp_GetElement: case kIROp_GetElementPtr: + case kIROp_UpdateElement: + case kIROp_UpdateField: case kIROp_MeshOutputRef: case kIROp_MakeVectorFromScalar: case kIROp_swizzle: |
