• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

processone / ejabberd / 956

20 Mar 2025 11:52AM UTC coverage: 33.773% (-0.01%) from 33.785%
956

push

github

badlop
Update moved or broken URLs in documentation

15146 of 44847 relevant lines covered (33.77%)

624.22 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

44.69
/src/ejabberd_sql.erl
1
%%%----------------------------------------------------------------------
2
%%% File    : ejabberd_sql.erl
3
%%% Author  : Alexey Shchepin <alexey@process-one.net>
4
%%% Purpose : Serve SQL connection
5
%%% Created :  8 Dec 2004 by Alexey Shchepin <alexey@process-one.net>
6
%%%
7
%%%
8
%%% ejabberd, Copyright (C) 2002-2025   ProcessOne
9
%%%
10
%%% This program is free software; you can redistribute it and/or
11
%%% modify it under the terms of the GNU General Public License as
12
%%% published by the Free Software Foundation; either version 2 of the
13
%%% License, or (at your option) any later version.
14
%%%
15
%%% This program is distributed in the hope that it will be useful,
16
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
17
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18
%%% General Public License for more details.
19
%%%
20
%%% You should have received a copy of the GNU General Public License along
21
%%% with this program; if not, write to the Free Software Foundation, Inc.,
22
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
23
%%%
24
%%%----------------------------------------------------------------------
25

26
-module(ejabberd_sql).
27

28
-author('alexey@process-one.net').
29

30
-behaviour(p1_fsm).
31

32
%% External exports
33
-export([start_link/2,
34
         sql_query/2,
35
         sql_query/3,
36
         sql_query_t/1,
37
         sql_transaction/2,
38
         sql_transaction/4,
39
         sql_bloc/2,
40
         sql_bloc/3,
41
         abort/1,
42
         restart/1,
43
         use_new_schema/0,
44
         sql_query_to_iolist/1,
45
         sql_query_to_iolist/2,
46
         escape/1,
47
         standard_escape/1,
48
         escape_like/1,
49
         escape_like_arg/1,
50
         escape_like_arg_circumflex/1,
51
         to_string_literal/2,
52
         to_string_literal_t/1,
53
         to_bool/1,
54
         sqlite_db/1,
55
         sqlite_file/1,
56
         encode_term/1,
57
         decode_term/1,
58
         odbcinst_config/0,
59
         init_mssql/1,
60
         keep_alive/2,
61
         to_list/2,
62
         to_array/2,
63
         parse_mysql_version/2]).
64

65
%% gen_fsm callbacks
66
-export([init/1, handle_event/3, handle_sync_event/4,
67
         handle_info/3, terminate/3, print_state/1,
68
         code_change/4]).
69

70
-export([connecting/2, connecting/3,
71
         session_established/2, session_established/3]).
72

73
-ifdef(OTP_BELOW_28).
74
-ifdef(OTP_BELOW_26).
75
%% OTP 25 or lower
76
-type(odbc_connection_reference() ::  pid()).
77
-type(db_ref_pid() :: pid()).
78
-else.
79
%% OTP 26 or 27
80
-type(odbc_connection_reference() ::  odbc:connection_reference()).
81
-type(db_ref_pid() :: pid()).
82
-endif.
83
-else.
84
%% OTP 28 or higher
85
-nominal(odbc_connection_reference() :: odbc:connection_reference()).
86
-nominal(db_ref_pid() :: pid()).
87
-dialyzer([no_opaque_union]).
88
-endif.
89

90
-include("logger.hrl").
91
-include("ejabberd_sql_pt.hrl").
92
-include("ejabberd_stacktrace.hrl").
93

94
-record(state,
95
        {db_ref               :: undefined | db_ref_pid() | odbc_connection_reference(),
96
         db_type = odbc       :: pgsql | mysql | sqlite | odbc | mssql,
97
         db_version           :: undefined | non_neg_integer() | {non_neg_integer(), atom(), non_neg_integer()},
98
         reconnect_count = 0  :: non_neg_integer(),
99
         host                 :: binary(),
100
         pending_requests     :: p1_queue:queue(),
101
         overload_reported    :: undefined | integer(),
102
         timeout              :: pos_integer()}).
103

104
-define(STATE_KEY, ejabberd_sql_state).
105
-define(NESTING_KEY, ejabberd_sql_nesting_level).
106
-define(TOP_LEVEL_TXN, 0).
107
-define(MAX_TRANSACTION_RESTARTS, 10).
108
-define(KEEPALIVE_QUERY, [<<"SELECT 1;">>]).
109
-define(PREPARE_KEY, ejabberd_sql_prepare).
110
%%-define(DBGFSM, true).
111
-ifdef(DBGFSM).
112
-define(FSMOPTS, [{debug, [trace]}]).
113
-else.
114
-define(FSMOPTS, []).
115
-endif.
116

117
-type state() :: #state{}.
118
-type sql_query_simple(T) :: [sql_query(T) | binary()] | binary() |
119
                             #sql_query{} |
120
                             fun(() -> T) | fun((atom(), _) -> T).
121
-type sql_query(T) :: sql_query_simple(T) |
122
                      [{atom() | {atom(), any()}, sql_query_simple(T)}].
123
-type sql_query_result(T) :: {updated, non_neg_integer()} |
124
                             {error, binary() | atom()} |
125
                             {selected, [binary()], [[binary()]]} |
126
                             {selected, [any()]} |
127
                             T.
128

129
%%%----------------------------------------------------------------------
130
%%% API
131
%%%----------------------------------------------------------------------
132
-spec start_link(binary(), pos_integer()) -> {ok, pid()} | {error, term()}.
133
start_link(Host, I) ->
134
    Proc = binary_to_atom(get_worker_name(Host, I), utf8),
3✔
135
    p1_fsm:start_link({local, Proc}, ?MODULE, [Host],
3✔
136
                      fsm_limit_opts() ++ ?FSMOPTS).
137

138
-spec sql_query(binary(), sql_query(T), pos_integer()) -> sql_query_result(T).
139
sql_query(Host, Query, Timeout) ->
140
    sql_call(Host, {sql_query, Query}, Timeout).
16,618✔
141

142
-spec sql_query(binary(), sql_query(T)) -> sql_query_result(T).
143
sql_query(Host, Query) ->
144
    sql_query(Host, Query, query_timeout(Host)).
16,618✔
145

146
%% SQL transaction based on a list of queries
147
%% This function automatically
148
-spec sql_transaction(binary(), [sql_query(T)] | fun(() -> T), pos_integer(), pos_integer()) ->
149
                             {atomic, T} |
150
                             {aborted, any()}.
151
sql_transaction(Host, Queries, Timeout, Restarts)
152
    when is_list(Queries) ->
153
    F = fun () ->
5✔
154
                lists:foreach(fun (Query) -> sql_query_t(Query) end,
5✔
155
                              Queries)
156
        end,
157
    sql_transaction(Host, F, Timeout, Restarts);
5✔
158
%% SQL transaction, based on a erlang anonymous function (F = fun)
159
sql_transaction(Host, F, Timeout, Restarts) when is_function(F) ->
160
    case sql_call(Host, {sql_transaction, F, Restarts}, Timeout) of
2,726✔
161
        {atomic, _} = Ret -> Ret;
2,726✔
162
        {aborted, _} = Ret -> Ret;
×
163
        Err -> {aborted, Err}
×
164
    end.
165

166
-spec sql_transaction(binary(), [sql_query(T)] | fun(() -> T)) ->
167
    {atomic, T} |
168
    {aborted, any()}.
169
sql_transaction(Host, Queries) ->
170
    sql_transaction(Host, Queries, query_timeout(Host), ?MAX_TRANSACTION_RESTARTS).
2,726✔
171

172
%% SQL bloc, based on a erlang anonymous function (F = fun)
173
sql_bloc(Host, F, Timeout) ->
174
    sql_call(Host, {sql_bloc, F}, Timeout).
2,664✔
175

176
sql_bloc(Host, F) ->
177
    sql_bloc(Host, F, query_timeout(Host)).
2,664✔
178

179
sql_call(Host, Msg, Timeout) ->
180
    case get(?STATE_KEY) of
22,008✔
181
        undefined ->
182
            sync_send_event(Host,
21,966✔
183
                            {sql_cmd, Msg, current_time() + Timeout},
184
                            Timeout);
185
        _State ->
186
            nested_op(Msg)
42✔
187
    end.
188

189
keep_alive(Host, Proc) ->
190
    Timeout = query_timeout(Host),
×
191
    case sync_send_event(
×
192
           Proc,
193
           {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}, current_time() + Timeout},
194
           Timeout) of
195
        {selected,_,[[<<"1">>]]} ->
196
            ok;
×
197
        Err ->
198
            ?ERROR_MSG("Keep alive query failed, closing connection: ~p", [Err]),
×
199
            sync_send_event(Proc, force_timeout, Timeout)
×
200
    end.
201

202
sync_send_event(Host, Msg, Timeout) when is_binary(Host) ->
203
    case ejabberd_sql_sup:start(Host) of
21,966✔
204
        ok ->
205
            Proc = get_worker(Host),
21,966✔
206
            sync_send_event(Proc, Msg, Timeout);
21,966✔
207
        {error, _} = Err ->
208
            Err
×
209
    end;
210
sync_send_event(Proc, Msg, Timeout) ->
211
    try p1_fsm:sync_send_event(Proc, Msg, Timeout)
21,966✔
212
    catch _:{Reason, {p1_fsm, _, _}} ->
213
            {error, Reason}
×
214
    end.
215

216
-spec sql_query_t(sql_query(T)) -> sql_query_result(T).
217
%% This function is intended to be used from inside an sql_transaction:
218
sql_query_t(Query) ->
219
    QRes = sql_query_internal(Query),
19,959✔
220
    case QRes of
19,959✔
221
      {error, Reason} -> restart(Reason);
×
222
      Rs when is_list(Rs) ->
223
          case lists:keysearch(error, 1, Rs) of
×
224
            {value, {error, Reason}} -> restart(Reason);
×
225
            _ -> QRes
×
226
          end;
227
      _ -> QRes
19,959✔
228
    end.
229

230
abort(Reason) ->
231
    exit(Reason).
×
232

233
restart(Reason) ->
234
    throw({aborted, Reason}).
×
235

236
-spec escape_char(char()) -> binary().
237
escape_char($\000) -> <<"\\0">>;
×
238
escape_char($\n) -> <<"\\n">>;
×
239
escape_char($\t) -> <<"\\t">>;
×
240
escape_char($\b) -> <<"\\b">>;
×
241
escape_char($\r) -> <<"\\r">>;
×
242
escape_char($') -> <<"''">>;
16✔
243
escape_char($") -> <<"\\\"">>;
16✔
244
escape_char($\\) -> <<"\\\\">>;
440✔
245
escape_char(C) -> <<C>>.
20,604✔
246

247
-spec escape(binary()) -> binary().
248
escape(S) ->
249
        <<  <<(escape_char(Char))/binary>> || <<Char>> <= S >>.
844✔
250

251
%% Escape character that will confuse an SQL engine
252
%% Percent and underscore only need to be escaped for pattern matching like
253
%% statement
254
escape_like(S) when is_binary(S) ->
255
    << <<(escape_like(C))/binary>> || <<C>> <= S >>;
×
256
escape_like($%) -> <<"\\%">>;
×
257
escape_like($_) -> <<"\\_">>;
×
258
escape_like($\\) -> <<"\\\\\\\\">>;
×
259
escape_like(C) when is_integer(C), C >= 0, C =< 255 -> escape_char(C).
×
260

261
escape_like_arg(S) when is_binary(S) ->
262
    << <<(escape_like_arg(C))/binary>> || <<C>> <= S >>;
924✔
263
escape_like_arg($%) -> <<"\\%">>;
546✔
264
escape_like_arg($_) -> <<"\\_">>;
1,086✔
265
escape_like_arg($\\) -> <<"\\\\">>;
546✔
266
escape_like_arg($[) -> <<"\\[">>;     % For MSSQL
546✔
267
escape_like_arg($]) -> <<"\\]">>;
546✔
268
escape_like_arg(C) when is_integer(C), C >= 0, C =< 255 -> <<C>>.
29,134✔
269

270
escape_like_arg_circumflex(S) when is_binary(S) ->
271
    << <<(escape_like_arg_circumflex(C))/binary>> || <<C>> <= S >>;
×
272
escape_like_arg_circumflex($%) -> <<"^%">>;
×
273
escape_like_arg_circumflex($_) -> <<"^_">>;
×
274
escape_like_arg_circumflex($^) -> <<"^^">>;
×
275
escape_like_arg_circumflex($[) -> <<"^[">>;     % For MSSQL
×
276
escape_like_arg_circumflex($]) -> <<"^]">>;
×
277
escape_like_arg_circumflex(C) when is_integer(C), C >= 0, C =< 255 -> <<C>>.
×
278

279
to_bool(<<"t">>) -> true;
×
280
to_bool(<<"true">>) -> true;
×
281
to_bool(<<"1">>) -> true;
472✔
282
to_bool(true) -> true;
×
283
to_bool(1) -> true;
292✔
284
to_bool(_) -> false.
2,277✔
285

286
to_list(EscapeFun, Val) ->
287
    Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)),
×
288
    [<<"(">>, Escaped, <<")">>].
×
289

290
to_array(EscapeFun, Val) ->
291
    Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)),
×
292
    lists:flatten([<<"{">>, Escaped, <<"}">>]).
×
293

294
to_string_literal(odbc, S) ->
295
    <<"'", (escape(S))/binary, "'">>;
×
296
to_string_literal(mysql, S) ->
297
    <<"'", (escape(S))/binary, "'">>;
422✔
298
to_string_literal(mssql, S) ->
299
    <<"'", (standard_escape(S))/binary, "'">>;
×
300
to_string_literal(sqlite, S) ->
301
    <<"'", (standard_escape(S))/binary, "'">>;
422✔
302
to_string_literal(pgsql, S) ->
303
    <<"E'", (escape(S))/binary, "'">>.
422✔
304

305
to_string_literal_t(S) ->
306
    State = get(?STATE_KEY),
15✔
307
    to_string_literal(State#state.db_type, S).
15✔
308

309
encode_term(Term) ->
310
    escape(list_to_binary(
×
311
             erl_prettypr:format(erl_syntax:abstract(Term),
312
                                 [{paper, 65535}, {ribbon, 65535}]))).
313

314
decode_term(Bin) ->
315
    Str = binary_to_list(<<Bin/binary, ".">>),
1,648✔
316
    try
1,648✔
317
        {ok, Tokens, _} = erl_scan:string(Str),
1,648✔
318
        {ok, Term} = erl_parse:parse_term(Tokens),
1,648✔
319
        Term
1,648✔
320
    catch _:{badmatch, {error, {Line, Mod, Reason}, _}} ->
321
            ?ERROR_MSG("Corrupted Erlang term in SQL database:~n"
×
322
                       "** Scanner error: at line ~B: ~ts~n"
323
                       "** Term: ~ts",
324
                       [Line, Mod:format_error(Reason), Bin]),
×
325
            erlang:error(badarg);
×
326
          _:{badmatch, {error, {Line, Mod, Reason}}} ->
327
            ?ERROR_MSG("Corrupted Erlang term in SQL database:~n"
×
328
                       "** Parser error: at line ~B: ~ts~n"
329
                       "** Term: ~ts",
330
                       [Line, Mod:format_error(Reason), Bin]),
×
331
            erlang:error(badarg)
×
332
    end.
333

334
-spec sqlite_db(binary()) -> atom().
335
sqlite_db(Host) ->
336
    list_to_atom("ejabberd_sqlite_" ++ binary_to_list(Host)).
18,409✔
337

338
-spec sqlite_file(binary()) -> string().
339
sqlite_file(Host) ->
340
    case ejabberd_option:sql_database(Host) of
2✔
341
        undefined ->
342
            Path = ["sqlite", atom_to_list(node()),
2✔
343
                    binary_to_list(Host), "ejabberd.db"],
344
            case file:get_cwd() of
2✔
345
                {ok, Cwd} ->
346
                    filename:join([Cwd|Path]);
2✔
347
                {error, Reason} ->
348
                    ?ERROR_MSG("Failed to get current directory: ~ts",
×
349
                               [file:format_error(Reason)]),
×
350
                    filename:join(Path)
×
351
            end;
352
        File ->
353
            binary_to_list(File)
×
354
    end.
355

356
use_new_schema() ->
357
    ejabberd_option:new_sql_schema().
17,397✔
358

359
-spec get_worker(binary()) -> atom().
360
get_worker(Host) ->
361
    PoolSize = ejabberd_option:sql_pool_size(Host),
21,966✔
362
    I = p1_rand:round_robin(PoolSize) + 1,
21,966✔
363
    binary_to_existing_atom(get_worker_name(Host, I), utf8).
21,966✔
364

365
-spec get_worker_name(binary(), pos_integer()) -> binary().
366
get_worker_name(Host, I) ->
367
    <<"ejabberd_sql_", Host/binary, $_, (integer_to_binary(I))/binary>>.
21,969✔
368

369
%%%----------------------------------------------------------------------
370
%%% Callback functions from gen_fsm
371
%%%----------------------------------------------------------------------
372
init([Host]) ->
373
    process_flag(trap_exit, true),
3✔
374
    case ejabberd_option:sql_keepalive_interval(Host) of
3✔
375
        undefined ->
376
            ok;
3✔
377
        KeepaliveInterval ->
378
            timer:apply_interval(KeepaliveInterval, ?MODULE,
×
379
                                 keep_alive, [Host, self()])
380
    end,
381
    [DBType | _] = db_opts(Host),
3✔
382
    p1_fsm:send_event(self(), connect),
3✔
383
    QueueType = ejabberd_option:sql_queue_type(Host),
3✔
384
    {ok, connecting,
3✔
385
     #state{db_type = DBType, host = Host,
386
            pending_requests = p1_queue:new(QueueType, max_fsm_queue()),
387
            timeout = query_timeout(Host)}}.
388

389
connecting(connect, #state{host = Host} = State) ->
390
    ConnectRes = case db_opts(Host) of
3✔
391
                     [mysql | Args] -> apply(fun mysql_connect/8, Args);
1✔
392
                     [pgsql | Args] -> apply(fun pgsql_connect/8, Args);
1✔
393
                     [sqlite | Args] -> apply(fun sqlite_connect/1, Args);
1✔
394
                     [mssql | Args] -> apply(fun odbc_connect/2, Args);
×
395
                     [odbc | Args] -> apply(fun odbc_connect/2, Args)
×
396
                 end,
397
    case ConnectRes of
3✔
398
        {ok, Ref} ->
399
            try link(Ref) of
3✔
400
                _ ->
401
                    lists:foreach(
3✔
402
                      fun({{?PREPARE_KEY, _} = Key, _}) ->
403
                              erase(Key);
×
404
                         (_) ->
405
                              ok
9✔
406
                      end, get()),
407
                    PendingRequests =
3✔
408
                        p1_queue:dropwhile(
409
                          fun(Req) ->
410
                                  p1_fsm:send_event(self(), Req),
×
411
                                  true
×
412
                          end, State#state.pending_requests),
413
                    State1 = State#state{db_ref = Ref,
3✔
414
                                         pending_requests = PendingRequests},
415
                    State2 = get_db_version(State1),
3✔
416
                    {next_state, session_established, State2#state{reconnect_count = 0}}
3✔
417
            catch _:Reason ->
418
                    handle_reconnect(Reason, State)
×
419
            end;
420
        {error, Reason} ->
421
            handle_reconnect(Reason, State)
×
422
    end;
423
connecting(Event, State) ->
424
    ?WARNING_MSG("Unexpected event in 'connecting': ~p",
×
425
                 [Event]),
×
426
    {next_state, connecting, State}.
×
427

428
connecting({sql_cmd, {sql_query, ?KEEPALIVE_QUERY}, Timestamp},
429
           From, State) ->
430
    reply(From, {error, <<"SQL connection failed">>}, Timestamp),
×
431
    {next_state, connecting, State};
×
432
connecting({sql_cmd, Command, Timestamp} = Req, From,
433
           State) ->
434
    ?DEBUG("Queuing pending request while connecting:~n\t~p",
×
435
           [Req]),
×
436
    PendingRequests =
×
437
        try p1_queue:in({sql_cmd, Command, From, Timestamp},
×
438
                        State#state.pending_requests)
439
        catch error:full ->
440
                Err = <<"SQL request queue is overfilled">>,
×
441
                ?ERROR_MSG("~ts, bouncing all pending requests", [Err]),
×
442
                Q = p1_queue:dropwhile(
×
443
                      fun({sql_cmd, _, To, TS}) ->
444
                              reply(To, {error, Err}, TS),
×
445
                              true
×
446
                      end, State#state.pending_requests),
447
                p1_queue:in({sql_cmd, Command, From, Timestamp}, Q)
×
448
        end,
449
    {next_state, connecting,
×
450
     State#state{pending_requests = PendingRequests}};
451
connecting(Request, {Who, _Ref}, State) ->
452
    ?WARNING_MSG("Unexpected call ~p from ~p in 'connecting'",
×
453
                 [Request, Who]),
×
454
    {next_state, connecting, State}.
×
455

456
session_established({sql_cmd, Command, Timestamp}, From,
457
                    State) ->
458
    run_sql_cmd(Command, From, State, Timestamp);
21,966✔
459
session_established(Request, {Who, _Ref}, State) ->
460
    ?WARNING_MSG("Unexpected call ~p from ~p in 'session_established'",
×
461
                 [Request, Who]),
×
462
    {next_state, session_established, State}.
×
463

464
session_established({sql_cmd, Command, From, Timestamp},
465
                    State) ->
466
    run_sql_cmd(Command, From, State, Timestamp);
×
467
session_established(force_timeout, State) ->
468
    {stop, timeout, State};
×
469
session_established(Event, State) ->
470
    ?WARNING_MSG("Unexpected event in 'session_established': ~p",
×
471
                 [Event]),
×
472
    {next_state, session_established, State}.
×
473

474
handle_event(_Event, StateName, State) ->
475
    {next_state, StateName, State}.
×
476

477
handle_sync_event(_Event, _From, StateName, State) ->
478
    {reply, {error, badarg}, StateName, State}.
×
479

480
code_change(_OldVsn, StateName, State, _Extra) ->
481
    {ok, StateName, State}.
×
482

483
handle_info({'EXIT', _Pid, _Reason}, connecting, State) ->
484
    {next_state, connecting, State};
×
485
handle_info({'EXIT', _Pid, Reason}, _StateName, State) ->
486
    handle_reconnect(Reason, State);
×
487
handle_info(Info, StateName, State) ->
488
    ?WARNING_MSG("Unexpected info in ~p: ~p",
×
489
                 [StateName, Info]),
×
490
    {next_state, StateName, State}.
×
491

492
terminate(_Reason, _StateName, State) ->
493
    case State#state.db_type of
3✔
494
        mysql -> catch p1_mysql_conn:stop(State#state.db_ref);
1✔
495
        sqlite -> catch sqlite3:close(sqlite_db(State#state.host));
1✔
496
        _ -> ok
1✔
497
    end,
498
    ok.
3✔
499

500
%%----------------------------------------------------------------------
501
%% Func: print_state/1
502
%% Purpose: Prepare the state to be printed on error log
503
%% Returns: State to print
504
%%----------------------------------------------------------------------
505
print_state(State) -> State.
×
506

507
%%%----------------------------------------------------------------------
508
%%% Internal functions
509
%%%----------------------------------------------------------------------
510
handle_reconnect(Reason, #state{host = Host, reconnect_count = RC} = State) ->
511
    StartInterval0 = ejabberd_option:sql_start_interval(Host),
×
512
    StartInterval = case RC of
×
513
                        0 -> erlang:min(5000, StartInterval0);
×
514
                        _ -> StartInterval0
×
515
                    end,
516
    ?WARNING_MSG("~p connection failed:~n"
×
517
                 "** Reason: ~p~n"
518
                 "** Retry after: ~B seconds",
519
                 [State#state.db_type, Reason,
520
                  StartInterval div 1000]),
×
521
    case State#state.db_type of
×
522
        mysql -> catch p1_mysql_conn:stop(State#state.db_ref);
×
523
        sqlite -> catch sqlite3:close(sqlite_db(State#state.host));
×
524
        pgsql -> catch pgsql:terminate(State#state.db_ref);
×
525
        _ -> ok
×
526
    end,
527
    p1_fsm:send_event_after(StartInterval, connect),
×
528
    {next_state, connecting, State#state{reconnect_count = RC + 1,
×
529
                                         timeout = query_timeout(Host)}}.
530

531
run_sql_cmd(Command, From, State, Timestamp) ->
532
    CT = current_time(),
21,966✔
533
    case CT >= Timestamp of
21,966✔
534
        true ->
535
            State1 = report_overload(State),
×
536
            {next_state, session_established, State1};
×
537
        false ->
538
            receive
21,966✔
539
                {'EXIT', _Pid, Reason} ->
540
                    PR = p1_queue:in({sql_cmd, Command, From, Timestamp},
×
541
                                     State#state.pending_requests),
542
                    handle_reconnect(Reason, State#state{pending_requests = PR})
×
543
            after 0 ->
544
                Timeout = min(query_timeout(State#state.host), Timestamp - CT),
21,966✔
545
                put(?NESTING_KEY, ?TOP_LEVEL_TXN),
21,966✔
546
                put(?STATE_KEY, State#state{timeout = Timeout}),
21,966✔
547
                abort_on_driver_error(outer_op(Command), From, Timestamp)
21,966✔
548
            end
549
    end.
550

551
%% @doc Only called by handle_call, only handles top level operations.
552
-spec outer_op(Op::{atom(), binary()} | {sql_transaction, binary(), pos_integer()}) ->
553
    {error, Reason::binary()} | {aborted, Reason::binary()} | {atomic, Result::any()}.
554
outer_op({sql_query, Query}) ->
555
    sql_query_internal(Query);
16,609✔
556
outer_op({sql_transaction, F, Restarts}) ->
557
    outer_transaction(F, Restarts, <<"">>);
2,720✔
558
outer_op({sql_bloc, F}) -> execute_bloc(F).
2,637✔
559

560
%% Called via sql_query/transaction/bloc from client code when inside a
561
%% nested operation
562
nested_op({sql_query, Query}) ->
563
    sql_query_internal(Query);
9✔
564
nested_op({sql_transaction, F, Restarts}) ->
565
    NestingLevel = get(?NESTING_KEY),
6✔
566
    if NestingLevel =:= (?TOP_LEVEL_TXN) ->
6✔
567
        outer_transaction(F, Restarts, <<"">>);
3✔
568
        true -> inner_transaction(F)
3✔
569
    end;
570
nested_op({sql_bloc, F}) -> execute_bloc(F).
27✔
571

572
%% Never retry nested transactions - only outer transactions
573
inner_transaction(F) ->
574
    PreviousNestingLevel = get(?NESTING_KEY),
3✔
575
    case get(?NESTING_KEY) of
3✔
576
      ?TOP_LEVEL_TXN ->
577
          {backtrace, T} = process_info(self(), backtrace),
×
578
          ?ERROR_MSG("Inner transaction called at outer txn "
×
579
                     "level. Trace: ~ts",
580
                     [T]),
×
581
          erlang:exit(implementation_faulty);
×
582
      _N -> ok
3✔
583
    end,
584
    put(?NESTING_KEY, PreviousNestingLevel + 1),
3✔
585
    Result = (catch F()),
3✔
586
    put(?NESTING_KEY, PreviousNestingLevel),
3✔
587
    case Result of
3✔
588
      {aborted, Reason} -> {aborted, Reason};
×
589
      {'EXIT', Reason} -> {'EXIT', Reason};
×
590
      {atomic, Res} -> {atomic, Res};
×
591
      Res -> {atomic, Res}
3✔
592
    end.
593

594
outer_transaction(F, NRestarts, _Reason) ->
595
    PreviousNestingLevel = get(?NESTING_KEY),
2,723✔
596
    case get(?NESTING_KEY) of
2,723✔
597
      ?TOP_LEVEL_TXN -> ok;
2,723✔
598
      _N ->
599
          {backtrace, T} = process_info(self(), backtrace),
×
600
          ?ERROR_MSG("Outer transaction called at inner txn "
×
601
                     "level. Trace: ~ts",
602
                     [T]),
×
603
          erlang:exit(implementation_faulty)
×
604
    end,
605
    case sql_begin() of
2,723✔
606
        {error, Reason} ->
607
            maybe_restart_transaction(F, NRestarts, Reason, false);
×
608
        _ ->
609
            put(?NESTING_KEY, PreviousNestingLevel + 1),
2,723✔
610
            try F() of
2,723✔
611
                Res ->
612
                    case sql_commit() of
2,723✔
613
                        {error, Reason} ->
614
                            restart(Reason);
×
615
                        _ ->
616
                            {atomic, Res}
2,723✔
617
                    end
618
            catch
619
                ?EX_RULE(throw, {aborted, Reason}, _) when NRestarts > 0 ->
620
                    maybe_restart_transaction(F, NRestarts, Reason, true);
×
621
                ?EX_RULE(throw, {aborted, Reason}, Stack) when NRestarts =:= 0 ->
622
                    StackTrace = ?EX_STACK(Stack),
×
623
                    ?ERROR_MSG("SQL transaction restarts exceeded~n** "
×
624
                               "Restarts: ~p~n** Last abort reason: "
625
                               "~p~n** Stacktrace: ~p~n** When State "
626
                               "== ~p",
627
                               [?MAX_TRANSACTION_RESTARTS, Reason,
628
                                StackTrace, get(?STATE_KEY)]),
×
629
                    maybe_restart_transaction(F, NRestarts, Reason, true);
×
630
                ?EX_RULE(_, Reason, _) ->
631
                    maybe_restart_transaction(F, 0, Reason, true)
×
632
            end
633
    end.
634

635
maybe_restart_transaction(F, NRestarts, Reason, DoRollback) ->
636
    Res = case driver_restart_required(Reason) of
×
637
              true ->
638
                  {aborted, Reason};
×
639
              _ when DoRollback ->
640
                  case sql_rollback() of
×
641
                      {error, Reason2} ->
642
                          case driver_restart_required(Reason2) of
×
643
                              true ->
644
                                  {aborted, Reason2};
×
645
                              _ ->
646
                                  continue
×
647
                          end;
648
                      _ ->
649
                          continue
×
650
                  end;
651
              _ ->
652
                  continue
×
653
    end,
654
    case Res of
×
655
        continue when NRestarts > 0 ->
656
            put(?NESTING_KEY, ?TOP_LEVEL_TXN),
×
657
            outer_transaction(F, NRestarts - 1, Reason);
×
658
        continue ->
659
            {aborted, Reason};
×
660
        Other ->
661
            Other
×
662
    end.
663

664
execute_bloc(F) ->
665
    case catch F() of
2,664✔
666
      {aborted, Reason} -> {aborted, Reason};
×
667
      {'EXIT', Reason} -> {aborted, Reason};
×
668
      Res -> {atomic, Res}
2,664✔
669
    end.
670

671
execute_fun(F) when is_function(F, 0) ->
672
    F();
75✔
673
execute_fun(F) when is_function(F, 2) ->
674
    State = get(?STATE_KEY),
3,672✔
675
    F(State#state.db_type, State#state.db_version).
3,672✔
676

677
sql_query_internal([{_, _} | _] = Queries) ->
678
    State = get(?STATE_KEY),
5,446✔
679
    case select_sql_query(Queries, State) of
5,446✔
680
        undefined ->
681
            {error, <<"no matching query for the current DBMS found">>};
×
682
        Query ->
683
            sql_query_internal(Query)
5,446✔
684
    end;
685
sql_query_internal(#sql_query{} = Query) ->
686
    State = get(?STATE_KEY),
31,349✔
687
    Res =
31,349✔
688
        try
689
            case State#state.db_type of
31,349✔
690
                odbc ->
691
                    generic_sql_query(Query);
×
692
                mssql ->
693
                    mssql_sql_query(Query);
×
694
                pgsql ->
695
                    Key = {?PREPARE_KEY, Query#sql_query.hash},
10,052✔
696
                    case get(Key) of
10,052✔
697
                        undefined ->
698
                            Host = State#state.host,
137✔
699
                            PreparedStatements =
137✔
700
                                ejabberd_option:sql_prepared_statements(Host),
701
                            case PreparedStatements of
137✔
702
                                false ->
703
                                    put(Key, ignore);
×
704
                                true ->
705
                                    case pgsql_prepare(Query, State) of
137✔
706
                                        {ok, _, _, _} ->
707
                                            put(Key, prepared);
137✔
708
                                        {error, Error} ->
709
                                            ?ERROR_MSG(
×
710
                                               "PREPARE failed for SQL query "
711
                                               "at ~p: ~p",
712
                                               [Query#sql_query.loc, Error]),
×
713
                                            put(Key, ignore)
×
714
                                    end
715
                            end;
716
                        _ ->
717
                            ok
9,915✔
718
                    end,
719
                    case get(Key) of
10,052✔
720
                        prepared ->
721
                            pgsql_execute_sql_query(Query, State);
10,052✔
722
                        _ ->
723
                            pgsql_sql_query(Query)
×
724
                    end;
725
                mysql ->
726
                    case {Query#sql_query.flags, ejabberd_option:sql_prepared_statements(State#state.host)} of
11,117✔
727
                        {1, _} ->
728
                            generic_sql_query(Query);
×
729
                        {_, false} ->
730
                            generic_sql_query(Query);
×
731
                        _ ->
732
                            mysql_prepared_execute(Query, State)
11,117✔
733
                    end;
734
                sqlite ->
735
                    sqlite_sql_query(Query)
10,180✔
736
            end
737
        catch exit:{timeout, _} ->
738
                {error, <<"timed out">>};
×
739
              exit:{killed, _} ->
740
                {error, <<"killed">>};
×
741
              exit:{normal, _} ->
742
                {error, <<"terminated unexpectedly">>};
×
743
              exit:{shutdown, _} ->
744
                {error, <<"shutdown">>};
×
745
              ?EX_RULE(Class, Reason, Stack) ->
746
                StackTrace = ?EX_STACK(Stack),
×
747
                ?ERROR_MSG("Internal error while processing SQL query:~n** ~ts",
×
748
                           [misc:format_exception(2, Class, Reason, StackTrace)]),
×
749
                {error, <<"internal error">>}
×
750
        end,
751
    check_error(Res, Query);
31,349✔
752
sql_query_internal(F) when is_function(F) ->
753
    case catch execute_fun(F) of
3,747✔
754
        {aborted, Reason} -> {error, Reason};
×
755
        {'EXIT', Reason} -> {error, Reason};
×
756
        Res -> Res
3,747✔
757
    end;
758
sql_query_internal(Query) ->
759
    State = get(?STATE_KEY),
17,107✔
760
    ?DEBUG("SQL: \"~ts\"", [Query]),
17,107✔
761
    QueryTimeout = State#state.timeout,
17,107✔
762
    Res = case State#state.db_type of
17,107✔
763
            odbc ->
764
                to_odbc(odbc:sql_query(State#state.db_ref, [Query],
×
765
                                       QueryTimeout - 1000));
766
            mssql ->
767
                to_odbc(odbc:sql_query(State#state.db_ref, [Query],
×
768
                                       QueryTimeout - 1000));
769
            pgsql ->
770
                pgsql_to_odbc(pgsql:squery(State#state.db_ref, Query,
2,347✔
771
                                           QueryTimeout - 1000));
772
            mysql ->
773
                mysql_to_odbc(p1_mysql_conn:squery(State#state.db_ref,
2,348✔
774
                                                   [Query], self(),
775
                                                   [{timeout, QueryTimeout - 1000},
776
                                                    {result_type, binary}]));
777
              sqlite ->
778
                  Host = State#state.host,
12,412✔
779
                  sqlite_to_odbc(Host, sqlite3:sql_exec(sqlite_db(Host), Query))
12,412✔
780
          end,
781
    check_error(Res, Query).
17,107✔
782

783
select_sql_query(Queries, State) ->
784
    select_sql_query(
5,446✔
785
      Queries, State#state.db_type, State#state.db_version, undefined).
786

787
select_sql_query([], _Type, _Version, undefined) ->
788
    undefined;
×
789
select_sql_query([], _Type, _Version, Query) ->
790
    Query;
×
791
select_sql_query([{any, Query} | _], _Type, _Version, _) ->
792
    Query;
5,446✔
793
select_sql_query([{Type, Query} | _], Type, _Version, _) ->
794
    Query;
×
795
select_sql_query([{{Type, _Version1}, Query1} | Rest], Type, undefined, _) ->
796
    select_sql_query(Rest, Type, undefined, Query1);
×
797
select_sql_query([{{Type, Version1}, Query1} | Rest], Type, Version, Query) ->
798
    if
×
799
        Version >= Version1 ->
800
            Query1;
×
801
        true ->
802
            select_sql_query(Rest, Type, Version, Query)
×
803
    end;
804
select_sql_query([{_, _} | Rest], Type, Version, Query) ->
805
    select_sql_query(Rest, Type, Version, Query).
5,446✔
806

807
generic_sql_query(SQLQuery) ->
808
    sql_query_format_res(
×
809
      sql_query_internal(generic_sql_query_format(SQLQuery)),
810
      SQLQuery).
811

812
generic_sql_query_format(SQLQuery) ->
813
    Args = (SQLQuery#sql_query.args)(generic_escape()),
×
814
    (SQLQuery#sql_query.format_query)(Args).
×
815

816
generic_escape() ->
817
    #sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end,
×
818
                integer = fun(X) -> misc:i2l(X) end,
×
819
                boolean = fun(true) -> <<"1">>;
×
820
                             (false) -> <<"0">>
×
821
                          end,
822
                in_array_string = fun(X) -> <<"'", (escape(X))/binary, "'">> end,
×
823
                like_escape = fun() -> <<"">> end
×
824
               }.
825

826
pgsql_sql_query(SQLQuery) ->
827
    sql_query_format_res(
×
828
      sql_query_internal(pgsql_sql_query_format(SQLQuery)),
829
      SQLQuery).
830

831
pgsql_sql_query_format(SQLQuery) ->
832
    Args = (SQLQuery#sql_query.args)(pgsql_escape()),
×
833
    (SQLQuery#sql_query.format_query)(Args).
×
834

835
pgsql_escape() ->
836
    #sql_escape{string = fun(X) -> <<"E'", (escape(X))/binary, "'">> end,
×
837
                integer = fun(X) -> misc:i2l(X) end,
×
838
                boolean = fun(true) -> <<"'t'">>;
×
839
                             (false) -> <<"'f'">>
×
840
                          end,
841
                in_array_string = fun(X) -> <<"E'", (escape(X))/binary, "'">> end,
×
842
                like_escape = fun() -> <<"ESCAPE E'\\\\'">> end
×
843
               }.
844

845
sqlite_sql_query(SQLQuery) ->
846
    sql_query_format_res(
10,180✔
847
      sql_query_internal(sqlite_sql_query_format(SQLQuery)),
848
      SQLQuery).
849

850
sqlite_sql_query_format(SQLQuery) ->
851
    Args = (SQLQuery#sql_query.args)(sqlite_escape()),
10,180✔
852
    (SQLQuery#sql_query.format_query)(Args).
10,180✔
853

854
sqlite_escape() ->
855
    #sql_escape{string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end,
10,180✔
856
                integer = fun(X) -> misc:i2l(X) end,
2,865✔
857
                boolean = fun(true) -> <<"1">>;
115✔
858
                             (false) -> <<"0">>
290✔
859
                          end,
860
                in_array_string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end,
×
861
                like_escape = fun() -> <<"ESCAPE '\\'">> end
308✔
862
               }.
863

864
standard_escape(S) ->
865
    << <<(case Char of
23,601✔
866
              $' -> << "''" >>;
9,175✔
867
              _ -> << Char >>
905,776✔
868
          end)/binary>> || <<Char>> <= S >>.
23,601✔
869

870
mssql_sql_query(SQLQuery) ->
871
    sqlite_sql_query(SQLQuery).
×
872

873
pgsql_prepare(SQLQuery, State) ->
874
    Escape = #sql_escape{_ = fun(_) -> arg end,
137✔
875
                         like_escape = fun() -> escape end},
6✔
876
    {RArgs, _} =
137✔
877
        lists:foldl(
878
            fun(arg, {Acc, I}) ->
879
                {[<<$$, (integer_to_binary(I))/binary>> | Acc], I + 1};
371✔
880
               (escape, {Acc, I}) ->
881
                   {[<<"ESCAPE E'\\\\'">> | Acc], I};
6✔
882
               (List, {Acc, I}) when is_list(List) ->
883
                   {[<<$$, (integer_to_binary(I))/binary>> | Acc], I + 1}
×
884
            end, {[], 1}, (SQLQuery#sql_query.args)(Escape)),
885
    Args = lists:reverse(RArgs),
137✔
886
    %N = length((SQLQuery#sql_query.args)(Escape)),
887
    %Args = [<<$$, (integer_to_binary(I))/binary>> || I <- lists:seq(1, N)],
888
    Query = (SQLQuery#sql_query.format_query)(Args),
137✔
889
    pgsql:prepare(State#state.db_ref, SQLQuery#sql_query.hash, Query).
137✔
890

891
pgsql_execute_escape() ->
892
    #sql_escape{string = fun(X) -> X end,
10,052✔
893
                integer = fun(X) -> [misc:i2l(X)] end,
2,603✔
894
                boolean = fun(true) -> "1";
115✔
895
                             (false) -> "0"
290✔
896
                          end,
897
                in_array_string = fun(X) -> <<"\"", (escape(X))/binary, "\"">> end,
×
898
                like_escape = fun() -> ignore end
308✔
899
               }.
900

901
pgsql_execute_sql_query(SQLQuery, State) ->
902
    Args = (SQLQuery#sql_query.args)(pgsql_execute_escape()),
10,052✔
903
    Args2 = lists:filter(fun(ignore) -> false; (_) -> true end, Args),
10,052✔
904
    ExecuteRes =
10,052✔
905
        pgsql:execute(State#state.db_ref, SQLQuery#sql_query.hash, Args2),
906
%    {T, ExecuteRes} =
907
%        timer:tc(pgsql, execute, [State#state.db_ref, SQLQuery#sql_query.hash, Args]),
908
%    io:format("T ~ts ~p~n", [SQLQuery#sql_query.hash, T]),
909
    Res = pgsql_execute_to_odbc(ExecuteRes),
10,052✔
910
    sql_query_format_res(Res, SQLQuery).
10,052✔
911

912
mysql_prepared_execute(#sql_query{hash = Hash} = Query, State) ->
913
    ValEsc = #sql_escape{like_escape = fun() -> ignore end, _ = fun(X) -> X end},
11,117✔
914
    TypesEsc = #sql_escape{string = fun(_) -> string end,
11,117✔
915
                           integer = fun(_) -> integer end,
2,912✔
916
                           boolean = fun(_) -> bool end,
405✔
917
                           in_array_string = fun(_) -> string end,
×
918
                           like_escape = fun() -> ignore end},
308✔
919
    Val = [X || X <- (Query#sql_query.args)(ValEsc), X /= ignore],
11,117✔
920
    Types = [X || X <- (Query#sql_query.args)(TypesEsc), X /= ignore],
11,117✔
921
    QueryFn = fun() ->
11,117✔
922
        PrepEsc = #sql_escape{like_escape = fun() -> <<>> end, _ = fun(_) -> <<"?">> end},
170✔
923
        (Query#sql_query.format_query)((Query#sql_query.args)(PrepEsc))
170✔
924
        end,
925
    QueryTimeout = query_timeout(State#state.host),
11,117✔
926
    Res = p1_mysql_conn:prepared_query(State#state.db_ref, QueryFn, Hash, Val, Types,
11,117✔
927
                                       self(), [{timeout, QueryTimeout - 1000}]),
928
    Res2 = mysql_to_odbc(Res),
11,117✔
929
    sql_query_format_res(Res2, Query).
11,117✔
930

931
sql_query_format_res({selected, _, Rows}, SQLQuery) ->
932
    Res =
20,005✔
933
        lists:flatmap(
934
          fun(Row) ->
935
                  try
30,814✔
936
                      [(SQLQuery#sql_query.format_res)(Row)]
30,814✔
937
                  catch
938
                      ?EX_RULE(Class, Reason, Stack) ->
939
                          StackTrace = ?EX_STACK(Stack),
×
940
                          ?ERROR_MSG("Error while processing SQL query result:~n"
×
941
                                     "** Row: ~p~n** ~ts",
942
                                     [Row,
943
                                      misc:format_exception(2, Class, Reason, StackTrace)]),
×
944
                          []
×
945
                  end
946
          end, Rows),
947
    {selected, Res};
20,005✔
948
sql_query_format_res(Res, _SQLQuery) ->
949
    Res.
11,344✔
950

951
sql_query_to_iolist(SQLQuery) ->
952
    generic_sql_query_format(SQLQuery).
×
953

954
sql_query_to_iolist(sqlite, SQLQuery) ->
955
    sqlite_sql_query_format(SQLQuery);
×
956
sql_query_to_iolist(_DbType, SQLQuery) ->
957
    generic_sql_query_format(SQLQuery).
×
958

959
sql_begin() ->
960
    sql_query_internal(
2,723✔
961
      [{mssql, [<<"begin transaction;">>]},
962
       {any, [<<"begin;">>]}]).
963

964
sql_commit() ->
965
    sql_query_internal(
2,723✔
966
      [{mssql, [<<"commit transaction;">>]},
967
       {any, [<<"commit;">>]}]).
968

969
sql_rollback() ->
970
    sql_query_internal(
×
971
      [{mssql, [<<"rollback transaction;">>]},
972
       {any, [<<"rollback;">>]}]).
973

974
driver_restart_required(<<"query timed out">>) -> true;
×
975
driver_restart_required(<<"connection closed">>) -> true;
×
976
driver_restart_required(<<"Failed sending data on socket", _/binary>>) -> true;
×
977
driver_restart_required(<<"SQL connection failed">>) -> true;
×
978
driver_restart_required(<<"Communication link failure">>) -> true;
×
979
driver_restart_required(_) -> false.
×
980

981
%% Generate the OTP callback return tuple depending on the driver result.
982
abort_on_driver_error({Tag, Msg} = Reply, From, Timestamp) when Tag == error; Tag == aborted ->
983
    reply(From, Reply, Timestamp),
×
984
    case driver_restart_required(Msg) of
×
985
        true ->
986
            handle_reconnect(Msg, get(?STATE_KEY));
×
987
        _ ->
988
            {next_state, session_established, get(?STATE_KEY)}
×
989
    end;
990
abort_on_driver_error(Reply, From, Timestamp) ->
991
    reply(From, Reply, Timestamp),
21,966✔
992
    {next_state, session_established, get(?STATE_KEY)}.
21,966✔
993

994
-spec report_overload(state()) -> state().
995
report_overload(#state{overload_reported = PrevTime} = State) ->
996
    CurrTime = current_time(),
×
997
    case PrevTime == undefined orelse (CurrTime - PrevTime) > timer:seconds(30) of
×
998
        true ->
999
            ?ERROR_MSG("SQL connection pool is overloaded, "
×
1000
                       "discarding stale requests", []),
×
1001
            State#state{overload_reported = current_time()};
×
1002
        false ->
1003
            State
×
1004
    end.
1005

1006
-spec reply({pid(), term()}, term(), integer()) -> term().
1007
reply(From, Reply, Timestamp) ->
1008
    case current_time() >= Timestamp of
21,966✔
1009
        true -> ok;
×
1010
        false -> p1_fsm:reply(From, Reply)
21,966✔
1011
    end.
1012

1013
%% == pure ODBC code
1014

1015
%% part of init/1
1016
%% Open an ODBC database connection
1017
odbc_connect(SQLServer, Timeout) ->
1018
    ejabberd:start_app(odbc),
×
1019
    odbc:connect(binary_to_list(SQLServer),
×
1020
                 [{scrollable_cursors, off},
1021
                  {extended_errors, on},
1022
                  {tuple_row, off},
1023
                  {timeout, Timeout},
1024
                  {binary_strings, on}]).
1025

1026
%% == Native SQLite code
1027

1028
%% part of init/1
1029
%% Open a database connection to SQLite
1030

1031
sqlite_connect(Host) ->
1032
    File = sqlite_file(Host),
1✔
1033
    case filelib:ensure_dir(File) of
1✔
1034
        ok ->
1035
            case sqlite3:open(sqlite_db(Host), [{file, File}]) of
1✔
1036
                {ok, Ref} ->
1037
                    sqlite3:sql_exec(
1✔
1038
                      sqlite_db(Host), "pragma foreign_keys = on"),
1039
                    {ok, Ref};
1✔
1040
                {error, {already_started, Ref}} ->
1041
                    {ok, Ref};
×
1042
                {error, Reason} ->
1043
                    {error, Reason}
×
1044
            end;
1045
        Err ->
1046
            Err
×
1047
    end.
1048

1049
%% Convert SQLite query result to Erlang ODBC result formalism
1050
sqlite_to_odbc(Host, ok) ->
1051
    {updated, sqlite3:changes(sqlite_db(Host))};
4,099✔
1052
sqlite_to_odbc(Host, {rowid, _}) ->
1053
    {updated, sqlite3:changes(sqlite_db(Host))};
1,894✔
1054
sqlite_to_odbc(_Host, [{columns, Columns}, {rows, TRows}]) ->
1055
    Rows = [lists:map(
6,419✔
1056
              fun(I) when is_integer(I) ->
1057
                      integer_to_binary(I);
2,866✔
1058
                 (B) ->
1059
                      B
28,467✔
1060
              end, tuple_to_list(Row)) || Row <- TRows],
6,419✔
1061
    {selected, [list_to_binary(C) || C <- Columns], Rows};
6,419✔
1062
sqlite_to_odbc(_Host, {error, _Code, Reason}) ->
1063
    {error, Reason};
×
1064
sqlite_to_odbc(_Host, _) ->
1065
    {updated, undefined}.
×
1066

1067
%% == Native PostgreSQL code
1068

1069
%% part of init/1
1070
%% Open a database connection to PostgreSQL
1071
pgsql_connect(Server, Port, DB, Username, Password, ConnectTimeout,
1072
              Transport, SSLOpts) ->
1073
    pgsql:connect([{host, Server},
1✔
1074
                   {database, DB},
1075
                   {user, Username},
1076
                   {password, Password},
1077
                   {port, Port},
1078
                   {transport, Transport},
1079
                   {connect_timeout, ConnectTimeout},
1080
                   {as_binary, true}|SSLOpts]).
1081

1082
%% Convert PostgreSQL query result to Erlang ODBC result formalism
1083
pgsql_to_odbc({ok, PGSQLResult}) ->
1084
    case PGSQLResult of
2,347✔
1085
      [Item] -> pgsql_item_to_odbc(Item);
2,347✔
1086
      Items -> [pgsql_item_to_odbc(Item) || Item <- Items]
×
1087
    end.
1088

1089
pgsql_item_to_odbc({<<"SELECT", _/binary>>, Rows,
1090
                    Recs}) ->
1091
    {selected, [element(1, Row) || Row <- Rows], Recs};
416✔
1092
pgsql_item_to_odbc({<<"FETCH", _/binary>>, Rows,
1093
                    Recs}) ->
1094
    {selected, [element(1, Row) || Row <- Rows], Recs};
×
1095
pgsql_item_to_odbc(<<"INSERT ", OIDN/binary>>) ->
1096
    [_OID, N] = str:tokens(OIDN, <<" ">>),
×
1097
    {updated, binary_to_integer(N)};
×
1098
pgsql_item_to_odbc(<<"DELETE ", N/binary>>) ->
1099
    {updated, binary_to_integer(N)};
37✔
1100
pgsql_item_to_odbc(<<"UPDATE ", N/binary>>) ->
1101
    {updated, binary_to_integer(N)};
×
1102
pgsql_item_to_odbc({error, Error}) -> {error, Error};
×
1103
pgsql_item_to_odbc(_) -> {updated, undefined}.
1,894✔
1104

1105
pgsql_execute_to_odbc({ok, {<<"SELECT", _/binary>>, Rows}}) ->
1106
    {selected, [], [[Field || {_, Field} <- Row] || Row <- Rows]};
6,437✔
1107
pgsql_execute_to_odbc({ok, {'INSERT', N}}) ->
1108
    {updated, N};
2,434✔
1109
pgsql_execute_to_odbc({ok, {'DELETE', N}}) ->
1110
    {updated, N};
1,177✔
1111
pgsql_execute_to_odbc({ok, {'UPDATE', N}}) ->
1112
    {updated, N};
4✔
1113
pgsql_execute_to_odbc({error, Error}) -> {error, Error};
×
1114
pgsql_execute_to_odbc(_) -> {updated, undefined}.
×
1115

1116

1117
%% == Native MySQL code
1118

1119
%% part of init/1
1120
%% Open a database connection to MySQL
1121
mysql_connect(Server, Port, DB, Username, Password, ConnectTimeout, Transport, SSLOpts0) ->
1122
    SSLOpts = case Transport of
1✔
1123
                  ssl ->
1124
                      [ssl_required|SSLOpts0];
×
1125
                  _ ->
1126
                      []
1✔
1127
              end,
1128
    case p1_mysql_conn:start(binary_to_list(Server), Port,
1✔
1129
                             binary_to_list(Username),
1130
                             binary_to_list(Password),
1131
                             binary_to_list(DB),
1132
                             ConnectTimeout, fun log/3, SSLOpts)
1133
        of
1134
        {ok, Ref} ->
1135
            p1_mysql_conn:fetch(
1✔
1136
                Ref, [<<"set names 'utf8mb4' collate 'utf8mb4_bin';">>], self()),
1137
            {ok, Ref};
1✔
1138
        Err -> Err
×
1139
    end.
1140

1141
%% Convert MySQL query result to Erlang ODBC result formalism
1142
mysql_to_odbc({updated, MySQLRes}) ->
1143
    {updated, p1_mysql:get_result_affected_rows(MySQLRes)};
5,485✔
1144
mysql_to_odbc({data, MySQLRes}) ->
1145
    mysql_item_to_odbc(p1_mysql:get_result_field_info(MySQLRes),
7,981✔
1146
                       p1_mysql:get_result_rows(MySQLRes));
1147
mysql_to_odbc({error, MySQLRes})
1148
  when is_binary(MySQLRes) ->
1149
    {error, MySQLRes};
×
1150
mysql_to_odbc({error, MySQLRes})
1151
  when is_list(MySQLRes) ->
1152
    {error, list_to_binary(MySQLRes)};
×
1153
mysql_to_odbc({error, MySQLRes}) ->
1154
    mysql_to_odbc({error, p1_mysql:get_result_reason(MySQLRes)});
×
1155
mysql_to_odbc(ok) ->
1156
    ok.
×
1157

1158

1159
%% When tabular data is returned, convert it to the ODBC formalism
1160
mysql_item_to_odbc(Columns, Recs) ->
1161
    {selected, [element(2, Column) || Column <- Columns], Recs}.
7,981✔
1162

1163
to_odbc({selected, Columns, Rows}) ->
1164
    Rows2 = lists:map(
×
1165
        fun(Row) ->
1166
            Row2 = if is_tuple(Row) -> tuple_to_list(Row);
×
1167
                       is_list(Row) -> Row
×
1168
                   end,
1169
            lists:map(
×
1170
                fun(I) when is_integer(I) -> integer_to_binary(I);
×
1171
                    (B) -> B
×
1172
                end, Row2)
1173
        end, Rows),
1174
    {selected, [list_to_binary(C) || C <- Columns], Rows2};
×
1175
to_odbc({error, Reason}) when is_list(Reason) ->
1176
    {error, list_to_binary(Reason)};
×
1177
to_odbc(Res) ->
1178
    Res.
×
1179

1180
parse_mysql_version(SVersion, DefaultUpsert) ->
1181
    case re:run(SVersion, <<"(\\d+)\\.(\\d+)(?:\\.(\\d+))?(?:-([^-]*))?">>,
1✔
1182
                [{capture, all_but_first, binary}]) of
1183
        {match, [V1, V2, V3, Type]} ->
1184
            V = ((bin_to_int(V1)*1000)+bin_to_int(V2))*1000+bin_to_int(V3),
1✔
1185
            TypeA = binary_to_atom(Type, utf8),
1✔
1186
            Flags = case TypeA of
1✔
1187
                        'MariaDB' -> DefaultUpsert;
×
1188
                        _ when V >= 5007026 andalso V < 8000000 -> 1;
×
1189
                        _ when V >= 8000020 -> 1;
1✔
1190
                        _ -> DefaultUpsert
×
1191
                    end,
1192
            {ok, {V, TypeA, Flags}};
1✔
1193
        {match, [V1, V2, V3]} ->
1194
            V = ((bin_to_int(V1)*1000)+bin_to_int(V2))*1000+bin_to_int(V3),
×
1195
            Flags = case V of
×
1196
                        _ when V >= 5007026 andalso V < 8000000 -> 1;
×
1197
                        _ when V >= 8000020 -> 1;
×
1198
                        _ -> DefaultUpsert
×
1199
                    end,
1200
            {ok, {V, unknown, Flags}};
×
1201
        _ ->
1202
            error
×
1203
    end.
1204

1205
get_db_version(#state{db_type = pgsql} = State) ->
1206
    case pgsql:squery(State#state.db_ref,
1✔
1207
                      <<"select current_setting('server_version_num')">>) of
1208
        {ok, [{_, _, [[SVersion]]}]} ->
1209
            case catch binary_to_integer(SVersion) of
1✔
1210
                Version when is_integer(Version) ->
1211
                    State#state{db_version = Version};
1✔
1212
                Error ->
1213
                    ?WARNING_MSG("Error getting pgsql version: ~p", [Error]),
×
1214
                    State
×
1215
            end;
1216
        Res ->
1217
            ?WARNING_MSG("Error getting pgsql version: ~p", [Res]),
×
1218
            State
×
1219
    end;
1220
get_db_version(#state{db_type = mysql, host = Host} = State) ->
1221
    DefaultUpsert = case lists:member(mysql_alternative_upsert, ejabberd_option:sql_flags(Host)) of
1✔
1222
                        true -> 1;
×
1223
                        _ -> 0
1✔
1224
                    end,
1225
    case mysql_to_odbc(p1_mysql_conn:squery(State#state.db_ref,
1✔
1226
                                            [<<"select version();">>], self(),
1227
                                            [{timeout, 5000},
1228
                                             {result_type, binary}])) of
1229
        {selected, _, [SVersion]} ->
1230
            case parse_mysql_version(SVersion, DefaultUpsert) of
1✔
1231
                {ok, V} ->
1232
                    State#state{db_version = V};
1✔
1233
                error ->
1234
                    ?WARNING_MSG("Error parsing mysql version: ~p", [SVersion]),
×
1235
                    State
×
1236
            end;
1237
        Res ->
1238
            ?WARNING_MSG("Error getting mysql version: ~p", [Res]),
×
1239
            State
×
1240
    end;
1241
get_db_version(State) ->
1242
    State.
1✔
1243

1244
bin_to_int(<<>>) -> 0;
×
1245
bin_to_int(V) -> binary_to_integer(V).
3✔
1246

1247
log(Level, Format, Args) ->
1248
    case Level of
4✔
1249
      debug -> ?DEBUG(Format, Args);
4✔
1250
      info -> ?INFO_MSG(Format, Args);
×
1251
      normal -> ?INFO_MSG(Format, Args);
×
1252
      error -> ?ERROR_MSG(Format, Args)
×
1253
    end.
1254

1255
db_opts(Host) ->
1256
    Type = ejabberd_option:sql_type(Host),
6✔
1257
    Server = ejabberd_option:sql_server(Host),
6✔
1258
    Timeout = ejabberd_option:sql_connect_timeout(Host),
6✔
1259
    Transport = case ejabberd_option:sql_ssl(Host) of
6✔
1260
                    false -> tcp;
6✔
1261
                    true -> ssl
×
1262
                end,
1263
    warn_if_ssl_unsupported(Transport, Type),
6✔
1264
    case Type of
6✔
1265
        odbc ->
1266
            [odbc, Server, Timeout];
×
1267
        sqlite ->
1268
            [sqlite, Host];
2✔
1269
        _ ->
1270
            Port = ejabberd_option:sql_port(Host),
4✔
1271
            DB = case ejabberd_option:sql_database(Host) of
4✔
1272
                     undefined -> <<"ejabberd">>;
×
1273
                     D -> D
4✔
1274
                 end,
1275
            User = ejabberd_option:sql_username(Host),
4✔
1276
            Pass = ejabberd_option:sql_password(Host),
4✔
1277
            SSLOpts = get_ssl_opts(Transport, Host),
4✔
1278
            case Type of
4✔
1279
                mssql ->
1280
                    case odbc_server_is_connstring(Server) of
×
1281
                        true ->
1282
                            [mssql, Server, Timeout];
×
1283
                        false ->
1284
                            Encryption = case Transport of
×
1285
                                tcp -> <<"">>;
×
1286
                                ssl -> <<";ENCRYPTION=require;ENCRYPT=yes">>
×
1287
                            end,
1288
                            [mssql, <<"DRIVER=ODBC;SERVER=", Server/binary, ";DATABASE=", DB/binary,
×
1289
                                      ";UID=", User/binary, ";PWD=", Pass/binary,
1290
                                      ";PORT=", (integer_to_binary(Port))/binary, Encryption/binary,
1291
                                      ";CLIENT_CHARSET=UTF-8;">>, Timeout]
1292
                    end;
1293
                _ ->
1294
                    [Type, Server, Port, DB, User, Pass, Timeout, Transport, SSLOpts]
4✔
1295
            end
1296
    end.
1297

1298
warn_if_ssl_unsupported(tcp, _) ->
1299
    ok;
6✔
1300
warn_if_ssl_unsupported(ssl, pgsql) ->
1301
    ok;
×
1302
warn_if_ssl_unsupported(ssl, mssql) ->
1303
    ok;
×
1304
warn_if_ssl_unsupported(ssl, mysql) ->
1305
    ok;
×
1306
warn_if_ssl_unsupported(ssl, Type) ->
1307
    ?WARNING_MSG("SSL connection is not supported for ~ts", [Type]).
×
1308

1309
get_ssl_opts(ssl, Host) ->
1310
    Opts1 = case ejabberd_option:sql_ssl_certfile(Host) of
×
1311
                undefined -> [];
×
1312
                CertFile -> [{certfile, CertFile}]
×
1313
            end,
1314
    Opts2 = case ejabberd_option:sql_ssl_cafile(Host) of
×
1315
                undefined -> Opts1;
×
1316
                CAFile -> [{cacertfile, CAFile}|Opts1]
×
1317
            end,
1318
    case ejabberd_option:sql_ssl_verify(Host) of
×
1319
        true ->
1320
            case lists:keymember(cacertfile, 1, Opts2) of
×
1321
                true ->
1322
                    [{verify, verify_peer}|Opts2];
×
1323
                false ->
1324
                    ?WARNING_MSG("SSL verification is enabled for "
×
1325
                                 "SQL connection, but option "
1326
                                 "'sql_ssl_cafile' is not set; "
1327
                                 "verification will be disabled", []),
×
1328
                    Opts2
×
1329
            end;
1330
        false ->
1331
            [{verify, verify_none}|Opts2]
×
1332
    end;
1333
get_ssl_opts(tcp, _) ->
1334
    [].
4✔
1335

1336
init_mssql_odbcinst(Host) ->
1337
    Driver = ejabberd_option:sql_odbc_driver(Host),
×
1338
    ODBCINST = io_lib:fwrite("[ODBC]~n"
×
1339
                             "Driver = ~s~n", [Driver]),
1340
    ?DEBUG("~ts:~n~ts", [odbcinst_config(), ODBCINST]),
×
1341
    case filelib:ensure_dir(odbcinst_config()) of
×
1342
        ok ->
1343
            try
×
1344
                ok = write_file_if_new(odbcinst_config(), ODBCINST),
×
1345
                os:putenv("ODBCSYSINI", tmp_dir()),
×
1346
                ok
×
1347
            catch error:{badmatch, {error, Reason} = Err} ->
1348
                    ?ERROR_MSG("Failed to create temporary files in ~ts: ~ts",
×
1349
                               [tmp_dir(), file:format_error(Reason)]),
×
1350
                    Err
×
1351
            end;
1352
        {error, Reason} = Err ->
1353
            ?ERROR_MSG("Failed to create temporary directory ~ts: ~ts",
×
1354
                       [tmp_dir(), file:format_error(Reason)]),
×
1355
            Err
×
1356
    end.
1357

1358
init_mssql(Host) ->
1359
    Server = ejabberd_option:sql_server(Host),
×
1360
    case odbc_server_is_connstring(Server) of
×
1361
        true -> ok;
×
1362
        false -> init_mssql_odbcinst(Host)
×
1363
    end.
1364

1365
odbc_server_is_connstring(Server) ->
1366
    case binary:match(Server, <<"=">>) of
×
1367
        nomatch -> false;
×
1368
        _ -> true
×
1369
    end.
1370

1371
write_file_if_new(File, Payload) ->
1372
    case filelib:is_file(File) of
×
1373
        true -> ok;
×
1374
        false -> file:write_file(File, Payload)
×
1375
    end.
1376

1377
tmp_dir() ->
1378
    case os:type() of
1✔
1379
        {win32, _} -> filename:join([os:getenv("HOME"), "conf"]);
×
1380
        _ -> filename:join(["/tmp", "ejabberd"])
1✔
1381
    end.
1382

1383
odbcinst_config() ->
1384
    filename:join(tmp_dir(), "odbcinst.ini").
1✔
1385

1386
max_fsm_queue() ->
1387
    proplists:get_value(max_queue, fsm_limit_opts(), unlimited).
3✔
1388

1389
fsm_limit_opts() ->
1390
    ejabberd_config:fsm_limit_opts([]).
6✔
1391

1392
query_timeout(LServer) ->
1393
    ejabberd_option:sql_query_timeout(LServer).
55,094✔
1394

1395
current_time() ->
1396
    erlang:monotonic_time(millisecond).
65,898✔
1397

1398
%% ***IMPORTANT*** This error format requires extended_errors turned on.
1399
extended_error({"08S01", _, Reason}) ->
1400
    % TCP Provider: The specified network name is no longer available
1401
    ?DEBUG("ODBC Link Failure: ~ts", [Reason]),
×
1402
    <<"Communication link failure">>;
×
1403
extended_error({"08001", _, Reason}) ->
1404
    % Login timeout expired
1405
    ?DEBUG("ODBC Connect Timeout: ~ts", [Reason]),
×
1406
    <<"SQL connection failed">>;
×
1407
extended_error({"IMC01", _, Reason}) ->
1408
    % The connection is broken and recovery is not possible
1409
    ?DEBUG("ODBC Link Failure: ~ts", [Reason]),
×
1410
    <<"Communication link failure">>;
×
1411
extended_error({"IMC06", _, Reason}) ->
1412
    % The connection is broken and recovery is not possible
1413
    ?DEBUG("ODBC Link Failure: ~ts", [Reason]),
×
1414
    <<"Communication link failure">>;
×
1415
extended_error({Code, _, Reason}) ->
1416
    ?DEBUG("ODBC Error ~ts: ~ts", [Code, Reason]),
×
1417
    iolist_to_binary(Reason);
×
1418
extended_error(Error) ->
1419
    Error.
×
1420

1421
check_error({error, Why} = Err, _Query) when Why == killed ->
1422
    Err;
×
1423
check_error({error, Why}, #sql_query{} = Query) ->
1424
    Err = extended_error(Why),
×
1425
    ?ERROR_MSG("SQL query '~ts' at ~p failed: ~p",
×
1426
               [Query#sql_query.hash, Query#sql_query.loc, Err]),
×
1427
    {error, Err};
×
1428
check_error({error, Why}, Query) ->
1429
    Err = extended_error(Why),
×
1430
    case catch iolist_to_binary(Query) of
×
1431
        SQuery when is_binary(SQuery) ->
1432
            ?ERROR_MSG("SQL query '~ts' failed: ~p", [SQuery, Err]);
×
1433
        _ ->
1434
            ?ERROR_MSG("SQL query ~p failed: ~p", [Query, Err])
×
1435
    end,
1436
    {error, Err};
×
1437
check_error(Result, _Query) ->
1438
    Result.
48,456✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc