• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

comp-physics / Quantum-HRF-Tomography / 17144060545

22 Aug 2025 02:07AM UTC coverage: 95.967% (-1.4%) from 97.368%
17144060545

push

github

kaminotesf
Upgrade to 0.2.0 with bug fix

## Bug Fixes:
• Fix memory leak in tree visualization matplotlib figure cleanup
• Fix zero norm statevector handling with proper ValueError
• Fix division by zero in stabilizer entropy with α=1 validation
• Fix potential index error in get_signs with array length checks
• Fix return type mismatch in swap_test function signature
• Fix warning suppression side effects with targeted decorator

## Performance:
• Add multiprocessing for 4-8x tree generation speedup
• Optimize majority voting with pre-allocated arrays for 2-3x speedup

52 of 99 new or added lines in 3 files covered. (52.53%)

1 existing line in 1 file now uncovered.

1499 of 1562 relevant lines covered (95.97%)

0.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.6
/tests/test_sample.py
1
import unittest
1✔
2
import warnings
1✔
3
import numpy as np
1✔
4
from unittest.mock import Mock, MagicMock, patch
1✔
5
import pytest
1✔
6
from qiskit import QuantumCircuit
1✔
7
from qiskit.circuit.random import random_circuit
1✔
8
from qiskit_aer import AerSimulator
1✔
9
from qiskit.circuit.library import real_amplitudes, efficient_su2
1✔
10
from qiskit_ibm_runtime.fake_provider import FakeFez
1✔
11
from qiskit.providers import JobStatus
1✔
12
from qiskit.result import Result
1✔
13
from hadamard_random_forest.sample import (
1✔
14
    get_statevector,
15
    get_circuits,
16
    get_samples_noisy,
17
    get_circuits_hardware,
18
    get_samples_hardware
19
)
20
from hadamard_random_forest.random_forest import fix_random_seed
1✔
21

22
class TestSample(unittest.TestCase):
1✔
23

24
    def setUp(self):
1✔
25
        """Set up test setting and common test data."""
26
        self.backend_sim = AerSimulator()
1✔
27
        self.fake_backend = FakeFez()
1✔
28
        
29
        # Common test parameters
30
        self.test_qubit_counts = [2, 3, 4]
1✔
31
        self.test_shots = 1024
1✔
32
        
33
        # Create sample circuits for testing
34
        self.simple_circuits = {}
1✔
35
        self.complex_circuits = {}
1✔
36
        
37
        for num_qubits in self.test_qubit_counts:
1✔
38
            # Simple circuit (real_amplitudes)
39
            self.simple_circuits[num_qubits] = real_amplitudes(
1✔
40
                num_qubits, reps=1, insert_barriers=False
41
            )
42
            
43
            # More complex circuit (efficient_su2)
44
            self.complex_circuits[num_qubits] = efficient_su2(
1✔
45
                num_qubits, reps=2, insert_barriers=False
46
            )
47

48
    def _create_test_samples(self, num_qubits: int, normalize: bool = True) -> list:
1✔
49
        """Create valid test sample data."""
50
        samples = [np.random.rand(2**num_qubits) for _ in range(num_qubits + 1)]
1✔
51
        if normalize:
1✔
52
            # Normalize to valid probability distributions
53
            samples = [s / np.sum(s) for s in samples]
1✔
54
        return samples
1✔
55

56
    def _create_mock_counts(self, num_qubits: int, shots: int = 1024) -> dict:
1✔
57
        """Create mock measurement counts for testing."""
58
        counts = {}
1✔
59
        n_states = 2**num_qubits
1✔
60
        # Distribute shots randomly across computational basis states
61
        remaining_shots = shots
1✔
62
        for i in range(n_states - 1):
1✔
63
            count = np.random.randint(0, remaining_shots // (n_states - i) + 1)
1✔
64
            if count > 0:
1✔
65
                counts[format(i, f'0{num_qubits}b')] = count
1✔
66
                remaining_shots -= count
1✔
67
        
68
        # Assign remaining shots to last state
69
        if remaining_shots > 0:
1✔
70
            counts[format(n_states - 1, f'0{num_qubits}b')] = remaining_shots
1✔
71
            
72
        return counts
1✔
73

74
    def test_get_circuits_basic(self):
1✔
75
        """Test basic functionality of get_circuits."""
76
        for num_qubits in self.test_qubit_counts:
1✔
77
            with self.subTest(num_qubits=num_qubits):
1✔
78
                base_circuit = self.simple_circuits[num_qubits]
1✔
79
                circuits = get_circuits(num_qubits, base_circuit)
1✔
80
                
81
                # Verify structure
82
                self.assertIsInstance(circuits, list)
1✔
83
                self.assertEqual(len(circuits), num_qubits + 1)  # Base + H variants
1✔
84
                
85
                # Verify all circuits have correct qubit count
86
                for circuit in circuits:
1✔
87
                    self.assertEqual(circuit.num_qubits, num_qubits)
1✔
88
                    self.assertIsNotNone(circuit)
1✔
89
                    # Verify measurements are present
90
                    self.assertTrue(any(op.operation.name == 'measure' for op in circuit.data))
1✔
91

92
    def test_get_circuits_structure_validation(self):
1✔
93
        """Test that get_circuits preserves base circuit structure correctly."""
94
        num_qubits = 3
1✔
95
        base_circuit = real_amplitudes(num_qubits, reps=2)
1✔
96
        circuits = get_circuits(num_qubits, base_circuit)
1✔
97
        
98
        # Subsequent circuits should have exactly one additional H gate
99
        for i in range(1, len(circuits)):
1✔
100
            circuit = circuits[i]
1✔
101
            
102
            # Count H gates in the circuit
103
            h_count = sum(1 for op in circuit.data if op.operation.name == 'h')
1✔
104
            
105
            # Should have exactly one H gate more than base circuit
106
            # (base circuit shouldn't have H gates for real_amplitudes)
107
            self.assertGreaterEqual(h_count, 1)
1✔
108
            
109
            # Verify the circuit has measurements
110
            measure_count = sum(1 for op in circuit.data if op.operation.name == 'measure')
1✔
111
            self.assertEqual(measure_count, num_qubits)
1✔
112

113
    def test_get_circuits_different_ansatz_types(self):
1✔
114
        """Test get_circuits with different circuit types."""
115
        test_cases = [
1✔
116
            ("real_amplitudes", self.simple_circuits),
117
            ("efficient_su2", self.complex_circuits)
118
        ]
119
        
120
        for ansatz_name, circuit_dict in test_cases:
1✔
121
            for num_qubits in self.test_qubit_counts:
1✔
122
                with self.subTest(ansatz=ansatz_name, num_qubits=num_qubits):
1✔
123
                    base_circuit = circuit_dict[num_qubits]
1✔
124
                    circuits = get_circuits(num_qubits, base_circuit)
1✔
125
                    
126
                    self.assertEqual(len(circuits), num_qubits + 1)
1✔
127
                    
128
                    # Verify parameters are preserved
129
                    if base_circuit.num_parameters > 0:
1✔
130
                        for circuit in circuits:
1✔
131
                            # After composition, parameters should still be present
132
                            # (though measurement might add classical registers)
133
                            self.assertGreaterEqual(circuit.num_parameters, 0)
1✔
134

135
    def test_get_circuits_parameter_preservation(self):
1✔
136
        """Test that circuit parameters are preserved during get_circuits."""
137
        num_qubits = 3
1✔
138
        base_circuit = real_amplitudes(num_qubits, reps=2)
1✔
139
        original_params = base_circuit.num_parameters
1✔
140
        
141
        circuits = get_circuits(num_qubits, base_circuit)
1✔
142
        
143
        # All circuits should preserve the original parameters
144
        for circuit in circuits:
1✔
145
            self.assertEqual(circuit.num_parameters, original_params)
1✔
146

147
    def test_get_samples_noisy_without_mitigation(self):
1✔
148
        """Test get_samples_noisy without error mitigation."""
149
        for num_qubits in [2, 3]:  # Use smaller systems for faster testing
1✔
150
            with self.subTest(num_qubits=num_qubits):
1✔
151
                base_circuit = self.simple_circuits[num_qubits]
1✔
152
                circuits = get_circuits(num_qubits, base_circuit)
1✔
153
                parameters = np.random.rand(base_circuit.num_parameters)
1✔
154
                
155
                samples = get_samples_noisy(
1✔
156
                    num_qubits=num_qubits,
157
                    circuits=circuits,
158
                    shots=self.test_shots,
159
                    parameters=parameters,
160
                    backend_sim=self.backend_sim,
161
                    error_mitigation=False
162
                )
163
                
164
                # Verify output structure
165
                self.assertIsInstance(samples, list)
1✔
166
                self.assertEqual(len(samples), num_qubits + 1)
1✔
167
                
168
                # Verify each sample array
169
                for sample in samples:
1✔
170
                    self.assertIsInstance(sample, np.ndarray)
1✔
171
                    self.assertEqual(sample.shape, (2**num_qubits,))
1✔
172
                    self.assertTrue(np.all(sample >= 0))  # Probabilities non-negative
1✔
173
                    self.assertAlmostEqual(np.sum(sample), 1.0, places=6)  # Normalized
1✔
174

175
    @patch('hadamard_random_forest.sample.M3Mitigation')
1✔
176
    @patch('hadamard_random_forest.sample.mthree_utils.final_measurement_mapping')
1✔
177
    def test_get_samples_noisy_with_mitigation(self, mock_mapping, mock_m3):
1✔
178
        """Test get_samples_noisy with error mitigation using mocks."""
179
        num_qubits = 2
1✔
180
        base_circuit = self.simple_circuits[num_qubits]
1✔
181
        circuits = get_circuits(num_qubits, base_circuit)
1✔
182
        parameters = np.random.rand(base_circuit.num_parameters)
1✔
183
        
184
        # Mock the M3 mitigation objects
185
        mock_mit_instance = MagicMock()
1✔
186
        mock_m3.return_value = mock_mit_instance
1✔
187
        
188
        # Mock mapping function
189
        mock_mapping.return_value = [0, 1]
1✔
190
        
191
        # Mock the mitigation correction
192
        mock_quasi = MagicMock()
1✔
193
        mock_quasi.nearest_probability_distribution.return_value = {
1✔
194
            '00': 0.25, '01': 0.25, '10': 0.25, '11': 0.25
195
        }
196
        mock_mit_instance.apply_correction.return_value = mock_quasi
1✔
197
        
198
        # Mock the backend to return predictable counts
199
        mock_job = MagicMock()
1✔
200
        mock_result = MagicMock()
1✔
201
        mock_job.result.return_value = mock_result
1✔
202
        mock_result.get_counts.return_value = self._create_mock_counts(num_qubits, self.test_shots)
1✔
203
        
204
        with patch.object(self.backend_sim, 'run', return_value=mock_job):
1✔
205
            samples = get_samples_noisy(
1✔
206
                num_qubits=num_qubits,
207
                circuits=circuits,
208
                shots=self.test_shots,
209
                parameters=parameters,
210
                backend_sim=self.backend_sim,
211
                error_mitigation=True
212
            )
213
        
214
        # Verify mitigation was called
215
        self.assertTrue(mock_m3.called)
1✔
216
        self.assertTrue(mock_mit_instance.apply_correction.called)
1✔
217
        
218
        # Verify output structure
219
        self.assertIsInstance(samples, list)
1✔
220
        self.assertEqual(len(samples), num_qubits + 1)
1✔
221
        
222
        for sample in samples:
1✔
223
            self.assertIsInstance(sample, np.ndarray)
1✔
224
            self.assertEqual(sample.shape, (2**num_qubits,))
1✔
225

226
    def test_get_samples_noisy_different_shot_counts(self):
1✔
227
        """Test get_samples_noisy with different shot counts."""
228
        num_qubits = 2
1✔
229
        base_circuit = self.simple_circuits[num_qubits]
1✔
230
        circuits = get_circuits(num_qubits, base_circuit)
1✔
231
        parameters = np.random.rand(base_circuit.num_parameters)
1✔
232
        
233
        shot_counts = [100, 1000, 10000]
1✔
234
        
235
        for shots in shot_counts:
1✔
236
            with self.subTest(shots=shots):
1✔
237
                samples = get_samples_noisy(
1✔
238
                    num_qubits=num_qubits,
239
                    circuits=circuits,
240
                    shots=shots,
241
                    parameters=parameters,
242
                    backend_sim=self.backend_sim,
243
                    error_mitigation=False
244
                )
245
                
246
                # Basic validation
247
                self.assertEqual(len(samples), num_qubits + 1)
1✔
248
                for sample in samples:
1✔
249
                    self.assertEqual(sample.shape, (2**num_qubits,))
1✔
250
                    # Higher shot counts should generally give more precise results
251
                    self.assertTrue(np.all(sample >= 0))
1✔
252

253
    def test_get_samples_noisy_parameter_assignment(self):
1✔
254
        """Test that parameters are correctly assigned to circuits."""
255
        num_qubits = 3
1✔
256
        base_circuit = self.simple_circuits[num_qubits]
1✔
257
        circuits = get_circuits(num_qubits, base_circuit)
1✔
258
        
259
        # Test with different parameter values
260
        param_sets = [
1✔
261
            np.zeros(base_circuit.num_parameters),
262
            np.ones(base_circuit.num_parameters) * 0.5,
263
            np.random.rand(base_circuit.num_parameters) * 2 * np.pi
264
        ]
265
        
266
        for i, parameters in enumerate(param_sets):
1✔
267
            with self.subTest(param_set=i):
1✔
268
                samples = get_samples_noisy(
1✔
269
                    num_qubits=num_qubits,
270
                    circuits=circuits,
271
                    shots=self.test_shots,
272
                    parameters=parameters,
273
                    backend_sim=self.backend_sim,
274
                    error_mitigation=False
275
                )
276
                
277
                # Different parameters should generally produce different results
278
                self.assertEqual(len(samples), num_qubits + 1)
1✔
279
                for sample in samples:
1✔
280
                    self.assertTrue(np.isfinite(sample).all())
1✔
281

282
    def test_get_circuits_hardware_basic(self):
1✔
283
        """Test basic functionality of get_circuits_hardware."""
284
        for num_qubits in [2, 3]:  # Use smaller systems for transpilation
1✔
285
            with self.subTest(num_qubits=num_qubits):
1✔
286
                base_circuit = self.simple_circuits[num_qubits]
1✔
287
                
288
                circuits = get_circuits_hardware(
1✔
289
                    num_qubits=num_qubits,
290
                    base_circuit=base_circuit,
291
                    device=self.fake_backend
292
                )
293
                
294
                # Verify structure
295
                self.assertIsInstance(circuits, list)
1✔
296
                self.assertEqual(len(circuits), num_qubits + 1)
1✔
297
                
298
                # Verify all circuits are transpiled (should have different structure)
299
                for circuit in circuits:
1✔
300
                    self.assertIsNotNone(circuit)
1✔
301
                    # Transpiled circuits may have different qubit counts due to routing
302
                    self.assertGreaterEqual(circuit.num_qubits, num_qubits)
1✔
303
                    # Verify measurements are present
304
                    self.assertTrue(any(op.operation.name == 'measure' for op in circuit.data))
1✔
305

306
    def test_get_circuits_hardware_transpilation(self):
1✔
307
        """Test that transpilation occurs correctly."""
308
        num_qubits = 3
1✔
309
        base_circuit = self.simple_circuits[num_qubits]
1✔
310
        
311
        # Compare transpiled vs non-transpiled
312
        regular_circuits = get_circuits(num_qubits, base_circuit)
1✔
313
        hardware_circuits = get_circuits_hardware(
1✔
314
            num_qubits=num_qubits,
315
            base_circuit=base_circuit,
316
            device=self.fake_backend
317
        )
318
        
319
        # Same number of circuits
320
        self.assertEqual(len(regular_circuits), len(hardware_circuits))
1✔
321
        
322
        # Hardware circuits should generally have more gates due to decomposition
323
        for i, (reg_circuit, hw_circuit) in enumerate(zip(regular_circuits, hardware_circuits)):
1✔
324
            with self.subTest(circuit_index=i):
1✔
325
                # Hardware circuits may have more operations due to transpilation
326
                self.assertGreaterEqual(len(hw_circuit.data), 0)
1✔
327
                # Both should have measurements
328
                reg_has_measure = any(op.operation.name == 'measure' for op in reg_circuit.data)
1✔
329
                hw_has_measure = any(op.operation.name == 'measure' for op in hw_circuit.data)
1✔
330
                self.assertTrue(reg_has_measure)
1✔
331
                self.assertTrue(hw_has_measure)
1✔
332

333
    def test_get_circuits_hardware_different_backends(self):
1✔
334
        """Test get_circuits_hardware with different backend types."""
335
        num_qubits = 2
1✔
336
        base_circuit = self.simple_circuits[num_qubits]
1✔
337
        
338
        # Test with different backends
339
        backends = [
1✔
340
            self.fake_backend,
341
            AerSimulator.from_backend(self.fake_backend)
342
        ]
343
        
344
        for i, backend in enumerate(backends):
1✔
345
            with self.subTest(backend_type=i):
1✔
346
                circuits = get_circuits_hardware(
1✔
347
                    num_qubits=num_qubits,
348
                    base_circuit=base_circuit,
349
                    device=backend
350
                )
351
                
352
                self.assertEqual(len(circuits), num_qubits + 1)
1✔
353
                for circuit in circuits:
1✔
354
                    self.assertIsNotNone(circuit)
1✔
355
                    # Verify the circuit is executable (has valid structure)
356
                    self.assertGreater(len(circuit.data), 0)
1✔
357

358
    @patch('hadamard_random_forest.sample.generate_preset_pass_manager')
1✔
359
    def test_get_circuits_hardware_pass_manager_usage(self, mock_pm):
1✔
360
        """Test that pass manager is used correctly."""
361
        num_qubits = 2
1✔
362
        base_circuit = self.simple_circuits[num_qubits]
1✔
363
        
364
        # Mock the pass manager
365
        mock_pm_instance = MagicMock()
1✔
366
        mock_pm.return_value = mock_pm_instance
1✔
367
        mock_pm_instance.run.side_effect = lambda x: x  # Return circuit unchanged
1✔
368
        
369
        circuits = get_circuits_hardware(
1✔
370
            num_qubits=num_qubits,
371
            base_circuit=base_circuit,
372
            device=self.fake_backend
373
        )
374
        
375
        # Verify pass manager was created and used
376
        mock_pm.assert_called_once()
1✔
377
        # Should be called once for each circuit (num_qubits + 1)
378
        self.assertEqual(mock_pm_instance.run.call_count, num_qubits + 1)
1✔
379
        # Verify we got the expected number of circuits
380
        self.assertEqual(len(circuits), num_qubits + 1)
1✔
381

382
    @patch('hadamard_random_forest.sample.Sampler')
1✔
383
    @patch('hadamard_random_forest.sample.mthree_utils.final_measurement_mapping')
1✔
384
    def test_get_samples_hardware_basic(self, mock_mapping, mock_sampler_class):
1✔
385
        """Test basic functionality of get_samples_hardware."""
386
        num_qubits = 2
1✔
387
        base_circuit = self.simple_circuits[num_qubits]
1✔
388
        circuits = get_circuits_hardware(num_qubits, base_circuit, self.fake_backend)
1✔
389
        parameters = np.random.rand(base_circuit.num_parameters)
1✔
390
        shots = 1024
1✔
391
        
392
        # Mock the sampler and its results
393
        mock_sampler = MagicMock()
1✔
394
        mock_sampler_class.return_value = mock_sampler
1✔
395
        
396
        # Mock job and result
397
        mock_job = MagicMock()
1✔
398
        mock_result = MagicMock()
1✔
399
        mock_data = MagicMock()
1✔
400
        mock_meas = MagicMock()
1✔
401
        
402
        # Set up the mock chain
403
        mock_sampler.run.return_value = mock_job
1✔
404
        mock_job.result.return_value = [mock_result]
1✔
405
        mock_result.data = mock_data
1✔
406
        mock_data.meas = mock_meas
1✔
407
        mock_meas.get_counts.return_value = self._create_mock_counts(num_qubits, shots)
1✔
408
        mock_job.job_id.return_value = f"job_test_{np.random.randint(1000)}"
1✔
409
        mock_job.usage_estimation = {'quantum_seconds': 1.23}
1✔
410
        
411
        # Mock mapping
412
        mock_mapping.return_value = list(range(num_qubits))
1✔
413
        
414
        result = get_samples_hardware(
1✔
415
            num_qubits=num_qubits,
416
            shots=shots,
417
            circuits=circuits,
418
            parameters=parameters,
419
            device=self.fake_backend,
420
            error_mitigation=False
421
        )
422
        
423
        # Verify return structure
424
        self.assertIsInstance(result, tuple)
1✔
425
        self.assertEqual(len(result), 4)
1✔
426
        
427
        mitigated_samples, raw_samples, job_ids, quantum_times = result
1✔
428
        
429
        # Verify samples
430
        self.assertIsInstance(mitigated_samples, list)
1✔
431
        self.assertIsInstance(raw_samples, list)
1✔
432
        self.assertEqual(len(mitigated_samples), num_qubits + 1)
1✔
433
        self.assertEqual(len(raw_samples), num_qubits + 1)
1✔
434
        
435
        for sample in mitigated_samples + raw_samples:
1✔
436
            self.assertIsInstance(sample, np.ndarray)
1✔
437
            self.assertEqual(sample.shape, (2**num_qubits,))
1✔
438
            self.assertTrue(np.all(sample >= 0))
1✔
439
        
440
        # Verify job metadata
441
        self.assertIsInstance(job_ids, list)
1✔
442
        self.assertIsInstance(quantum_times, list)
1✔
443
        self.assertEqual(len(job_ids), num_qubits + 1)
1✔
444
        self.assertEqual(len(quantum_times), num_qubits + 1)
1✔
445

446
    @patch('hadamard_random_forest.sample.Sampler')
1✔
447
    @patch('hadamard_random_forest.sample.mthree.M3Mitigation')
1✔
448
    @patch('hadamard_random_forest.sample.mthree_utils.final_measurement_mapping')
1✔
449
    def test_get_samples_hardware_with_mitigation(self, mock_mapping, mock_m3, mock_sampler_class):
1✔
450
        """Test get_samples_hardware with error mitigation."""
451
        num_qubits = 2
1✔
452
        base_circuit = self.simple_circuits[num_qubits]
1✔
453
        circuits = get_circuits_hardware(num_qubits, base_circuit, self.fake_backend)
1✔
454
        parameters = np.random.rand(base_circuit.num_parameters)
1✔
455
        shots = 1024
1✔
456
        
457
        # Mock sampler
458
        mock_sampler = MagicMock()
1✔
459
        mock_sampler_class.return_value = mock_sampler
1✔
460
        
461
        # Mock job results
462
        mock_job = MagicMock()
1✔
463
        mock_result = MagicMock()
1✔
464
        mock_data = MagicMock()
1✔
465
        mock_meas = MagicMock()
1✔
466
        
467
        mock_sampler.run.return_value = mock_job
1✔
468
        mock_job.result.return_value = [mock_result]
1✔
469
        mock_result.data = mock_data
1✔
470
        mock_data.meas = mock_meas
1✔
471
        mock_meas.get_counts.return_value = self._create_mock_counts(num_qubits, shots)
1✔
472
        mock_job.job_id.return_value = "test_job_with_mitigation"
1✔
473
        mock_job.usage_estimation = {'quantum_seconds': 2.45}
1✔
474
        
475
        # Mock M3 mitigation
476
        mock_mit_instance = MagicMock()
1✔
477
        mock_m3.return_value = mock_mit_instance
1✔
478
        mock_quasi = MagicMock()
1✔
479
        mock_quasi.nearest_probability_distribution.return_value = {
1✔
480
            '00': 0.25, '01': 0.25, '10': 0.25, '11': 0.25
481
        }
482
        mock_mit_instance.apply_correction.return_value = mock_quasi
1✔
483
        
484
        # Mock mapping
485
        mock_mapping.return_value = list(range(num_qubits))
1✔
486
        
487
        result = get_samples_hardware(
1✔
488
            num_qubits=num_qubits,
489
            shots=shots,
490
            circuits=circuits,
491
            parameters=parameters,
492
            device=self.fake_backend,
493
            error_mitigation=True
494
        )
495
        
496
        # Verify mitigation was used
497
        self.assertTrue(mock_m3.called)
1✔
498
        self.assertTrue(mock_mit_instance.apply_correction.called)
1✔
499
        
500
        # Verify structure
501
        mitigated_samples, raw_samples, job_ids, quantum_times = result
1✔
502
        self.assertEqual(len(mitigated_samples), num_qubits + 1)
1✔
503
        self.assertEqual(len(raw_samples), num_qubits + 1)
1✔
504
        self.assertEqual(len(job_ids), num_qubits + 1)
1✔
505
        self.assertEqual(len(quantum_times), num_qubits + 1)
1✔
506
        
507
        # Mitigated and raw samples should be different arrays
508
        for i in range(len(mitigated_samples)):
1✔
509
            # They should have the same shape but potentially different values
510
            self.assertEqual(mitigated_samples[i].shape, raw_samples[i].shape)
1✔
511

512
    @patch('hadamard_random_forest.sample.Sampler')
1✔
513
    def test_get_samples_hardware_job_tracking(self, mock_sampler_class):
1✔
514
        """Test that job IDs and quantum times are correctly tracked."""
515
        num_qubits = 2
1✔
516
        base_circuit = self.simple_circuits[num_qubits]
1✔
517
        circuits = get_circuits_hardware(num_qubits, base_circuit, self.fake_backend)
1✔
518
        parameters = np.random.rand(base_circuit.num_parameters)
1✔
519
        
520
        # Mock sampler with unique job IDs and times
521
        mock_sampler = MagicMock()
1✔
522
        mock_sampler_class.return_value = mock_sampler
1✔
523
        
524
        # Create unique job mocks
525
        job_ids = [f"job_{i}" for i in range(num_qubits + 1)]
1✔
526
        quantum_times = [i * 0.5 + 1.0 for i in range(num_qubits + 1)]
1✔
527
        
528
        def create_mock_job(job_id, q_time):
1✔
529
            mock_job = MagicMock()
1✔
530
            mock_result = MagicMock()
1✔
531
            mock_data = MagicMock()
1✔
532
            mock_meas = MagicMock()
1✔
533
            
534
            mock_job.result.return_value = [mock_result]
1✔
535
            mock_result.data = mock_data
1✔
536
            mock_data.meas = mock_meas
1✔
537
            mock_meas.get_counts.return_value = self._create_mock_counts(num_qubits, 1024)
1✔
538
            mock_job.job_id.return_value = job_id
1✔
539
            mock_job.usage_estimation = {'quantum_seconds': q_time}
1✔
540
            
541
            return mock_job
1✔
542
        
543
        # Set up sampler to return different jobs
544
        mock_jobs = [create_mock_job(jid, qt) for jid, qt in zip(job_ids, quantum_times)]
1✔
545
        mock_sampler.run.side_effect = mock_jobs
1✔
546
        
547
        result = get_samples_hardware(
1✔
548
            num_qubits=num_qubits,
549
            shots=1024,
550
            circuits=circuits,
551
            parameters=parameters,
552
            device=self.fake_backend,
553
            error_mitigation=False
554
        )
555
        
556
        _, _, returned_job_ids, returned_quantum_times = result
1✔
557
        
558
        # Verify job tracking
559
        self.assertEqual(returned_job_ids, job_ids)
1✔
560
        self.assertEqual(returned_quantum_times, quantum_times)
1✔
561

562
    def test_get_statevector_basic(self):
1✔
563
        """Test basic functionality of get_statevector."""
564
        for num_qubits in [2, 3]:
1✔
565
            with self.subTest(num_qubits=num_qubits):
1✔
566
                num_trees = 5
1✔
567
                samples = self._create_test_samples(num_qubits, normalize=True)
1✔
568
                
569
                statevector = get_statevector(
1✔
570
                    num_qubits=num_qubits,
571
                    num_trees=num_trees,
572
                    samples=samples,
573
                    save_tree=False,
574
                    show_tree=False
575
                )
576
                
577
                # Verify structure
578
                self.assertIsInstance(statevector, np.ndarray)
1✔
579
                self.assertEqual(statevector.shape, (2**num_qubits,))
1✔
580
                
581
                # Verify properties
582
                self.assertTrue(np.all(np.isfinite(statevector)))
1✔
583
                self.assertGreater(np.linalg.norm(statevector), 0)
1✔
584
                # Should be normalized
585
                self.assertAlmostEqual(np.linalg.norm(statevector), 1.0, places=6)
1✔
586

587
    def test_get_statevector_negative_probabilities_warning(self):
1✔
588
        """Test that get_statevector handles negative probabilities with warning."""
589
        num_qubits = 2
1✔
590
        num_trees = 3
1✔
591
        
592
        # Create samples with some negative values
593
        samples = self._create_test_samples(num_qubits, normalize=False)
1✔
594
        samples[0][1] = -0.1  # Add negative value to first sample
1✔
595
        
596
        with warnings.catch_warnings(record=True) as w:
1✔
597
            warnings.simplefilter("always")
1✔
598
            statevector = get_statevector(
1✔
599
                num_qubits=num_qubits,
600
                num_trees=num_trees,
601
                samples=samples,
602
                save_tree=False,
603
                show_tree=False
604
            )
605
            
606
            # Check that warning was issued
607
            self.assertEqual(len(w), 1)
1✔
608
            self.assertTrue(issubclass(w[0].category, UserWarning))
1✔
609
            self.assertIn("Negative sample probabilities", str(w[0].message))
1✔
610
        
611
        # Result should still be valid
612
        self.assertIsInstance(statevector, np.ndarray)
1✔
613
        self.assertTrue(np.all(np.isfinite(statevector)))
1✔
614

615

616
    def test_get_statevector_different_tree_counts(self):
1✔
617
        """Test get_statevector with different numbers of trees."""
618
        num_qubits = 2
1✔
619
        samples = self._create_test_samples(num_qubits, normalize=True)
1✔
620
        
621
        tree_counts = [1, 3, 5, 11]  # Include odd numbers for majority voting
1✔
622
        
623
        for num_trees in tree_counts:
1✔
624
            with self.subTest(num_trees=num_trees):
1✔
625
                with warnings.catch_warnings():
1✔
626
                    warnings.simplefilter("ignore", UserWarning)
1✔
627
                    statevector = get_statevector(
1✔
628
                        num_qubits=num_qubits,
629
                        num_trees=num_trees,
630
                        samples=samples,
631
                        save_tree=False,
632
                        show_tree=False
633
                    )
634
                
635
                self.assertEqual(statevector.shape, (2**num_qubits,))
1✔
636
                self.assertAlmostEqual(np.linalg.norm(statevector), 1.0, places=6)
1✔
637

638
    def test_get_statevector_normalization(self):
1✔
639
        """Test that get_statevector properly normalizes output."""
640
        num_qubits = 3
1✔
641
        num_trees = 5
1✔
642
        samples = self._create_test_samples(num_qubits, normalize=True)
1✔
643
        
644
        with warnings.catch_warnings():
1✔
645
            warnings.simplefilter("ignore", UserWarning)
1✔
646
            statevector = get_statevector(
1✔
647
                num_qubits=num_qubits,
648
                num_trees=num_trees,
649
                samples=samples,
650
                save_tree=False,
651
                show_tree=False
652
            )
653
        
654
        # Test normalization
655
        norm = np.linalg.norm(statevector)
1✔
656
        self.assertAlmostEqual(norm, 1.0, places=6)
1✔
657
        
658
        # Test that all elements are reasonable
659
        self.assertTrue(np.all(np.abs(statevector) <= 1.0))
1✔
660

661
    def test_get_statevector_reproducibility(self):
1✔
662
        """Test that get_statevector gives reproducible results with same input."""
663
        num_qubits = 2
1✔
664
        num_trees = 5
1✔
665
        samples = self._create_test_samples(num_qubits, normalize=True)
1✔
666
        
667
        # Run twice with same samples
668
        with warnings.catch_warnings():
1✔
669
            warnings.simplefilter("ignore", UserWarning)
1✔
670
            
671
            # Fix the random seed before each call
672
            fix_random_seed(42)
1✔
673
            statevector1 = get_statevector(
1✔
674
                num_qubits=num_qubits,
675
                num_trees=num_trees,
676
                samples=samples.copy(),
677
                save_tree=False,
678
                show_tree=False
679
            )
680
            
681
            fix_random_seed(42)
1✔
682
            statevector2 = get_statevector(
1✔
683
                num_qubits=num_qubits,
684
                num_trees=num_trees,
685
                samples=samples.copy(),
686
                save_tree=False,
687
                show_tree=False
688
            )
689
        
690
        # Results should be identical with same seed
691
        np.testing.assert_array_almost_equal(statevector1, statevector2, decimal=10)
1✔
692

693

694
class TestSampleErrorHandling(unittest.TestCase):
1✔
695
    """Test error handling and input validation for sample functions."""
696

697
    def setUp(self):
1✔
698
        """Set up test fixtures."""
699
        self.backend_sim = AerSimulator()
1✔
700
        self.fake_backend = FakeFez()
1✔
701

702
    def test_get_circuits_invalid_inputs(self):
1✔
703
        """Test get_circuits with edge case inputs."""
704
        # Test with negative qubits - should work but produce empty range
705
        try:
1✔
706
            circuits = get_circuits(-1, QuantumCircuit(1))
1✔
707
            # Should return just the base circuit with measurements
708
            self.assertEqual(len(circuits), 1)  # -1 + 1 = 0 additional circuits
1✔
709
        except Exception:
×
710
            # Negative qubits might raise an exception, which is also acceptable
711
            pass
×
712
        
713
        # Test with mismatched qubit count - function works but may produce unexpected results
714
        circuit_2q = QuantumCircuit(2)
1✔
715
        circuits = get_circuits(3, circuit_2q)  # Circuit has 2 qubits, asking for 3
1✔
716
        # Should still return 4 circuits (3+1), though the extra H gates will be on non-existent qubits
717
        self.assertEqual(len(circuits), 4)
1✔
718

719
    def test_get_samples_noisy_invalid_parameters(self):
1✔
720
        """Test get_samples_noisy with invalid parameter arrays."""
721
        num_qubits = 2
1✔
722
        base_circuit = real_amplitudes(num_qubits, reps=1)
1✔
723
        circuits = get_circuits(num_qubits, base_circuit)
1✔
724
        
725
        # Test with wrong parameter count
726
        wrong_params = np.random.rand(base_circuit.num_parameters + 1)
1✔
727
        
728
        with self.assertRaises(Exception):
1✔
729
            get_samples_noisy(
1✔
730
                num_qubits=num_qubits,
731
                circuits=circuits,
732
                shots=1024,
733
                parameters=wrong_params,
734
                backend_sim=self.backend_sim,
735
                error_mitigation=False
736
            )
737

738
    def test_get_samples_noisy_zero_shots(self):
1✔
739
        """Test get_samples_noisy with zero shots."""
740
        num_qubits = 2
1✔
741
        base_circuit = real_amplitudes(num_qubits, reps=1)
1✔
742
        circuits = get_circuits(num_qubits, base_circuit)
1✔
743
        parameters = np.random.rand(base_circuit.num_parameters)
1✔
744
        
745
        # Zero shots should either raise error or return empty results
746
        with self.assertRaises(Exception):
1✔
747
            get_samples_noisy(
1✔
748
                num_qubits=num_qubits,
749
                circuits=circuits,
750
                shots=0,
751
                parameters=parameters,
752
                backend_sim=self.backend_sim,
753
                error_mitigation=False
754
            )
755

756
    def test_get_statevector_invalid_samples(self):
1✔
757
        """Test get_statevector with invalid sample inputs."""
758
        num_qubits = 2
1✔
759
        num_trees = 3
1✔
760
        
761
        # Test with wrong number of samples
762
        wrong_samples = [np.random.rand(4) for _ in range(2)]  # Should be 3 samples for 2 qubits
1✔
763
        
764
        with self.assertRaises(Exception):
1✔
765
            get_statevector(num_qubits, num_trees, wrong_samples, save_tree=False)
1✔
766
        
767
        # Test with wrong sample dimensions
768
        wrong_dim_samples = [np.random.rand(8) for _ in range(3)]  # Should be 4 elements for 2 qubits
1✔
769
        
770
        with self.assertRaises(Exception):
1✔
771
            get_statevector(num_qubits, num_trees, wrong_dim_samples, save_tree=False)
1✔
772

773
    def test_get_statevector_empty_samples(self):
1✔
774
        """Test get_statevector with empty or null samples."""
775
        num_qubits = 2
1✔
776
        num_trees = 3
1✔
777
        
778
        # Test with empty list
779
        with self.assertRaises(Exception):
1✔
780
            get_statevector(num_qubits, num_trees, [], save_tree=False)
1✔
781
        
782
        # Test with samples containing zeros
783
        zero_samples = [np.zeros(4) for _ in range(3)]
1✔
784
        
785
        with warnings.catch_warnings():
1✔
786
            warnings.simplefilter("ignore", UserWarning)
1✔
787
            try:
1✔
788
                result = get_statevector(num_qubits, num_trees, zero_samples, save_tree=False)
1✔
789
                # Should either work or raise an exception, but not crash
UNCOV
790
                self.assertIsInstance(result, np.ndarray)
×
791
            except Exception:
1✔
792
                # Zero samples might legitimately cause errors
793
                pass
1✔
794

795
    def test_get_circuits_hardware_invalid_backend(self):
1✔
796
        """Test get_circuits_hardware with invalid backend."""
797
        num_qubits = 2
1✔
798
        base_circuit = real_amplitudes(num_qubits)
1✔
799
        
800
        # Test with None backend - may raise exception when passed to generate_preset_pass_manager
801
        try:
1✔
802
            circuits = get_circuits_hardware(num_qubits, base_circuit, None)
1✔
803
            # If it doesn't raise an exception, that's also acceptable behavior
804
            self.assertIsInstance(circuits, list)
1✔
805
        except Exception:
×
806
            # None backend should raise an exception in generate_preset_pass_manager
807
            pass
×
808

809
    @patch('hadamard_random_forest.sample.Sampler')
1✔
810
    def test_get_samples_hardware_failed_jobs(self, mock_sampler_class):
1✔
811
        """Test get_samples_hardware handling of failed jobs."""
812
        num_qubits = 2
1✔
813
        base_circuit = real_amplitudes(num_qubits)
1✔
814
        
815
        # Suppress mthree deprecation warnings from external library
816
        with warnings.catch_warnings():
1✔
817
            warnings.filterwarnings("ignore", category=DeprecationWarning, module="mthree.utils")
1✔
818
            circuits = get_circuits_hardware(num_qubits, base_circuit, self.fake_backend)
1✔
819
        
820
        parameters = np.random.rand(base_circuit.num_parameters)
1✔
821
        
822
        # Mock sampler that raises an exception
823
        mock_sampler = MagicMock()
1✔
824
        mock_sampler_class.return_value = mock_sampler
1✔
825
        mock_sampler.run.side_effect = Exception("Job submission failed")
1✔
826
        
827
        with self.assertRaises(Exception):
1✔
828
            get_samples_hardware(
1✔
829
                num_qubits=num_qubits,
830
                shots=1024,
831
                circuits=circuits,
832
                parameters=parameters,
833
                device=self.fake_backend,
834
                error_mitigation=False
835
            )
836

837
    def test_parameter_dimension_mismatch(self):
1✔
838
        """Test functions with parameter dimension mismatches."""
839
        num_qubits = 3
1✔
840
        base_circuit = real_amplitudes(num_qubits, reps=2)
1✔
841
        circuits = get_circuits(num_qubits, base_circuit)
1✔
842
        
843
        # Create parameters with wrong dimensions
844
        wrong_params_1d = np.random.rand(5)  # Wrong count
1✔
845
        wrong_params_2d = np.random.rand(2, 2)  # Wrong shape
1✔
846
        
847
        test_params = [wrong_params_1d, wrong_params_2d]
1✔
848
        
849
        for params in test_params:
1✔
850
            with self.subTest(params_shape=params.shape):
1✔
851
                with self.assertRaises(Exception):
1✔
852
                    get_samples_noisy(
1✔
853
                        num_qubits=num_qubits,
854
                        circuits=circuits,
855
                        shots=1024,
856
                        parameters=params,
857
                        backend_sim=self.backend_sim,
858
                        error_mitigation=False
859
                    )
860

861
    def test_extreme_values(self):
1✔
862
        """Test functions with extreme input values."""
863
        # Test with very large qubit count (should be handled gracefully)
864
        large_qubits = 20
1✔
865
        
866
        # This should either work or fail gracefully, not crash
867
        try:
1✔
868
            simple_circuit = QuantumCircuit(large_qubits)
1✔
869
            circuits = get_circuits(large_qubits, simple_circuit)
1✔
870
            self.assertEqual(len(circuits), large_qubits + 1)
1✔
871
        except (MemoryError, Exception):
×
872
            # Large systems might legitimately fail due to memory constraints
873
            pass
×
874

875
        # Test with very small valid inputs
876
        minimal_circuit = QuantumCircuit(1)
1✔
877
        minimal_circuits = get_circuits(1, minimal_circuit)
1✔
878
        self.assertEqual(len(minimal_circuits), 2)  # Base + 1 H variant
1✔
879

880

881
class TestSampleIntegration(unittest.TestCase):
1✔
882
    """Integration tests for the sample module."""
883

884
    def setUp(self):
1✔
885
        """Set up test fixtures."""
886
        self.backend_sim = AerSimulator()
1✔
887

888
    def test_integration_basic_workflow(self):
1✔
889
        """Test the complete basic workflow integration."""
890
        num_qubits = 2  # Smaller for faster testing
1✔
891
        num_trees = 3
1✔
892
        base_circuit = real_amplitudes(num_qubits)
1✔
893
        parameters = np.random.rand(base_circuit.num_parameters)
1✔
894
        shots = 1024
1✔
895
        
896
        # Complete workflow
897
        circuits = get_circuits(num_qubits, base_circuit)
1✔
898
        samples = get_samples_noisy(
1✔
899
            num_qubits, circuits, shots, parameters, 
900
            self.backend_sim, error_mitigation=False
901
        )
902
        
903
        with warnings.catch_warnings():
1✔
904
            warnings.simplefilter("ignore", UserWarning)
1✔
905
            statevector = get_statevector(num_qubits, num_trees, samples, save_tree=False)
1✔
906
        
907
        # Validate end-to-end
908
        self.assertEqual(statevector.shape, (2**num_qubits,))
1✔
909
        self.assertTrue(np.all(np.isfinite(statevector)))
1✔
910
        self.assertAlmostEqual(np.linalg.norm(statevector), 1.0, places=6)
1✔
911

912
    def test_integration_with_error_mitigation(self):
1✔
913
        """Test integration workflow with error mitigation."""
914
        num_qubits = 2
1✔
915
        num_trees = 3
1✔
916
        base_circuit = real_amplitudes(num_qubits)
1✔
917
        parameters = np.random.rand(base_circuit.num_parameters)
1✔
918
        shots = 1024
1✔
919
        
920
        # Use mock to avoid expensive M3 calibration
921
        with patch('hadamard_random_forest.sample.M3Mitigation') as mock_m3:
1✔
922
            mock_mit_instance = MagicMock()
1✔
923
            mock_m3.return_value = mock_mit_instance
1✔
924
            mock_quasi = MagicMock()
1✔
925
            mock_quasi.nearest_probability_distribution.return_value = {
1✔
926
                '00': 0.25, '01': 0.25, '10': 0.25, '11': 0.25
927
            }
928
            mock_mit_instance.apply_correction.return_value = mock_quasi
1✔
929
            
930
            with patch('hadamard_random_forest.sample.mthree_utils.final_measurement_mapping') as mock_mapping:
1✔
931
                mock_mapping.return_value = [0, 1]
1✔
932
                
933
                circuits = get_circuits(num_qubits, base_circuit)
1✔
934
                samples = get_samples_noisy(
1✔
935
                    num_qubits, circuits, shots, parameters,
936
                    self.backend_sim, error_mitigation=True
937
                )
938
                
939
                with warnings.catch_warnings():
1✔
940
                    warnings.simplefilter("ignore", UserWarning)
1✔
941
                    statevector = get_statevector(num_qubits, num_trees, samples, save_tree=False)
1✔
942
                
943
                # Validate workflow with mitigation
944
                self.assertEqual(statevector.shape, (2**num_qubits,))
1✔
945
                self.assertTrue(np.all(np.isfinite(statevector)))
1✔
946

947
    def test_integration_different_ansatz_types(self):
1✔
948
        """Test integration with different circuit ansatz types."""
949
        num_qubits = 2
1✔
950
        num_trees = 3
1✔
951
        shots = 1024
1✔
952
        
953
        ansatz_types = [
1✔
954
            real_amplitudes(num_qubits, reps=1),
955
            efficient_su2(num_qubits, reps=1)
956
        ]
957
        
958
        for i, base_circuit in enumerate(ansatz_types):
1✔
959
            with self.subTest(ansatz_type=i):
1✔
960
                parameters = np.random.rand(base_circuit.num_parameters)
1✔
961
                
962
                circuits = get_circuits(num_qubits, base_circuit)
1✔
963
                samples = get_samples_noisy(
1✔
964
                    num_qubits, circuits, shots, parameters,
965
                    self.backend_sim, error_mitigation=False
966
                )
967
                
968
                with warnings.catch_warnings():
1✔
969
                    warnings.simplefilter("ignore", UserWarning)
1✔
970
                    statevector = get_statevector(num_qubits, num_trees, samples, save_tree=False)
1✔
971
                
972
                self.assertEqual(statevector.shape, (2**num_qubits,))
1✔
973
                self.assertTrue(np.all(np.isfinite(statevector)))
1✔
974

975
    @patch('hadamard_random_forest.sample.Sampler')
1✔
976
    def test_integration_hardware_workflow(self, mock_sampler_class):
1✔
977
        """Test integration of hardware workflow with mocks."""
978
        num_qubits = 2
1✔
979
        num_trees = 3
1✔
980
        base_circuit = real_amplitudes(num_qubits)
1✔
981
        parameters = np.random.rand(base_circuit.num_parameters)
1✔
982
        shots = 1024
1✔
983
        
984
        # Mock hardware workflow
985
        mock_sampler = MagicMock()
1✔
986
        mock_sampler_class.return_value = mock_sampler
1✔
987
        
988
        mock_job = MagicMock()
1✔
989
        mock_result = MagicMock()
1✔
990
        mock_data = MagicMock()
1✔
991
        mock_meas = MagicMock()
1✔
992
        
993
        mock_sampler.run.return_value = mock_job
1✔
994
        mock_job.result.return_value = [mock_result]
1✔
995
        mock_result.data = mock_data
1✔
996
        mock_data.meas = mock_meas
1✔
997
        mock_meas.get_counts.return_value = {'00': 256, '01': 256, '10': 256, '11': 256}
1✔
998
        mock_job.job_id.return_value = "test_job"
1✔
999
        mock_job.usage_estimation = {'quantum_seconds': 1.5}
1✔
1000
        
1001
        with patch('hadamard_random_forest.sample.mthree_utils.final_measurement_mapping') as mock_mapping:
1✔
1002
            mock_mapping.return_value = [0, 1]
1✔
1003
            
1004
            # Complete hardware workflow
1005
            circuits = get_circuits_hardware(num_qubits, base_circuit, self.backend_sim)
1✔
1006
            mitigated_samples, raw_samples, job_ids, quantum_times = get_samples_hardware(
1✔
1007
                num_qubits, shots, circuits, parameters,
1008
                self.backend_sim, error_mitigation=False
1009
            )
1010
            
1011
            with warnings.catch_warnings():
1✔
1012
                warnings.simplefilter("ignore", UserWarning)
1✔
1013
                statevector = get_statevector(num_qubits, num_trees, mitigated_samples, save_tree=False)
1✔
1014
            
1015
            # Validate hardware workflow
1016
            self.assertEqual(len(mitigated_samples), num_qubits + 1)
1✔
1017
            self.assertEqual(len(raw_samples), num_qubits + 1)
1✔
1018
            self.assertEqual(len(job_ids), num_qubits + 1)
1✔
1019
            self.assertEqual(len(quantum_times), num_qubits + 1)
1✔
1020
            self.assertEqual(statevector.shape, (2**num_qubits,))
1✔
1021

1022
    def test_performance_scaling(self):
1✔
1023
        """Test performance scaling across different system sizes."""
1024
        import time
1✔
1025
        
1026
        results = {}
1✔
1027
        max_qubits = 4  # Keep reasonable for testing
1✔
1028
        
1029
        for num_qubits in range(2, max_qubits + 1):
1✔
1030
            base_circuit = real_amplitudes(num_qubits, reps=1)
1✔
1031
            parameters = np.random.rand(base_circuit.num_parameters)
1✔
1032
            
1033
            start_time = time.time()
1✔
1034
            
1035
            # Time the circuit generation
1036
            circuits = get_circuits(num_qubits, base_circuit)
1✔
1037
            
1038
            # Time the sampling (with reduced shots for speed)
1039
            samples = get_samples_noisy(
1✔
1040
                num_qubits, circuits, 512, parameters,  # Reduced shots
1041
                self.backend_sim, error_mitigation=False
1042
            )
1043
            
1044
            # Time the reconstruction (with fewer trees)
1045
            with warnings.catch_warnings():
1✔
1046
                warnings.simplefilter("ignore", UserWarning)
1✔
1047
                statevector = get_statevector(num_qubits, 3, samples, save_tree=False)
1✔
1048
            
1049
            elapsed_time = time.time() - start_time
1✔
1050
            results[num_qubits] = elapsed_time
1✔
1051
            
1052
            # Validate result
1053
            self.assertEqual(statevector.shape, (2**num_qubits,))
1✔
1054
            self.assertTrue(np.all(np.isfinite(statevector)))
1✔
1055
        
1056
        # Basic scaling check - should be reasonable
1057
        for num_qubits, time_taken in results.items():
1✔
1058
            self.assertLess(time_taken, 60.0)  # Should complete within 60 seconds
1✔
1059
        
1060
        # Check that scaling is not exponential in the small regime
1061
        if len(results) >= 2:
1✔
1062
            times = list(results.values())
1✔
1063
            # Time shouldn't increase by more than factor of 10 per qubit for small systems
1064
            for i in range(1, len(times)):
1✔
1065
                self.assertLess(times[i] / times[i-1], 10.0)
1✔
1066

1067
    def test_memory_usage_reasonable(self):
1✔
1068
        """Test that memory usage is reasonable for moderate system sizes."""
1069
        import psutil
1✔
1070
        import os
1✔
1071
        
1072
        process = psutil.Process(os.getpid())
1✔
1073
        initial_memory = process.memory_info().rss / 1024 / 1024  # MB
1✔
1074
        
1075
        # Test with moderate system size
1076
        num_qubits = 4
1✔
1077
        num_trees = 5
1✔
1078
        base_circuit = real_amplitudes(num_qubits, reps=2)
1✔
1079
        parameters = np.random.rand(base_circuit.num_parameters)
1✔
1080
        
1081
        circuits = get_circuits(num_qubits, base_circuit)
1✔
1082
        samples = get_samples_noisy(
1✔
1083
            num_qubits, circuits, 1024, parameters,
1084
            self.backend_sim, error_mitigation=False
1085
        )
1086
        
1087
        with warnings.catch_warnings():
1✔
1088
            warnings.simplefilter("ignore", UserWarning)
1✔
1089
            statevector = get_statevector(num_qubits, num_trees, samples, save_tree=False)
1✔
1090
        
1091
        final_memory = process.memory_info().rss / 1024 / 1024  # MB
1✔
1092
        memory_increase = final_memory - initial_memory
1✔
1093
        
1094
        # Memory increase should be reasonable (less than 500MB for 4 qubits)
1095
        self.assertLess(memory_increase, 500.0)
1✔
1096
        
1097
        # Result should still be valid
1098
        self.assertEqual(statevector.shape, (2**num_qubits,))
1✔
1099
        self.assertTrue(np.all(np.isfinite(statevector)))
1✔
1100

1101
if __name__ == '__main__':
1✔
1102
    unittest.main()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc