diff options
| author | Yong He <yonghe@outlook.com> | 2024-08-28 09:23:08 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-28 09:23:08 -0700 |
| commit | 638e5fb000d4e242a91e8b653da4a72daec0efda (patch) | |
| tree | cfcd15c1fc6bdee624eb33abac3268241b086dec | |
| parent | 16595a8379e9dbfa1845fd72f3531ff3372da3ef (diff) | |
Make tuple types work in autodiff. (#4923)
| -rw-r--r-- | source/slang/diff.meta.slang | 25 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-addr-inst-elimination.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-expand-type.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-tuple-types.cpp | 68 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 20 | ||||
| -rw-r--r-- | tests/language-feature/tuple/tuple-autodiff.slang | 49 |
14 files changed, 203 insertions, 32 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index a4c468ef7..80aca230a 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1210,6 +1210,31 @@ extension Array<T, N> : IDifferentiable } } +__generic<each T : IDifferentiable> +extension Tuple<T> : IDifferentiable +{ + typealias Differential = Tuple<expand(each T).Differential>; + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return makeTuple(expand (each T).dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return makeTuple(expand(each T).dadd(each a, each b)); + } + + __generic<U : __BuiltinRealType> + [__unsafeForceInlineEarly] + static Differential dmul(U a, Differential b) + { + return makeTuple(expand(each T).dmul(a, each b)); + } +} + // Matrix transpose __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> [ForceInline] diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index a13e13851..9879a4187 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -523,7 +523,7 @@ FuncType* ASTBuilder::getFuncType(ArrayView<Type*> parameters, Type* result, Typ return getOrCreate<FuncType>(parameters, result, errorType); } -TupleType* ASTBuilder::getTupleType(List<Type*>& types) +TupleType* ASTBuilder::getTupleType(ArrayView<Type*> types) { // The canonical form of a tuple type is always a DeclRefType(GenAppDeclRef(TupleDecl, ConcreteTypePack(types...))). // If `types` is already a single ConcreteTypePack, then we can use that directly. @@ -536,7 +536,7 @@ TupleType* ASTBuilder::getTupleType(List<Type*>& types) } // Otherwise, we need to create a ConcreteTypePack to hold the types. - auto typePack = getTypePack(types.getArrayView()); + auto typePack = getTypePack(types); return as<TupleType>(getSpecializedBuiltinType(typePack, "TupleType")); } diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 64282ce78..3e2a88dd8 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -508,7 +508,7 @@ public: Val* getSNormModifierVal(); Val* getNoDiffModifierVal(); - TupleType* getTupleType(List<Type*>& types); + TupleType* getTupleType(ArrayView<Type*> types); FuncType* getFuncType(ArrayView<Type*> parameters, Type* result, Type* errorType = nullptr); diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 500407e26..ec064b5b3 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -4055,7 +4055,7 @@ namespace Slang { types.add(baseTupleType->getMember(index)); } - swizExpr->type = QualType(m_astBuilder->getTupleType(types)); + swizExpr->type = QualType(m_astBuilder->getTupleType(types.getArrayView())); } // A swizzle can be used as an l-value as long as there @@ -4908,7 +4908,7 @@ namespace Slang types.reserve(expr->members.getCount()); for(auto t : expr->members) types.add(t.type); - auto tupleType = m_astBuilder->getTupleType(types); + auto tupleType = m_astBuilder->getTupleType(types.getArrayView()); expr->type = m_astBuilder->getTypeType(tupleType); return expr; diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp index 7889a2f61..8a48936d7 100644 --- a/source/slang/slang-ir-addr-inst-elimination.cpp +++ b/source/slang/slang-ir-addr-inst-elimination.cpp @@ -25,7 +25,7 @@ struct AddressInstEliminationContext case kIROp_GetElementPtr: case kIROp_FieldAddress: { - IRInst* args[] = {getValue(builder, addr->getOperand(0)), addr->getOperand(1)}; + IRInst* args[] = { getValue(builder, addr->getOperand(0)), addr->getOperand(1) }; return builder.emitIntrinsicInst( cast<IRPtrTypeBase>(addr->getFullType())->getValueType(), (addr->getOp() == kIROp_GetElementPtr ? kIROp_GetElement : kIROp_FieldExtract), @@ -60,7 +60,7 @@ struct AddressInstEliminationContext if (accessChain.getCount()) { auto lastVal = builder.emitLoad(lastAddr); - auto update = builder.emitUpdateElement(lastVal, accessChain, val); + auto update = builder.emitUpdateElement(lastVal, accessChain.getArrayView(), val); builder.emitStore(lastAddr, update); } else diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 91d3e71cb..fe7c77ba0 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1164,7 +1164,7 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI auto primalVal = findOrTranscribePrimalInst(builder, origVal); IRInst* primalUpdateField = - builder->emitUpdateElement(primalBase, primalAccessChain, primalVal); + builder->emitUpdateElement(primalBase, primalAccessChain.getArrayView(), primalVal); IRInst* diffUpdateElement = nullptr; List<IRInst*> diffAccessChain; @@ -1198,7 +1198,7 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI auto primalElementType = primalVal->getDataType(); diffUpdateElement = builder->emitUpdateElement( - diffBase, diffAccessChain, diffVal); + diffBase, diffAccessChain.getArrayView(), diffVal); builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType); } else @@ -1206,7 +1206,7 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI auto primalElementType = primalVal->getDataType(); auto zeroElementDiff = getDifferentialZeroOfType(builder, primalElementType); diffUpdateElement = builder->emitUpdateElement( - diffBase, diffAccessChain, zeroElementDiff); + diffBase, diffAccessChain.getArrayView(), zeroElementDiff); builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType); } } diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index f8f6b03ab..d42462e1b 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -2016,7 +2016,7 @@ struct DiffTransposePass SLANG_ASSERT(diffZero); auto revRest = builder->emitUpdateElement( revValue, - accessChain, + accessChain.getArrayView(), diffZero); gradients.add(RevGradient( RevGradient::Flavor::Simple, diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index bf83d8d7f..8ca7dbe76 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -448,26 +448,26 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) IRBuilder subBuilder(item->getConcreteType()); if (as<IRTypePack>(concreteType) || as<IRTupleType>(concreteType)) { - // For tuple types, register the differential type for each element, but don't register for the + // For tuple types with concrete element types, + // register the differential type for each element, but don't register for the // tuple/typepack itself. - auto witnessPack = as<IRMakeWitnessPack>(witness); - SLANG_ASSERT(witnessPack); - - for (UInt i = 0; i < concreteType->getOperandCount(); i++) + if (auto witnessPack = as<IRMakeWitnessPack>(witness)) { - auto element = concreteType->getOperand(i); - auto elementWitness = witnessPack->getOperand(i); - differentiableWitnessDictionary.addIfNotExists( - (IRType*)element, - _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey)); + + for (UInt i = 0; i < concreteType->getOperandCount(); i++) + { + auto element = concreteType->getOperand(i); + auto elementWitness = witnessPack->getOperand(i); + differentiableWitnessDictionary.addIfNotExists( + (IRType*)element, + _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey)); + } + return; } - return; - } - else - { - differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness()); } + differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness()); + if (!as<IRInterfaceType>(item->getConcreteType())) { differentiableWitnessDictionary.addIfNotExists( @@ -2241,12 +2241,16 @@ void releaseNullDifferentialType(AutoDiffSharedContext* context) { if (auto keepAliveDecoration = nullStruct->findDecoration<IRKeepAliveDecoration>()) keepAliveDecoration->removeAndDeallocate(); + if (auto exportDecoration = nullStruct->findDecoration<IRHLSLExportDecoration>()) + exportDecoration->removeAndDeallocate(); } if (auto nullWitness = context->nullDifferentialWitness) { if (auto keepAliveDecoration = nullWitness->findDecoration<IRKeepAliveDecoration>()) keepAliveDecoration->removeAndDeallocate(); + if (auto exportDecoration = nullWitness->findDecoration<IRHLSLExportDecoration>()) + exportDecoration->removeAndDeallocate(); } } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 3236bb2e6..79362799b 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4209,7 +4209,7 @@ public: IRInst* emitUpdateElement(IRInst* base, IRInst* index, IRInst* newElement); IRInst* emitUpdateElement(IRInst* base, IRIntegerValue index, IRInst* newElement); - IRInst* emitUpdateElement(IRInst* base, const List<IRInst*>& accessChain, IRInst* newElement); + IRInst* emitUpdateElement(IRInst* base, ArrayView<IRInst*> accessChain, IRInst* newElement); IRInst* emitGetAddress( IRType* type, diff --git a/source/slang/slang-ir-lower-expand-type.cpp b/source/slang/slang-ir-lower-expand-type.cpp index 8b68b1fc1..0f2c21dec 100644 --- a/source/slang/slang-ir-lower-expand-type.cpp +++ b/source/slang/slang-ir-lower-expand-type.cpp @@ -21,8 +21,9 @@ namespace Slang { auto eachInst = as<IREach>(val); auto packInst = eachInst->getElement(); + auto type = (IRType*)clonePatternVal(cloneEnv, builder, packInst->getFullType(), eachIndex); packInst = clonePatternValImpl(cloneEnv, builder, packInst, eachIndex); - auto result = builder->emitGetTupleElement(val->getFullType(), packInst, eachIndex); + auto result = builder->emitGetTupleElement(type, packInst, eachIndex); return result; } case kIROp_Specialize: diff --git a/source/slang/slang-ir-lower-tuple-types.cpp b/source/slang/slang-ir-lower-tuple-types.cpp index 6177cfec2..91d6bfc29 100644 --- a/source/slang/slang-ir-lower-tuple-types.cpp +++ b/source/slang/slang-ir-lower-tuple-types.cpp @@ -262,6 +262,71 @@ namespace Slang inst->removeAndDeallocate(); } + void processUpdateElement(IRUpdateElement* inst) + { + // For UpdateElement insts, we need to figure out all the intermediate types on the access chain, + // and if any of them are lowered tuples, we need to replace the access key with the new struct + // key for the lowered tuple struct. + // + ShortList<IRInst*> newAccessChain; + bool accessChainChanged = false; + auto baseType = inst->getOldValue()->getDataType(); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + for (UInt i = 0; i < inst->getAccessKeyCount(); i++) + { + auto key = inst->getAccessKey(i); + if (auto structKey = as<IRStructKey>(key)) + { + if (auto structType = as<IRStructType>(baseType)) + { + auto field = findStructField(structType, structKey); + baseType = field->getFieldType(); + newAccessChain.add(structKey); + } + else + { + // If we see anything not supported, just bail out. + return; + } + } + else if (auto arrayType = as<IRArrayTypeBase>(baseType)) + { + baseType = arrayType->getElementType(); + newAccessChain.add(key); + } + else if (auto loweredTupleInfo = getLoweredTupleType(&builder, baseType)) + { + auto fieldIndex = getIntVal(key); + if (fieldIndex >= 0 && (Index)fieldIndex < loweredTupleInfo->fields.getCount()) + { + auto field = loweredTupleInfo->fields[fieldIndex]; + baseType = field->getFieldType(); + newAccessChain.add(field->getKey()); + accessChainChanged = true; + } + else + { + // If we see anything not supported, just bail out. + break; + } + } + else + { + // If we see anything not supported, just bail out. + break; + } + } + + if (accessChainChanged) + { + auto newInst = builder.emitUpdateElement(inst->getOldValue(), newAccessChain.getArrayView().arrayView, inst->getElementValue()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + } + } + void processInst(IRInst* inst) { switch (inst->getOp()) @@ -291,6 +356,9 @@ namespace Slang case kIROp_IndexedFieldKey: processIndexedFieldKey((IRIndexedFieldKey*)inst); break; + case kIROp_UpdateElement: + processUpdateElement((IRUpdateElement*)inst); + break; default: break; } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 0b0a42617..6a76ccce3 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5108,6 +5108,11 @@ namespace Slang { type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount()); } + else if (auto tupleType = as<IRTupleType>(base->getDataType())) + { + type = (IRType*)tupleType->getOperand(getIntVal(index)); + return emitGetTupleElement(type, base, index); + } SLANG_RELEASE_ASSERT(type); return emitElementExtract(type, base, index); @@ -5211,6 +5216,11 @@ namespace Slang // HLSL support things like float.x, in which case we just return the base pointer. return basePtr; } + else if (const auto tupleType = as<IRTupleType>(valueType)) + { + SLANG_ASSERT(as<IRIntLit>(index)); + type = (IRType*)tupleType->getOperand(getIntVal(index)); + } SLANG_RELEASE_ASSERT(type); auto inst = createInst<IRGetElementPtr>( this, @@ -5281,7 +5291,7 @@ namespace Slang return emitUpdateElement(base, getIntValue(getIntType(), index), newElement); } - IRInst* IRBuilder::emitUpdateElement(IRInst* base, const List<IRInst*>& accessChain, IRInst* newElement) + IRInst* IRBuilder::emitUpdateElement(IRInst* base, ArrayView<IRInst*> accessChain, IRInst* newElement) { List<IRInst*> args; args.add(base); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 31427e616..87199734a 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -5401,6 +5401,8 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis } }; + LoweredValInfo result; + // As required by the implementation of 'assign' and as a small // optimization, we will detect if the base expression has also lowered // into a swizzle and only return a single swizzle instead of nested @@ -5435,7 +5437,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis swizzledLValue->elementIndices); context->shared->extValues.add(swizzledLValue); - return LoweredValInfo::swizzledLValue(swizzledLValue); + result = LoweredValInfo::swizzledLValue(swizzledLValue); } else if(loweredBase.flavor == LoweredValInfo::Flavor::SwizzledMatrixLValue) { @@ -5455,7 +5457,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis swizzledLValue->elementCoords); context->shared->extValues.add(swizzledLValue); - return LoweredValInfo::swizzledMatrixLValue(swizzledLValue); + result = LoweredValInfo::swizzledMatrixLValue(swizzledLValue); } else { @@ -5464,8 +5466,20 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis swizzledLValue->base = loweredBase; swizzledLValue->elementIndices = expr->elementIndices; context->shared->extValues.add(swizzledLValue); - return LoweredValInfo::swizzledLValue(swizzledLValue); + result = LoweredValInfo::swizzledLValue(swizzledLValue); + } + + // For a one-element swizzle on a tuple, we can just return the pointer to the member + // instead of a SwizzledLValue because they can't follow the same folding logic as + // vectors and matrices. + // + bool shouldUseSimpleLVal = elementCount == 1 && as<TupleType>(expr->base->type) != nullptr; + if (shouldUseSimpleLVal) + { + auto addr = getAddress(context, result, expr->loc); + return LoweredValInfo::ptr(addr); } + return result; } }; diff --git a/tests/language-feature/tuple/tuple-autodiff.slang b/tests/language-feature/tuple/tuple-autodiff.slang new file mode 100644 index 000000000..d42cc0159 --- /dev/null +++ b/tests/language-feature/tuple/tuple-autodiff.slang @@ -0,0 +1,49 @@ + +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-cpu -compute -output-using-type -shaderobj + +// This is a test modified from autodiff/reverse-struct-multi-write.slang to test that +// tuple types can be autodiff'ed the same way as struct types. + +//TEST_INPUT:ubuffer(data=[1 2], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typealias A = Tuple<float, Tuple<float, float>>; + +[Differentiable] +A f(A a) +{ + // Read/writes to local struct variables won't be SSA'd out by default. + // The backward diff preparation pass will kick in to create temp vars for them. + A aout; + aout._1._1 = 2 * a._1._0; + aout._1._1 = aout._1._1 + 2 * a._1._0; + aout._1._0 = aout._1._1 + 5 * a._1._0; + + // The result should be equivalent to: + /* + A aout; + var tmp = 2 * a.x; + tmp = tmp + 2 * a.x; + aout.y = tmp; + aout.x = tmp + 5 * a.x; + */ + return aout; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + A a = makeTuple(1.0, makeTuple(1.0, 2.0)); + + var dpa = diffPair(a); + + A.Differential dout = makeTuple(1.0, makeTuple(1.0, 1.0)); + + bwd_diff(f)(dpa, dout); + // CHECK: 13 + outputBuffer[0] = dpa.d._1._0; // Expect: 13 + // CHECK: 0 + outputBuffer[1] = dpa.d._1._1; // Expect: 0 +} |
