diff options
| author | Yong He <yonghe@outlook.com> | 2024-08-23 21:45:59 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-23 21:45:59 -0700 |
| commit | b2ca2d5a4efeae807d3c3f48f60235e47413b559 (patch) | |
| tree | 643d2bab5776e5f8f7cfa722975af9e826d77c9d /source | |
| parent | e4088cd602bd4d5a72fea67a787b1319acfc044d (diff) | |
Make variadic generics work with interfaces and forward autodiff. (#4905)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 24 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 44 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 58 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 43 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 191 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 87 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 156 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-expand-type.cpp | 167 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-expand-type.h | 30 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 352 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 10 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-mangle.cpp | 42 |
20 files changed, 878 insertions, 371 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 84e1b8168..0b57993ef 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -953,6 +953,30 @@ extension Tuple<T> : IComparable } } +interface IMutatingFunc<TR, each TP> +{ + [mutating] + TR __call(expand each TP p); +} + +interface IFunc<TR, each TP> : IMutatingFunc<TR, expand each TP> +{ + TR __call(expand each TP p); +} + +interface IDifferentiableMutatingFunc<TR : IDifferentiable, each TP : IDifferentiable> : IMutatingFunc<TR, expand each TP> +{ + [Differentiable] + [mutating] + TR __call(expand each TP p); +} + +interface IDifferentiableFunc<TR : IDifferentiable, each TP : IDifferentiable> : IFunc<TR, expand each TP>, IDifferentiableMutatingFunc<TR, expand each TP> +{ + [Differentiable] + TR __call(expand each TP p); +} + __generic<T> __magic_type(NativeRefType) __intrinsic_type($(kIROp_NativePtrType)) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index c27e0c6f0..66707fc56 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -3931,7 +3931,7 @@ namespace Slang { // Our synthesized method will have parameters matching the names // and types of those on the requirement, and it will use expressions - // that reference those parametesr as arguments for the call expresison + // that reference those parameters as arguments for the call expresison // that makes up the body. // for (auto paramDeclRef : getParameters(m_astBuilder, requirement)) @@ -3951,14 +3951,6 @@ namespace Slang synParamDecl->parentDecl = synthesized; synthesized->members.add(synParamDecl); - // For each paramter, we will create an argument expression - // for the call in the function body. - // - auto synArg = m_astBuilder->create<VarExpr>(); - synArg->declRef = makeDeclRef(synParamDecl); - synArg->type = paramType; - synArgs.add(synArg); - // Add modifiers for (auto modifier : paramDeclRef.getDecl()->modifiers) { @@ -3975,6 +3967,33 @@ namespace Slang addModifier(synParamDecl, clonedModifier); } } + + // Create an expression that references the parameter for use in arguments. + auto synArg = m_astBuilder->create<VarExpr>(); + synArg->declRef = makeDeclRef(synParamDecl); + synArg->type = paramType; + + if (auto typePack = as<ConcreteTypePack>(paramType)) + { + // If paramType is a concrete type pack, we want to expand it out into + // individual arguments. + for (Index i = 0; i < typePack->getTypeCount(); i++) + { + auto elementType = typePack->getElementType(i); + auto synMemberExpr = m_astBuilder->create<SwizzleExpr>(); + synMemberExpr->base = synArg; + synMemberExpr->elementIndices.add((UInt)i); + synMemberExpr->type = elementType; + synArgs.add(synMemberExpr); + } + } + else + { + // For ordinary non-pack paramters, we will use synArg directly to + // referencing the parameter for the call in the function body. + // + synArgs.add(synArg); + } } } @@ -4156,8 +4175,6 @@ namespace Slang addModifier(synFuncDecl, m_astBuilder->create<ForceInlineAttribute>()); synFuncDecl->parentDecl = aggTypeDecl; - SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); - bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl); } else { @@ -4281,6 +4298,11 @@ namespace Slang // synFuncDecl->parentDecl = context->parentDecl; + // If the synthesized func is differentiable, make sure to populate its + // differential type dictionary. + SemanticsDeclBodyVisitor bodyVisitor(withParentFunc(synFuncDecl)); + bodyVisitor.registerDifferentiableTypesForFunc(synFuncDecl); + // Once our synthesized declaration is complete, we need // to install it as the witness that satifies the given // requirement. diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index fe43a4f8f..4d36299bb 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1174,6 +1174,23 @@ namespace Slang } } + if (auto typePack = as<ConcreteTypePack>(type)) + { + bool anyDifferentiableElement = false; + List<Type*> diffTypes; + for (Index i = 0; i < typePack->getTypeCount(); i++) + { + auto t = typePack->getElementType(i); + auto diffType = tryGetDifferentialType(builder, t); + if (!diffType) + diffType = m_astBuilder->getVoidType(); + else + anyDifferentiableElement = true; + diffTypes.add(diffType); + } + if (anyDifferentiableElement) + return builder->getTypePack(diffTypes.getArrayView()); + } return nullptr; } @@ -1368,6 +1385,13 @@ namespace Slang }); return; } + + if (auto typePack = as<ConcreteTypePack>(type)) + { + for (Index i = 0; i < typePack->getTypeCount(); i++) + maybeRegisterDifferentiableTypeImplRecursive(builder, typePack->getElementType(i)); + return; + } } @@ -2797,6 +2821,36 @@ namespace Slang return modifiedType->getBase(); } + if (auto typePack = as<ConcreteTypePack>(primalType)) + { + // The differential pair of a type pack should be a type pack of differential pairs. + List<Type*> diffTypes; + for (Index i = 0; i < typePack->getTypeCount(); i++) + { + auto t = typePack->getElementType(i); + diffTypes.add(getDifferentialPairType(t)); + } + return m_astBuilder->getTypePack(diffTypes.getArrayView()); + } + else if (isAbstractTypePack(primalType)) + { + // The differential pair of an abstract type pack P should be `expand DifferentialPair<each P>`. + auto eachType = m_astBuilder->getEachType(primalType); + auto diffPairEachType = getDifferentialPairType(eachType); + if (auto expandType = as<ExpandType>(primalType)) + { + List<Type*> capturedTypePacks; + for (Index i = 0; i < expandType->getCapturedTypePackCount(); i++) + { + capturedTypePacks.add(expandType->getCapturedTypePack(i)); + } + return m_astBuilder->getExpandType(diffPairEachType, capturedTypePacks.getArrayView()); + } + else + { + return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType)); + } + } // Get a reference to the builtin 'IDifferentiable' interface auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType(); @@ -3598,6 +3652,10 @@ namespace Slang if (!isTypePack(baseType) && !as<TupleType>(baseType)) goto error; } + + if (auto tupleType = as<TupleType>(baseType)) + baseType = tupleType->getTypePack(); + { SLANG_ASSERT(m_capturedTypePacks); if (auto baseExpandType = as<ExpandType>(baseType)) diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 9adbe42d5..91d3e71cb 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -444,6 +444,10 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* operandDataType = (IRType*)findOrTranscribePrimalInst(builder, operandDataType); diffOperands.add(getDifferentialZeroOfType(builder, operandDataType)); } + else + { + diffOperands.add(builder->getVoidValue()); + } } } @@ -1110,6 +1114,39 @@ InstPair ForwardDiffTranscriber::transcribeGetElement(IRBuilder* builder, IRInst return InstPair(primalGetElementPtr, diffGetElementPtr); } +InstPair ForwardDiffTranscriber::transcribeGetTupleElement(IRBuilder* builder, IRInst* originalInst) +{ + IRInst* origBase = originalInst->getOperand(0); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto primalIndex = originalInst->getOperand(1); + + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType()); + + IRInst* primalOperands[] = { primalBase, primalIndex }; + IRInst* primalGetElement = builder->emitIntrinsicInst( + primalType, + originalInst->getOp(), + 2, + primalOperands); + + IRInst* diffGetElement = nullptr; + + if (auto diffType = differentiateType(builder, primalGetElement->getDataType())) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + { + IRInst* diffOperands[] = { diffBase, primalIndex }; + diffGetElement = builder->emitIntrinsicInst( + diffType, + originalInst->getOp(), + 2, + diffOperands); + } + } + + return InstPair(primalGetElement, diffGetElement); +} + InstPair ForwardDiffTranscriber::transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst) { auto updateInst = as<IRUpdateElement>(originalInst); @@ -1792,6 +1829,8 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_MakeVectorFromScalar: case kIROp_MakeArray: case kIROp_MakeArrayFromElement: + case kIROp_MakeTuple: + case kIROp_MakeValuePack: return transcribeConstruct(builder, origInst); case kIROp_MakeStruct: return transcribeMakeStruct(builder, origInst); @@ -1805,7 +1844,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_swizzle: return transcribeSwizzle(builder, as<IRSwizzle>(origInst)); - case kIROp_MakeTuple: case kIROp_Neg: return transcribeByPassthrough(builder, origInst); @@ -1832,6 +1870,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_GetElementPtr: return transcribeGetElement(builder, origInst); + case kIROp_GetTupleElement: + return transcribeGetTupleElement(builder, origInst); + case kIROp_ifElse: return transcribeIfElse(builder, as<IRIfElse>(origInst)); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index f88235558..f2659777d 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -48,6 +48,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct); InstPair transcribeMakeStruct(IRBuilder* builder, IRInst* origMakeStruct); + InstPair transcribeMakeTuple(IRBuilder* builder, IRInst* origMakeTuple); + // Differentiating a call instruction here is primarily about generating // an appropriate call list based on whichever parameters have differentials // in the current transcription context. @@ -68,6 +70,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr); + InstPair transcribeGetTupleElement(IRBuilder* builder, IRInst* origInst); + InstPair transcribeUpdateElement(IRBuilder* builder, IRInst* originalInst); InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index a1fa5f21a..da69ed8ae 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -174,179 +174,9 @@ IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRI IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey); -// Get or construct `:IDifferentiable` conformance for a DifferentiablePair. -IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType) -{ - // Differentiate the pair type to get it's differential (which is itself a pair) - auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)inOriginalDiffPairType); - - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); - - auto table = builder->createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, (IRType*)inPrimalDiffPairType); - - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType); - builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, autoDiffSharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, autoDiffSharedContext->zeroMethodStructKey, zeroMethod); - - bool isUserCodeType = as<IRDifferentialPairUserCodeType>(inOriginalDiffPairType) ? true : false; - - // Fill in differential method implementations. - auto elementType = as<IRDifferentialPairTypeBase>(inPrimalDiffPairType)->getValueType(); - auto innerWitness = as<IRDifferentialPairTypeBase>(inPrimalDiffPairType)->getWitness(); - - { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType)); - b.emitBlock(); - auto p0 = b.emitParam(diffDiffPairType); - auto p1 = b.emitParam(diffDiffPairType); - - // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. - auto innerAdd = _lookupWitness(&b, innerWitness, autoDiffSharedContext->addMethodStructKey); - IRInst* argsPrimal[2] = { - isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0), - isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) }; - auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal); - IRInst* argsDiff[2] = { - isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0), - isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)}; - auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff); - auto retVal = - isUserCodeType - ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart) - : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart); - b.emitReturn(retVal); - } - { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(zeroMethod); - zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType)); - b.emitBlock(); - auto innerZero = _lookupWitness(&b, innerWitness, autoDiffSharedContext->zeroMethodStructKey); - auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); - auto retVal = - isUserCodeType - ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal) - : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal); - b.emitReturn(retVal); - } - - // Record this in the context for future lookups - differentiableTypeConformanceContext.differentiableWitnessDictionary[(IRType*)inOriginalDiffPairType] = table; - - return table; -} - -// Get or construct `:IDifferentiable` conformance for an Array. -IRWitnessTable* AutoDiffTranscriberBase::getArrayWitness(IRBuilder* builder, IRInst* inOriginalArrayType, IRInst* inPrimalArrayType) -{ - // Differentiate the pair type to get it's differential (which is itself a pair) - auto diffArrayType = (IRType*)differentiateType(builder, (IRType*)inOriginalArrayType); - - if (!diffArrayType) - return nullptr; - - auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(inOriginalArrayType)->getElementType()); - - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); - - auto table = builder->createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, (IRType*)inPrimalArrayType); - - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffArrayType); - builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, autoDiffSharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, autoDiffSharedContext->zeroMethodStructKey, zeroMethod); - - auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType(); - - // Fill in differential method implementations. - { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffArrayType, diffArrayType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType)); - b.emitBlock(); - auto p0 = b.emitParam(diffArrayType); - auto p1 = b.emitParam(diffArrayType); - - // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. - auto innerAdd = _lookupWitness(&b, innerWitness, autoDiffSharedContext->addMethodStructKey); - auto resultVar = b.emitVar(diffArrayType); - IRBlock* loopBodyBlock = nullptr; - IRBlock* loopBreakBlock = nullptr; - auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock); - b.setInsertBefore(loopBodyBlock->getTerminator()); - - IRInst* args[2] = { - b.emitElementExtract(p0, loopCounter), - b.emitElementExtract(p1, loopCounter) }; - auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args); - auto addr = b.emitElementAddress(resultVar, loopCounter); - b.emitStore(addr, elementResult); - b.setInsertInto(loopBreakBlock); - b.emitReturn(b.emitLoad(resultVar)); - } - { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(zeroMethod); - zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType)); - b.emitBlock(); - - auto innerZero = _lookupWitness(&b, innerWitness, autoDiffSharedContext->zeroMethodStructKey); - auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); - auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal); - b.emitReturn(retVal); - } - - // Record this in the context for future lookups - differentiableTypeConformanceContext.differentiableWitnessDictionary[(IRType*)inOriginalArrayType] = table; - - return table; -} - IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType) { - if (isNoDiffType((IRType*)originalType)) - return nullptr; - - IRInst* witness = - differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)originalType); - if (witness) - { - witness = lookupPrimalInst(builder, witness, nullptr); - SLANG_RELEASE_ASSERT(witness || as<IRArrayType>(originalType)); - } - if (!witness) - { - auto primalType = lookupPrimalInst(builder, originalType, nullptr); - SLANG_RELEASE_ASSERT(primalType); - if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType)) - { - witness = getDifferentialPairWitness(builder, originalType, primalPairType); - } - else if (auto arrayType = as<IRArrayType>(primalType)) - { - witness = getArrayWitness(builder, originalType, arrayType); - } - else if (auto extractExistential = as<IRExtractExistentialType>(originalType)) - { - differentiateExtractExistentialType(builder, extractExistential, witness); - } - } - return witness; + return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType); } IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness) @@ -486,15 +316,20 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy } case kIROp_TupleType: + case kIROp_TypePack: { - auto tupleType = as<IRTupleType>(primalType); List<IRType*> diffTypeList; - // TODO: what if we have type parameters here? - for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) - diffTypeList.add( - differentiateType(builder, (IRType*)tupleType->getOperand(ii))); - - return builder->getTupleType(diffTypeList); + for (UIndex ii = 0; ii < primalType->getOperandCount(); ii++) + { + auto diffElementType = differentiateType(builder, (IRType*)primalType->getOperand(ii)); + if (!diffElementType) + diffElementType = builder->getVoidType(); + diffTypeList.add(diffElementType); + } + if (primalType->getOp() == kIROp_TupleType) + return builder->getTupleType(diffTypeList); + else + return builder->getTypePack((UInt)diffTypeList.getCount(), diffTypeList.getBuffer()); } default: diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index f672631e3..f7f2dd6f2 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -91,10 +91,6 @@ struct AutoDiffTranscriberBase void maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc); - // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. - IRWitnessTable* getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType); - IRWitnessTable* getArrayWitness(IRBuilder* builder, IRInst* inOriginalArrayType, IRInst* inPrimalArrayType); - IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType); IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 05884d13d..f8f6b03ab 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1486,6 +1486,9 @@ struct DiffTransposePass return transposeMakeStruct(builder, fwdInst, revValue); case kIROp_MakeArray: return transposeMakeArray(builder, fwdInst, revValue); + case kIROp_MakeTuple: + case kIROp_MakeValuePack: + return transposeMakeTuple(builder, fwdInst, revValue); case kIROp_MakeArrayFromElement: return transposeMakeArrayFromElement(builder, fwdInst, revValue); @@ -1898,6 +1901,29 @@ struct DiffTransposePass return TranspositionResult(gradients); } + TranspositionResult transposeMakeTuple(IRBuilder* builder, IRInst* fwdMakeTuple, IRInst* revValue) + { + List<RevGradient> gradients; + auto type = fwdMakeTuple->getDataType(); + for (UInt ii = 0; ii < type->getOperandCount(); ii++) + { + auto elementType = (IRType*)type->getOperand(ii); + auto gradAtField = builder->emitGetTupleElement( + elementType, + revValue, + ii); + SLANG_RELEASE_ASSERT(ii < fwdMakeTuple->getOperandCount()); + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakeTuple->getOperand(ii), + gradAtField, + fwdMakeTuple)); + } + + // (A = MakeTuple(F1, F2, F3)) -> [(dF1 += dA.F1), (dF2 += dA.F2), (dF3 += dA.F3)] + return TranspositionResult(gradients); + } + TranspositionResult transposeMakeStruct(IRBuilder* builder, IRInst* fwdMakeStruct, IRInst* revValue) { List<RevGradient> gradients; @@ -2429,25 +2455,38 @@ struct DiffTransposePass auto baseType = firstFwdSwizzleInst->getBase()->getDataType(); IRIntegerValue elementCount = 0; - IRType* elementType = nullptr; - IRType* primalElementType = nullptr; + List<IRType*> elementTypes; + List<IRType*> primalElementTypes; bool isVectorType = false; - + bool isTupleType = false; if (auto vectorType = as<IRVectorType>(baseType)) { IRInst* elementCountInst = vectorType->getElementCount(); - elementType = vectorType->getElementType(); - primalElementType = as<IRVectorType>(aggPrimalType)->getElementType(); - SLANG_ASSERT(as<IRIntLit>(elementCountInst)); elementCount = as<IRIntLit>(elementCountInst)->getValue(); + for (IRIntegerValue i = 0; i < elementCount; i++) + { + elementTypes.add(vectorType->getElementType()); + primalElementTypes.add(as<IRVectorType>(aggPrimalType)->getElementType()); + } + SLANG_ASSERT(as<IRIntLit>(elementCountInst)); isVectorType = true; } else if (auto basicType = as<IRBasicType>(baseType)) { - elementType = basicType; - primalElementType = aggPrimalType; + elementTypes.add(basicType); + primalElementTypes.add(aggPrimalType); elementCount = 1; } + else if (as<IRTupleType>(baseType) || as<IRTypePack>(baseType)) + { + isTupleType = true; + elementCount = baseType->getOperandCount(); + for (UInt i = 0; i < baseType->getOperandCount(); i++) + { + elementTypes.add((IRType*)baseType->getOperand(i)); + primalElementTypes.add((IRType*)(aggPrimalType->getOperand(i))); + } + } else { SLANG_UNREACHABLE("unknown operand type of swizzle."); @@ -2456,18 +2495,22 @@ struct DiffTransposePass IRInst* targetInst = firstGradient.targetInst; // Make a list of zeros of the base type. - auto zeroElement = emitDZeroOfDiffInstType(builder, primalElementType); List<IRInst*> elementGrads; + List<IRInst*> zeroElements; for (Index i = 0; i < elementCount; ++i) + { + auto zeroElement = emitDZeroOfDiffInstType(builder, primalElementTypes[i]); elementGrads.add(zeroElement); + zeroElements.add(zeroElement); + } auto accGrad = [&](UIndex i, IRInst* grad) { - if (elementGrads[i] == zeroElement) + if (elementGrads[i] == zeroElements[i]) elementGrads[i] = grad; else - elementGrads[i] = emitDAddOfDiffInstType(builder, primalElementType, elementGrads[i], grad); + elementGrads[i] = emitDAddOfDiffInstType(builder, primalElementTypes[i], elementGrads[i], grad); }; for (auto gradient : gradients) @@ -2493,12 +2536,19 @@ struct DiffTransposePass else if (isVectorType) accGrad((UIndex)targetIndex, builder->emitElementExtract( - elementType, + elementTypes[(UIndex)targetIndex], gradient.revGradInst, builder->getIntValue( builder->getIntType(), sourceIndex))); - // Case 3: Swizzled input is a scalar. + // Case 3: swizzled output is a tuple. + else if (isTupleType) + accGrad((UIndex)targetIndex, + builder->emitGetTupleElement( + elementTypes[(UIndex)targetIndex], + gradient.revGradInst, + (UInt)sourceIndex)); + // Case 4: Swizzled input is a scalar. else accGrad((UIndex)targetIndex, gradient.revGradInst); } @@ -2509,6 +2559,17 @@ struct DiffTransposePass targetInst, builder->emitMakeVector(baseType, (UInt)elementCount, elementGrads.getBuffer()), nullptr); + else if (isTupleType) + { + return RevGradient( + targetInst, + builder->emitIntrinsicInst( + baseType, + baseType->getOp()==kIROp_TupleType ? kIROp_MakeTuple : kIROp_MakeValuePack, + (UInt)elementCount, + elementGrads.getBuffer()), + nullptr); + } else return RevGradient( targetInst, diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 6b275179c..b7c2037e5 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -44,6 +44,13 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK return entry->getRequirementVal(); } } + else if (as<IRMakeWitnessPack>(witness)) + { + // We are looking up a witness from a type pack. + // This is only allowed if we are looking up a differential type. + // We should turn this into an actual witness table for the type pack/tuple type. + SLANG_UNEXPECTED("looking up from a witness pack is invalid and should have been lowered."); + } else { return builder->emitLookupInterfaceMethodInst( @@ -434,10 +441,33 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) } else { - differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness()); + auto witness = item->getWitness(); // Also register the type's differential type with the same witness. + auto concreteType = item->getConcreteType(); IRBuilder subBuilder(item->getConcreteType()); + if (as<IRTypePack>(concreteType) || as<IRTupleType>(concreteType)) + { + // For tuple types, register the differential type for each element, but don't register for the + // tuple/typepack itself. + auto witnessPack = as<IRMakeWitnessPack>(witness); + SLANG_ASSERT(witnessPack); + + for (UInt i = 0; i < concreteType->getOperandCount(); i++) + { + auto element = concreteType->getOperand(i); + auto elementWitness = witnessPack->getOperand(i); + differentiableWitnessDictionary.addIfNotExists( + (IRType*)element, + _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey)); + } + return; + } + else + { + differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness()); + } + if (!as<IRInterfaceType>(item->getConcreteType())) { differentiableWitnessDictionary.addIfNotExists( @@ -768,16 +798,18 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build SLANG_UNIMPLEMENTED_X("Impl"); } + case kIROp_TypePack: case kIROp_TupleType: { - auto tupleType = as<IRTupleType>(primalType); List<IRType*> diffTypeList; // TODO: what if we have type parameters here? - for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) + for (UIndex ii = 0; ii < primalType->getOperandCount(); ii++) diffTypeList.add( - differentiateType(builder, (IRType*)tupleType->getOperand(ii))); - - return builder->getTupleType(diffTypeList); + differentiateType(builder, (IRType*)primalType->getOperand(ii))); + if (primalType->getOp() == kIROp_TupleType) + return builder->getTupleType(diffTypeList); + else + return builder->getTypePack((UInt)diffTypeList.getCount(), diffTypeList.getBuffer()); } default: @@ -795,6 +827,12 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil { SLANG_RELEASE_ASSERT(witness || as<IRArrayType>(primalType)); } + if (as<IRMakeWitnessPack>(witness)) + { + // If registered witness is a witness pack for a type pack, + // we should reconstruct the true witness table. + witness = nullptr; + } if (!witness) { @@ -811,6 +849,14 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil { witness = getExtractExistensialTypeWitness(builder, extractExistential); } + else if (auto typePack = as<IRTypePack>(primalType)) + { + witness = getTupleWitness(builder, typePack); + } + else if (auto tupleType = as<IRTupleType>(primalType)) + { + witness = getTupleWitness(builder, tupleType); + } } return witness; } @@ -963,6 +1009,104 @@ IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder return table; } +IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder, IRInst* inTupleType) +{ + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffTupleType = (IRType*)differentiateType(builder, (IRType*)inTupleType); + + if (!diffTupleType) + return nullptr; + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + + auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + + // Fill in differential method implementations. + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffTupleType, diffTupleType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType)); + b.emitBlock(); + auto p0 = b.emitParam(diffTupleType); + auto p1 = b.emitParam(diffTupleType); + List<IRInst*> results; + for (UInt i = 0; i < inTupleType->getOperandCount(); i++) + { + auto elementType = inTupleType->getOperand(i); + auto diffElementType = (IRType*)diffTupleType->getOperand(i); + auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType); + IRInst* elementResult = nullptr; + if (!innerWitness) + { + elementResult = b.getVoidValue(); + } + else + { + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); + auto iVal = b.getIntValue(b.getIntType(), i); + IRInst* args[2] = { + b.emitGetTupleElement(diffElementType, p0, iVal), + b.emitGetTupleElement(diffElementType, p1, iVal) }; + elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args); + } + results.add(elementResult); + } + IRInst* resultVal = nullptr; + if (diffTupleType->getOp() == kIROp_TupleType) + resultVal = b.emitMakeTuple(diffTupleType, results); + else + resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); + b.emitReturn(resultVal); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); + b.emitBlock(); + List<IRInst*> results; + for (UInt i = 0; i < inTupleType->getOperandCount(); i++) + { + auto elementType = inTupleType->getOperand(i); + auto diffElementType = (IRType*)diffTupleType->getOperand(i); + auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType); + IRInst* elementResult = nullptr; + if (!innerWitness) + { + elementResult = b.getVoidValue(); + } + else + { + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); + elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr); + } + results.add(elementResult); + } + IRInst* resultVal = nullptr; + if (diffTupleType->getOp() == kIROp_TupleType) + resultVal = b.emitMakeTuple(diffTupleType, results); + else + resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); + b.emitReturn(resultVal); + } + + // Record this in the context for future lookups + differentiableWitnessDictionary[(IRType*)inTupleType] = table; + + return table; +} + IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness( IRBuilder* builder, IRExtractExistentialType* extractExistentialType) diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index d8f0373ac..23ae717be 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -191,6 +191,8 @@ struct DifferentiableTypeConformanceContext IRInst* getArrayWitness(IRBuilder* builder, IRArrayType* pairType); + IRInst* getTupleWitness(IRBuilder* builder, IRInst* tupleType); + IRInst* getExtractExistensialTypeWitness(IRBuilder* builder, IRExtractExistentialType* extractExistentialType); IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness); @@ -240,6 +242,11 @@ struct DifferentiableTypeConformanceContext diffElementType, as<IRArrayType>(origType)->getElementCount()); } + case kIROp_TupleType: + case kIROp_TypePack: + { + return differentiateType(builder, origType); + } case kIROp_DifferentialPairUserCodeType: { auto diffPairType = as<IRDifferentialPairTypeBase>(origType); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 80c810620..179ed3065 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -243,9 +243,11 @@ INST(AssociatedType, associated_type, 0, HOISTABLE) INST(ThisType, this_type, 0, HOISTABLE) INST(RTTIType, rtti_type, 0, HOISTABLE) INST(RTTIHandleType, rtti_handle_type, 0, HOISTABLE) -INST(TupleType, tuple_type, 0, HOISTABLE) +/*TupleTypeBase*/ + INST(TupleType, tuple_type, 0, HOISTABLE) + INST(TypePack, TypePack, 0, HOISTABLE) +INST_RANGE(TupleTypeBase, TupleType, TypePack) INST(TargetTupleType, TargetTuple, 0, HOISTABLE) -INST(TypePack, TypePack, 0, HOISTABLE) INST(ExpandTypeOrVal, ExpandTypeOrVal, 1, HOISTABLE) // A type that identifies it's contained type as being emittable as `spirv_literal. diff --git a/source/slang/slang-ir-lower-expand-type.cpp b/source/slang/slang-ir-lower-expand-type.cpp new file mode 100644 index 000000000..8b68b1fc1 --- /dev/null +++ b/source/slang/slang-ir-lower-expand-type.cpp @@ -0,0 +1,167 @@ +#include "slang-ir-lower-expand-type.h" + +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir-clone.h" + +namespace Slang +{ + IRInst* clonePatternVal(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, IRInst* eachIndex); + + IRInst* clonePatternValImpl(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, IRInst* eachIndex) + { + if (!val) + return val; + + switch (val->getOp()) + { + case kIROp_ExpandTypeOrVal: + return val; + case kIROp_Each: + { + auto eachInst = as<IREach>(val); + auto packInst = eachInst->getElement(); + packInst = clonePatternValImpl(cloneEnv, builder, packInst, eachIndex); + auto result = builder->emitGetTupleElement(val->getFullType(), packInst, eachIndex); + return result; + } + case kIROp_Specialize: + case kIROp_LookupWitness: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialWitnessTable: + break; + default: + // If the value is not a type, and it is not in a block, then it is some global inst + // that shouldn't be deep copied into current block, such as a IRFunc. + if (!as<IRType>(val) && getBlock(val->getParent()) == nullptr) + return val; + break; + } + bool anyChange = false; + ShortList<IRInst*> operands; + for (UInt i = 0; i < val->getOperandCount(); i++) + { + auto newOperand = clonePatternVal(cloneEnv, builder, val->getOperand(i), eachIndex); + if (newOperand != val->getOperand(i)) + anyChange = true; + operands.add(newOperand); + } + auto newType = clonePatternVal(cloneEnv, builder, val->getFullType(), eachIndex); + if (newType != val->getFullType()) + anyChange = true; + if (!anyChange) + return val; + + auto newVal = builder->emitIntrinsicInst((IRType*)newType, val->getOp(), operands.getCount(), operands.getArrayView().getBuffer()); + if (newVal != val) + { + cloneInstDecorationsAndChildren(&cloneEnv, builder->getModule(), val, newVal); + } + return newVal; + } + + IRInst* clonePatternVal(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, IRInst* eachIndex) + { + if (auto clonedVal = cloneEnv.mapOldValToNew.tryGetValue(val)) + return *clonedVal; + cloneEnv.mapOldValToNew[val] = val; + auto result = clonePatternValImpl(cloneEnv, builder, val, eachIndex); + cloneEnv.mapOldValToNew[val] = result; + return result; + } + + // Translate a `IRExpandType` into an `IRExpand` where the `PatternType` is defined + // inside the `IRExpand` body. + // + IRInst* lowerExpandTypeImpl(IRExpandType* expandType) + { + // Turn `IRExpandType` into an `IRExpand` instruction. + IRBuilder builder(expandType); + builder.setInsertBefore(expandType); + List<IRInst*> capturedArgs; + IRCloneEnv cloneEnv; + for (UInt i = 0; i < expandType->getCaptureCount(); i++) + { + auto capturedArg = expandType->getCaptureType(i); + capturedArgs.add(capturedArg); + } + auto result = builder.emitExpandInst(expandType->getFullType(), expandType->getCaptureCount(), capturedArgs.getBuffer()); + builder.setInsertInto(result); + builder.emitBlock(); + auto eachIndex = builder.emitParam(builder.getIntType()); + auto newPatternType = clonePatternVal(cloneEnv, &builder, expandType->getPatternType(), eachIndex); + builder.emitYield(newPatternType); + return result; + } + + // Process the body of an `IRExpand` instruction, and replace the type of children insts if it + // is an `IRExpandType`. + // + void processExpandVal(IRExpand* expandVal) + { + IRBuilder builder(expandVal); + IRCloneEnv cloneEnv; + auto eachIndex = expandVal->getFirstBlock()->getFirstParam(); + for (auto block : expandVal->getBlocks()) + { + for (auto inst : block->getModifiableChildren()) + { + builder.setInsertBefore(inst); + auto newType = clonePatternVal(cloneEnv, &builder, inst->getFullType(), eachIndex); + if (newType != inst->getFullType()) + { + inst = builder.replaceOperand(&inst->typeUse, newType); + } + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto oldOperand = inst->getOperand(i); + if (!oldOperand) + continue; + if (isChildInstOf(oldOperand, expandVal)) + continue; + auto newOperand = clonePatternVal(cloneEnv, &builder, oldOperand, eachIndex); + if (newOperand != inst->getOperand(i)) + { + inst = builder.replaceOperand(inst->getOperands() + i, newOperand); + } + } + } + } + } + + void lowerExpandType(IRModule* module) + { + // Use a work list to process all instructions in the module, and lower any `IRExpandType` we see + // along the way. + + List<IRInst*> workList; + for (auto type : module->getGlobalInsts()) + { + workList.add(type); + } + + while (workList.getCount() != 0) + { + auto inst = workList.getLast(); + workList.removeLast(); + + if (auto expandType = as<IRExpandType>(inst)) + { + inst = lowerExpandTypeImpl(expandType); + if (inst != expandType) + { + expandType->replaceUsesWith(inst); + expandType->removeAndDeallocate(); + } + } + else if (auto expandVal = as<IRExpand>(inst)) + { + processExpandVal(expandVal); + } + for (auto child : inst->getChildren()) + { + workList.add(child); + } + } + } +} diff --git a/source/slang/slang-ir-lower-expand-type.h b/source/slang/slang-ir-lower-expand-type.h new file mode 100644 index 000000000..28136e8c0 --- /dev/null +++ b/source/slang/slang-ir-lower-expand-type.h @@ -0,0 +1,30 @@ +#pragma once + +namespace Slang +{ + struct IRModule; + + // After IR lowering, an `expand each X` type will be defined in the IR as: + // %X = ... + // %e = IREach(%X) + // %expand = IRExpandType(%e) + // This form allows our IR deduplication logic to find the deduplicate the same + // `exapnd` types into the same IR inst. + // However after lowering is done, we no longer need this deduplication service. + // But having expand types defined in this form is making it very difficult to + // specialize. + // This pass runs immediately after IR lowering process for a module (pre-linking) + // to turn `IRExpandType` into `IRExpand`, so that the above expand type will be + // represented as: + // %expand = IRExpand : IRTypeKind + // { + // %eachIndex = IRParam : int; + // %e = ...; // may use %eachIndex. + // yield %e; + // } + // + // After this translation, there should be no longer any IRExpandType/IREach instructions + // that are alive in the IR. All future passes will only need to deal with IRExpand. + // + void lowerExpandType(IRModule* module); +} diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index b5f5edb05..8405f9e78 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -329,6 +329,7 @@ struct PeepholeContext : InstPassBase case kIROp_MakeTuple: case kIROp_MakeValuePack: case kIROp_MakeWitnessPack: + case kIROp_TypePack: { auto element = inst->getOperand(1); if (auto intLit = as<IRIntLit>(element)) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index c9e94352e..a56dae025 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -8,6 +8,7 @@ #include "slang-ir-lower-witness-lookup.h" #include "slang-ir-dce.h" #include "slang-ir-sccp.h" +#include "slang-ir-util.h" #include "../core/slang-performance-profiler.h" namespace Slang @@ -85,6 +86,7 @@ struct SpecializationContext { case kIROp_GlobalGenericParam: case kIROp_LookupWitness: + case kIROp_GetTupleElement: return false; case kIROp_Specialize: // The `specialize` instruction is a bit sepcial, @@ -589,9 +591,6 @@ struct SpecializationContext case kIROp_Expand: return maybeSpecializeExpand(as<IRExpand>(inst)); - case kIROp_ExpandTypeOrVal: - return maybeSpecializeExpandTypeOrVal(as<IRExpandType>(inst)); - case kIROp_GetTupleElement: return maybeSpecializeFoldableInst(inst); @@ -605,6 +604,15 @@ struct SpecializationContext case kIROp_CountOf: return maybeSpecializeCountOf(inst); + + case kIROp_Func: + + if (tryExpandParameterPack(as<IRFunc>(inst))) + { + addUsersToWorkList(inst); + return true; + } + return false; } } @@ -1010,6 +1018,9 @@ struct SpecializationContext workList.removeLast(); workListSet.remove(inst); + if (!inst->getParent() && inst->getOp() != kIROp_Module) + continue; + // For each instruction we process, we want to perform // a few steps. // @@ -1182,11 +1193,8 @@ struct SpecializationContext auto newWrapExistential = builder.emitWrapExistential( resultType, newCall, slotOperandCount, slotOperands.getArrayView().getBuffer()); inst->replaceUsesWith(newWrapExistential); - workList.remove(inst); inst->removeAndDeallocate(); addUsersToWorkList(newWrapExistential); - - workList.remove(wrapExistential); SLANG_ASSERT(!wrapExistential->hasUses()); wrapExistential->removeAndDeallocate(); return true; @@ -1209,6 +1217,14 @@ struct SpecializationContext if (maybeSpecializeBufferLoadCall(inst)) return false; + // If any arguments are value packs, we need to flatten them. + bool isCalleeFullyExpanded = false; + tryExpandParameterPack(as<IRFunc>(inst->getCallee()), &isCalleeFullyExpanded); + if (isCalleeFullyExpanded) + { + inst = tryExpandArgPack((IRCall*)inst); + } + // We can only specialize a call when the callee function is known. // auto calleeFunc = as<IRFunc>(inst->getCallee()); @@ -2402,13 +2418,9 @@ struct SpecializationContext break; } } - auto type = clonePatternVal(*subEnv, builder, childInst->getFullType(), index); - for (UInt i = 0; i < childInst->getOperandCount(); i++) - { - clonePatternVal(*subEnv, builder, childInst->getOperand(i), index); - } auto newInst = cloneInst(subEnv, builder, childInst); - newInst = builder->replaceOperand(&newInst->typeUse, type); + if (newInst != childInst) + addToWorkList(newInst); subEnv->mapOldValToNew[childInst] = newInst; IRBuilder subBuilder(*builder); subBuilder.setInsertInto(newInst); @@ -2419,6 +2431,32 @@ struct SpecializationContext return newInst; } + // A helper function to emit a MakeWitnessPack, MakeTypePack or MakeValuePack inst from + // a collection of elements, dependending on `type`. + // + IRInst* makeSpecializedPack(IRBuilder& builder, IRType* type, ArrayView<IRInst*> elements) + { + IRInst* resultPack = nullptr; + if (as<IRWitnessTableType>(type)) + { + List<IRType*> types; + for (auto element : elements) + types.add(element->getDataType()); + auto newTypePack = builder.getTypePack(elements.getCount(), types.getBuffer()); + resultPack = builder.emitMakeWitnessPack(newTypePack, elements); + } + else if (as<IRTypeKind>(type) || as<IRTypeType>(type)) + { + auto newTypePack = builder.getTypePack(elements.getCount(), (IRType* const*)elements.getBuffer()); + resultPack = newTypePack; + } + else + { + resultPack = builder.emitMakeValuePack((UInt)elements.getCount(), elements.getBuffer()); + } + return resultPack; + } + bool maybeSpecializeExpand(IRExpand* expandInst) { if (expandInst->getCaptureCount() == 0) @@ -2440,44 +2478,57 @@ struct SpecializationContext } if (elementCount == 0) { - auto resultValuePack = builder.emitMakeValuePack(0, (IRInst*const*)nullptr); - expandInst->replaceUsesWith(resultValuePack); + auto resultPack = makeSpecializedPack(builder, expandInst->getDataType(), elements.getArrayView()); + expandInst->replaceUsesWith(resultPack); expandInst->removeAndDeallocate(); - addUsersToWorkList(resultValuePack); + addUsersToWorkList(resultPack); return true; } + + bool isMultiBlock = as<IRYield>(expandInst->getFirstBlock()->getTerminator()) == nullptr; for (UInt i = 0; i < elementCount; i++) { IRCloneEnv cloneEnv; - IRBlock* firstBlock = nullptr; IRBuilder subBuilder = builder; - for (auto childBlock : expandInst->getBlocks()) + IRBlock* mergeBlock = nullptr; + if (isMultiBlock) { - auto newBlock = subBuilder.emitBlock(); - if (!firstBlock) - firstBlock = newBlock; - cloneEnv.mapOldValToNew[childBlock] = newBlock; + IRBlock* firstBlock = nullptr; + for (auto childBlock : expandInst->getBlocks()) + { + auto newBlock = subBuilder.emitBlock(); + if (!firstBlock) + firstBlock = newBlock; + cloneEnv.mapOldValToNew[childBlock] = newBlock; + } + + builder.emitBranch(firstBlock); + + mergeBlock = subBuilder.emitBlock(); + builder.setInsertInto(mergeBlock); } + auto indexParam = expandInst->getFirstBlock()->getFirstParam(); SLANG_ASSERT(indexParam); cloneEnv.mapOldValToNew[indexParam] = subBuilder.getIntValue(subBuilder.getIntType(), i); - builder.emitBranch(firstBlock); - - IRBlock* mergeBlock = subBuilder.emitBlock(); - builder.setInsertInto(mergeBlock); - for (auto childBlock : expandInst->getBlocks()) { - auto newBlock = cloneEnv.mapOldValToNew[childBlock]; - subBuilder.setInsertInto(newBlock); + if (isMultiBlock) + { + auto newBlock = cloneEnv.mapOldValToNew[childBlock]; + subBuilder.setInsertInto(newBlock); + } for (auto child : childBlock->getChildren()) { if (as<IRYield>(child)) { - elements.add(cloneEnv.mapOldValToNew[child->getOperand(0)]); - subBuilder.emitBranch(mergeBlock); + auto currentResult = child->getOperand(0); + currentResult = findCloneForOperand(&cloneEnv, currentResult); + elements.add(currentResult); + if (isMultiBlock) + subBuilder.emitBranch(mergeBlock); continue; } specializeExpandChildInst(cloneEnv, &subBuilder, child, i); @@ -2486,129 +2537,22 @@ struct SpecializationContext } } - auto resultValuePack = builder.emitMakeValuePack((UInt)elements.getCount(), elements.getBuffer()); - auto currentBlock = builder.getBlock(); - for (auto nextInst = expandInst->next; nextInst;) - { - auto next = nextInst->next; - nextInst->insertAtEnd(currentBlock); - nextInst = next; - } - addUsersToWorkList(expandInst); - expandInst->replaceUsesWith(resultValuePack); - expandInst->removeAndDeallocate(); - return true; - } - IRInst* clonePatternValImpl(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, UInt indexInPack) - { - if (!val) - return val; - - switch (val->getOp()) - { - case kIROp_ExpandTypeOrVal: - return val; - case kIROp_Each: + IRInst* resultPack = makeSpecializedPack(builder, expandInst->getDataType(), elements.getArrayView()); + if (isMultiBlock) { - auto eachInst = as<IREach>(val); - auto packInst = eachInst->getElement(); - if (auto typePack = as<IRTypePack>(packInst)) - { - SLANG_RELEASE_ASSERT(indexInPack < typePack->getOperandCount()); - return typePack->getOperand(indexInPack); - } - else if (auto makeValuePack = as<IRMakeValuePack>(packInst)) - { - SLANG_RELEASE_ASSERT(indexInPack < makeValuePack->getOperandCount()); - return makeValuePack->getOperand(indexInPack); - } - else if (!as<IRTypeKind>(packInst->getDataType())) + auto currentBlock = builder.getBlock(); + for (auto nextInst = expandInst->next; nextInst;) { - auto type = clonePatternVal(cloneEnv, builder, val, indexInPack); - return builder->emitGetTupleElement((IRType*)type, packInst, indexInPack); + auto next = nextInst->next; + nextInst->insertAtEnd(currentBlock); + nextInst = next; } - return val; - } - default: - break; - } - bool anyChange = false; - ShortList<IRInst*> operands; - for (UInt i = 0; i < val->getOperandCount(); i++) - { - auto newOperand = clonePatternVal(cloneEnv, builder, val->getOperand(i), indexInPack); - if (newOperand != val->getOperand(i)) - anyChange = true; - operands.add(newOperand); - } - auto newType = clonePatternVal(cloneEnv, builder, val->getFullType(), indexInPack); - if (newType != val->getFullType()) - anyChange = true; - if (!anyChange) - return val; - - auto newVal = builder->emitIntrinsicInst((IRType*)newType, val->getOp(), operands.getCount(), operands.getArrayView().getBuffer()); - if (newVal != val) - { - cloneInstDecorationsAndChildren(&cloneEnv, module, val, newVal); - } - return newVal; - } - - IRInst* clonePatternVal(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, UInt indexInPack) - { - if (auto clonedVal = cloneEnv.mapOldValToNew.tryGetValue(val)) - return *clonedVal; - cloneEnv.mapOldValToNew[val] = val; - auto result = clonePatternValImpl(cloneEnv, builder, val, indexInPack); - cloneEnv.mapOldValToNew[val] = result; - return result; - } - - bool maybeSpecializeExpandTypeOrVal(IRExpandType* expandInst) - { - if (expandInst->getCaptureCount() == 0) - return false; - - for (UInt i = 0; i < expandInst->getCaptureCount(); i++) - { - if (!as<IRTypePack>(expandInst->getCaptureType(i))) - return false; - } - IRBuilder builder(expandInst); - builder.setInsertBefore(expandInst); - List<IRInst*> elements; - UInt elementCount = 0; - if (auto firstTypePack = as<IRTypePack>(expandInst->getCaptureType(0))) - { - elementCount = firstTypePack->getOperandCount(); - } - for (UInt i = 0; i < elementCount; i++) - { - IRCloneEnv cloneEnv; - auto element = clonePatternVal(cloneEnv, &builder, expandInst->getPatternType(), i); - elements.add(element); } addUsersToWorkList(expandInst); - if (as<IRWitnessTableType>(expandInst->getDataType())) - { - List<IRType*> types; - for (auto element : elements) - types.add(element->getDataType()); - auto newTypePack = builder.getTypePack(elements.getCount(), types.getBuffer()); - auto result = builder.emitMakeWitnessPack(newTypePack, elements.getArrayView()); - expandInst->replaceUsesWith(result); - expandInst->removeAndDeallocate(); - return true; - } - else - { - auto newTypePack = builder.getTypePack(elements.getCount(), (IRType*const*)elements.getBuffer()); - expandInst->replaceUsesWith(newTypePack); - expandInst->removeAndDeallocate(); - return true; - } + expandInst->replaceUsesWith(resultPack); + expandInst->removeAndDeallocate(); + return true; } // The handling of specialization for global generic type @@ -2680,6 +2624,108 @@ struct SpecializationContext } } } + + + // If `func` has any parameters whose types are `IRTypePack`, then we will expand them + // into multiple parameters, so that the function has no parameters of type `IRTypePack`. + // returns true if changes are made. + // For example, this function turns `int f(TypePack<int, float> v)` into + // ``` + // int f(int v0, float v1) + // { + // v = MakeValuePack(v0,. v1); + // ... + // } + // ``` + // + bool tryExpandParameterPack(IRFunc* func, bool* outIsFullyExpanded = nullptr) + { + if (!func) + return false; + if (outIsFullyExpanded) + *outIsFullyExpanded = true; + ShortList<IRInst*> params; + for (auto param : func->getParams()) + { + if (as<IRTypePack>(param->getDataType())) + params.add(param); + if (as<IRExpand>(param->getDataType())) + { + if (outIsFullyExpanded) + *outIsFullyExpanded = false; + return false; + } + } + if (params.getCount() == 0) + return false; + + IRBuilder builder(func); + for (auto param : params) + { + builder.setInsertBefore(param); + auto typePack = as<IRTypePack>(param->getDataType()); + ShortList<IRInst*> newParams; + for (UInt i = 0; i < typePack->getOperandCount(); i++) + { + auto newParam = builder.createParam((IRType*)typePack->getOperand(i)); + newParam->insertBefore(param); + newParams.add(newParam); + } + setInsertBeforeOrdinaryInst(&builder, param); + auto val = builder.emitMakeValuePack(typePack, (UInt)newParams.getCount(), newParams.getArrayView().getBuffer()); + param->replaceUsesWith(val); + param->removeAndDeallocate(); + addUsersToWorkList(val); + } + + fixUpFuncType(func); + return true; + } + + // If any arguments in a call is a value pack, we will expand them into the argument list, + // so that the call has no arguments of type `IRTypePack`. + // For example, we will turn `f(MakeValuePack(a, b))` into `f(a, b)`. + // + IRCall* tryExpandArgPack(IRCall* call) + { + bool anyArgPack = false; + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto arg = call->getArg(i); + if (as<IRTypePack>(arg->getDataType())) + { + anyArgPack = true; + break; + } + } + if (!anyArgPack) + return call; + IRBuilder builder(call); + builder.setInsertBefore(call); + List<IRInst*> newArgs; + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto arg = call->getArg(i); + if (auto typePack = as<IRTypePack>(arg->getDataType())) + { + for (UInt elementIndex = 0; elementIndex < typePack->getOperandCount(); elementIndex++) + { + auto newArg = builder.emitGetTupleElement((IRType*)typePack->getOperand(elementIndex), arg, elementIndex); + newArgs.add(newArg); + } + } + else + { + newArgs.add(arg); + } + } + auto newCall = builder.emitCallInst(call->getFullType(), call->getCallee(), newArgs.getArrayView()); + call->replaceUsesWith(newCall); + call->transferDecorationsTo(newCall); + call->removeAndDeallocate(); + return newCall; + } + }; bool specializeModule( @@ -2785,6 +2831,13 @@ IRInst* specializeGenericImpl( IRBuilder* builder = &builderStorage; builder->setInsertBefore(genericVal); + List<IRInst*> pendingWorkList; + SLANG_DEFER + ( + for (Index ii = pendingWorkList.getCount() - 1; ii >= 0; ii--) + context->addToWorkList(pendingWorkList[ii]); + ); + // Now we will run through the body of the generic and // clone each of its instructions into the global scope, // until we reach a `return` instruction. @@ -2825,10 +2878,11 @@ IRInst* specializeGenericImpl( { if (auto func = as<IRFunc>(specializedVal)) { + context->tryExpandParameterPack(func); simplifyFunc(context->targetProgram, func, IRSimplificationOptions::getFast(context->targetProgram)); } } - + pendingWorkList.add(specializedVal); return specializedVal; } @@ -2848,7 +2902,7 @@ IRInst* specializeGenericImpl( // if (context) { - context->addToWorkList(clonedInst); + pendingWorkList.add(clonedInst); } } } diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index e030b6d24..817c10ec2 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1695,6 +1695,7 @@ struct GenericChildrenMigrationContextImpl case kIROp_ClassType: case kIROp_Func: case kIROp_Generic: + case kIROp_Expand: return false; default: break; diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index e0769686c..0b0a42617 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4114,12 +4114,17 @@ namespace Slang // `getTupleElement(makeTuple(a_0, a_1, ... a_N), i)` then we should // just return `a_i`, provided that the index is properly in range. // - if( auto makeTuple = as<IRMakeTuple>(tuple) ) + switch(tuple->getOp()) { - if( element < makeTuple->getOperandCount() ) + case kIROp_MakeTuple: + case kIROp_MakeValuePack: + case kIROp_MakeWitnessPack: + case kIROp_TypePack: + if( element < tuple->getOperandCount() ) { - return makeTuple->getOperand(element); + return tuple->getOperand(element); } + break; } return emitGetTupleElement(type, tuple, getIntValue(getIntType(), element)); } @@ -8345,6 +8350,7 @@ namespace Slang case kIROp_DifferentialPairGetDifferential: case kIROp_MakeDifferentialPair: case kIROp_MakeTuple: + case kIROp_MakeValuePack: case kIROp_GetTupleElement: case kIROp_StructuredBufferLoad: case kIROp_RWStructuredBufferLoad: diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index b1c2b001e..4a3a04404 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1933,18 +1933,22 @@ struct IRAttributedType : IRType IRInst* getAttr() { return getOperand(1); } }; +struct IRTupleTypeBase : IRType +{ + IR_PARENT_ISA(TupleTypeBase) +}; + /// Represents a tuple. Tuples are created by `IRMakeTuple` and its elements /// are accessed via `GetTupleElement(tupleValue, IRIntLit)`. -struct IRTupleType : IRType +struct IRTupleType : IRTupleTypeBase { IR_LEAF_ISA(TupleType) }; - /// Represents a type pack. Type packs behave like tuples, but they have a /// "flattening" semantics, so that MakeTypePack(MakeTypePack(T1,T2), T3) is /// MakeTypePack(T1,T2,T3). -struct IRTypePack : IRType +struct IRTypePack : IRTupleTypeBase { IR_LEAF_ISA(TypePack) }; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 2828752a0..02c4fae68 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -11,6 +11,7 @@ #include "slang-check.h" #include "slang-ir-bit-field-accessors.h" #include "slang-ir-loop-inversion.h" +#include "slang-ir-lower-expand-type.h" #include "slang-ir.h" #include "slang-ir-util.h" #include "slang-ir-constexpr.h" @@ -1987,7 +1988,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower } else { - return lowerType(context, type->getTypePack()); + return context->irBuilder->getTupleType(lowerType(context, type->getTypePack())); } } @@ -11080,6 +11081,13 @@ RefPtr<IRModule> generateIRForTranslationUnit( // Synthesize some code we want to make sure is inlined and simplified synthesizeBitFieldAccessors(module); + // Lower `IRExpandType` types to use `IRExpand`, where the pattern type + // is nested inside the `IRExpand` as its children, instead of being same + // level entities as the ExpandType itself. + // This will unify the specialization logic for both type and value level + // expansion. + lowerExpandType(module); + // Generate DebugValue insts to store values into debug variables, // if debug symbols are enabled. if (context->includeDebugInfo) diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index 7951ddc38..5a0d41f09 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -608,6 +608,48 @@ namespace Slang { emitType(context, getResultType(context->astBuilder, callableDeclRef)); } + + // Include key modifiers in the mangled name so we never deduplicate + // things like a nonmutating method and a mutating method. + bool isMutating = false; + bool isRefThis = false; + bool isFwdDiff = false; + bool isBwdDiff = false; + bool isNoDiffThis = false; + for (auto modifier : callableDeclRef.getDecl()->modifiers) + { + if (as<MutatingAttribute>(modifier)) + { + isMutating = true; + } + else if (as<RefAttribute>(modifier)) + { + isRefThis = true; + } + else if (as<ForwardDifferentiableAttribute>(modifier)) + { + isFwdDiff = true; + } + else if (as<BackwardDifferentiableAttribute>(modifier)) + { + isBwdDiff = true; + } + else if (as<NoDiffThisAttribute>(modifier)) + { + isNoDiffThis = true; + } + } + + if (isMutating) + emitRaw(context, "m"); + if (isRefThis) + emitRaw(context, "r"); + if (isFwdDiff) + emitRaw(context, "f"); + if (isBwdDiff) + emitRaw(context, "b"); + if (isNoDiffThis) + emitRaw(context, "n"); } } |
