summaryrefslogtreecommitdiffstats
path: root/Whisper/ML/TensorGpuViews.h
blob: ef26473ebe074c22bbd6dd03ec28e8d7079560e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#pragma once
#include <stdint.h>
#include "../D3D/device.h"

namespace DirectCompute
{
	class TensorGpuViews
	{
	protected:
		CComPtr<ID3D11ShaderResourceView> srv;
		CComPtr<ID3D11UnorderedAccessView> uav;

	public:

		operator ID3D11ShaderResourceView* ( ) const { return srv; }
		operator ID3D11UnorderedAccessView* ( ) const { return uav; }

		HRESULT create( ID3D11Buffer* buffer, DXGI_FORMAT format, size_t countElements, bool makeUav );

		void clear()
		{
			srv = nullptr;
			uav = nullptr;
		}

		void setGpuViews( ID3D11ShaderResourceView* read, ID3D11UnorderedAccessView* write = nullptr )
		{
			srv = read;
			uav = write;
		}
	};
}