summaryrefslogtreecommitdiffstats
path: root/Whisper/Utils/parallelFor.cpp
blob: c2b324b29c6ae500249c425ae89c0bf1d36e712a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#include "stdafx.h"
#include "parallelFor.h"

namespace
{
	class alignas( 64 ) ParallelForContext
	{
		volatile long threadIndex;
		volatile HRESULT status;

		alignas( 64 ) void* const context;
		const Whisper::pfnParallelForCallback pfn;

		static void __stdcall callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work );

	public:

		ParallelForContext( void* ctx, Whisper::pfnParallelForCallback pfn );

		PTP_WORK createWork();

		HRESULT getStatus() const;
	};

	ParallelForContext::ParallelForContext( void* ctx, Whisper::pfnParallelForCallback callback ) :
		threadIndex( 1 ),
		status( S_FALSE ),
		context( ctx ),
		pfn( callback )
	{ }

	PTP_WORK ParallelForContext::createWork()
	{
		return CreateThreadpoolWork( &callbackStatic, this, nullptr );
	}

	void __stdcall ParallelForContext::callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work )
	{
		ParallelForContext& context = *(ParallelForContext*)pv;
		int ith = InterlockedIncrement( &context.threadIndex );
		ith--;
		const HRESULT hr = context.pfn( ith, context.context );
		if( SUCCEEDED( hr ) )
			return;
		InterlockedCompareExchange( &context.status, hr, S_FALSE );
	}

	HRESULT ParallelForContext::getStatus() const
	{
		const HRESULT hr = status;
		if( SUCCEEDED( hr ) )
			return S_OK;
		return hr;
	}
}

namespace Whisper
{
	HRESULT parallelFor( pfnParallelForCallback pfn, int threadsCount, void* ctx )
	{
		if( threadsCount < 1 )
			return E_BOUNDS;
		if( threadsCount == 1 )
			return pfn( 0, ctx );

		ParallelForContext context{ ctx, pfn };

		PTP_WORK const pw = context.createWork();
		if( nullptr == pw )
			return getLastHr();

		for( int i = 1; i < threadsCount; i++ )
			SubmitThreadpoolWork( pw );

		const HRESULT hr0 = pfn( 0, ctx );

		WaitForThreadpoolWorkCallbacks( pw, FALSE );
		CloseThreadpoolWork( pw );

		if( FAILED( hr0 ) )
			return hr0;
		return context.getStatus();
	}
}

using namespace Whisper;

ThreadPoolWork::~ThreadPoolWork()
{
	if( nullptr != work )
	{
		CloseThreadpoolWork( work );
		work = nullptr;
	}
}

HRESULT ThreadPoolWork::create()
{
	if( nullptr == work )
	{
		work = CreateThreadpoolWork( &callbackStatic, this, nullptr );
		if( nullptr != work )
			return S_OK;
		return getLastHr();
	}
	return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED );
}

HRESULT ThreadPoolWork::parallelFor( int threadsCount ) noexcept
{
	if( nullptr != work )
	{
		if( threadsCount <= 1 )
			return threadPoolCallback( 0 );

		threadIndex = 1;
		status = S_FALSE;
		for( int i = 1; i < threadsCount; i++ )
			SubmitThreadpoolWork( work );

		const HRESULT hr0 = threadPoolCallback( 0 );

		WaitForThreadpoolWorkCallbacks( work, FALSE );

		if( FAILED( hr0 ) )
			return hr0;
		if( SUCCEEDED( status ) )
			return S_OK;
		return status;
	}

	return OLE_E_BLANK;
}

void __stdcall ThreadPoolWork::callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work )
{
	ThreadPoolWork* tpw = (ThreadPoolWork*)pv;
	int ith = InterlockedIncrement( &tpw->threadIndex );
	ith--;
	const HRESULT hr = tpw->threadPoolCallback( ith );
	if( SUCCEEDED( hr ) )
		return;
	InterlockedCompareExchange( &tpw->status, hr, S_FALSE );
}