Skip to content

Commit 31c21fe

Browse files
committed
v0.1.17, replacing window smoothing by exponential smoothing.
1 parent 25a266a commit 31c21fe

5 files changed

Lines changed: 166 additions & 20 deletions

File tree

outputs/wandb_plot.pdf

6.9 KB
Binary file not shown.

plotify/smoothing.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
Adapted from cherry:
5+
https://github.com/learnables/cherry/blob/master/cherry/plot.py
6+
7+
"""
8+
9+
import numpy as np
10+
11+
12+
def _one_sided_smoothing(x_before, y_before, smoothing_temperature=1.0):
13+
"""
14+
[[Source]](https://github.com/seba-1511/cherry/blob/master/cherry/plot.py)
15+
**Decription**
16+
One side (regular) exponential moving average for smoothing a curve
17+
It evenly resamples points baesd on x-axis and then averages y values with
18+
weighting factor decreasing exponentially.
19+
**Arguments**
20+
* **x_before** (ndarray) - x values. Required to be in accending order.
21+
* **y_before** (ndarray) - y values. Required to have same size as x_before.
22+
* **smoothing_temperature** (float, *optional*, default=1.0) - the number of previous
23+
steps trusted. Used to calculate the decay factor.
24+
**Return**
25+
* **x_after** (ndarray) - x values after resampling.
26+
* **y_after** (ndarray) - y values after smoothing.
27+
* **y_count** (ndarray) - decay values at each steps.
28+
**Credit**
29+
Adapted from OpenAI's baselines implementation.
30+
**Example**
31+
~~~python
32+
from cherry.plot import _one_sided_smoothing as osmooth
33+
x_smoothed, y_smoothed, y_counts = osmooth(x_original,
34+
y_original,
35+
smoothing_temperature=1.0)
36+
~~~
37+
"""
38+
39+
if x_before is None:
40+
x_before = np.arange(len(y_before))
41+
42+
assert len(x_before) == len(y_before), \
43+
'x_before and y_before must have equal length.'
44+
assert all(x_before[i] <= x_before[i+1] for i in range(len(x_before)-1)), \
45+
'x_before needs to be sorted in ascending order.'
46+
47+
# Resampling
48+
size = len(x_before)
49+
x_after = np.linspace(x_before[0], x_before[-1], size)
50+
y_after = np.zeros(size, dtype=float)
51+
y_count = np.zeros(size, dtype=float)
52+
53+
# Weighting factor for data of previous steps
54+
alpha = np.exp(-1./smoothing_temperature)
55+
x_before_length = x_before[-1] - x_before[0]
56+
x_before_index = 0
57+
decay_period = x_before_length/(size-1)*smoothing_temperature
58+
59+
for i in range(len(x_after)):
60+
# Compute current EMA value based on the value of previous time step
61+
if(i != 0):
62+
y_after[i] = alpha * y_after[i-1]
63+
y_count[i] = alpha * y_count[i-1]
64+
65+
# Compute current EMA value by adding weighted average of old points
66+
# covered by the new point
67+
while x_before_index < size:
68+
if x_after[i] >= x_before[x_before_index]:
69+
difference = x_after[i] - x_before[x_before_index]
70+
# Weighting factor for y value of each old points
71+
beta = np.exp(-(difference/decay_period))
72+
y_after[i] += y_before[x_before_index] * beta
73+
y_count[i] += beta
74+
x_before_index += 1
75+
else:
76+
break
77+
78+
y_after = y_after/y_count
79+
return x_after, y_after, y_count
80+
81+
82+
def exponential_smoothing(x, y=None, temperature=1.0):
83+
"""
84+
[[Source]](https://github.com/seba-1511/cherry/blob/master/cherry/plot.py)
85+
**Decription**
86+
Two-sided exponential moving average for smoothing a curve.
87+
It performs regular exponential moving average twice from two different
88+
sides and then combines the results together.
89+
**Arguments**
90+
* **x** (ndarray/tensor/list) - x values, in accending order.
91+
* **y** (ndarray/tensor/list) - y values.
92+
* **temperature** (float, *optional*, default=1.0) - The higher,
93+
the smoother.
94+
**Return**
95+
* ndarray - x values after resampling.
96+
* ndarray - y values after smoothing.
97+
**Credit**
98+
Adapted from OpenAI's baselines implementation.
99+
**Example**
100+
~~~python
101+
from cherry.plot import exponential_smoothing
102+
x_smoothed, y_smoothed, _ = exponential_smoothing(x_original,
103+
y_original,
104+
temperature=3.)
105+
~~~
106+
"""
107+
108+
if y is None:
109+
y = x
110+
x = np.arange(0, len(y))
111+
112+
if isinstance(y, list):
113+
y = np.array(y)
114+
115+
if isinstance(x, list):
116+
x = np.array(x)
117+
118+
assert x.shape == y.shape
119+
assert len(x.shape) == 1
120+
x_after1, y_after1, y_count1 = _one_sided_smoothing(x,
121+
y,
122+
temperature)
123+
x_after2, y_after2, y_count2 = _one_sided_smoothing(-x[::-1],
124+
y[::-1],
125+
temperature)
126+
127+
y_after2 = y_after2[::-1]
128+
y_count2 = y_count2[::-1]
129+
130+
y_after = y_after1 * y_count1 + y_after2 * y_count2
131+
y_after /= (y_count1 + y_count2)
132+
return x_after1.tolist(), y_after.tolist()
133+
134+
135+
def smooth(x, y=None, temperature=1.0):
136+
# Not officially supported.
137+
result = exponential_smoothing(x=x, y=y, temperature=temperature)
138+
if y is None:
139+
return result[1]
140+
return result

plotify/wandb_plots.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import plotify as pl
66
import wandb
77

8+
from .smoothing import smooth
9+
810

911
def wandb_plot(config):
1012
"""
@@ -100,15 +102,13 @@ def wandb_plot(config):
100102
run_ys = run_ys[:cutoff]
101103

102104
# smooth each run
103-
if 'smooth_window' in result:
104-
"""
105-
From:
106-
https://stackoverflow.com/questions/11352047/finding-moving-average-from-data-points-in-python/34387987#34387987
107-
"""
108-
smooth_window = result.get('smooth_window')
109-
y_cumsum = np.cumsum(run_ys)
110-
run_ys = (y_cumsum[smooth_window:] - y_cumsum[:-smooth_window]) / smooth_window
111-
run_xs = run_xs[:-smooth_window]
105+
if 'smooth_temperature' in result:
106+
smooth_temperature = result.get('smooth_temperature')
107+
run_xs, run_ys = smooth(
108+
x=run_xs,
109+
y=run_ys,
110+
temperature=smooth_temperature,
111+
)
112112

113113
# average y values that have the same x values
114114
xs_increasing = np.diff(run_xs)
@@ -184,6 +184,7 @@ def wandb_plot(config):
184184
alpha=0.5,
185185
color=color,
186186
linewidth=0.0,
187+
step='mid',
187188
)
188189

189190
return plot

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
find_packages,
66
)
77

8-
VERSION = '0.1.16'
8+
VERSION = '0.1.17'
99

1010
install(
1111
name='plotify',

tests/plot_wandb.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,32 +26,37 @@
2626
'arnolds/qmcrl/u3371i4t',
2727
'arnolds/qmcrl/k86a61gi',
2828
'arnolds/qmcrl/61inx3b9',
29-
'arnolds/qmcrl/pq7h1956',
30-
'arnolds/qmcrl/5lnjdhyy',
31-
'arnolds/qmcrl/3bktd5yv',
32-
'arnolds/qmcrl/yrpk7li0',
33-
'arnolds/qmcrl/xpouxt8w',
29+
# 'arnolds/qmcrl/pq7h1956',
30+
# 'arnolds/qmcrl/5lnjdhyy',
31+
# 'arnolds/qmcrl/3bktd5yv',
32+
# 'arnolds/qmcrl/yrpk7li0',
33+
# 'arnolds/qmcrl/xpouxt8w',
3434
],
3535
'x_key': 'iteration',
3636
'y_key': 'test/episode_returns',
3737
'label': 'MC',
3838
'color': pl.Maureen['blue'],
3939
'linewidth': 1.8,
40-
'smooth_window': 1,
40+
'smooth_temperature': 15.0,
4141
'markevery': 1000,
4242
'samples': 4196,
4343
'shade': 'std',
4444
},
4545
{
46-
'wandb_id': 'arnolds/qmcrl/xpouxt8w',
46+
'wandb_id': [
47+
'arnolds/qmcrl/xpouxt8w'
48+
'arnolds/qmcrl/61inx3b9',
49+
'arnolds/qmcrl/pq7h1956',
50+
],
4751
'x_key': 'iteration',
4852
'y_key': 'test/episode_returns',
4953
'label': 'RQMC',
5054
'color': pl.Maureen['orange'],
5155
'linewidth': 1.8,
52-
'smooth_window': 50,
53-
'samples': 10000,
54-
'markevery': 1000,
56+
'smooth_temperature': 5.0,
57+
'samples': 512,
58+
'markevery': 64,
59+
'shade': 'ci95',
5560
},
5661
],
5762
}

0 commit comments

Comments
 (0)