How to calculate and plot the derivative of a function using matplotlib and python ?


To calculate the derivative of a function f at a given point x, a solution with python is to use the scipy function called derivative. Let's consider the following function: $f(x)=x^2$ quand x=2.

How to calculate and plot the derivative of a function using matplotlib and python ?
How to calculate and plot the derivative of a function using matplotlib and python ?

>>> from scipy import misc 
>>> 
>>> def fonction(x): 
...     return x*x
...
>>> misc.derivative(fonction, 2.0) 
4.0

First, the misc module is imported with the command "from scipy import misc" and then a simple function is defined that return here the value of f for a given x:

def fonction(x): 
    return x*x

To get the value of the derivative of f at a given x, the function misc.derivative(fonction, x) can then be used.

It is then possible to extend this simple example and to plot the result using matplotlib:

from pylab import *
from scipy import misc

ax = subplot(111)

def fonction(x):
    return 3*x*x+2*x+1

x = arange(-2.0, 2.0, 0.01)

y = fonction(x)

plot(x, y,'r-')

yp = misc.derivative(fonction, x)

plot(x, yp,'b-')

grid(True)

ax.spines['left'].set_position('zero')
ax.spines['right'].set_color('none')
ax.spines['bottom'].set_position('zero')
ax.spines['top'].set_color('none')

text(-0.75, 6.0,
     r'$f(x)=3x^2+2x+1$', horizontalalignment='center',
     fontsize=18,color='red')

text(-1.0, -8.0,
     r"$f'(x)=6x+2$", horizontalalignment='center',
     fontsize=18,color='blue')

show()
Image

of