summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-18 21:57:24 -0700
committerGitHub <noreply@github.com>2024-08-18 21:57:24 -0700
commitecf85df6eee3da76ef54b14e4ab083f22da89e46 (patch)
tree4656f9c11a1f7f40550d469fecbcd7a16c541f52 /source
parentca5d303748517889a5d5849224671fa8945e1c6d (diff)
Variadic Generics Part 2: IR lowering and specialization. (#4849)
* Variadic Generics Part 2: IR lowering and specialization. * Update design doc status. * Update design doc. * Resolve review comments.
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang8
-rw-r--r--source/slang/slang-ast-builder.cpp14
-rw-r--r--source/slang/slang-ir-inst-defs.h6
-rw-r--r--source/slang/slang-ir-insts.h60
-rw-r--r--source/slang/slang-ir-link.cpp16
-rw-r--r--source/slang/slang-ir-loop-inversion.cpp2
-rw-r--r--source/slang/slang-ir-peephole.cpp18
-rw-r--r--source/slang/slang-ir-peephole.h1
-rw-r--r--source/slang/slang-ir-specialize.cpp255
-rw-r--r--source/slang/slang-ir-ssa.cpp16
-rw-r--r--source/slang/slang-ir.cpp71
-rw-r--r--source/slang/slang-ir.h30
-rw-r--r--source/slang/slang-legalize-types.cpp19
-rw-r--r--source/slang/slang-lower-to-ir.cpp168
-rw-r--r--source/slang/slang-syntax.cpp2
15 files changed, 640 insertions, 46 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 695423285..6c51ccef0 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -870,6 +870,14 @@ bool operator!=(__none_t noneVal, Optional<T> val)
return val.hasValue;
}
+__generic<each T>
+__magic_type(TupleType)
+struct Tuple
+{
+ __intrinsic_op($(0))
+ __init(expand each T);
+}
+
__generic<T>
__magic_type(NativeRefType)
__intrinsic_type($(kIROp_NativePtrType))
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index faf15470f..3a2b2933d 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -525,7 +525,19 @@ FuncType* ASTBuilder::getFuncType(ArrayView<Type*> parameters, Type* result, Typ
TupleType* ASTBuilder::getTupleType(List<Type*>& types)
{
- return getOrCreate<TupleType>(types.getArrayView());
+ // The canonical form of a tuple type is always a DeclRefType(GenAppDeclRef(TupleDecl, ConcreteTypePack(types...))).
+ // If `types` is already a single ConcreteTypePack, then we can use that directly.
+ if (types.getCount() == 1)
+ {
+ if (isTypePack(types[0]))
+ {
+ return as<TupleType>(getSpecializedBuiltinType(types[0], "TupleType"));
+ }
+ }
+
+ // Otherwise, we need to create a ConcreteTypePack to hold the types.
+ auto typePack = getTypePack(types.getArrayView());
+ return as<TupleType>(getSpecializedBuiltinType(typePack, "TupleType"));
}
TypeType* ASTBuilder::getTypeType(Type* type)
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 39de083f0..a4225d041 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -101,6 +101,7 @@ INST(Nop, nop, 0, 0)
//
/* Kind */
INST(TypeKind, Type, 0, HOISTABLE)
+ INST(TypeParameterPackKind, TypeParameterPack, 0, HOISTABLE)
INST(RateKind, Rate, 0, HOISTABLE)
INST(GenericKind, Generic, 0, HOISTABLE)
INST_RANGE(Kind, TypeKind, GenericKind)
@@ -244,6 +245,7 @@ INST(RTTIType, rtti_type, 0, HOISTABLE)
INST(RTTIHandleType, rtti_handle_type, 0, HOISTABLE)
INST(TupleType, tuple_type, 0, HOISTABLE)
INST(TargetTupleType, TargetTuple, 0, HOISTABLE)
+INST(ExpandTypeOrVal, ExpandTypeOrVal, 1, HOISTABLE)
// A type that identifies it's contained type as being emittable as `spirv_literal.
INST(SPIRVLiteralType, spirvLiteralType, 1, HOISTABLE)
@@ -343,6 +345,9 @@ INST(MakeTuple, makeTuple, 0, 0)
INST(MakeTargetTuple, makeTuple, 0, 0)
INST(GetTargetTupleElement, getTargetTupleElement, 0, 0)
INST(GetTupleElement, getTupleElement, 2, 0)
+INST(MakeWitnessPack, MakeWitnessPack, 0, HOISTABLE)
+INST(Expand, Expand, 1, 0)
+INST(Each, Each, 1, HOISTABLE)
INST(MakeResultValue, makeResultValue, 1, 0)
INST(MakeResultError, makeResultError, 1, 0)
INST(IsResultError, isResultError, 1, 0)
@@ -566,6 +571,7 @@ INST(SwizzledStore, swizzledStore, 2, 0)
/* IRTerminatorInst */
INST(Return, return_val, 1, 0)
+ INST(Yield, yield, 1, 0)
/* IRUnconditionalBranch */
// unconditionalBranch <target>
INST(unconditionalBranch, unconditionalBranch, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index e30b903b5..dc5fb2744 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2484,6 +2484,13 @@ struct IRReturn : IRTerminatorInst
IRInst* getVal() { return getOperand(0); }
};
+struct IRYield : IRTerminatorInst
+{
+ IR_LEAF_ISA(Yield);
+
+ IRInst* getVal() { return getOperand(0); }
+};
+
struct IRDiscard : IRTerminatorInst
{};
@@ -2825,12 +2832,36 @@ struct IRBindGlobalGenericParam : IRInst
IR_LEAF_ISA(BindGlobalGenericParam)
};
+struct IRExpand : IRInst
+{
+ IR_LEAF_ISA(Expand)
+ UInt getCaptureCount() { return getOperandCount(); }
+ IRInst* getCapture(UInt index) { return getOperand(index); }
+ IRInstList<IRBlock> getBlocks()
+ {
+ return IRInstList<IRBlock>(getChildren());
+ }
+};
+
+
+struct IREach : IRInst
+{
+ IR_LEAF_ISA(Each)
+
+ IRInst* getElement() { return getOperand(0); }
+};
+
// An Instruction that creates a tuple value.
struct IRMakeTuple : IRInst
{
IR_LEAF_ISA(MakeTuple)
};
+struct IRMakeWitnessPack : IRInst
+{
+ IR_LEAF_ISA(MakeWitnessPack)
+};
+
struct IRGetTupleElement : IRInst
{
IR_LEAF_ISA(GetTupleElement)
@@ -3328,7 +3359,7 @@ public:
// Get the current function (or other value with code)
// that we are inserting into (if any).
- IRGlobalValueWithCode* getFunc() { return m_insertLoc.getFunc(); }
+ IRInst* getFunc() { return m_insertLoc.getFunc(); }
void setInsertInto(IRInst* insertInto) { setInsertLoc(IRInsertLoc::atEnd(insertInto)); }
void setInsertBefore(IRInst* insertBefore) { setInsertLoc(IRInsertLoc::before(insertBefore)); }
@@ -3478,6 +3509,8 @@ public:
IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2);
IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2, IRType* type3);
+ IRExpandType* getExpandTypeOrVal(IRType* type, IRInst* pattern, ArrayView<IRInst*> capture);
+
IRResultType* getResultType(IRType* valueType, IRType* errorType);
IROptionalType* getOptionalType(IRType* valueType);
@@ -3485,6 +3518,7 @@ public:
IRWitnessTableType* getWitnessTableType(IRType* baseType);
IRWitnessTableIDType* getWitnessTableIDType(IRType* baseType);
IRType* getTypeType() { return getType(IROp::kIROp_TypeType); }
+ IRType* getTypeParameterPackKind() { return getType(IROp::kIROp_TypeParameterPackKind); }
IRType* getKeyType() { return nullptr; }
IRTypeKind* getTypeKind();
@@ -3715,6 +3749,9 @@ public:
return emitSpecializeInst(type, genericVal, args.getCount(), args.begin());
}
+ IRInst* emitExpandInst(IRType* type, UInt capturedArgCount, IRInst* const* capturedArgs);
+ IRInst* emitEachInst(IRType* type, IRInst* base, IRInst* indexArg = nullptr);
+
IRInst* emitLookupInterfaceMethodInst(
IRType* type,
IRInst* witnessTableVal,
@@ -3814,6 +3851,13 @@ public:
IRInst* emitMakeTuple(IRType* type, List<IRInst*> const& args)
{
+ if (args.getCount() == 1)
+ {
+ if (args[0]->getOp() == kIROp_Expand)
+ {
+ return args[0];
+ }
+ }
return emitMakeTuple(type, args.getCount(), args.getBuffer());
}
@@ -3828,11 +3872,22 @@ public:
return emitMakeTuple(SLANG_COUNT_OF(args), args);
}
+ IRInst* emitMakeWitnessPack(IRType* type, ArrayView<IRInst*> args)
+ {
+ return emitIntrinsicInst(type, kIROp_MakeWitnessPack, (UInt)args.getCount(), args.getBuffer());
+ }
+
IRInst* emitMakeString(IRInst* nativeStr);
IRInst* emitGetNativeString(IRInst* str);
+ IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, int element)
+ {
+ return emitGetTupleElement(type, tuple, (UInt)element);
+ }
+
IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, UInt element);
+ IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, IRInst* element);
IRInst* emitMakeResultError(IRType* resultType, IRInst* errorVal);
IRInst* emitMakeResultValue(IRType* resultType, IRInst* val);
@@ -4186,6 +4241,9 @@ public:
IRInst* emitReturn(
IRInst* val);
+ IRInst* emitYield(
+ IRInst* val);
+
IRInst* emitReturn();
IRInst* emitThrow(IRInst* val);
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 5bb485b22..8b08b9045 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -1170,6 +1170,17 @@ IRFunc* cloneFuncImpl(
return clonedFunc;
}
+// Can an inst with `opcode` contain basic blocks as children?
+bool canInstContainBasicBlocks(IROp opcode)
+{
+ switch (opcode)
+ {
+ case kIROp_Expand:
+ return true;
+ default:
+ return false;
+ }
+}
IRInst* cloneInst(
IRSpecContextBase* context,
@@ -1238,7 +1249,10 @@ IRInst* cloneInst(
argCount, newArgs.getArrayView().getBuffer());
builder->addInst(clonedInst);
registerClonedValue(context, clonedInst, originalValues);
- cloneDecorationsAndChildren(context, clonedInst, originalInst);
+ if (canInstContainBasicBlocks(clonedInst->getOp()))
+ cloneGlobalValueWithCodeCommon(context, (IRGlobalValueWithCode*)clonedInst, (IRGlobalValueWithCode*)originalInst, originalValues);
+ else
+ cloneDecorationsAndChildren(context, clonedInst, originalInst);
cloneExtraDecorations(context, clonedInst, originalValues);
return clonedInst;
}
diff --git a/source/slang/slang-ir-loop-inversion.cpp b/source/slang/slang-ir-loop-inversion.cpp
index c811bf357..9ae734507 100644
--- a/source/slang/slang-ir-loop-inversion.cpp
+++ b/source/slang/slang-ir-loop-inversion.cpp
@@ -140,7 +140,7 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)
builder.setInsertInto(loop->getParent());
const auto s = as<IRBlock>(loop->getParent());
- auto domTree = computeDominatorTree(s->getParent());
+ auto domTree = computeDominatorTree((IRGlobalValueWithCode*)s->getParent());
SLANG_ASSERT(s);
const auto c1 = loop->getTargetBlock();
const auto c1Terminator = as<IRIfElse>(c1->getTerminator());
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index 232633d69..aa8dfddab 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -324,7 +324,10 @@ struct PeepholeContext : InstPassBase
}
break;
case kIROp_GetTupleElement:
- if (inst->getOperand(0)->getOp() == kIROp_MakeTuple)
+ switch (inst->getOperand(0)->getOp())
+ {
+ case kIROp_MakeTuple:
+ case kIROp_MakeWitnessPack:
{
auto element = inst->getOperand(1);
if (auto intLit = as<IRIntLit>(element))
@@ -333,6 +336,10 @@ struct PeepholeContext : InstPassBase
maybeRemoveOldInst(inst);
changed = true;
}
+ break;
+ }
+ default:
+ break;
}
break;
case kIROp_FieldExtract:
@@ -1181,6 +1188,15 @@ bool peepholeOptimize(TargetProgram* target, IRInst* func)
return context.processFunc(func);
}
+bool peepholeOptimizeInst(TargetProgram* target, IRModule* module, IRInst* inst)
+{
+ PeepholeContext context = PeepholeContext(module);
+ context.targetProgram = target;
+ context.useFastAnalysis = true;
+ context.processInst(inst);
+ return context.changed;
+}
+
bool peepholeOptimizeGlobalScope(TargetProgram* target, IRModule* module)
{
PeepholeContext context = PeepholeContext(module);
diff --git a/source/slang/slang-ir-peephole.h b/source/slang/slang-ir-peephole.h
index 411267072..3fdb74450 100644
--- a/source/slang/slang-ir-peephole.h
+++ b/source/slang/slang-ir-peephole.h
@@ -22,6 +22,7 @@ namespace Slang
/// Apply peephole optimizations.
bool peepholeOptimize(TargetProgram* target, IRModule* module, PeepholeOptimizationOptions options);
bool peepholeOptimize(TargetProgram* target, IRInst* func);
+ bool peepholeOptimizeInst(TargetProgram* target, IRModule* module, IRInst* inst);
bool peepholeOptimizeGlobalScope(TargetProgram* target, IRModule* module);
bool tryReplaceInstUsesWithSimplifiedValue(TargetProgram* target, IRModule* module, IRInst* inst);
}
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index c86906b2d..2eb16112f 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -1,6 +1,6 @@
// slang-ir-specialize.cpp
#include "slang-ir-specialize.h"
-
+#include "slang-ir-peephole.h"
#include "slang-ir.h"
#include "slang-ir-clone.h"
#include "slang-ir-insts.h"
@@ -585,6 +585,15 @@ struct SpecializationContext
case kIROp_BindExistentialsType:
return maybeSpecializeBindExistentialsType(as<IRBindExistentialsType>(inst));
+
+ case kIROp_Expand:
+ return maybeSpecializeExpand(as<IRExpand>(inst));
+
+ case kIROp_ExpandTypeOrVal:
+ return maybeSpecializeExpandTypeOrVal(as<IRExpandType>(inst));
+
+ case kIROp_GetTupleElement:
+ return maybeSpecializeFoldableInst(inst);
}
}
@@ -597,7 +606,7 @@ struct SpecializationContext
{
// Note: While we currently have named the instruction
// `lookup_witness_method`, the `method` part is a misnomer
- // and the same instruction can look up *any* interface
+ // and the same instruction can look up *any* interfacemay
// requirement based on the witness table that provides
// a conformance, and the "key" that indicates the interface
// requirement.
@@ -609,7 +618,9 @@ struct SpecializationContext
//
auto witnessTable = as<IRWitnessTable>(lookupInst->getWitnessTable());
if (!witnessTable)
+ {
return false;
+ }
// Because we have a concrete witness table, we can
// use it to look up the IR value that satisfies
@@ -642,6 +653,19 @@ struct SpecializationContext
return true;
}
+ bool maybeSpecializeFoldableInst(IRInst* inst)
+ {
+ auto firstUse = inst->firstUse;
+ bool instChanged = peepholeOptimizeInst(targetProgram, module, inst);
+
+ for (auto use = firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ addToWorkList(user);
+ }
+ return instChanged;
+ }
+
// The above subroutine needed a way to look up
// the satisfying value for a given requirement
// key in a concrete witness table, so let's
@@ -2208,6 +2232,233 @@ struct SpecializationContext
return false;
}
+ IRInst* specializeExpandChildInst(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* childInst, UInt index)
+ {
+ IRCloneEnv freshEnv;
+ IRCloneEnv* subEnv = &cloneEnv;
+ switch (childInst->getOp())
+ {
+ case kIROp_Expand:
+ {
+ subEnv = &freshEnv;
+ break;
+ }
+ }
+ auto type = clonePatternVal(*subEnv, builder, childInst->getFullType(), index);
+ for (UInt i = 0; i < childInst->getOperandCount(); i++)
+ {
+ clonePatternVal(*subEnv, builder, childInst->getOperand(i), index);
+ }
+ auto newInst = cloneInst(subEnv, builder, childInst);
+ newInst = builder->replaceOperand(&newInst->typeUse, type);
+ subEnv->mapOldValToNew[childInst] = newInst;
+ IRBuilder subBuilder(*builder);
+ subBuilder.setInsertInto(newInst);
+ for (auto child : childInst->getChildren())
+ {
+ specializeExpandChildInst(*subEnv, &subBuilder, child, index);
+ }
+ return newInst;
+ }
+
+ bool maybeSpecializeExpand(IRExpand* expandInst)
+ {
+ if (expandInst->getCaptureCount() == 0)
+ return false;
+
+ for (UInt i = 0; i < expandInst->getCaptureCount(); i++)
+ {
+ if (!as<IRTupleType>(expandInst->getCapture(i)))
+ return false;
+ }
+
+ IRBuilder builder(expandInst);
+ builder.setInsertBefore(expandInst);
+ List<IRInst*> elements;
+ UInt elementCount = 0;
+ if (auto firstTupleType = as<IRTupleType>(expandInst->getCapture(0)))
+ {
+ elementCount = firstTupleType->getOperandCount();
+ }
+ if (elementCount == 0)
+ {
+ auto resultTuple = builder.emitMakeTuple(0, (IRInst*const*)nullptr);
+ expandInst->replaceUsesWith(resultTuple);
+ expandInst->removeAndDeallocate();
+ addUsersToWorkList(resultTuple);
+ return true;
+ }
+
+ for (UInt i = 0; i < elementCount; i++)
+ {
+ IRCloneEnv cloneEnv;
+ IRBlock* firstBlock = nullptr;
+ IRBuilder subBuilder = builder;
+ for (auto childBlock : expandInst->getBlocks())
+ {
+ auto newBlock = subBuilder.emitBlock();
+ if (!firstBlock)
+ firstBlock = newBlock;
+ cloneEnv.mapOldValToNew[childBlock] = newBlock;
+ }
+ auto indexParam = expandInst->getFirstBlock()->getFirstParam();
+ SLANG_ASSERT(indexParam);
+ cloneEnv.mapOldValToNew[indexParam] = subBuilder.getIntValue(subBuilder.getIntType(), i);
+
+ builder.emitBranch(firstBlock);
+
+ IRBlock* mergeBlock = subBuilder.emitBlock();
+ builder.setInsertInto(mergeBlock);
+
+ for (auto childBlock : expandInst->getBlocks())
+ {
+ auto newBlock = cloneEnv.mapOldValToNew[childBlock];
+ subBuilder.setInsertInto(newBlock);
+ for (auto child : childBlock->getChildren())
+ {
+ if (as<IRYield>(child))
+ {
+ elements.add(cloneEnv.mapOldValToNew[child->getOperand(0)]);
+ subBuilder.emitBranch(mergeBlock);
+ continue;
+ }
+ specializeExpandChildInst(cloneEnv, &subBuilder, child, i);
+ addToWorkList(childBlock);
+ }
+ }
+
+ }
+ auto resultTuple = builder.emitMakeTuple(elements);
+ auto currentBlock = builder.getBlock();
+ for (auto nextInst = expandInst->next; nextInst;)
+ {
+ auto next = nextInst->next;
+ nextInst->insertAtEnd(currentBlock);
+ nextInst = next;
+ }
+ addUsersToWorkList(expandInst);
+ expandInst->replaceUsesWith(resultTuple);
+ expandInst->removeAndDeallocate();
+ return true;
+ }
+
+ IRInst* clonePatternValImpl(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, UInt indexInPack)
+ {
+ if (!val)
+ return val;
+
+ switch (val->getOp())
+ {
+ case kIROp_ExpandTypeOrVal:
+ return val;
+ case kIROp_Each:
+ {
+ auto eachInst = as<IREach>(val);
+ auto packInst = eachInst->getElement();
+ if (auto tuple = as<IRTupleType>(packInst))
+ {
+ SLANG_RELEASE_ASSERT(indexInPack < tuple->getOperandCount());
+ return tuple->getOperand(indexInPack);
+ }
+ else if (auto makeTuple = as<IRMakeTuple>(packInst))
+ {
+ SLANG_RELEASE_ASSERT(indexInPack < makeTuple->getOperandCount());
+ return makeTuple->getOperand(indexInPack);
+ }
+ else if (!as<IRTypeKind>(packInst->getDataType()))
+ {
+ auto type = clonePatternVal(cloneEnv, builder, val, indexInPack);
+ return builder->emitGetTupleElement((IRType*)type, packInst, indexInPack);
+ }
+ return val;
+ }
+ default:
+ break;
+ }
+ bool anyChange = false;
+ ShortList<IRInst*> operands;
+ for (UInt i = 0; i < val->getOperandCount(); i++)
+ {
+ auto newOperand = clonePatternVal(cloneEnv, builder, val->getOperand(i), indexInPack);
+ if (newOperand != val->getOperand(i))
+ anyChange = true;
+ operands.add(newOperand);
+ }
+ auto newType = clonePatternVal(cloneEnv, builder, val->getFullType(), indexInPack);
+ if (newType != val->getFullType())
+ anyChange = true;
+ if (!anyChange)
+ return val;
+
+ auto newVal = builder->emitIntrinsicInst((IRType*)newType, val->getOp(), operands.getCount(), operands.getArrayView().getBuffer());
+ if (newVal != val)
+ {
+ cloneInstDecorationsAndChildren(&cloneEnv, module, val, newVal);
+ }
+ return newVal;
+ }
+
+ IRInst* clonePatternVal(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, UInt indexInPack)
+ {
+ if (auto clonedVal = cloneEnv.mapOldValToNew.tryGetValue(val))
+ return *clonedVal;
+ cloneEnv.mapOldValToNew[val] = val;
+ auto result = clonePatternValImpl(cloneEnv, builder, val, indexInPack);
+ cloneEnv.mapOldValToNew[val] = result;
+ return result;
+ }
+
+ bool maybeSpecializeExpandTypeOrVal(IRExpandType* expandInst)
+ {
+ if (expandInst->getCaptureCount() == 0)
+ return false;
+
+ bool anyAbstractPack = false;
+ for (UInt i = 0; i < expandInst->getCaptureCount(); i++)
+ {
+ if (!as<IRTupleType>(expandInst->getCaptureType(i)))
+ {
+ anyAbstractPack = true;
+ break;
+ }
+ }
+ if (anyAbstractPack)
+ return false;
+ IRBuilder builder(expandInst);
+ builder.setInsertBefore(expandInst);
+ List<IRInst*> elements;
+ UInt elementCount = 0;
+ if (auto firstTupleType = as<IRTupleType>(expandInst->getCaptureType(0)))
+ {
+ elementCount = firstTupleType->getOperandCount();
+ }
+ for (UInt i = 0; i < elementCount; i++)
+ {
+ IRCloneEnv cloneEnv;
+ auto element = clonePatternVal(cloneEnv, &builder, expandInst->getPatternType(), i);
+ elements.add(element);
+ }
+ addUsersToWorkList(expandInst);
+ if (as<IRWitnessTableType>(expandInst->getDataType()))
+ {
+ List<IRType*> types;
+ for (auto element : elements)
+ types.add(element->getDataType());
+ auto newTupleType = builder.getTupleType(types);
+ auto result = builder.emitMakeWitnessPack(newTupleType, elements.getArrayView());
+ expandInst->replaceUsesWith(result);
+ expandInst->removeAndDeallocate();
+ return true;
+ }
+ else
+ {
+ auto newTupleType = builder.getTupleType(elements.getCount(), (IRType*const*)elements.getBuffer());
+ expandInst->replaceUsesWith(newTupleType);
+ expandInst->removeAndDeallocate();
+ return true;
+ }
+ }
+
// The handling of specialization for global generic type
// parameters involves searching for all `bind_global_generic_param`
// instructions in the input module.
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index 788c9a391..e44c4079b 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -146,7 +146,8 @@ bool allUsesLeadToLoads(IRInst* inst)
// Is the given variable one that we can promote to SSA form?
bool isPromotableVar(
ConstructSSAContext* /*context*/,
- IRVar* var)
+ IRVar* var,
+ HashSet<IRBlock*> &knownBlocks)
{
// We want to identify variables such that we can always
// determine what they will contain at a point in the
@@ -226,8 +227,13 @@ bool isPromotableVar(
}
break;
}
+
+ // If the use is outside of known blocks, then we can't promote it.
+ if (!knownBlocks.contains(getBlock(user)))
+ return false;
}
+
// If all of the uses passed our checking, then
// we are good to go.
return true;
@@ -237,6 +243,12 @@ bool isPromotableVar(
void identifyPromotableVars(
ConstructSSAContext* context)
{
+ HashSet<IRBlock*> knownBlocks;
+ for (auto bb = context->globalVal->getFirstBlock(); bb; bb = bb->getNextBlock())
+ {
+ knownBlocks.add(bb);
+ }
+
for (auto bb = context->globalVal->getFirstBlock(); bb; bb = bb->getNextBlock())
{
for (auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst())
@@ -246,7 +258,7 @@ void identifyPromotableVars(
IRVar* var = (IRVar*)ii;
- if (isPromotableVar(context, var))
+ if (isPromotableVar(context, var, knownBlocks))
{
context->promotableVars.add(var);
}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index bfd6c20cf..c97c04f88 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -808,12 +808,7 @@ namespace Slang
IRType* IRFunc::getResultType() { return getDataType()->getResultType(); }
UInt IRFunc::getParamCount() { return getDataType()->getParamCount(); }
IRType* IRFunc::getParamType(UInt index) { return getDataType()->getParamType(index); }
-
- void IRGlobalValueWithCode::addBlock(IRBlock* block)
- {
- block->insertAtEnd(this);
- }
-
+
void fixUpFuncType(IRFunc* func, IRType* resultType)
{
SLANG_ASSERT(func);
@@ -1279,14 +1274,16 @@ namespace Slang
// Get the current function (or other value with code)
// that we are inserting into (if any).
- IRGlobalValueWithCode* IRInsertLoc::getFunc() const
+ IRInst* IRInsertLoc::getFunc() const
{
auto pp = getParent();
if (const auto block = as<IRBlock>(pp))
{
pp = pp->getParent();
}
- return as<IRGlobalValueWithCode>(pp);
+ if (as<IRGlobalValueWithCode>(pp) || as<IRExpand>(pp))
+ return pp;
+ return nullptr;
}
void addHoistableInst(
@@ -2805,6 +2802,14 @@ namespace Slang
return getTupleType(SLANG_COUNT_OF(operands), operands);
}
+ IRExpandType* IRBuilder::getExpandTypeOrVal(IRType* type, IRInst* pattern, ArrayView<IRInst*> capture)
+ {
+ ShortList<IRInst*> args;
+ args.add(pattern);
+ args.addRange(capture);
+ return (IRExpandType*)emitIntrinsicInst(type, kIROp_ExpandTypeOrVal, args.getCount(), args.getArrayView().getBuffer());
+ }
+
IRResultType* IRBuilder::getResultType(IRType* valueType, IRType* errorType)
{
IRInst* operands[] = {valueType, errorType};
@@ -3548,6 +3553,26 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitExpandInst(IRType* type, UInt capturedArgCount, IRInst* const* capturedArgs)
+ {
+ auto inst = createInstWithTrailingArgs<IRSpecialize>(
+ this,
+ kIROp_Expand,
+ type,
+ capturedArgCount,
+ capturedArgs,
+ 0,
+ nullptr);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitEachInst(IRType* type, IRInst* base, IRInst* indexArg)
+ {
+ IRInst* args[] = { base, indexArg };
+ return emitIntrinsicInst(type, kIROp_Each, indexArg ? 2 : 1, args);
+ }
+
IRInst* IRBuilder::emitLookupInterfaceMethodInst(
IRType* type,
IRInst* witnessTableVal,
@@ -4057,6 +4082,12 @@ namespace Slang
return emitIntrinsicInst(getNativeStringType(), kIROp_getNativeStr, 1, &str);
}
+ IRInst* IRBuilder::emitGetTupleElement(IRType* type, IRInst* tuple, IRInst* element)
+ {
+ IRInst* args[] = { tuple, element };
+ return emitIntrinsicInst(type, kIROp_GetTupleElement, 2, args);
+ }
+
IRInst* IRBuilder::emitGetTupleElement(IRType* type, IRInst* tuple, UInt element)
{
// As a quick simplification/optimization, if the user requests
@@ -4070,9 +4101,7 @@ namespace Slang
return makeTuple->getOperand(element);
}
}
-
- IRInst* args[] = { tuple, getIntValue(getIntType(), element) };
- return emitIntrinsicInst(type, kIROp_GetTupleElement, 2, args);
+ return emitGetTupleElement(type, tuple, getIntValue(getIntType(), element));
}
IRInst* IRBuilder::emitMakeResultError(IRType* resultType, IRInst* errorVal)
@@ -5409,6 +5438,18 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitYield(
+ IRInst* val)
+ {
+ auto inst = createInst<IRYield>(
+ this,
+ kIROp_Yield,
+ nullptr,
+ val);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitReturn()
{
auto voidVal = getVoidValue();
@@ -7238,6 +7279,7 @@ namespace Slang
case kIROp_Func:
case kIROp_GlobalVar:
case kIROp_Generic:
+ case kIROp_Expand:
dumpIRGlobalValueWithCode(context, (IRGlobalValueWithCode*)inst);
return;
@@ -8230,6 +8272,7 @@ namespace Slang
case kIROp_WitnessTableEntry:
case kIROp_InterfaceRequirementEntry:
case kIROp_Block:
+ case kIROp_Each:
return false;
/// Liveness markers have no side effects
@@ -8250,6 +8293,7 @@ namespace Slang
case kIROp_MakeMatrixFromScalar:
case kIROp_MatrixReshape:
case kIROp_VectorReshape:
+ case kIROp_MakeWitnessPack:
case kIROp_MakeArray:
case kIROp_MakeArrayFromElement:
case kIROp_MakeStruct:
@@ -8806,6 +8850,11 @@ namespace Slang
}
}
+ void IRInst::addBlock(IRBlock* block)
+ {
+ block->insertAtEnd(this);
+ }
+
void IRInst::dump()
{
if (auto intLit = as<IRIntLit>(this))
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 719b383c3..ececdad43 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -534,7 +534,7 @@ public:
/// This searches up the parent chain starting with `getParent()` looking for a code-bearing
/// value that things are being inserted into (could be a function, generic, etc.)
///
- IRGlobalValueWithCode* getFunc() const;
+ IRInst* getFunc() const;
private:
/// Internal constructor
@@ -567,6 +567,8 @@ enum class IRTypeLayoutRuleName
_Count,
};
+struct IRBlock;
+
// Every value in the IR is an instruction (even things
// like literal values).
//
@@ -833,6 +835,13 @@ struct IRInst
/// Print the IR to stdout for debugging purposes
///
void dump();
+
+ /// Insert a basic block at the end of this func/code containing inst.
+ void addBlock(IRBlock* block);
+
+ IRBlock* getFirstBlock() { return (IRBlock*)getFirstChild(); }
+ IRBlock* getLastBlock() { return (IRBlock*)getLastChild(); }
+
};
enum class IRDynamicCastBehavior
@@ -1291,11 +1300,6 @@ struct IRBlock : IRInst
getLastOrdinaryInst());
}
- // The parent of a basic block is assumed to be a
- // value with code (e.g., a function, global variable
- // with initializer, etc.).
- IRGlobalValueWithCode* getParent() { return cast<IRGlobalValueWithCode>(IRInst::getParent()); }
-
// The predecessor and successor lists of a block are needed
// when we want to work with the control flow graph (CFG) of
// a function. Rather than store these explicitly (and thus
@@ -1620,6 +1624,7 @@ struct IRRateQualifiedType : IRType
// same type.
SIMPLE_IR_PARENT_TYPE(Kind, Type);
SIMPLE_IR_TYPE(TypeKind, Kind);
+SIMPLE_IR_TYPE(TypeParameterPackKind, Kind);
// The kind of any and all generics.
//
@@ -1941,6 +1946,16 @@ struct IRTargetTupleType : IRType
IR_LEAF_ISA(TargetTupleType)
};
+/// Represents a `expand T` type used in variadic generic decls in Slang. Expected to be substituted
+/// by actual types during specialization.
+struct IRExpandType : IRType
+{
+ IR_LEAF_ISA(ExpandTypeOrVal)
+ IRType* getPatternType() { return (IRType*)(getOperand(0)); }
+ UInt getCaptureCount() { return getOperandCount() - 1; }
+ IRType* getCaptureType(UInt index) { return (IRType*)(getOperand(index + 1)); }
+};
+
/// Represents an `Result<T,E>`, used by functions that throws error codes.
struct IRResultType : IRType
{
@@ -2040,9 +2055,6 @@ struct IRGlobalValueWithCode : IRInst
return IRInstList<IRBlock>(getChildren());
}
- // Add a block to the end of this function.
- void addBlock(IRBlock* block);
-
IR_PARENT_ISA(GlobalValueWithCode)
};
diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp
index 5e8390cef..6009ef33e 100644
--- a/source/slang/slang-legalize-types.cpp
+++ b/source/slang/slang-legalize-types.cpp
@@ -287,6 +287,11 @@ struct TupleTypeBuilder
{
specialType = legalFieldType;
}
+
+ // `void` is currently legalized to simple, but we don't want to add a
+ // `void` field to the struct.
+ if (legalLeafType.getSimple()->getOp() == kIROp_VoidType)
+ return;
}
break;
@@ -419,7 +424,6 @@ struct TupleTypeBuilder
bool isSpecialField = context->isSpecialType(fieldType);
auto legalFieldType = legalizeType(context, fieldType);
-
addField(
field->getKey(),
legalFieldType,
@@ -1385,10 +1389,15 @@ LegalType legalizeTypeImpl(
context,
arrayType->getElementType());
- // If element type hasn't change, return original type.
- if (legalElementType.flavor == LegalType::Flavor::simple &&
- legalElementType.getSimple() == arrayType->getElementType())
- return LegalType::simple(arrayType);
+ if (legalElementType.flavor == LegalType::Flavor::simple)
+ {
+ if (legalElementType.getSimple()->getOp() == kIROp_VoidType)
+ return LegalType();
+
+ // If element type hasn't change, return original type.
+ if (legalElementType.getSimple() == arrayType->getElementType())
+ return LegalType::simple(arrayType);
+ }
ArrayLegalTypeWrapper wrapper;
wrapper.arrayType = arrayType;
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 95e9d96da..9ceb3074a 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -594,6 +594,9 @@ struct IRGenContext
bool includeDebugInfo = false;
+ // The element index if we are inside an `expand` expression.
+ IRInst* expandIndex = nullptr;
+
explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder)
: shared(inShared)
, astBuilder(inAstBuilder)
@@ -1653,6 +1656,86 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(resultVal);
}
+ LoweredValInfo visitConcreteTypePack(ConcreteTypePack* typePack)
+ {
+ ShortList<IRType*> types;
+ for (Index i = 0; i < typePack->getTypeCount(); i++)
+ {
+ auto loweredType = lowerType(context, typePack->getElementType(i));
+ types.add(loweredType);
+ }
+ auto irBuilder = getBuilder();
+ IRType* irTypePack = irBuilder->getTupleType((UInt)types.getCount(), types.getArrayView().getBuffer());
+ return LoweredValInfo::simple(irTypePack);
+ }
+
+ LoweredValInfo visitEachType(EachType* eachType)
+ {
+ auto type = lowerType(context, eachType->getElementType());
+ return LoweredValInfo::simple(getBuilder()->emitEachInst(
+ getBuilder()->getTypeKind(),
+ type));
+ }
+
+ LoweredValInfo visitExpandType(ExpandType* expandType)
+ {
+ auto irBuilder = getBuilder();
+ auto type = lowerType(context, expandType->getPatternType());
+ ShortList<IRInst*> capturedTypes;
+ for (Index i = 0; i < expandType->getCapturedTypePackCount(); i++)
+ {
+ auto loweredType = lowerType(context, expandType->getCapturedTypePack(i));
+ capturedTypes.add(loweredType);
+ }
+ return LoweredValInfo::simple(irBuilder->getExpandTypeOrVal(
+ irBuilder->getTypeKind(), type, capturedTypes.getArrayView().arrayView));
+ }
+
+ LoweredValInfo visitTypePackSubtypeWitness(TypePackSubtypeWitness* witnessPack)
+ {
+ auto irBuilder = getBuilder();
+ ShortList<IRInst*> witnesses;
+ ShortList<IRType*> elementTypes;
+ for (Index i = 0; i < witnessPack->getCount(); i++)
+ {
+ auto loweredWitness = lowerVal(context, witnessPack->getWitness(i));
+ witnesses.add(loweredWitness.val);
+ elementTypes.add(loweredWitness.val->getFullType());
+ }
+ auto irWitnessPack = irBuilder->emitMakeWitnessPack(
+ irBuilder->getTupleType((UInt)elementTypes.getCount(), elementTypes.getArrayView().getBuffer()),
+ witnesses.getArrayView().arrayView);
+ return LoweredValInfo::simple(irWitnessPack);
+ }
+
+ LoweredValInfo visitExpandSubtypeWitness(ExpandSubtypeWitness* witness)
+ {
+ auto irBuilder = getBuilder();
+
+ auto patternWitnessVal = lowerVal(context, witness->getPatternTypeWitness());
+ auto subType = lowerType(context, witness->getSub());
+ auto supType = lowerType(context, witness->getSup());
+ auto witnessTableType = irBuilder->getWitnessTableType(supType);
+ ShortList<IRInst*> captures;
+ if (auto expandType = as<IRExpandType>(subType))
+ {
+ for (UInt i = 0; i < expandType->getCaptureCount(); i++)
+ {
+ captures.add(expandType->getCaptureType(i));
+ }
+ }
+ return LoweredValInfo::simple(irBuilder->getExpandTypeOrVal(witnessTableType, patternWitnessVal.val, captures.getArrayView().arrayView));
+ }
+
+ LoweredValInfo visitEachSubtypeWitness(EachSubtypeWitness* witness)
+ {
+ auto elementWitness = lowerVal(context, witness->getPatternTypeWitness());
+ auto irBuilder = getBuilder();
+ auto subType = lowerType(context, witness->getSub());
+ auto witnessTableType = irBuilder->getWitnessTableType(subType);
+ return LoweredValInfo::simple(irBuilder->emitEachInst(witnessTableType, getSimpleVal(context, elementWitness)));
+ }
+
LoweredValInfo visitDeclaredSubtypeWitness(DeclaredSubtypeWitness* val)
{
if (as<ThisTypeConstraintDecl>(val->getDeclRef()))
@@ -1885,6 +1968,23 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
context->irBuilder->getTypeKind()));
}
+ IRType* visitTupleType(TupleType* type)
+ {
+ List<IRType*> elementTypes;
+ if (as<ConcreteTypePack>(type->getTypePack()))
+ {
+ for (Index i = 0; i < type->getMemberCount(); i++)
+ {
+ elementTypes.add(lowerType(context, type->getMember(i)));
+ }
+ return context->irBuilder->getTupleType(elementTypes);
+ }
+ else
+ {
+ return lowerType(context, type->getTypePack());
+ }
+ }
+
IRType* visitNamedExpressionType(NamedExpressionType* type)
{
return (IRType*)getSimpleVal(context, dispatchType(type->getCanonicalType()));
@@ -4315,19 +4415,54 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
return lowerSubExpr(expr->base);
}
- LoweredValInfo visitPackExpr(PackExpr*)
+ LoweredValInfo visitPackExpr(PackExpr* expr)
{
- SLANG_UNIMPLEMENTED_X("codegen for pack expression");
+ List<IRInst*> irArgs;
+ for (auto arg : expr->args)
+ {
+ irArgs.add(getSimpleVal(context, lowerSubExpr(arg)));
+ }
+ auto irMakeTuple = getBuilder()->emitMakeTuple(irArgs);
+ return LoweredValInfo::simple(irMakeTuple);
}
- LoweredValInfo visitEachExpr(EachExpr*)
+ LoweredValInfo visitEachExpr(EachExpr* expr)
{
- SLANG_UNIMPLEMENTED_X("codegen for each expression");
+ auto subVal = lowerSubExpr(expr->baseExpr);
+ SLANG_ASSERT(context->expandIndex);
+ auto irEach = getBuilder()->emitGetTupleElement(lowerType(context, expr->type), getSimpleVal(context, subVal), context->expandIndex);
+ return LoweredValInfo::simple(irEach);
}
- LoweredValInfo visitExpandExpr(ExpandExpr*)
+ LoweredValInfo visitExpandExpr(ExpandExpr* expr)
{
- SLANG_UNIMPLEMENTED_X("codegen for expand expression");
+ auto irBuilder = getBuilder();
+ auto irType = lowerType(context, expr->type);
+ List<IRInst*> irCapturedPacks;
+ if (auto expandType = as<IRExpandType>(irType))
+ {
+ for (UInt i = 0; i < expandType->getCaptureCount(); i++)
+ {
+ irCapturedPacks.add(expandType->getCaptureType(i));
+ }
+ }
+ else
+ {
+ // If the type of the expression is not an ExpandType, then it must be
+ // a DeclRefType to a generic type pack parameter.
+ // In this case, the captured type is just the DeclRefType itself.
+ irCapturedPacks.add(irType);
+ }
+ auto expandInst = irBuilder->emitExpandInst(irType, (UInt)irCapturedPacks.getCount(), irCapturedPacks.getBuffer());
+ irBuilder->setInsertInto(expandInst);
+ irBuilder->emitBlock();
+ auto eachIndex = irBuilder->emitParam(irBuilder->getIntType());
+ IRInst* oldExpandIndex = context->expandIndex;
+ context->expandIndex = eachIndex;
+ SLANG_DEFER(context->expandIndex = oldExpandIndex);
+ irBuilder->emitYield(getSimpleVal(context, lowerSubExpr(expr->baseExpr)));
+ irBuilder->setInsertAfter(expandInst);
+ return LoweredValInfo::simple(expandInst);
}
LoweredValInfo getSimpleDefaultVal(IRType* type)
@@ -8968,11 +9103,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// in the order they were declared.
for (auto member : genericDecl->members)
{
- if (auto typeParamDecl = as<GenericTypeParamDecl>(member))
+ if (auto typeParamDecl = as<GenericTypeParamDeclBase>(member))
{
- // TODO: use a `TypeKind` to represent the
- // classifier of the parameter.
- auto param = subBuilder->emitParam(subBuilder->getTypeType());
+ IRType* typeKind = nullptr;
+ if (as<GenericTypePackParamDecl>(member))
+ typeKind = subBuilder->getTypeParameterPackKind();
+ else
+ typeKind = subBuilder->getTypeType();
+ auto param = subBuilder->emitParam(typeKind);
addNameHint(context, param, typeParamDecl);
subContext->setValue(typeParamDecl, LoweredValInfo::simple(param));
}
@@ -10289,7 +10427,15 @@ LoweredValInfo ensureDecl(
}
IRBuilder subIRBuilder(context->irBuilder->getModule());
- subIRBuilder.setInsertInto(subIRBuilder.getModule());
+ if (as<VarDecl>(decl) && decl->findModifier<LocalTempVarModifier>())
+ {
+ // Do not modify insert location.
+ subIRBuilder.setInsertLoc(context->irBuilder->getInsertLoc());
+ }
+ else
+ {
+ subIRBuilder.setInsertInto(subIRBuilder.getModule());
+ }
IRGenEnv subEnv;
subEnv.outer = context->env;
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 7fa9e8fc0..a55f0eb1a 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -507,7 +507,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
// TODO: need to figure out how to unify this with the logic
// in the generic case...
Type* DeclRefType::create(
- ASTBuilder* astBuilder,
+ ASTBuilder* astBuilder,
DeclRef<Decl> declRef)
{
if (declRef.getDecl()->findModifier<BuiltinTypeModifier>())