• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

atlanticwave-sdx / sdx-controller / 19910063817

03 Dec 2025 09:53PM UTC coverage: 54.759% (-0.6%) from 55.33%
19910063817

Pull #498

github

web-flow
Merge 955e689ba into 51b6778fc
Pull Request #498: when restart, sync the toplogy vlan/bw state from connections out of db

4 of 32 new or added lines in 2 files covered. (12.5%)

90 existing lines in 2 files now uncovered.

1260 of 2301 relevant lines covered (54.76%)

1.1 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

50.0
/sdx_controller/messaging/rpc_queue_consumer.py
1
#!/usr/bin/env python
2
import json
2✔
3
import logging
2✔
4
import os
2✔
5
import threading
2✔
6
import time
2✔
7
import traceback
2✔
8
from queue import Queue
2✔
9

10
import pika
2✔
11
from sdx_datamodel.constants import (
2✔
12
    Constants,
13
    DomainStatus,
14
    MessageQueueNames,
15
    MongoCollections,
16
)
17
from sdx_datamodel.models.topology import SDX_TOPOLOGY_ID_prefix
2✔
18
from sdx_pce.models import ConnectionPath, ConnectionRequest, ConnectionSolution
2✔
19

20
from sdx_controller.handlers.connection_handler import (
2✔
21
    ConnectionHandler,
22
    connection_state_machine,
23
    get_connection_status,
24
    parse_conn_status,
25
)
26
from sdx_controller.handlers.lc_message_handler import LcMessageHandler
2✔
27
from sdx_controller.models import connection
2✔
28
from sdx_controller.utils.parse_helper import ParseHelper
2✔
29

30
MQ_HOST = os.getenv("MQ_HOST")
2✔
31
MQ_PORT = os.getenv("MQ_PORT") or 5672
2✔
32
MQ_USER = os.getenv("MQ_USER") or "guest"
2✔
33
MQ_PASS = os.getenv("MQ_PASS") or "guest"
2✔
34
HEARTBEAT_INTERVAL = int(os.getenv("HEARTBEAT_INTERVAL", 10))  # seconds
2✔
35
HEARTBEAT_TOLERANCE = int(
2✔
36
    os.getenv("HEARTBEAT_TOLERANCE", 3)
37
)  # consecutive missed heartbeats allowed
38

39

40
# subscribe to the corresponding queue
41
SUB_QUEUE = MessageQueueNames.OXP_UPDATE
2✔
42

43
logger = logging.getLogger(__name__)
2✔
44

45
MongoCollections.SOLUTIONS = "solutions"
2✔
46

47

48
class HeartbeatMonitor:
2✔
49
    def __init__(self, db_instance):
2✔
50
        self.last_heartbeat = {}  # domain -> last heartbeat timestamp
2✔
51
        self.domain_status = {}  # domain -> current status (UP / UNKNOWN)
2✔
52
        self.lock = threading.Lock()
2✔
53
        self.monitoring = False
2✔
54
        self.db_instance = db_instance  # store DB instance
2✔
55

56
    def record_heartbeat(self, domain):
2✔
57
        """Record heartbeat from a domain and mark it as UP if previously UNKNOWN."""
UNCOV
58
        with self.lock:
×
59
            self.last_heartbeat[domain] = time.time()
×
60

UNCOV
61
            previous_status = self.domain_status.get(domain)
×
62
            self.domain_status[domain] = DomainStatus.UP
×
63

64
            # Update DB if status changed from UNKNOWN -> UP
UNCOV
65
            if previous_status == DomainStatus.UNKNOWN:
×
66
                logger.info(
×
67
                    f"[HeartbeatMonitor] Domain {domain} is BACK UP after missed heartbeats."
68
                )
UNCOV
69
                domain_dict_from_db = self.db_instance.get_value_from_db(
×
70
                    MongoCollections.DOMAINS, Constants.DOMAIN_DICT
71
                )
UNCOV
72
                if domain in domain_dict_from_db:
×
73
                    domain_dict_from_db[domain] = DomainStatus.UP
×
74
                    self.db_instance.add_key_value_pair_to_db(
×
75
                        MongoCollections.DOMAINS,
76
                        Constants.DOMAIN_DICT,
77
                        domain_dict_from_db,
78
                    )
79

UNCOV
80
            logger.debug(f"[HeartbeatMonitor] Heartbeat recorded for {domain}")
×
81

82
    def check_status(self):
2✔
83
        """Mark domains as UNKNOWN if heartbeats are missing."""
84
        now = time.time()
2✔
85
        with self.lock:
2✔
86
            for domain, last_time in self.last_heartbeat.items():
2✔
UNCOV
87
                if now - last_time > HEARTBEAT_TOLERANCE * HEARTBEAT_INTERVAL:
×
88
                    if self.domain_status.get(domain) != DomainStatus.UNKNOWN:
×
89
                        logger.warning(
×
90
                            f"[HeartbeatMonitor] Domain {domain} marked UNKNOWN (missed {HEARTBEAT_TOLERANCE} heartbeats)"
91
                        )
UNCOV
92
                        self.domain_status[domain] = DomainStatus.UNKNOWN
×
93

UNCOV
94
                        domain_dict_from_db = self.db_instance.get_value_from_db(
×
95
                            MongoCollections.DOMAINS, Constants.DOMAIN_DICT
96
                        )
UNCOV
97
                        if domain in domain_dict_from_db:
×
98
                            domain_dict_from_db[domain] = DomainStatus.UNKNOWN
×
99
                            self.db_instance.add_key_value_pair_to_db(
×
100
                                MongoCollections.DOMAINS,
101
                                Constants.DOMAIN_DICT,
102
                                domain_dict_from_db,
103
                            )
104

105
    def get_status(self, domain):
2✔
106
        """Return the current status of a domain."""
UNCOV
107
        with self.lock:
×
108
            return self.domain_status.get(domain, "unknown")
×
109

110
    def start_monitoring(self):
2✔
111
        """Start a background thread to monitor heartbeat status."""
112
        if self.monitoring:
2✔
UNCOV
113
            return
×
114
        self.monitoring = True
2✔
115
        logger.info("[HeartbeatMonitor] Started monitoring heartbeats.")
2✔
116

117
        def monitor_loop():
2✔
118
            while self.monitoring:
2✔
119
                self.check_status()
2✔
120
                time.sleep(HEARTBEAT_INTERVAL)
2✔
121

122
        t = threading.Thread(target=monitor_loop, daemon=True)
2✔
123
        t.start()
2✔
124

125

126
class RpcConsumer(object):
2✔
127
    def __init__(self, thread_queue, exchange_name, te_manager):
2✔
128
        self.logger = logging.getLogger(__name__)
2✔
129

130
        self.logger.info(f"[MQ] Using amqp://{MQ_USER}@{MQ_HOST}:{MQ_PORT}")
2✔
131

132
        self.connection = pika.BlockingConnection(
2✔
133
            pika.ConnectionParameters(
134
                host=MQ_HOST,
135
                port=MQ_PORT,
136
                credentials=pika.PlainCredentials(username=MQ_USER, password=MQ_PASS),
137
            )
138
        )
139

140
        self.channel = self.connection.channel()
2✔
141
        self.exchange_name = exchange_name
2✔
142

143
        self.channel.queue_declare(queue=SUB_QUEUE)
2✔
144
        self._thread_queue = thread_queue
2✔
145

146
        self.te_manager = te_manager
2✔
147

148
        self._exit_event = threading.Event()
2✔
149

150
    def on_request(self, ch, method, props, message_body):
2✔
UNCOV
151
        response = message_body
×
152
        self._thread_queue.put(message_body)
×
153

UNCOV
154
        self.connection = pika.BlockingConnection(
×
155
            pika.ConnectionParameters(
156
                host=MQ_HOST,
157
                port=MQ_PORT,
158
                credentials=pika.PlainCredentials(username=MQ_USER, password=MQ_PASS),
159
            )
160
        )
UNCOV
161
        self.channel = self.connection.channel()
×
162

UNCOV
163
        try:
×
164
            ch.basic_publish(
×
165
                exchange=self.exchange_name,
166
                routing_key=props.reply_to,
167
                properties=pika.BasicProperties(correlation_id=props.correlation_id),
168
                body=str(response),
169
            )
UNCOV
170
            ch.basic_ack(delivery_tag=method.delivery_tag)
×
171
        except Exception as err:
×
172
            self.logger.info(f"[MQ] encountered error when publishing: {err}")
×
173

174
    def start_consumer(self):
2✔
175
        self.channel.basic_qos(prefetch_count=1)
2✔
176
        self.channel.basic_consume(queue=SUB_QUEUE, on_message_callback=self.on_request)
2✔
177

178
        self.logger.info(" [MQ] Awaiting requests from queue: " + SUB_QUEUE)
2✔
179
        self.channel.start_consuming()
2✔
180

181
    def start_sdx_consumer(self, thread_queue, db_instance):
2✔
182
        rpc = RpcConsumer(thread_queue, "", self.te_manager)
2✔
183
        t1 = threading.Thread(target=rpc.start_consumer, args=(), daemon=True)
2✔
184
        t1.start()
2✔
185

186
        lc_message_handler = LcMessageHandler(db_instance, self.te_manager)
2✔
187
        parse_helper = ParseHelper()
2✔
188

189
        heartbeat_monitor = HeartbeatMonitor(db_instance)
2✔
190
        heartbeat_monitor.start_monitoring()
2✔
191

192
        latest_topo = {}
2✔
193
        domain_dict = {}
2✔
194

195
        # This part reads from DB when SDX controller initially starts.
196
        # It looks for domain_dict, if already in DB,
197
        # Then use the existing ones from DB.
198
        domain_dict_from_db = db_instance.get_value_from_db(
2✔
199
            MongoCollections.DOMAINS, Constants.DOMAIN_DICT
200
        )
201
        latest_topo_from_db = db_instance.get_value_from_db(
2✔
202
            MongoCollections.TOPOLOGIES, Constants.LATEST_TOPOLOGY
203
        )
204

205
        if domain_dict_from_db:
2✔
UNCOV
206
            domain_dict = domain_dict_from_db
×
207
            logger.debug("Domain list already exists in db: ")
×
208
            logger.debug(domain_dict)
×
209

210
        if latest_topo_from_db:
2✔
211
            latest_topo = latest_topo_from_db
2✔
212
            logger.debug("Topology already exists in db: ")
2✔
213
            logger.debug(latest_topo)
2✔
214

215
        # If topologies already saved in db, use them to initialize te_manager
216
        if domain_dict:
2✔
217
            for domain in domain_dict.keys():
×
218
                topology = db_instance.get_value_from_db(
×
219
                    MongoCollections.TOPOLOGIES, SDX_TOPOLOGY_ID_prefix + domain
220
                )
221

UNCOV
222
                if not topology:
×
223
                    continue
×
224

225
                # Get the actual thing minus the Mongo ObjectID.
UNCOV
226
                self.te_manager.add_topology(topology)
×
UNCOV
227
                logger.debug(f"Read {domain}: {topology}")
×
228
            # update topology/pce state in TE Manager
NEW
229
            graph = self.te_manager.generate_graph_te()
×
NEW
UNCOV
230
            connections = db_instance.get_all_entries_in_collection(
×
231
                MongoCollections.CONNECTIONS
232
            )
NEW
233
            if not connections:
×
NEW
234
                logger.info("No connection was found")
×
235
            else:
NEW
236
                for connection in connections:
×
NEW
237
                    service_id = next(iter(connection))
×
NEW
238
                    status = get_connection_status(db_instance, service_id)
×
NEW
239
                    logger.info(
×
240
                        f"Restart: service_id: {service_id}, status: {status.get(service_id)}"
241
                    )
242
                    # 1. update the vlan tables in pce
NEW
243
                    domain_breakdown = db_instance.get_value_from_db(
×
244
                        MongoCollections.BREAKDOWNS, service_id
245
                    )
NEW
246
                    if not domain_breakdown:
×
NEW
247
                        logger.warning(f"Could not find breakdown for {service_id}")
×
NEW
248
                        continue
×
NEW
249
                    try:
×
NEW
250
                        vlan_tags_table = self.te_manager.vlan_tags_table
×
NEW
251
                        for domain, segment in domain_breakdown.items():
×
NEW
252
                            logger.debug(f"domain:{domain};segment:{segment}")
×
NEW
253
                            domain_table = vlan_tags_table.get(domain)
×
NEW
254
                            uni_a = segment.get("uni_a")
×
NEW
255
                            vlan_table = domain_table.get(uni_a.get("port_id"))
×
NEW
256
                            vlan_table[uni_a.get("tag").get("value")] = service_id
×
NEW
257
                            uni_z = segment.get("uni_z")
×
NEW
258
                            vlan_table = domain_table.get(uni_z.get("port_id"))
×
NEW
259
                            vlan_table[uni_z.get("tag").get("value")] = service_id
×
NEW
260
                    except Exception as e:
×
NEW
261
                        err = traceback.format_exc().replace("\n", ", ")
×
NEW
262
                        logger.error(
×
263
                            f"Error when recovering breakdown vlan assignment: {e} - {err}"
264
                        )
NEW
265
                        return f"Error: {e}", 410
×
266

267
        while not self._exit_event.is_set():
2✔
268
            msg = thread_queue.get()
2✔
269
            logger.debug("MQ received message:" + str(msg))
×
270

271
            if not parse_helper.is_json(msg):
×
272
                logger.debug("Non JSON message, ignored")
×
273
                continue
×
274

275
            msg_json = json.loads(msg)
×
276
            if "type" in msg_json and msg_json.get("type") == "Heart Beat":
×
277
                domain = msg_json.get("domain")
×
278
                heartbeat_monitor.record_heartbeat(domain)
×
279
                logger.debug(f"Heart beat received from {domain}")
×
280
                continue
×
281

282
            try:
×
283
                lc_message_handler.process_lc_json_msg(
×
284
                    msg,
285
                    latest_topo,
286
                    domain_dict,
287
                )
288
            except Exception as exc:
×
289
                err = traceback.format_exc().replace("\n", ", ")
×
290
                logger.error(f"Failed to process LC message: {exc} -- {err}")
×
291

292
    def stop_threads(self):
2✔
293
        """
294
        Signal threads that we're ready to stop.
295
        """
296
        logger.info("[MQ] Stopping threads.")
×
297
        self.channel.stop_consuming()
×
298
        self._exit_event.set()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc