• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

processone / ejabberd / 628

21 Nov 2023 12:55PM UTC coverage: 32.587% (-0.008%) from 32.595%
628

push

github

prefiks
Update xmpp and make opening bind2 session close other sessions with same tag

6 of 17 new or added lines in 3 files covered. (35.29%)

1 existing line in 1 file now uncovered.

13508 of 41452 relevant lines covered (32.59%)

645.47 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

44.21
/src/ejabberd_sql.erl
1
%%%----------------------------------------------------------------------
2
%%% File    : ejabberd_sql.erl
3
%%% Author  : Alexey Shchepin <alexey@process-one.net>
4
%%% Purpose : Serve SQL connection
5
%%% Created :  8 Dec 2004 by Alexey Shchepin <alexey@process-one.net>
6
%%%
7
%%%
8
%%% ejabberd, Copyright (C) 2002-2023   ProcessOne
9
%%%
10
%%% This program is free software; you can redistribute it and/or
11
%%% modify it under the terms of the GNU General Public License as
12
%%% published by the Free Software Foundation; either version 2 of the
13
%%% License, or (at your option) any later version.
14
%%%
15
%%% This program is distributed in the hope that it will be useful,
16
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
17
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18
%%% General Public License for more details.
19
%%%
20
%%% You should have received a copy of the GNU General Public License along
21
%%% with this program; if not, write to the Free Software Foundation, Inc.,
22
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
23
%%%
24
%%%----------------------------------------------------------------------
25

26
-module(ejabberd_sql).
27

28
-author('alexey@process-one.net').
29

30
-behaviour(p1_fsm).
31

32
%% External exports
33
-export([start_link/2,
34
         sql_query/2,
35
         sql_query_t/1,
36
         sql_transaction/2,
37
         sql_bloc/2,
38
         abort/1,
39
         restart/1,
40
         use_new_schema/0,
41
         sql_query_to_iolist/1,
42
         sql_query_to_iolist/2,
43
         escape/1,
44
         standard_escape/1,
45
         escape_like/1,
46
         escape_like_arg/1,
47
         escape_like_arg_circumflex/1,
48
         to_string_literal/2,
49
         to_string_literal_t/1,
50
         to_bool/1,
51
         sqlite_db/1,
52
         sqlite_file/1,
53
         encode_term/1,
54
         decode_term/1,
55
         odbcinst_config/0,
56
         init_mssql/1,
57
         keep_alive/2,
58
         to_list/2,
59
         to_array/2]).
60

61
%% gen_fsm callbacks
62
-export([init/1, handle_event/3, handle_sync_event/4,
63
         handle_info/3, terminate/3, print_state/1,
64
         code_change/4]).
65

66
-export([connecting/2, connecting/3,
67
         session_established/2, session_established/3]).
68

69
-include("logger.hrl").
70
-include("ejabberd_sql_pt.hrl").
71
-include("ejabberd_stacktrace.hrl").
72

73
-record(state,
74
        {db_ref               :: undefined | pid(),
75
         db_type = odbc       :: pgsql | mysql | sqlite | odbc | mssql,
76
         db_version           :: undefined | non_neg_integer() | {non_neg_integer(), atom(), non_neg_integer()},
77
         reconnect_count = 0  :: non_neg_integer(),
78
         host                 :: binary(),
79
         pending_requests     :: p1_queue:queue(),
80
         overload_reported    :: undefined | integer()}).
81

82
-define(STATE_KEY, ejabberd_sql_state).
83
-define(NESTING_KEY, ejabberd_sql_nesting_level).
84
-define(TOP_LEVEL_TXN, 0).
85
-define(MAX_TRANSACTION_RESTARTS, 10).
86
-define(KEEPALIVE_QUERY, [<<"SELECT 1;">>]).
87
-define(PREPARE_KEY, ejabberd_sql_prepare).
88
%%-define(DBGFSM, true).
89
-ifdef(DBGFSM).
90
-define(FSMOPTS, [{debug, [trace]}]).
91
-else.
92
-define(FSMOPTS, []).
93
-endif.
94

95
-type state() :: #state{}.
96
-type sql_query_simple(T) :: [sql_query(T) | binary()] | binary() |
97
                             #sql_query{} |
98
                             fun(() -> T) | fun((atom(), _) -> T).
99
-type sql_query(T) :: sql_query_simple(T) |
100
                      [{atom() | {atom(), any()}, sql_query_simple(T)}].
101
-type sql_query_result(T) :: {updated, non_neg_integer()} |
102
                             {error, binary() | atom()} |
103
                             {selected, [binary()], [[binary()]]} |
104
                             {selected, [any()]} |
105
                             T.
106

107
%%%----------------------------------------------------------------------
108
%%% API
109
%%%----------------------------------------------------------------------
110
-spec start_link(binary(), pos_integer()) -> {ok, pid()} | {error, term()}.
111
start_link(Host, I) ->
112
    Proc = binary_to_atom(get_worker_name(Host, I), utf8),
3✔
113
    p1_fsm:start_link({local, Proc}, ?MODULE, [Host],
3✔
114
                      fsm_limit_opts() ++ ?FSMOPTS).
115

116
-spec sql_query(binary(), sql_query(T)) -> sql_query_result(T).
117
sql_query(Host, Query) ->
118
    sql_call(Host, {sql_query, Query}).
15,250✔
119

120
%% SQL transaction based on a list of queries
121
%% This function automatically
122
-spec sql_transaction(binary(), [sql_query(T)] | fun(() -> T)) ->
123
                             {atomic, T} |
124
                             {aborted, any()}.
125
sql_transaction(Host, Queries)
126
    when is_list(Queries) ->
127
    F = fun () ->
5✔
128
                lists:foreach(fun (Query) -> sql_query_t(Query) end,
5✔
129
                              Queries)
130
        end,
131
    sql_transaction(Host, F);
5✔
132
%% SQL transaction, based on a erlang anonymous function (F = fun)
133
sql_transaction(Host, F) when is_function(F) ->
134
    case sql_call(Host, {sql_transaction, F}) of
2,726✔
135
        {atomic, _} = Ret -> Ret;
2,726✔
136
        {aborted, _} = Ret -> Ret;
×
137
        Err -> {aborted, Err}
×
138
    end.
139

140
%% SQL bloc, based on a erlang anonymous function (F = fun)
141
sql_bloc(Host, F) -> sql_call(Host, {sql_bloc, F}).
2,664✔
142

143
sql_call(Host, Msg) ->
144
    Timeout = query_timeout(Host),
20,640✔
145
    case get(?STATE_KEY) of
20,640✔
146
        undefined ->
147
            sync_send_event(Host,
20,598✔
148
                            {sql_cmd, Msg, current_time() + Timeout},
149
                            Timeout);
150
        _State ->
151
            nested_op(Msg)
42✔
152
    end.
153

154
keep_alive(Host, Proc) ->
155
    Timeout = query_timeout(Host),
×
156
    case sync_send_event(
×
157
           Proc,
158
           {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}, current_time() + Timeout},
159
           Timeout) of
160
        {selected,_,[[<<"1">>]]} ->
161
            ok;
×
162
        _Err ->
163
            ?ERROR_MSG("Keep alive query failed, closing connection: ~p", [_Err]),
×
164
            sync_send_event(Proc, force_timeout, Timeout)
×
165
    end.
166

167
sync_send_event(Host, Msg, Timeout) when is_binary(Host) ->
168
    case ejabberd_sql_sup:start(Host) of
20,598✔
169
        ok ->
170
            Proc = get_worker(Host),
20,598✔
171
            sync_send_event(Proc, Msg, Timeout);
20,598✔
172
        {error, _} = Err ->
173
            Err
×
174
    end;
175
sync_send_event(Proc, Msg, Timeout) ->
176
    try p1_fsm:sync_send_event(Proc, Msg, Timeout)
20,598✔
177
    catch _:{Reason, {p1_fsm, _, _}} ->
178
            {error, Reason}
×
179
    end.
180

181
-spec sql_query_t(sql_query(T)) -> sql_query_result(T).
182
%% This function is intended to be used from inside an sql_transaction:
183
sql_query_t(Query) ->
184
    QRes = sql_query_internal(Query),
19,615✔
185
    case QRes of
19,615✔
186
      {error, Reason} -> restart(Reason);
×
187
      Rs when is_list(Rs) ->
188
          case lists:keysearch(error, 1, Rs) of
×
189
            {value, {error, Reason}} -> restart(Reason);
×
190
            _ -> QRes
×
191
          end;
192
      _ -> QRes
19,615✔
193
    end.
194

195
abort(Reason) ->
196
    exit(Reason).
×
197

198
restart(Reason) ->
199
    throw({aborted, Reason}).
×
200

201
-spec escape_char(char()) -> binary().
202
escape_char($\000) -> <<"\\0">>;
×
203
escape_char($\n) -> <<"\\n">>;
×
204
escape_char($\t) -> <<"\\t">>;
×
205
escape_char($\b) -> <<"\\b">>;
×
206
escape_char($\r) -> <<"\\r">>;
×
207
escape_char($') -> <<"''">>;
16✔
208
escape_char($") -> <<"\\\"">>;
16✔
209
escape_char($\\) -> <<"\\\\">>;
432✔
210
escape_char(C) -> <<C>>.
20,084✔
211

212
-spec escape(binary()) -> binary().
213
escape(S) ->
214
        <<  <<(escape_char(Char))/binary>> || <<Char>> <= S >>.
820✔
215

216
%% Escape character that will confuse an SQL engine
217
%% Percent and underscore only need to be escaped for pattern matching like
218
%% statement
219
escape_like(S) when is_binary(S) ->
220
    << <<(escape_like(C))/binary>> || <<C>> <= S >>;
×
221
escape_like($%) -> <<"\\%">>;
×
222
escape_like($_) -> <<"\\_">>;
×
223
escape_like($\\) -> <<"\\\\\\\\">>;
×
224
escape_like(C) when is_integer(C), C >= 0, C =< 255 -> escape_char(C).
×
225

226
escape_like_arg(S) when is_binary(S) ->
227
    << <<(escape_like_arg(C))/binary>> || <<C>> <= S >>;
924✔
228
escape_like_arg($%) -> <<"\\%">>;
546✔
229
escape_like_arg($_) -> <<"\\_">>;
1,086✔
230
escape_like_arg($\\) -> <<"\\\\">>;
546✔
231
escape_like_arg($[) -> <<"\\[">>;     % For MSSQL
546✔
232
escape_like_arg($]) -> <<"\\]">>;
546✔
233
escape_like_arg(C) when is_integer(C), C >= 0, C =< 255 -> <<C>>.
29,110✔
234

235
escape_like_arg_circumflex(S) when is_binary(S) ->
236
    << <<(escape_like_arg_circumflex(C))/binary>> || <<C>> <= S >>;
×
237
escape_like_arg_circumflex($%) -> <<"^%">>;
×
238
escape_like_arg_circumflex($_) -> <<"^_">>;
×
239
escape_like_arg_circumflex($^) -> <<"^^">>;
×
240
escape_like_arg_circumflex($[) -> <<"^[">>;     % For MSSQL
×
241
escape_like_arg_circumflex($]) -> <<"^]">>;
×
242
escape_like_arg_circumflex(C) when is_integer(C), C >= 0, C =< 255 -> <<C>>.
×
243

244
to_bool(<<"t">>) -> true;
×
245
to_bool(<<"true">>) -> true;
×
246
to_bool(<<"1">>) -> true;
458✔
247
to_bool(true) -> true;
×
248
to_bool(1) -> true;
292✔
249
to_bool(_) -> false.
2,202✔
250

251
to_list(EscapeFun, Val) ->
252
    Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)),
×
253
    [<<"(">>, Escaped, <<")">>].
×
254

255
to_array(EscapeFun, Val) ->
256
    Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)),
×
257
    lists:flatten([<<"{">>, Escaped, <<"}">>]).
×
258

259
to_string_literal(odbc, S) ->
260
    <<"'", (escape(S))/binary, "'">>;
×
261
to_string_literal(mysql, S) ->
262
    <<"'", (escape(S))/binary, "'">>;
410✔
263
to_string_literal(mssql, S) ->
264
    <<"'", (standard_escape(S))/binary, "'">>;
×
265
to_string_literal(sqlite, S) ->
266
    <<"'", (standard_escape(S))/binary, "'">>;
410✔
267
to_string_literal(pgsql, S) ->
268
    <<"E'", (escape(S))/binary, "'">>.
410✔
269

270
to_string_literal_t(S) ->
271
    State = get(?STATE_KEY),
15✔
272
    to_string_literal(State#state.db_type, S).
15✔
273

274
encode_term(Term) ->
275
    escape(list_to_binary(
×
276
             erl_prettypr:format(erl_syntax:abstract(Term),
277
                                 [{paper, 65535}, {ribbon, 65535}]))).
278

279
decode_term(Bin) ->
280
    Str = binary_to_list(<<Bin/binary, ".">>),
1,650✔
281
    try
1,650✔
282
        {ok, Tokens, _} = erl_scan:string(Str),
1,650✔
283
        {ok, Term} = erl_parse:parse_term(Tokens),
1,650✔
284
        Term
1,650✔
285
    catch _:{badmatch, {error, {Line, Mod, Reason}, _}} ->
286
            ?ERROR_MSG("Corrupted Erlang term in SQL database:~n"
×
287
                       "** Scanner error: at line ~B: ~ts~n"
288
                       "** Term: ~ts",
289
                       [Line, Mod:format_error(Reason), Bin]),
×
290
            erlang:error(badarg);
×
291
          _:{badmatch, {error, {Line, Mod, Reason}}} ->
292
            ?ERROR_MSG("Corrupted Erlang term in SQL database:~n"
×
293
                       "** Parser error: at line ~B: ~ts~n"
294
                       "** Term: ~ts",
295
                       [Line, Mod:format_error(Reason), Bin]),
×
296
            erlang:error(badarg)
×
297
    end.
298

299
-spec sqlite_db(binary()) -> atom().
300
sqlite_db(Host) ->
301
    list_to_atom("ejabberd_sqlite_" ++ binary_to_list(Host)).
18,235✔
302

303
-spec sqlite_file(binary()) -> string().
304
sqlite_file(Host) ->
305
    case ejabberd_option:sql_database(Host) of
2✔
306
        undefined ->
307
            Path = ["sqlite", atom_to_list(node()),
2✔
308
                    binary_to_list(Host), "ejabberd.db"],
309
            case file:get_cwd() of
2✔
310
                {ok, Cwd} ->
311
                    filename:join([Cwd|Path]);
2✔
312
                {error, Reason} ->
313
                    ?ERROR_MSG("Failed to get current directory: ~ts",
×
314
                               [file:format_error(Reason)]),
×
315
                    filename:join(Path)
×
316
            end;
317
        File ->
318
            binary_to_list(File)
×
319
    end.
320

321
use_new_schema() ->
322
    ejabberd_option:new_sql_schema().
16,273✔
323

324
-spec get_worker(binary()) -> atom().
325
get_worker(Host) ->
326
    PoolSize = ejabberd_option:sql_pool_size(Host),
20,598✔
327
    I = p1_rand:round_robin(PoolSize) + 1,
20,598✔
328
    binary_to_existing_atom(get_worker_name(Host, I), utf8).
20,598✔
329

330
-spec get_worker_name(binary(), pos_integer()) -> binary().
331
get_worker_name(Host, I) ->
332
    <<"ejabberd_sql_", Host/binary, $_, (integer_to_binary(I))/binary>>.
20,601✔
333

334
%%%----------------------------------------------------------------------
335
%%% Callback functions from gen_fsm
336
%%%----------------------------------------------------------------------
337
init([Host]) ->
338
    process_flag(trap_exit, true),
3✔
339
    case ejabberd_option:sql_keepalive_interval(Host) of
3✔
340
        undefined ->
341
            ok;
3✔
342
        KeepaliveInterval ->
343
            timer:apply_interval(KeepaliveInterval, ?MODULE,
×
344
                                 keep_alive, [Host, self()])
345
    end,
346
    [DBType | _] = db_opts(Host),
3✔
347
    p1_fsm:send_event(self(), connect),
3✔
348
    QueueType = ejabberd_option:sql_queue_type(Host),
3✔
349
    {ok, connecting,
3✔
350
     #state{db_type = DBType, host = Host,
351
            pending_requests = p1_queue:new(QueueType, max_fsm_queue())}}.
352

