diff --git a/examples/demo_mandelbrot.py b/examples/demo_mandelbrot.py index 1c04da612..69b30a1c5 100644 --- a/examples/demo_mandelbrot.py +++ b/examples/demo_mandelbrot.py @@ -38,7 +38,7 @@ def calc_fractal_opencl(q, maxiter): ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) - output = np.empty(q.shape, dtype=np.uint16) + output = np.empty(q.shape, dtype=np.uint32) mf = cl.mem_flags q_opencl = cl.Buffer(ctx, mf.READ_ONLY | mf.COPY_HOST_PTR, hostbuf=q) @@ -48,12 +48,13 @@ def calc_fractal_opencl(q, maxiter): ctx, """ #pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable - __kernel void mandelbrot(__global float2 *q, - __global ushort *output, ushort const maxiter) + #pragma OPENCL EXTENSION cl_khr_fp64 : enable + __kernel void mandelbrot(__global double2 *q, + __global uint *output, uint const maxiter) { int gid = get_global_id(0); - float nreal, real = 0; - float imag = 0; + double nreal, real = 0; + double imag = 0; output[gid] = 0; @@ -70,7 +71,7 @@ def calc_fractal_opencl(q, maxiter): ).build() prg.mandelbrot( - queue, output.shape, None, q_opencl, output_opencl, np.uint16(maxiter) + queue, output.shape, None, q_opencl, output_opencl, np.uint32(maxiter) ) cl.enqueue_copy(queue, output, output_opencl).wait() @@ -82,7 +83,7 @@ def calc_fractal_serial(q, maxiter): # calculate z using pure python on a numpy array # note that, unlike the other two implementations, # the number of iterations per point is NOT constant - z = np.zeros(q.shape, complex) + z = np.zeros(q.shape, np.complex128) output = np.resize( np.array( 0, @@ -107,7 +108,7 @@ def calc_fractal_numpy(q, maxiter): ), q.shape, ) - z = np.zeros(q.shape, np.complex64) + z = np.zeros(q.shape, np.complex128) for it in range(maxiter): z = z * z + q @@ -129,7 +130,7 @@ def draw(self, x1, x2, y1, y2, maxiter=30): # draw the Mandelbrot set, from numpy example xx = np.arange(x1, x2, (x2 - x1) / w) yy = np.arange(y2, y1, (y1 - y2) / h) * 1j - q = np.ravel(xx + yy[:, np.newaxis]).astype(np.complex64) + q = np.ravel(xx + yy[:, np.newaxis]).astype(np.complex128) start_main = time.time() output = calc_fractal(q, maxiter)