-
Notifications
You must be signed in to change notification settings - Fork 661
/
ROCKET.py
100 lines (82 loc) · 4.25 KB
/
ROCKET.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/053_models.ROCKET.ipynb.
# %% auto 0
__all__ = ['RocketClassifier', 'load_rocket', 'RocketRegressor']
# %% ../../nbs/053_models.ROCKET.ipynb 3
import sklearn
from sklearn.linear_model import RidgeClassifierCV, RidgeCV
from sklearn.metrics import make_scorer
from sklearn.preprocessing import StandardScaler
from ..data.external import *
from ..imports import *
from .layers import *
# %% ../../nbs/053_models.ROCKET.ipynb 4
class RocketClassifier(sklearn.pipeline.Pipeline):
"""Time series classification using ROCKET features and a linear classifier"""
def __init__(self, num_kernels=10_000, normalize_input=True, random_state=None,
alphas=np.logspace(-3, 3, 7), normalize_features=True, memory=None, verbose=False, scoring=None, class_weight=None, **kwargs):
"""
RocketClassifier is recommended for up to 10k time series.
For a larger dataset, you can use ROCKET (in Pytorch).
scoring = None --> defaults to accuracy.
Rocket args:
num_kernels : int, number of random convolutional kernels (default 10,000)
normalize_input : boolean, whether or not to normalise the input time series per instance (default True)
random_state : Optional random seed (default None)
"""
try:
import sktime
from sktime.transformations.panel.rocket import Rocket
except ImportError:
raise("You need to install sktime to be able to use RocketClassifier")
self.steps = [('rocket', Rocket(num_kernels=num_kernels, normalise=normalize_input, random_state=random_state))]
if normalize_features:
self.steps += [('scalar', StandardScaler(with_mean=False))]
self.steps += [('ridgeclassifiercv', RidgeClassifierCV(alphas=alphas, scoring=scoring, class_weight=class_weight, **kwargs))]
store_attr()
self._validate_steps()
def __repr__(self):
return f'Pipeline(steps={self.steps.copy()})'
def save(self, fname='Rocket', path='./models'):
path = Path(path)
filename = path/fname
with open(f'{filename}.pkl', 'wb') as output:
pickle.dump(self, output, pickle.HIGHEST_PROTOCOL)
# %% ../../nbs/053_models.ROCKET.ipynb 5
def load_rocket(fname='Rocket', path='./models'):
path = Path(path)
filename = path/fname
with open(f'{filename}.pkl', 'rb') as input:
output = pickle.load(input)
return output
# %% ../../nbs/053_models.ROCKET.ipynb 6
class RocketRegressor(sklearn.pipeline.Pipeline):
"""Time series regression using ROCKET features and a linear regressor"""
def __init__(self, num_kernels=10_000, normalize_input=True, random_state=None,
alphas=np.logspace(-3, 3, 7), normalize_features=True, memory=None, verbose=False, scoring=None, **kwargs):
"""
RocketRegressor is recommended for up to 10k time series.
For a larger dataset, you can use ROCKET (in Pytorch).
scoring = None --> defaults to r2.
Args:
num_kernels : int, number of random convolutional kernels (default 10,000)
normalize_input : boolean, whether or not to normalise the input time series per instance (default True)
random_state : Optional random seed (default None)
"""
try:
import sktime
from sktime.transformations.panel.rocket import Rocket
except ImportError:
raise("You need to install sktime to be able to use RocketRegressor")
self.steps = [('rocket', Rocket(num_kernels=num_kernels, normalise=normalize_input, random_state=random_state))]
if normalize_features:
self.steps += [('scalar', StandardScaler(with_mean=False))]
self.steps += [('ridgecv', RidgeCV(alphas=alphas, scoring=scoring, **kwargs))]
store_attr()
self._validate_steps()
def __repr__(self):
return f'Pipeline(steps={self.steps.copy()})'
def save(self, fname='Rocket', path='./models'):
path = Path(path)
filename = path/fname
with open(f'{filename}.pkl', 'wb') as output:
pickle.dump(self, output, pickle.HIGHEST_PROTOCOL)