353
connecting(connect, #state{host = Host} = State) ->
354
    ConnectRes = case db_opts(Host) of
3✔
355
                     [mysql | Args] -> apply(fun mysql_connect/8, Args);
1✔
356
                     [pgsql | Args] -> apply(fun pgsql_connect/8, Args);
1✔
357
                     [sqlite | Args] -> apply(fun sqlite_connect/1, Args);
1✔
358
                     [mssql | Args] -> apply(fun odbc_connect/2, Args);
×
359
                     [odbc | Args] -> apply(fun odbc_connect/2, Args)
×
360
                 end,
361
    case ConnectRes of
3✔
362
        {ok, Ref} ->
363
            try link(Ref) of
3✔
364
                _ ->
365
                    lists:foreach(
3✔
366
                      fun({{?PREPARE_KEY, _} = Key, _}) ->
367
                              erase(Key);
×
368
                         (_) ->
369
                              ok
9✔
370
                      end, get()),
371
                    PendingRequests =
3✔
372
                        p1_queue:dropwhile(
373
                          fun(Req) ->
374
                                  p1_fsm:send_event(self(), Req),
×
375
                                  true
×
376
                          end, State#state.pending_requests),
377
                    State1 = State#state{db_ref = Ref,
3✔
378
                                         pending_requests = PendingRequests},
379
                    State2 = get_db_version(State1),
3✔
380
                    {next_state, session_established, State2#state{reconnect_count = 0}}
3✔
381
            catch _:Reason ->
382
                    handle_reconnect(Reason, State)
×
383
            end;
384
        {error, Reason} ->
385
            handle_reconnect(Reason, State)
×
386
    end;
387
connecting(Event, State) ->
388
    ?WARNING_MSG("Unexpected event in 'connecting': ~p",
×
389
                 [Event]),
×
390
    {next_state, connecting, State}.
×
391

392
connecting({sql_cmd, {sql_query, ?KEEPALIVE_QUERY}, Timestamp},
393
           From, State) ->
394
    reply(From, {error, <<"SQL connection failed">>}, Timestamp),
×
395
    {next_state, connecting, State};
×
396
connecting({sql_cmd, Command, Timestamp} = Req, From,
397
           State) ->
398
    ?DEBUG("Queuing pending request while connecting:~n\t~p",
×
399
           [Req]),
×
400
    PendingRequests =
×
401
        try p1_queue:in({sql_cmd, Command, From, Timestamp},
×
402
                        State#state.pending_requests)
403
        catch error:full ->
404
                Err = <<"SQL request queue is overfilled">>,
×
405
                ?ERROR_MSG("~ts, bouncing all pending requests", [Err]),
×
406
                Q = p1_queue:dropwhile(
×
407
                      fun({sql_cmd, _, To, TS}) ->
408
                              reply(To, {error, Err}, TS),
×
409
                              true
×
410
                      end, State#state.pending_requests),
411
                p1_queue:in({sql_cmd, Command, From, Timestamp}, Q)
×
412
        end,
413
    {next_state, connecting,
×
414
     State#state{pending_requests = PendingRequests}};
415
connecting(Request, {Who, _Ref}, State) ->
416
    ?WARNING_MSG("Unexpected call ~p from ~p in 'connecting'",
×
417
                 [Request, Who]),
×
418
    {next_state, connecting, State}.
×
419

420
session_established({sql_cmd, Command, Timestamp}, From,
421
                    State) ->
422
    run_sql_cmd(Command, From, State, Timestamp);
20,598✔
423
session_established(Request, {Who, _Ref}, State) ->
424
    ?WARNING_MSG("Unexpected call ~p from ~p in 'session_established'",
×
425
                 [Request, Who]),
×
426
    {next_state, session_established, State}.
×
427

428
session_established({sql_cmd, Command, From, Timestamp},
429
                    State) ->
430
    run_sql_cmd(Command, From, State, Timestamp);
×
431
session_established(force_timeout, State) ->
432
    {stop, timeout, State};
×
433
session_established(Event, State) ->
434
    ?WARNING_MSG("Unexpected event in 'session_established': ~p",
×
435
                 [Event]),
×
436
    {next_state, session_established, State}.
×
437

438
handle_event(_Event, StateName, State) ->
439
    {next_state, StateName, State}.
×
440

441
handle_sync_event(_Event, _From, StateName, State) ->
442
    {reply, {error, badarg}, StateName, State}.
×
443

444
code_change(_OldVsn, StateName, State, _Extra) ->
445
    {ok, StateName, State}.
×
446

447
handle_info({'EXIT', _Pid, _Reason}, connecting, State) ->
448
    {next_state, connecting, State};
×
449
handle_info({'EXIT', _Pid, Reason}, _StateName, State) ->
450
    handle_reconnect(Reason, State);
×
451
handle_info(Info, StateName, State) ->
452
    ?WARNING_MSG("Unexpected info in ~p: ~p",
×
453
                 [StateName, Info]),
×
454
    {next_state, StateName, State}.
×
455

456
terminate(_Reason, _StateName, State) ->
457
    case State#state.db_type of
3✔
458
        mysql -> catch p1_mysql_conn:stop(State#state.db_ref);
1✔
459
        sqlite -> catch sqlite3:close(sqlite_db(State#state.host));
1✔
460
        _ -> ok
1✔
461
    end,
462
    ok.
3✔
463

464
%%----------------------------------------------------------------------
465
%% Func: print_state/1
466
%% Purpose: Prepare the state to be printed on error log
467
%% Returns: State to print
468
%%----------------------------------------------------------------------
469
print_state(State) -> State.
×
470

471
%%%----------------------------------------------------------------------
472
%%% Internal functions
473
%%%----------------------------------------------------------------------
474
handle_reconnect(Reason, #state{host = Host, reconnect_count = RC} = State) ->
475
    StartInterval0 = ejabberd_option:sql_start_interval(Host),
×
476
    StartInterval = case RC of
×
477
                        0 -> erlang:min(5000, StartInterval0);
×
478
                        _ -> StartInterval0
×
479
                    end,
480
    ?WARNING_MSG("~p connection failed:~n"
×
481
                 "** Reason: ~p~n"
482
                 "** Retry after: ~B seconds",
483
                 [State#state.db_type, Reason,
484
                  StartInterval div 1000]),
×
485
    case State#state.db_type of
×
486
        mysql -> catch p1_mysql_conn:stop(State#state.db_ref);
×
487
        sqlite -> catch sqlite3:close(sqlite_db(State#state.host));
×
488
        pgsql -> catch pgsql:terminate(State#state.db_ref);
×
489
        _ -> ok
×
490
    end,
491
    p1_fsm:send_event_after(StartInterval, connect),
×
492
    {next_state, connecting, State#state{reconnect_count = RC + 1}}.
×
493

494
run_sql_cmd(Command, From, State, Timestamp) ->
495
    case current_time() >= Timestamp of
20,598✔
496
        true ->
497
            State1 = report_overload(State),
×
498
            {next_state, session_established, State1};
×
499
        false ->
500
            receive
20,598✔
501
                {'EXIT', _Pid, Reason} ->
502
                    PR = p1_queue:in({sql_cmd, Command, From, Timestamp},
×
503
                                     State#state.pending_requests),
504
                    handle_reconnect(Reason, State#state{pending_requests = PR})
×
505
            after 0 ->
506
                put(?NESTING_KEY, ?TOP_LEVEL_TXN),
20,598✔
507
                put(?STATE_KEY, State),
20,598✔
508
                abort_on_driver_error(outer_op(Command), From, Timestamp)
20,598✔
509
            end
510
    end.
511

512
%% @doc Only called by handle_call, only handles top level operations.
513
-spec outer_op(Op::{atom(), binary()}) ->
514
    {error, Reason::binary()} | {aborted, Reason::binary()} | {atomic, Result::any()}.
515
outer_op({sql_query, Query}) ->
516
    sql_query_internal(Query);
15,241✔
517
outer_op({sql_transaction, F}) ->
518
    outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, <<"">>);
2,720✔
519
outer_op({sql_bloc, F}) -> execute_bloc(F).
2,637✔
520

521
%% Called via sql_query/transaction/bloc from client code when inside a
522
%% nested operation
523
nested_op({sql_query, Query}) ->
524
    sql_query_internal(Query);
9✔
525
nested_op({sql_transaction, F}) ->
526
    NestingLevel = get(?NESTING_KEY),
6✔
527
    if NestingLevel =:= (?TOP_LEVEL_TXN) ->
6✔
528
           outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, <<"">>);
3✔
529
       true -> inner_transaction(F)
3✔
530
    end;
531
nested_op({sql_bloc, F}) -> execute_bloc(F).
27✔
532

533
%% Never retry nested transactions - only outer transactions
534
inner_transaction(F) ->
535
    PreviousNestingLevel = get(?NESTING_KEY),
3✔
536
    case get(?NESTING_KEY) of
3✔
537
      ?TOP_LEVEL_TXN ->
538
          {backtrace, T} = process_info(self(), backtrace),
×
539
          ?ERROR_MSG("Inner transaction called at outer txn "
×
540
                     "level. Trace: ~ts",
541
                     [T]),
×
542
          erlang:exit(implementation_faulty);
×
543
      _N -> ok
3✔
544
    end,
545
    put(?NESTING_KEY, PreviousNestingLevel + 1),
3✔
546
    Result = (catch F()),
3✔
547
    put(?NESTING_KEY, PreviousNestingLevel),
3✔
548
    case Result of
3✔
549
      {aborted, Reason} -> {aborted, Reason};
×
550
      {'EXIT', Reason} -> {'EXIT', Reason};
×
551
      {atomic, Res} -> {atomic, Res};
×
552
      Res -> {atomic, Res}
3✔
553
    end.
554

555
outer_transaction(F, NRestarts, _Reason) ->
556
    PreviousNestingLevel = get(?NESTING_KEY),
2,723✔
557
    case get(?NESTING_KEY) of
2,723✔
558
      ?TOP_LEVEL_TXN -> ok;
2,723✔
559
      _N ->
560
          {backtrace, T} = process_info(self(), backtrace),
×
561
          ?ERROR_MSG("Outer transaction called at inner txn "
×
562
                     "level. Trace: ~ts",
563
                     [T]),
×
564
          erlang:exit(implementation_faulty)
×
565
    end,
566
    case sql_begin() of
2,723✔
567
        {error, Reason} ->
568
            maybe_restart_transaction(F, NRestarts, Reason, false);
×
569
        _ ->
570
            put(?NESTING_KEY, PreviousNestingLevel + 1),
2,723✔
571
            try F() of
2,723✔
572
                Res ->
573
                    case sql_commit() of
2,723✔
574
                        {error, Reason} ->
575
                            restart(Reason);
×
576
                        _ ->
577
                            {atomic, Res}
2,723✔
578
                    end
579
            catch
580
                ?EX_RULE(throw, {aborted, Reason}, _) when NRestarts > 0 ->
581
                    maybe_restart_transaction(F, NRestarts, Reason, true);
×
582
                ?EX_RULE(throw, {aborted, Reason}, Stack) when NRestarts =:= 0 ->
583
                    StackTrace = ?EX_STACK(Stack),
×
584
                    ?ERROR_MSG("SQL transaction restarts exceeded~n** "
×
585
                               "Restarts: ~p~n** Last abort reason: "
586
                               "~p~n** Stacktrace: ~p~n** When State "
587
                               "== ~p",
588
                               [?MAX_TRANSACTION_RESTARTS, Reason,
589
                                StackTrace, get(?STATE_KEY)]),
×
590
                    maybe_restart_transaction(F, NRestarts, Reason, true);
×
591
                ?EX_RULE(exit, Reason, _) ->
592
                    maybe_restart_transaction(F, 0, Reason, true)
×
593
            end
594
    end.
595

596
maybe_restart_transaction(F, NRestarts, Reason, DoRollback) ->
597
    Res = case driver_restart_required(Reason) of
×
598
              true ->
599
                  {aborted, Reason};
×
600
              _ when DoRollback ->
601
                  case sql_rollback() of
×
602
                      {error, Reason2} ->
603
                          case driver_restart_required(Reason2) of
×
604
                              true ->
605
                                  {aborted, Reason2};
×
606
                              _ ->
607
                                  continue
×
608
                          end;
609
                      _ ->
610
                          continue
×
611
                  end;
612
              _ ->
613
                  continue
×
614
    end,
615
    case Res of
×
616
        continue when NRestarts > 0 ->
617
            put(?NESTING_KEY, ?TOP_LEVEL_TXN),
×
618
            outer_transaction(F, NRestarts - 1, Reason);
×
619
        continue ->
620
            {aborted, Reason};
×
621
        Other ->
622
            Other
×
623
    end.
624

625
execute_bloc(F) ->
626
    case catch F() of
2,664✔
627
      {aborted, Reason} -> {aborted, Reason};
×
628
      {'EXIT', Reason} -> {aborted, Reason};
×
629
      Res -> {atomic, Res}
2,664✔
630
    end.
631

632
execute_fun(F) when is_function(F, 0) ->
633
    F();
×
634
execute_fun(F) when is_function(F, 2) ->
635
    State = get(?STATE_KEY),
3,447✔
636
    F(State#state.db_type, State#state.db_version).
3,447✔
637

638
sql_query_internal([{_, _} | _] = Queries) ->
639
    State = get(?STATE_KEY),
5,446✔
640
    case select_sql_query(Queries, State) of
5,446✔
641
        undefined ->
642
            {error, <<"no matching query for the current DBMS found">>};
×
643
        Query ->
644
            sql_query_internal(Query)
5,446✔
645
    end;
646
sql_query_internal(#sql_query{} = Query) ->
647
    State = get(?STATE_KEY),
30,106✔
648
    Res =
30,106✔
649
        try
650
            case State#state.db_type of
30,106✔
651
                odbc ->
652
                    generic_sql_query(Query);
×
653
                mssql ->
654
                    mssql_sql_query(Query);
×
655
                pgsql ->
656
                    Key = {?PREPARE_KEY, Query#sql_query.hash},
9,498✔
657
                    case get(Key) of
9,498✔
658
                        undefined ->
659
                            Host = State#state.host,
131✔
660
                            PreparedStatements =
131✔
661
                                ejabberd_option:sql_prepared_statements(Host),
662
                            case PreparedStatements of
131✔
663
                                false ->
664
                                    put(Key, ignore);
×
665
                                true ->
666
                                    case pgsql_prepare(Query, State) of
131✔
667
                                        {ok, _, _, _} ->
668
                                            put(Key, prepared);
131✔
669
                                        {error, Error} ->
670
                                            ?ERROR_MSG(
×
671
                                               "PREPARE failed for SQL query "
672
                                               "at ~p: ~p",
673
                                               [Query#sql_query.loc, Error]),
×
674
                                            put(Key, ignore)
×
675
                                    end
676
                            end;
677
                        _ ->
678
                            ok
9,367✔
679
                    end,
680
                    case get(Key) of
9,498✔
681
                        prepared ->
682
                            pgsql_execute_sql_query(Query, State);
9,498✔
683
                        _ ->
684
                            pgsql_sql_query(Query)
×
685
                    end;
686
                mysql ->
687
                    case {Query#sql_query.flags, ejabberd_option:sql_prepared_statements(State#state.host)} of
10,541✔
688
                        {1, _} ->
689
                            generic_sql_query(Query);
×
690
                        {_, false} ->
691
                            generic_sql_query(Query);
×
692
                        _ ->
693
                            mysql_prepared_execute(Query, State)
10,541✔
694
                    end;
695
                sqlite ->
696
                    sqlite_sql_query(Query)
10,067✔
697
            end
698
        catch exit:{timeout, _} ->
699
                {error, <<"timed out">>};
×
700
              exit:{killed, _} ->
701
                {error, <<"killed">>};
×
702
              exit:{normal, _} ->
703
                {error, <<"terminated unexpectedly">>};
×
704
              exit:{shutdown, _} ->
705
                {error, <<"shutdown">>};
×
706
              ?EX_RULE(Class, Reason, Stack) ->
707
                StackTrace = ?EX_STACK(Stack),
×
708
                ?ERROR_MSG("Internal error while processing SQL query:~n** ~ts",
×
709
                           [misc:format_exception(2, Class, Reason, StackTrace)]),
×
710
                {error, <<"internal error">>}
×
711
        end,
712
    check_error(Res, Query);
30,106✔
713
sql_query_internal(F) when is_function(F) ->
714
    case catch execute_fun(F) of
3,447✔
715
        {aborted, Reason} -> {error, Reason};
×
716
        {'EXIT', Reason} -> {error, Reason};
×
717
        Res -> Res
3,447✔
718
    end;
719
sql_query_internal(Query) ->
720
    State = get(?STATE_KEY),
16,825✔
721
    ?DEBUG("SQL: \"~ts\"", [Query]),
16,825✔
722
    QueryTimeout = query_timeout(State#state.host),
16,825✔
723
    Res = case State#state.db_type of
16,825✔
724
            odbc ->
725
                to_odbc(odbc:sql_query(State#state.db_ref, [Query],
×
726
                                       QueryTimeout - 1000));
727
            mssql ->
728
                to_odbc(odbc:sql_query(State#state.db_ref, [Query],
×
729
                                       QueryTimeout - 1000));
730
            pgsql ->
731
                pgsql_to_odbc(pgsql:squery(State#state.db_ref, Query,
2,267✔
732
                                           QueryTimeout - 1000));
733
            mysql ->
734
                mysql_to_odbc(p1_mysql_conn:squery(State#state.db_ref,
2,267✔
735
                                                   [Query], self(),
736
                                                   [{timeout, QueryTimeout - 1000},
737
                                                    {result_type, binary}]));
738
              sqlite ->
739
                  Host = State#state.host,
12,291✔
740
                  sqlite_to_odbc(Host, sqlite3:sql_exec(sqlite_db(Host), Query))
12,291✔
741
          end,
742
    check_error(Res, Query).
16,825✔
743

744
select_sql_query(Queries, State) ->
745
    select_sql_query(
5,446✔
746
      Queries, State#state.db_type, State#state.db_version, undefined).
747

748
select_sql_query([], _Type, _Version, undefined) ->
749
    undefined;
×
750
select_sql_query([], _Type, _Version, Query) ->
751
    Query;
×
752
select_sql_query([{any, Query} | _], _Type, _Version, _) ->
753
    Query;
5,446✔
754
select_sql_query([{Type, Query} | _], Type, _Version, _) ->
755
    Query;
×
756
select_sql_query([{{Type, _Version1}, Query1} | Rest], Type, undefined, _) ->
757
    select_sql_query(Rest, Type, undefined, Query1);
×
758
select_sql_query([{{Type, Version1}, Query1} | Rest], Type, Version, Query) ->
759
    if
×
760
        Version >= Version1 ->
761
            Query1;
×
762
        true ->
763
            select_sql_query(Rest, Type, Version, Query)
×
764
    end;
765
select_sql_query([{_, _} | Rest], Type, Version, Query) ->
766
    select_sql_query(Rest, Type, Version, Query).
5,446✔
767

768
generic_sql_query(SQLQuery) ->
769
    sql_query_format_res(
×
770
      sql_query_internal(generic_sql_query_format(SQLQuery)),
771
      SQLQuery).
772

773
generic_sql_query_format(SQLQuery) ->
774
    Args = (SQLQuery#sql_query.args)(generic_escape()),
×
775
    (SQLQuery#sql_query.format_query)(Args).
×
776

777
generic_escape() ->
778
    #sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end,
×
779
                integer = fun(X) -> misc:i2l(X) end,
×
780
                boolean = fun(true) -> <<"1">>;
×
781
                             (false) -> <<"0">>
×
782
                          end,
783
                in_array_string = fun(X) -> <<"'", (escape(X))/binary, "'">> end,
×
784
                like_escape = fun() -> <<"">> end
×
785
               }.
786

787
pgsql_sql_query(SQLQuery) ->
788
    sql_query_format_res(
×
789
      sql_query_internal(pgsql_sql_query_format(SQLQuery)),
790
      SQLQuery).
791

792
pgsql_sql_query_format(SQLQuery) ->
793
    Args = (SQLQuery#sql_query.args)(pgsql_escape()),
×
794
    (SQLQuery#sql_query.format_query)(Args).
×
795

796
pgsql_escape() ->
797
    #sql_escape{string = fun(X) -> <<"E'", (escape(X))/binary, "'">> end,
×
798
                integer = fun(X) -> misc:i2l(X) end,
×
799
                boolean = fun(true) -> <<"'t'">>;
×
800
                             (false) -> <<"'f'">>
×
801
                          end,
802
                in_array_string = fun(X) -> <<"E'", (escape(X))/binary, "'">> end,
×
803
                like_escape = fun() -> <<"ESCAPE E'\\\\'">> end
×
804
               }.
805

806
sqlite_sql_query(SQLQuery) ->
807
    sql_query_format_res(
10,067✔
808
      sql_query_internal(sqlite_sql_query_format(SQLQuery)),
809
      SQLQuery).
810

811
sqlite_sql_query_format(SQLQuery) ->
812
    Args = (SQLQuery#sql_query.args)(sqlite_escape()),
10,067✔
813
    (SQLQuery#sql_query.format_query)(Args).
10,067✔
814

815
sqlite_escape() ->
816
    #sql_escape{string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end,
10,067✔
817
                integer = fun(X) -> misc:i2l(X) end,
2,828✔
818
                boolean = fun(true) -> <<"1">>;
115✔
819
                             (false) -> <<"0">>
290✔
820
                          end,
821
                in_array_string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end,
×
822
                like_escape = fun() -> <<"ESCAPE '\\'">> end
308✔
823
               }.
824

825
standard_escape(S) ->
826
    << <<(case Char of
23,107✔
827
              $' -> << "''" >>;
8,970✔
828
              _ -> << Char >>
894,138✔
829
          end)/binary>> || <<Char>> <= S >>.
23,107✔
830

831
mssql_sql_query(SQLQuery) ->
832
    sqlite_sql_query(SQLQuery).
×
833

834
pgsql_prepare(SQLQuery, State) ->
835
    Escape = #sql_escape{_ = fun(_) -> arg end,
131✔
836
                         like_escape = fun() -> escape end},
6✔
837
    {RArgs, _} =
131✔
838
        lists:foldl(
839
            fun(arg, {Acc, I}) ->
840
                {[<<$$, (integer_to_binary(I))/binary>> | Acc], I + 1};
361✔
841
               (escape, {Acc, I}) ->
842
                   {[<<"ESCAPE E'\\\\'">> | Acc], I};
6✔
843
               (List, {Acc, I}) when is_list(List) ->
844
                   {[<<$$, (integer_to_binary(I))/binary>> | Acc], I + 1}
×
845
            end, {[], 1}, (SQLQuery#sql_query.args)(Escape)),
846
    Args = lists:reverse(RArgs),
131✔
847
    %N = length((SQLQuery#sql_query.args)(Escape)),
848
    %Args = [<<$$, (integer_to_binary(I))/binary>> || I <- lists:seq(1, N)],
849
    Query = (SQLQuery#sql_query.format_query)(Args),
131✔
850
    pgsql:prepare(State#state.db_ref, SQLQuery#sql_query.hash, Query).
131✔
851

852
pgsql_execute_escape() ->
853
    #sql_escape{string = fun(X) -> X end,
9,498✔
854
                integer = fun(X) -> [misc:i2l(X)] end,
2,558✔
855
                boolean = fun(true) -> "1";
115✔
856
                             (false) -> "0"
290✔
857
                          end,
858
                in_array_string = fun(X) -> <<"\"", (escape(X))/binary, "\"">> end,
×
859
                like_escape = fun() -> ignore end
308✔
860
               }.
861

862
pgsql_execute_sql_query(SQLQuery, State) ->
863
    Args = (SQLQuery#sql_query.args)(pgsql_execute_escape()),
9,498✔
864
    Args2 = lists:filter(fun(ignore) -> false; (_) -> true end, Args),
9,498✔
865
    ExecuteRes =
9,498✔
866
        pgsql:execute(State#state.db_ref, SQLQuery#sql_query.hash, Args2),
867
%    {T, ExecuteRes} =
868
%        timer:tc(pgsql, execute, [State#state.db_ref, SQLQuery#sql_query.hash, Args]),
869
%    io:format("T ~ts ~p~n", [SQLQuery#sql_query.hash, T]),
870
    Res = pgsql_execute_to_odbc(ExecuteRes),
9,498✔
871
    sql_query_format_res(Res, SQLQuery).
9,498✔
872

873
mysql_prepared_execute(#sql_query{hash = Hash} = Query, State) ->
874
    ValEsc = #sql_escape{like_escape = fun() -> ignore end, _ = fun(X) -> X end},
10,541✔
875
    TypesEsc = #sql_escape{string = fun(_) -> string end,
10,541✔
876
                           integer = fun(_) -> integer end,
2,867✔
877
                           boolean = fun(_) -> bool end,
405✔
878
                           in_array_string = fun(_) -> string end,
×
879
                           like_escape = fun() -> ignore end},
308✔
880
    Val = [X || X <- (Query#sql_query.args)(ValEsc), X /= ignore],
10,541✔
881
    Types = [X || X <- (Query#sql_query.args)(TypesEsc), X /= ignore],
10,541✔
882
    QueryFn = fun() ->
10,541✔
883
        PrepEsc = #sql_escape{like_escape = fun() -> <<>> end, _ = fun(_) -> <<"?">> end},
163✔
884
        (Query#sql_query.format_query)((Query#sql_query.args)(PrepEsc))
163✔
885
        end,
886
    QueryTimeout = query_timeout(State#state.host),
10,541✔
887
    Res = p1_mysql_conn:prepared_query(State#state.db_ref, QueryFn, Hash, Val, Types,
10,541✔
888
                                       self(), [{timeout, QueryTimeout - 1000}]),
889
    Res2 = mysql_to_odbc(Res),
10,541✔
890
    sql_query_format_res(Res2, Query).
10,541✔
891

892
sql_query_format_res({selected, _, Rows}, SQLQuery) ->
893
    Res =
18,898✔
894
        lists:flatmap(
895
          fun(Row) ->
896
                  try
30,627✔
897
                      [(SQLQuery#sql_query.format_res)(Row)]
30,627✔
898
                  catch
899
                      ?EX_RULE(Class, Reason, Stack) ->
900
                          StackTrace = ?EX_STACK(Stack),
×
901
                          ?ERROR_MSG("Error while processing SQL query result:~n"
×
902
                                     "** Row: ~p~n** ~ts",
903
                                     [Row,
904
                                      misc:format_exception(2, Class, Reason, StackTrace)]),
×
905
                          []
×
906
                  end
907
          end, Rows),
908
    {selected, Res};
18,898✔
909
sql_query_format_res(Res, _SQLQuery) ->
910
    Res.
11,208✔
911

912
sql_query_to_iolist(SQLQuery) ->
913
    generic_sql_query_format(SQLQuery).
×
914

915
sql_query_to_iolist(sqlite, SQLQuery) ->
916
    sqlite_sql_query_format(SQLQuery);
×
917
sql_query_to_iolist(_DbType, SQLQuery) ->
918
    generic_sql_query_format(SQLQuery).
×
919

920
sql_begin() ->
921
    sql_query_internal(
2,723✔
922
      [{mssql, [<<"begin transaction;">>]},
923
       {any, [<<"begin;">>]}]).
924

925
sql_commit() ->
926
    sql_query_internal(
2,723✔
927
      [{mssql, [<<"commit transaction;">>]},
928
       {any, [<<"commit;">>]}]).
929

930
sql_rollback() ->
931
    sql_query_internal(
×
932
      [{mssql, [<<"rollback transaction;">>]},
933
       {any, [<<"rollback;">>]}]).
934

935
driver_restart_required(<<"query timed out">>) -> true;
×
936
driver_restart_required(<<"connection closed">>) -> true;
×
937
driver_restart_required(<<"Failed sending data on socket", _/binary>>) -> true;
×
938
driver_restart_required(<<"SQL connection failed">>) -> true;
×
939
driver_restart_required(<<"Communication link failure">>) -> true;
×
940
driver_restart_required(_) -> false.
×
941

942
%% Generate the OTP callback return tuple depending on the driver result.
943
abort_on_driver_error({Tag, Msg} = Reply, From, Timestamp) when Tag == error; Tag == aborted ->
944
    reply(From, Reply, Timestamp),
×
945
    case driver_restart_required(Msg) of
×
946
        true ->
947
            handle_reconnect(Msg, get(?STATE_KEY));
×
948
        _ ->
949
            {next_state, session_established, get(?STATE_KEY)}
×
950
    end;
951
abort_on_driver_error(Reply, From, Timestamp) ->
952
    reply(From, Reply, Timestamp),
20,598✔
953
    {next_state, session_established, get(?STATE_KEY)}.
20,598✔
954

955
-spec report_overload(state()) -> state().
956
report_overload(#state{overload_reported = PrevTime} = State) ->
957
    CurrTime = current_time(),
×
958
    case PrevTime == undefined orelse (CurrTime - PrevTime) > timer:seconds(30) of
×
959
        true ->
960
            ?ERROR_MSG("SQL connection pool is overloaded, "
×
961
                       "discarding stale requests", []),
×
962
            State#state{overload_reported = current_time()};
×
963
        false ->
964
            State
×
965
    end.
966

967
-spec reply({pid(), term()}, term(), integer()) -> term().
968
reply(From, Reply, Timestamp) ->
969
    case current_time() >= Timestamp of
20,598✔
970
        true -> ok;
×
971
        false -> p1_fsm:reply(From, Reply)
20,598✔
972
    end.
973

974
%% == pure ODBC code
975

976
%% part of init/1
977
%% Open an ODBC database connection
978
odbc_connect(SQLServer, Timeout) ->
979
    ejabberd:start_app(odbc),
×
980
    odbc:connect(binary_to_list(SQLServer),
×
981
                 [{scrollable_cursors, off},
982
                  {extended_errors, on},
983
                  {tuple_row, off},
984
                  {timeout, Timeout},
985
                  {binary_strings, on}]).
986

987
%% == Native SQLite code
988

989
%% part of init/1
990
%% Open a database connection to SQLite
991

992
sqlite_connect(Host) ->
993
    File = sqlite_file(Host),
1✔
994
    case filelib:ensure_dir(File) of
1✔
995
        ok ->
996
            case sqlite3:open(sqlite_db(Host), [{file, File}]) of
1✔
997
                {ok, Ref} ->
998
                    sqlite3:sql_exec(
1✔
999
                      sqlite_db(Host), "pragma foreign_keys = on"),
1000
                    {ok, Ref};
1✔
1001
                {error, {already_started, Ref}} ->
1002
                    {ok, Ref};
×
1003
                {error, Reason} ->
1004
                    {error, Reason}
×
1005
            end;
1006
        Err ->
1007
            Err
×
1008
    end.
1009

1010
%% Convert SQLite query result to Erlang ODBC result formalism
1011
sqlite_to_odbc(Host, ok) ->
1012
    {updated, sqlite3:changes(sqlite_db(Host))};
4,072✔
1013
sqlite_to_odbc(Host, {rowid, _}) ->
1014
    {updated, sqlite3:changes(sqlite_db(Host))};
1,868✔
1015
sqlite_to_odbc(_Host, [{columns, Columns}, {rows, TRows}]) ->
1016
    Rows = [lists:map(
6,351✔
1017
              fun(I) when is_integer(I) ->
1018
                      integer_to_binary(I);
2,847✔
1019
                 (B) ->
1020
                      B
28,201✔
1021
              end, tuple_to_list(Row)) || Row <- TRows],
6,351✔
1022
    {selected, [list_to_binary(C) || C <- Columns], Rows};
6,351✔
1023
sqlite_to_odbc(_Host, {error, _Code, Reason}) ->
1024
    {error, Reason};
×
1025
sqlite_to_odbc(_Host, _) ->
1026
    {updated, undefined}.
×
1027

1028
%% == Native PostgreSQL code
1029

1030
%% part of init/1
1031
%% Open a database connection to PostgreSQL
1032
pgsql_connect(Server, Port, DB, Username, Password, ConnectTimeout,
1033
              Transport, SSLOpts) ->
1034
    pgsql:connect([{host, Server},
1✔
1035
                   {database, DB},
1036
                   {user, Username},
1037
                   {password, Password},
1038
                   {port, Port},
1039
                   {transport, Transport},
1040
                   {connect_timeout, ConnectTimeout},
1041
                   {as_binary, true}|SSLOpts]).
1042

1043
%% Convert PostgreSQL query result to Erlang ODBC result formalism
1044
pgsql_to_odbc({ok, PGSQLResult}) ->
1045
    case PGSQLResult of
2,267✔
1046
      [Item] -> pgsql_item_to_odbc(Item);
2,267✔
1047
      Items -> [pgsql_item_to_odbc(Item) || Item <- Items]
×
1048
    end.
1049

1050
pgsql_item_to_odbc({<<"SELECT", _/binary>>, Rows,
1051
                    Recs}) ->
1052
    {selected, [element(1, Row) || Row <- Rows], Recs};
409✔
1053
pgsql_item_to_odbc({<<"FETCH", _/binary>>, Rows,
1054
                    Recs}) ->
1055
    {selected, [element(1, Row) || Row <- Rows], Recs};
×
1056
pgsql_item_to_odbc(<<"INSERT ", OIDN/binary>>) ->
1057
    [_OID, N] = str:tokens(OIDN, <<" ">>),
×
1058
    {updated, binary_to_integer(N)};
×
1059
pgsql_item_to_odbc(<<"DELETE ", N/binary>>) ->
1060
    {updated, binary_to_integer(N)};
42✔
1061
pgsql_item_to_odbc(<<"UPDATE ", N/binary>>) ->
1062
    {updated, binary_to_integer(N)};
×
1063
pgsql_item_to_odbc({error, Error}) -> {error, Error};
×
1064
pgsql_item_to_odbc(_) -> {updated, undefined}.
1,816✔
1065

1066
pgsql_execute_to_odbc({ok, {<<"SELECT", _/binary>>, Rows}}) ->
1067
    {selected, [], [[Field || {_, Field} <- Row] || Row <- Rows]};
5,925✔
1068
pgsql_execute_to_odbc({ok, {'INSERT', N}}) ->
1069
    {updated, N};
2,405✔
1070
pgsql_execute_to_odbc({ok, {'DELETE', N}}) ->
1071
    {updated, N};
1,164✔
1072
pgsql_execute_to_odbc({ok, {'UPDATE', N}}) ->
1073
    {updated, N};
4✔
1074
pgsql_execute_to_odbc({error, Error}) -> {error, Error};
×
1075
pgsql_execute_to_odbc(_) -> {updated, undefined}.
×
1076

1077

1078
%% == Native MySQL code
1079

1080
%% part of init/1
1081
%% Open a database connection to MySQL
1082
mysql_connect(Server, Port, DB, Username, Password, ConnectTimeout, Transport, SSLOpts0) ->
1083
    SSLOpts = case Transport of
1✔
1084
                  ssl ->
1085
                      [ssl_required|SSLOpts0];
×
1086
                  _ ->
1087
                      []
1✔
1088
              end,
1089
    case p1_mysql_conn:start(binary_to_list(Server), Port,
1✔
1090
                             binary_to_list(Username),
1091
                             binary_to_list(Password),
1092
                             binary_to_list(DB),
1093
                             ConnectTimeout, fun log/3, SSLOpts)
1094
        of
1095
        {ok, Ref} ->
1096
            p1_mysql_conn:fetch(
1✔
1097
                Ref, [<<"set names 'utf8mb4' collate 'utf8mb4_bin';">>], self()),
1098
            {ok, Ref};
1✔
1099
        Err -> Err
×
1100
    end.
1101

1102
%% Convert MySQL query result to Erlang ODBC result formalism
1103
mysql_to_odbc({updated, MySQLRes}) ->
1104
    {updated, p1_mysql:get_result_affected_rows(MySQLRes)};
5,368✔
1105
mysql_to_odbc({data, MySQLRes}) ->
1106
    mysql_item_to_odbc(p1_mysql:get_result_field_info(MySQLRes),
7,441✔
1107
                       p1_mysql:get_result_rows(MySQLRes));
1108
mysql_to_odbc({error, MySQLRes})
1109
  when is_binary(MySQLRes) ->
1110
    {error, MySQLRes};
×
1111
mysql_to_odbc({error, MySQLRes})
1112
  when is_list(MySQLRes) ->
1113
    {error, list_to_binary(MySQLRes)};
×
1114
mysql_to_odbc({error, MySQLRes}) ->
1115
    mysql_to_odbc({error, p1_mysql:get_result_reason(MySQLRes)});
×
1116
mysql_to_odbc(ok) ->
1117
    ok.
×
1118

1119

1120
%% When tabular data is returned, convert it to the ODBC formalism
1121
mysql_item_to_odbc(Columns, Recs) ->
1122
    {selected, [element(2, Column) || Column <- Columns], Recs}.
7,441✔
1123

1124
to_odbc({selected, Columns, Recs}) ->
1125
    Rows = [lists:map(
×
1126
              fun(I) when is_integer(I) ->
1127
                      integer_to_binary(I);
×
1128
                 (B) ->
1129
                      B
×
1130
              end, Row) || Row <- Recs],
×
1131
    {selected, [list_to_binary(C) || C <- Columns], Rows};
×
1132
to_odbc({error, Reason}) when is_list(Reason) ->
1133
    {error, list_to_binary(Reason)};
×
1134
to_odbc(Res) ->
1135
    Res.
×
1136

1137
get_db_version(#state{db_type = pgsql} = State) ->
1138
    case pgsql:squery(State#state.db_ref,
1✔
1139
                      <<"select current_setting('server_version_num')">>) of
1140
        {ok, [{_, _, [[SVersion]]}]} ->
1141
            case catch binary_to_integer(SVersion) of
1✔
1142
                Version when is_integer(Version) ->
1143
                    State#state{db_version = Version};
1✔
1144
                Error ->
1145
                    ?WARNING_MSG("Error getting pgsql version: ~p", [Error]),
×
1146
                    State
×
1147
            end;
1148
        Res ->
1149
            ?WARNING_MSG("Error getting pgsql version: ~p", [Res]),
×
1150
            State
×
1151
    end;
1152
get_db_version(#state{db_type = mysql, host = Host} = State) ->
1153
    DefaultUpsert = case lists:member(mysql_alternative_upsert, ejabberd_option:sql_flags(Host)) of
1✔
1154
                        true -> 1;
×
1155
                        _ -> 0
1✔
1156
                    end,
1157
    case mysql_to_odbc(p1_mysql_conn:squery(State#state.db_ref,
1✔
1158
                                            [<<"select version();">>], self(),
1159
                                            [{timeout, 5000},
1160
                                             {result_type, binary}])) of
1161
        {selected, _, [SVersion]} ->
1162
            case re:run(SVersion, <<"(\\d+)\\.(\\d+)(?:\\.(\\d+))?(?:-([^-]*))?">>,
1✔
1163
                        [{capture, all_but_first, binary}]) of
1164
                {match, [V1, V2, V3, Type]} ->
1165
                    V = ((bin_to_int(V1)*1000)+bin_to_int(V2))*1000+bin_to_int(V3),
1✔
1166
                    TypeA = binary_to_atom(Type, utf8),
1✔
1167
                    Flags = case TypeA of
1✔
1168
                                'MariaDB' -> DefaultUpsert;
×
1169
                                _ when V >= 5007026 andalso V < 8000000 -> 1;
×
1170
                                _ when V >= 8000020 -> 1;
1✔
1171
                                _ -> DefaultUpsert
×
1172
                            end,
1173
                    State#state{db_version = {V, TypeA, Flags}};
1✔
1174
                {match, [V1, V2, V3]} ->
1175
                    V = ((bin_to_int(V1)*1000)+bin_to_int(V2))*1000+bin_to_int(V3),
×
1176
                    Flags = case V of
×
1177
                                _ when V >= 5007026 andalso V < 8000000 -> 1;
×
1178
                                _ when V >= 8000020 -> 1;
×
1179
                                _ -> DefaultUpsert
×
1180
                            end,
1181
                    State#state{db_version = {V, unknown, Flags}};
×
1182
                _ ->
1183
                    ?WARNING_MSG("Error parsing mysql version: ~p", [SVersion]),
×
1184
                    State
×
1185
            end;
1186
        Res ->
1187
            ?WARNING_MSG("Error getting mysql version: ~p", [Res]),
×
1188
            State
×
1189
    end;
1190
get_db_version(State) ->
1191
    State.
1✔
1192

1193
bin_to_int(<<>>) -> 0;
×
1194
bin_to_int(V) -> binary_to_integer(V).
3✔
1195

1196
log(Level, Format, Args) ->
1197
    case Level of
2✔
1198
      debug -> ?DEBUG(Format, Args);
2✔
UNCOV
1199
      info -> ?INFO_MSG(Format, Args);
×
1200
      normal -> ?INFO_MSG(Format, Args);
×
1201
      error -> ?ERROR_MSG(Format, Args)
×
1202
    end.
1203

1204
db_opts(Host) ->
1205
    Type = ejabberd_option:sql_type(Host),
6✔
1206
    Server = ejabberd_option:sql_server(Host),
6✔
1207
    Timeout = ejabberd_option:sql_connect_timeout(Host),
6✔
1208
    Transport = case ejabberd_option:sql_ssl(Host) of
6✔
1209
                    false -> tcp;
6✔
1210
                    true -> ssl
×
1211
                end,
1212
    warn_if_ssl_unsupported(Transport, Type),
6✔
1213
    case Type of
6✔
1214
        odbc ->
1215
            [odbc, Server, Timeout];
×
1216
        sqlite ->
1217
            [sqlite, Host];
2✔
1218
        _ ->
1219
            Port = ejabberd_option:sql_port(Host),
4✔
1220
            DB = case ejabberd_option:sql_database(Host) of
4✔
1221
                     undefined -> <<"ejabberd">>;
×
1222
                     D -> D
4✔
1223
                 end,
1224
            User = ejabberd_option:sql_username(Host),
4✔
1225
            Pass = ejabberd_option:sql_password(Host),
4✔
1226
            SSLOpts = get_ssl_opts(Transport, Host),
4✔
1227
            case Type of
4✔
1228
                mssql ->
1229
                    case odbc_server_is_connstring(Server) of
×
1230
                        true ->
1231
                            [mssql, Server, Timeout];
×
1232
                        false ->
1233
                            Encryption = case Transport of
×
1234
                                tcp -> <<"">>;
×
1235
                                ssl -> <<";ENCRYPTION=require;ENCRYPT=yes">>
×
1236
                            end,
1237
                            [mssql, <<"DRIVER=ODBC;SERVER=", Server/binary, ";DATABASE=", DB/binary,
×
1238
                                      ";UID=", User/binary, ";PWD=", Pass/binary,
1239
                                      ";PORT=", (integer_to_binary(Port))/binary, Encryption/binary,
1240
                                      ";CLIENT_CHARSET=UTF-8;">>, Timeout]
1241
                    end;
1242
                _ ->
1243
                    [Type, Server, Port, DB, User, Pass, Timeout, Transport, SSLOpts]
4✔
1244
            end
1245
    end.
1246

1247
warn_if_ssl_unsupported(tcp, _) ->
1248
    ok;
6✔
1249
warn_if_ssl_unsupported(ssl, pgsql) ->
1250
    ok;
×
1251
warn_if_ssl_unsupported(ssl, mssql) ->
1252
    ok;
×
1253
warn_if_ssl_unsupported(ssl, mysql) ->
1254
    ok;
×
1255
warn_if_ssl_unsupported(ssl, Type) ->
1256
    ?WARNING_MSG("SSL connection is not supported for ~ts", [Type]).
×
1257

1258
get_ssl_opts(ssl, Host) ->
1259
    Opts1 = case ejabberd_option:sql_ssl_certfile(Host) of
×
1260
                undefined -> [];
×
1261
                CertFile -> [{certfile, CertFile}]
×
1262
            end,
1263
    Opts2 = case ejabberd_option:sql_ssl_cafile(Host) of
×
1264
                undefined -> Opts1;
×
1265
                CAFile -> [{cacertfile, CAFile}|Opts1]
×
1266
            end,
1267
    case ejabberd_option:sql_ssl_verify(Host) of
×
1268
        true ->
1269
            case lists:keymember(cacertfile, 1, Opts2) of
×
1270
                true ->
1271
                    [{verify, verify_peer}|Opts2];
×
1272
                false ->
1273
                    ?WARNING_MSG("SSL verification is enabled for "
×
1274
                                 "SQL connection, but option "
1275
                                 "'sql_ssl_cafile' is not set; "
1276
                                 "verification will be disabled", []),
×
1277
                    Opts2
×
1278
            end;
1279
        false ->
1280
            [{verify, verify_none}|Opts2]
×
1281
    end;
1282
get_ssl_opts(tcp, _) ->
1283
    [].
4✔
1284

1285
init_mssql_odbcinst(Host) ->
1286
    Driver = ejabberd_option:sql_odbc_driver(Host),
×
1287
    ODBCINST = io_lib:fwrite("[ODBC]~n"
×
1288
                             "Driver = ~s~n", [Driver]),
1289
    ?DEBUG("~ts:~n~ts", [odbcinst_config(), ODBCINST]),
×
1290
    case filelib:ensure_dir(odbcinst_config()) of
×
1291
        ok ->
1292
            try
×
1293
                ok = write_file_if_new(odbcinst_config(), ODBCINST),
×
1294
                os:putenv("ODBCSYSINI", tmp_dir()),
×
1295
                ok
×
1296
            catch error:{badmatch, {error, Reason} = Err} ->
1297
                    ?ERROR_MSG("Failed to create temporary files in ~ts: ~ts",
×
1298
                               [tmp_dir(), file:format_error(Reason)]),
×
1299
                    Err
×
1300
            end;
1301
        {error, Reason} = Err ->
1302
            ?ERROR_MSG("Failed to create temporary directory ~ts: ~ts",
×
1303
                       [tmp_dir(), file:format_error(Reason)]),
×
1304
            Err
×
1305
    end.
1306

1307
init_mssql(Host) ->
1308
    Server = ejabberd_option:sql_server(Host),
×
1309
    case odbc_server_is_connstring(Server) of
×
1310
        true -> ok;
×
1311
        false -> init_mssql_odbcinst(Host)
×
1312
    end.
1313

1314
odbc_server_is_connstring(Server) ->
1315
    case binary:match(Server, <<"=">>) of
×
1316
        nomatch -> false;
×
1317
        _ -> true
×
1318
    end.
1319

1320
write_file_if_new(File, Payload) ->
1321
    case filelib:is_file(File) of
×
1322
        true -> ok;
×
1323
        false -> file:write_file(File, Payload)
×
1324
    end.
1325

1326
tmp_dir() ->
1327
    case os:type() of
1✔
1328
        {win32, _} -> filename:join([os:getenv("HOME"), "conf"]);
×
1329
        _ -> filename:join(["/tmp", "ejabberd"])
1✔
1330
    end.
1331

1332
odbcinst_config() ->
1333
    filename:join(tmp_dir(), "odbcinst.ini").
1✔
1334

1335
max_fsm_queue() ->
1336
    proplists:get_value(max_queue, fsm_limit_opts(), unlimited).
3✔
1337

1338
fsm_limit_opts() ->
1339
    ejabberd_config:fsm_limit_opts([]).
6✔
1340

1341
query_timeout(LServer) ->
1342
    ejabberd_option:sql_query_timeout(LServer).
48,006✔
1343

1344
current_time() ->
1345
    erlang:monotonic_time(millisecond).
61,794✔
1346

1347
%% ***IMPORTANT*** This error format requires extended_errors turned on.
1348
extended_error({"08S01", _, Reason}) ->
1349
    % TCP Provider: The specified network name is no longer available
1350
    ?DEBUG("ODBC Link Failure: ~ts", [Reason]),
×
1351
    <<"Communication link failure">>;
×
1352
extended_error({"08001", _, Reason}) ->
1353
    % Login timeout expired
1354
    ?DEBUG("ODBC Connect Timeout: ~ts", [Reason]),
×
1355
    <<"SQL connection failed">>;
×
1356
extended_error({"IMC01", _, Reason}) ->
1357
    % The connection is broken and recovery is not possible
1358
    ?DEBUG("ODBC Link Failure: ~ts", [Reason]),
×
1359
    <<"Communication link failure">>;
×
1360
extended_error({"IMC06", _, Reason}) ->
1361
    % The connection is broken and recovery is not possible
1362
    ?DEBUG("ODBC Link Failure: ~ts", [Reason]),
×
1363
    <<"Communication link failure">>;
×
1364
extended_error({Code, _, Reason}) ->
1365
    ?DEBUG("ODBC Error ~ts: ~ts", [Code, Reason]),
×
1366
    iolist_to_binary(Reason);
×
1367
extended_error(Error) ->
1368
    Error.
×
1369

1370
check_error({error, Why} = Err, _Query) when Why == killed ->
1371
    Err;
×
1372
check_error({error, Why}, #sql_query{} = Query) ->
1373
    Err = extended_error(Why),
×
1374
    ?ERROR_MSG("SQL query '~ts' at ~p failed: ~p",
×
1375
               [Query#sql_query.hash, Query#sql_query.loc, Err]),
×
1376
    {error, Err};
×
1377
check_error({error, Why}, Query) ->
1378
    Err = extended_error(Why),
×
1379
    case catch iolist_to_binary(Query) of
×
1380
        SQuery when is_binary(SQuery) ->
1381
            ?ERROR_MSG("SQL query '~ts' failed: ~p", [SQuery, Err]);
×
1382
        _ ->
1383
            ?ERROR_MSG("SQL query ~p failed: ~p", [Query, Err])
×
1384
    end,
1385
    {error, Err};
×
1386
check_error(Result, _Query) ->
1387
    Result.
46,931✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc