numba Examples
numba Examples
conda update numba
A Simple Function
Suppose we want to write an image-processing function in Python. Here’s how it might look.
import numpy
def filter2d(image, filt):
M, N = image.shape
Mf, Nf = filt.shape
Mf2 = Mf // 2
Nf2 = Nf // 2
result = numpy.zeros_like(image)
for i in range(Mf2, M - Mf2):
for j in range(Nf2, N - Nf2):
num = 0.0
for ii in range(Mf):
for jj in range(Nf):
num += (filt[Mf-1-ii, Nf-1-jj] * image[i-Mf2+ii, j-Nf2+jj])
result[i, j] = num
return result
# This kind of quadruply-nested for-loop is going to be quite slow.
# Using Numba we can compile this code to LLVM which then gets
# compiled to machine code:
from numba import double, jit
fastfilter_2d = jit(double[:,:](double[:,:], double[:,:]))(filter2d)
# Now fastfilter_2d runs at speeds as if you had first translated
# it to C, compiled the code and wrapped it with Python
image = numpy.random.random((100, 100))
filt = numpy.random.random((10, 10))
res = fastfilter_2d(image, filt)
Numba actually produces two functions. The first function is the low-level compiled version of filter2d. The second function is the Python wrapper to that low-level function so that the function can be called from Python. The first function can be called from other numba functions to eliminate all python overhead in function calling.
Objects
# -*- coding: utf-8 -*-
from __future__ import print_function, division, absolute_import
from numba import jit
class MyClass(object):
def mymethod(self, arg):
return arg * 2
@jit
def call_method(obj):
print(obj.mymethod("hello")) # object result
mydouble = obj.mymethod(10.2) # native double
print(mydouble * 2) # native multiplication
call_method(MyClass())
UFuncs
from numba import vectorize
from numba import autojit, double, jit
import math
import numpy as np
@vectorize(['f8(f8)','f4(f4)'])
def sinc(x):
if x == 0:
return 1.0
else:
return math.sin(x*math.pi) / (x*math.pi)
@vectorize(['int8(int8,int8)',
'int16(int16,int16)',
'int32(int32,int32)',
'int64(int64,int64)',
'f4(f4,f4)',
'f8(f8,f8)'])
def add(x,y):
return x + y
@vectorize(['f8(f8)','f4(f4)'])
def logit(x):
return math.log(x / (1-x))
@vectorize(['f8(f8)','f4(f4)'])
def expit(x):
if x > 0:
x = math.exp(x)
return x / (1 + x)
else:
return 1 / (1 + math.exp(-x))
@jit('f8(f8,f8[:])')
def polevl(x, coef):
N = len(coef)
ans = coef[0]
i = 1
while i < N:
ans = ans * x + coef[i]
i += 1
return ans
@jit('f8(f8,f8[:])')
def p1evl(x, coef):
N = len(coef)
ans = x + coef[0]
i = 1
while i < N:
ans = ans * x + coef[i]
i += 1
return ans
PP