• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

processone / ejabberd / 747

27 Jun 2024 01:43PM UTC coverage: 32.123% (+0.8%) from 31.276%
747

push

github

badlop
Set version to 24.06

14119 of 43953 relevant lines covered (32.12%)

614.73 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

44.5
/src/ejabberd_sql.erl
1
%%%----------------------------------------------------------------------
2
%%% File    : ejabberd_sql.erl
3
%%% Author  : Alexey Shchepin <alexey@process-one.net>
4
%%% Purpose : Serve SQL connection
5
%%% Created :  8 Dec 2004 by Alexey Shchepin <alexey@process-one.net>
6
%%%
7
%%%
8
%%% ejabberd, Copyright (C) 2002-2024   ProcessOne
9
%%%
10
%%% This program is free software; you can redistribute it and/or
11
%%% modify it under the terms of the GNU General Public License as
12
%%% published by the Free Software Foundation; either version 2 of the
13
%%% License, or (at your option) any later version.
14
%%%
15
%%% This program is distributed in the hope that it will be useful,
16
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
17
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18
%%% General Public License for more details.
19
%%%
20
%%% You should have received a copy of the GNU General Public License along
21
%%% with this program; if not, write to the Free Software Foundation, Inc.,
22
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
23
%%%
24
%%%----------------------------------------------------------------------
25

26
-module(ejabberd_sql).
27

28
-author('alexey@process-one.net').
29

30
-behaviour(p1_fsm).
31

32
%% External exports
33
-export([start_link/2,
34
         sql_query/2,
35
         sql_query_t/1,
36
         sql_transaction/2,
37
         sql_bloc/2,
38
         abort/1,
39
         restart/1,
40
         use_new_schema/0,
41
         sql_query_to_iolist/1,
42
         sql_query_to_iolist/2,
43
         escape/1,
44
         standard_escape/1,
45
         escape_like/1,
46
         escape_like_arg/1,
47
         escape_like_arg_circumflex/1,
48
         to_string_literal/2,
49
         to_string_literal_t/1,
50
         to_bool/1,
51
         sqlite_db/1,
52
         sqlite_file/1,
53
         encode_term/1,
54
         decode_term/1,
55
         odbcinst_config/0,
56
         init_mssql/1,
57
         keep_alive/2,
58
         to_list/2,
59
         to_array/2,
60
         parse_mysql_version/2]).
61

62
%% gen_fsm callbacks
63
-export([init/1, handle_event/3, handle_sync_event/4,
64
         handle_info/3, terminate/3, print_state/1,
65
         code_change/4]).
66

67
-export([connecting/2, connecting/3,
68
         session_established/2, session_established/3]).
69

70
-ifdef(ODBC_HAS_TYPES).
71
    -type(odbc_connection_reference() ::  odbc:connection_reference()).
72
-else.
73
    -type(odbc_connection_reference() ::  pid()).
74
-endif.
75

76
-include("logger.hrl").
77
-include("ejabberd_sql_pt.hrl").
78
-include("ejabberd_stacktrace.hrl").
79

80
-record(state,
81
        {db_ref               :: undefined | pid() | odbc_connection_reference(),
82
         db_type = odbc       :: pgsql | mysql | sqlite | odbc | mssql,
83
         db_version           :: undefined | non_neg_integer() | {non_neg_integer(), atom(), non_neg_integer()},
84
         reconnect_count = 0  :: non_neg_integer(),
85
         host                 :: binary(),
86
         pending_requests     :: p1_queue:queue(),
87
         overload_reported    :: undefined | integer()}).
88

89
-define(STATE_KEY, ejabberd_sql_state).
90
-define(NESTING_KEY, ejabberd_sql_nesting_level).
91
-define(TOP_LEVEL_TXN, 0).
92
-define(MAX_TRANSACTION_RESTARTS, 10).
93
-define(KEEPALIVE_QUERY, [<<"SELECT 1;">>]).
94
-define(PREPARE_KEY, ejabberd_sql_prepare).
95
%%-define(DBGFSM, true).
96
-ifdef(DBGFSM).
97
-define(FSMOPTS, [{debug, [trace]}]).
98
-else.
99
-define(FSMOPTS, []).
100
-endif.
101

102
-type state() :: #state{}.
103
-type sql_query_simple(T) :: [sql_query(T) | binary()] | binary() |
104
                             #sql_query{} |
105
                             fun(() -> T) | fun((atom(), _) -> T).
106
-type sql_query(T) :: sql_query_simple(T) |
107
                      [{atom() | {atom(), any()}, sql_query_simple(T)}].
108
-type sql_query_result(T) :: {updated, non_neg_integer()} |
109
                             {error, binary() | atom()} |
110
                             {selected, [binary()], [[binary()]]} |
111
                             {selected, [any()]} |
112
                             T.
113

114
%%%----------------------------------------------------------------------
115
%%% API
116
%%%----------------------------------------------------------------------
117
-spec start_link(binary(), pos_integer()) -> {ok, pid()} | {error, term()}.
118
start_link(Host, I) ->
119
    Proc = binary_to_atom(get_worker_name(Host, I), utf8),
3✔
120
    p1_fsm:start_link({local, Proc}, ?MODULE, [Host],
3✔
121
                      fsm_limit_opts() ++ ?FSMOPTS).
122

123
-spec sql_query(binary(), sql_query(T)) -> sql_query_result(T).
124
sql_query(Host, Query) ->
125
    sql_call(Host, {sql_query, Query}).
15,618✔
126

127
%% SQL transaction based on a list of queries
128
%% This function automatically
129
-spec sql_transaction(binary(), [sql_query(T)] | fun(() -> T)) ->
130
                             {atomic, T} |
131
                             {aborted, any()}.
132
sql_transaction(Host, Queries)
133
    when is_list(Queries) ->
134
    F = fun () ->
3✔
135
                lists:foreach(fun (Query) -> sql_query_t(Query) end,
3✔
136
                              Queries)
137
        end,
138
    sql_transaction(Host, F);
3✔
139
%% SQL transaction, based on a erlang anonymous function (F = fun)
140
sql_transaction(Host, F) when is_function(F) ->
141
    case sql_call(Host, {sql_transaction, F}) of
2,724✔
142
        {atomic, _} = Ret -> Ret;
2,724✔
143
        {aborted, _} = Ret -> Ret;
×
144
        Err -> {aborted, Err}
×
145
    end.
146

147
%% SQL bloc, based on a erlang anonymous function (F = fun)
148
sql_bloc(Host, F) -> sql_call(Host, {sql_bloc, F}).
2,664✔
149

150
sql_call(Host, Msg) ->
151
    Timeout = query_timeout(Host),
21,006✔
152
    case get(?STATE_KEY) of
21,006✔
153
        undefined ->
154
            sync_send_event(Host,
20,964✔
155
                            {sql_cmd, Msg, current_time() + Timeout},
156
                            Timeout);
157
        _State ->
158
            nested_op(Msg)
42✔
159
    end.
160

161
keep_alive(Host, Proc) ->
162
    Timeout = query_timeout(Host),
×
163
    case sync_send_event(
×
164
           Proc,
165
           {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}, current_time() + Timeout},
166
           Timeout) of
167
        {selected,_,[[<<"1">>]]} ->
168
            ok;
×
169
        _Err ->
170
            ?ERROR_MSG("Keep alive query failed, closing connection: ~p", [_Err]),
×
171
            sync_send_event(Proc, force_timeout, Timeout)
×
172
    end.
173

174
sync_send_event(Host, Msg, Timeout) when is_binary(Host) ->
175
    case ejabberd_sql_sup:start(Host) of
20,964✔
176
        ok ->
177
            Proc = get_worker(Host),
20,964✔
178
            sync_send_event(Proc, Msg, Timeout);
20,964✔
179
        {error, _} = Err ->
180
            Err
×
181
    end;
182
sync_send_event(Proc, Msg, Timeout) ->
183
    try p1_fsm:sync_send_event(Proc, Msg, Timeout)
20,964✔
184
    catch _:{Reason, {p1_fsm, _, _}} ->
185
            {error, Reason}
×
186
    end.
187

188
-spec sql_query_t(sql_query(T)) -> sql_query_result(T).
189
%% This function is intended to be used from inside an sql_transaction:
190
sql_query_t(Query) ->
191
    QRes = sql_query_internal(Query),
19,848✔
192
    case QRes of
19,848✔
193
      {error, Reason} -> restart(Reason);
×
194
      Rs when is_list(Rs) ->
195
          case lists:keysearch(error, 1, Rs) of
×
196
            {value, {error, Reason}} -> restart(Reason);
×
197
            _ -> QRes
×
198
          end;
199
      _ -> QRes
19,848✔
200
    end.
201

202
abort(Reason) ->
203
    exit(Reason).
×
204

205
restart(Reason) ->
206
    throw({aborted, Reason}).
×
207

208
-spec escape_char(char()) -> binary().
209
escape_char($\000) -> <<"\\0">>;
×
210
escape_char($\n) -> <<"\\n">>;
×
211
escape_char($\t) -> <<"\\t">>;
×
212
escape_char($\b) -> <<"\\b">>;
×
213
escape_char($\r) -> <<"\\r">>;
×
214
escape_char($') -> <<"''">>;
16✔
215
escape_char($") -> <<"\\\"">>;
16✔
216
escape_char($\\) -> <<"\\\\">>;
432✔
217
escape_char(C) -> <<C>>.
20,079✔
218

219
-spec escape(binary()) -> binary().
220
escape(S) ->
221
        <<  <<(escape_char(Char))/binary>> || <<Char>> <= S >>.
820✔
222

223
%% Escape character that will confuse an SQL engine
224
%% Percent and underscore only need to be escaped for pattern matching like
225
%% statement
226
escape_like(S) when is_binary(S) ->
227
    << <<(escape_like(C))/binary>> || <<C>> <= S >>;
×
228
escape_like($%) -> <<"\\%">>;
×
229
escape_like($_) -> <<"\\_">>;
×
230
escape_like($\\) -> <<"\\\\\\\\">>;
×
231
escape_like(C) when is_integer(C), C >= 0, C =< 255 -> escape_char(C).
×
232

233
escape_like_arg(S) when is_binary(S) ->
234
    << <<(escape_like_arg(C))/binary>> || <<C>> <= S >>;
924✔
235
escape_like_arg($%) -> <<"\\%">>;
546✔
236
escape_like_arg($_) -> <<"\\_">>;
1,086✔
237
escape_like_arg($\\) -> <<"\\\\">>;
546✔
238
escape_like_arg($[) -> <<"\\[">>;     % For MSSQL
546✔
239
escape_like_arg($]) -> <<"\\]">>;
546✔
240
escape_like_arg(C) when is_integer(C), C >= 0, C =< 255 -> <<C>>.
29,140✔
241

242
escape_like_arg_circumflex(S) when is_binary(S) ->
243
    << <<(escape_like_arg_circumflex(C))/binary>> || <<C>> <= S >>;
×
244
escape_like_arg_circumflex($%) -> <<"^%">>;
×
245
escape_like_arg_circumflex($_) -> <<"^_">>;
×
246
escape_like_arg_circumflex($^) -> <<"^^">>;
×
247
escape_like_arg_circumflex($[) -> <<"^[">>;     % For MSSQL
×
248
escape_like_arg_circumflex($]) -> <<"^]">>;
×
249
escape_like_arg_circumflex(C) when is_integer(C), C >= 0, C =< 255 -> <<C>>.
×
250

251
to_bool(<<"t">>) -> true;
×
252
to_bool(<<"true">>) -> true;
×
253
to_bool(<<"1">>) -> true;
472✔
254
to_bool(true) -> true;
×
255
to_bool(1) -> true;
292✔
256
to_bool(_) -> false.
2,263✔
257

258
to_list(EscapeFun, Val) ->
259
    Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)),
×
260
    [<<"(">>, Escaped, <<")">>].
×
261

262
to_array(EscapeFun, Val) ->
263
    Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)),
×
264
    lists:flatten([<<"{">>, Escaped, <<"}">>]).
×
265

266
to_string_literal(odbc, S) ->
267
    <<"'", (escape(S))/binary, "'">>;
×
268
to_string_literal(mysql, S) ->
269
    <<"'", (escape(S))/binary, "'">>;
410✔
270
to_string_literal(mssql, S) ->
271
    <<"'", (standard_escape(S))/binary, "'">>;
×
272
to_string_literal(sqlite, S) ->
273
    <<"'", (standard_escape(S))/binary, "'">>;
410✔
274
to_string_literal(pgsql, S) ->
275
    <<"E'", (escape(S))/binary, "'">>.
410✔
276

277
to_string_literal_t(S) ->
278
    State = get(?STATE_KEY),
15✔
279
    to_string_literal(State#state.db_type, S).
15✔
280

281
encode_term(Term) ->
282
    escape(list_to_binary(
×
283
             erl_prettypr:format(erl_syntax:abstract(Term),
284
                                 [{paper, 65535}, {ribbon, 65535}]))).
285

286
decode_term(Bin) ->
287
    Str = binary_to_list(<<Bin/binary, ".">>),
1,645✔
288
    try
1,645✔
289
        {ok, Tokens, _} = erl_scan:string(Str),
1,645✔
290
        {ok, Term} = erl_parse:parse_term(Tokens),
1,645✔
291
        Term
1,645✔
292
    catch _:{badmatch, {error, {Line, Mod, Reason}, _}} ->
293
            ?ERROR_MSG("Corrupted Erlang term in SQL database:~n"
×
294
                       "** Scanner error: at line ~B: ~ts~n"
295
                       "** Term: ~ts",
296
                       [Line, Mod:format_error(Reason), Bin]),
×
297
            erlang:error(badarg);
×
298
          _:{badmatch, {error, {Line, Mod, Reason}}} ->
299
            ?ERROR_MSG("Corrupted Erlang term in SQL database:~n"
×
300
                       "** Parser error: at line ~B: ~ts~n"
301
                       "** Term: ~ts",
302
                       [Line, Mod:format_error(Reason), Bin]),
×
303
            erlang:error(badarg)
×
304
    end.
305

306
-spec sqlite_db(binary()) -> atom().
307
sqlite_db(Host) ->
308
    list_to_atom("ejabberd_sqlite_" ++ binary_to_list(Host)).
18,398✔
309

310
-spec sqlite_file(binary()) -> string().
311
sqlite_file(Host) ->
312
    case ejabberd_option:sql_database(Host) of
2✔
313
        undefined ->
314
            Path = ["sqlite", atom_to_list(node()),
2✔
315
                    binary_to_list(Host), "ejabberd.db"],
316
            case file:get_cwd() of
2✔
317
                {ok, Cwd} ->
318
                    filename:join([Cwd|Path]);
2✔
319
                {error, Reason} ->
320
                    ?ERROR_MSG("Failed to get current directory: ~ts",
×
321
                               [file:format_error(Reason)]),
×
322
                    filename:join(Path)
×
323
            end;
324
        File ->
325
            binary_to_list(File)
×
326
    end.
327

328
use_new_schema() ->
329
    ejabberd_option:new_sql_schema().
16,431✔
330

331
-spec get_worker(binary()) -> atom().
332
get_worker(Host) ->
333
    PoolSize = ejabberd_option:sql_pool_size(Host),
20,964✔
334
    I = p1_rand:round_robin(PoolSize) + 1,
20,964✔
335
    binary_to_existing_atom(get_worker_name(Host, I), utf8).
20,964✔
336

337
-spec get_worker_name(binary(), pos_integer()) -> binary().
338
get_worker_name(Host, I) ->
339
    <<"ejabberd_sql_", Host/binary, $_, (integer_to_binary(I))/binary>>.
20,967✔
340

341
%%%----------------------------------------------------------------------
342
%%% Callback functions from gen_fsm
343
%%%----------------------------------------------------------------------
344
init([Host]) ->
345
    process_flag(trap_exit, true),
3✔
346
    case ejabberd_option:sql_keepalive_interval(Host) of
3✔
347
        undefined ->
348
            ok;
3✔
349
        KeepaliveInterval ->
350
            timer:apply_interval(KeepaliveInterval, ?MODULE,
×
351
                                 keep_alive, [Host, self()])
352
    end,
353
    [DBType | _] = db_opts(Host),
3✔
354
    p1_fsm:send_event(self(), connect),
3✔
355
    QueueType = ejabberd_option:sql_queue_type(Host),
3✔
356
    {ok, connecting,
3✔
357
     #state{db_type = DBType, host = Host,
358
            pending_requests = p1_queue:new(QueueType, max_fsm_queue())}}.
359

