#!/usr/bin/python3

import pyopencl
import pyopencl.array
import numpy as np
import sys

test_failed=False
#check for device - use 'beignet' not 'Intel' to distinguish from any future ICR package
beignet_platforms=[p for p in pyopencl.get_platforms() if 'beignet' in p.version.lower()]
if len(beignet_platforms)!=1:
    sys.exit(1)
beignet_devices=beignet_platforms[0].get_devices()
if len(beignet_devices)==0:
    sys.exit(77)#1 platform with 0 devices is expected if no hardware -> skip
if not all('beignet' in d.version.lower() for d in beignet_devices):
    sys.exit(1)

#check that 2 instances at the same time works
ctx1=pyopencl.Context([beignet_devices[0]])
cq1=pyopencl.CommandQueue(ctx1)
ctx2=pyopencl.Context([beignet_devices[0]])
cq2=pyopencl.CommandQueue(ctx2)
a=np.random.randn(10**6).astype(np.dtype('float32'))
aCL1=pyopencl.array.to_device(cq1,a)
aCL2=pyopencl.array.to_device(cq2,a)
bCL1=pyopencl.array.to_device(cq1,a)
bCL2=pyopencl.array.to_device(cq2,a)
f1=pyopencl.elementwise.ElementwiseKernel(ctx1,pyopencl.tools.dtype_to_ctype(aCL1.dtype)+" *a,"+pyopencl.tools.dtype_to_ctype(aCL1.dtype)+" *b","b[i]=cos(a[i])+sin(a[i])+sqrt(a[i])","cossinsqrt1")
f2=pyopencl.elementwise.ElementwiseKernel(ctx2,pyopencl.tools.dtype_to_ctype(aCL2.dtype)+" *a,"+pyopencl.tools.dtype_to_ctype(aCL2.dtype)+" *b","b[i]=cos(a[i])+sin(a[i])-sqrt(a[i])","cossinsqrt2")
b1=np.cos(a)+np.sin(a)+np.sqrt(a)
b2=np.cos(a)+np.sin(a)-np.sqrt(a)
w=f1(aCL1,bCL1)
f2(aCL2,bCL2).wait()
w.wait()
error1=np.max(np.nan_to_num(np.abs(bCL1.get()-b1)))
error2=np.max(np.nan_to_num(np.abs(bCL2.get()-b2)))
print("rounding errors:",error1,error2)
if error1+error2>1e-5:
    test_failed=True

#check that invalid input is an error, not a crash
a=np.random.randn(10**6).astype(np.dtype('float32'))
aCL=pyopencl.array.to_device(cq1,a)
bCL=pyopencl.array.to_device(cq1,a)
flist=[pyopencl.elementwise.ElementwiseKernel(ctx1,pyopencl.tools.dtype_to_ctype(aCL.dtype)+" *a,"+pyopencl.tools.dtype_to_ctype(aCL.dtype)+" *b","b[i]=cos(a[i])+/-","badsyntax"),pyopencl.elementwise.ElementwiseKernel(ctx1,pyopencl.tools.dtype_to_ctype(aCL.dtype)+" *a,"+pyopencl.tools.dtype_to_ctype(aCL.dtype)+" *b","b[i]=cos(aaaaa[i])","badvar"),pyopencl.elementwise.ElementwiseKernel(ctx1,pyopencl.tools.dtype_to_ctype(aCL.dtype)+" *a,"+pyopencl.tools.dtype_to_ctype(aCL.dtype)+" *b","b[i]=cossss(a[i])","badfunc")]#,pyopencl.elementwise.ElementwiseKernel(ctx1,pyopencl.tools.dtype_to_ctype(aCL.dtype)+" *a,"+pyopencl.tools.dtype_to_ctype(aCL.dtype)+" *b","b[i]=a[i]+1.0","double")]
for n,f in enumerate(flist):
    try:
        print("running invalid-kernels test ",n)
        f(aCL,bCL).wait()
        print("unexpected success")
        test_failed=True
    except pyopencl.Error as e:
        print("expected fail:",e)
print("end of test")
sys.exit(1 if test_failed else 0)
