From 6f681279d99e72e717bb2b91763b80e570ae725b Mon Sep 17 00:00:00 2001 From: Tim Foley Date: Wed, 20 Dec 2017 17:35:10 -0800 Subject: IR: fixes for subscript accessors (#322) * IR: fixes for subscript accessors Fixes #320 This is a bunch of fixes for handling of `__subscript` operations on builtin types (notably `RWStructuredBuffer` and `StructuredBuffer` at this point). - Automatically add a `GetterDecl` to any subscript decalratio was declithout any accessors. This avoids hitting a null- dereference in the emit logic. - Add a notion of a `RefAccessor` (declared with `ref`) as a peer to getters and setters. The idea is that a `ref` accessor returns a pointer to the element data, so that it can be used for both getting and setting values. This is closer to the behavior of `RWStructuredBuffer` element access in HLSL. - Fixes for dealing with "access chains" where there might be a combination of a subscript (where the is a `get` and `set` but no `ref`) and member access, so that we have to read the base value into a temp, modify it, and then write it back. - This logic is still a bit of a mess, so we will eventually want to take a more consistent pass over this to deal with how we "materialize" values for setters. - Update `RWStructuredBuffer` to have a `ref` accessor, and then fix up the IR tests to handle the new opcode that I added for it. - Note: I didn't handle this as an intrinsic simply because the `tests/ir/*` tests aren't really set up to handle builtins with ugly mangled names. * Fixup: type error in VM for buffer element ref I was using the result type of the op as the element type for computing the element address, but the result type is a pointer to the real element type. This caused test failures on 64-bit platforms, where the stride of the buffer in the `ir/factorial` test needs to be 4. The fix is to assume the result type is a pointer, and extract the pointed-to type out of that. --- source/slang/lower-to-ir.cpp | 263 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 221 insertions(+), 42 deletions(-) (limited to 'source/slang/lower-to-ir.cpp') diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 8c8b34908..8fc1239cd 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -207,6 +207,9 @@ struct BoundMemberInfo : ExtendedValueInfo // The (AST-level) declaration reference. DeclRef declRef; + + // The type of this value + RefPtr type; }; // Represents the result of a swizzle operation in @@ -362,6 +365,7 @@ LoweredValInfo emitCallToVal( switch (funcVal.flavor) { case LoweredValInfo::Flavor::None: + SLANG_UNEXPECTED("null function"); default: return LoweredValInfo::simple( builder->emitCallInst(type, getSimpleVal(context, funcVal), argCount, args)); @@ -531,25 +535,60 @@ LoweredValInfo emitCallToDeclRef( if (auto subscriptDeclRef = funcDeclRef.As()) { - // A reference to a subscript declaration is potentially a - // special case, if we have more than just a getter. + // A reference to a subscript declaration is a special case, + // because it is not possible to call a subscript directly; + // we must call one of its accessors. + // + // TODO: everything here will also apply to propery declarations + // once we have them, so some of this code might be shared + // some day. DeclRef getterDeclRef; bool justAGetter = true; for (auto accessorDeclRef : getMembersOfType(subscriptDeclRef)) { + // If the subscript declares a `ref` accessor, then we can just + // invoke that directly to get an l-value we can use. + if(auto refAccessorDeclRef = accessorDeclRef.As()) + { + // The `ref` accessor will return a pointer to the value, so + // we need to reflect that in the type of our `call` instruction. + RefPtr ptrType = context->getSession()->getPtrType(type); + + // Rather than call `emitCallToVal` here, we make a recursive call + // to `emitCallToDeclRef` so that it can handle things like intrinsic-op + // modifiers attached to the acecssor. + LoweredValInfo callVal = emitCallToDeclRef( + context, + ptrType, + refAccessorDeclRef, + funcExpr, + argCount, + args); + + // The result from the call needs to be implicitly dereferenced, + // so that it can work as an l-value of the desired result type. + return LoweredValInfo::ptr(getSimpleVal(context, callVal)); + } + + // If we don't find a `ref` accessor, then we want to track whether + // this subscript has any accessors other than `get` (assuming + // that everything except `get` can be used for setting...). + if (auto foundGetterDeclRef = accessorDeclRef.As()) { + // We found a getter. getterDeclRef = foundGetterDeclRef; } else { + // There was something other than a getter, so we can't + // invoke an accessor just now. justAGetter = false; - break; } } - if (!justAGetter) + if (!justAGetter || !getterDeclRef) { // We can't perform an actual call right now, because // this expression might appear in an r-value or l-value @@ -575,8 +614,7 @@ LoweredValInfo emitCallToDeclRef( // Otherwise we are just call the getter, and so that // is what we need to be emitting a call to... - if (getterDeclRef) - funcDeclRef = getterDeclRef; + funcDeclRef = getterDeclRef; } auto funcDecl = funcDeclRef.getDecl(); @@ -660,7 +698,64 @@ LoweredValInfo emitCallToDeclRef( return emitCallToDeclRef(context, type, funcDeclRef, funcExpr, args.Count(), args.Buffer()); } -IRValue* getSimpleVal(IRGenContext* context, LoweredValInfo lowered) +LoweredValInfo extractField( + IRGenContext* context, + Type* fieldType, + LoweredValInfo base, + DeclRef field) +{ + IRBuilder* builder = context->irBuilder; + + switch (base.flavor) + { + default: + { + IRValue* irBase = getSimpleVal(context, base); + return LoweredValInfo::simple( + builder->emitFieldExtract( + fieldType, + irBase, + builder->getDeclRefVal(field))); + } + break; + + case LoweredValInfo::Flavor::BoundMember: + case LoweredValInfo::Flavor::BoundSubscript: + { + // The base value is one that is trying to defer a get-vs-set + // decision, so we will need to do the same. + + RefPtr boundMemberInfo = new BoundMemberInfo(); + boundMemberInfo->type = fieldType; + boundMemberInfo->base = base; + boundMemberInfo->declRef = field; + + context->shared->extValues.Add(boundMemberInfo); + return LoweredValInfo::boundMember(boundMemberInfo); + } + break; + + case LoweredValInfo::Flavor::Ptr: + { + // We are "extracting" a field from an lvalue address, + // which means we should just compute an lvalue + // representing the field address. + IRValue* irBasePtr = base.val; + return LoweredValInfo::ptr( + builder->emitFieldAddress( + context->getSession()->getPtrType(fieldType), + irBasePtr, + builder->getDeclRefVal(field))); + } + break; + } +} + + + +LoweredValInfo materialize( + IRGenContext* context, + LoweredValInfo lowered) { auto builder = context->irBuilder; @@ -668,13 +763,9 @@ top: switch(lowered.flavor) { case LoweredValInfo::Flavor::None: - return nullptr; - case LoweredValInfo::Flavor::Simple: - return lowered.val; - case LoweredValInfo::Flavor::Ptr: - return builder->emitLoad(lowered.val); + return lowered; case LoweredValInfo::Flavor::BoundSubscript: { @@ -693,7 +784,27 @@ top: } SLANG_UNEXPECTED("subscript had no getter"); - UNREACHABLE_RETURN(nullptr); + UNREACHABLE_RETURN(LoweredValInfo()); + } + break; + + case LoweredValInfo::Flavor::BoundMember: + { + auto boundMemberInfo = lowered.getBoundMemberInfo(); + auto base = materialize(context, boundMemberInfo->base); + + auto declRef = boundMemberInfo->declRef; + if( auto fieldDeclRef = declRef.As() ) + { + lowered = extractField(context, boundMemberInfo->type, base, fieldDeclRef); + goto top; + } + else + { + + SLANG_UNEXPECTED("unexpected member flavor"); + UNREACHABLE_RETURN(LoweredValInfo()); + } } break; @@ -701,13 +812,39 @@ top: { auto swizzleInfo = lowered.getSwizzledLValueInfo(); - return builder->emitSwizzle( + return LoweredValInfo::simple(builder->emitSwizzle( swizzleInfo->type, getSimpleVal(context, swizzleInfo->base), swizzleInfo->elementCount, - swizzleInfo->elementIndices); + swizzleInfo->elementIndices)); } + default: + SLANG_UNEXPECTED("unhandled value flavor"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + +} + +IRValue* getSimpleVal(IRGenContext* context, LoweredValInfo lowered) +{ + auto builder = context->irBuilder; + + // First, try to eliminate any "bound" operations along the chain, + // so that we are dealing with an ordinary value, or an l-value pointer. + lowered = materialize(context, lowered); + + switch(lowered.flavor) + { + case LoweredValInfo::Flavor::None: + return nullptr; + + case LoweredValInfo::Flavor::Simple: + return lowered.val; + + case LoweredValInfo::Flavor::Ptr: + return builder->emitLoad(lowered.val); + default: SLANG_UNEXPECTED("unhandled value flavor"); UNREACHABLE_RETURN(nullptr); @@ -1042,6 +1179,7 @@ struct ExprLoweringVisitorBase : ExprVisitor else if (auto callableDeclRef = declRef.As()) { RefPtr boundMemberInfo = new BoundMemberInfo(); + boundMemberInfo->type = nullptr; boundMemberInfo->base = loweredBase; boundMemberInfo->declRef = callableDeclRef; return LoweredValInfo::boundMember(boundMemberInfo); @@ -1543,33 +1681,7 @@ struct ExprLoweringVisitorBase : ExprVisitor LoweredValInfo base, DeclRef field) { - switch (base.flavor) - { - default: - { - IRValue* irBase = getSimpleVal(context, base); - return LoweredValInfo::simple( - getBuilder()->emitFieldExtract( - getSimpleType(fieldType), - irBase, - getBuilder()->getDeclRefVal(field))); - } - break; - - case LoweredValInfo::Flavor::Ptr: - { - // We are "extracting" a field from an lvalue address, - // which means we should just compute an lvalue - // representing the field address. - IRValue* irBasePtr = base.val; - return LoweredValInfo::ptr( - getBuilder()->emitFieldAddress( - context->getSession()->getPtrType(getSimpleType(fieldType)), - irBasePtr, - getBuilder()->getDeclRefVal(field))); - } - break; - } + return Slang::extractField(context, getSimpleType(fieldType), base, field); } LoweredValInfo visitStaticMemberExpr(StaticMemberExpr* expr) @@ -2392,6 +2504,28 @@ void lowerStmt( return visitor.dispatch(stmt); } +static LoweredValInfo maybeMoveMutableTemp( + IRGenContext* context, + LoweredValInfo const& val) +{ + switch(val.flavor) + { + case LoweredValInfo::Flavor::Ptr: + return val; + + default: + { + IRValue* irVal = getSimpleVal(context, val); + auto type = irVal->getType(); + auto var = createVar(context, type); + + assign(context, var, LoweredValInfo::simple(irVal)); + return var; + } + break; + } +} + void assign( IRGenContext* context, LoweredValInfo const& inLeft, @@ -2412,6 +2546,7 @@ top: case LoweredValInfo::Flavor::Ptr: case LoweredValInfo::Flavor::SwizzledLValue: case LoweredValInfo::Flavor::BoundSubscript: + case LoweredValInfo::Flavor::BoundMember: { builder->emitStore( left.val, @@ -2511,6 +2646,43 @@ top: } break; + case LoweredValInfo::Flavor::BoundMember: + { + auto boundMemberInfo = left.getBoundMemberInfo(); + + // If we hit this case, then it means that we are trying to set + // a single field in someting that is not atomically set-able. + // (e.g., an element of a value where the `subscript` operation + // has `get` and `set` but not a `ref` accessor). + // + // We need to read the entire base value out, modify the field + // we care about, and then write it back. + + auto declRef = boundMemberInfo->declRef; + if( auto fieldDeclRef = declRef.As() ) + { + // materialize the base value and move it into + // a mutable temporary if needed + auto baseVal = boundMemberInfo->base; + auto tempVal = maybeMoveMutableTemp(context, materialize(context, baseVal)); + + // extract the field l-value out of the temporary + auto tempFieldVal = extractField(context, boundMemberInfo->type, tempVal, fieldDeclRef); + + // assign to the field of the temporary l-value + assign(context, tempFieldVal, right); + + // write back the modified temporary to the base l-value + assign(context, baseVal, tempVal); + } + else + { + SLANG_UNEXPECTED("handled member flavor"); + } + + } + break; + default: SLANG_UNIMPLEMENTED_X("assignment"); break; @@ -3303,6 +3475,13 @@ struct DeclLoweringVisitor : DeclVisitor irResultType = context->getSession()->getVoidType(); } + if( auto refAccessorDecl = dynamic_cast(decl) ) + { + // A `ref` accessor needs to return a *pointer* to the value + // being accessed, rather than a simple value. + irResultType = context->getSession()->getPtrType(irResultType); + } + auto irFuncType = getFuncType( context, paramTypes.Count(), -- cgit v1.2.3