360
connecting(connect, #state{host = Host} = State) ->
361
    ConnectRes = case db_opts(Host) of
3✔
362
                     [mysql | Args] -> apply(fun mysql_connect/8, Args);
1✔
363
                     [pgsql | Args] -> apply(fun pgsql_connect/8, Args);
1✔
364
                     [sqlite | Args] -> apply(fun sqlite_connect/1, Args);
1✔
365
                     [mssql | Args] -> apply(fun odbc_connect/2, Args);
×
366
                     [odbc | Args] -> apply(fun odbc_connect/2, Args)
×
367
                 end,
368
    case ConnectRes of
3✔
369
        {ok, Ref} ->
370
            try link(Ref) of
3✔
371
                _ ->
372
                    lists:foreach(
3✔
373
                      fun({{?PREPARE_KEY, _} = Key, _}) ->
374
                              erase(Key);
×
375
                         (_) ->
376
                              ok
9✔
377
                      end, get()),
378
                    PendingRequests =
3✔
379
                        p1_queue:dropwhile(
380
                          fun(Req) ->
381
                                  p1_fsm:send_event(self(), Req),
×
382
                                  true
×
383
                          end, State#state.pending_requests),
384
                    State1 = State#state{db_ref = Ref,
3✔
385
                                         pending_requests = PendingRequests},
386
                    State2 = get_db_version(State1),
3✔
387
                    {next_state, session_established, State2#state{reconnect_count = 0}}
3✔
388
            catch _:Reason ->
389
                    handle_reconnect(Reason, State)
×
390
            end;
391
        {error, Reason} ->
392
            handle_reconnect(Reason, State)
×
393
    end;
394
connecting(Event, State) ->
395
    ?WARNING_MSG("Unexpected event in 'connecting': ~p",
×
396
                 [Event]),
×
397
    {next_state, connecting, State}.
×
398

399
connecting({sql_cmd, {sql_query, ?KEEPALIVE_QUERY}, Timestamp},
400
           From, State) ->
401
    reply(From, {error, <<"SQL connection failed">>}, Timestamp),
×
402
    {next_state, connecting, State};
×
403
connecting({sql_cmd, Command, Timestamp} = Req, From,
404
           State) ->
405
    ?DEBUG("Queuing pending request while connecting:~n\t~p",
×
406
           [Req]),
×
407
    PendingRequests =
×
408
        try p1_queue:in({sql_cmd, Command, From, Timestamp},
×
409
                        State#state.pending_requests)
410
        catch error:full ->
411
                Err = <<"SQL request queue is overfilled">>,
×
412
                ?ERROR_MSG("~ts, bouncing all pending requests", [Err]),
×
413
                Q = p1_queue:dropwhile(
×
414
                      fun({sql_cmd, _, To, TS}) ->
415
                              reply(To, {error, Err}, TS),
×
416
                              true
×
417
                      end, State#state.pending_requests),
418
                p1_queue:in({sql_cmd, Command, From, Timestamp}, Q)
×
419
        end,
420
    {next_state, connecting,
×
421
     State#state{pending_requests = PendingRequests}};
422
connecting(Request, {Who, _Ref}, State) ->
423
    ?WARNING_MSG("Unexpected call ~p from ~p in 'connecting'",
×
424
                 [Request, Who]),
×
425
    {next_state, connecting, State}.
×
426

427
session_established({sql_cmd, Command, Timestamp}, From,
428
                    State) ->
429
    run_sql_cmd(Command, From, State, Timestamp);
20,964✔
430
session_established(Request, {Who, _Ref}, State) ->
431
    ?WARNING_MSG("Unexpected call ~p from ~p in 'session_established'",
×
432
                 [Request, Who]),
×
433
    {next_state, session_established, State}.
×
434

435
session_established({sql_cmd, Command, From, Timestamp},
436
                    State) ->
437
    run_sql_cmd(Command, From, State, Timestamp);
×
438
session_established(force_timeout, State) ->
439
    {stop, timeout, State};
×
440
session_established(Event, State) ->
441
    ?WARNING_MSG("Unexpected event in 'session_established': ~p",
×
442
                 [Event]),
×
443
    {next_state, session_established, State}.
×
444

445
handle_event(_Event, StateName, State) ->
446
    {next_state, StateName, State}.
×
447

448
handle_sync_event(_Event, _From, StateName, State) ->
449
    {reply, {error, badarg}, StateName, State}.
×
450

451
code_change(_OldVsn, StateName, State, _Extra) ->
452
    {ok, StateName, State}.
×
453

454
handle_info({'EXIT', _Pid, _Reason}, connecting, State) ->
455
    {next_state, connecting, State};
×
456
handle_info({'EXIT', _Pid, Reason}, _StateName, State) ->
457
    handle_reconnect(Reason, State);
×
458
handle_info(Info, StateName, State) ->
459
    ?WARNING_MSG("Unexpected info in ~p: ~p",
×
460
                 [StateName, Info]),
×
461
    {next_state, StateName, State}.
×
462

463
terminate(_Reason, _StateName, State) ->
464
    case State#state.db_type of
3✔
465
        mysql -> catch p1_mysql_conn:stop(State#state.db_ref);
1✔
466
        sqlite -> catch sqlite3:close(sqlite_db(State#state.host));
1✔
467
        _ -> ok
1✔
468
    end,
469
    ok.
3✔
470

471
%%----------------------------------------------------------------------
472
%% Func: print_state/1
473
%% Purpose: Prepare the state to be printed on error log
474
%% Returns: State to print
475
%%----------------------------------------------------------------------
476
print_state(State) -> State.
×
477

478
%%%----------------------------------------------------------------------
479
%%% Internal functions
480
%%%----------------------------------------------------------------------
481
handle_reconnect(Reason, #state{host = Host, reconnect_count = RC} = State) ->
482
    StartInterval0 = ejabberd_option:sql_start_interval(Host),
×
483
    StartInterval = case RC of
×
484
                        0 -> erlang:min(5000, StartInterval0);
×
485
                        _ -> StartInterval0
×
486
                    end,
487
    ?WARNING_MSG("~p connection failed:~n"
×
488
                 "** Reason: ~p~n"
489
                 "** Retry after: ~B seconds",
490
                 [State#state.db_type, Reason,
491
                  StartInterval div 1000]),
×
492
    case State#state.db_type of
×
493
        mysql -> catch p1_mysql_conn:stop(State#state.db_ref);
×
494
        sqlite -> catch sqlite3:close(sqlite_db(State#state.host));
×
495
        pgsql -> catch pgsql:terminate(State#state.db_ref);
×
496
        _ -> ok
×
497
    end,
498
    p1_fsm:send_event_after(StartInterval, connect),
×
499
    {next_state, connecting, State#state{reconnect_count = RC + 1}}.
×
500

501
run_sql_cmd(Command, From, State, Timestamp) ->
502
    case current_time() >= Timestamp of
20,964✔
503
        true ->
504
            State1 = report_overload(State),
×
505
            {next_state, session_established, State1};
×
506
        false ->
507
            receive
20,964✔
508
                {'EXIT', _Pid, Reason} ->
509
                    PR = p1_queue:in({sql_cmd, Command, From, Timestamp},
×
510
                                     State#state.pending_requests),
511
                    handle_reconnect(Reason, State#state{pending_requests = PR})
×
512
            after 0 ->
513
                put(?NESTING_KEY, ?TOP_LEVEL_TXN),
20,964✔
514
                put(?STATE_KEY, State),
20,964✔
515
                abort_on_driver_error(outer_op(Command), From, Timestamp)
20,964✔
516
            end
517
    end.
518

519
%% @doc Only called by handle_call, only handles top level operations.
520
-spec outer_op(Op::{atom(), binary()}) ->
521
    {error, Reason::binary()} | {aborted, Reason::binary()} | {atomic, Result::any()}.
522
outer_op({sql_query, Query}) ->
523
    sql_query_internal(Query);
15,609✔
524
outer_op({sql_transaction, F}) ->
525
    outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, <<"">>);
2,718✔
526
outer_op({sql_bloc, F}) -> execute_bloc(F).
2,637✔
527

528
%% Called via sql_query/transaction/bloc from client code when inside a
529
%% nested operation
530
nested_op({sql_query, Query}) ->
531
    sql_query_internal(Query);
9✔
532
nested_op({sql_transaction, F}) ->
533
    NestingLevel = get(?NESTING_KEY),
6✔
534
    if NestingLevel =:= (?TOP_LEVEL_TXN) ->
6✔
535
           outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, <<"">>);
3✔
536
       true -> inner_transaction(F)
3✔
537
    end;
538
nested_op({sql_bloc, F}) -> execute_bloc(F).
27✔
539

540
%% Never retry nested transactions - only outer transactions
541
inner_transaction(F) ->
542
    PreviousNestingLevel = get(?NESTING_KEY),
3✔
543
    case get(?NESTING_KEY) of
3✔
544
      ?TOP_LEVEL_TXN ->
545
          {backtrace, T} = process_info(self(), backtrace),
×
546
          ?ERROR_MSG("Inner transaction called at outer txn "
×
547
                     "level. Trace: ~ts",
548
                     [T]),
×
549
          erlang:exit(implementation_faulty);
×
550
      _N -> ok
3✔
551
    end,
552
    put(?NESTING_KEY, PreviousNestingLevel + 1),
3✔
553
    Result = (catch F()),
3✔
554
    put(?NESTING_KEY, PreviousNestingLevel),
3✔
555
    case Result of
3✔
556
      {aborted, Reason} -> {aborted, Reason};
×
557
      {'EXIT', Reason} -> {'EXIT', Reason};
×
558
      {atomic, Res} -> {atomic, Res};
×
559
      Res -> {atomic, Res}
3✔
560
    end.
561

562
outer_transaction(F, NRestarts, _Reason) ->
563
    PreviousNestingLevel = get(?NESTING_KEY),
2,721✔
564
    case get(?NESTING_KEY) of
2,721✔
565
      ?TOP_LEVEL_TXN -> ok;
2,721✔
566
      _N ->
567
          {backtrace, T} = process_info(self(), backtrace),
×
568
          ?ERROR_MSG("Outer transaction called at inner txn "
×
569
                     "level. Trace: ~ts",
570
                     [T]),
×
571
          erlang:exit(implementation_faulty)
×
572
    end,
573
    case sql_begin() of
2,721✔
574
        {error, Reason} ->
575
            maybe_restart_transaction(F, NRestarts, Reason, false);
×
576
        _ ->
577
            put(?NESTING_KEY, PreviousNestingLevel + 1),
2,721✔
578
            try F() of
2,721✔
579
                Res ->
580
                    case sql_commit() of
2,721✔
581
                        {error, Reason} ->
582
                            restart(Reason);
×
583
                        _ ->
584
                            {atomic, Res}
2,721✔
585
                    end
586
            catch
587
                ?EX_RULE(throw, {aborted, Reason}, _) when NRestarts > 0 ->
588
                    maybe_restart_transaction(F, NRestarts, Reason, true);
×
589
                ?EX_RULE(throw, {aborted, Reason}, Stack) when NRestarts =:= 0 ->
590
                    StackTrace = ?EX_STACK(Stack),
×
591
                    ?ERROR_MSG("SQL transaction restarts exceeded~n** "
×
592
                               "Restarts: ~p~n** Last abort reason: "
593
                               "~p~n** Stacktrace: ~p~n** When State "
594
                               "== ~p",
595
                               [?MAX_TRANSACTION_RESTARTS, Reason,
596
                                StackTrace, get(?STATE_KEY)]),
×
597
                    maybe_restart_transaction(F, NRestarts, Reason, true);
×
598
                ?EX_RULE(exit, Reason, _) ->
599
                    maybe_restart_transaction(F, 0, Reason, true)
×
600
            end
601
    end.
602

603
maybe_restart_transaction(F, NRestarts, Reason, DoRollback) ->
604
    Res = case driver_restart_required(Reason) of
×
605
              true ->
606
                  {aborted, Reason};
×
607
              _ when DoRollback ->
608
                  case sql_rollback() of
×
609
                      {error, Reason2} ->
610
                          case driver_restart_required(Reason2) of
×
611
                              true ->
612
                                  {aborted, Reason2};
×
613
                              _ ->
614
                                  continue
×
615
                          end;
616
                      _ ->
617
                          continue
×
618
                  end;
619
              _ ->
620
                  continue
×
621
    end,
622
    case Res of
×
623
        continue when NRestarts > 0 ->
624
            put(?NESTING_KEY, ?TOP_LEVEL_TXN),
×
625
            outer_transaction(F, NRestarts - 1, Reason);
×
626
        continue ->
627
            {aborted, Reason};
×
628
        Other ->
629
            Other
×
630
    end.
631

632
execute_bloc(F) ->
633
    case catch F() of
2,664✔
634
      {aborted, Reason} -> {aborted, Reason};
×
635
      {'EXIT', Reason} -> {aborted, Reason};
×
636
      Res -> {atomic, Res}
2,664✔
637
    end.
638

639
execute_fun(F) when is_function(F, 0) ->
640
    F();
61✔
641
execute_fun(F) when is_function(F, 2) ->
642
    State = get(?STATE_KEY),
3,643✔
643
    F(State#state.db_type, State#state.db_version).
3,643✔
644

645
sql_query_internal([{_, _} | _] = Queries) ->
646
    State = get(?STATE_KEY),
5,442✔
647
    case select_sql_query(Queries, State) of
5,442✔
648
        undefined ->
649
            {error, <<"no matching query for the current DBMS found">>};
×
650
        Query ->
651
            sql_query_internal(Query)
5,442✔
652
    end;
653
sql_query_internal(#sql_query{} = Query) ->
654
    State = get(?STATE_KEY),
30,391✔
655
    Res =
30,391✔
656
        try
657
            case State#state.db_type of
30,391✔
658
                odbc ->
659
                    generic_sql_query(Query);
×
660
                mssql ->
661
                    mssql_sql_query(Query);
×
662
                pgsql ->
663
                    Key = {?PREPARE_KEY, Query#sql_query.hash},
9,577✔
664
                    case get(Key) of
9,577✔
665
                        undefined ->
666
                            Host = State#state.host,
134✔
667
                            PreparedStatements =
134✔
668
                                ejabberd_option:sql_prepared_statements(Host),
669
                            case PreparedStatements of
134✔
670
                                false ->
671
                                    put(Key, ignore);
×
672
                                true ->
673
                                    case pgsql_prepare(Query, State) of
134✔
674
                                        {ok, _, _, _} ->
675
                                            put(Key, prepared);
134✔
676
                                        {error, Error} ->
677
                                            ?ERROR_MSG(
×
678
                                               "PREPARE failed for SQL query "
679
                                               "at ~p: ~p",
680
                                               [Query#sql_query.loc, Error]),
×
681
                                            put(Key, ignore)
×
682
                                    end
683
                            end;
684
                        _ ->
685
                            ok
9,443✔
686
                    end,
687
                    case get(Key) of
9,577✔
688
                        prepared ->
689
                            pgsql_execute_sql_query(Query, State);
9,577✔
690
                        _ ->
691
                            pgsql_sql_query(Query)
×
692
                    end;
693
                mysql ->
694
                    case {Query#sql_query.flags, ejabberd_option:sql_prepared_statements(State#state.host)} of
10,638✔
695
                        {1, _} ->
696
                            generic_sql_query(Query);
×
697
                        {_, false} ->
698
                            generic_sql_query(Query);
×
699
                        _ ->
700
                            mysql_prepared_execute(Query, State)
10,638✔
701
                    end;
702
                sqlite ->
703
                    sqlite_sql_query(Query)
10,176✔
704
            end
705
        catch exit:{timeout, _} ->
706
                {error, <<"timed out">>};
×
707
              exit:{killed, _} ->
708
                {error, <<"killed">>};
×
709
              exit:{normal, _} ->
710
                {error, <<"terminated unexpectedly">>};
×
711
              exit:{shutdown, _} ->
712
                {error, <<"shutdown">>};
×
713
              ?EX_RULE(Class, Reason, Stack) ->
714
                StackTrace = ?EX_STACK(Stack),
×
715
                ?ERROR_MSG("Internal error while processing SQL query:~n** ~ts",
×
716
                           [misc:format_exception(2, Class, Reason, StackTrace)]),
×
717
                {error, <<"internal error">>}
×
718
        end,
719
    check_error(Res, Query);
30,391✔
720
sql_query_internal(F) when is_function(F) ->
721
    case catch execute_fun(F) of
3,704✔
722
        {aborted, Reason} -> {error, Reason};
×
723
        {'EXIT', Reason} -> {error, Reason};
×
724
        Res -> Res
3,704✔
725
    end;
726
sql_query_internal(Query) ->
727
    State = get(?STATE_KEY),
16,989✔
728
    ?DEBUG("SQL: \"~ts\"", [Query]),
16,989✔
729
    QueryTimeout = query_timeout(State#state.host),
16,989✔
730
    Res = case State#state.db_type of
16,989✔
731
            odbc ->
732
                to_odbc(odbc:sql_query(State#state.db_ref, [Query],
×
733
                                       QueryTimeout - 1000));
734
            mssql ->
735
                to_odbc(odbc:sql_query(State#state.db_ref, [Query],
×
736
                                       QueryTimeout - 1000));
737
            pgsql ->
738
                pgsql_to_odbc(pgsql:squery(State#state.db_ref, Query,
2,293✔
739
                                           QueryTimeout - 1000));
740
            mysql ->
741
                mysql_to_odbc(p1_mysql_conn:squery(State#state.db_ref,
2,294✔
742
                                                   [Query], self(),
743
                                                   [{timeout, QueryTimeout - 1000},
744
                                                    {result_type, binary}]));
745
              sqlite ->
746
                  Host = State#state.host,
12,402✔
747
                  sqlite_to_odbc(Host, sqlite3:sql_exec(sqlite_db(Host), Query))
12,402✔
748
          end,
749
    check_error(Res, Query).
16,989✔
750

751
select_sql_query(Queries, State) ->
752
    select_sql_query(
5,442✔
753
      Queries, State#state.db_type, State#state.db_version, undefined).
754

755
select_sql_query([], _Type, _Version, undefined) ->
756
    undefined;
×
757
select_sql_query([], _Type, _Version, Query) ->
758
    Query;
×
759
select_sql_query([{any, Query} | _], _Type, _Version, _) ->
760
    Query;
5,442✔
761
select_sql_query([{Type, Query} | _], Type, _Version, _) ->
762
    Query;
×
763
select_sql_query([{{Type, _Version1}, Query1} | Rest], Type, undefined, _) ->
764
    select_sql_query(Rest, Type, undefined, Query1);
×
765
select_sql_query([{{Type, Version1}, Query1} | Rest], Type, Version, Query) ->
766
    if
×
767
        Version >= Version1 ->
768
            Query1;
×
769
        true ->
770
            select_sql_query(Rest, Type, Version, Query)
×
771
    end;
772
select_sql_query([{_, _} | Rest], Type, Version, Query) ->
773
    select_sql_query(Rest, Type, Version, Query).
5,442✔
774

775
generic_sql_query(SQLQuery) ->
776
    sql_query_format_res(
×
777
      sql_query_internal(generic_sql_query_format(SQLQuery)),
778
      SQLQuery).
779

780
generic_sql_query_format(SQLQuery) ->
781
    Args = (SQLQuery#sql_query.args)(generic_escape()),
×
782
    (SQLQuery#sql_query.format_query)(Args).
×
783

784
generic_escape() ->
785
    #sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end,
×
786
                integer = fun(X) -> misc:i2l(X) end,
×
787
                boolean = fun(true) -> <<"1">>;
×
788
                             (false) -> <<"0">>
×
789
                          end,
790
                in_array_string = fun(X) -> <<"'", (escape(X))/binary, "'">> end,
×
791
                like_escape = fun() -> <<"">> end
×
792
               }.
793

794
pgsql_sql_query(SQLQuery) ->
795
    sql_query_format_res(
×
796
      sql_query_internal(pgsql_sql_query_format(SQLQuery)),
797
      SQLQuery).
798

799
pgsql_sql_query_format(SQLQuery) ->
800
    Args = (SQLQuery#sql_query.args)(pgsql_escape()),
×
801
    (SQLQuery#sql_query.format_query)(Args).
×
802

803
pgsql_escape() ->
804
    #sql_escape{string = fun(X) -> <<"E'", (escape(X))/binary, "'">> end,
×
805
                integer = fun(X) -> misc:i2l(X) end,
×
806
                boolean = fun(true) -> <<"'t'">>;
×
807
                             (false) -> <<"'f'">>
×
808
                          end,
809
                in_array_string = fun(X) -> <<"E'", (escape(X))/binary, "'">> end,
×
810
                like_escape = fun() -> <<"ESCAPE E'\\\\'">> end
×
811
               }.
812

813
sqlite_sql_query(SQLQuery) ->
814
    sql_query_format_res(
10,176✔
815
      sql_query_internal(sqlite_sql_query_format(SQLQuery)),
816
      SQLQuery).
817

818
sqlite_sql_query_format(SQLQuery) ->
819
    Args = (SQLQuery#sql_query.args)(sqlite_escape()),
10,176✔
820
    (SQLQuery#sql_query.format_query)(Args).
10,176✔
821

822
sqlite_escape() ->
823
    #sql_escape{string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end,
10,176✔
824
                integer = fun(X) -> misc:i2l(X) end,
2,867✔
825
                boolean = fun(true) -> <<"1">>;
115✔
826
                             (false) -> <<"0">>
290✔
827
                          end,
828
                in_array_string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end,
×
829
                like_escape = fun() -> <<"ESCAPE '\\'">> end
308✔
830
               }.
831

832
standard_escape(S) ->
833
    << <<(case Char of
23,583✔
834
              $' -> << "''" >>;
9,157✔
835
              _ -> << Char >>
905,663✔
836
          end)/binary>> || <<Char>> <= S >>.
23,583✔
837

838
mssql_sql_query(SQLQuery) ->
839
    sqlite_sql_query(SQLQuery).
×
840

841
pgsql_prepare(SQLQuery, State) ->
842
    Escape = #sql_escape{_ = fun(_) -> arg end,
134✔
843
                         like_escape = fun() -> escape end},
6✔
844
    {RArgs, _} =
134✔
845
        lists:foldl(
846
            fun(arg, {Acc, I}) ->
847
                {[<<$$, (integer_to_binary(I))/binary>> | Acc], I + 1};
369✔
848
               (escape, {Acc, I}) ->
849
                   {[<<"ESCAPE E'\\\\'">> | Acc], I};
6✔
850
               (List, {Acc, I}) when is_list(List) ->
851
                   {[<<$$, (integer_to_binary(I))/binary>> | Acc], I + 1}
×
852
            end, {[], 1}, (SQLQuery#sql_query.args)(Escape)),
853
    Args = lists:reverse(RArgs),
134✔
854
    %N = length((SQLQuery#sql_query.args)(Escape)),
855
    %Args = [<<$$, (integer_to_binary(I))/binary>> || I <- lists:seq(1, N)],
856
    Query = (SQLQuery#sql_query.format_query)(Args),
134✔
857
    pgsql:prepare(State#state.db_ref, SQLQuery#sql_query.hash, Query).
134✔
858

859
pgsql_execute_escape() ->
860
    #sql_escape{string = fun(X) -> X end,
9,577✔
861
                integer = fun(X) -> [misc:i2l(X)] end,
2,597✔
862
                boolean = fun(true) -> "1";
115✔
863
                             (false) -> "0"
290✔
864
                          end,
865
                in_array_string = fun(X) -> <<"\"", (escape(X))/binary, "\"">> end,
×
866
                like_escape = fun() -> ignore end
308✔
867
               }.
868

869
pgsql_execute_sql_query(SQLQuery, State) ->
870
    Args = (SQLQuery#sql_query.args)(pgsql_execute_escape()),
9,577✔
871
    Args2 = lists:filter(fun(ignore) -> false; (_) -> true end, Args),
9,577✔
872
    ExecuteRes =
9,577✔
873
        pgsql:execute(State#state.db_ref, SQLQuery#sql_query.hash, Args2),
874
%    {T, ExecuteRes} =
875
%        timer:tc(pgsql, execute, [State#state.db_ref, SQLQuery#sql_query.hash, Args]),
876
%    io:format("T ~ts ~p~n", [SQLQuery#sql_query.hash, T]),
877
    Res = pgsql_execute_to_odbc(ExecuteRes),
9,577✔
878
    sql_query_format_res(Res, SQLQuery).
9,577✔
879

880
mysql_prepared_execute(#sql_query{hash = Hash} = Query, State) ->
881
    ValEsc = #sql_escape{like_escape = fun() -> ignore end, _ = fun(X) -> X end},
10,638✔
882
    TypesEsc = #sql_escape{string = fun(_) -> string end,
10,638✔
883
                           integer = fun(_) -> integer end,
2,906✔
884
                           boolean = fun(_) -> bool end,
405✔
885
                           in_array_string = fun(_) -> string end,
×
886
                           like_escape = fun() -> ignore end},
308✔
887
    Val = [X || X <- (Query#sql_query.args)(ValEsc), X /= ignore],
10,638✔
888
    Types = [X || X <- (Query#sql_query.args)(TypesEsc), X /= ignore],
10,638✔
889
    QueryFn = fun() ->
10,638✔
890
        PrepEsc = #sql_escape{like_escape = fun() -> <<>> end, _ = fun(_) -> <<"?">> end},
167✔
891
        (Query#sql_query.format_query)((Query#sql_query.args)(PrepEsc))
167✔
892
        end,
893
    QueryTimeout = query_timeout(State#state.host),
10,638✔
894
    Res = p1_mysql_conn:prepared_query(State#state.db_ref, QueryFn, Hash, Val, Types,
10,638✔
895
                                       self(), [{timeout, QueryTimeout - 1000}]),
896
    Res2 = mysql_to_odbc(Res),
10,638✔
897
    sql_query_format_res(Res2, Query).
10,638✔
898

899
sql_query_format_res({selected, _, Rows}, SQLQuery) ->
900
    Res =
19,061✔
901
        lists:flatmap(
902
          fun(Row) ->
903
                  try
30,917✔
904
                      [(SQLQuery#sql_query.format_res)(Row)]
30,917✔
905
                  catch
906
                      ?EX_RULE(Class, Reason, Stack) ->
907
                          StackTrace = ?EX_STACK(Stack),
×
908
                          ?ERROR_MSG("Error while processing SQL query result:~n"
×
909
                                     "** Row: ~p~n** ~ts",
910
                                     [Row,
911
                                      misc:format_exception(2, Class, Reason, StackTrace)]),
×
912
                          []
×
913
                  end
914
          end, Rows),
915
    {selected, Res};
19,061✔
916
sql_query_format_res(Res, _SQLQuery) ->
917
    Res.
11,330✔
918

919
sql_query_to_iolist(SQLQuery) ->
920
    generic_sql_query_format(SQLQuery).
×
921

922
sql_query_to_iolist(sqlite, SQLQuery) ->
923
    sqlite_sql_query_format(SQLQuery);
×
924
sql_query_to_iolist(_DbType, SQLQuery) ->
925
    generic_sql_query_format(SQLQuery).
×
926

927
sql_begin() ->
928
    sql_query_internal(
2,721✔
929
      [{mssql, [<<"begin transaction;">>]},
930
       {any, [<<"begin;">>]}]).
931

932
sql_commit() ->
933
    sql_query_internal(
2,721✔
934
      [{mssql, [<<"commit transaction;">>]},
935
       {any, [<<"commit;">>]}]).
936

937
sql_rollback() ->
938
    sql_query_internal(
×
939
      [{mssql, [<<"rollback transaction;">>]},
940
       {any, [<<"rollback;">>]}]).
941

942
driver_restart_required(<<"query timed out">>) -> true;
×
943
driver_restart_required(<<"connection closed">>) -> true;
×
944
driver_restart_required(<<"Failed sending data on socket", _/binary>>) -> true;
×
945
driver_restart_required(<<"SQL connection failed">>) -> true;
×
946
driver_restart_required(<<"Communication link failure">>) -> true;
×
947
driver_restart_required(_) -> false.
×
948

949
%% Generate the OTP callback return tuple depending on the driver result.
950
abort_on_driver_error({Tag, Msg} = Reply, From, Timestamp) when Tag == error; Tag == aborted ->
951
    reply(From, Reply, Timestamp),
×
952
    case driver_restart_required(Msg) of
×
953
        true ->
954
            handle_reconnect(Msg, get(?STATE_KEY));
×
955
        _ ->
956
            {next_state, session_established, get(?STATE_KEY)}
×
957
    end;
958
abort_on_driver_error(Reply, From, Timestamp) ->
959
    reply(From, Reply, Timestamp),
20,964✔
960
    {next_state, session_established, get(?STATE_KEY)}.
20,964✔
961

962
-spec report_overload(state()) -> state().
963
report_overload(#state{overload_reported = PrevTime} = State) ->
964
    CurrTime = current_time(),
×
965
    case PrevTime == undefined orelse (CurrTime - PrevTime) > timer:seconds(30) of
×
966
        true ->
967
            ?ERROR_MSG("SQL connection pool is overloaded, "
×
968
                       "discarding stale requests", []),
×
969
            State#state{overload_reported = current_time()};
×
970
        false ->
971
            State
×
972
    end.
973

974
-spec reply({pid(), term()}, term(), integer()) -> term().
975
reply(From, Reply, Timestamp) ->
976
    case current_time() >= Timestamp of
20,964✔
977
        true -> ok;
×
978
        false -> p1_fsm:reply(From, Reply)
20,964✔
979
    end.
980

981
%% == pure ODBC code
982

983
%% part of init/1
984
%% Open an ODBC database connection
985
odbc_connect(SQLServer, Timeout) ->
986
    ejabberd:start_app(odbc),
×
987
    odbc:connect(binary_to_list(SQLServer),
×
988
                 [{scrollable_cursors, off},
989
                  {extended_errors, on},
990
                  {tuple_row, off},
991
                  {timeout, Timeout},
992
                  {binary_strings, on}]).
993

994
%% == Native SQLite code
995

996
%% part of init/1
997
%% Open a database connection to SQLite
998

999
sqlite_connect(Host) ->
1000
    File = sqlite_file(Host),
1✔
1001
    case filelib:ensure_dir(File) of
1✔
1002
        ok ->
1003
            case sqlite3:open(sqlite_db(Host), [{file, File}]) of
1✔
1004
                {ok, Ref} ->
1005
                    sqlite3:sql_exec(
1✔
1006
                      sqlite_db(Host), "pragma foreign_keys = on"),
1007
                    {ok, Ref};
1✔
1008
                {error, {already_started, Ref}} ->
1009
                    {ok, Ref};
×
1010
                {error, Reason} ->
1011
                    {error, Reason}
×
1012
            end;
1013
        Err ->
1014
            Err
×
1015
    end.
1016

1017
%% Convert SQLite query result to Erlang ODBC result formalism
1018
sqlite_to_odbc(Host, ok) ->
1019
    {updated, sqlite3:changes(sqlite_db(Host))};
4,099✔
1020
sqlite_to_odbc(Host, {rowid, _}) ->
1021
    {updated, sqlite3:changes(sqlite_db(Host))};
1,893✔
1022
sqlite_to_odbc(_Host, [{columns, Columns}, {rows, TRows}]) ->
1023
    Rows = [lists:map(
6,410✔
1024
              fun(I) when is_integer(I) ->
1025
                      integer_to_binary(I);
2,861✔
1026
                 (B) ->
1027
                      B
28,543✔
1028
              end, tuple_to_list(Row)) || Row <- TRows],
6,410✔
1029
    {selected, [list_to_binary(C) || C <- Columns], Rows};
6,410✔
1030
sqlite_to_odbc(_Host, {error, _Code, Reason}) ->
1031
    {error, Reason};
×
1032
sqlite_to_odbc(_Host, _) ->
1033
    {updated, undefined}.
×
1034

1035
%% == Native PostgreSQL code
1036

1037
%% part of init/1
1038
%% Open a database connection to PostgreSQL
1039
pgsql_connect(Server, Port, DB, Username, Password, ConnectTimeout,
1040
              Transport, SSLOpts) ->
1041
    pgsql:connect([{host, Server},
1✔
1042
                   {database, DB},
1043
                   {user, Username},
1044
                   {password, Password},
1045
                   {port, Port},
1046
                   {transport, Transport},
1047
                   {connect_timeout, ConnectTimeout},
1048
                   {as_binary, true}|SSLOpts]).
1049

1050
%% Convert PostgreSQL query result to Erlang ODBC result formalism
1051
pgsql_to_odbc({ok, PGSQLResult}) ->
1052
    case PGSQLResult of
2,293✔
1053
      [Item] -> pgsql_item_to_odbc(Item);
2,293✔
1054
      Items -> [pgsql_item_to_odbc(Item) || Item <- Items]
×
1055
    end.
1056

1057
pgsql_item_to_odbc({<<"SELECT", _/binary>>, Rows,
1058
                    Recs}) ->
1059
    {selected, [element(1, Row) || Row <- Rows], Recs};
410✔
1060
pgsql_item_to_odbc({<<"FETCH", _/binary>>, Rows,
1061
                    Recs}) ->
1062
    {selected, [element(1, Row) || Row <- Rows], Recs};
×
1063
pgsql_item_to_odbc(<<"INSERT ", OIDN/binary>>) ->
1064
    [_OID, N] = str:tokens(OIDN, <<" ">>),
×
1065
    {updated, binary_to_integer(N)};
×
1066
pgsql_item_to_odbc(<<"DELETE ", N/binary>>) ->
1067
    {updated, binary_to_integer(N)};
1✔
1068
pgsql_item_to_odbc(<<"UPDATE ", N/binary>>) ->
1069
    {updated, binary_to_integer(N)};
×
1070
pgsql_item_to_odbc({error, Error}) -> {error, Error};
×
1071
pgsql_item_to_odbc(_) -> {updated, undefined}.
1,882✔
1072

1073
pgsql_execute_to_odbc({ok, {<<"SELECT", _/binary>>, Rows}}) ->
1074
    {selected, [], [[Field || {_, Field} <- Row] || Row <- Rows]};
5,968✔
1075
pgsql_execute_to_odbc({ok, {'INSERT', N}}) ->
1076
    {updated, N};
2,430✔
1077
pgsql_execute_to_odbc({ok, {'DELETE', N}}) ->
1078
    {updated, N};
1,175✔
1079
pgsql_execute_to_odbc({ok, {'UPDATE', N}}) ->
1080
    {updated, N};
4✔
1081
pgsql_execute_to_odbc({error, Error}) -> {error, Error};
×
1082
pgsql_execute_to_odbc(_) -> {updated, undefined}.
×
1083

1084

1085
%% == Native MySQL code
1086

1087
%% part of init/1
1088
%% Open a database connection to MySQL
1089
mysql_connect(Server, Port, DB, Username, Password, ConnectTimeout, Transport, SSLOpts0) ->
1090
    SSLOpts = case Transport of
1✔
1091
                  ssl ->
1092
                      [ssl_required|SSLOpts0];
×
1093
                  _ ->
1094
                      []
1✔
1095
              end,
1096
    case p1_mysql_conn:start(binary_to_list(Server), Port,
1✔
1097
                             binary_to_list(Username),
1098
                             binary_to_list(Password),
1099
                             binary_to_list(DB),
1100
                             ConnectTimeout, fun log/3, SSLOpts)
1101
        of
1102
        {ok, Ref} ->
1103
            p1_mysql_conn:fetch(
1✔
1104
                Ref, [<<"set names 'utf8mb4' collate 'utf8mb4_bin';">>], self()),
1105
            {ok, Ref};
1✔
1106
        Err -> Err
×
1107
    end.
1108

1109
%% Convert MySQL query result to Erlang ODBC result formalism
1110
mysql_to_odbc({updated, MySQLRes}) ->
1111
    {updated, p1_mysql:get_result_affected_rows(MySQLRes)};
5,430✔
1112
mysql_to_odbc({data, MySQLRes}) ->
1113
    mysql_item_to_odbc(p1_mysql:get_result_field_info(MySQLRes),
7,503✔
1114
                       p1_mysql:get_result_rows(MySQLRes));
1115
mysql_to_odbc({error, MySQLRes})
1116
  when is_binary(MySQLRes) ->
1117
    {error, MySQLRes};
×
1118
mysql_to_odbc({error, MySQLRes})
1119
  when is_list(MySQLRes) ->
1120
    {error, list_to_binary(MySQLRes)};
×
1121
mysql_to_odbc({error, MySQLRes}) ->
1122
    mysql_to_odbc({error, p1_mysql:get_result_reason(MySQLRes)});
×
1123
mysql_to_odbc(ok) ->
1124
    ok.
×
1125

1126

1127
%% When tabular data is returned, convert it to the ODBC formalism
1128
mysql_item_to_odbc(Columns, Recs) ->
1129
    {selected, [element(2, Column) || Column <- Columns], Recs}.
7,503✔
1130

1131
to_odbc({selected, Columns, Rows}) ->
1132
    Rows2 = lists:map(
×
1133
        fun(Row) ->
1134
            Row2 = if is_tuple(Row) -> tuple_to_list(Row);
×
1135
                       is_list(Row) -> Row
×
1136
                   end,
1137
            lists:map(
×
1138
                fun(I) when is_integer(I) -> integer_to_binary(I);
×
1139
                    (B) -> B
×
1140
                end, Row2)
1141
        end, Rows),
1142
    {selected, [list_to_binary(C) || C <- Columns], Rows2};
×
1143
to_odbc({error, Reason}) when is_list(Reason) ->
1144
    {error, list_to_binary(Reason)};
×
1145
to_odbc(Res) ->
1146
    Res.
×
1147

1148
parse_mysql_version(SVersion, DefaultUpsert) ->
1149
    case re:run(SVersion, <<"(\\d+)\\.(\\d+)(?:\\.(\\d+))?(?:-([^-]*))?">>,
1✔
1150
                [{capture, all_but_first, binary}]) of
1151
        {match, [V1, V2, V3, Type]} ->
1152
            V = ((bin_to_int(V1)*1000)+bin_to_int(V2))*1000+bin_to_int(V3),
1✔
1153
            TypeA = binary_to_atom(Type, utf8),
1✔
1154
            Flags = case TypeA of
1✔
1155
                        'MariaDB' -> DefaultUpsert;
×
1156
                        _ when V >= 5007026 andalso V < 8000000 -> 1;
×
1157
                        _ when V >= 8000020 -> 1;
1✔
1158
                        _ -> DefaultUpsert
×
1159
                    end,
1160
            {ok, {V, TypeA, Flags}};
1✔
1161
        {match, [V1, V2, V3]} ->
1162
            V = ((bin_to_int(V1)*1000)+bin_to_int(V2))*1000+bin_to_int(V3),
×
1163
            Flags = case V of
×
1164
                        _ when V >= 5007026 andalso V < 8000000 -> 1;
×
1165
                        _ when V >= 8000020 -> 1;
×
1166
                        _ -> DefaultUpsert
×
1167
                    end,
1168
            {ok, {V, unknown, Flags}};
×
1169
        _ ->
1170
            error
×
1171
    end.
1172

1173
get_db_version(#state{db_type = pgsql} = State) ->
1174
    case pgsql:squery(State#state.db_ref,
1✔
1175
                      <<"select current_setting('server_version_num')">>) of
1176
        {ok, [{_, _, [[SVersion]]}]} ->
1177
            case catch binary_to_integer(SVersion) of
1✔
1178
                Version when is_integer(Version) ->
1179
                    State#state{db_version = Version};
1✔
1180
                Error ->
1181
                    ?WARNING_MSG("Error getting pgsql version: ~p", [Error]),
×
1182
                    State
×
1183
            end;
1184
        Res ->
1185
            ?WARNING_MSG("Error getting pgsql version: ~p", [Res]),
×
1186
            State
×
1187
    end;
1188
get_db_version(#state{db_type = mysql, host = Host} = State) ->
1189
    DefaultUpsert = case lists:member(mysql_alternative_upsert, ejabberd_option:sql_flags(Host)) of
1✔
1190
                        true -> 1;
×
1191
                        _ -> 0
1✔
1192
                    end,
1193
    case mysql_to_odbc(p1_mysql_conn:squery(State#state.db_ref,
1✔
1194
                                            [<<"select version();">>], self(),
1195
                                            [{timeout, 5000},
1196
                                             {result_type, binary}])) of
1197
        {selected, _, [SVersion]} ->
1198
            case parse_mysql_version(SVersion, DefaultUpsert) of
1✔
1199
                {ok, V} ->
1200
                    State#state{db_version = V};
1✔
1201
                error ->
1202
                    ?WARNING_MSG("Error parsing mysql version: ~p", [SVersion]),
×
1203
                    State
×
1204
            end;
1205
        Res ->
1206
            ?WARNING_MSG("Error getting mysql version: ~p", [Res]),
×
1207
            State
×
1208
    end;
1209
get_db_version(State) ->
1210
    State.
1✔
1211

1212
bin_to_int(<<>>) -> 0;
×
1213
bin_to_int(V) -> binary_to_integer(V).
3✔
1214

1215
log(Level, Format, Args) ->
1216
    case Level of
5✔
1217
      debug -> ?DEBUG(Format, Args);
4✔
1218
      info -> ?INFO_MSG(Format, Args);
1✔
1219
      normal -> ?INFO_MSG(Format, Args);
×
1220
      error -> ?ERROR_MSG(Format, Args)
×
1221
    end.
1222

1223
db_opts(Host) ->
1224
    Type = ejabberd_option:sql_type(Host),
6✔
1225
    Server = ejabberd_option:sql_server(Host),
6✔
1226
    Timeout = ejabberd_option:sql_connect_timeout(Host),
6✔
1227
    Transport = case ejabberd_option:sql_ssl(Host) of
6✔
1228
                    false -> tcp;
6✔
1229
                    true -> ssl
×
1230
                end,
1231
    warn_if_ssl_unsupported(Transport, Type),
6✔
1232
    case Type of
6✔
1233
        odbc ->
1234
            [odbc, Server, Timeout];
×
1235
        sqlite ->
1236
            [sqlite, Host];
2✔
1237
        _ ->
1238
            Port = ejabberd_option:sql_port(Host),
4✔
1239
            DB = case ejabberd_option:sql_database(Host) of
4✔
1240
                     undefined -> <<"ejabberd">>;
×
1241
                     D -> D
4✔
1242
                 end,
1243
            User = ejabberd_option:sql_username(Host),
4✔
1244
            Pass = ejabberd_option:sql_password(Host),
4✔
1245
            SSLOpts = get_ssl_opts(Transport, Host),
4✔
1246
            case Type of
4✔
1247
                mssql ->
1248
                    case odbc_server_is_connstring(Server) of
×
1249
                        true ->
1250
                            [mssql, Server, Timeout];
×
1251
                        false ->
1252
                            Encryption = case Transport of
×
1253
                                tcp -> <<"">>;
×
1254
                                ssl -> <<";ENCRYPTION=require;ENCRYPT=yes">>
×
1255
                            end,
1256
                            [mssql, <<"DRIVER=ODBC;SERVER=", Server/binary, ";DATABASE=", DB/binary,
×
1257
                                      ";UID=", User/binary, ";PWD=", Pass/binary,
1258
                                      ";PORT=", (integer_to_binary(Port))/binary, Encryption/binary,
1259
                                      ";CLIENT_CHARSET=UTF-8;">>, Timeout]
1260
                    end;
1261
                _ ->
1262
                    [Type, Server, Port, DB, User, Pass, Timeout, Transport, SSLOpts]
4✔
1263
            end
1264
    end.
1265

1266
warn_if_ssl_unsupported(tcp, _) ->
1267
    ok;
6✔
1268
warn_if_ssl_unsupported(ssl, pgsql) ->
1269
    ok;
×
1270
warn_if_ssl_unsupported(ssl, mssql) ->
1271
    ok;
×
1272
warn_if_ssl_unsupported(ssl, mysql) ->
1273
    ok;
×
1274
warn_if_ssl_unsupported(ssl, Type) ->
1275
    ?WARNING_MSG("SSL connection is not supported for ~ts", [Type]).
×
1276

1277
get_ssl_opts(ssl, Host) ->
1278
    Opts1 = case ejabberd_option:sql_ssl_certfile(Host) of
×
1279
                undefined -> [];
×
1280
                CertFile -> [{certfile, CertFile}]
×
1281
            end,
1282
    Opts2 = case ejabberd_option:sql_ssl_cafile(Host) of
×
1283
                undefined -> Opts1;
×
1284
                CAFile -> [{cacertfile, CAFile}|Opts1]
×
1285
            end,
1286
    case ejabberd_option:sql_ssl_verify(Host) of
×
1287
        true ->
1288
            case lists:keymember(cacertfile, 1, Opts2) of
×
1289
                true ->
1290
                    [{verify, verify_peer}|Opts2];
×
1291
                false ->
1292
                    ?WARNING_MSG("SSL verification is enabled for "
×
1293
                                 "SQL connection, but option "
1294
                                 "'sql_ssl_cafile' is not set; "
1295
                                 "verification will be disabled", []),
×
1296
                    Opts2
×
1297
            end;
1298
        false ->
1299
            [{verify, verify_none}|Opts2]
×
1300
    end;
1301
get_ssl_opts(tcp, _) ->
1302
    [].
4✔
1303

1304
init_mssql_odbcinst(Host) ->
1305
    Driver = ejabberd_option:sql_odbc_driver(Host),
×
1306
    ODBCINST = io_lib:fwrite("[ODBC]~n"
×
1307
                             "Driver = ~s~n", [Driver]),
1308
    ?DEBUG("~ts:~n~ts", [odbcinst_config(), ODBCINST]),
×
1309
    case filelib:ensure_dir(odbcinst_config()) of
×
1310
        ok ->
1311
            try
×
1312
                ok = write_file_if_new(odbcinst_config(), ODBCINST),
×
1313
                os:putenv("ODBCSYSINI", tmp_dir()),
×
1314
                ok
×
1315
            catch error:{badmatch, {error, Reason} = Err} ->
1316
                    ?ERROR_MSG("Failed to create temporary files in ~ts: ~ts",
×
1317
                               [tmp_dir(), file:format_error(Reason)]),
×
1318
                    Err
×
1319
            end;
1320
        {error, Reason} = Err ->
1321
            ?ERROR_MSG("Failed to create temporary directory ~ts: ~ts",
×
1322
                       [tmp_dir(), file:format_error(Reason)]),
×
1323
            Err
×
1324
    end.
1325

1326
init_mssql(Host) ->
1327
    Server = ejabberd_option:sql_server(Host),
×
1328
    case odbc_server_is_connstring(Server) of
×
1329
        true -> ok;
×
1330
        false -> init_mssql_odbcinst(Host)
×
1331
    end.
1332

1333
odbc_server_is_connstring(Server) ->
1334
    case binary:match(Server, <<"=">>) of
×
1335
        nomatch -> false;
×
1336
        _ -> true
×
1337
    end.
1338

1339
write_file_if_new(File, Payload) ->
1340
    case filelib:is_file(File) of
×
1341
        true -> ok;
×
1342
        false -> file:write_file(File, Payload)
×
1343
    end.
1344

1345
tmp_dir() ->
1346
    case os:type() of
1✔
1347
        {win32, _} -> filename:join([os:getenv("HOME"), "conf"]);
×
1348
        _ -> filename:join(["/tmp", "ejabberd"])
1✔
1349
    end.
1350

1351
odbcinst_config() ->
1352
    filename:join(tmp_dir(), "odbcinst.ini").
1✔
1353

1354
max_fsm_queue() ->
1355
    proplists:get_value(max_queue, fsm_limit_opts(), unlimited).
3✔
1356

1357
fsm_limit_opts() ->
1358
    ejabberd_config:fsm_limit_opts([]).
6✔
1359

1360
query_timeout(LServer) ->
1361
    ejabberd_option:sql_query_timeout(LServer).
48,633✔
1362

1363
current_time() ->
1364
    erlang:monotonic_time(millisecond).
62,892✔
1365

1366
%% ***IMPORTANT*** This error format requires extended_errors turned on.
1367
extended_error({"08S01", _, Reason}) ->
1368
    % TCP Provider: The specified network name is no longer available
1369
    ?DEBUG("ODBC Link Failure: ~ts", [Reason]),
×
1370
    <<"Communication link failure">>;
×
1371
extended_error({"08001", _, Reason}) ->
1372
    % Login timeout expired
1373
    ?DEBUG("ODBC Connect Timeout: ~ts", [Reason]),
×
1374
    <<"SQL connection failed">>;
×
1375
extended_error({"IMC01", _, Reason}) ->
1376
    % The connection is broken and recovery is not possible
1377
    ?DEBUG("ODBC Link Failure: ~ts", [Reason]),
×
1378
    <<"Communication link failure">>;
×
1379
extended_error({"IMC06", _, Reason}) ->
1380
    % The connection is broken and recovery is not possible
1381
    ?DEBUG("ODBC Link Failure: ~ts", [Reason]),
×
1382
    <<"Communication link failure">>;
×
1383
extended_error({Code, _, Reason}) ->
1384
    ?DEBUG("ODBC Error ~ts: ~ts", [Code, Reason]),
×
1385
    iolist_to_binary(Reason);
×
1386
extended_error(Error) ->
1387
    Error.
×
1388

1389
check_error({error, Why} = Err, _Query) when Why == killed ->
1390
    Err;
×
1391
check_error({error, Why}, #sql_query{} = Query) ->
1392
    Err = extended_error(Why),
×
1393
    ?ERROR_MSG("SQL query '~ts' at ~p failed: ~p",
×
1394
               [Query#sql_query.hash, Query#sql_query.loc, Err]),
×
1395
    {error, Err};
×
1396
check_error({error, Why}, Query) ->
1397
    Err = extended_error(Why),
×
1398
    case catch iolist_to_binary(Query) of
×
1399
        SQuery when is_binary(SQuery) ->
1400
            ?ERROR_MSG("SQL query '~ts' failed: ~p", [SQuery, Err]);
×
1401
        _ ->
1402
            ?ERROR_MSG("SQL query ~p failed: ~p", [Query, Err])
×
1403
    end,
1404
    {error, Err};
×
1405
check_error(Result, _Query) ->
1406
    Result.
47,380✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc