forked from hhuhhu/trade
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrading_dates_mixin.py
More file actions
90 lines (74 loc) · 3.15 KB
/
trading_dates_mixin.py
File metadata and controls
90 lines (74 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding: utf-8 -*-
#
# Copyright 2017 Ricequant, Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import pandas as pd
from utils.py2 import lru_cache
class TradingDatesMixin(object):
def __init__(self, dates):
self._dates = dates
def get_trading_dates(self, start_date, end_date):
# 只需要date部分
start_date = pd.Timestamp(start_date).replace(hour=0, minute=0, second=0)
end_date = pd.Timestamp(end_date).replace(hour=0, minute=0, second=0)
left = self._dates.searchsorted(start_date)
right = self._dates.searchsorted(end_date, side='right')
return self._dates[left:right]
def get_previous_trading_date(self, date):
date = pd.Timestamp(date).replace(hour=0, minute=0, second=0)
return self._get_previous_trading_date(date)
@lru_cache(None)
def _get_previous_trading_date(self, date):
pos = self._dates.searchsorted(date)
if pos > 0:
return self._dates[pos - 1]
else:
return self._dates[0]
def get_next_trading_date(self, date):
date = pd.Timestamp(date).replace(hour=0, minute=0, second=0)
pos = self._dates.searchsorted(date, side='right')
return self._dates[pos]
def is_trading_date(self, date):
date = pd.Timestamp(date).replace(hour=0, minute=0, second=0)
pos = self._dates.searchsorted(date)
return pos < len(self._dates) and self._dates[pos] == date
@lru_cache(512)
def _get_future_trading_date(self, dt):
dt1 = dt - datetime.timedelta(hours=4)
td = pd.Timestamp(dt1.date())
pos = self._dates.searchsorted(td)
if self._dates[pos] != td:
raise RuntimeError('invalid future calendar datetime: {}'.format(dt))
if dt1.hour >= 16:
return self._dates[pos + 1]
return td
def get_trading_dt(self, calendar_dt):
trading_date = self.get_future_trading_date(calendar_dt)
return datetime.datetime.combine(trading_date, calendar_dt.time())
def get_future_trading_date(self, dt):
return self._get_future_trading_date(dt.replace(minute=0, second=0))
def get_nth_previous_trading_date(self, date, n):
date = pd.Timestamp(date).replace(hour=0, minute=0, second=0)
pos = self._dates.searchsorted(date)
if pos >= n:
return self._dates[pos - n]
else:
return self._dates[0]
def get_n_trading_dates_until(self, dt, n):
date = pd.Timestamp(dt).replace(hour=0, minute=0, second=0)
pos = self._dates.searchsorted(date)
if pos >= n:
return self._dates[pos - n:pos]
return self._dates[:pos]