How to estimate the mean with a truncated dataset using python ?

Published: May 09, 2020

Examples of how to estimate the mean with a truncated dataset using python for data generated from a normal distribution:

1 -- Create a dataset of random numbers from a normal distribution

Create a set of random numbers distributed according to a normal distribution:

````import scipy.stats`
`import numpy as np`
`import matplotlib.pyplot as plt`

`mu_0 = 2.0`
`srd_0 = 4.0`

`data = np.random.randn(100000)`
`data = data * srd_0 + mu_0`

`data = data.reshape(-1, 1)`
```

2 -- Calculate the mean for a complete dataset

````hx, hy, _ = plt.hist(data, bins=50, density=1,color="lightblue")`

`plt.ylim(0.0,max(hx)+0.05)`
`plt.title('Mean estimation from a censored dataset')`
`plt.grid()`

`plt.xlim(-4*srd_0,4*srd_0)`

`plt.savefig("censored_data_01.png", bbox_inches='tight')`
`plt.show()`
```

````print('Mean (Complete Data): ', np.mean(data))`
`print('Std (Complete Data): ',np.std(data))`
```

returns

````Mean (Complete Data):  2.0104953814107076`
`Std (Complete Data):  3.9865860565580946`
```

3 -- Calculate the mean for an incomplete dataset

````max_x = 5`

`data_trunc_2 = np.copy(data)`
`data_trunc_2[data_trunc_2 > max_x] = max_x`

`data_trunc_2 = data_trunc_2.reshape(-1, 1)`

`hx, hy, _ = plt.hist(data_trunc_2, density=1, bins=50,color="lightblue")`

`plt.ylim(0.0,max(hx)+0.05)`
`plt.title('Mean estimation from a censored dataset')`
`plt.grid()`

`plt.xlim(-4*srd_0,4*srd_0)`

`plt.savefig("censored_data_02.png", bbox_inches='tight')`
`plt.show()`
```

````data_trunc_2_not_cens = data_trunc_2[data_trunc_2<max_x]`
`data_trunc_2_cens = data_trunc_2[data_trunc_2==max_x]`

`x = np.linspace(-10, 10, 1000, endpoint=True)`

`#print(data_trunc_2.shape)`
`#print(data_trunc_2_cens.shape)`
`#print(data_trunc_2_not_cens.shape)`

`y = []`
`for i in x:`
`    p1 = np.log(scipy.stats.norm.pdf(data_trunc_2_not_cens,i,4)).sum()`
`    p2 = np.log(1.0 - scipy.stats.norm.cdf(data_trunc_2_cens,i,4)).sum()`
`    y.append(p1+p2)`

`plt.plot(x,y)`

`plt.title('Mean estimation from a censored dataset')`

`plt.savefig("censored_data_03.png", bbox_inches='tight')`
`plt.show()`

`print('Mean (Censored Data): ', np.mean(data_trunc_2))`
`print('Std (Censored Data): ',np.std(data_trunc_2))`
```

returns

````Mean (Censored Data):  1.4878815869535522`
`Std (Censored Data):  3.2348435672740012`
```

````y_min = y.index(max(y))`
`print('Mean (max log likelihood): ', x[y_min])`
```

returns

````Mean (max log likelihood):  2.0120120120120113`
```

Image

of