summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-28 09:23:08 -0700
committerGitHub <noreply@github.com>2024-08-28 09:23:08 -0700
commit638e5fb000d4e242a91e8b653da4a72daec0efda (patch)
treecfcd15c1fc6bdee624eb33abac3268241b086dec
parent16595a8379e9dbfa1845fd72f3531ff3372da3ef (diff)
Make tuple types work in autodiff. (#4923)
-rw-r--r--source/slang/diff.meta.slang25
-rw-r--r--source/slang/slang-ast-builder.cpp4
-rw-r--r--source/slang/slang-ast-builder.h2
-rw-r--r--source/slang/slang-check-expr.cpp4
-rw-r--r--source/slang/slang-ir-addr-inst-elimination.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp6
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h2
-rw-r--r--source/slang/slang-ir-autodiff.cpp34
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir-lower-expand-type.cpp3
-rw-r--r--source/slang/slang-ir-lower-tuple-types.cpp68
-rw-r--r--source/slang/slang-ir.cpp12
-rw-r--r--source/slang/slang-lower-to-ir.cpp20
-rw-r--r--tests/language-feature/tuple/tuple-autodiff.slang49
14 files changed, 203 insertions, 32 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index a4c468ef7..80aca230a 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -1210,6 +1210,31 @@ extension Array<T, N> : IDifferentiable
}
}
+__generic<each T : IDifferentiable>
+extension Tuple<T> : IDifferentiable
+{
+ typealias Differential = Tuple<expand(each T).Differential>;
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return makeTuple(expand (each T).dzero());
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return makeTuple(expand(each T).dadd(each a, each b));
+ }
+
+ __generic<U : __BuiltinRealType>
+ [__unsafeForceInlineEarly]
+ static Differential dmul(U a, Differential b)
+ {
+ return makeTuple(expand(each T).dmul(a, each b));
+ }
+}
+
// Matrix transpose
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index a13e13851..9879a4187 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -523,7 +523,7 @@ FuncType* ASTBuilder::getFuncType(ArrayView<Type*> parameters, Type* result, Typ
return getOrCreate<FuncType>(parameters, result, errorType);
}
-TupleType* ASTBuilder::getTupleType(List<Type*>& types)
+TupleType* ASTBuilder::getTupleType(ArrayView<Type*> types)
{
// The canonical form of a tuple type is always a DeclRefType(GenAppDeclRef(TupleDecl, ConcreteTypePack(types...))).
// If `types` is already a single ConcreteTypePack, then we can use that directly.
@@ -536,7 +536,7 @@ TupleType* ASTBuilder::getTupleType(List<Type*>& types)
}
// Otherwise, we need to create a ConcreteTypePack to hold the types.
- auto typePack = getTypePack(types.getArrayView());
+ auto typePack = getTypePack(types);
return as<TupleType>(getSpecializedBuiltinType(typePack, "TupleType"));
}
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index 64282ce78..3e2a88dd8 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -508,7 +508,7 @@ public:
Val* getSNormModifierVal();
Val* getNoDiffModifierVal();
- TupleType* getTupleType(List<Type*>& types);
+ TupleType* getTupleType(ArrayView<Type*> types);
FuncType* getFuncType(ArrayView<Type*> parameters, Type* result, Type* errorType = nullptr);
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 500407e26..ec064b5b3 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -4055,7 +4055,7 @@ namespace Slang
{
types.add(baseTupleType->getMember(index));
}
- swizExpr->type = QualType(m_astBuilder->getTupleType(types));
+ swizExpr->type = QualType(m_astBuilder->getTupleType(types.getArrayView()));
}
// A swizzle can be used as an l-value as long as there
@@ -4908,7 +4908,7 @@ namespace Slang
types.reserve(expr->members.getCount());
for(auto t : expr->members)
types.add(t.type);
- auto tupleType = m_astBuilder->getTupleType(types);
+ auto tupleType = m_astBuilder->getTupleType(types.getArrayView());
expr->type = m_astBuilder->getTypeType(tupleType);
return expr;
diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp
index 7889a2f61..8a48936d7 100644
--- a/source/slang/slang-ir-addr-inst-elimination.cpp
+++ b/source/slang/slang-ir-addr-inst-elimination.cpp
@@ -25,7 +25,7 @@ struct AddressInstEliminationContext
case kIROp_GetElementPtr:
case kIROp_FieldAddress:
{
- IRInst* args[] = {getValue(builder, addr->getOperand(0)), addr->getOperand(1)};
+ IRInst* args[] = { getValue(builder, addr->getOperand(0)), addr->getOperand(1) };
return builder.emitIntrinsicInst(
cast<IRPtrTypeBase>(addr->getFullType())->getValueType(),
(addr->getOp() == kIROp_GetElementPtr ? kIROp_GetElement : kIROp_FieldExtract),
@@ -60,7 +60,7 @@ struct AddressInstEliminationContext
if (accessChain.getCount())
{
auto lastVal = builder.emitLoad(lastAddr);
- auto update = builder.emitUpdateElement(lastVal, accessChain, val);
+ auto update = builder.emitUpdateElement(lastVal, accessChain.getArrayView(), val);
builder.emitStore(lastAddr, update);
}
else
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 91d3e71cb..fe7c77ba0 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1164,7 +1164,7 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI
auto primalVal = findOrTranscribePrimalInst(builder, origVal);
IRInst* primalUpdateField =
- builder->emitUpdateElement(primalBase, primalAccessChain, primalVal);
+ builder->emitUpdateElement(primalBase, primalAccessChain.getArrayView(), primalVal);
IRInst* diffUpdateElement = nullptr;
List<IRInst*> diffAccessChain;
@@ -1198,7 +1198,7 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI
auto primalElementType = primalVal->getDataType();
diffUpdateElement = builder->emitUpdateElement(
- diffBase, diffAccessChain, diffVal);
+ diffBase, diffAccessChain.getArrayView(), diffVal);
builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType);
}
else
@@ -1206,7 +1206,7 @@ InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRI
auto primalElementType = primalVal->getDataType();
auto zeroElementDiff = getDifferentialZeroOfType(builder, primalElementType);
diffUpdateElement = builder->emitUpdateElement(
- diffBase, diffAccessChain, zeroElementDiff);
+ diffBase, diffAccessChain.getArrayView(), zeroElementDiff);
builder->addPrimalElementTypeDecoration(diffUpdateElement, primalElementType);
}
}
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index f8f6b03ab..d42462e1b 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -2016,7 +2016,7 @@ struct DiffTransposePass
SLANG_ASSERT(diffZero);
auto revRest = builder->emitUpdateElement(
revValue,
- accessChain,
+ accessChain.getArrayView(),
diffZero);
gradients.add(RevGradient(
RevGradient::Flavor::Simple,
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index bf83d8d7f..8ca7dbe76 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -448,26 +448,26 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
IRBuilder subBuilder(item->getConcreteType());
if (as<IRTypePack>(concreteType) || as<IRTupleType>(concreteType))
{
- // For tuple types, register the differential type for each element, but don't register for the
+ // For tuple types with concrete element types,
+ // register the differential type for each element, but don't register for the
// tuple/typepack itself.
- auto witnessPack = as<IRMakeWitnessPack>(witness);
- SLANG_ASSERT(witnessPack);
-
- for (UInt i = 0; i < concreteType->getOperandCount(); i++)
+ if (auto witnessPack = as<IRMakeWitnessPack>(witness))
{
- auto element = concreteType->getOperand(i);
- auto elementWitness = witnessPack->getOperand(i);
- differentiableWitnessDictionary.addIfNotExists(
- (IRType*)element,
- _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey));
+
+ for (UInt i = 0; i < concreteType->getOperandCount(); i++)
+ {
+ auto element = concreteType->getOperand(i);
+ auto elementWitness = witnessPack->getOperand(i);
+ differentiableWitnessDictionary.addIfNotExists(
+ (IRType*)element,
+ _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey));
+ }
+ return;
}
- return;
- }
- else
- {
- differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness());
}
+ differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness());
+
if (!as<IRInterfaceType>(item->getConcreteType()))
{
differentiableWitnessDictionary.addIfNotExists(
@@ -2241,12 +2241,16 @@ void releaseNullDifferentialType(AutoDiffSharedContext* context)
{
if (auto keepAliveDecoration = nullStruct->findDecoration<IRKeepAliveDecoration>())
keepAliveDecoration->removeAndDeallocate();
+ if (auto exportDecoration = nullStruct->findDecoration<IRHLSLExportDecoration>())
+ exportDecoration->removeAndDeallocate();
}
if (auto nullWitness = context->nullDifferentialWitness)
{
if (auto keepAliveDecoration = nullWitness->findDecoration<IRKeepAliveDecoration>())
keepAliveDecoration->removeAndDeallocate();
+ if (auto exportDecoration = nullWitness->findDecoration<IRHLSLExportDecoration>())
+ exportDecoration->removeAndDeallocate();
}
}
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 3236bb2e6..79362799b 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4209,7 +4209,7 @@ public:
IRInst* emitUpdateElement(IRInst* base, IRInst* index, IRInst* newElement);
IRInst* emitUpdateElement(IRInst* base, IRIntegerValue index, IRInst* newElement);
- IRInst* emitUpdateElement(IRInst* base, const List<IRInst*>& accessChain, IRInst* newElement);
+ IRInst* emitUpdateElement(IRInst* base, ArrayView<IRInst*> accessChain, IRInst* newElement);
IRInst* emitGetAddress(
IRType* type,
diff --git a/source/slang/slang-ir-lower-expand-type.cpp b/source/slang/slang-ir-lower-expand-type.cpp
index 8b68b1fc1..0f2c21dec 100644
--- a/source/slang/slang-ir-lower-expand-type.cpp
+++ b/source/slang/slang-ir-lower-expand-type.cpp
@@ -21,8 +21,9 @@ namespace Slang
{
auto eachInst = as<IREach>(val);
auto packInst = eachInst->getElement();
+ auto type = (IRType*)clonePatternVal(cloneEnv, builder, packInst->getFullType(), eachIndex);
packInst = clonePatternValImpl(cloneEnv, builder, packInst, eachIndex);
- auto result = builder->emitGetTupleElement(val->getFullType(), packInst, eachIndex);
+ auto result = builder->emitGetTupleElement(type, packInst, eachIndex);
return result;
}
case kIROp_Specialize:
diff --git a/source/slang/slang-ir-lower-tuple-types.cpp b/source/slang/slang-ir-lower-tuple-types.cpp
index 6177cfec2..91d6bfc29 100644
--- a/source/slang/slang-ir-lower-tuple-types.cpp
+++ b/source/slang/slang-ir-lower-tuple-types.cpp
@@ -262,6 +262,71 @@ namespace Slang
inst->removeAndDeallocate();
}
+ void processUpdateElement(IRUpdateElement* inst)
+ {
+ // For UpdateElement insts, we need to figure out all the intermediate types on the access chain,
+ // and if any of them are lowered tuples, we need to replace the access key with the new struct
+ // key for the lowered tuple struct.
+ //
+ ShortList<IRInst*> newAccessChain;
+ bool accessChainChanged = false;
+ auto baseType = inst->getOldValue()->getDataType();
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+
+ for (UInt i = 0; i < inst->getAccessKeyCount(); i++)
+ {
+ auto key = inst->getAccessKey(i);
+ if (auto structKey = as<IRStructKey>(key))
+ {
+ if (auto structType = as<IRStructType>(baseType))
+ {
+ auto field = findStructField(structType, structKey);
+ baseType = field->getFieldType();
+ newAccessChain.add(structKey);
+ }
+ else
+ {
+ // If we see anything not supported, just bail out.
+ return;
+ }
+ }
+ else if (auto arrayType = as<IRArrayTypeBase>(baseType))
+ {
+ baseType = arrayType->getElementType();
+ newAccessChain.add(key);
+ }
+ else if (auto loweredTupleInfo = getLoweredTupleType(&builder, baseType))
+ {
+ auto fieldIndex = getIntVal(key);
+ if (fieldIndex >= 0 && (Index)fieldIndex < loweredTupleInfo->fields.getCount())
+ {
+ auto field = loweredTupleInfo->fields[fieldIndex];
+ baseType = field->getFieldType();
+ newAccessChain.add(field->getKey());
+ accessChainChanged = true;
+ }
+ else
+ {
+ // If we see anything not supported, just bail out.
+ break;
+ }
+ }
+ else
+ {
+ // If we see anything not supported, just bail out.
+ break;
+ }
+ }
+
+ if (accessChainChanged)
+ {
+ auto newInst = builder.emitUpdateElement(inst->getOldValue(), newAccessChain.getArrayView().arrayView, inst->getElementValue());
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ }
+ }
+
void processInst(IRInst* inst)
{
switch (inst->getOp())
@@ -291,6 +356,9 @@ namespace Slang
case kIROp_IndexedFieldKey:
processIndexedFieldKey((IRIndexedFieldKey*)inst);
break;
+ case kIROp_UpdateElement:
+ processUpdateElement((IRUpdateElement*)inst);
+ break;
default:
break;
}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 0b0a42617..6a76ccce3 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5108,6 +5108,11 @@ namespace Slang
{
type = getVectorType(matrixType->getElementType(), matrixType->getColumnCount());
}
+ else if (auto tupleType = as<IRTupleType>(base->getDataType()))
+ {
+ type = (IRType*)tupleType->getOperand(getIntVal(index));
+ return emitGetTupleElement(type, base, index);
+ }
SLANG_RELEASE_ASSERT(type);
return emitElementExtract(type, base, index);
@@ -5211,6 +5216,11 @@ namespace Slang
// HLSL support things like float.x, in which case we just return the base pointer.
return basePtr;
}
+ else if (const auto tupleType = as<IRTupleType>(valueType))
+ {
+ SLANG_ASSERT(as<IRIntLit>(index));
+ type = (IRType*)tupleType->getOperand(getIntVal(index));
+ }
SLANG_RELEASE_ASSERT(type);
auto inst = createInst<IRGetElementPtr>(
this,
@@ -5281,7 +5291,7 @@ namespace Slang
return emitUpdateElement(base, getIntValue(getIntType(), index), newElement);
}
- IRInst* IRBuilder::emitUpdateElement(IRInst* base, const List<IRInst*>& accessChain, IRInst* newElement)
+ IRInst* IRBuilder::emitUpdateElement(IRInst* base, ArrayView<IRInst*> accessChain, IRInst* newElement)
{
List<IRInst*> args;
args.add(base);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 31427e616..87199734a 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -5401,6 +5401,8 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis
}
};
+ LoweredValInfo result;
+
// As required by the implementation of 'assign' and as a small
// optimization, we will detect if the base expression has also lowered
// into a swizzle and only return a single swizzle instead of nested
@@ -5435,7 +5437,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis
swizzledLValue->elementIndices);
context->shared->extValues.add(swizzledLValue);
- return LoweredValInfo::swizzledLValue(swizzledLValue);
+ result = LoweredValInfo::swizzledLValue(swizzledLValue);
}
else if(loweredBase.flavor == LoweredValInfo::Flavor::SwizzledMatrixLValue)
{
@@ -5455,7 +5457,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis
swizzledLValue->elementCoords);
context->shared->extValues.add(swizzledLValue);
- return LoweredValInfo::swizzledMatrixLValue(swizzledLValue);
+ result = LoweredValInfo::swizzledMatrixLValue(swizzledLValue);
}
else
{
@@ -5464,8 +5466,20 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis
swizzledLValue->base = loweredBase;
swizzledLValue->elementIndices = expr->elementIndices;
context->shared->extValues.add(swizzledLValue);
- return LoweredValInfo::swizzledLValue(swizzledLValue);
+ result = LoweredValInfo::swizzledLValue(swizzledLValue);
+ }
+
+ // For a one-element swizzle on a tuple, we can just return the pointer to the member
+ // instead of a SwizzledLValue because they can't follow the same folding logic as
+ // vectors and matrices.
+ //
+ bool shouldUseSimpleLVal = elementCount == 1 && as<TupleType>(expr->base->type) != nullptr;
+ if (shouldUseSimpleLVal)
+ {
+ auto addr = getAddress(context, result, expr->loc);
+ return LoweredValInfo::ptr(addr);
}
+ return result;
}
};
diff --git a/tests/language-feature/tuple/tuple-autodiff.slang b/tests/language-feature/tuple/tuple-autodiff.slang
new file mode 100644
index 000000000..d42cc0159
--- /dev/null
+++ b/tests/language-feature/tuple/tuple-autodiff.slang
@@ -0,0 +1,49 @@
+
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-cpu -compute -output-using-type -shaderobj
+
+// This is a test modified from autodiff/reverse-struct-multi-write.slang to test that
+// tuple types can be autodiff'ed the same way as struct types.
+
+//TEST_INPUT:ubuffer(data=[1 2], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typealias A = Tuple<float, Tuple<float, float>>;
+
+[Differentiable]
+A f(A a)
+{
+ // Read/writes to local struct variables won't be SSA'd out by default.
+ // The backward diff preparation pass will kick in to create temp vars for them.
+ A aout;
+ aout._1._1 = 2 * a._1._0;
+ aout._1._1 = aout._1._1 + 2 * a._1._0;
+ aout._1._0 = aout._1._1 + 5 * a._1._0;
+
+ // The result should be equivalent to:
+ /*
+ A aout;
+ var tmp = 2 * a.x;
+ tmp = tmp + 2 * a.x;
+ aout.y = tmp;
+ aout.x = tmp + 5 * a.x;
+ */
+ return aout;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ A a = makeTuple(1.0, makeTuple(1.0, 2.0));
+
+ var dpa = diffPair(a);
+
+ A.Differential dout = makeTuple(1.0, makeTuple(1.0, 1.0));
+
+ bwd_diff(f)(dpa, dout);
+ // CHECK: 13
+ outputBuffer[0] = dpa.d._1._0; // Expect: 13
+ // CHECK: 0
+ outputBuffer[1] = dpa.d._1._1; // Expect: 0
+}