Source code for covid19_inference.plotting

import datetime

import numpy as np
import matplotlib
import matplotlib.pyplot as plt


[docs]def get_all_free_RVs_names(model): """ Returns the names of all free parameters of the model Parameters ---------- model: pm.Model instance Returns ------- : list of variable names """ varnames = [str(x).replace('_log__', '') for x in model.free_RVs] return varnames
[docs]def get_prior_distribution(model, x, varname): """ Given a model and variable name, returns the prior distribution evaluated at x. Parameters ---------- model: pm.Model instance x: list or array varname: string Returns ------- : array """ return np.exp(model[varname].distribution.logp(x).eval())
[docs]def plot_hist(model, trace, ax, varname, colors = ('tab:blue', 'tab:orange'), bins = 50): """ Plots one histogram of the prior and posterior distribution of the variable varname. Parameters ---------- model: pm.Model instance trace: trace of the model ax: matplotlib.axes instance varname: string colors: list with 2 colornames bins: number or array passed to np.hist Returns ------- None """ if len(trace[varname].shape) >= 2: print('Dimension of {} larger than one, skipping'.format(varname)) ax.set_visible(False) return ax.hist(trace[varname], bins=bins, density=True, color=colors[1], label='Posterior') limits = ax.get_xlim() x = np.linspace(*limits, num=100) try: ax.plot(x, get_prior_distribution(model, x, varname), label='Prior', color=colors[0], linewidth=3) except: pass ax.set_xlim(*limits) ax.set_ylabel('Density') ax.set_xlabel(varname)
[docs]def plot_cases(trace, new_cases_obs, date_begin_sim, diff_data_sim, start_date_plot=None, end_date_plot=None, ylim=None, week_interval=None, colors = ('tab:blue', 'tab:orange'), country = 'Germany'): """ Plots the new cases, the fit, forecast and lambda_t evolution Parameters ---------- trace : trace returned by model new_cases_obs : array date_begin_sim : datetime.datetime diff_data_sim : float Difference in days between the begin of the simulation and the data start_date_plot : datetime.datetime end_date_plot : datetime.datetime ylim : float the maximal y value to be plotted week_interval : int the interval in weeks of the y ticks colors : list with 2 colornames Returns ------- figure, axes """ def conv_time_to_mpl_dates(arr): return matplotlib.dates.date2num([datetime.timedelta(days=float(date)) + date_begin_sim for date in arr]) new_cases_sim = trace['new_cases'] len_sim = trace['lambda_t'].shape[1] if start_date_plot is None: start_date_plot = date_begin_sim + datetime.timedelta(days=diff_data_sim) if end_date_plot is None: end_date_plot = date_begin_sim + datetime.timedelta(days=len_sim) if ylim is None: ylim = 1.6*np.max(new_cases_obs) num_days_data = len(new_cases_obs) diff_to_0 = num_days_data + diff_data_sim date_data_end = date_begin_sim + datetime.timedelta(days=diff_data_sim + num_days_data) num_days_future = (end_date_plot - date_data_end).days start_date_mpl, end_date_mpl = matplotlib.dates.date2num([start_date_plot, end_date_plot]) if week_interval is None: week_inter_left = int(np.ceil(num_days_data/7/5)) week_inter_right = int(np.ceil((end_date_mpl - start_date_mpl)/7/6)) else: week_inter_left = week_interval week_inter_right = week_interval fig, axes = plt.subplots(2, 2, figsize=(9, 5), gridspec_kw={'height_ratios': [1, 3], 'width_ratios': [2, 3]}) ax = axes[1][0] time_arr = np.arange(-len(new_cases_obs), 0) mpl_dates = conv_time_to_mpl_dates(time_arr) + diff_data_sim + num_days_data ax.plot(mpl_dates, new_cases_obs, 'd', markersize=6, label='Data', zorder=5, color=colors[0]) new_cases_past = new_cases_sim[:, :num_days_data] percentiles = np.percentile(new_cases_past, q=2.5, axis=0), np.percentile(new_cases_past, q=97.5, axis=0) ax.plot(mpl_dates, np.median(new_cases_past, axis=0), color=colors[1], label='Fit (with 95% CI)') ax.fill_between(mpl_dates, percentiles[0], percentiles[1], alpha=0.3, color=colors[1]) ax.set_yscale('log') ax.set_ylabel('Number of new cases') ax.set_xlabel('Date') ax.legend() ax.xaxis.set_major_locator(matplotlib.dates.WeekdayLocator(interval = week_inter_left, byweekday=matplotlib.dates.SU)) ax.xaxis.set_minor_locator(matplotlib.dates.DayLocator()) ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%m/%d')) ax.set_xlim(start_date_mpl) ax = axes[1][1] time1 = np.arange(-len(new_cases_obs) , 0) mpl_dates = conv_time_to_mpl_dates(time1) + diff_data_sim + num_days_data ax.plot(mpl_dates, new_cases_obs, 'd', label='Data', markersize=4, color=colors[0], zorder=5) new_cases_past = new_cases_sim[:, :num_days_data] ax.plot(mpl_dates, np.median(new_cases_past, axis=0), '--', color=colors[1], linewidth=1.5, label='Fit with 95% CI') percentiles = np.percentile(new_cases_past, q=2.5, axis=0), np.percentile(new_cases_past, q=97.5, axis=0) ax.fill_between(mpl_dates, percentiles[0], percentiles[1], alpha=0.2, color=colors[1]) time2 = np.arange(0, num_days_future) mpl_dates_fut = conv_time_to_mpl_dates(time2) + diff_data_sim + num_days_data cases_future = new_cases_sim[:, num_days_data:num_days_data+num_days_future].T median = np.median(cases_future, axis=-1) percentiles = ( np.percentile(cases_future, q=2.5, axis=-1), np.percentile(cases_future, q=97.5, axis=-1), ) ax.plot(mpl_dates_fut, median, color=colors[1], linewidth=3, label='forecast with 75% and 95% CI') ax.fill_between(mpl_dates_fut, percentiles[0], percentiles[1], alpha=0.1, color=colors[1]) ax.fill_between(mpl_dates_fut, np.percentile(cases_future, q=12.5, axis=-1), np.percentile(cases_future, q=87.5, axis=-1), alpha=0.2, color=colors[1]) ax.set_xlabel('Date') ax.set_ylabel(f'New confirmed cases in {country}') ax.legend(loc='upper left') ax.set_ylim(0, ylim) func_format = lambda num, _: "${:.0f}\,$k".format(num / 1_000) ax.yaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(func_format)) ax.set_xlim(start_date_mpl, end_date_mpl) ax.xaxis.set_major_locator(matplotlib.dates.WeekdayLocator(interval = week_inter_right, byweekday=matplotlib.dates.SU)) ax.xaxis.set_minor_locator(matplotlib.dates.DayLocator()) ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%m/%d')) ax = axes[0][1] time = np.arange(-diff_to_0 , -diff_to_0 + len_sim ) lambda_t = trace['lambda_t'][:, :] μ = trace['mu'][:, None] mpl_dates = conv_time_to_mpl_dates(time) + diff_data_sim + num_days_data ax.plot(mpl_dates, np.median(lambda_t - μ, axis=0), color=colors[1], linewidth=2) ax.fill_between(mpl_dates, np.percentile(lambda_t - μ, q=2.5, axis=0), np.percentile(lambda_t - μ, q=97.5, axis=0), alpha=0.15, color=colors[1]) ax.set_ylabel('effective\ngrowth rate $\lambda_t^*$') #ax.set_ylim(-0.15, 0.45) ylims = ax.get_ylim() ax.hlines(0, start_date_mpl, end_date_mpl, linestyles=':') delay = matplotlib.dates.date2num(date_data_end) - np.percentile(trace['delay'], q=75) ax.vlines(delay, ylims[0], ylims[1], linestyles='-', colors=['tab:red']) ax.set_ylim(*ylims) ax.text(delay + 0.5, ylims[1] - 0.04*np.diff(ylims), 'unconstrained because\nof reporting delay', color='tab:red', verticalalignment='top') ax.text(delay - 0.5, ylims[1] - 0.04*np.diff(ylims), 'constrained\nby data', color='tab:red', horizontalalignment='right', verticalalignment='top') ax.xaxis.set_major_locator(matplotlib.dates.WeekdayLocator(interval = week_inter_right, byweekday=matplotlib.dates.SU)) ax.xaxis.set_minor_locator(matplotlib.dates.DayLocator()) ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter('%m/%d')) ax.set_xlim(start_date_mpl, end_date_mpl) axes[0][0].set_visible(False) plt.subplots_adjust(wspace=0.4, hspace=.3) return fig, axes