• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

lunarmodules / copas / 3724285554

pending completion
3724285554

push

github

Thijs Schreijer
release 4.5.0

1 of 1 new or added line in 1 file covered. (100.0%)

1237 of 1454 relevant lines covered (85.08%)

6953.37 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.3
/src/copas.lua
1
-------------------------------------------------------------------------------
2
-- Copas - Coroutine Oriented Portable Asynchronous Services
3
--
4
-- A dispatcher based on coroutines that can be used by TCP/IP servers.
5
-- Uses LuaSocket as the interface with the TCP/IP stack.
6
--
7
-- Authors: Andre Carregal, Javier Guerra, and Fabio Mascarenhas
8
-- Contributors: Diego Nehab, Mike Pall, David Burgess, Leonardo Godinho,
9
--               Thomas Harning Jr., and Gary NG
10
--
11
-- Copyright 2005-2021 - Kepler Project (www.keplerproject.org)
12
--
13
-- $Id: copas.lua,v 1.37 2009/04/07 22:09:52 carregal Exp $
14
-------------------------------------------------------------------------------
15

16
if package.loaded["socket.http"] and (_VERSION=="Lua 5.1") then     -- obsolete: only for Lua 5.1 compatibility
28✔
17
  error("you must require copas before require'ing socket.http")
×
18
end
19
if package.loaded["copas.http"] and (_VERSION=="Lua 5.1") then     -- obsolete: only for Lua 5.1 compatibility
28✔
20
  error("you must require copas before require'ing copas.http")
×
21
end
22

23

24
local socket = require "socket"
28✔
25
local binaryheap = require "binaryheap"
28✔
26
local gettime = socket.gettime
28✔
27
local ssl -- only loaded upon demand
28

29
local WATCH_DOG_TIMEOUT = 120
28✔
30
local UDP_DATAGRAM_MAX = socket._DATAGRAMSIZE or 8192
28✔
31
local TIMEOUT_PRECISION = 0.1  -- 100ms
28✔
32
local fnil = function() end
23,765✔
33

34

35
local coroutine_create = coroutine.create
28✔
36
local coroutine_running = coroutine.running
28✔
37
local coroutine_yield = coroutine.yield
28✔
38
local coroutine_resume = coroutine.resume
28✔
39
local coroutine_status = coroutine.status
28✔
40

41

42
-- nil-safe versions for pack/unpack
43
local _unpack = unpack or table.unpack
28✔
44
local unpack = function(t, i, j) return _unpack(t, i or 1, j or t.n or #t) end
132✔
45
local pack = function(...) return { n = select("#", ...), ...} end
183✔
46

47

48
local pcall = pcall
28✔
49
if _VERSION=="Lua 5.1" and not jit then     -- obsolete: only for Lua 5.1 compatibility
28✔
50
  pcall = require("coxpcall").pcall
×
51
  coroutine_running = require("coxpcall").running
×
52
end
53

54

55
do
56
  -- Redefines LuaSocket functions with coroutine safe versions (pure Lua)
57
  -- (this allows the use of socket.http from within copas)
58
  local err_mt = {
28✔
59
    __tostring = function (self)
60
      return "Copas 'try' error intermediate table: '"..tostring(self[1].."'")
×
61
    end,
62
  }
63

64
  local function statusHandler(status, ...)
65
    if status then return ... end
14✔
66
    local err = (...)
5✔
67
    if type(err) == "table" and getmetatable(err) == err_mt then
5✔
68
      return nil, err[1]
5✔
69
    else
70
      error(err)
×
71
    end
72
  end
73

74
  function socket.protect(func)
28✔
75
    return function (...)
76
            return statusHandler(pcall(func, ...))
28✔
77
          end
78
  end
79

80
  function socket.newtry(finalizer)
28✔
81
    return function (...)
82
            local status = (...)
197✔
83
            if not status then
197✔
84
              pcall(finalizer or fnil, select(2, ...))
5✔
85
              error(setmetatable({ (select(2, ...)) }, err_mt), 0)
5✔
86
            end
87
            return ...
192✔
88
          end
89
  end
90

91
  socket.try = socket.newtry()
56✔
92
end
93

94

95
local copas = setmetatable({},{
56✔
96
  __call = function(self, ...)
97
    return self.loop(...)
×
98
  end,
99
})
100

101
-- Meta information is public even if beginning with an "_"
102
copas._COPYRIGHT   = "Copyright (C) 2005-2022 Kepler Project"
28✔
103
copas._DESCRIPTION = "Coroutine Oriented Portable Asynchronous Services"
28✔
104
copas._VERSION     = "Copas 4.5.0"
28✔
105

106
-- Close the socket associated with the current connection after the handler finishes
107
copas.autoclose = true
28✔
108

109
-- indicator for the loop running
110
copas.running = false
28✔
111

112

113
-------------------------------------------------------------------------------
114
-- Object names, to track names of thread/coroutines and sockets
115
-------------------------------------------------------------------------------
116
local object_names = setmetatable({}, {
56✔
117
  __mode = "k",
118
  __index = function(self, key)
119
    local name = tostring(key)
27✔
120
    if key ~= nil then
27✔
121
      rawset(self, key, name)
27✔
122
    end
123
    return name
27✔
124
  end
125
})
126

127
-------------------------------------------------------------------------------
128
-- Simple set implementation
129
-- adds a FIFO queue for each socket in the set
130
-------------------------------------------------------------------------------
131

132
local function newsocketset()
133
  local set = {}
84✔
134

135
  do  -- set implementation
136
    local reverse = {}
84✔
137

138
    -- Adds a socket to the set, does nothing if it exists
139
    -- @return skt if added, or nil if it existed
140
    function set:insert(skt)
84✔
141
      if not reverse[skt] then
161✔
142
        self[#self + 1] = skt
161✔
143
        reverse[skt] = #self
161✔
144
        return skt
161✔
145
      end
146
    end
147

148
    -- Removes socket from the set, does nothing if not found
149
    -- @return skt if removed, or nil if it wasn't in the set
150
    function set:remove(skt)
84✔
151
      local index = reverse[skt]
227✔
152
      if index then
227✔
153
        reverse[skt] = nil
158✔
154
        local top = self[#self]
158✔
155
        self[#self] = nil
158✔
156
        if top ~= skt then
158✔
157
          reverse[top] = index
21✔
158
          self[index] = top
21✔
159
        end
160
        return skt
158✔
161
      end
162
    end
163

164
  end
165

166
  do  -- queues implementation
167
    local fifo_queues = setmetatable({},{
168✔
168
      __mode = "k",                 -- auto collect queue if socket is gone
169
      __index = function(self, skt) -- auto create fifo queue if not found
170
        local newfifo = {}
56✔
171
        self[skt] = newfifo
56✔
172
        return newfifo
56✔
173
      end,
174
    })
175

176
    -- pushes an item in the fifo queue for the socket.
177
    function set:push(skt, itm)
84✔
178
      local queue = fifo_queues[skt]
146✔
179
      queue[#queue + 1] = itm
146✔
180
    end
181

182
    -- pops an item from the fifo queue for the socket
183
    function set:pop(skt)
84✔
184
      local queue = fifo_queues[skt]
133✔
185
      return table.remove(queue, 1)
133✔
186
    end
187

188
  end
189

190
  return set
84✔
191
end
192

193

194

195
-- Threads immediately resumable
196
local _resumable = {} do
28✔
197
  local resumelist = {}
28✔
198

199
  function _resumable:push(co)
28✔
200
    resumelist[#resumelist + 1] = co
23,718✔
201
  end
202

203
  function _resumable:clear_resumelist()
28✔
204
    local lst = resumelist
22,169✔
205
    resumelist = {}
22,169✔
206
    return lst
22,169✔
207
  end
208

209
  function _resumable:done()
28✔
210
    return resumelist[1] == nil
22,579✔
211
  end
212

213
  function _resumable:count()
28✔
214
    return #resumelist + #_resumable
×
215
  end
216

217
end
218

219

220

221
-- Similar to the socket set above, but tailored for the use of
222
-- sleeping threads
223
local _sleeping = {} do
28✔
224

225
  local heap = binaryheap.minUnique()
28✔
226
  local lethargy = setmetatable({}, { __mode = "k" }) -- list of coroutines sleeping without a wakeup time
28✔
227

228

229
  -- Required base implementation
230
  -----------------------------------------
231
  _sleeping.insert = fnil
28✔
232
  _sleeping.remove = fnil
28✔
233

234
  -- push a new timer on the heap
235
  function _sleeping:push(sleeptime, co)
28✔
236
    if sleeptime < 0 then
23,736✔
237
      lethargy[co] = true
515✔
238
    elseif sleeptime == 0 then
23,221✔
239
      _resumable:push(co)
44,286✔
240
    else
241
      heap:insert(gettime() + sleeptime, co)
1,078✔
242
    end
243
  end
244

245
  -- find the thread that should wake up to the time, if any
246
  function _sleeping:pop(time)
28✔
247
    if time < (heap:peekValue() or math.huge) then
46,438✔
248
      return
22,169✔
249
    end
250
    return heap:pop()
1,050✔
251
  end
252

253
  -- additional methods for time management
254
  -----------------------------------------
255
  function _sleeping:getnext()  -- returns delay until next sleep expires, or nil if there is none
28✔
256
    local t = heap:peekValue()
822✔
257
    if t then
822✔
258
      -- never report less than 0, because select() might block
259
      return math.max(t - gettime(), 0)
822✔
260
    end
261
  end
262

263
  function _sleeping:wakeup(co)
28✔
264
    if lethargy[co] then
515✔
265
      lethargy[co] = nil
512✔
266
      _resumable:push(co)
512✔
267
      return
512✔
268
    end
269
    if heap:remove(co) then
6✔
270
      _resumable:push(co)
1✔
271
    end
272
  end
273

274
  -- @param tos number of timeouts running
275
  function _sleeping:done(tos)
28✔
276
    -- return true if we have nothing more to do
277
    -- the timeout task doesn't qualify as work (fallbacks only),
278
    -- the lethargy also doesn't qualify as work ('dead' tasks),
279
    -- but the combination of a timeout + a lethargy can be work
280
    return heap:size() == 1       -- 1 means only the timeout-timer task is running
616✔
281
           and not (tos > 0 and next(lethargy))
308✔
282
  end
283

284
  -- gets number of threads in binaryheap and lethargy
285
  function _sleeping:status()
28✔
286
    local c = 0
×
287
    for _ in pairs(lethargy) do c = c + 1 end
×
288

289
    return heap:size(), c
×
290
  end
291

292
end   -- _sleeping
293

294

295

296
-------------------------------------------------------------------------------
297
-- Tracking coroutines and sockets
298
-------------------------------------------------------------------------------
299

300
local _servers = newsocketset() -- servers being handled
28✔
301
local _threads = setmetatable({}, {__mode = "k"})  -- registered threads added with addthread()
28✔
302
local _canceled = setmetatable({}, {__mode = "k"}) -- threads that are canceled and pending removal
28✔
303
local _autoclose = setmetatable({}, {__mode = "kv"}) -- sockets (value) to close when a thread (key) exits
28✔
304
local _autoclose_r = setmetatable({}, {__mode = "kv"}) -- reverse: sockets (key) to close when a thread (value) exits
28✔
305

306

307
-- for each socket we log the last read and last write times to enable the
308
-- watchdog to follow up if it takes too long.
309
-- tables contain the time, indexed by the socket
310
local _reading_log = {}
28✔
311
local _writing_log = {}
28✔
312

313
local _closed = {} -- track sockets that have been closed (list/array)
28✔
314

315
local _reading = newsocketset() -- sockets currently being read
28✔
316
local _writing = newsocketset() -- sockets currently being written
28✔
317
local _isSocketTimeout = { -- set of errors indicating a socket-timeout
28✔
318
  ["timeout"] = true,      -- default LuaSocket timeout
319
  ["wantread"] = true,     -- LuaSec specific timeout
320
  ["wantwrite"] = true,    -- LuaSec specific timeout
321
}
322

323
-------------------------------------------------------------------------------
324
-- Coroutine based socket timeouts.
325
-------------------------------------------------------------------------------
326
local user_timeouts_connect
327
local user_timeouts_send
328
local user_timeouts_receive
329
do
330
  local timeout_mt = {
28✔
331
    __mode = "k",
332
    __index = function(self, skt)
333
      -- if there is no timeout found, we insert one automatically,
334
      -- a 10 year timeout as substitute for the default "blocking" should do
335
      self[skt] = 10*365*24*60*60
63✔
336
      return self[skt]
63✔
337
    end,
338
  }
339

340
  user_timeouts_connect = setmetatable({}, timeout_mt)
28✔
341
  user_timeouts_send = setmetatable({}, timeout_mt)
28✔
342
  user_timeouts_receive = setmetatable({}, timeout_mt)
28✔
343
end
344

345
local useSocketTimeoutErrors = setmetatable({},{ __mode = "k" })
28✔
346

347

348
-- sto = socket-time-out
349
local sto_timeout, sto_timed_out, sto_change_queue, sto_error do
28✔
350

351
  local socket_register = setmetatable({}, { __mode = "k" })    -- socket by coroutine
28✔
352
  local operation_register = setmetatable({}, { __mode = "k" }) -- operation "read"/"write" by coroutine
28✔
353
  local timeout_flags = setmetatable({}, { __mode = "k" })      -- true if timedout, by coroutine
28✔
354

355

356
  local function socket_callback(co)
357
    local skt = socket_register[co]
12✔
358
    local queue = operation_register[co]
12✔
359

360
    -- flag the timeout and resume the coroutine
361
    timeout_flags[co] = true
12✔
362
    _resumable:push(co)
12✔
363

364
    -- clear the socket from the current queue
365
    if queue == "read" then
12✔
366
      _reading:remove(skt)
20✔
367
    elseif queue == "write" then
2✔
368
      _writing:remove(skt)
4✔
369
    else
370
      error("bad queue name; expected 'read'/'write', got: "..tostring(queue))
×
371
    end
372
  end
373

374

375
  -- Sets a socket timeout.
376
  -- Calling it as `sto_timeout()` will cancel the timeout.
377
  -- @param queue (string) the queue the socket is currently in, must be either "read" or "write"
378
  -- @param skt (socket) the socket on which to operate
379
  -- @param use_connect_to (bool) timeout to use is determined based on queue (read/write) or if this
380
  -- is truthy, it is the connect timeout.
381
  -- @return true
382
  function sto_timeout(skt, queue, use_connect_to)
28✔
383
    local co = coroutine_running()
396,933✔
384
    socket_register[co] = skt
396,933✔
385
    operation_register[co] = queue
396,933✔
386
    timeout_flags[co] = nil
396,933✔
387
    if skt then
396,933✔
388
      local to = (use_connect_to and user_timeouts_connect[skt]) or
198,481✔
389
                 (queue == "read" and user_timeouts_receive[skt]) or
198,431✔
390
                 user_timeouts_send[skt]
2,449✔
391
      copas.timeout(to, socket_callback)
396,946✔
392
    else
393
      copas.timeout(0)
198,460✔
394
    end
395
    return true
396,933✔
396
  end
397

398

399
  -- Changes the timeout to a different queue (read/write).
400
  -- Only usefull with ssl-handshakes and "wantread", "wantwrite" errors, when
401
  -- the queue has to be changed, so the timeout handler knows where to find the socket.
402
  -- @param queue (string) the new queue the socket is in, must be either "read" or "write"
403
  -- @return true
404
  function sto_change_queue(queue)
28✔
405
    operation_register[coroutine_running()] = queue
127✔
406
    return true
127✔
407
  end
408

409

410
  -- Responds with `true` if the operation timed-out.
411
  function sto_timed_out()
28✔
412
    return timeout_flags[coroutine_running()]
158✔
413
  end
414

415

416
  -- Returns the poroper timeout error
417
  function sto_error(err)
28✔
418
    return useSocketTimeoutErrors[coroutine_running()] and err or "timeout"
12✔
419
  end
420
end
421

422

423

424
-------------------------------------------------------------------------------
425
-- Coroutine based socket I/O functions.
426
-------------------------------------------------------------------------------
427

428
-- Returns "tcp"" for plain TCP and "ssl" for ssl-wrapped sockets, so truthy
429
-- for tcp based, and falsy for udp based.
430
local isTCP do
28✔
431
  local lookup = {
28✔
432
    tcp = "tcp",
433
    SSL = "ssl",
434
  }
435

436
  function isTCP(socket)
28✔
437
    return lookup[tostring(socket):sub(1,3)]
182✔
438
  end
439
end
440

441
function copas.close(skt, ...)
28✔
442
  _closed[#_closed+1] = skt
30✔
443
  return skt:close(...)
30✔
444
end
445

446

447

448
-- nil or negative is indefinitly
449
function copas.settimeout(skt, timeout)
28✔
450
  timeout = timeout or -1
29✔
451
  if type(timeout) ~= "number" then
29✔
452
    return nil, "timeout must be 'nil' or a number"
×
453
  end
454

455
  return copas.settimeouts(skt, timeout, timeout, timeout)
29✔
456
end
457

458
-- negative is indefinitly, nil means do not change
459
function copas.settimeouts(skt, connect, send, read)
28✔
460

461
  if connect ~= nil and type(connect) ~= "number" then
64✔
462
    return nil, "connect timeout must be 'nil' or a number"
×
463
  end
464
  if connect then
64✔
465
    if connect < 0 then
64✔
466
      connect = nil
×
467
    end
468
    user_timeouts_connect[skt] = connect
64✔
469
  end
470

471

472
  if send ~= nil and type(send) ~= "number" then
64✔
473
    return nil, "send timeout must be 'nil' or a number"
×
474
  end
475
  if send then
64✔
476
    if send < 0 then
64✔
477
      send = nil
×
478
    end
479
    user_timeouts_send[skt] = send
64✔
480
  end
481

482

483
  if read ~= nil and type(read) ~= "number" then
64✔
484
    return nil, "read timeout must be 'nil' or a number"
×
485
  end
486
  if read then
64✔
487
    if read < 0 then
64✔
488
      read = nil
×
489
    end
490
    user_timeouts_receive[skt] = read
64✔
491
  end
492

493

494
  return true
64✔
495
end
496

497
-- reads a pattern from a client and yields to the reading set on timeouts
498
-- UDP: a UDP socket expects a second argument to be a number, so it MUST
499
-- be provided as the 'pattern' below defaults to a string. Will throw a
500
-- 'bad argument' error if omitted.
501
function copas.receive(client, pattern, part)
28✔
502
  local s, err
503
  pattern = pattern or "*l"
195,974✔
504
  local current_log = _reading_log
195,974✔
505
  sto_timeout(client, "read")
195,974✔
506

507
  repeat
508
    s, err, part = client:receive(pattern, part)
196,046✔
509

510
    -- guarantees that high throughput doesn't take other threads to starvation
511
    if (math.random(100) > 90) then
196,046✔
512
      copas.pause()
19,741✔
513
    end
514

515
    if s then
196,046✔
516
      current_log[client] = nil
195,959✔
517
      sto_timeout()
195,959✔
518
      return s, err, part
195,959✔
519

520
    elseif not _isSocketTimeout[err] then
87✔
521
      current_log[client] = nil
5✔
522
      sto_timeout()
5✔
523
      return s, err, part
5✔
524

525
    elseif sto_timed_out() then
164✔
526
      current_log[client] = nil
9✔
527
      return nil, sto_error(err), part
18✔
528
    end
529

530
    if err == "wantwrite" then -- wantwrite may be returned during SSL renegotiations
73✔
531
      current_log = _writing_log
×
532
      current_log[client] = gettime()
×
533
      sto_change_queue("write")
×
534
      coroutine_yield(client, _writing)
×
535
    else
536
      current_log = _reading_log
73✔
537
      current_log[client] = gettime()
73✔
538
      sto_change_queue("read")
73✔
539
      coroutine_yield(client, _reading)
73✔
540
    end
541
  until false
72✔
542
end
543

544
-- receives data from a client over UDP. Not available for TCP.
545
-- (this is a copy of receive() method, adapted for receivefrom() use)
546
function copas.receivefrom(client, size)
28✔
547
  local s, err, port
548
  size = size or UDP_DATAGRAM_MAX
4✔
549
  sto_timeout(client, "read")
4✔
550

551
  repeat
552
    s, err, port = client:receivefrom(size) -- upon success err holds ip address
8✔
553

554
    -- garantees that high throughput doesn't take other threads to starvation
555
    if (math.random(100) > 90) then
8✔
556
      copas.pause()
×
557
    end
558

559
    if s then
8✔
560
      _reading_log[client] = nil
3✔
561
      sto_timeout()
3✔
562
      return s, err, port
3✔
563

564
    elseif err ~= "timeout" then
5✔
565
      _reading_log[client] = nil
×
566
      sto_timeout()
×
567
      return s, err, port
×
568

569
    elseif sto_timed_out() then
10✔
570
      _reading_log[client] = nil
1✔
571
      return nil, sto_error(err), port
2✔
572
    end
573

574
    _reading_log[client] = gettime()
4✔
575
    coroutine_yield(client, _reading)
4✔
576
  until false
4✔
577
end
578

579
-- same as above but with special treatment when reading chunks,
580
-- unblocks on any data received.
581
function copas.receivepartial(client, pattern, part)
28✔
582
  local s, err
583
  pattern = pattern or "*l"
2✔
584
  local orig_size = #(part or "")
2✔
585
  local current_log = _reading_log
2✔
586
  sto_timeout(client, "read")
2✔
587

588
  repeat
589
    s, err, part = client:receive(pattern, part)
3✔
590

591
    -- guarantees that high throughput doesn't take other threads to starvation
592
    if (math.random(100) > 90) then
3✔
593
      copas.pause()
×
594
    end
595

596
    if s or (type(part) == "string" and #part > orig_size) then
3✔
597
      current_log[client] = nil
2✔
598
      sto_timeout()
2✔
599
      return s, err, part
2✔
600

601
    elseif not _isSocketTimeout[err] then
1✔
602
      current_log[client] = nil
×
603
      sto_timeout()
×
604
      return s, err, part
×
605

606
    elseif sto_timed_out() then
2✔
607
      current_log[client] = nil
×
608
      return nil, sto_error(err), part
×
609
    end
610

611
    if err == "wantwrite" then
1✔
612
      current_log = _writing_log
×
613
      current_log[client] = gettime()
×
614
      sto_change_queue("write")
×
615
      coroutine_yield(client, _writing)
×
616
    else
617
      current_log = _reading_log
1✔
618
      current_log[client] = gettime()
1✔
619
      sto_change_queue("read")
1✔
620
      coroutine_yield(client, _reading)
1✔
621
    end
622
  until false
1✔
623
end
624
copas.receivePartial = copas.receivepartial  -- compat: receivePartial is deprecated
28✔
625

626
-- sends data to a client. The operation is buffered and
627
-- yields to the writing set on timeouts
628
-- Note: from and to parameters will be ignored by/for UDP sockets
629
function copas.send(client, data, from, to)
28✔
630
  local s, err
631
  from = from or 1
2,449✔
632
  local lastIndex = from - 1
2,449✔
633
  local current_log = _writing_log
2,449✔
634
  sto_timeout(client, "write")
2,449✔
635

636
  repeat
637
    s, err, lastIndex = client:send(data, lastIndex + 1, to)
2,478✔
638

639
    -- guarantees that high throughput doesn't take other threads to starvation
640
    if (math.random(100) > 90) then
2,478✔
641
      copas.pause()
239✔
642
    end
643

644
    if s then
2,478✔
645
      current_log[client] = nil
2,445✔
646
      sto_timeout()
2,445✔
647
      return s, err, lastIndex
2,445✔
648

649
    elseif not _isSocketTimeout[err] then
33✔
650
      current_log[client] = nil
4✔
651
      sto_timeout()
4✔
652
      return s, err, lastIndex
4✔
653

654
    elseif sto_timed_out() then
58✔
655
      current_log[client] = nil
×
656
      return nil, sto_error(err), lastIndex
×
657
    end
658

659
    if err == "wantread" then
29✔
660
      current_log = _reading_log
×
661
      current_log[client] = gettime()
×
662
      sto_change_queue("read")
×
663
      coroutine_yield(client, _reading)
×
664
    else
665
      current_log = _writing_log
29✔
666
      current_log[client] = gettime()
29✔
667
      sto_change_queue("write")
29✔
668
      coroutine_yield(client, _writing)
29✔
669
    end
670
  until false
29✔
671
end
672

673
function copas.sendto(client, data, ip, port)
28✔
674
  -- deprecated; for backward compatibility only, since UDP doesn't block on sending
675
  return client:sendto(data, ip, port)
×
676
end
677

678
-- waits until connection is completed
679
function copas.connect(skt, host, port)
28✔
680
  skt:settimeout(0)
30✔
681
  local ret, err, tried_more_than_once
682
  sto_timeout(skt, "write", true)
29✔
683

684
  repeat
685
    ret, err = skt:connect(host, port)
48✔
686

687
    -- non-blocking connect on Windows results in error "Operation already
688
    -- in progress" to indicate that it is completing the request async. So essentially
689
    -- it is the same as "timeout"
690
    if ret or (err ~= "timeout" and err ~= "Operation already in progress") then
44✔
691
      _writing_log[skt] = nil
27✔
692
      sto_timeout()
27✔
693
      -- Once the async connect completes, Windows returns the error "already connected"
694
      -- to indicate it is done, so that error should be ignored. Except when it is the
695
      -- first call to connect, then it was already connected to something else and the
696
      -- error should be returned
697
      if (not ret) and (err == "already connected" and tried_more_than_once) then
27✔
698
        return 1
×
699
      end
700
      return ret, err
27✔
701

702
    elseif sto_timed_out() then
34✔
703
      _writing_log[skt] = nil
2✔
704
      return nil, sto_error(err)
4✔
705
    end
706

707
    tried_more_than_once = tried_more_than_once or true
15✔
708
    _writing_log[skt] = gettime()
15✔
709
    coroutine_yield(skt, _writing)
15✔
710
  until false
15✔
711
end
712

713

714
-- Wraps a tcp socket in an ssl socket and configures it. If the socket was
715
-- already wrapped, it does nothing and returns the socket.
716
-- @param wrap_params the parameters for the ssl-context
717
-- @return wrapped socket, or throws an error
718
local function ssl_wrap(skt, wrap_params)
719
  if isTCP(skt) == "ssl" then return skt end -- was already wrapped
52✔
720
  if not wrap_params then
15✔
721
    error("cannot wrap socket into a secure socket (using 'ssl.wrap()') without parameters/context")
×
722
  end
723

724
  ssl = ssl or require("ssl")
15✔
725
  local nskt = assert(ssl.wrap(skt, wrap_params)) -- assert, because we do not want to silently ignore this one!!
30✔
726

727
  nskt:settimeout(0)  -- non-blocking on the ssl-socket
15✔
728
  copas.settimeouts(nskt, user_timeouts_connect[skt],
30✔
729
    user_timeouts_send[skt], user_timeouts_receive[skt]) -- copy copas user-timeout to newly wrapped one
19✔
730

731
  local co = _autoclose_r[skt]
15✔
732
  if co then
15✔
733
    -- socket registered for autoclose, move registration to wrapped one
734
    _autoclose[co] = nskt
4✔
735
    _autoclose_r[skt] = nil
4✔
736
    _autoclose_r[nskt] = co
4✔
737
  end
738

739
  local sock_name = object_names[skt]
15✔
740
  if sock_name ~= tostring(skt) then
15✔
741
    -- socket had a custom name, so copy it over
742
    object_names[nskt] = sock_name
6✔
743
  end
744
  return nskt
15✔
745
end
746

747

748
-- For each luasec method we have a subtable, allows for future extension.
749
-- Required structure:
750
-- {
751
--   wrap = ... -- parameter to 'wrap()'; the ssl parameter table, or the context object
752
--   sni = {                  -- parameters to 'sni()'
753
--     names = string | table -- 1st parameter
754
--     strict = bool          -- 2nd parameter
755
--   }
756
-- }
757
local function normalize_sslt(sslt)
758
  local t = type(sslt)
46✔
759
  local r = setmetatable({}, {
92✔
760
    __index = function(self, key)
761
      -- a bug if this happens, here as a sanity check, just being careful since
762
      -- this is security stuff
763
      error("accessing unknown 'ssl_params' table key: "..tostring(key))
×
764
    end,
765
  })
766
  if t == "nil" then
46✔
767
    r.wrap = false
31✔
768
    r.sni = false
31✔
769

770
  elseif t == "table" then
15✔
771
    if sslt.mode or sslt.protocol then
15✔
772
      -- has the mandatory fields for the ssl-params table for handshake
773
      -- backward compatibility
774
      r.wrap = sslt
4✔
775
      r.sni = false
4✔
776
    else
777
      -- has the target definition, copy our known keys
778
      r.wrap = sslt.wrap or false -- 'or false' because we do not want nils
11✔
779
      r.sni = sslt.sni or false -- 'or false' because we do not want nils
11✔
780
    end
781

782
  elseif t == "userdata" then
×
783
    -- it's an ssl-context object for the handshake
784
    -- backward compatibility
785
    r.wrap = sslt
×
786
    r.sni = false
×
787

788
  else
789
    error("ssl parameters; did not expect type "..tostring(sslt))
×
790
  end
791

792
  return r
46✔
793
end
794

795

796
---
797
-- Peforms an (async) ssl handshake on a connected TCP client socket.
798
-- NOTE: if not ssl-wrapped already, then replace all previous socket references, with the returned new ssl wrapped socket
799
-- Throws error and does not return nil+error, as that might silently fail
800
-- in code like this;
801
--   copas.addserver(s1, function(skt)
802
--       skt = copas.wrap(skt, sparams)
803
--       skt:dohandshake()   --> without explicit error checking, this fails silently and
804
--       skt:send(body)      --> continues unencrypted
805
-- @param skt Regular LuaSocket CLIENT socket object
806
-- @param wrap_params Table with ssl parameters
807
-- @return wrapped ssl socket, or throws an error
808
function copas.dohandshake(skt, wrap_params)
28✔
809
  ssl = ssl or require("ssl")
15✔
810

811
  local nskt = ssl_wrap(skt, wrap_params)
15✔
812

813
  sto_timeout(nskt, "write", true)
15✔
814
  local queue
815

816
  repeat
817
    local success, err = nskt:dohandshake()
39✔
818

819
    if success then
39✔
820
      sto_timeout()
13✔
821
      return nskt
13✔
822

823
    elseif not _isSocketTimeout[err] then
26✔
824
      sto_timeout()
2✔
825
      error("TLS/SSL handshake failed: " .. tostring(err))
2✔
826

827
    elseif sto_timed_out() then
48✔
828
      return nil, sto_error(err)
×
829

830
    elseif err == "wantwrite" then
24✔
831
      sto_change_queue("write")
×
832
      queue = _writing
×
833

834
    elseif err == "wantread" then
24✔
835
      sto_change_queue("read")
24✔
836
      queue = _reading
24✔
837

838
    else
839
      error("TLS/SSL handshake failed: " .. tostring(err))
×
840
    end
841

842
    coroutine_yield(nskt, queue)
24✔
843
  until false
24✔
844
end
845

846
-- flushes a client write buffer (deprecated)
847
function copas.flush()
28✔
848
end
849

850
-- wraps a TCP socket to use Copas methods (send, receive, flush and settimeout)
851
local _skt_mt_tcp = {
28✔
852
      __tostring = function(self)
853
        return tostring(self.socket).." (copas wrapped)"
×
854
      end,
855

856
      __index = {
28✔
857
        send = function (self, data, from, to)
858
          return copas.send (self.socket, data, from, to)
2,448✔
859
        end,
860

861
        receive = function (self, pattern, prefix)
862
          if user_timeouts_receive[self.socket] == 0 then
195,974✔
863
            return copas.receivepartial(self.socket, pattern, prefix)
2✔
864
          end
865
          return copas.receive(self.socket, pattern, prefix)
195,971✔
866
        end,
867

868
        receivepartial = function (self, pattern, prefix)
869
          return copas.receivepartial(self.socket, pattern, prefix)
×
870
        end,
871

872
        flush = function (self)
873
          return copas.flush(self.socket)
×
874
        end,
875

876
        settimeout = function (self, time)
877
          return copas.settimeout(self.socket, time)
29✔
878
        end,
879

880
        settimeouts = function (self, connect, send, receive)
881
          return copas.settimeouts(self.socket, connect, send, receive)
×
882
        end,
883

884
        -- TODO: socket.connect is a shortcut, and must be provided with an alternative
885
        -- if ssl parameters are available, it will also include a handshake
886
        connect = function(self, ...)
887
          local res, err = copas.connect(self.socket, ...)
29✔
888
          if res then
29✔
889
            if self.ssl_params.sni then self:sni() end
26✔
890
            if self.ssl_params.wrap then res, err = self:dohandshake() end
36✔
891
          end
892
          return res, err
28✔
893
        end,
894

895
        close = function(self, ...)
896
          return copas.close(self.socket, ...)
30✔
897
        end,
898

899
        -- TODO: socket.bind is a shortcut, and must be provided with an alternative
900
        bind = function(self, ...) return self.socket:bind(...) end,
28✔
901

902
        -- TODO: is this DNS related? hence blocking?
903
        getsockname = function(self, ...) return self.socket:getsockname(...) end,
28✔
904

905
        getstats = function(self, ...) return self.socket:getstats(...) end,
28✔
906

907
        setstats = function(self, ...) return self.socket:setstats(...) end,
28✔
908

909
        listen = function(self, ...) return self.socket:listen(...) end,
28✔
910

911
        accept = function(self, ...) return self.socket:accept(...) end,
28✔
912

913
        setoption = function(self, ...) return self.socket:setoption(...) end,
28✔
914

915
        -- TODO: is this DNS related? hence blocking?
916
        getpeername = function(self, ...) return self.socket:getpeername(...) end,
28✔
917

918
        shutdown = function(self, ...) return self.socket:shutdown(...) end,
28✔
919

920
        sni = function(self, names, strict)
921
          local sslp = self.ssl_params
11✔
922
          self.socket = ssl_wrap(self.socket, sslp.wrap)
22✔
923
          if names == nil then
11✔
924
            names = sslp.sni.names
9✔
925
            strict = sslp.sni.strict
9✔
926
          end
927
          return self.socket:sni(names, strict)
11✔
928
        end,
929

930
        dohandshake = function(self, wrap_params)
931
          local nskt, err = copas.dohandshake(self.socket, wrap_params or self.ssl_params.wrap)
15✔
932
          if not nskt then return nskt, err end
13✔
933
          self.socket = nskt  -- replace internal socket with the newly wrapped ssl one
13✔
934
          return self
13✔
935
        end,
936

937
      }
28✔
938
}
939

940
-- wraps a UDP socket, copy of TCP one adapted for UDP.
941
local _skt_mt_udp = {__index = { }}
28✔
942
for k,v in pairs(_skt_mt_tcp) do _skt_mt_udp[k] = _skt_mt_udp[k] or v end
84✔
943
for k,v in pairs(_skt_mt_tcp.__index) do _skt_mt_udp.__index[k] = v end
560✔
944

945
_skt_mt_udp.__index.send        = function(self, ...) return self.socket:send(...) end
29✔
946

947
_skt_mt_udp.__index.sendto      = function(self, ...) return self.socket:sendto(...) end
31✔
948

949

950
_skt_mt_udp.__index.receive =     function (self, size)
28✔
951
                                    return copas.receive (self.socket, (size or UDP_DATAGRAM_MAX))
2✔
952
                                  end
953

954
_skt_mt_udp.__index.receivefrom = function (self, size)
28✔
955
                                    return copas.receivefrom (self.socket, (size or UDP_DATAGRAM_MAX))
4✔
956
                                  end
957

958
                                  -- TODO: is this DNS related? hence blocking?
959
_skt_mt_udp.__index.setpeername = function(self, ...) return self.socket:setpeername(...) end
29✔
960

961
_skt_mt_udp.__index.setsockname = function(self, ...) return self.socket:setsockname(...) end
28✔
962

963
                                    -- do not close client, as it is also the server for udp.
964
_skt_mt_udp.__index.close       = function(self, ...) return true end
30✔
965

966
_skt_mt_udp.__index.settimeouts = function (self, connect, send, receive)
28✔
967
                                    return copas.settimeouts(self.socket, connect, send, receive)
×
968
                                  end
969

970

971

972
---
973
-- Wraps a LuaSocket socket object in an async Copas based socket object.
974
-- @param skt The socket to wrap
975
-- @sslt (optional) Table with ssl parameters, use an empty table to use ssl with defaults
976
-- @return wrapped socket object
977
function copas.wrap (skt, sslt)
28✔
978
  if (getmetatable(skt) == _skt_mt_tcp) or (getmetatable(skt) == _skt_mt_udp) then
50✔
979
    return skt -- already wrapped
×
980
  end
981

982
  skt:settimeout(0)
51✔
983

984
  if isTCP(skt) then
100✔
985
    return setmetatable ({socket = skt, ssl_params = normalize_sslt(sslt)}, _skt_mt_tcp)
92✔
986
  else
987
    return setmetatable ({socket = skt}, _skt_mt_udp)
4✔
988
  end
989
end
990

991
--- Wraps a handler in a function that deals with wrapping the socket and doing the
992
-- optional ssl handshake.
993
function copas.handler(handler, sslparams)
28✔
994
  -- TODO: pass a timeout value to set, and use during handshake
995
  return function (skt, ...)
996
    skt = copas.wrap(skt, sslparams) -- this call will normalize the sslparams table
26✔
997
    local sslp = skt.ssl_params
13✔
998
    if sslp.sni then skt:sni(sslp.sni.names, sslp.sni.strict) end
13✔
999
    if sslp.wrap then skt:dohandshake(sslp.wrap) end
13✔
1000
    return handler(skt, ...)
12✔
1001
  end
1002
end
1003

1004

1005
--------------------------------------------------
1006
-- Error handling
1007
--------------------------------------------------
1008

1009
local _errhandlers = setmetatable({}, { __mode = "k" })   -- error handler per coroutine
28✔
1010

1011

1012
function copas.gettraceback(msg, co, skt)
28✔
1013
  local co_str = co == nil and "nil" or copas.getthreadname(co)
5✔
1014
  local skt_str = skt == nil and "nil" or copas.getsocketname(skt)
5✔
1015
  local msg_str = msg == nil and "" or tostring(msg)
5✔
1016
  if msg_str == "" then
5✔
1017
    msg_str = ("(coroutine: %s, socket: %s)"):format(msg_str, co_str, skt_str)
×
1018
  else
1019
    msg_str = ("%s (coroutine: %s, socket: %s)"):format(msg_str, co_str, skt_str)
5✔
1020
  end
1021

1022
  if type(co) == "thread" then
5✔
1023
    -- regular Copas coroutine
1024
    return debug.traceback(co, msg_str)
5✔
1025
  end
1026
  -- not a coroutine, but the main thread, this happens if a timeout callback
1027
  -- (see `copas.timeout` causes an error (those callbacks run on the main thread).
1028
  return debug.traceback(msg_str, 2)
×
1029
end
1030

1031

1032
local function _deferror(msg, co, skt)
1033
  print(copas.gettraceback(msg, co, skt))
6✔
1034
end
1035

1036

1037
function copas.seterrorhandler(err, default)
28✔
1038
  assert(err == nil or type(err) == "function", "Expected the handler to be a function, or nil")
10✔
1039
  if default then
10✔
1040
    assert(err ~= nil, "Expected the handler to be a function when setting the default")
7✔
1041
    _deferror = err
7✔
1042
  else
1043
    _errhandlers[coroutine_running()] = err
3✔
1044
  end
1045
end
1046
copas.setErrorHandler = copas.seterrorhandler  -- deprecated; old casing
28✔
1047

1048

1049
function copas.geterrorhandler(co)
28✔
1050
  co = co or coroutine_running()
2✔
1051
  return _errhandlers[co] or _deferror
2✔
1052
end
1053

1054

1055
-- if `bool` is truthy, then the original socket errors will be returned in case of timeouts;
1056
-- `timeout, wantread, wantwrite, Operation already in progress`. If falsy, it will always
1057
-- return `timeout`.
1058
function copas.useSocketTimeoutErrors(bool)
28✔
1059
  useSocketTimeoutErrors[coroutine_running()] = not not bool -- force to a boolean
1✔
1060
end
1061

1062
-------------------------------------------------------------------------------
1063
-- Thread handling
1064
-------------------------------------------------------------------------------
1065

1066
local function _doTick (co, skt, ...)
1067
  if not co then return end
24,720✔
1068

1069
  -- if a coroutine was canceled/removed, don't resume it
1070
  if _canceled[co] then
24,720✔
1071
    _canceled[co] = nil -- also clean up the registry
2✔
1072
    _threads[co] = nil
2✔
1073
    return
2✔
1074
  end
1075

1076
  -- res: the socket (being read/write on) or the time to sleep
1077
  -- new_q: either _writing, _reading, or _sleeping
1078
  -- local time_before = gettime()
1079
  local ok, res, new_q = coroutine_resume(co, skt, ...)
24,718✔
1080
  -- local duration = gettime() - time_before
1081
  -- if duration > 1 then
1082
  --   duration = math.floor(duration * 1000)
1083
  --   pcall(_errhandlers[co] or _deferror, "task ran for "..tostring(duration).." milliseconds.", co, skt)
1084
  -- end
1085

1086
  if new_q == _reading or new_q == _writing or new_q == _sleeping then
24,716✔
1087
    -- we're yielding to a new queue
1088
    new_q:insert (res)
23,882✔
1089
    new_q:push (res, co)
23,882✔
1090
    return
23,882✔
1091
  end
1092

1093
  -- coroutine is terminating
1094

1095
  if ok and coroutine_status(co) ~= "dead" then
834✔
1096
    -- it called coroutine.yield from a non-Copas function which is unexpected
1097
    ok = false
1✔
1098
    res = "coroutine.yield was called without a resume first, user-code cannot yield to Copas"
1✔
1099
  end
1100

1101
  if not ok then
834✔
1102
    local k, e = pcall(_errhandlers[co] or _deferror, res, co, skt)
7✔
1103
    if not k then
7✔
1104
      print("Failed executing error handler: " .. tostring(e))
×
1105
    end
1106
  end
1107

1108
  local skt_to_close = _autoclose[co]
834✔
1109
  if skt_to_close then
834✔
1110
    skt_to_close:close()
18✔
1111
    _autoclose[co] = nil
18✔
1112
    _autoclose_r[skt_to_close] = nil
18✔
1113
  end
1114

1115
  _errhandlers[co] = nil
834✔
1116
end
1117

1118

1119
local _accept do
28✔
1120
  local client_counters = setmetatable({}, { __mode = "k" })
28✔
1121

1122
  -- accepts a connection on socket input
1123
  function _accept(server_skt, handler)
28✔
1124
    local client_skt = server_skt:accept()
20✔
1125
    if client_skt then
20✔
1126
      local count = (client_counters[server_skt] or 0) + 1
20✔
1127
      client_counters[server_skt] = count
20✔
1128
      object_names[client_skt] = object_names[server_skt] .. ":client_" .. count
33✔
1129

1130
      client_skt:settimeout(0)
20✔
1131
      copas.settimeouts(client_skt, user_timeouts_connect[server_skt],  -- copy server socket timeout settings
40✔
1132
        user_timeouts_send[server_skt], user_timeouts_receive[server_skt])
33✔
1133

1134
      local co = coroutine_create(handler)
20✔
1135
      object_names[co] = object_names[server_skt] .. ":handler_" .. count
20✔
1136

1137
      if copas.autoclose then
20✔
1138
        _autoclose[co] = client_skt
20✔
1139
        _autoclose_r[client_skt] = co
20✔
1140
      end
1141

1142
      _doTick(co, client_skt)
20✔
1143
    end
1144
  end
1145
end
1146

1147
-------------------------------------------------------------------------------
1148
-- Adds a server/handler pair to Copas dispatcher
1149
-------------------------------------------------------------------------------
1150

1151
do
1152
  local function addTCPserver(server, handler, timeout, name)
1153
    server:settimeout(0)
15✔
1154
    if name then
15✔
1155
      object_names[server] = name
×
1156
    end
1157
    _servers[server] = handler
15✔
1158
    _reading:insert(server)
15✔
1159
    if timeout then
15✔
1160
      copas.settimeout(server, timeout)
×
1161
    end
1162
  end
1163

1164
  local function addUDPserver(server, handler, timeout, name)
1165
    server:settimeout(0)
×
1166
    local co = coroutine_create(handler)
×
1167
    if name then
×
1168
      object_names[server] = name
×
1169
    end
1170
    object_names[co] = object_names[server]..":handler"
×
1171
    _reading:insert(server)
×
1172
    if timeout then
×
1173
      copas.settimeout(server, timeout)
×
1174
    end
1175
    _doTick(co, server)
×
1176
  end
1177

1178

1179
  function copas.addserver(server, handler, timeout, name)
28✔
1180
    if isTCP(server) then
30✔
1181
      addTCPserver(server, handler, timeout, name)
30✔
1182
    else
1183
      addUDPserver(server, handler, timeout, name)
×
1184
    end
1185
  end
1186
end
1187

1188

1189
function copas.removeserver(server, keep_open)
28✔
1190
  local skt = server
13✔
1191
  local mt = getmetatable(server)
13✔
1192
  if mt == _skt_mt_tcp or mt == _skt_mt_udp then
13✔
1193
    skt = server.socket
×
1194
  end
1195

1196
  _servers:remove(skt)
13✔
1197
  _reading:remove(skt)
13✔
1198

1199
  if keep_open then
13✔
1200
    return true
3✔
1201
  end
1202
  return server:close()
10✔
1203
end
1204

1205

1206

1207
-------------------------------------------------------------------------------
1208
-- Adds an new coroutine thread to Copas dispatcher
1209
-------------------------------------------------------------------------------
1210
function copas.addnamedthread(name, handler, ...)
28✔
1211
  if type(name) == "function" and type(handler) == "string" then
851✔
1212
    -- old call, flip args for compatibility
1213
    name, handler = handler, name
×
1214
  end
1215

1216
  -- create a coroutine that skips the first argument, which is always the socket
1217
  -- passed by the scheduler, but `nil` in case of a task/thread
1218
  local thread = coroutine_create(function(_, ...)
1,702✔
1219
    copas.pause()
851✔
1220
    return handler(...)
850✔
1221
  end)
1222
  if name then
851✔
1223
    object_names[thread] = name
61✔
1224
  end
1225

1226
  _threads[thread] = true -- register this thread so it can be removed
851✔
1227
  _doTick (thread, nil, ...)
851✔
1228
  return thread
851✔
1229
end
1230

1231

1232
function copas.addthread(handler, ...)
28✔
1233
  return copas.addnamedthread(nil, handler, ...)
790✔
1234
end
1235

1236

1237
function copas.removethread(thread)
28✔
1238
  -- if the specified coroutine is registered, add it to the canceled table so
1239
  -- that next time it tries to resume it exits.
1240
  _canceled[thread] = _threads[thread or 0]
4✔
1241
end
1242

1243

1244

1245
-------------------------------------------------------------------------------
1246
-- Sleep/pause management functions
1247
-------------------------------------------------------------------------------
1248

1249
-- yields the current coroutine and wakes it after 'sleeptime' seconds.
1250
-- If sleeptime < 0 then it sleeps until explicitly woken up using 'wakeup'
1251
-- TODO: deprecated, remove in next major
1252
function copas.sleep(sleeptime)
28✔
1253
  coroutine_yield((sleeptime or 0), _sleeping)
×
1254
end
1255

1256

1257
-- yields the current coroutine and wakes it after 'sleeptime' seconds.
1258
-- if sleeptime < 0 then it sleeps 0 seconds.
1259
function copas.pause(sleeptime)
28✔
1260
  if sleeptime and sleeptime > 0 then
23,221✔
1261
    coroutine_yield(sleeptime, _sleeping)
2,126✔
1262
  else
1263
    coroutine_yield(0, _sleeping)
22,143✔
1264
  end
1265
end
1266

1267

1268
-- yields the current coroutine until explicitly woken up using 'wakeup'
1269
function copas.pauseforever()
28✔
1270
  coroutine_yield(-1, _sleeping)
515✔
1271
end
1272

1273

1274
-- Wakes up a sleeping coroutine 'co'.
1275
function copas.wakeup(co)
28✔
1276
  _sleeping:wakeup(co)
515✔
1277
end
1278

1279

1280

1281
-------------------------------------------------------------------------------
1282
-- Timeout management
1283
-------------------------------------------------------------------------------
1284

1285
do
1286
  local timeout_register = setmetatable({}, { __mode = "k" })
28✔
1287
  local time_out_thread
1288
  local timerwheel = require("timerwheel").new({
56✔
1289
      precision = TIMEOUT_PRECISION,                -- timeout precision 100ms
28✔
1290
      ringsize = math.floor(60/TIMEOUT_PRECISION),  -- ring size 1 minute
28✔
1291
      err_handler = function(err)
1292
        return _deferror(err, time_out_thread)
2✔
1293
      end,
1294
    })
1295

1296
  time_out_thread = copas.addnamedthread("copas_core_timer", function()
56✔
1297
    while true do
1298
      copas.pause(TIMEOUT_PRECISION)
948✔
1299
      timerwheel:step()
1,842✔
1300
    end
1301
  end)
1302

1303
  -- get the number of timeouts running
1304
  function copas.gettimeouts()
28✔
1305
    return timerwheel:count()
308✔
1306
  end
1307

1308
  --- Sets the timeout for the current coroutine.
1309
  -- @param delay delay (seconds), use 0 to cancel the timerout
1310
  -- @param callback function with signature: `function(coroutine)` where coroutine is the routine that timed-out
1311
  -- @return true
1312
  function copas.timeout(delay, callback)
28✔
1313
    local co = coroutine_running()
397,705✔
1314
    local existing_timer = timeout_register[co]
397,705✔
1315

1316
    if existing_timer then
397,705✔
1317
      timerwheel:cancel(existing_timer)
198,722✔
1318
    end
1319

1320
    if delay > 0 then
397,705✔
1321
      timeout_register[co] = timerwheel:set(delay, callback, co)
397,970✔
1322
    elseif delay == 0 then
198,720✔
1323
      timeout_register[co] = nil
198,720✔
1324
    else
1325
      error("timout value must be greater than or equal to 0, got: "..tostring(delay))
×
1326
    end
1327

1328
    return true
397,705✔
1329
  end
1330

1331
end
1332

1333

1334
-------------------------------------------------------------------------------
1335
-- main tasks: manage readable and writable socket sets
1336
-------------------------------------------------------------------------------
1337
-- a task is an object with a required method `step()` that deals with a
1338
-- single step for that task.
1339

1340
local _tasks = {} do
28✔
1341
  function _tasks:add(tsk)
28✔
1342
    _tasks[#_tasks + 1] = tsk
112✔
1343
  end
1344
end
1345

1346

1347
-- a task to check ready to read events
1348
local _readable_task = {} do
28✔
1349

1350
  local function tick(skt)
1351
    local handler = _servers[skt]
111✔
1352
    if handler then
111✔
1353
      _accept(skt, handler)
40✔
1354
    else
1355
      _reading:remove(skt)
91✔
1356
      _doTick(_reading:pop(skt), skt)
182✔
1357
    end
1358
  end
1359

1360
  function _readable_task:step()
28✔
1361
    for _, skt in ipairs(self._events) do
22,280✔
1362
      tick(skt)
111✔
1363
    end
1364
  end
1365

1366
  _tasks:add(_readable_task)
56✔
1367
end
1368

1369

1370
-- a task to check ready to write events
1371
local _writable_task = {} do
28✔
1372

1373
  local function tick(skt)
1374
    _writing:remove(skt)
42✔
1375
    _doTick(_writing:pop(skt), skt)
84✔
1376
  end
1377

1378
  function _writable_task:step()
28✔
1379
    for _, skt in ipairs(self._events) do
22,211✔
1380
      tick(skt)
42✔
1381
    end
1382
  end
1383

1384
  _tasks:add(_writable_task)
56✔
1385
end
1386

1387

1388

1389
-- sleeping threads task
1390
local _sleeping_task = {} do
28✔
1391

1392
  function _sleeping_task:step()
28✔
1393
    local now = gettime()
22,169✔
1394

1395
    local co = _sleeping:pop(now)
22,169✔
1396
    while co do
23,219✔
1397
      -- we're pushing them to _resumable, since that list will be replaced before
1398
      -- executing. This prevents tasks running twice in a row with pause(0) for example.
1399
      -- So here we won't execute, but at _resumable step which is next
1400
      _resumable:push(co)
1,050✔
1401
      co = _sleeping:pop(now)
2,100✔
1402
    end
1403
  end
1404

1405
  _tasks:add(_sleeping_task)
28✔
1406
end
1407

1408

1409

1410
-- resumable threads task
1411
local _resumable_task = {} do
28✔
1412

1413
  function _resumable_task:step()
28✔
1414
    -- replace the resume list before iterating, so items placed in there
1415
    -- will indeed end up in the next copas step, not in this one, and not
1416
    -- create a loop
1417
    local resumelist = _resumable:clear_resumelist()
22,169✔
1418

1419
    for _, co in ipairs(resumelist) do
45,884✔
1420
      _doTick(co)
23,716✔
1421
    end
1422
  end
1423

1424
  _tasks:add(_resumable_task)
28✔
1425
end
1426

1427

1428
-------------------------------------------------------------------------------
1429
-- Checks for reads and writes on sockets
1430
-------------------------------------------------------------------------------
1431
local _select_plain do
28✔
1432

1433
  local last_cleansing = 0
28✔
1434
  local duration = function(t2, t1) return t2-t1 end
22,198✔
1435

1436
  _select_plain = function(timeout)
1437
    local err
1438
    local now = gettime()
22,170✔
1439

1440
    -- remove any closed sockets to prevent select from hanging on them
1441
    if _closed[1] then
22,170✔
1442
      for i, skt in ipairs(_closed) do
56✔
1443
        _closed[i] = { _reading:remove(skt), _writing:remove(skt) }
84✔
1444
      end
1445
    end
1446

1447
    _readable_task._events, _writable_task._events, err = socket.select(_reading, _writing, timeout)
22,170✔
1448
    local r_events, w_events = _readable_task._events, _writable_task._events
22,170✔
1449

1450
    -- inject closed sockets in readable/writeable task so they can error out properly
1451
    if _closed[1] then
22,170✔
1452
      for i, skts in ipairs(_closed) do
56✔
1453
        _closed[i] = nil
28✔
1454
        r_events[#r_events+1] = skts[1]
28✔
1455
        w_events[#w_events+1] = skts[2]
28✔
1456
      end
1457
    end
1458

1459
    if duration(now, last_cleansing) > WATCH_DOG_TIMEOUT then
44,340✔
1460
      last_cleansing = now
27✔
1461

1462
      -- Check all sockets selected for reading, and check how long they have been waiting
1463
      -- for data already, without select returning them as readable
1464
      for skt,time in pairs(_reading_log) do
27✔
1465
        if not r_events[skt] and duration(now, time) > WATCH_DOG_TIMEOUT then
×
1466
          -- This one timedout while waiting to become readable, so move
1467
          -- it in the readable list and try and read anyway, despite not
1468
          -- having been returned by select
1469
          _reading_log[skt] = nil
×
1470
          r_events[#r_events + 1] = skt
×
1471
          r_events[skt] = #r_events
×
1472
        end
1473
      end
1474

1475
      -- Do the same for writing
1476
      for skt,time in pairs(_writing_log) do
27✔
1477
        if not w_events[skt] and duration(now, time) > WATCH_DOG_TIMEOUT then
×
1478
          _writing_log[skt] = nil
×
1479
          w_events[#w_events + 1] = skt
×
1480
          w_events[skt] = #w_events
×
1481
        end
1482
      end
1483
    end
1484

1485
    if err == "timeout" and #r_events + #w_events > 0 then
22,170✔
1486
      return nil
1✔
1487
    else
1488
      return err
22,169✔
1489
    end
1490
  end
1491
end
1492

1493

1494

1495
-------------------------------------------------------------------------------
1496
-- Dispatcher loop step.
1497
-- Listen to client requests and handles them
1498
-- Returns false if no socket-data was handled, or true if there was data
1499
-- handled (or nil + error message)
1500
-------------------------------------------------------------------------------
1501

1502
local copas_stats
1503
local min_ever, max_ever
1504

1505
local _select = _select_plain
28✔
1506

1507
-- instrumented version of _select() to collect stats
1508
local _select_instrumented = function(timeout)
1509
  if copas_stats then
×
1510
    local step_duration = gettime() - copas_stats.step_start
×
1511
    copas_stats.duration_max = math.max(copas_stats.duration_max, step_duration)
×
1512
    copas_stats.duration_min = math.min(copas_stats.duration_min, step_duration)
×
1513
    copas_stats.duration_tot = copas_stats.duration_tot + step_duration
×
1514
    copas_stats.steps = copas_stats.steps + 1
×
1515
  else
1516
    copas_stats = {
×
1517
      duration_max = -1,
1518
      duration_min = 999999,
1519
      duration_tot = 0,
1520
      steps = 0,
1521
    }
1522
  end
1523

1524
  local err = _select_plain(timeout)
×
1525

1526
  local now = gettime()
×
1527
  copas_stats.time_start = copas_stats.time_start or now
×
1528
  copas_stats.step_start = now
×
1529

1530
  return err
×
1531
end
1532

1533

1534
function copas.step(timeout)
28✔
1535
  -- Need to wake up the select call in time for the next sleeping event
1536
  if not _resumable:done() then
44,340✔
1537
    timeout = 0
21,348✔
1538
  else
1539
    timeout = math.min(_sleeping:getnext(), timeout or math.huge)
1,644✔
1540
  end
1541

1542
  local err = _select(timeout)
22,170✔
1543

1544
  for _, tsk in ipairs(_tasks) do
110,845✔
1545
    tsk:step()
88,677✔
1546
  end
1547

1548
  if err then
22,168✔
1549
    if err == "timeout" then
22,027✔
1550
      if timeout + 0.01 > TIMEOUT_PRECISION and math.random(100) > 90 then
22,027✔
1551
        -- we were idle, so occasionally do a GC sweep to ensure lingering
1552
        -- sockets are closed, and we don't accidentally block the loop from
1553
        -- exiting
1554
        collectgarbage()
60✔
1555
      end
1556
      return false
22,027✔
1557
    end
1558
    return nil, err
×
1559
  end
1560

1561
  return true
141✔
1562
end
1563

1564

1565
-------------------------------------------------------------------------------
1566
-- Check whether there is something to do.
1567
-- returns false if there are no sockets for read/write nor tasks scheduled
1568
-- (which means Copas is in an empty spin)
1569
-------------------------------------------------------------------------------
1570
function copas.finished()
28✔
1571
  return #_reading == 0 and #_writing == 0 and _resumable:done() and _sleeping:done(copas.gettimeouts())
23,235✔
1572
end
1573

1574
local _getstats do
28✔
1575
  local _getstats_instrumented, _getstats_plain
1576

1577

1578
  function _getstats_plain(enable)
28✔
1579
    -- this function gets hit if turned off, so turn on if true
1580
    if enable == true then
×
1581
      _select = _select_instrumented
×
1582
      _getstats = _getstats_instrumented
×
1583
      -- reset stats
1584
      min_ever = nil
×
1585
      max_ever = nil
×
1586
      copas_stats = nil
×
1587
    end
1588
    return {}
×
1589
  end
1590

1591

1592
  -- convert from seconds to millisecs, with microsec precision
1593
  local function useconds(t)
1594
    return math.floor((t * 1000000) + 0.5) / 1000
×
1595
  end
1596
  -- convert from seconds to seconds, with millisec precision
1597
  local function mseconds(t)
1598
    return math.floor((t * 1000) + 0.5) / 1000
×
1599
  end
1600

1601

1602
  function _getstats_instrumented(enable)
28✔
1603
    if enable == false then
×
1604
      _select = _select_plain
×
1605
      _getstats = _getstats_plain
×
1606
      -- instrumentation disabled, so switch to the plain implementation
1607
      return _getstats(enable)
×
1608
    end
1609
    if (not copas_stats) or (copas_stats.step == 0) then
×
1610
      return {}
×
1611
    end
1612
    local stats = copas_stats
×
1613
    copas_stats = nil
×
1614
    min_ever = math.min(min_ever or 9999999, stats.duration_min)
×
1615
    max_ever = math.max(max_ever or 0, stats.duration_max)
×
1616
    stats.duration_min_ever = min_ever
×
1617
    stats.duration_max_ever = max_ever
×
1618
    stats.duration_avg = stats.duration_tot / stats.steps
×
1619
    stats.step_start = nil
×
1620
    stats.time_end = gettime()
×
1621
    stats.time_tot = stats.time_end - stats.time_start
×
1622
    stats.time_avg = stats.time_tot / stats.steps
×
1623

1624
    stats.duration_avg = useconds(stats.duration_avg)
×
1625
    stats.duration_max = useconds(stats.duration_max)
×
1626
    stats.duration_max_ever = useconds(stats.duration_max_ever)
×
1627
    stats.duration_min = useconds(stats.duration_min)
×
1628
    stats.duration_min_ever = useconds(stats.duration_min_ever)
×
1629
    stats.duration_tot = useconds(stats.duration_tot)
×
1630
    stats.time_avg = useconds(stats.time_avg)
×
1631
    stats.time_start = mseconds(stats.time_start)
×
1632
    stats.time_end = mseconds(stats.time_end)
×
1633
    stats.time_tot = mseconds(stats.time_tot)
×
1634
    return stats
×
1635
  end
1636

1637
  _getstats = _getstats_plain
28✔
1638
end
1639

1640

1641
function copas.status(enable_stats)
28✔
1642
  local res = _getstats(enable_stats)
×
1643
  res.running = not not copas.running
×
1644
  res.timeout = copas.gettimeouts()
×
1645
  res.timer, res.inactive = _sleeping:status()
×
1646
  res.read = #_reading
×
1647
  res.write = #_writing
×
1648
  res.active = _resumable:count()
×
1649
  return res
×
1650
end
1651

1652

1653
-------------------------------------------------------------------------------
1654
-- Dispatcher endless loop.
1655
-- Listen to client requests and handles them forever
1656
-------------------------------------------------------------------------------
1657
function copas.loop(initializer, timeout)
28✔
1658
  if type(initializer) == "function" then
42✔
1659
    copas.addnamedthread("copas_initializer", initializer)
28✔
1660
  else
1661
    timeout = initializer or timeout
28✔
1662
  end
1663

1664
  copas.running = true
42✔
1665
  while not copas.finished() do copas.step(timeout) end
66,588✔
1666
  copas.running = false
40✔
1667
end
1668

1669

1670
-------------------------------------------------------------------------------
1671
-- Naming sockets and coroutines.
1672
-------------------------------------------------------------------------------
1673
do
1674
  local function realsocket(skt)
1675
    local mt = getmetatable(skt)
15✔
1676
    if mt == _skt_mt_tcp or mt == _skt_mt_udp then
15✔
1677
      return skt.socket
15✔
1678
    else
1679
      return skt
×
1680
    end
1681
  end
1682

1683

1684
  function copas.setsocketname(name, skt)
28✔
1685
    assert(type(name) == "string", "expected arg #1 to be a string")
15✔
1686
    skt = assert(realsocket(skt), "expected arg #2 to be a socket")
30✔
1687
    object_names[skt] = name
15✔
1688
  end
1689

1690

1691
  function copas.getsocketname(skt)
28✔
1692
    skt = assert(realsocket(skt), "expected arg #1 to be a socket")
×
1693
    return object_names[skt]
×
1694
  end
1695
end
1696

1697

1698
function copas.setthreadname(name, coro)
28✔
1699
  assert(type(name) == "string", "expected arg #1 to be a string")
10✔
1700
  coro = coro or coroutine_running()
10✔
1701
  assert(type(coro) == "thread", "expected arg #2 to be a coroutine or nil")
10✔
1702
  object_names[coro] = name
10✔
1703
end
1704

1705

1706
function copas.getthreadname(coro)
28✔
1707
  coro = coro or coroutine_running()
5✔
1708
  assert(type(coro) == "thread", "expected arg #1 to be a coroutine or nil")
5✔
1709
  return object_names[coro]
7✔
1710
end
1711

1712
-------------------------------------------------------------------------------
1713
-- Debug functionality.
1714
-------------------------------------------------------------------------------
1715
do
1716
  copas.debug = {}
28✔
1717

1718
  local log_core    -- if truthy, the core-timer will also be logged
1719
  local debug_log   -- function used as logger
1720

1721

1722
  local debug_yield = function(skt, queue)
1723
    local name = object_names[coroutine_running()]
717✔
1724

1725
    if log_core or name ~= "copas_core_timer" then
717✔
1726
      if queue == _sleeping then
715✔
1727
        debug_log("yielding '", name, "' to SLEEP for ", skt," seconds")
711✔
1728

1729
      elseif queue == _writing then
4✔
1730
        debug_log("yielding '", name, "' to WRITE on '", object_names[skt], "'")
2✔
1731

1732
      elseif queue == _reading then
3✔
1733
        debug_log("yielding '", name, "' to READ on '", object_names[skt], "'")
4✔
1734

1735
      else
1736
        debug_log("thread '", name, "' yielding to unexpected queue; ", tostring(queue), " (", type(queue), ")", debug.traceback())
×
1737
      end
1738
    end
1739

1740
    return coroutine.yield(skt, queue)
717✔
1741
  end
1742

1743

1744
  local debug_resume = function(coro, skt, ...)
1745
    local name = object_names[coro]
719✔
1746

1747
    if skt then
719✔
1748
      debug_log("resuming '", name, "' for socket '", object_names[skt], "'")
4✔
1749
    else
1750
      if log_core or name ~= "copas_core_timer" then
715✔
1751
        debug_log("resuming '", name, "'")
713✔
1752
      end
1753
    end
1754
    return coroutine.resume(coro, skt, ...)
719✔
1755
  end
1756

1757

1758
  local debug_create = function(f)
1759
    local f_wrapped = function(...)
1760
      local results = pack(f(...))
4✔
1761
      debug_log("exiting '", object_names[coroutine_running()], "'")
2✔
1762
      return unpack(results)
2✔
1763
    end
1764

1765
    return coroutine.create(f_wrapped)
2✔
1766
  end
1767

1768

1769
  debug_log = fnil
28✔
1770

1771

1772
  -- enables debug output for all coroutine operations.
1773
  function copas.debug.start(logger, core)
56✔
1774
    log_core = core
1✔
1775
    debug_log = logger or print
1✔
1776
    coroutine_yield = debug_yield
1✔
1777
    coroutine_resume = debug_resume
1✔
1778
    coroutine_create = debug_create
1✔
1779
  end
1780

1781

1782
  -- disables debug output for coroutine operations.
1783
  function copas.debug.stop()
56✔
1784
    debug_log = fnil
×
1785
    coroutine_yield = coroutine.yield
×
1786
    coroutine_resume = coroutine.resume
×
1787
    coroutine_create = coroutine.create
×
1788
  end
1789

1790
  do
1791
    local call_id = 0
28✔
1792

1793
    -- Description table of socket functions for debug output.
1794
    -- each socket function name has TWO entries;
1795
    -- 'name_in' and 'name_out', each being an array of names/descriptions of respectively
1796
    -- input parameters and return values.
1797
    -- If either table has a 'callback' key, then that is a function that will be called
1798
    -- with the parameters/return-values for further inspection.
1799
    local args = {
28✔
1800
      settimeout_in = {
28✔
1801
        "socket ",
1802
        "seconds",
1803
        "mode   ",
1804
      },
28✔
1805
      settimeout_out = {
28✔
1806
        "success",
1807
        "error  ",
1808
      },
28✔
1809
      connect_in = {
28✔
1810
        "socket ",
1811
        "address",
1812
        "port   ",
1813
      },
28✔
1814
      connect_out = {
28✔
1815
        "success",
1816
        "error  ",
1817
      },
28✔
1818
      getfd_in = {
28✔
1819
        "socket ",
1820
        -- callback = function(...)
1821
        --   print(debug.traceback("called from:", 4))
1822
        -- end,
1823
      },
28✔
1824
      getfd_out = {
28✔
1825
        "fd",
1826
      },
28✔
1827
      send_in = {
28✔
1828
        "socket   ",
1829
        "data     ",
1830
        "idx-start",
1831
        "idx-end  ",
1832
      },
28✔
1833
      send_out = {
28✔
1834
        "last-idx-send    ",
1835
        "error            ",
1836
        "err-last-idx-send",
1837
      },
28✔
1838
      receive_in = {
28✔
1839
        "socket ",
1840
        "pattern",
1841
        "prefix ",
1842
      },
28✔
1843
      receive_out = {
28✔
1844
        "received    ",
1845
        "error       ",
1846
        "partial data",
1847
      },
28✔
1848
      dirty_in = {
28✔
1849
        "socket",
1850
        -- callback = function(...)
1851
        --   print(debug.traceback("called from:", 4))
1852
        -- end,
1853
      },
28✔
1854
      dirty_out = {
28✔
1855
        "data in read-buffer",
1856
      },
28✔
1857
      close_in = {
28✔
1858
        "socket",
1859
        -- callback = function(...)
1860
        --   print(debug.traceback("called from:", 4))
1861
        -- end,
1862
      },
28✔
1863
      close_out = {
28✔
1864
        "success",
1865
        "error",
1866
      },
28✔
1867
    }
1868
    local function print_call(func, msg, ...)
1869
      print(msg)
102✔
1870
      local arg = pack(...)
102✔
1871
      local desc = args[func] or {}
102✔
1872
      for i = 1, math.max(arg.n, #desc) do
216✔
1873
        local value = arg[i]
114✔
1874
        if type(value) == "string" then
114✔
1875
          local xvalue = value:sub(1,30)
3✔
1876
          if xvalue ~= value then
3✔
1877
            xvalue = xvalue .."(...truncated)"
×
1878
          end
1879
          print("\t"..(desc[i] or i)..": '"..tostring(xvalue).."' ("..type(value).." #"..#value..")")
3✔
1880
        else
1881
          print("\t"..(desc[i] or i)..": '"..tostring(value).."' ("..type(value)..")")
111✔
1882
        end
1883
      end
1884
      if desc.callback then
102✔
1885
        desc.callback(...)
×
1886
      end
1887
    end
1888

1889
    local debug_mt = {
28✔
1890
      __index = function(self, key)
1891
        local value = self.__original_socket[key]
51✔
1892
        if type(value) ~= "function" then
51✔
1893
          return value
×
1894
        end
1895
        return function(self2, ...)
1896
            local my_id = call_id + 1
51✔
1897
            call_id = my_id
51✔
1898
            local results
1899

1900
            if self2 ~= self then
51✔
1901
              -- there is no self
1902
              print_call(tostring(key).."_in", my_id .. "-calling '"..tostring(key) .. "' with; ", self, ...)
×
1903
              results = pack(value(self, ...))
×
1904
            else
1905
              print_call(tostring(key).."_in", my_id .. "-calling '" .. tostring(key) .. "' with; ", self.__original_socket, ...)
51✔
1906
              results = pack(value(self.__original_socket, ...))
102✔
1907
            end
1908
            print_call(tostring(key).."_out", my_id .. "-results '"..tostring(key) .. "' returned; ", unpack(results))
102✔
1909
            return unpack(results)
51✔
1910
          end
1911
      end,
1912
      __tostring = function(self)
1913
        return tostring(self.__original_socket)
4✔
1914
      end
1915
    }
1916

1917

1918
    -- wraps a socket (copas or luasocket) in a debug version printing all calls
1919
    -- and their parameters/return values. Extremely noisy!
1920
    -- returns the wrapped socket.
1921
    -- NOTE: only for plain sockets, will not support TLS
1922
    function copas.debug.socket(original_skt)
56✔
1923
      if (getmetatable(original_skt) == _skt_mt_tcp) or (getmetatable(original_skt) == _skt_mt_udp) then
1✔
1924
        -- already wrapped as Copas socket, so recurse with the original luasocket one
1925
        original_skt.socket = copas.debug.socket(original_skt.socket)
×
1926
        return original_skt
×
1927
      end
1928

1929
      local proxy = setmetatable({
2✔
1930
        __original_socket = original_skt
1✔
1931
      }, debug_mt)
1✔
1932

1933
      return proxy
1✔
1934
    end
1935
  end
1936
end
1937

1938

1939
return copas
28✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc