• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

wwu-mmll / photonai / 14280849012

04 Nov 2024 01:43PM UTC coverage: 91.073%. Remained the same
14280849012

push

github

web-flow
Merge pull request #89 from wwu-mmll/develop

Develop

93 of 98 new or added lines in 10 files covered. (94.9%)

110 existing lines in 7 files now uncovered.

5815 of 6385 relevant lines covered (91.07%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.48
/photonai/processing/permutation_test.py
1
import numpy as np
1✔
2
import pandas as pd
1✔
3
import dask
1✔
4
import os
1✔
5
from dask.distributed import Client
1✔
6
from datetime import timedelta
1✔
7
from pymodm import connect
1✔
8
from pymodm.errors import DoesNotExist, ConnectionError
1✔
9
from pymongo import DESCENDING
1✔
10

11
from photonai.base import OutputSettings
1✔
12
from photonai.photonlogger.logger import logger
1✔
13

14
from photonai.processing.inner_folds import Scorer
1✔
15
from photonai.processing.results_structure import MDBPermutationResults, MDBPermutationMetrics, MDBHyperpipe
1✔
16

17

18
class PermutationTest:
1✔
19

20
    def __init__(self, hyperpipe_constructor, permutation_id: str, n_perms=1000, n_processes=1, random_state=15,
1✔
21
                 verbosity=-1):
22

23
        self.hyperpipe_constructor = hyperpipe_constructor
1✔
24
        self.n_perms = n_perms
1✔
25
        self.permutation_id = permutation_id
1✔
26
        self.mother_permutation_id = PermutationTest.get_mother_permutation_id(permutation_id)
1✔
27
        self.n_processes = n_processes
1✔
28
        self.random_state = random_state
1✔
29
        self.verbosity = verbosity
1✔
30
        self.pipe = None
1✔
31
        self.metrics = None
1✔
32

33
    @staticmethod
1✔
34
    def manage_metrics(metrics, last_element=None, best_config_metric=''):
1✔
35
        metric_dict = dict()
1✔
36
        for metric in metrics:
1✔
37
            metric_dict[metric] = {'name': metric,
1✔
38
                                   'greater_is_better': PermutationTest.set_greater_is_better(metric, last_element)}
39
        if best_config_metric not in metric_dict.keys():
1✔
40
            metric_dict[best_config_metric] = {'name': best_config_metric,
×
41
                                                'greater_is_better': PermutationTest.set_greater_is_better(best_config_metric)}
42
        return metric_dict
1✔
43

44
    @staticmethod
1✔
45
    def get_mother_permutation_id(permutation_id):
1✔
46
        m_perm = permutation_id + "_reference"
1✔
47
        return m_perm
1✔
48

49
    def fit(self, X, y, **kwargs):
1✔
50

51
        self.pipe = self.hyperpipe_constructor()
1✔
52

53
        # we need a mongodb to collect the results!
54
        if not self.pipe.output_settings.mongodb_connect_url:
1✔
55
            raise ValueError("MongoDB connection string must be given for permutation tests")
1✔
56

57
        # Get all specified metrics
58
        best_config_metric = self.pipe.optimization.best_config_metric
1✔
59
        self.metrics = PermutationTest.manage_metrics(self.pipe.optimization.metrics, self.pipe.elements[-1], best_config_metric)
1✔
60

61
        # at first we do a reference optimization
62
        y_true = y
1✔
63

64
        # Run with true labels
65
        connect(self.pipe.output_settings.mongodb_connect_url, alias="photon_core")
1✔
66
        # Check if it already exists in DB
67
        try:
1✔
68
            existing_reference = MDBHyperpipe.objects.raw({'permutation_id': self.mother_permutation_id,
1✔
69
                                                           'computation_completed': True}).first()
70
            if not existing_reference.permutation_test:
×
71
                existing_reference.permutation_test = MDBPermutationResults(n_perms=self.n_perms)
×
72
                existing_reference.save()
×
73
            # check if all outer folds exist
74
            logger.info("Found hyperpipe computation with true targets, skipping the optimization process with true targets")
×
75
        except DoesNotExist:
1✔
76
            # if we havent computed the reference value do it:
77
            logger.info("Calculating Reference Values with true targets.")
1✔
78
            try:
1✔
79
                self.pipe.permutation_id = self.mother_permutation_id
1✔
80
                self.pipe.fit(X, y_true, **kwargs)
1✔
81
                self.pipe.results.computation_completed = True
1✔
82
                self.pipe.results.permutation_test = MDBPermutationResults(n_perms=self.n_perms)
1✔
83
                self.clear_data_and_save(self.pipe)
1✔
84
                existing_reference = self.pipe.results
1✔
85

86
            except Exception as e:
×
87
                if self.pipe.results is not None:
×
88
                    self.pipe.results.permutation_failed = str(e)
×
89
                    logger.error(e)
×
90
                    PermutationTest.clear_data_and_save(self.pipe)
×
91
                raise e
×
92

93
        # check for sanity
94
        if not self.__validate_usability(existing_reference):
1✔
95
            raise RuntimeError("Permutation Test is not adviced because results are not better than dummy. Aborting.")
1✔
96

97
        # find how many permutations have been computed already:
98
        existing_permutations = list(MDBHyperpipe.objects.raw({'permutation_id': self.permutation_id,
1✔
99
                                                               'computation_completed': True}).only('name'))
100
        existing_permutations = [int(perm_run.name.split('_')[-1]) for perm_run in existing_permutations]
1✔
101

102
        # we do one more permutation that is left in case the last permutation runs broke, one for each parallel
103
        if len(existing_permutations) > 0:
1✔
104
            perms_todo = set(np.arange(self.n_perms)) - set(existing_permutations)
×
105
        else:
106
            perms_todo = np.arange(self.n_perms)
1✔
107

108
        logger.info(str(len(perms_todo)) + " permutation runs to do")
1✔
109

110
        if len(perms_todo) > 0:
1✔
111
            # create permutation labels
112
            np.random.seed(self.random_state)
1✔
113
            self.permutations = [np.random.permutation(y_true) for _ in range(self.n_perms)]
1✔
114

115
            # Run parallel pool
116
            job_list = list()
1✔
117
            if self.n_processes > 1:
1✔
118
                try:
×
119

120
                    my_client = Client(threads_per_worker=1,
×
121
                                       n_workers=self.n_processes,
122
                                       processes=True)
123

124
                    for perm_run in perms_todo:
×
125
                        del_job = dask.delayed(PermutationTest.run_parallelized_permutation)(self.hyperpipe_constructor, X,
×
126
                                                                                             perm_run,
127
                                                                                             self.permutations[perm_run],
128
                                                                                             self.permutation_id,
129
                                                                                             self.verbosity, **kwargs)
130
                        job_list.append(del_job)
×
131

132
                    dask.compute(*job_list)
×
133

134
                finally:
135
                    my_client.close()
×
136
            else:
137
                for perm_run in perms_todo:
1✔
138
                    PermutationTest.run_parallelized_permutation(self.hyperpipe_constructor, X, perm_run,
1✔
139
                                                                 self.permutations[perm_run],
140
                                                                 self.permutation_id, self.verbosity, **kwargs)
141

142
        perm_result = self._calculate_results(self.permutation_id,
1✔
143
                                              mongodb_path=self.pipe.output_settings.mongodb_connect_url)
144

145
        performance_df = pd.DataFrame(dict([(name, [i]) for name, i in perm_result.p_values.items()]))
1✔
146
        performance_df.to_csv(os.path.join(existing_reference.output_folder, 'permutation_test_results.csv'))
1✔
147
        return self
1✔
148

149
    @staticmethod
1✔
150
    def clear_data_and_save(perm_pipe):
1✔
151
        perm_pipe.results.outer_folds = list()
1✔
152
        perm_pipe.results.best_config = None
1✔
153
        perm_pipe.results.save()
1✔
154

155
    @staticmethod
1✔
156
    def run_parallelized_permutation(hyperpipe_constructor, X, perm_run, y_perm, permutation_id, verbosity=-1,
1✔
157
                                     **kwargs):
158
        # Create new instance of hyperpipe and set all parameters
159
        perm_pipe = hyperpipe_constructor()
1✔
160
        perm_pipe.verbosity = verbosity
1✔
161
        perm_pipe.name = perm_pipe.name + '_perm_' + str(perm_run)
1✔
162
        perm_pipe.permutation_id = permutation_id
1✔
163

164
        # print(y_perm)
165
        po = OutputSettings(mongodb_connect_url=perm_pipe.output_settings.mongodb_connect_url,
1✔
166
                            save_output=False)
167
        perm_pipe.output_settings = po
1✔
168
        perm_pipe.calculate_metrics_across_folds = False
1✔
169
        try:
1✔
170
            # Fit hyperpipe
171
            # WE DO PRINT BECAUSE WE HAVE NO COMMON LOGGER!!!
172
            print('Fitting permutation ' + str(perm_run) + ' ...')
1✔
173
            perm_pipe.fit(X, y_perm, **kwargs)
1✔
174
            perm_pipe.results.computation_completed = True
1✔
175
            perm_pipe.results.permutation_run = perm_run
1✔
176
            PermutationTest.clear_data_and_save(perm_pipe)
1✔
177
            print('Finished permutation ' + str(perm_run) + ' ...')
1✔
UNCOV
178
        except Exception as e:
×
179
            if perm_pipe.results is not None:
×
180
                perm_pipe.results.permutation_failed = str(e)
×
181
                perm_pipe.results.save()
×
182
                print('Failed permutation ' + str(perm_run) + ' ...')
×
183
        return perm_run
1✔
184

185
    @staticmethod
1✔
186
    def _calculate_results(permutation_id, save_to_db=True, mongodb_path="mongodb://localhost:27017/photon_results"):
1✔
187

188
        logger.info("Calculating permutation test results")
1✔
189
        try:
1✔
190
            mother_permutation = PermutationTest.find_reference(mongodb_path, permutation_id)
1✔
191
            if mother_permutation is None:
1✔
UNCOV
192
                raise DoesNotExist
×
193
        except DoesNotExist:
×
194
            return None
×
195
        else:
196
            all_permutations = list(MDBHyperpipe.objects.raw({'permutation_id': permutation_id,
1✔
197
                                                              'computation_completed': True}).project({'metrics_test': 1}))
198
            # all_permutations = MDBHyperpipe.objects.raw({'permutation_id': permutation_id,
199
            #                                              'computation_completed': True}).only('metrics_test')
200
            number_of_permutations = len(all_permutations)
1✔
201
            print("Found {} permutations.".format(number_of_permutations))
1✔
202

203
            if number_of_permutations == 0:
1✔
UNCOV
204
                number_of_permutations = 1
×
205

206
            true_performances = mother_permutation.get_test_metric(operation="mean")
1✔
207
            perm_performances = dict()
1✔
208
            metric_list = list(set([m.metric_name for m in mother_permutation.metrics_test]))
1✔
209
            metrics = PermutationTest.manage_metrics(metric_list, None,
1✔
210
                                                     mother_permutation.hyperpipe_info.best_config_metric)
211

212
            for _, metric in metrics.items():
1✔
213
                perm_performances[metric["name"]] = [i.get_test_metric(metric["name"], operation="mean")
1✔
214
                                                     for i in all_permutations]
215

216
            # Calculate p-value
217
            p = PermutationTest.calculate_p(true_performance=true_performances, perm_performances=perm_performances,
1✔
218
                                            metrics=metrics, n_perms=number_of_permutations)
219
            p_text = dict()
1✔
220
            for _, metric in metrics.items():
1✔
221
                if p[metric['name']] == 0:
1✔
UNCOV
222
                    p_text[metric['name']] = "p < {}".format(str(1/number_of_permutations))
×
223
                else:
224
                    p_text[metric['name']] = "p = {}".format(p[metric['name']])
1✔
225

226
            # Print results
227
            logger.clean_info("""
1✔
228
            Done with permutations...
229

230
            Results Permutation test
231
            ===============================================
232
            """)
233
            for _, metric in metrics.items():
1✔
234
                logger.clean_info("""
1✔
235
                    Metric: {}
236
                    True Performance: {}
237
                    p Value: {}
238

239
                """.format(metric['name'], true_performances[metric['name']], p_text[metric['name']]))
240

241
            if save_to_db:
1✔
242
                # Write results to results object
243
                if mother_permutation.permutation_test is None:
1✔
UNCOV
244
                    perm_results = MDBPermutationResults(n_perms=number_of_permutations)
×
245
                else:
246
                    perm_results = mother_permutation.permutation_test
1✔
247
                perm_results.n_perms_done = number_of_permutations
1✔
248
                results_all_metrics = list()
1✔
249
                for _, metric in metrics.items():
1✔
250
                    perm_metrics = MDBPermutationMetrics(metric_name=metric['name'], p_value=p[metric['name']],
1✔
251
                                                         metric_value=true_performances[metric['name']])
252
                    perm_metrics.values_permutations = perm_performances[metric['name']]
1✔
253
                    results_all_metrics.append(perm_metrics)
1✔
254
                perm_results.metrics = results_all_metrics
1✔
255
                mother_permutation.permutation_test = perm_results
1✔
256
                mother_permutation.save()
1✔
257

258
            if mother_permutation.permutation_test is not None:
1✔
259
                n_perms = mother_permutation.permutation_test.n_perms
1✔
260
            else:
261
                # we guess?
UNCOV
262
                n_perms = 1000
×
263

264
            result = PermutationTest.PermutationResult(true_performances, perm_performances,
1✔
265
                                                       p, number_of_permutations, n_perms)
266

267
            return result
1✔
268

269
    class PermutationResult:
1✔
270

271
        def __init__(self, true_performances: dict = {}, perm_performances: dict = {},
1✔
272
                     p_values: dict = {}, n_perms_done: int = 0, n_perms: int = 0):
273

274
            self.true_performances = true_performances
1✔
275
            self.perm_performances = perm_performances
1✔
276
            self.p_values = p_values
1✔
277
            self.n_perms_done = n_perms_done
1✔
278
            self.n_perms = n_perms
1✔
279

280
    @staticmethod
1✔
281
    def find_reference(mongo_db_connect_url, permutation_id, find_wizard_id=False):
1✔
282
        def _find_mummy(permutation_id):
1✔
283
            if not find_wizard_id:
1✔
284
                return MDBHyperpipe.objects.raw(
1✔
285
                    {'permutation_id': PermutationTest.get_mother_permutation_id(permutation_id),
286
                     'computation_completed': True}).order_by([('computation_start_time', DESCENDING)]).first()
287
            else:
288
                return MDBHyperpipe.objects.raw({'wizard_object_id': permutation_id}).order_by([('computation_start_time', DESCENDING)]).first()
1✔
289

290
        try:
1✔
291
            # in case we haven't been connected try again
292
            connect(mongo_db_connect_url, alias="photon_core")
1✔
293
            mother_permutation = _find_mummy(permutation_id)
1✔
UNCOV
294
        except DoesNotExist:
×
295
            return None
×
296
        except ConnectionError:
×
297
            # in case we haven't been connected try again
UNCOV
298
            connect(mongo_db_connect_url, alias="photon_core")
×
299
            try:
×
300
                mother_permutation = _find_mummy(permutation_id)
×
301
            except DoesNotExist:
×
302
                return None
×
303
        return mother_permutation
1✔
304

305
    @staticmethod
1✔
306
    def prepare_for_wizard(permutation_id, wizard_id, mongo_db_connect_url="mongodb://localhost:27017/photon_results"):
1✔
307
        mother_permutation = PermutationTest.find_reference(mongo_db_connect_url, permutation_id=wizard_id,
1✔
308
                                                            find_wizard_id=True)
309
        mother_permutation.permutation_id = PermutationTest.get_mother_permutation_id(permutation_id)
1✔
310
        mother_permutation.save()
1✔
311
        result = dict()
1✔
312
        if mother_permutation.computation_end_time is not None and mother_permutation.computation_start_time is not None:
1✔
313
            result[
1✔
314
                "estimated_duration"] = mother_permutation.computation_end_time - mother_permutation.computation_start_time
315
        else:
UNCOV
316
            result["estimated_duration"] = timedelta(seconds=0)
×
317
        result["usability"] = PermutationTest.__validate_usability(mother_permutation)
1✔
318
        return result
1✔
319

320
    @staticmethod
1✔
321
    def __validate_usability(mother_permutation):
1✔
322
        if mother_permutation is not None:
1✔
323
            if mother_permutation.dummy_estimator:
1✔
324
                best_config_metric = mother_permutation.hyperpipe_info.best_config_metric
1✔
325
                dummy_threshold_to_beat = mother_permutation.dummy_estimator.get_test_metric(name=best_config_metric,
1✔
326
                                                                                             operation="mean")
327
                if dummy_threshold_to_beat is not None:
1✔
328
                    mother_perm_threshold = mother_permutation.get_test_metric(name=best_config_metric,
1✔
329
                                                                               operation="mean")
330
                    if mother_permutation.hyperpipe_info.maximize_best_config_metric:
1✔
331
                        if mother_perm_threshold > dummy_threshold_to_beat:
1✔
332
                            return True
1✔
333
                        else:
334
                            return False
1✔
335
                    else:
UNCOV
336
                        if mother_perm_threshold < dummy_threshold_to_beat:
×
337
                            return True
×
338
                        else:
UNCOV
339
                            return False
×
340
                else:
341
                    # we have no dummy results so we assume it should be okay
UNCOV
342
                    return True
×
343
        else:
UNCOV
344
            return None
×
345

346
    def collect_results(self, result):
1✔
347
        # This is called whenever foo_pool(i) returns a result.
348
        # result_list is modified only by the main process, not the pool workers.
UNCOV
349
        logger.info("Finished Permutation Run" + str(result))
×
350

351
    @staticmethod
1✔
352
    def calculate_p(true_performance, perm_performances, metrics, n_perms):
1✔
353
        p = dict()
1✔
354
        for _, metric in metrics.items():
1✔
355
            if metric['greater_is_better']:
1✔
356
                p[metric['name']] = (np.sum(true_performance[metric['name']] < np.asarray(perm_performances[metric['name']])) + 1)/(n_perms + 1)
1✔
357
            else:
UNCOV
358
                p[metric['name']] = (np.sum(true_performance[metric['name']] > np.asarray(perm_performances[metric['name']])) + 1)/(n_perms + 1)
×
359
        return p
1✔
360

361
    @staticmethod
1✔
362
    def set_greater_is_better(metric, last_element = None):
1✔
363
        """
364
        Set greater_is_better for metric
365
        :param string specifying metric
366
        """
367
        if metric == 'score' and last_element is not None:
1✔
368
            # if no specific metric was chosen, use default scoring method
UNCOV
369
            if hasattr(last_element.base_element, '_estimator_type'):
×
370
                greater_is_better = True
×
371
            else:
372
                # Todo: better error checking?
UNCOV
373
                logger.error('NotImplementedError: ' +
×
374
                               'No metric was chosen and last pipeline element does not specify ' +
375
                               'whether it is a classifier, regressor, transformer or ' +
376
                               'clusterer.')
UNCOV
377
                raise NotImplementedError('No metric was chosen and last pipeline element does not specify '
×
378
                                          'whether it is a classifier, regressor, transformer or '
379
                                          'clusterer.')
380
        else:
381
            greater_is_better = Scorer.greater_is_better_distinction(metric)
1✔
382
        return greater_is_better
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc