• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ricequant / rqalpha / 16618541631

30 Jul 2025 07:04AM UTC coverage: 65.126%. Remained the same
16618541631

push

github

web-flow
Merge pull request #930 from ricequant/develop

Develop

6762 of 10383 relevant lines covered (65.13%)

4.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.14
/rqalpha/data/base_data_source/data_source.py
1
# -*- coding: utf-8 -*-
2
# 版权所有 2020 深圳米筐科技有限公司(下称“米筐科技”)
3
#
4
# 除非遵守当前许可,否则不得使用本软件。
5
#
6
#     * 非商业用途(非商业用途指个人出于非商业目的使用本软件,或者高校、研究所等非营利机构出于教育、科研等目的使用本软件):
7
#         遵守 Apache License 2.0(下称“Apache 2.0 许可”),
8
#         您可以在以下位置获得 Apache 2.0 许可的副本:http://www.apache.org/licenses/LICENSE-2.0。
9
#         除非法律有要求或以书面形式达成协议,否则本软件分发时需保持当前许可“原样”不变,且不得附加任何条件。
10
#
11
#     * 商业用途(商业用途指个人出于任何商业目的使用本软件,或者法人或其他组织出于任何目的使用本软件):
12
#         未经米筐科技授权,任何个人不得出于任何商业目的使用本软件(包括但不限于向第三方提供、销售、出租、出借、转让本软件、
13
#         本软件的衍生产品、引用或借鉴了本软件功能或源代码的产品或服务),任何法人或其他组织不得出于任何目的使用本软件,
14
#         否则米筐科技有权追究相应的知识产权侵权责任。
15
#         在此前提下,对本软件的使用同样需要遵守 Apache 2.0 许可,Apache 2.0 许可与本许可冲突之处,以本许可为准。
16
#         详细的授权流程,请联系 public@ricequant.com 获取。
17
import os
7✔
18
import pickle
7✔
19
from datetime import date, datetime, timedelta
7✔
20
from itertools import groupby
7✔
21
from typing import Dict, Iterable, List, Optional, Sequence, Union
7✔
22

23
import numpy as np
7✔
24
import pandas as pd
7✔
25
import six
7✔
26
from rqalpha.utils.i18n import gettext as _
7✔
27
from rqalpha.const import INSTRUMENT_TYPE, TRADING_CALENDAR_TYPE
7✔
28
from rqalpha.interface import AbstractDataSource
7✔
29
from rqalpha.model.instrument import Instrument
7✔
30
from rqalpha.utils.datetime_func import (convert_date_to_int, convert_int_to_date, convert_int_to_datetime)
7✔
31
from rqalpha.utils.exception import RQInvalidArgument
7✔
32
from rqalpha.utils.functools import lru_cache
7✔
33
from rqalpha.utils.typing import DateLike
7✔
34
from rqalpha.environment import Environment
7✔
35
from rqalpha.data.base_data_source.adjust import FIELDS_REQUIRE_ADJUSTMENT, adjust_bars
7✔
36
from rqalpha.data.base_data_source.storage_interface import (AbstractCalendarStore, AbstractDateSet,
7✔
37
                                AbstractDayBarStore, AbstractDividendStore,
38
                                AbstractInstrumentStore)
39
from rqalpha.data.base_data_source.storages import (DateSet, DayBarStore, DividendStore,
7✔
40
                       ExchangeTradingCalendarStore, FutureDayBarStore,
41
                       FutureInfoStore,InstrumentStore,
42
                       ShareTransformationStore, SimpleFactorStore,
43
                       YieldCurveStore, FuturesTradingParameters)
44

45

46
BAR_RESAMPLE_FIELD_METHODS = {
7✔
47
    "open": "first",
48
    "close": "last",
49
    "iopv": "last",
50
    "high": "max",
51
    "low": "min",
52
    "total_turnover": "sum",
53
    "volume": "sum",
54
    "num_trades": "sum",
55
    "acc_net_value": "last",
56
    "unit_net_value": "last",
57
    "discount_rate": "last",
58
    "settlement": "last",
59
    "prev_settlement": "last",
60
    "open_interest": "last",
61
    "basis_spread": "last",
62
    "contract_multiplier": "last",
63
    "strike_price": "last",
64
}
65

66

