forked from amit-sharma/causal-inference-tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplotter.py
More file actions
51 lines (40 loc) · 1.87 KB
/
Copy pathplotter.py
File metadata and controls
51 lines (40 loc) · 1.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import matplotlib.pyplot as plt
from datetime import datetime
SMALL_SIZE = 8
MEDIUM_SIZE = 26
BIGGER_SIZE = 30
plt.rc('font', size=SMALL_SIZE) # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE) # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE) # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
def plot_treatment_outcome(treatment, outcome, time_var):
fig, ax = plt.subplots()
ax.plot(time_var,treatment,'o')
ax.plot(time_var,outcome, 'r^')
plt.legend(loc="upper left", bbox_to_anchor=(0.4,1))
plt.xlabel("Time")
#plt.show()plt.xlabel("Time"
fig.set_size_inches(8, 6)
fig.savefig("poster_obs_data"+datetime.now().strftime("%H-%M-%s")+".png", bbox_inches="tight")
def plot_causal_effect(estimate, treatment, outcome):
fig, ax = plt.subplots()
x_min = 0
x_max = max(treatment)
y_min= estimate["intercept"]
y_max = y_min + estimate["value"]*(x_max-x_min)
ax.scatter(treatment, outcome,
c="gray", marker="o", label="Observed data")
ax.plot([x_min, x_max], [y_min, y_max], c="black", ls="solid", lw=4, label="Causal variation")
ax.set_ylim(0, max(outcome ))
ax.set_xlim(0, x_max)
bbox_props = dict(boxstyle="round", fc="w", ec="0.5", alpha=0.9)
ax.text(10.8, 1, r"DoWhy estimate $\rho$ (slope) = " +str(round(estimate["value"], 2)), ha="right", va="bottom", size=20, bbox=bbox_props)
ax.legend(loc="upper left")
plt.xlabel("Treatment")
plt.ylabel("Outcome")
#plt.show()
fig.set_size_inches(8, 6)
fig.savefig("poster_effect"+datetime.now().strftime("%H-%M-%s")+".png", bbox_inches='tight')