diff options
Diffstat (limited to 'source/slang/lower-to-ir.cpp')
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 320 |
1 files changed, 188 insertions, 132 deletions
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 06ad66bc4..7fce0c385 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -79,7 +79,7 @@ struct SubscriptInfo : ExtendedValueInfo struct BoundSubscriptInfo : ExtendedValueInfo { DeclRef<SubscriptDecl> declRef; - IRType* type; + RefPtr<Type> type; List<IRValue*> args; UInt genericArgCount; }; @@ -218,8 +218,8 @@ struct BoundMemberInfo : ExtendedValueInfo // struct SwizzledLValueInfo : ExtendedValueInfo { - // IR-level The type of the expression. - IRType* type; + // The type of the expression. + RefPtr<Type> type; // The base expression (this should be an l-value) LoweredValInfo base; @@ -355,7 +355,7 @@ LoweredValInfo emitCompoundAssignOp( auto leftVal = builder->emitLoad(leftPtr); - IRInst* innerArgs[] = { leftVal, rightVal }; + IRValue* innerArgs[] = { leftVal, rightVal }; auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs); builder->emitStore(leftPtr, innerOp); @@ -363,23 +363,31 @@ LoweredValInfo emitCompoundAssignOp( return LoweredValInfo::ptr(leftPtr); } -IRInst* getOneValOfType( +IRValue* getOneValOfType( IRGenContext* context, IRType* type) { - switch(type->op) + if (auto basicType = dynamic_cast<BasicExpressionType*>(type)) { - case kIROp_Int32Type: - case kIROp_UInt32Type: - return context->irBuilder->getIntValue(type, 1); + switch (basicType->baseType) + { + case BaseType::Int: + case BaseType::UInt: + case BaseType::UInt64: + return context->irBuilder->getIntValue(type, 1); - case kIROp_Float32Type: - return context->irBuilder->getFloatValue(type, 1.0); + case BaseType::Float: + case BaseType::Double: + return context->irBuilder->getFloatValue(type, 1.0); - default: - SLANG_UNEXPECTED("inc/dec type"); - return nullptr; + default: + break; + } } + // TODO: should make sure to handle vector and matrix types here + + SLANG_UNEXPECTED("inc/dec type"); + return nullptr; } LoweredValInfo emitPreOp( @@ -396,9 +404,9 @@ LoweredValInfo emitPreOp( auto preVal = builder->emitLoad(argPtr); - IRInst* oneVal = getOneValOfType(context, type); + IRValue* oneVal = getOneValOfType(context, type); - IRInst* innerArgs[] = { preVal, oneVal }; + IRValue* innerArgs[] = { preVal, oneVal }; auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs); builder->emitStore(argPtr, innerOp); @@ -420,9 +428,9 @@ LoweredValInfo emitPostOp( auto preVal = builder->emitLoad(argPtr); - IRInst* oneVal = getOneValOfType(context, type); + IRValue* oneVal = getOneValOfType(context, type); - IRInst* innerArgs[] = { preVal, oneVal }; + IRValue* innerArgs[] = { preVal, oneVal }; auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs); builder->emitStore(argPtr, innerOp); @@ -647,10 +655,7 @@ struct LoweredTypeInfo Simple, }; - union - { - IRType* type; - }; + RefPtr<IRType> type; Flavor flavor; LoweredTypeInfo() @@ -743,6 +748,35 @@ LoweredValInfo lowerDecl( DeclBase* decl, Layout* layout); +IRType* getIntType( + IRGenContext* context) +{ + return context->getSession()->getBuiltinType(BaseType::Int); +} + +// Get a pointer type to the given element type +RefPtr<PtrType> getPtrType( + IRGenContext* context, + IRType* valueType) +{ + return context->getSession()->getPtrType(valueType); +} + +RefPtr<IRFuncType> getFuncType( + IRGenContext* context, + UInt paramCount, + RefPtr<IRType> const* paramTypes, + IRType* resultType) +{ + RefPtr<FuncType> funcType = new FuncType(); + funcType->resultType = resultType; + for (UInt pp = 0; pp < paramCount; ++pp) + { + funcType->paramTypes.Add(paramTypes[pp]); + } + return funcType; +} + // struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredTypeInfo> @@ -761,7 +795,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower // TODO: it is a bit messy here that the `ConstantIntVal` representation // has no notion of a *type* associated with the value... - auto type = getBuilder()->getBaseType(BaseType::Int); + auto type = getIntType(context); return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value)); } @@ -772,16 +806,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredTypeInfo visitFuncType(FuncType* type) { - LoweredValInfo loweredFunc = ensureDecl(context, type->declRef); - auto loweredFuncVal = getSimpleVal(context, loweredFunc); - - // HACK: deal with the case where the decl might not - // lower to anything, and so we don't have a type to - // work with. - if (!loweredFuncVal) - return LoweredTypeInfo(); - - return loweredFuncVal->getType(); + return LoweredTypeInfo(type); } void addGenericArgs(List<IRValue*>* ioArgs, DeclRefBase declRef) @@ -799,12 +824,16 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredTypeInfo visitDeclRefType(DeclRefType* type) { +#if 1 + // TODO: is there actually anything to be done at this point? + return LoweredTypeInfo(type); +#else // We need to detect builtin/intrinsic types here, since they should map to custom modifiers // We need to catch builtin/intrinsic types here if( auto intrinsicTypeMod = type->declRef.getDecl()->FindModifier<IntrinsicTypeModifier>() ) { auto builder = getBuilder(); - auto intType = builder->getBaseType(BaseType::Int); + auto intType = getIntType(context); // List<IRValue*> irArgs; for( auto val : intrinsicTypeMod->irOperands ) @@ -831,61 +860,32 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower default: SLANG_UNIMPLEMENTED_X("type lowering"); } - +#endif } LoweredTypeInfo visitBasicExpressionType(BasicExpressionType* type) { - return getBuilder()->getBaseType(type->baseType); + return LoweredTypeInfo(type); } LoweredTypeInfo visitVectorExpressionType(VectorExpressionType* type) { - auto irElementType = lowerSimpleType(context, type->elementType); - auto irElementCount = lowerSimpleVal(context, type->elementCount); - - return getBuilder()->getVectorType(irElementType, irElementCount); + return LoweredTypeInfo(type); } LoweredTypeInfo visitMatrixExpressionType(MatrixExpressionType* type) { - auto irElementType = lowerSimpleType(context, type->getElementType()); - auto irRowCount = lowerSimpleVal(context, type->getRowCount()); - auto irColumnCount = lowerSimpleVal(context, type->getColumnCount()); - - return getBuilder()->getMatrixType(irElementType, irRowCount, irColumnCount); + return LoweredTypeInfo(type); } - LoweredTypeInfo getArrayType( - LoweredTypeInfo const& loweredElementType, - IRValue* irElementCount) + LoweredTypeInfo visitArrayExpressionType(ArrayExpressionType* type) { - switch (loweredElementType.flavor) - { - case LoweredTypeInfo::Flavor::Simple: - return getBuilder()->getArrayType( - loweredElementType.type, - irElementCount); - break; - - default: - SLANG_UNEXPECTED("array element type"); - break; - } + return LoweredTypeInfo(type); } - LoweredTypeInfo visitArrayExpressionType(ArrayExpressionType* type) + LoweredTypeInfo visitIRBasicBlockType(IRBasicBlockType* type) { - auto loweredElementType = lowerType(context, type->baseType); - if (auto elementCount = type->ArrayLength) - { - auto irElementCount = lowerSimpleVal(context, elementCount); - return getArrayType(loweredElementType, irElementCount); - } - else - { - return getArrayType(loweredElementType, nullptr); - } + return LoweredTypeInfo(type); } }; @@ -1017,9 +1017,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> if (auto fieldDeclRef = declRef.As<StructField>()) { // Okay, easy enough: we have a reference to a field of a struct type... - - auto loweredField = ensureDecl(context, fieldDeclRef); - return extractField(loweredType, loweredBase, loweredField); + return extractField(loweredType, loweredBase, fieldDeclRef); } else if (auto callableDeclRef = declRef.As<CallableDecl>()) { @@ -1045,14 +1043,12 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // in order for a dereference to make senese, so we just // need to extract the value type from that pointer here. // - auto loweredBaseVal = getSimpleVal(context, loweredBase); - auto loweredBaseType = loweredBaseVal->getType(); - switch( loweredBaseType->op ) - { - case kIROp_PtrType: - // TODO: should we enumerate these explicitly? - case kIROp_ConstantBufferType: - case kIROp_TextureBufferType: + IRValue* loweredBaseVal = getSimpleVal(context, loweredBase); + RefPtr<Type> loweredBaseType = loweredBaseVal->getType(); + + if (loweredBaseType->As<PointerLikeType>() + || loweredBaseType->As<PtrType>()) + { // Note that we do *not* perform an actual `load` operation // here, but rather just use the pointer value to construct // an appropriate `LoweredValInfo` representing the underlying @@ -1064,8 +1060,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // and is just a bit of pointer math. // return LoweredValInfo::ptr(loweredBaseVal); - - default: + } + else + { SLANG_UNIMPLEMENTED_X("codegen for deref expression"); return LoweredValInfo(); } @@ -1472,7 +1469,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> case LoweredValInfo::Flavor::Ptr: return LoweredValInfo::ptr( builder->emitElementAddress( - builder->getPtrType(getSimpleType(type)), + getPtrType(context, getSimpleType(type)), baseVal.val, indexVal)); @@ -1484,9 +1481,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> } LoweredValInfo extractField( - LoweredTypeInfo fieldType, - LoweredValInfo base, - LoweredValInfo field) + LoweredTypeInfo fieldType, + LoweredValInfo base, + DeclRef<StructField> field) { switch (base.flavor) { @@ -1497,7 +1494,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> getBuilder()->emitFieldExtract( getSimpleType(fieldType), irBase, - (IRStructField*) getSimpleVal(context, field))); + getBuilder()->getDeclRefVal(field))); } break; @@ -1509,9 +1506,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> IRValue* irBasePtr = base.val; return LoweredValInfo::ptr( getBuilder()->emitFieldAddress( - getBuilder()->getPtrType(getSimpleType(fieldType)), + getPtrType(context, getSimpleType(fieldType)), irBasePtr, - (IRStructField*) getSimpleVal(context, field))); + getBuilder()->getDeclRefVal(field))); } break; } @@ -1598,7 +1595,7 @@ struct RValueExprLoweringVisitor : ExprLoweringVisitorBase<RValueExprLoweringVis auto builder = getBuilder(); - auto irIntType = builder->getBaseType(BaseType::Int); + auto irIntType = getIntType(context); UInt elementCount = (UInt)expr->elementCount; IRValue* irElementIndices[4]; @@ -1661,34 +1658,22 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> void insertBlock(IRBlock* block) { auto builder = getBuilder(); - auto parent = builder->parentInst; - IRBlock* prevBlock = nullptr; - IRFunc* parentFunc = nullptr; - - switch (parent->op) - { - case kIROp_Block: - prevBlock = (IRBlock*)parent; - parentFunc = prevBlock->getParent(); - break; - - default: - SLANG_UNEXPECTED("bad parent kind for block"); - return; - } + auto prevBlock = builder->block; + auto parentFunc = prevBlock->parentFunc; // If the previous block doesn't already have // a terminator instruction, then be sure to // emit a branch to the new block. - if (!isTerminatorInst(prevBlock->lastChild)) + if (!isTerminatorInst(prevBlock->lastInst)) { builder->emitBranch(block); } - builder->parentInst = parentFunc; - builder->addInst(block); - builder->parentInst = block; + parentFunc->addBlock(block); + + builder->func = parentFunc; + builder->block = block; } // Start a new block at the current location. @@ -2025,8 +2010,8 @@ top: auto loweredBase = swizzleInfo->base; // Load from the base value: - IRInst* irLeftVal = getSimpleVal(context, loweredBase); - auto irRightVal = getSimpleVal(context, right); + IRValue* irLeftVal = getSimpleVal(context, loweredBase); + IRValue* irRightVal = getSimpleVal(context, right); // Now apply the swizzle IRInst* irSwizzled = builder->emitSwizzleSet( @@ -2066,7 +2051,7 @@ top: emitCallToDeclRef( context, - builder->getVoidType(), + context->getSession()->getVoidType(), setterDeclRef, allArgs, subscriptInfo->genericArgCount); @@ -2153,8 +2138,67 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return LoweredValInfo(); } + bool isGlobalVarDecl(VarDeclBase* decl) + { + auto parent = decl->ParentDecl; + if (dynamic_cast<ModuleDecl*>(parent)) + { + // Variable declared at global scope? -> Global. + return true; + } + + return false; + } + + LoweredValInfo lowerGlobalVarDecl(VarDeclBase* decl) + { + auto varType = lowerSimpleType(context, decl->getType()); + + IRAddressSpace addressSpace = kIRAddressSpace_Default; + if (decl->HasModifier<HLSLGroupSharedModifier>()) + { + addressSpace = kIRAddressSpace_GroupShared; + } + + auto builder = getBuilder(); + auto irGlobal = builder->createGlobalVar(varType); + + if (decl) + { + builder->addHighLevelDeclDecoration(irGlobal, decl); + } + + if (auto layout = getLayout()) + { + builder->addLayoutDecoration(irGlobal, layout); + } + + // A global variable's SSA value is a *pointer* to + // the underlying storage. + auto globalVal = LoweredValInfo::ptr(irGlobal); + context->shared->declValues.Add( + DeclRef<VarDeclBase>(decl, nullptr), + globalVal); + + if( auto initExpr = decl->initExpr ) + { + // TODO: need to handle global with initializer! + } + + getBuilder()->getModule()->globalValues.Add(irGlobal); + + return globalVal; + } + LoweredValInfo visitVarDeclBase(VarDeclBase* decl) { + // Detect global (or effectively global) variables + // and handle them differently. + if (isGlobalVarDecl(decl)) + { + return lowerGlobalVarDecl(decl); + } + // A user-defined variable declaration will usually turn into // an `alloca` operation for the variable's storage, // plus some code to initialize it and then store to the variable. @@ -2219,6 +2263,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl) { +#if 0 // User-defined aggregate type: need to translate into // a corresponding IR aggregate type. @@ -2257,6 +2302,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> builder->addInst(irStruct); return LoweredValInfo::simple(irStruct); +#else + // TODO: What is there to do with a `struct` type? + return LoweredValInfo(); +#endif } // Sometimes we need to refer to a declaration the way that it would be specialized @@ -2592,7 +2641,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } void trySetMangledName( - IRInst* inst, + IRFunc* irFunc, Decl* decl) { // We want to generate a mangled name for the given declaration and attach @@ -2605,8 +2654,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> String mangledName = getMangledName(decl); - auto decoration = getBuilder()->addDecoration<IRMangledNameDecoration>(inst); - decoration->mangledName = mangledName; + irFunc->mangledName = mangledName; } @@ -2649,11 +2697,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // need to create an IR function here IRFunc* irFunc = subBuilder->createFunc(); - subBuilder->parentInst = irFunc; + subBuilder->func = irFunc; trySetMangledName(irFunc, decl); - List<IRType*> paramTypes; + List<RefPtr<Type>> paramTypes; // We first need to walk the generic parameters (if any) // because these will influence the declared type of @@ -2662,6 +2710,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> for( auto genericParamDecl : parameterLists.genericParams ) { UInt genericParamIndex = genericParamCounter++; +#if 0 if( auto genericTypeParamDecl = dynamic_cast<GenericTypeParamDecl*>(genericParamDecl) ) { // In the logical type for the function, a generic @@ -2675,10 +2724,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // to the appropriate generic parameter position. IRType* irParameterType = context->irBuilder->getGenericParameterType(genericParamIndex); - LoweredValInfo LoweredValInfo = LoweredValInfo::simple(irParameterType); + LoweredValInfo LoweredValInfo = LoweredValInfo::type(irParameterType); subContext->shared->declValues[makeDeclRef(genericTypeParamDecl)] = LoweredValInfo; } else +#endif { // TODO: handle the other cases here. SLANG_UNEXPECTED("generic parameter kind"); @@ -2702,7 +2752,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // // TODO: Is this the best representation we can use? - auto irPtrType = subBuilder->getPtrType(irParamType); + auto irPtrType = getPtrType(context, irParamType); paramTypes.Add(irPtrType); } } @@ -2726,14 +2776,15 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Instead, a setter always returns `void` // - irResultType = getBuilder()->getVoidType(); + irResultType = context->getSession()->getVoidType(); } - auto irFuncType = getBuilder()->getFuncType( + auto irFuncType = getFuncType( + context, paramTypes.Count(), paramTypes.Buffer(), irResultType); - irFunc->type.init(irFunc, irFuncType); + irFunc->type = irFuncType; if (!decl->Body) { @@ -2753,7 +2804,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // This is a function definition, so we need to actually // construct IR for the body... IRBlock* entryBlock = subBuilder->emitBlock(); - subBuilder->parentInst = entryBlock; + subBuilder->block = entryBlock; UInt paramTypeIndex = 0; for( auto paramInfo : parameterLists.params ) @@ -2771,7 +2822,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // // TODO: Is this the best representation we can use? - auto irPtrType = (IRPtrType*)irParamType; + auto irPtrType = irParamType.As<PtrType>(); IRParam* irParamPtr = subBuilder->emitParam(irPtrType); if(auto paramDecl = paramInfo.decl) @@ -2829,14 +2880,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // We need to carefully add a terminator instruction to the end // of the body, in case the user didn't do so. - if (!isTerminatorInst(subContext->irBuilder->parentInst->lastChild)) + if (!isTerminatorInst(subContext->irBuilder->block->lastInst)) { - if (irResultType->op == kIROp_VoidType) + if (irResultType->Equals(context->getSession()->getVoidType())) { + // `void`-returning function can get an implicit + // return on exit of the body statement. subContext->irBuilder->emitReturn(); } else { + // Value-returning function is expected to `return` + // on every control-flow path. We need to enforce + // this by putting an `unreachable` terminator here, + // and then emit a dataflow error if this block + // can't be eliminated. SLANG_UNEXPECTED("Needed a return here"); subContext->irBuilder->emitReturn(); } @@ -2845,7 +2903,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addHighLevelDeclDecoration(irFunc, decl); - getBuilder()->addInst(irFunc); + getBuilder()->getModule()->globalValues.Add(irFunc); return LoweredValInfo::simple(irFunc); } @@ -2878,7 +2936,6 @@ LoweredValInfo ensureDecl( IRBuilder subIRBuilder; subIRBuilder.shared = context->irBuilder->shared; - subIRBuilder.parentInst = subIRBuilder.shared->module; IRGenContext subContext = *context; @@ -2997,15 +3054,14 @@ IRModule* lowerEntryPointToIR( SharedIRBuilder sharedBuilderStorage; SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->module = nullptr; + sharedBuilder->session = entryPoint->compileRequest->mSession; IRBuilder builderStorage; IRBuilder* builder = &builderStorage; builder->shared = sharedBuilder; - builder->parentInst = nullptr; IRModule* module = builder->createModule(); sharedBuilder->module = module; - builder->parentInst = module; context->irBuilder = builder; |