67
class BaseDataSource(AbstractDataSource):
7✔
68
    DEFAULT_INS_TYPES = (
7✔
69
        INSTRUMENT_TYPE.CS, INSTRUMENT_TYPE.FUTURE, INSTRUMENT_TYPE.ETF, INSTRUMENT_TYPE.LOF, INSTRUMENT_TYPE.INDX,
70
        INSTRUMENT_TYPE.PUBLIC_FUND, INSTRUMENT_TYPE.REITs
71
    )
72

73
    def __init__(self, path: str, custom_future_info: dict, *args, **kwargs) -> None:
7✔
74
        if not os.path.exists(path):
7✔
75
            raise RuntimeError('bundle path {} not exist'.format(os.path.abspath(path)))
×
76

77
        def _p(name):
7✔
78
            return os.path.join(path, name)
7✔
79

80
        funds_day_bar_store = DayBarStore(_p('funds.h5'))
7✔
81
        self._day_bars = {
7✔
82
            INSTRUMENT_TYPE.CS: DayBarStore(_p('stocks.h5')),
83
            INSTRUMENT_TYPE.INDX: DayBarStore(_p('indexes.h5')),
84
            INSTRUMENT_TYPE.FUTURE: FutureDayBarStore(_p('futures.h5')),
85
            INSTRUMENT_TYPE.ETF: funds_day_bar_store,
86
            INSTRUMENT_TYPE.LOF: funds_day_bar_store,
87
            INSTRUMENT_TYPE.REITs: funds_day_bar_store
88
        }  # type: Dict[INSTRUMENT_TYPE, AbstractDayBarStore]
89
        
90
        self._future_info_store = FutureInfoStore(_p("future_info.json"), custom_future_info)
7✔
91
        
92
        self._instruments_stores = {}  # type: Dict[INSTRUMENT_TYPE, AbstractInstrumentStore]
7✔
93
        self._ins_id_or_sym_type_map = {}  # type: Dict[str, INSTRUMENT_TYPE]
7✔
94
        instruments = []
7✔
95
        
96
        env = Environment.get_instance()
7✔
97
        with open(_p('instruments.pk'), 'rb') as f:
7✔
98
            for i in pickle.load(f):
7✔
99
                if i["type"] == "Future" and Instrument.is_future_continuous_contract(i["order_book_id"]):
7✔
100
                    i["listed_date"] = datetime(1990, 1, 1)
7✔
101
                instruments.append(Instrument(
7✔
102
                    i, 
103
                    lambda i: self._future_info_store.get_tick_size(i),
104
                    ))
105
        for ins_type in self.DEFAULT_INS_TYPES:
7✔
106
            self.register_instruments_store(InstrumentStore(instruments, ins_type))
7✔
107
        dividend_store = DividendStore(_p('dividends.h5'))
7✔
108
        self._dividends = {
7✔
109
            INSTRUMENT_TYPE.CS: dividend_store,
110
            INSTRUMENT_TYPE.ETF: dividend_store,
111
            INSTRUMENT_TYPE.LOF: dividend_store,
112
        }
113

114
        self._calendar_providers = {
7✔
115
            TRADING_CALENDAR_TYPE.EXCHANGE: ExchangeTradingCalendarStore(_p("trading_dates.npy"))
116
        }
117

118
        self._yield_curve = YieldCurveStore(_p('yield_curve.h5'))
7✔
119

120
        split_store = SimpleFactorStore(_p('split_factor.h5'))
7✔
121
        self._split_factors = {
7✔
122
            INSTRUMENT_TYPE.CS: split_store,
123
            INSTRUMENT_TYPE.ETF: split_store,
124
            INSTRUMENT_TYPE.LOF: split_store,
125
        }
126
        self._ex_cum_factor = SimpleFactorStore(_p('ex_cum_factor.h5'))
7✔
127
        self._share_transformation = ShareTransformationStore(_p('share_transformation.json'))
7✔
128

129
        self._suspend_days = [DateSet(_p('suspended_days.h5'))]  # type: List[AbstractDateSet]
7✔
130
        self._st_stock_days = DateSet(_p('st_stock_days.h5'))
7✔
131

132
    def register_day_bar_store(self, instrument_type, store):
7✔
133
        #  type: (INSTRUMENT_TYPE, AbstractDayBarStore) -> None
134
        self._day_bars[instrument_type] = store
×
135

136
    def register_instruments_store(self, instruments_store):
7✔
137
        # type: (AbstractInstrumentStore) -> None
138
        instrument_type = instruments_store.instrument_type
7✔
139
        for id_or_sym in instruments_store.all_id_and_syms:
7✔
140
            self._ins_id_or_sym_type_map[id_or_sym] = instrument_type
7✔
141
        self._instruments_stores[instrument_type] = instruments_store
7✔
142

143
    def register_dividend_store(self, instrument_type, dividend_store):
7✔
144
        # type: (INSTRUMENT_TYPE, AbstractDividendStore) -> None
145
        self._dividends[instrument_type] = dividend_store
×
146

147
    def register_split_store(self, instrument_type, split_store):
7✔
148
        self._split_factors[instrument_type] = split_store
×
149

150
    def register_calendar_store(self, calendar_type, calendar_store):
7✔
151
        # type: (TRADING_CALENDAR_TYPE, AbstractCalendarStore) -> None
152
        self._calendar_providers[calendar_type] = calendar_store
×
153

154
    def append_suspend_date_set(self, date_set):
7✔
155
        # type: (AbstractDateSet) -> None
156
        self._suspend_days.append(date_set)
×
157

158
    @lru_cache(2048)
7✔
159
    def get_dividend(self, instrument):
6✔
160
        try:
7✔
161
            dividend_store = self._dividends[instrument.type]
7✔
162
        except KeyError:
7✔
163
            return None
7✔
164

165
        return dividend_store.get_dividend(instrument.order_book_id)
7✔
166

167
    def get_trading_minutes_for(self, order_book_id, trading_dt):
7✔
168
        raise NotImplementedError
×
169

170
    def get_trading_calendars(self):
7✔
171
        # type: () -> Dict[TRADING_CALENDAR_TYPE, pd.DatetimeIndex]
172
        return {t: store.get_trading_calendar() for t, store in self._calendar_providers.items()}
7✔
173

174
    def get_instruments(self, id_or_syms=None, types=None):
7✔
175
        # type: (Optional[Iterable[str]], Optional[Iterable[INSTRUMENT_TYPE]]) -> Iterable[Instrument]
176
        if id_or_syms is not None:
7✔
177
            ins_type_getter = lambda i: self._ins_id_or_sym_type_map.get(i)
7✔
178
            type_id_iter = groupby(sorted(id_or_syms, key=ins_type_getter), key=ins_type_getter)
7✔
179
        else:
180
            type_id_iter = ((t, None) for t in types or self._instruments_stores.keys())
7✔
181
        for ins_type, id_or_syms in type_id_iter:
7✔
182
            if ins_type is not None and ins_type in self._instruments_stores:
7✔
183
                yield from self._instruments_stores[ins_type].get_instruments(id_or_syms)
7✔
184

185
    def get_share_transformation(self, order_book_id):
7✔
186
        return self._share_transformation.get_share_transformation(order_book_id)
7✔
187

188
    def is_suspended(self, order_book_id, dates):
7✔
189
        # type: (str, Sequence[DateLike]) -> List[bool]
190
        for date_set in self._suspend_days:
7✔
191
            result = date_set.contains(order_book_id, dates)
7✔
192
            if result is not None:
7✔
193
                return result
7✔
194
        else:
195
            return [False] * len(dates)
7✔
196

197
    def is_st_stock(self, order_book_id, dates):
7✔
198
        result = self._st_stock_days.contains(order_book_id, dates)
7✔
199
        return result if result is not None else [False] * len(dates)
7✔
200

201
    @lru_cache(None)
7✔
202
    def _all_day_bars_of(self, instrument):
6✔
203
        return self._day_bars[instrument.type].get_bars(instrument.order_book_id)
7✔
204

205
    @lru_cache(None)
7✔
206
    def _filtered_day_bars(self, instrument):
6✔
207
        bars = self._all_day_bars_of(instrument)
7✔
208
        return bars[bars['volume'] > 0]
7✔
209

210
    def get_bar(self, instrument, dt, frequency):
7✔
211
        # type: (Instrument, Union[datetime, date], str) -> Optional[np.ndarray]
212
        if frequency != '1d':
7✔
213
            raise NotImplementedError
×
214

215
        bars = self._all_day_bars_of(instrument)
7✔
216
        if len(bars) <= 0:
7✔
217
            return
×
218
        dt = np.uint64(convert_date_to_int(dt))
7✔
219
        pos = bars['datetime'].searchsorted(dt)
7✔
220
        if pos >= len(bars) or bars['datetime'][pos] != dt:
7✔
221
            return None
×
222

223
        return bars[pos]
7✔
224

225
    OPEN_AUCTION_BAR_FIELDS = ["datetime", "open", "limit_up", "limit_down", "volume", "total_turnover"]
7✔
226

227
    def get_open_auction_bar(self, instrument, dt):
7✔
228
        # type: (Instrument, Union[datetime, date]) -> Dict
229
        day_bar = self.get_bar(instrument, dt, "1d")
7✔
230
        if day_bar is None:
7✔
231
            bar = dict.fromkeys(self.OPEN_AUCTION_BAR_FIELDS, np.nan)
×
232
        else:
233
            bar = {k: day_bar[k] if k in day_bar.dtype.names else np.nan for k in self.OPEN_AUCTION_BAR_FIELDS}
7✔
234
        bar["last"] = bar["open"]
7✔
235
        return bar
7✔
236

237
    def get_settle_price(self, instrument, date):
7✔
238
        bar = self.get_bar(instrument, date, '1d')
7✔
239
        if bar is None:
7✔
240
            return np.nan
×
241
        return bar['settlement']
7✔
242

243
    @staticmethod
7✔
244
    def _are_fields_valid(fields, valid_fields):
6✔
245
        if fields is None:
7✔
246
            return True
×
247
        if isinstance(fields, six.string_types):
7✔
248
            return fields in valid_fields
7✔
249
        for field in fields:
7✔
250
            if field not in valid_fields:
7✔
251
                return False
×
252
        return True
7✔
253

254
    def get_ex_cum_factor(self, order_book_id):
7✔
255
        return self._ex_cum_factor.get_factors(order_book_id)
7✔
256

257
    def _update_weekly_trading_date_index(self, idx):
7✔
258
        env = Environment.get_instance()
×
259
        if env.data_proxy.is_trading_date(idx):
×
260
            return idx
×
261
        return env.data_proxy.get_previous_trading_date(idx)
×
262

263
    def resample_week_bars(self, bars, bar_count, fields):
7✔
264
        df_bars = pd.DataFrame(bars)
×
265
        df_bars['datetime'] = df_bars.apply(lambda x: convert_int_to_datetime(x['datetime']), axis=1)
×
266
        df_bars = df_bars.set_index('datetime')
×
267
        nead_fields = fields
×
268
        if isinstance(nead_fields, str):
×
269
            nead_fields = [nead_fields]
×
270
        hows = {field: BAR_RESAMPLE_FIELD_METHODS[field] for field in nead_fields if field in BAR_RESAMPLE_FIELD_METHODS}
×
271
        df_bars = df_bars.resample('W-Fri').agg(hows)
×
272
        df_bars.index = df_bars.index.map(self._update_weekly_trading_date_index)
×
273
        df_bars = df_bars[~df_bars.index.duplicated(keep='first')]
×
274
        df_bars.sort_index(inplace=True)
×
275
        df_bars = df_bars[-bar_count:]
×
276
        df_bars = df_bars.reset_index()
×
277
        df_bars['datetime'] = df_bars.apply(lambda x: np.uint64(convert_date_to_int(x['datetime'].date())), axis=1)
×
278
        df_bars = df_bars.set_index('datetime')
×
279
        bars = df_bars.to_records()
×
280
        return bars
×
281

282
    def history_bars(self, instrument, bar_count, frequency, fields, dt,
7✔
283
                     skip_suspended=True, include_now=False,
284
                     adjust_type='pre', adjust_orig=None):
285

286
        if frequency != '1d' and frequency != '1w':
7✔
287
            raise NotImplementedError
×
288

289
        if skip_suspended and instrument.type == 'CS':
7✔
290
            bars = self._filtered_day_bars(instrument)
7✔
291
        else:
292
            bars = self._all_day_bars_of(instrument)
7✔
293

294
        if not self._are_fields_valid(fields, bars.dtype.names):
7✔
295
            raise RQInvalidArgument("invalid fields: {}".format(fields))
×
296

297
        if len(bars) <= 0:
7✔
298
            return bars
7✔
299

300
        if frequency == '1w':
7✔
301
            if include_now:
×
302
                dt = np.uint64(convert_date_to_int(dt))
×
303
                i = bars['datetime'].searchsorted(dt, side='right')
×
304
            else:
305
                monday = dt - timedelta(days=dt.weekday())
×
306
                monday = np.uint64(convert_date_to_int(monday))
×
307
                i = bars['datetime'].searchsorted(monday, side='left')
×
308

309
            left = i - bar_count * 5 if i >= bar_count * 5 else 0
×
310
            bars = bars[left:i]
×
311

312
            if adjust_type == 'none' or instrument.type in {'Future', 'INDX'}:
×
313
                # 期货及指数无需复权
314
                week_bars = self.resample_week_bars(bars, bar_count, fields)
×
315
                return week_bars if fields is None else week_bars[fields]
×
316

317
            if isinstance(fields, str) and fields not in FIELDS_REQUIRE_ADJUSTMENT:
×
318
                week_bars = self.resample_week_bars(bars, bar_count, fields)
×
319
                return week_bars if fields is None else week_bars[fields]
×
320

321
            adjust_bars_date = adjust_bars(bars, self.get_ex_cum_factor(instrument.order_book_id),
×
322
                                           fields, adjust_type, adjust_orig)
323
            adjust_week_bars = self.resample_week_bars(adjust_bars_date, bar_count, fields)
×
324
            return adjust_week_bars if fields is None else adjust_week_bars[fields]
×
325
        dt = np.uint64(convert_date_to_int(dt))
7✔
326
        i = bars['datetime'].searchsorted(dt, side='right')
7✔
327
        left = i - bar_count if i >= bar_count else 0
7✔
328
        bars = bars[left:i]
7✔
329
        if adjust_type == 'none' or instrument.type in {'Future', 'INDX'}:
7✔
330
            # 期货及指数无需复权
331
            return bars if fields is None else bars[fields]
7✔
332

333
        if isinstance(fields, str) and fields not in FIELDS_REQUIRE_ADJUSTMENT:
7✔
334
            return bars if fields is None else bars[fields]
×
335

336
        bars = adjust_bars(bars, self.get_ex_cum_factor(instrument.order_book_id),
7✔
337
                           fields, adjust_type, adjust_orig)
338

339
        return bars if fields is None else bars[fields]
7✔
340

341
    def current_snapshot(self, instrument, frequency, dt):
7✔
342
        raise NotImplementedError
×
343

344
    @lru_cache(2048)
7✔
345
    def get_split(self, instrument):
6✔
346
        try:
7✔
347
            splilt_store = self._split_factors[instrument.type]
7✔
348
        except KeyError:
7✔
349
            return None
7✔
350

351
        return splilt_store.get_factors(instrument.order_book_id)
7✔
352

353
    def available_data_range(self, frequency):
7✔
354
        # FIXME
355
        from rqalpha.const import DEFAULT_ACCOUNT_TYPE
7✔
356
        accounts = Environment.get_instance().config.base.accounts
7✔
357
        if not (DEFAULT_ACCOUNT_TYPE.STOCK in accounts or DEFAULT_ACCOUNT_TYPE.FUTURE in accounts):
7✔
358
            return date.min, date.max
×
359
        if frequency in ['tick', '1d']:
7✔
360
            s, e = self._day_bars[INSTRUMENT_TYPE.INDX].get_date_range('000001.XSHG')
7✔
361
            return convert_int_to_date(s).date(), convert_int_to_date(e).date()
7✔
362

363
    def get_yield_curve(self, start_date, end_date, tenor=None):
7✔
364
        return self._yield_curve.get_yield_curve(start_date, end_date, tenor=tenor)
7✔
365

366
    @lru_cache(1024)
7✔
367
    def get_futures_trading_parameters(self, instrument: Instrument, dt: datetime.date) -> FuturesTradingParameters:
7✔
368
        return self._future_info_store.get_future_info(instrument.order_book_id, instrument.underlying_symbol)
7✔
369

370
    def get_merge_ticks(self, order_book_id_list, trading_date, last_dt=None):
7✔
371
        raise NotImplementedError
×
372

373
    def history_ticks(self, instrument, count, dt):
7✔
374
        raise NotImplementedError
×
375

376
    def get_algo_bar(self, id_or_ins, start_min, end_min, dt):
7✔
377
        raise NotImplementedError("open source rqalpha not support algo order")
×
378

379
    def get_open_auction_volume(self, instrument, dt):
7✔
380
        # type: (Instrument, datetime.datetime) -> float
381
        volume = self.get_open_auction_bar(instrument, dt)['volume']
7✔
382
        return volume
7✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc