diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 65 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 59 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-types.cpp | 174 |
5 files changed, 267 insertions, 41 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 55c0ee46d..640f516ed 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -568,11 +568,9 @@ namespace Slang { DifferentiableTypeConformanceContext* diffTypeContext; - virtual bool shouldConvertAddrInst(IRInst* addrInst) override + virtual bool shouldConvertAddrInst(IRInst*) override { - if (isDifferentiableType(*diffTypeContext, addrInst->getDataType())) - return true; - return false; + return true; } }; @@ -598,7 +596,9 @@ namespace Slang if (SLANG_SUCCEEDED(result)) { + disableIRValidationAtInsert(); simplifyFunc(func); + enableIRValidationAtInsert(); } return result; } diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 70018b476..95ad58586 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -328,18 +328,6 @@ struct DiffTransposePass getPhiGrads(trueBlock).getCount(), getPhiGrads(trueBlock).getBuffer()); - // Old false-side starting block becomes end block - // for the new pre-cond region (which could be empty) - // - IRBlock* revPreCondEndBlock = revBlockMap[falseBlock]; - if (!falseRegionInfo.isTrivial) - { - builder.setInsertInto(revPreCondEndBlock); - builder.emitBranch( - revCondBlock, - getPhiGrads(falseBlock).getCount(), - getPhiGrads(falseBlock).getBuffer()); - } IRBlock* revBreakRegionExitBlock = revBlockMap[firstLoopBlock]; if (!preCondRegionInfo.isTrivial) @@ -366,17 +354,42 @@ struct DiffTransposePass ifElse->getCondition(), revTrueBlock, revFalseBlock, - revLoopEndBlock); + revTrueBlock); - // Emit loop into rev-version of the break block. - auto revLoopBlock = revBlockMap[breakBlock]; - builder.setInsertInto(revLoopBlock); - builder.emitLoop( - revPreCondBlock, - revBreakBlock, - revLoopEndBlock, - getPhiGrads(breakBlock).getCount(), - getPhiGrads(breakBlock).getBuffer()); + // Old false-side starting block becomes end block + // for the new pre-cond region (which could be empty) + // + + if (!falseRegionInfo.isTrivial) + { + IRBlock* revPreCondEndBlock = revBlockMap[falseBlock]; + builder.setInsertInto(revPreCondEndBlock); + builder.emitLoop( + revCondBlock, + revBreakBlock, + revLoopEndBlock, + getPhiGrads(falseBlock).getCount(), + getPhiGrads(falseBlock).getBuffer()); + + auto revLoopStartBlock = revBlockMap[breakBlock]; + builder.setInsertInto(revLoopStartBlock); + builder.emitBranch( + revPreCondBlock, + getPhiGrads(breakBlock).getCount(), + getPhiGrads(breakBlock).getBuffer()); + } + else + { + // Emit loop into rev-version of the break block. + auto revLoopBlock = revBlockMap[breakBlock]; + builder.setInsertInto(revLoopBlock); + builder.emitLoop( + revPreCondBlock, + revBreakBlock, + revLoopEndBlock, + getPhiGrads(breakBlock).getCount(), + getPhiGrads(breakBlock).getBuffer()); + } currentBlock = breakBlock; break; @@ -1436,9 +1449,13 @@ struct DiffTransposePass argRequiresLoad.add(false); } - args.add(builder->emitLoad(primalContextDecor->getBackwardDerivativePrimalContextVar())); + // Ensure availability of the primal context var + auto primalContextVar = hoistPrimalInst(builder, primalContextDecor->getBackwardDerivativePrimalContextVar()); + SLANG_RELEASE_ASSERT(primalContextVar); + + args.add(builder->emitLoad(primalContextVar)); argTypes.add(as<IRPtrTypeBase>( - primalContextDecor->getBackwardDerivativePrimalContextVar()->getDataType()) + primalContextVar->getDataType()) ->getValueType()); argRequiresLoad.add(false); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 4e7539b48..50c5c4ea6 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -83,7 +83,7 @@ struct ExtractPrimalFuncContext SLANG_RELEASE_ASSERT(originalFuncType); List<IRType*> paramTypes; - for (UInt i = 0; i < originalFuncType->getParamCount() - 1; i++) + for (Index i = 0; i < ((Count) originalFuncType->getParamCount()) - 1; i++) paramTypes.add((IRType*)migrationContext.cloneInst(&builder, originalFuncType->getParamType(i))); paramTypes.add(builder.getInOutType((IRType*)outIntermediateType)); auto resultType = (IRType*)migrationContext.cloneInst(&builder, originalFuncType->getResultType()); diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 2ccb8d8e2..e2c84ce8b 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -525,7 +525,32 @@ struct DiffUnzipPass List<IRInst*> primalInsts; for (auto child = primalBlock->getFirstChild(); child; child = child->getNextInst()) + { + // TODO: This might be a decent place to enforce that each load has a single + // corresponding store (i.e. that everything is SSAd properly)? + + // We're only interested in insts that generate values. + if (child->getDataType() == nullptr || + as<IRVoidType>(child->getDataType()) || + as<IRFuncType>(child->getDataType()) || + as<IRTypeKind>(child->getDataType())) + continue; + + // We also don't care about pointer types (only Loads) + if (auto ptrType = as<IRPtrTypeBase>(child->getDataType())) + { + // There's an exception to this, if the var is an intermediate context type + // variable since there won't be a load from this yet (the load will + // be inserted later during the transposition process) + // + if (as<IRBackwardDiffIntermediateContextType>(ptrType->getValueType())) + primalInsts.add(child); + + continue; + } + primalInsts.add(child); + } IRBuilder builder(autodiffContext->moduleInst->getModule()); @@ -545,7 +570,7 @@ struct DiffUnzipPass bool shouldStore = false; for (auto use = inst->firstUse; use; use = use->nextUse) { - IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent()); + IRBlock* useBlock = getBlock(use->getUser()); if (isDifferentialInst(useBlock)) { @@ -561,7 +586,14 @@ struct DiffUnzipPass builder.setInsertBefore(firstPrimalBlock->getTerminator()); IRType* arrayType = inst->getDataType(); - SLANG_ASSERT(!as<IRPtrTypeBase>(arrayType)); // can't store pointers. + bool isPtrType = false; + + if (auto ptrType = as<IRPtrTypeBase>(arrayType)) + { + SLANG_RELEASE_ASSERT(as<IRBackwardDiffIntermediateContextType>(ptrType->getValueType())); + arrayType = ptrType->getValueType(); + isPtrType = true; + } for (auto region : regions) { @@ -582,11 +614,6 @@ struct DiffUnzipPass auto storageVar = builder.emitVar(arrayType); - // TODO(sai) STOPPED HERE: For some reason, we still have a direct param access - // when trying to cover up the access to last value of loop counter. - // Maybe we need a different way to access this? (use a var) - // Special case? - // 3. Store current value into the array and replace uses with a load. // TODO: If an index is missing, use the 'last' value of the primal index. { @@ -616,7 +643,8 @@ struct DiffUnzipPass { if (as<IRDecoration>(use->getUser())) { - if (!as<IRLoopExitPrimalValueDecoration>(use->getUser())) + if (!as<IRLoopExitPrimalValueDecoration>(use->getUser()) && + !as<IRBackwardDerivativePrimalContextDecoration>(use->getUser())) continue; } @@ -683,10 +711,17 @@ struct DiffUnzipPass instsToTag.add(loadAddr); } - auto loadedValue = builder.emitLoad(loadAddr); - instsToTag.add(loadedValue); + if (!isPtrType) + { + auto loadedValue = builder.emitLoad(loadAddr); + instsToTag.add(loadedValue); - use->set(loadedValue); + use->set(loadedValue); + } + else + { + use->set(loadAddr); + } } } @@ -744,6 +779,8 @@ struct DiffUnzipPass } auto intermediateVar = primalBuilder->emitVar((IRType*)intermediateType); + primalBuilder->markInstAsPrimal(intermediateVar); + primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar); auto primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn); diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 7660c9526..d7ed1f63f 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -752,7 +752,7 @@ static LegalVal legalizeUnconditionalBranch( SLANG_UNIMPLEMENTED_X("Unknown legalized val flavor."); } } - context->builder->emitBranch(branchInst->getTargetBlock(), newArgs.getCount() - 1, newArgs.getBuffer() + 1); + context->builder->emitIntrinsicInst(nullptr, branchInst->getOp(), newArgs.getCount(), newArgs.getBuffer()); return LegalVal(); } @@ -1665,6 +1665,169 @@ static LegalVal legalizeMakeStruct( } } +static LegalVal legalizeMakeArray( + IRTypeLegalizationContext* context, + LegalType legalType, + LegalVal const* legalArgs, + UInt argCount, + IROp constructOp) +{ + auto builder = context->builder; + + switch (legalType.flavor) + { + case LegalType::Flavor::none: + return LegalVal(); + + case LegalType::Flavor::simple: + { + List<IRInst*> args; + // We need a valid default val for elements that are legalized to `none`. + // We grab the first non-none value from the legalized args and use it. + // If all args are none (althoguh this shouldn't happen, since the entire array + // would have been legalized to none in this case.), we use defaultConstruct op. + // Use of defaultConstruct may lead to invalid HLSL/GLSL code, so we want to + // avoid that if possible. + IRInst* defaultVal = nullptr; + for (UInt aa = 0; aa < argCount; ++aa) + { + if (legalArgs[aa].flavor == LegalVal::Flavor::simple) + { + defaultVal = legalArgs[aa].getSimple(); + break; + } + } + if (!defaultVal) + { + defaultVal = builder->emitDefaultConstruct(as<IRArrayTypeBase>(legalType.getSimple())->getElementType()); + } + for (UInt aa = 0; aa < argCount; ++aa) + { + if (legalArgs[aa].flavor == LegalVal::Flavor::none) + args.add(defaultVal); + else + args.add(legalArgs[aa].getSimple()); + } + return LegalVal::simple( + builder->emitIntrinsicInst( + legalType.getSimple(), + constructOp, + args.getCount(), + args.getBuffer())); + } + + case LegalType::Flavor::pair: + { + // There are two sides, the ordinary and the special, + // and we basically just dispatch to both of them. + auto pairType = legalType.getPair(); + auto pairInfo = pairType->pairInfo; + LegalType ordinaryType = pairType->ordinaryType; + LegalType specialType = pairType->specialType; + + List<LegalVal> ordinaryArgs; + List<LegalVal> specialArgs; + bool hasValidOrdinaryArgs = false; + bool hasValidSpecialArgs = false; + for (UInt argIndex = 0; argIndex < argCount; argIndex++) + { + LegalVal arg = legalArgs[argIndex]; + + // The argument must be a pair. + if (arg.flavor == LegalVal::Flavor::pair) + { + auto argPair = arg.getPair(); + ordinaryArgs.add(argPair->ordinaryVal); + specialArgs.add(argPair->specialVal); + hasValidOrdinaryArgs = true; + hasValidSpecialArgs = true; + } + else if (arg.flavor == LegalVal::Flavor::simple) + { + if (arg.getSimple()->getFullType() == ordinaryType.irType) + { + ordinaryArgs.add(arg); + specialArgs.add(LegalVal()); + hasValidOrdinaryArgs = true; + } + else + { + ordinaryArgs.add(LegalVal()); + specialArgs.add(arg); + hasValidSpecialArgs = true; + } + } + else if (arg.flavor == LegalVal::Flavor::none) + { + ordinaryArgs.add(arg); + specialArgs.add(arg); + } + else + { + SLANG_UNEXPECTED("unhandled"); + } + } + + LegalVal ordinaryVal = LegalVal(); + if (hasValidOrdinaryArgs) + ordinaryVal = legalizeMakeArray( + context, + ordinaryType, + ordinaryArgs.getBuffer(), + ordinaryArgs.getCount(), + constructOp); + + LegalVal specialVal = LegalVal(); + if (hasValidSpecialArgs) + specialVal = legalizeMakeArray( + context, specialType, specialArgs.getBuffer(), specialArgs.getCount(), constructOp); + + return LegalVal::pair(ordinaryVal, specialVal, pairInfo); + } + break; + + case LegalType::Flavor::tuple: + { + // For array types that are legalized as tuples, + // we expect each element of the array to be legalized as the same tuples. + // We want to return a tuple, where i-th element is an array containing + // the i-th tuple-element of each legalized array-element. + + auto tupleType = legalType.getTuple(); + + RefPtr<TuplePseudoVal> resTupleInfo = new TuplePseudoVal(); + UInt elementCounter = 0; + for (auto typeElem : tupleType->elements) + { + auto elemKey = typeElem.key; + UInt elementIndex = elementCounter++; + List<LegalVal> subArray; + for (UInt i = 0; i < argCount; i++) + { + LegalVal argVal = legalArgs[i]; + SLANG_RELEASE_ASSERT(argVal.flavor == LegalVal::Flavor::tuple); + auto argTuple = argVal.getTuple(); + SLANG_RELEASE_ASSERT( + argTuple->elements.getCount() == tupleType->elements.getCount()); + subArray.add(argTuple->elements[elementIndex].val); + } + + auto legalSubArray = legalizeMakeArray(context, typeElem.type, subArray.getBuffer(), subArray.getCount(), constructOp); + + TuplePseudoVal::Element resElem; + resElem.key = elemKey; + resElem.val = legalSubArray; + resTupleInfo->elements.add(resElem); + } + return LegalVal::tuple(resTupleInfo); + } + + default: + SLANG_UNEXPECTED("unhandled"); + UNREACHABLE_RETURN(LegalVal()); + } +} + static LegalVal legalizeDefaultConstruct( IRTypeLegalizationContext* context, LegalType legalType) @@ -1762,11 +1925,20 @@ static LegalVal legalizeInst( type, args.getBuffer(), inst->getOperandCount()); + case kIROp_MakeArray: + case kIROp_MakeArrayFromElement: + return legalizeMakeArray( + context, + type, + args.getBuffer(), + inst->getOperandCount(), + inst->getOp()); case kIROp_DefaultConstruct: return legalizeDefaultConstruct( context, type); case kIROp_unconditionalBranch: + case kIROp_loop: return legalizeUnconditionalBranch(context, args, (IRUnconditionalBranch*)inst); case kIROp_undefined: return LegalVal(); |
