summaryrefslogtreecommitdiff
path: root/source/slang/lower-to-ir.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/lower-to-ir.cpp')
-rw-r--r--source/slang/lower-to-ir.cpp320
1 files changed, 188 insertions, 132 deletions
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 06ad66bc4..7fce0c385 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -79,7 +79,7 @@ struct SubscriptInfo : ExtendedValueInfo
struct BoundSubscriptInfo : ExtendedValueInfo
{
DeclRef<SubscriptDecl> declRef;
- IRType* type;
+ RefPtr<Type> type;
List<IRValue*> args;
UInt genericArgCount;
};
@@ -218,8 +218,8 @@ struct BoundMemberInfo : ExtendedValueInfo
//
struct SwizzledLValueInfo : ExtendedValueInfo
{
- // IR-level The type of the expression.
- IRType* type;
+ // The type of the expression.
+ RefPtr<Type> type;
// The base expression (this should be an l-value)
LoweredValInfo base;
@@ -355,7 +355,7 @@ LoweredValInfo emitCompoundAssignOp(
auto leftVal = builder->emitLoad(leftPtr);
- IRInst* innerArgs[] = { leftVal, rightVal };
+ IRValue* innerArgs[] = { leftVal, rightVal };
auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs);
builder->emitStore(leftPtr, innerOp);
@@ -363,23 +363,31 @@ LoweredValInfo emitCompoundAssignOp(
return LoweredValInfo::ptr(leftPtr);
}
-IRInst* getOneValOfType(
+IRValue* getOneValOfType(
IRGenContext* context,
IRType* type)
{
- switch(type->op)
+ if (auto basicType = dynamic_cast<BasicExpressionType*>(type))
{
- case kIROp_Int32Type:
- case kIROp_UInt32Type:
- return context->irBuilder->getIntValue(type, 1);
+ switch (basicType->baseType)
+ {
+ case BaseType::Int:
+ case BaseType::UInt:
+ case BaseType::UInt64:
+ return context->irBuilder->getIntValue(type, 1);
- case kIROp_Float32Type:
- return context->irBuilder->getFloatValue(type, 1.0);
+ case BaseType::Float:
+ case BaseType::Double:
+ return context->irBuilder->getFloatValue(type, 1.0);
- default:
- SLANG_UNEXPECTED("inc/dec type");
- return nullptr;
+ default:
+ break;
+ }
}
+ // TODO: should make sure to handle vector and matrix types here
+
+ SLANG_UNEXPECTED("inc/dec type");
+ return nullptr;
}
LoweredValInfo emitPreOp(
@@ -396,9 +404,9 @@ LoweredValInfo emitPreOp(
auto preVal = builder->emitLoad(argPtr);
- IRInst* oneVal = getOneValOfType(context, type);
+ IRValue* oneVal = getOneValOfType(context, type);
- IRInst* innerArgs[] = { preVal, oneVal };
+ IRValue* innerArgs[] = { preVal, oneVal };
auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs);
builder->emitStore(argPtr, innerOp);
@@ -420,9 +428,9 @@ LoweredValInfo emitPostOp(
auto preVal = builder->emitLoad(argPtr);
- IRInst* oneVal = getOneValOfType(context, type);
+ IRValue* oneVal = getOneValOfType(context, type);
- IRInst* innerArgs[] = { preVal, oneVal };
+ IRValue* innerArgs[] = { preVal, oneVal };
auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs);
builder->emitStore(argPtr, innerOp);
@@ -647,10 +655,7 @@ struct LoweredTypeInfo
Simple,
};
- union
- {
- IRType* type;
- };
+ RefPtr<IRType> type;
Flavor flavor;
LoweredTypeInfo()
@@ -743,6 +748,35 @@ LoweredValInfo lowerDecl(
DeclBase* decl,
Layout* layout);
+IRType* getIntType(
+ IRGenContext* context)
+{
+ return context->getSession()->getBuiltinType(BaseType::Int);
+}
+
+// Get a pointer type to the given element type
+RefPtr<PtrType> getPtrType(
+ IRGenContext* context,
+ IRType* valueType)
+{
+ return context->getSession()->getPtrType(valueType);
+}
+
+RefPtr<IRFuncType> getFuncType(
+ IRGenContext* context,
+ UInt paramCount,
+ RefPtr<IRType> const* paramTypes,
+ IRType* resultType)
+{
+ RefPtr<FuncType> funcType = new FuncType();
+ funcType->resultType = resultType;
+ for (UInt pp = 0; pp < paramCount; ++pp)
+ {
+ funcType->paramTypes.Add(paramTypes[pp]);
+ }
+ return funcType;
+}
+
//
struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredTypeInfo>
@@ -761,7 +795,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
// TODO: it is a bit messy here that the `ConstantIntVal` representation
// has no notion of a *type* associated with the value...
- auto type = getBuilder()->getBaseType(BaseType::Int);
+ auto type = getIntType(context);
return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value));
}
@@ -772,16 +806,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredTypeInfo visitFuncType(FuncType* type)
{
- LoweredValInfo loweredFunc = ensureDecl(context, type->declRef);
- auto loweredFuncVal = getSimpleVal(context, loweredFunc);
-
- // HACK: deal with the case where the decl might not
- // lower to anything, and so we don't have a type to
- // work with.
- if (!loweredFuncVal)
- return LoweredTypeInfo();
-
- return loweredFuncVal->getType();
+ return LoweredTypeInfo(type);
}
void addGenericArgs(List<IRValue*>* ioArgs, DeclRefBase declRef)
@@ -799,12 +824,16 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredTypeInfo visitDeclRefType(DeclRefType* type)
{
+#if 1
+ // TODO: is there actually anything to be done at this point?
+ return LoweredTypeInfo(type);
+#else
// We need to detect builtin/intrinsic types here, since they should map to custom modifiers
// We need to catch builtin/intrinsic types here
if( auto intrinsicTypeMod = type->declRef.getDecl()->FindModifier<IntrinsicTypeModifier>() )
{
auto builder = getBuilder();
- auto intType = builder->getBaseType(BaseType::Int);
+ auto intType = getIntType(context);
//
List<IRValue*> irArgs;
for( auto val : intrinsicTypeMod->irOperands )
@@ -831,61 +860,32 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
default:
SLANG_UNIMPLEMENTED_X("type lowering");
}
-
+#endif
}
LoweredTypeInfo visitBasicExpressionType(BasicExpressionType* type)
{
- return getBuilder()->getBaseType(type->baseType);
+ return LoweredTypeInfo(type);
}
LoweredTypeInfo visitVectorExpressionType(VectorExpressionType* type)
{
- auto irElementType = lowerSimpleType(context, type->elementType);
- auto irElementCount = lowerSimpleVal(context, type->elementCount);
-
- return getBuilder()->getVectorType(irElementType, irElementCount);
+ return LoweredTypeInfo(type);
}
LoweredTypeInfo visitMatrixExpressionType(MatrixExpressionType* type)
{
- auto irElementType = lowerSimpleType(context, type->getElementType());
- auto irRowCount = lowerSimpleVal(context, type->getRowCount());
- auto irColumnCount = lowerSimpleVal(context, type->getColumnCount());
-
- return getBuilder()->getMatrixType(irElementType, irRowCount, irColumnCount);
+ return LoweredTypeInfo(type);
}
- LoweredTypeInfo getArrayType(
- LoweredTypeInfo const& loweredElementType,
- IRValue* irElementCount)
+ LoweredTypeInfo visitArrayExpressionType(ArrayExpressionType* type)
{
- switch (loweredElementType.flavor)
- {
- case LoweredTypeInfo::Flavor::Simple:
- return getBuilder()->getArrayType(
- loweredElementType.type,
- irElementCount);
- break;
-
- default:
- SLANG_UNEXPECTED("array element type");
- break;
- }
+ return LoweredTypeInfo(type);
}
- LoweredTypeInfo visitArrayExpressionType(ArrayExpressionType* type)
+ LoweredTypeInfo visitIRBasicBlockType(IRBasicBlockType* type)
{
- auto loweredElementType = lowerType(context, type->baseType);
- if (auto elementCount = type->ArrayLength)
- {
- auto irElementCount = lowerSimpleVal(context, elementCount);
- return getArrayType(loweredElementType, irElementCount);
- }
- else
- {
- return getArrayType(loweredElementType, nullptr);
- }
+ return LoweredTypeInfo(type);
}
};
@@ -1017,9 +1017,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
if (auto fieldDeclRef = declRef.As<StructField>())
{
// Okay, easy enough: we have a reference to a field of a struct type...
-
- auto loweredField = ensureDecl(context, fieldDeclRef);
- return extractField(loweredType, loweredBase, loweredField);
+ return extractField(loweredType, loweredBase, fieldDeclRef);
}
else if (auto callableDeclRef = declRef.As<CallableDecl>())
{
@@ -1045,14 +1043,12 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// in order for a dereference to make senese, so we just
// need to extract the value type from that pointer here.
//
- auto loweredBaseVal = getSimpleVal(context, loweredBase);
- auto loweredBaseType = loweredBaseVal->getType();
- switch( loweredBaseType->op )
- {
- case kIROp_PtrType:
- // TODO: should we enumerate these explicitly?
- case kIROp_ConstantBufferType:
- case kIROp_TextureBufferType:
+ IRValue* loweredBaseVal = getSimpleVal(context, loweredBase);
+ RefPtr<Type> loweredBaseType = loweredBaseVal->getType();
+
+ if (loweredBaseType->As<PointerLikeType>()
+ || loweredBaseType->As<PtrType>())
+ {
// Note that we do *not* perform an actual `load` operation
// here, but rather just use the pointer value to construct
// an appropriate `LoweredValInfo` representing the underlying
@@ -1064,8 +1060,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// and is just a bit of pointer math.
//
return LoweredValInfo::ptr(loweredBaseVal);
-
- default:
+ }
+ else
+ {
SLANG_UNIMPLEMENTED_X("codegen for deref expression");
return LoweredValInfo();
}
@@ -1472,7 +1469,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
case LoweredValInfo::Flavor::Ptr:
return LoweredValInfo::ptr(
builder->emitElementAddress(
- builder->getPtrType(getSimpleType(type)),
+ getPtrType(context, getSimpleType(type)),
baseVal.val,
indexVal));
@@ -1484,9 +1481,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
}
LoweredValInfo extractField(
- LoweredTypeInfo fieldType,
- LoweredValInfo base,
- LoweredValInfo field)
+ LoweredTypeInfo fieldType,
+ LoweredValInfo base,
+ DeclRef<StructField> field)
{
switch (base.flavor)
{
@@ -1497,7 +1494,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
getBuilder()->emitFieldExtract(
getSimpleType(fieldType),
irBase,
- (IRStructField*) getSimpleVal(context, field)));
+ getBuilder()->getDeclRefVal(field)));
}
break;
@@ -1509,9 +1506,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
IRValue* irBasePtr = base.val;
return LoweredValInfo::ptr(
getBuilder()->emitFieldAddress(
- getBuilder()->getPtrType(getSimpleType(fieldType)),
+ getPtrType(context, getSimpleType(fieldType)),
irBasePtr,
- (IRStructField*) getSimpleVal(context, field)));
+ getBuilder()->getDeclRefVal(field)));
}
break;
}
@@ -1598,7 +1595,7 @@ struct RValueExprLoweringVisitor : ExprLoweringVisitorBase<RValueExprLoweringVis
auto builder = getBuilder();
- auto irIntType = builder->getBaseType(BaseType::Int);
+ auto irIntType = getIntType(context);
UInt elementCount = (UInt)expr->elementCount;
IRValue* irElementIndices[4];
@@ -1661,34 +1658,22 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
void insertBlock(IRBlock* block)
{
auto builder = getBuilder();
- auto parent = builder->parentInst;
- IRBlock* prevBlock = nullptr;
- IRFunc* parentFunc = nullptr;
-
- switch (parent->op)
- {
- case kIROp_Block:
- prevBlock = (IRBlock*)parent;
- parentFunc = prevBlock->getParent();
- break;
-
- default:
- SLANG_UNEXPECTED("bad parent kind for block");
- return;
- }
+ auto prevBlock = builder->block;
+ auto parentFunc = prevBlock->parentFunc;
// If the previous block doesn't already have
// a terminator instruction, then be sure to
// emit a branch to the new block.
- if (!isTerminatorInst(prevBlock->lastChild))
+ if (!isTerminatorInst(prevBlock->lastInst))
{
builder->emitBranch(block);
}
- builder->parentInst = parentFunc;
- builder->addInst(block);
- builder->parentInst = block;
+ parentFunc->addBlock(block);
+
+ builder->func = parentFunc;
+ builder->block = block;
}
// Start a new block at the current location.
@@ -2025,8 +2010,8 @@ top:
auto loweredBase = swizzleInfo->base;
// Load from the base value:
- IRInst* irLeftVal = getSimpleVal(context, loweredBase);
- auto irRightVal = getSimpleVal(context, right);
+ IRValue* irLeftVal = getSimpleVal(context, loweredBase);
+ IRValue* irRightVal = getSimpleVal(context, right);
// Now apply the swizzle
IRInst* irSwizzled = builder->emitSwizzleSet(
@@ -2066,7 +2051,7 @@ top:
emitCallToDeclRef(
context,
- builder->getVoidType(),
+ context->getSession()->getVoidType(),
setterDeclRef,
allArgs,
subscriptInfo->genericArgCount);
@@ -2153,8 +2138,67 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return LoweredValInfo();
}
+ bool isGlobalVarDecl(VarDeclBase* decl)
+ {
+ auto parent = decl->ParentDecl;
+ if (dynamic_cast<ModuleDecl*>(parent))
+ {
+ // Variable declared at global scope? -> Global.
+ return true;
+ }
+
+ return false;
+ }
+
+ LoweredValInfo lowerGlobalVarDecl(VarDeclBase* decl)
+ {
+ auto varType = lowerSimpleType(context, decl->getType());
+
+ IRAddressSpace addressSpace = kIRAddressSpace_Default;
+ if (decl->HasModifier<HLSLGroupSharedModifier>())
+ {
+ addressSpace = kIRAddressSpace_GroupShared;
+ }
+
+ auto builder = getBuilder();
+ auto irGlobal = builder->createGlobalVar(varType);
+
+ if (decl)
+ {
+ builder->addHighLevelDeclDecoration(irGlobal, decl);
+ }
+
+ if (auto layout = getLayout())
+ {
+ builder->addLayoutDecoration(irGlobal, layout);
+ }
+
+ // A global variable's SSA value is a *pointer* to
+ // the underlying storage.
+ auto globalVal = LoweredValInfo::ptr(irGlobal);
+ context->shared->declValues.Add(
+ DeclRef<VarDeclBase>(decl, nullptr),
+ globalVal);
+
+ if( auto initExpr = decl->initExpr )
+ {
+ // TODO: need to handle global with initializer!
+ }
+
+ getBuilder()->getModule()->globalValues.Add(irGlobal);
+
+ return globalVal;
+ }
+
LoweredValInfo visitVarDeclBase(VarDeclBase* decl)
{
+ // Detect global (or effectively global) variables
+ // and handle them differently.
+ if (isGlobalVarDecl(decl))
+ {
+ return lowerGlobalVarDecl(decl);
+ }
+
// A user-defined variable declaration will usually turn into
// an `alloca` operation for the variable's storage,
// plus some code to initialize it and then store to the variable.
@@ -2219,6 +2263,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl)
{
+#if 0
// User-defined aggregate type: need to translate into
// a corresponding IR aggregate type.
@@ -2257,6 +2302,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
builder->addInst(irStruct);
return LoweredValInfo::simple(irStruct);
+#else
+ // TODO: What is there to do with a `struct` type?
+ return LoweredValInfo();
+#endif
}
// Sometimes we need to refer to a declaration the way that it would be specialized
@@ -2592,7 +2641,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
void trySetMangledName(
- IRInst* inst,
+ IRFunc* irFunc,
Decl* decl)
{
// We want to generate a mangled name for the given declaration and attach
@@ -2605,8 +2654,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
String mangledName = getMangledName(decl);
- auto decoration = getBuilder()->addDecoration<IRMangledNameDecoration>(inst);
- decoration->mangledName = mangledName;
+ irFunc->mangledName = mangledName;
}
@@ -2649,11 +2697,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// need to create an IR function here
IRFunc* irFunc = subBuilder->createFunc();
- subBuilder->parentInst = irFunc;
+ subBuilder->func = irFunc;
trySetMangledName(irFunc, decl);
- List<IRType*> paramTypes;
+ List<RefPtr<Type>> paramTypes;
// We first need to walk the generic parameters (if any)
// because these will influence the declared type of
@@ -2662,6 +2710,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
for( auto genericParamDecl : parameterLists.genericParams )
{
UInt genericParamIndex = genericParamCounter++;
+#if 0
if( auto genericTypeParamDecl = dynamic_cast<GenericTypeParamDecl*>(genericParamDecl) )
{
// In the logical type for the function, a generic
@@ -2675,10 +2724,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// to the appropriate generic parameter position.
IRType* irParameterType = context->irBuilder->getGenericParameterType(genericParamIndex);
- LoweredValInfo LoweredValInfo = LoweredValInfo::simple(irParameterType);
+ LoweredValInfo LoweredValInfo = LoweredValInfo::type(irParameterType);
subContext->shared->declValues[makeDeclRef(genericTypeParamDecl)] = LoweredValInfo;
}
else
+#endif
{
// TODO: handle the other cases here.
SLANG_UNEXPECTED("generic parameter kind");
@@ -2702,7 +2752,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
// TODO: Is this the best representation we can use?
- auto irPtrType = subBuilder->getPtrType(irParamType);
+ auto irPtrType = getPtrType(context, irParamType);
paramTypes.Add(irPtrType);
}
}
@@ -2726,14 +2776,15 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Instead, a setter always returns `void`
//
- irResultType = getBuilder()->getVoidType();
+ irResultType = context->getSession()->getVoidType();
}
- auto irFuncType = getBuilder()->getFuncType(
+ auto irFuncType = getFuncType(
+ context,
paramTypes.Count(),
paramTypes.Buffer(),
irResultType);
- irFunc->type.init(irFunc, irFuncType);
+ irFunc->type = irFuncType;
if (!decl->Body)
{
@@ -2753,7 +2804,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// This is a function definition, so we need to actually
// construct IR for the body...
IRBlock* entryBlock = subBuilder->emitBlock();
- subBuilder->parentInst = entryBlock;
+ subBuilder->block = entryBlock;
UInt paramTypeIndex = 0;
for( auto paramInfo : parameterLists.params )
@@ -2771,7 +2822,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
// TODO: Is this the best representation we can use?
- auto irPtrType = (IRPtrType*)irParamType;
+ auto irPtrType = irParamType.As<PtrType>();
IRParam* irParamPtr = subBuilder->emitParam(irPtrType);
if(auto paramDecl = paramInfo.decl)
@@ -2829,14 +2880,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// We need to carefully add a terminator instruction to the end
// of the body, in case the user didn't do so.
- if (!isTerminatorInst(subContext->irBuilder->parentInst->lastChild))
+ if (!isTerminatorInst(subContext->irBuilder->block->lastInst))
{
- if (irResultType->op == kIROp_VoidType)
+ if (irResultType->Equals(context->getSession()->getVoidType()))
{
+ // `void`-returning function can get an implicit
+ // return on exit of the body statement.
subContext->irBuilder->emitReturn();
}
else
{
+ // Value-returning function is expected to `return`
+ // on every control-flow path. We need to enforce
+ // this by putting an `unreachable` terminator here,
+ // and then emit a dataflow error if this block
+ // can't be eliminated.
SLANG_UNEXPECTED("Needed a return here");
subContext->irBuilder->emitReturn();
}
@@ -2845,7 +2903,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addHighLevelDeclDecoration(irFunc, decl);
- getBuilder()->addInst(irFunc);
+ getBuilder()->getModule()->globalValues.Add(irFunc);
return LoweredValInfo::simple(irFunc);
}
@@ -2878,7 +2936,6 @@ LoweredValInfo ensureDecl(
IRBuilder subIRBuilder;
subIRBuilder.shared = context->irBuilder->shared;
- subIRBuilder.parentInst = subIRBuilder.shared->module;
IRGenContext subContext = *context;
@@ -2997,15 +3054,14 @@ IRModule* lowerEntryPointToIR(
SharedIRBuilder sharedBuilderStorage;
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->module = nullptr;
+ sharedBuilder->session = entryPoint->compileRequest->mSession;
IRBuilder builderStorage;
IRBuilder* builder = &builderStorage;
builder->shared = sharedBuilder;
- builder->parentInst = nullptr;
IRModule* module = builder->createModule();
sharedBuilder->module = module;
- builder->parentInst = module;
context->irBuilder = builder;