5 차원 배열 CUDA

2658 단어 CUDA
#include 
#include 

template
T***** create_5d_flat(int a, int b, int c, int d,int e) {
	T *base;
	cudaError_t err = cudaMallocManaged(&base, a*b*c*d*e * sizeof(T));
	assert(err == cudaSuccess);
	T *****ary;
	err = cudaMallocManaged(&ary, (a + a * b + a * b*c + a * b*c*d) * sizeof(T*));
	assert(err == cudaSuccess);
	for (int i = 0; i < a; i++) 
	{
		ary[i] = (T ****)((ary + a) + i * b);
		for (int j = 0; j < b; j++) 
		{
			ary[i][j] = (T ***)((ary + a + a * b) + (i * b + j)* c);
			for (int k = 0; k < c; k++)
			{
				ary[i][j][k] = (T **)((ary + a + a * b+a*b*c) + ((i * b + j)* c+k)*d      );
				for (int l = 0; l < d; l++)
					ary[i][j][k][l] = base + (((i*b + j)*c + k)*d+l)*e;
			}

		}
	}
	return ary;
}

template
void free_5d_flat(T***** ary) {
	if (ary[0][0][0][0]) cudaFree(ary[0][0][0][0]);
	if (ary) cudaFree(ary);
}


template
__global__ void fill(T***** data, int a, int b, int c, int d, int e) {
	unsigned long long int val = 0;
	for (int i = 0; i < a; i++)
		for (int j = 0; j < b; j++)
			for (int k = 0; k < c; k++)
				for (int l = 0; l < d; l++)
					for (int m = 0; m < e; m++)
						data[i][j][k][l][m] = val++;
}

void report_gpu_mem()
{
	size_t free, total;
	cudaMemGetInfo(&free, &total);
	std::cout << "Free = " << free << " Total = " << total << std::endl;
}

int main() {
	report_gpu_mem();

	unsigned long long int *****data2;
	std::cout << "allocating..." << std::endl;
	data2 = create_5d_flat(6, 9, 3, 5,4);

	report_gpu_mem();

	fill << <1, 1 >> > (data2, 6, 9, 3, 5, 4);
	cudaError_t err = cudaDeviceSynchronize();
	assert(err == cudaSuccess);






	std::cout << "validating..." << std::endl;




	//std::cout << *(data2[0] ) << std::endl;
	//std::cout << *data2[1] << std::endl;
	//std::cout << &data2[1][0][0][0] << std::endl;

	for (int i = 0; i < 6; i++)
		for (int j = 0; j < 9; j++)
			for (int k = 0; k < 3; k++)
				for (int l = 0; l < 5; l++)
					for (int m = 0; m < 4; m++)
						std::cout << data2[i][j][k][l][m] << std::endl;


	for (int i = 0; i < 6* 9* 3* 5* 4; i++)
		if (*(data2[0][0][0][0] + i) != i) 
		{
			std::cout << "mismatch at " << i << " was " << *(data2[0][0][0][0] + i) << std::endl;
			return -1;
		}

	free_5d_flat(data2);
	return 0;
}

좋은 웹페이지 즐겨찾기