• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

umputun / tg-spam / 15672803049

16 Jun 2025 05:42AM UTC coverage: 79.352% (-2.1%) from 81.499%
15672803049

Pull #294

github

umputun
Add CLI override functionality for auth credentials in database mode

- Created applyCLIOverrides function to handle selective CLI parameter overrides
- Currently handles --server.auth and --server.auth-hash overrides
- Only applies overrides when values differ from defaults
- Auth hash takes precedence over password when both are provided
- Added comprehensive unit tests covering all override scenarios
- Function is extensible for future CLI override needs (documented in comments)

This fixes the issue where users couldn't change auth credentials when using
database configuration mode (--confdb), as the save-config command would
overwrite all settings rather than just the auth credentials.
Pull Request #294: Implement database configuration support

891 of 1298 new or added lines in 9 files covered. (68.64%)

174 existing lines in 4 files now uncovered.

5734 of 7226 relevant lines covered (79.35%)

57.45 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.74
/app/webapi/webapi.go
1
// Package webapi provides a web API spam detection service.
2
package webapi
3

4
import (
5
        "bytes"
6
        "compress/gzip"
7
        "context"
8
        "crypto/rand"
9
        "crypto/sha1" //nolint
10
        "embed"
11
        "encoding/json"
12
        "errors"
13
        "fmt"
14
        "html/template"
15
        "io"
16
        "io/fs"
17
        "math/big"
18
        "net/http"
19
        "path"
20
        "strconv"
21
        "strings"
22
        "time"
23

24
        "github.com/didip/tollbooth/v8"
25
        log "github.com/go-pkgz/lgr"
26
        "github.com/go-pkgz/rest"
27
        "github.com/go-pkgz/rest/logger"
28
        "github.com/go-pkgz/routegroup"
29

30
        "github.com/umputun/tg-spam/app/config"
31
        "github.com/umputun/tg-spam/app/storage"
32
        "github.com/umputun/tg-spam/app/storage/engine"
33
        "github.com/umputun/tg-spam/lib/approved"
34
        "github.com/umputun/tg-spam/lib/spamcheck"
35
)
36

37
//go:generate moq --out mocks/detector.go --pkg mocks --with-resets --skip-ensure . Detector
38
//go:generate moq --out mocks/spam_filter.go --pkg mocks --with-resets --skip-ensure . SpamFilter
39
//go:generate moq --out mocks/locator.go --pkg mocks --with-resets --skip-ensure . Locator
40
//go:generate moq --out mocks/detected_spam.go --pkg mocks --with-resets --skip-ensure . DetectedSpam
41
//go:generate moq --out mocks/storage_engine.go --pkg mocks --with-resets --skip-ensure . StorageEngine
42

43
//go:embed assets/* assets/components/*
44
var templateFS embed.FS
45
var tmpl = template.Must(template.ParseFS(templateFS, "assets/*.html", "assets/components/*.html"))
46

47
// startTime tracks when the server started
48
var startTime = time.Now()
49

50
// Server is a web API server.
51
type Server struct {
52
        Config
53
}
54

55
// Config defines  server parameters
56
type Config struct {
57
        Version       string           // version to show in /ping
58
        ListenAddr    string           // listen address
59
        Detector      Detector         // spam detector
60
        SpamFilter    SpamFilter       // spam filter (bot)
61
        DetectedSpam  DetectedSpam     // detected spam accessor
62
        Locator       Locator          // locator for user info
63
        StorageEngine StorageEngine    // database engine access for backups
64
        SettingsStore SettingsStore    // configuration storage interface
65
        AuthUser      string           // basic auth username (default: "tg-spam")
66
        AuthPasswd    string           // basic auth password
67
        AuthHash      string           // basic auth hash. If both AuthPasswd and AuthHash are provided, AuthHash is used
68
        Dbg           bool             // debug mode
69
        AppSettings   *config.Settings // application settings
70
        ConfigDBMode  bool             // indicates if app is running with database config
71
}
72

73
// Settings contains all application settings
74
type Settings struct {
75
        InstanceID              string        `json:"instance_id"`
76
        PrimaryGroup            string        `json:"primary_group"`
77
        AdminGroup              string        `json:"admin_group"`
78
        DisableAdminSpamForward bool          `json:"disable_admin_spam_forward"`
79
        LoggerEnabled           bool          `json:"logger_enabled"`
80
        SuperUsers              []string      `json:"super_users"`
81
        NoSpamReply             bool          `json:"no_spam_reply"`
82
        CasEnabled              bool          `json:"cas_enabled"`
83
        MetaEnabled             bool          `json:"meta_enabled"`
84
        MetaLinksLimit          int           `json:"meta_links_limit"`
85
        MetaMentionsLimit       int           `json:"meta_mentions_limit"`
86
        MetaLinksOnly           bool          `json:"meta_links_only"`
87
        MetaImageOnly           bool          `json:"meta_image_only"`
88
        MetaVideoOnly           bool          `json:"meta_video_only"`
89
        MetaAudioOnly           bool          `json:"meta_audio_only"`
90
        MetaForwarded           bool          `json:"meta_forwarded"`
91
        MetaKeyboard            bool          `json:"meta_keyboard"`
92
        MetaUsernameSymbols     string        `json:"meta_username_symbols"`
93
        MultiLangLimit          int           `json:"multi_lang_limit"`
94
        OpenAIEnabled           bool          `json:"openai_enabled"`
95
        LuaPluginsEnabled       bool          `json:"lua_plugins_enabled"`
96
        LuaPluginsDir           string        `json:"lua_plugins_dir"`
97
        LuaEnabledPlugins       []string      `json:"lua_enabled_plugins"`
98
        LuaDynamicReload        bool          `json:"lua_dynamic_reload"`
99
        LuaAvailablePlugins     []string      `json:"lua_available_plugins"` // the list of all available Lua plugins
100
        SamplesDataPath         string        `json:"samples_data_path"`
101
        DynamicDataPath         string        `json:"dynamic_data_path"`
102
        WatchIntervalSecs       int           `json:"watch_interval_secs"`
103
        SimilarityThreshold     float64       `json:"similarity_threshold"`
104
        MinMsgLen               int           `json:"min_msg_len"`
105
        MaxEmoji                int           `json:"max_emoji"`
106
        MinSpamProbability      float64       `json:"min_spam_probability"`
107
        ParanoidMode            bool          `json:"paranoid_mode"`
108
        FirstMessagesCount      int           `json:"first_messages_count"`
109
        StartupMessageEnabled   bool          `json:"startup_message_enabled"`
110
        TrainingEnabled         bool          `json:"training_enabled"`
111
        StorageTimeout          time.Duration `json:"storage_timeout"`
112
        OpenAIVeto              bool          `json:"openai_veto"`
113
        OpenAIHistorySize       int           `json:"openai_history_size"`
114
        OpenAIModel             string        `json:"openai_model"`
115
        SoftBanEnabled          bool          `json:"soft_ban_enabled"`
116
        AbnormalSpacingEnabled  bool          `json:"abnormal_spacing_enabled"`
117
        HistorySize             int           `json:"history_size"`
118
        DebugModeEnabled        bool          `json:"debug_mode_enabled"`
119
        DryModeEnabled          bool          `json:"dry_mode_enabled"`
120
        TGDebugModeEnabled      bool          `json:"tg_debug_mode_enabled"`
121
}
122

123
// Detector is a spam detector interface.
124
type Detector interface {
125
        Check(req spamcheck.Request) (spam bool, cr []spamcheck.Response)
126
        ApprovedUsers() []approved.UserInfo
127
        AddApprovedUser(user approved.UserInfo) error
128
        RemoveApprovedUser(id string) error
129
        GetLuaPluginNames() []string // Returns the list of available Lua plugin names
130
}
131

132
// SpamFilter is a spam filter, bot interface.
133
type SpamFilter interface {
134
        UpdateSpam(msg string) error
135
        UpdateHam(msg string) error
136
        ReloadSamples() (err error)
137
        DynamicSamples() (spam, ham []string, err error)
138
        RemoveDynamicSpamSample(sample string) error
139
        RemoveDynamicHamSample(sample string) error
140
}
141

142
// Locator is a storage interface used to get user id by name and vice versa.
143
type Locator interface {
144
        UserIDByName(ctx context.Context, userName string) int64
145
        UserNameByID(ctx context.Context, userID int64) string
146
}
147

148
// DetectedSpam is a storage interface used to get detected spam messages and set added flag.
149
type DetectedSpam interface {
150
        Read(ctx context.Context) ([]storage.DetectedSpamInfo, error)
151
        SetAddedToSamplesFlag(ctx context.Context, id int64) error
152
        FindByUserID(ctx context.Context, userID int64) (*storage.DetectedSpamInfo, error)
153
}
154

155
// StorageEngine provides access to the database engine for operations like backup
156
type StorageEngine interface {
157
        Backup(ctx context.Context, w io.Writer) error
158
        Type() engine.Type
159
        BackupSqliteAsPostgres(ctx context.Context, w io.Writer) error
160
}
161

162
// NewServer creates a new web API server.
163
func NewServer(cfg Config) *Server {
46✔
164
        return &Server{Config: cfg}
46✔
165
}
46✔
166

167
// Run starts server and accepts requests checking for spam messages.
168
func (s *Server) Run(ctx context.Context) error {
3✔
169
        router := routegroup.New(http.NewServeMux())
3✔
170
        router.Use(rest.Recoverer(log.Default()))
3✔
171
        router.Use(logger.New(logger.Log(log.Default()), logger.Prefix("[DEBUG]")).Handler)
3✔
172
        router.Use(rest.Throttle(1000))
3✔
173
        router.Use(rest.AppInfo("tg-spam", "umputun", s.Version), rest.Ping)
3✔
174
        router.Use(tollbooth.HTTPMiddleware(tollbooth.NewLimiter(50, nil)))
3✔
175
        router.Use(rest.SizeLimit(1024 * 1024)) // 1M max request size
3✔
176

3✔
177
        // set default username if not provided
3✔
178
        if s.AuthUser == "" {
6✔
179
                s.AuthUser = "tg-spam" // default username
3✔
180
        }
3✔
181

182
        // hash-based authentication for maximum security
183
        if s.AuthHash != "" {
4✔
184
                log.Printf("[INFO] basic auth enabled for webapi server (user: %s)", s.AuthUser)
1✔
185
                router.Use(rest.BasicAuthWithBcryptHashAndPrompt(s.AuthUser, s.AuthHash))
1✔
186
        } else {
3✔
187
                log.Printf("[WARN] basic auth disabled, access to webapi is not protected")
2✔
188
        }
2✔
189

190
        router = s.routes(router) // setup routes
3✔
191

3✔
192
        srv := &http.Server{Addr: s.ListenAddr, Handler: router, ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second}
3✔
193
        go func() {
6✔
194
                <-ctx.Done()
3✔
195
                if err := srv.Shutdown(ctx); err != nil {
3✔
UNCOV
196
                        log.Printf("[WARN] failed to shutdown webapi server: %v", err)
×
197
                } else {
3✔
198
                        log.Printf("[INFO] webapi server stopped")
3✔
199
                }
3✔
200
        }()
201

202
        log.Printf("[INFO] start webapi server on %s", s.ListenAddr)
3✔
203
        if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
3✔
UNCOV
204
                return fmt.Errorf("failed to run server: %w", err)
×
205
        }
×
206
        return nil
3✔
207
}
208

209
func (s *Server) routes(router *routegroup.Bundle) *routegroup.Bundle {
5✔
210
        // auth api routes
5✔
211
        router.Route(func(authApi *routegroup.Bundle) {
10✔
212
                if s.AuthHash != "" {
6✔
213
                        authApi.Use(s.authMiddleware(rest.BasicAuthWithBcryptHashAndPrompt(s.AuthUser, s.AuthHash)))
1✔
214
                }
1✔
215
                authApi.HandleFunc("POST /check", s.checkMsgHandler)         // check a message for spam
5✔
216
                authApi.HandleFunc("GET /check/{user_id}", s.checkIDHandler) // check user id for spam
5✔
217

5✔
218
                authApi.Mount("/update").Route(func(r *routegroup.Bundle) {
10✔
219
                        // update spam/ham samples
5✔
220
                        r.HandleFunc("POST /spam", s.updateSampleHandler(s.SpamFilter.UpdateSpam)) // update spam samples
5✔
221
                        r.HandleFunc("POST /ham", s.updateSampleHandler(s.SpamFilter.UpdateHam))   // update ham samples
5✔
222
                })
5✔
223

224
                authApi.Mount("/delete").Route(func(r *routegroup.Bundle) {
10✔
225
                        // delete spam/ham samples
5✔
226
                        r.HandleFunc("POST /spam", s.deleteSampleHandler(s.SpamFilter.RemoveDynamicSpamSample))
5✔
227
                        r.HandleFunc("POST /ham", s.deleteSampleHandler(s.SpamFilter.RemoveDynamicHamSample))
5✔
228
                })
5✔
229

230
                authApi.Mount("/download").Route(func(r *routegroup.Bundle) {
10✔
231
                        r.HandleFunc("GET /spam", s.downloadSampleHandler(func(spam, _ []string) ([]string, string) {
5✔
UNCOV
232
                                return spam, "spam.txt"
×
233
                        }))
×
234
                        r.HandleFunc("GET /ham", s.downloadSampleHandler(func(_, ham []string) ([]string, string) {
5✔
UNCOV
235
                                return ham, "ham.txt"
×
236
                        }))
×
237
                        r.HandleFunc("GET /detected_spam", s.downloadDetectedSpamHandler)
5✔
238
                        r.HandleFunc("GET /backup", s.downloadBackupHandler)
5✔
239
                        r.HandleFunc("GET /export-to-postgres", s.downloadExportToPostgresHandler)
5✔
240
                })
241

242
                authApi.HandleFunc("GET /samples", s.getDynamicSamplesHandler)    // get dynamic samples
5✔
243
                authApi.HandleFunc("PUT /samples", s.reloadDynamicSamplesHandler) // reload samples
5✔
244

5✔
245
                authApi.Mount("/users").Route(func(r *routegroup.Bundle) { // manage approved users
10✔
246
                        // add user to the approved list and storage
5✔
247
                        r.HandleFunc("POST /add", s.updateApprovedUsersHandler(s.Detector.AddApprovedUser))
5✔
248
                        // remove user from an approved list and storage
5✔
249
                        r.HandleFunc("POST /delete", s.updateApprovedUsersHandler(s.removeApprovedUser))
5✔
250
                        // get approved users
5✔
251
                        r.HandleFunc("GET /", s.getApprovedUsersHandler)
5✔
252
                })
5✔
253

254
                authApi.HandleFunc("GET /settings", s.getSettingsHandler) // get application settings
5✔
255
        })
256

257
        router.Route(func(webUI *routegroup.Bundle) {
10✔
258
                if s.AuthHash != "" {
6✔
259
                        webUI.Use(s.authMiddleware(rest.BasicAuthWithBcryptHashAndPrompt(s.AuthUser, s.AuthHash)))
1✔
260
                }
1✔
261
                webUI.HandleFunc("GET /", s.htmlSpamCheckHandler)                         // serve template for webUI UI
5✔
262
                webUI.HandleFunc("GET /manage_samples", s.htmlManageSamplesHandler)       // serve manage samples page
5✔
263
                webUI.HandleFunc("GET /manage_users", s.htmlManageUsersHandler)           // serve manage users page
5✔
264
                webUI.HandleFunc("GET /detected_spam", s.htmlDetectedSpamHandler)         // serve detected spam page
5✔
265
                webUI.HandleFunc("GET /list_settings", s.htmlSettingsHandler)             // serve settings
5✔
266
                webUI.HandleFunc("POST /detected_spam/add", s.htmlAddDetectedSpamHandler) // add detected spam to samples
5✔
267

5✔
268
                // configuration management endpoints
5✔
269
                if s.SettingsStore != nil && s.ConfigDBMode {
5✔
NEW
270
                        webUI.Route(func(config *routegroup.Bundle) {
×
NEW
271
                                config.HandleFunc("POST /config", s.saveConfigHandler)     // save current configuration to database
×
NEW
272
                                config.HandleFunc("GET /config", s.loadConfigHandler)      // load configuration from database
×
NEW
273
                                config.HandleFunc("PUT /config", s.updateConfigHandler)    // update configuration
×
NEW
274
                                config.HandleFunc("DELETE /config", s.deleteConfigHandler) // delete configuration
×
NEW
275
                        })
×
276
                }
277

278
                // handle logout - force Basic Auth re-authentication
279
                webUI.HandleFunc("GET /logout", func(w http.ResponseWriter, _ *http.Request) {
5✔
UNCOV
280
                        w.Header().Set("WWW-Authenticate", `Basic realm="tg-spam"`)
×
281
                        w.WriteHeader(http.StatusUnauthorized)
×
282
                        fmt.Fprintln(w, "Logged out successfully")
×
283
                })
×
284

285
                // serve only specific static files at root level
286
                staticFiles := newStaticFS(templateFS,
5✔
287
                        staticFileMapping{urlPath: "styles.css", filesysPath: "assets/styles.css"},
5✔
288
                        staticFileMapping{urlPath: "logo.png", filesysPath: "assets/logo.png"},
5✔
289
                        staticFileMapping{urlPath: "spinner.svg", filesysPath: "assets/spinner.svg"},
5✔
290
                )
5✔
291
                webUI.HandleFiles("/", http.FS(staticFiles))
5✔
292
        })
293

294
        return router
5✔
295
}
296

297
// checkMsgHandler handles POST /check request.
298
// it gets message text and user id from request body and returns spam status and check results.
299
func (s *Server) checkMsgHandler(w http.ResponseWriter, r *http.Request) {
7✔
300
        type CheckResultDisplay struct {
7✔
301
                Spam   bool
7✔
302
                Checks []spamcheck.Response
7✔
303
        }
7✔
304

7✔
305
        isHtmxRequest := r.Header.Get("HX-Request") == "true"
7✔
306

7✔
307
        req := spamcheck.Request{CheckOnly: true}
7✔
308
        if !isHtmxRequest {
13✔
309
                // API request
6✔
310
                if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
7✔
311
                        w.WriteHeader(http.StatusBadRequest)
1✔
312
                        rest.RenderJSON(w, rest.JSON{"error": "can't decode request", "details": err.Error()})
1✔
313
                        log.Printf("[WARN] can't decode request: %v", err)
1✔
314
                        return
1✔
315
                }
1✔
316
        } else {
1✔
317
                // for hx-request (HTMX) we need to get the values from the form
1✔
318
                req.UserID = r.FormValue("user_id")
1✔
319
                req.UserName = r.FormValue("user_name")
1✔
320
                req.Msg = r.FormValue("msg")
1✔
321
        }
1✔
322

323
        spam, cr := s.Detector.Check(req)
6✔
324
        if !isHtmxRequest {
11✔
325
                // for API request return JSON
5✔
326
                rest.RenderJSON(w, rest.JSON{"spam": spam, "checks": cr})
5✔
327
                return
5✔
328
        }
5✔
329

330
        if req.Msg == "" {
1✔
UNCOV
331
                w.Header().Set("HX-Retarget", "#error-message")
×
332
                fmt.Fprintln(w, "<div class='alert alert-danger'>Valid message required.</div>")
×
333
                return
×
334
        }
×
335

336
        // render result for HTMX request
337
        resultDisplay := CheckResultDisplay{
1✔
338
                Spam:   spam,
1✔
339
                Checks: cr,
1✔
340
        }
1✔
341

1✔
342
        if err := tmpl.ExecuteTemplate(w, "check_results", resultDisplay); err != nil {
1✔
UNCOV
343
                log.Printf("[WARN] can't execute result template: %v", err)
×
344
                http.Error(w, "Error rendering result", http.StatusInternalServerError)
×
345
                return
×
346
        }
×
347
}
348

349
// checkIDHandler handles GET /check/{user_id} request.
350
// it returns JSON with the status "spam" or "ham" for a given user id.
351
// if user is spammer, it also returns check results.
352
func (s *Server) checkIDHandler(w http.ResponseWriter, r *http.Request) {
2✔
353
        type info struct {
2✔
354
                UserName  string               `json:"user_name,omitempty"`
2✔
355
                Message   string               `json:"message,omitempty"`
2✔
356
                Timestamp time.Time            `json:"timestamp,omitempty"`
2✔
357
                Checks    []spamcheck.Response `json:"checks,omitempty"`
2✔
358
        }
2✔
359
        resp := struct {
2✔
360
                Status string `json:"status"`
2✔
361
                Info   *info  `json:"info,omitempty"`
2✔
362
        }{
2✔
363
                Status: "ham",
2✔
364
        }
2✔
365

2✔
366
        userID, err := strconv.ParseInt(r.PathValue("user_id"), 10, 64)
2✔
367
        if err != nil {
2✔
UNCOV
368
                w.WriteHeader(http.StatusBadRequest)
×
369
                rest.RenderJSON(w, rest.JSON{"error": "can't parse user id", "details": err.Error()})
×
370
                return
×
371
        }
×
372

373
        si, err := s.DetectedSpam.FindByUserID(r.Context(), userID)
2✔
374
        if err != nil {
2✔
UNCOV
375
                w.WriteHeader(http.StatusInternalServerError)
×
376
                rest.RenderJSON(w, rest.JSON{"error": "can't get user info", "details": err.Error()})
×
377
                return
×
378
        }
×
379
        if si != nil {
3✔
380
                resp.Status = "spam"
1✔
381
                resp.Info = &info{
1✔
382
                        UserName:  si.UserName,
1✔
383
                        Message:   si.Text,
1✔
384
                        Timestamp: si.Timestamp,
1✔
385
                        Checks:    si.Checks,
1✔
386
                }
1✔
387
        }
1✔
388
        rest.RenderJSON(w, resp)
2✔
389
}
390

391
// getDynamicSamplesHandler handles GET /samples request. It returns dynamic samples both for spam and ham.
392
func (s *Server) getDynamicSamplesHandler(w http.ResponseWriter, _ *http.Request) {
2✔
393
        spam, ham, err := s.SpamFilter.DynamicSamples()
2✔
394
        if err != nil {
3✔
395
                w.WriteHeader(http.StatusInternalServerError)
1✔
396
                rest.RenderJSON(w, rest.JSON{"error": "can't get dynamic samples", "details": err.Error()})
1✔
397
                return
1✔
398
        }
1✔
399
        rest.RenderJSON(w, rest.JSON{"spam": spam, "ham": ham})
1✔
400
}
401

402
// downloadSampleHandler handles GET /download/spam|ham request. It returns dynamic samples both for spam and ham.
403
func (s *Server) downloadSampleHandler(pickFn func(spam, ham []string) ([]string, string)) func(w http.ResponseWriter, r *http.Request) {
13✔
404
        return func(w http.ResponseWriter, _ *http.Request) {
16✔
405
                spam, ham, err := s.SpamFilter.DynamicSamples()
3✔
406
                if err != nil {
4✔
407
                        w.WriteHeader(http.StatusInternalServerError)
1✔
408
                        rest.RenderJSON(w, rest.JSON{"error": "can't get dynamic samples", "details": err.Error()})
1✔
409
                        return
1✔
410
                }
1✔
411
                samples, name := pickFn(spam, ham)
2✔
412
                body := strings.Join(samples, "\n")
2✔
413
                w.Header().Set("Content-Type", "text/plain; charset=utf-8")
2✔
414
                w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name))
2✔
415
                w.Header().Set("Content-Length", strconv.Itoa(len(body)))
2✔
416
                w.WriteHeader(http.StatusOK)
2✔
417
                _, _ = w.Write([]byte(body))
2✔
418
        }
419
}
420

421
// updateSampleHandler handles POST /update/spam|ham request. It updates dynamic samples both for spam and ham.
422
func (s *Server) updateSampleHandler(updFn func(msg string) error) func(w http.ResponseWriter, r *http.Request) {
13✔
423
        return func(w http.ResponseWriter, r *http.Request) {
18✔
424
                var req struct {
5✔
425
                        Msg string `json:"msg"`
5✔
426
                }
5✔
427

5✔
428
                isHtmxRequest := r.Header.Get("HX-Request") == "true"
5✔
429

5✔
430
                if isHtmxRequest {
5✔
UNCOV
431
                        req.Msg = r.FormValue("msg")
×
432
                } else {
5✔
433
                        if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
6✔
434
                                w.WriteHeader(http.StatusBadRequest)
1✔
435
                                rest.RenderJSON(w, rest.JSON{"error": "can't decode request", "details": err.Error()})
1✔
436
                                return
1✔
437
                        }
1✔
438
                }
439

440
                err := updFn(req.Msg)
4✔
441
                if err != nil {
5✔
442
                        w.WriteHeader(http.StatusInternalServerError)
1✔
443
                        rest.RenderJSON(w, rest.JSON{"error": "can't update samples", "details": err.Error()})
1✔
444
                        return
1✔
445
                }
1✔
446

447
                if isHtmxRequest {
3✔
UNCOV
448
                        s.renderSamples(w, "samples_list")
×
449
                } else {
3✔
450
                        rest.RenderJSON(w, rest.JSON{"updated": true, "msg": req.Msg})
3✔
451
                }
3✔
452
        }
453
}
454

455
// deleteSampleHandler handles DELETE /samples request. It deletes dynamic samples both for spam and ham.
456
func (s *Server) deleteSampleHandler(delFn func(msg string) error) func(w http.ResponseWriter, r *http.Request) {
13✔
457
        return func(w http.ResponseWriter, r *http.Request) {
18✔
458
                var req struct {
5✔
459
                        Msg string `json:"msg"`
5✔
460
                }
5✔
461
                isHtmxRequest := r.Header.Get("HX-Request") == "true"
5✔
462
                if isHtmxRequest {
6✔
463
                        req.Msg = r.FormValue("msg")
1✔
464
                } else {
5✔
465
                        if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
4✔
UNCOV
466
                                w.WriteHeader(http.StatusBadRequest)
×
467
                                rest.RenderJSON(w, rest.JSON{"error": "can't decode request", "details": err.Error()})
×
468
                                return
×
469
                        }
×
470
                }
471

472
                if err := delFn(req.Msg); err != nil {
6✔
473
                        w.WriteHeader(http.StatusInternalServerError)
1✔
474
                        rest.RenderJSON(w, rest.JSON{"error": "can't delete sample", "details": err.Error()})
1✔
475
                        return
1✔
476
                }
1✔
477

478
                if isHtmxRequest {
5✔
479
                        s.renderSamples(w, "samples_list")
1✔
480
                } else {
4✔
481
                        rest.RenderJSON(w, rest.JSON{"deleted": true, "msg": req.Msg, "count": 1})
3✔
482
                }
3✔
483
        }
484
}
485

486
// reloadDynamicSamplesHandler handles PUT /samples request. It reloads dynamic samples from db storage.
487
func (s *Server) reloadDynamicSamplesHandler(w http.ResponseWriter, _ *http.Request) {
2✔
488
        if err := s.SpamFilter.ReloadSamples(); err != nil {
3✔
489
                w.WriteHeader(http.StatusInternalServerError)
1✔
490
                rest.RenderJSON(w, rest.JSON{"error": "can't reload samples", "details": err.Error()})
1✔
491
                return
1✔
492
        }
1✔
493
        rest.RenderJSON(w, rest.JSON{"reloaded": true})
1✔
494
}
495

496
// updateApprovedUsersHandler handles POST /users/add and /users/delete requests, it adds or removes users from approved list.
497
func (s *Server) updateApprovedUsersHandler(updFn func(ui approved.UserInfo) error) func(w http.ResponseWriter, r *http.Request) {
14✔
498
        return func(w http.ResponseWriter, r *http.Request) {
23✔
499
                req := approved.UserInfo{}
9✔
500
                isHtmxRequest := r.Header.Get("HX-Request") == "true"
9✔
501
                if isHtmxRequest {
10✔
502
                        req.UserID = r.FormValue("user_id")
1✔
503
                        req.UserName = r.FormValue("user_name")
1✔
504
                } else {
9✔
505
                        if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
9✔
506
                                w.WriteHeader(http.StatusBadRequest)
1✔
507
                                rest.RenderJSON(w, rest.JSON{"error": "can't decode request", "details": err.Error()})
1✔
508
                                return
1✔
509
                        }
1✔
510
                }
511

512
                // try to get userID from request and fallback to userName lookup if it's empty
513
                if req.UserID == "" {
12✔
514
                        req.UserID = strconv.FormatInt(s.Locator.UserIDByName(r.Context(), req.UserName), 10)
4✔
515
                }
4✔
516

517
                if req.UserID == "" || req.UserID == "0" {
9✔
518
                        if isHtmxRequest {
1✔
UNCOV
519
                                w.Header().Set("HX-Retarget", "#error-message")
×
520
                                fmt.Fprintln(w, "<div class='alert alert-danger'>Either userid or valid username required.</div>")
×
521
                                return
×
522
                        }
×
523
                        w.WriteHeader(http.StatusBadRequest)
1✔
524
                        rest.RenderJSON(w, rest.JSON{"error": "user ID is required"})
1✔
525
                        return
1✔
526
                }
527

528
                // add or remove user from the approved list of detector
529
                if err := updFn(req); err != nil {
7✔
UNCOV
530
                        w.WriteHeader(http.StatusInternalServerError)
×
531
                        rest.RenderJSON(w, rest.JSON{"error": "can't update approved users", "details": err.Error()})
×
532
                        return
×
533
                }
×
534

535
                if isHtmxRequest {
8✔
536
                        users := s.Detector.ApprovedUsers()
1✔
537
                        tmplData := struct {
1✔
538
                                ApprovedUsers      []approved.UserInfo
1✔
539
                                TotalApprovedUsers int
1✔
540
                        }{
1✔
541
                                ApprovedUsers:      users,
1✔
542
                                TotalApprovedUsers: len(users),
1✔
543
                        }
1✔
544

1✔
545
                        if err := tmpl.ExecuteTemplate(w, "users_list", tmplData); err != nil {
1✔
UNCOV
546
                                http.Error(w, "Error executing template", http.StatusInternalServerError)
×
547
                                return
×
548
                        }
×
549

550
                } else {
6✔
551
                        rest.RenderJSON(w, rest.JSON{"updated": true, "user_id": req.UserID, "user_name": req.UserName})
6✔
552
                }
6✔
553
        }
554
}
555

556
// removeApprovedUser is adopter for updateApprovedUsersHandler updFn
557
func (s *Server) removeApprovedUser(req approved.UserInfo) error {
2✔
558
        if err := s.Detector.RemoveApprovedUser(req.UserID); err != nil {
2✔
UNCOV
559
                return fmt.Errorf("failed to remove approved user %s: %w", req.UserID, err)
×
560
        }
×
561
        return nil
2✔
562
}
563

564
// getApprovedUsersHandler handles GET /users request. It returns list of approved users.
565
func (s *Server) getApprovedUsersHandler(w http.ResponseWriter, _ *http.Request) {
1✔
566
        rest.RenderJSON(w, rest.JSON{"user_ids": s.Detector.ApprovedUsers()})
1✔
567
}
1✔
568

569
// getSettingsHandler returns application settings, including the list of available Lua plugins
570
func (s *Server) getSettingsHandler(w http.ResponseWriter, _ *http.Request) {
3✔
571
        // get available Lua plugins and store them directly in AppSettings
3✔
572
        s.AppSettings.LuaPlugins.EnabledPlugins = s.Detector.GetLuaPluginNames()
3✔
573

3✔
574
        // return the application settings directly - sensitive info is protected by json tags
3✔
575
        rest.RenderJSON(w, s.AppSettings)
3✔
576
}
3✔
577

578
// htmlSpamCheckHandler handles GET / request.
579
// It returns rendered spam_check.html template with all the components.
580
func (s *Server) htmlSpamCheckHandler(w http.ResponseWriter, _ *http.Request) {
3✔
581
        tmplData := struct {
3✔
582
                Version string
3✔
583
        }{
3✔
584
                Version: s.Version,
3✔
585
        }
3✔
586

3✔
587
        if err := tmpl.ExecuteTemplate(w, "spam_check.html", tmplData); err != nil {
4✔
588
                log.Printf("[WARN] can't execute template: %v", err)
1✔
589
                http.Error(w, "Error executing template", http.StatusInternalServerError)
1✔
590
                return
1✔
591
        }
1✔
592
}
593

594
// htmlManageSamplesHandler handles GET /manage_samples request.
595
// It returns rendered manage_samples.html template with all the components.
596
func (s *Server) htmlManageSamplesHandler(w http.ResponseWriter, _ *http.Request) {
1✔
597
        s.renderSamples(w, "manage_samples.html")
1✔
598
}
1✔
599

600
func (s *Server) htmlManageUsersHandler(w http.ResponseWriter, _ *http.Request) {
3✔
601
        users := s.Detector.ApprovedUsers()
3✔
602
        tmplData := struct {
3✔
603
                ApprovedUsers      []approved.UserInfo
3✔
604
                TotalApprovedUsers int
3✔
605
        }{
3✔
606
                ApprovedUsers:      users,
3✔
607
                TotalApprovedUsers: len(users),
3✔
608
        }
3✔
609
        tmplData.TotalApprovedUsers = len(tmplData.ApprovedUsers)
3✔
610

3✔
611
        if err := tmpl.ExecuteTemplate(w, "manage_users.html", tmplData); err != nil {
4✔
612
                log.Printf("[WARN] can't execute template: %v", err)
1✔
613
                http.Error(w, "Error executing template", http.StatusInternalServerError)
1✔
614
                return
1✔
615
        }
1✔
616
}
617

618
func (s *Server) htmlDetectedSpamHandler(w http.ResponseWriter, r *http.Request) {
2✔
619
        ds, err := s.DetectedSpam.Read(r.Context())
2✔
620
        if err != nil {
3✔
621
                log.Printf("[ERROR] Failed to fetch detected spam: %v", err)
1✔
622
                http.Error(w, "Internal Server Error", http.StatusInternalServerError)
1✔
623
                return
1✔
624
        }
1✔
625

626
        // clean up detected spam entries
627
        for i, d := range ds {
3✔
628
                d.Text = strings.ReplaceAll(d.Text, "'", " ")
2✔
629
                d.Text = strings.ReplaceAll(d.Text, "\n", " ")
2✔
630
                d.Text = strings.ReplaceAll(d.Text, "\r", " ")
2✔
631
                d.Text = strings.ReplaceAll(d.Text, "\t", " ")
2✔
632
                d.Text = strings.ReplaceAll(d.Text, "\"", " ")
2✔
633
                d.Text = strings.ReplaceAll(d.Text, "\\", " ")
2✔
634
                ds[i] = d
2✔
635
        }
2✔
636

637
        // get filter from query param, default to "all"
638
        filter := r.URL.Query().Get("filter")
1✔
639
        if filter == "" {
2✔
640
                filter = "all"
1✔
641
        }
1✔
642

643
        // apply filtering
644
        var filteredDS []storage.DetectedSpamInfo
1✔
645
        switch filter {
1✔
UNCOV
646
        case "non-classified":
×
647
                for _, entry := range ds {
×
648
                        hasClassifierHam := false
×
649
                        for _, check := range entry.Checks {
×
650
                                if check.Name == "classifier" && !check.Spam {
×
651
                                        hasClassifierHam = true
×
652
                                        break
×
653
                                }
654
                        }
UNCOV
655
                        if hasClassifierHam {
×
656
                                filteredDS = append(filteredDS, entry)
×
657
                        }
×
658
                }
UNCOV
659
        case "openai":
×
660
                for _, entry := range ds {
×
661
                        hasOpenAI := false
×
662
                        for _, check := range entry.Checks {
×
663
                                if check.Name == "openai" {
×
664
                                        hasOpenAI = true
×
665
                                        break
×
666
                                }
667
                        }
UNCOV
668
                        if hasOpenAI {
×
669
                                filteredDS = append(filteredDS, entry)
×
670
                        }
×
671
                }
672
        default: // "all" or any other value
1✔
673
                filteredDS = ds
1✔
674
        }
675

676
        tmplData := struct {
1✔
677
                DetectedSpamEntries []storage.DetectedSpamInfo
1✔
678
                TotalDetectedSpam   int
1✔
679
                FilteredCount       int
1✔
680
                Filter              string
1✔
681
                OpenAIEnabled       bool
1✔
682
        }{
1✔
683
                DetectedSpamEntries: filteredDS,
1✔
684
                TotalDetectedSpam:   len(ds),
1✔
685
                FilteredCount:       len(filteredDS),
1✔
686
                Filter:              filter,
1✔
687
                OpenAIEnabled:       s.AppSettings != nil && s.AppSettings.IsOpenAIEnabled(),
1✔
688
        }
1✔
689

1✔
690
        // if it's an HTMX request, render both content and count display for OOB swap
1✔
691
        if r.Header.Get("HX-Request") == "true" {
1✔
UNCOV
692
                var buf bytes.Buffer
×
693

×
694
                // first render the content template
×
695
                if err := tmpl.ExecuteTemplate(&buf, "detected_spam_content", tmplData); err != nil {
×
696
                        log.Printf("[WARN] can't execute content template: %v", err)
×
697
                        http.Error(w, "Error executing template", http.StatusInternalServerError)
×
698
                        return
×
699
                }
×
700

701
                // then append OOB swap for the count display
UNCOV
702
                countHTML := ""
×
703
                if filter != "all" {
×
704
                        countHTML = fmt.Sprintf("(%d/%d)", len(filteredDS), len(ds))
×
705
                } else {
×
706
                        countHTML = fmt.Sprintf("(%d)", len(ds))
×
707
                }
×
708

UNCOV
709
                buf.WriteString(fmt.Sprintf(`<span id="count-display" hx-swap-oob="true">%s</span>`, countHTML))
×
710

×
711
                // write the combined response
×
712
                if _, err := buf.WriteTo(w); err != nil {
×
713
                        log.Printf("[WARN] failed to write response: %v", err)
×
714
                }
×
715
                return
×
716
        }
717

718
        // full page render for normal requests
719
        if err := tmpl.ExecuteTemplate(w, "detected_spam.html", tmplData); err != nil {
1✔
UNCOV
720
                log.Printf("[WARN] can't execute template: %v", err)
×
721
                http.Error(w, "Error executing template", http.StatusInternalServerError)
×
722
                return
×
723
        }
×
724
}
725

726
func (s *Server) htmlAddDetectedSpamHandler(w http.ResponseWriter, r *http.Request) {
5✔
727
        reportErr := func(err error, _ int) {
9✔
728
                w.Header().Set("HX-Retarget", "#error-message")
4✔
729
                fmt.Fprintf(w, "<div class='alert alert-danger'>%s</div>", err)
4✔
730
        }
4✔
731
        msg := r.FormValue("msg")
5✔
732

5✔
733
        id, err := strconv.ParseInt(r.FormValue("id"), 10, 64)
5✔
734
        if err != nil || msg == "" {
7✔
735
                log.Printf("[WARN] bad request: %v", err)
2✔
736
                reportErr(fmt.Errorf("bad request: %v", err), http.StatusBadRequest)
2✔
737
                return
2✔
738
        }
2✔
739

740
        if err := s.SpamFilter.UpdateSpam(msg); err != nil {
4✔
741
                log.Printf("[WARN] failed to update spam samples: %v", err)
1✔
742
                reportErr(fmt.Errorf("can't update spam samples: %v", err), http.StatusInternalServerError)
1✔
743
                return
1✔
744

1✔
745
        }
1✔
746
        if err := s.DetectedSpam.SetAddedToSamplesFlag(r.Context(), id); err != nil {
3✔
747
                log.Printf("[WARN] failed to update detected spam: %v", err)
1✔
748
                reportErr(fmt.Errorf("can't update detected spam: %v", err), http.StatusInternalServerError)
1✔
749
                return
1✔
750
        }
1✔
751
        w.WriteHeader(http.StatusOK)
1✔
752
}
753

754
func (s *Server) htmlSettingsHandler(w http.ResponseWriter, r *http.Request) {
4✔
755
        // get database information if StorageEngine is available
4✔
756
        var dbInfo struct {
4✔
757
                DatabaseType   string `json:"database_type"`
4✔
758
                GID            string `json:"gid"`
4✔
759
                DatabaseStatus string `json:"database_status"`
4✔
760
        }
4✔
761

4✔
762
        if s.StorageEngine != nil {
6✔
763
                // try to cast to SQL engine to get type information
2✔
764
                if sqlEngine, ok := s.StorageEngine.(*engine.SQL); ok {
2✔
UNCOV
765
                        dbInfo.DatabaseType = string(sqlEngine.Type())
×
766
                        dbInfo.GID = sqlEngine.GID()
×
767
                        dbInfo.DatabaseStatus = "Connected"
×
768
                } else {
2✔
769
                        dbInfo.DatabaseType = "Unknown"
2✔
770
                        dbInfo.DatabaseStatus = "Connected (unknown type)"
2✔
771
                }
2✔
772
        } else {
2✔
773
                dbInfo.DatabaseStatus = "Not connected"
2✔
774
        }
2✔
775

776
        // get backup information
777
        backupURL := "/download/backup"
4✔
778
        backupFilename := fmt.Sprintf("tg-spam-backup-%s-%s.sql.gz", dbInfo.DatabaseType, time.Now().Format("20060102-150405"))
4✔
779

4✔
780
        // get system info - uptime since server start
4✔
781
        uptime := time.Since(startTime)
4✔
782

4✔
783
        // get the list of available Lua plugins
4✔
784
        luaPlugins := s.Detector.GetLuaPluginNames()
4✔
785

4✔
786
        // get configuration DB status
4✔
787
        configAvailable := false
4✔
788
        var lastUpdated time.Time
4✔
789
        if s.SettingsStore != nil {
4✔
NEW
790
                configAvailable = true
×
NEW
791
                if lu, err := s.SettingsStore.LastUpdated(r.Context()); err == nil {
×
NEW
792
                        lastUpdated = lu
×
NEW
793
                }
×
794
        }
795

796
        data := struct {
4✔
797
                *config.Settings
4✔
798
                LuaAvailablePlugins []string
4✔
799
                Version             string
4✔
800
                Database            struct {
4✔
801
                        Type   string
4✔
802
                        GID    string
4✔
803
                        Status string
4✔
804
                }
4✔
805
                Backup struct {
4✔
806
                        URL      string
4✔
807
                        Filename string
4✔
808
                }
4✔
809
                System struct {
4✔
810
                        Uptime string
4✔
811
                }
4✔
812
                ConfigAvailable bool
4✔
813
                LastUpdated     time.Time
4✔
814
                ConfigDBMode    bool
4✔
815
        }{
4✔
816
                Settings:            s.AppSettings,
4✔
817
                LuaAvailablePlugins: luaPlugins,
4✔
818
                Version:             s.Version,
4✔
819
                Database: struct {
4✔
820
                        Type   string
4✔
821
                        GID    string
4✔
822
                        Status string
4✔
823
                }{
4✔
824
                        Type:   dbInfo.DatabaseType,
4✔
825
                        GID:    dbInfo.GID,
4✔
826
                        Status: dbInfo.DatabaseStatus,
4✔
827
                },
4✔
828
                Backup: struct {
4✔
829
                        URL      string
4✔
830
                        Filename string
4✔
831
                }{
4✔
832
                        URL:      backupURL,
4✔
833
                        Filename: backupFilename,
4✔
834
                },
4✔
835
                System: struct {
4✔
836
                        Uptime string
4✔
837
                }{
4✔
838
                        Uptime: formatDuration(uptime),
4✔
839
                },
4✔
840
                ConfigAvailable: configAvailable,
4✔
841
                LastUpdated:     lastUpdated,
4✔
842
                ConfigDBMode:    s.ConfigDBMode,
4✔
843
        }
4✔
844

4✔
845
        if err := tmpl.ExecuteTemplate(w, "settings.html", data); err != nil {
5✔
846
                log.Printf("[WARN] can't execute template: %v", err)
1✔
847
                http.Error(w, "Error executing template", http.StatusInternalServerError)
1✔
848
                return
1✔
849
        }
1✔
850
}
851

852
// formatDuration formats a duration in a human-readable way
853
func formatDuration(d time.Duration) string {
12✔
854
        days := int(d.Hours() / 24)
12✔
855
        hours := int(d.Hours()) % 24
12✔
856
        minutes := int(d.Minutes()) % 60
12✔
857

12✔
858
        if days > 0 {
15✔
859
                return fmt.Sprintf("%dd %dh %dm", days, hours, minutes)
3✔
860
        }
3✔
861

862
        if hours > 0 {
11✔
863
                return fmt.Sprintf("%dh %dm", hours, minutes)
2✔
864
        }
2✔
865

866
        return fmt.Sprintf("%dm", minutes)
7✔
867
}
868

869
func (s *Server) downloadDetectedSpamHandler(w http.ResponseWriter, r *http.Request) {
3✔
870
        ctx := r.Context()
3✔
871
        spam, err := s.DetectedSpam.Read(ctx)
3✔
872
        if err != nil {
4✔
873
                w.WriteHeader(http.StatusInternalServerError)
1✔
874
                rest.RenderJSON(w, rest.JSON{"error": "can't get detected spam", "details": err.Error()})
1✔
875
                return
1✔
876
        }
1✔
877

878
        type jsonSpamInfo struct {
2✔
879
                ID        int64                `json:"id"`
2✔
880
                GID       string               `json:"gid"`
2✔
881
                Text      string               `json:"text"`
2✔
882
                UserID    int64                `json:"user_id"`
2✔
883
                UserName  string               `json:"user_name"`
2✔
884
                Timestamp time.Time            `json:"timestamp"`
2✔
885
                Added     bool                 `json:"added"`
2✔
886
                Checks    []spamcheck.Response `json:"checks"`
2✔
887
        }
2✔
888

2✔
889
        // convert entries to jsonl format with lowercase fields
2✔
890
        lines := make([]string, 0, len(spam))
2✔
891
        for _, entry := range spam {
5✔
892
                data, err := json.Marshal(jsonSpamInfo{
3✔
893
                        ID:        entry.ID,
3✔
894
                        GID:       entry.GID,
3✔
895
                        Text:      entry.Text,
3✔
896
                        UserID:    entry.UserID,
3✔
897
                        UserName:  entry.UserName,
3✔
898
                        Timestamp: entry.Timestamp,
3✔
899
                        Added:     entry.Added,
3✔
900
                        Checks:    entry.Checks,
3✔
901
                })
3✔
902
                if err != nil {
3✔
UNCOV
903
                        w.WriteHeader(http.StatusInternalServerError)
×
904
                        rest.RenderJSON(w, rest.JSON{"error": "can't marshal entry", "details": err.Error()})
×
905
                        return
×
906
                }
×
907
                lines = append(lines, string(data))
3✔
908
        }
909

910
        body := strings.Join(lines, "\n")
2✔
911
        w.Header().Set("Content-Type", "application/x-jsonlines")
2✔
912
        w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", "detected_spam.jsonl"))
2✔
913
        w.Header().Set("Content-Length", strconv.Itoa(len(body)))
2✔
914
        w.WriteHeader(http.StatusOK)
2✔
915
        _, _ = w.Write([]byte(body))
2✔
916
}
917

918
// downloadBackupHandler streams a database backup as an SQL file with gzip compression
919
// Files are always compressed and always have .gz extension to ensure consistency
920
func (s *Server) downloadBackupHandler(w http.ResponseWriter, r *http.Request) {
2✔
921
        if s.StorageEngine == nil {
3✔
922
                w.WriteHeader(http.StatusInternalServerError)
1✔
923
                rest.RenderJSON(w, rest.JSON{"error": "storage engine not available"})
1✔
924
                return
1✔
925
        }
1✔
926

927
        // set filename based on database type and timestamp
928
        dbType := "db"
1✔
929
        sqlEng, ok := s.StorageEngine.(*engine.SQL)
1✔
930
        if ok {
1✔
UNCOV
931
                dbType = string(sqlEng.Type())
×
932
        }
×
933
        timestamp := time.Now().Format("20060102-150405")
1✔
934

1✔
935
        // always use a .gz extension as the content is always compressed
1✔
936
        filename := fmt.Sprintf("tg-spam-backup-%s-%s.sql.gz", dbType, timestamp)
1✔
937

1✔
938
        // set headers for file download - note we're using application/octet-stream
1✔
939
        // instead of application/sql to prevent browsers from trying to interpret the file
1✔
940
        w.Header().Set("Content-Type", "application/octet-stream")
1✔
941
        w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
1✔
942
        w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
1✔
943
        w.Header().Set("Pragma", "no-cache")
1✔
944
        w.Header().Set("Expires", "0")
1✔
945

1✔
946
        // create a gzip writer that streams to response
1✔
947
        gzipWriter := gzip.NewWriter(w)
1✔
948
        defer func() {
2✔
949
                if err := gzipWriter.Close(); err != nil {
1✔
UNCOV
950
                        log.Printf("[ERROR] failed to close gzip writer: %v", err)
×
951
                }
×
952
        }()
953

954
        // stream backup directly to response through gzip
955
        if err := s.StorageEngine.Backup(r.Context(), gzipWriter); err != nil {
1✔
UNCOV
956
                log.Printf("[ERROR] failed to create backup: %v", err)
×
957
                // we've already started writing the response, so we can't send a proper error response
×
958
                return
×
959
        }
×
960

961
        // flush the gzip writer to ensure all data is written
962
        if err := gzipWriter.Flush(); err != nil {
1✔
UNCOV
963
                log.Printf("[ERROR] failed to flush gzip writer: %v", err)
×
964
        }
×
965
}
966

967
// downloadExportToPostgresHandler streams a PostgreSQL-compatible export from a SQLite database
968
func (s *Server) downloadExportToPostgresHandler(w http.ResponseWriter, r *http.Request) {
3✔
969
        if s.StorageEngine == nil {
4✔
970
                w.WriteHeader(http.StatusInternalServerError)
1✔
971
                rest.RenderJSON(w, rest.JSON{"error": "storage engine not available"})
1✔
972
                return
1✔
973
        }
1✔
974

975
        // check if the database is SQLite
976
        if s.StorageEngine.Type() != engine.Sqlite {
3✔
977
                w.WriteHeader(http.StatusBadRequest)
1✔
978
                rest.RenderJSON(w, rest.JSON{"error": "source database must be SQLite"})
1✔
979
                return
1✔
980
        }
1✔
981

982
        // set filename based on timestamp
983
        timestamp := time.Now().Format("20060102-150405")
1✔
984
        filename := fmt.Sprintf("tg-spam-sqlite-to-postgres-%s.sql.gz", timestamp)
1✔
985

1✔
986
        // set headers for file download
1✔
987
        w.Header().Set("Content-Type", "application/octet-stream")
1✔
988
        w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", filename))
1✔
989
        w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
1✔
990
        w.Header().Set("Pragma", "no-cache")
1✔
991
        w.Header().Set("Expires", "0")
1✔
992

1✔
993
        // create a gzip writer that streams to response
1✔
994
        gzipWriter := gzip.NewWriter(w)
1✔
995
        defer func() {
2✔
996
                if err := gzipWriter.Close(); err != nil {
1✔
UNCOV
997
                        log.Printf("[ERROR] failed to close gzip writer: %v", err)
×
998
                }
×
999
        }()
1000

1001
        // stream export directly to response through gzip
1002
        if err := s.StorageEngine.BackupSqliteAsPostgres(r.Context(), gzipWriter); err != nil {
1✔
UNCOV
1003
                log.Printf("[ERROR] failed to create export: %v", err)
×
1004
                // we've already started writing the response, so we can't send a proper error response
×
1005
                return
×
1006
        }
×
1007

1008
        // flush the gzip writer to ensure all data is written
1009
        if err := gzipWriter.Flush(); err != nil {
1✔
UNCOV
1010
                log.Printf("[ERROR] failed to flush gzip writer: %v", err)
×
1011
        }
×
1012
}
1013

1014
func (s *Server) renderSamples(w http.ResponseWriter, tmplName string) {
6✔
1015
        spam, ham, err := s.SpamFilter.DynamicSamples()
6✔
1016
        if err != nil {
7✔
1017
                w.WriteHeader(http.StatusInternalServerError)
1✔
1018
                rest.RenderJSON(w, rest.JSON{"error": "can't fetch samples", "details": err.Error()})
1✔
1019
                return
1✔
1020
        }
1✔
1021

1022
        spam, ham = s.reverseSamples(spam, ham)
5✔
1023

5✔
1024
        type smpleWithID struct {
5✔
1025
                ID     string
5✔
1026
                Sample string
5✔
1027
        }
5✔
1028

5✔
1029
        makeID := func(s string) string {
19✔
1030
                hash := sha1.New() //nolint
14✔
1031
                if _, err := hash.Write([]byte(s)); err != nil {
14✔
UNCOV
1032
                        return fmt.Sprintf("%x", s)
×
1033
                }
×
1034
                return fmt.Sprintf("%x", hash.Sum(nil))
14✔
1035
        }
1036

1037
        tmplData := struct {
5✔
1038
                SpamSamples      []smpleWithID
5✔
1039
                HamSamples       []smpleWithID
5✔
1040
                TotalHamSamples  int
5✔
1041
                TotalSpamSamples int
5✔
1042
        }{
5✔
1043
                TotalHamSamples:  len(ham),
5✔
1044
                TotalSpamSamples: len(spam),
5✔
1045
        }
5✔
1046
        for _, s := range spam {
12✔
1047
                tmplData.SpamSamples = append(tmplData.SpamSamples, smpleWithID{ID: makeID(s), Sample: s})
7✔
1048
        }
7✔
1049
        for _, h := range ham {
12✔
1050
                tmplData.HamSamples = append(tmplData.HamSamples, smpleWithID{ID: makeID(h), Sample: h})
7✔
1051
        }
7✔
1052

1053
        if err := tmpl.ExecuteTemplate(w, tmplName, tmplData); err != nil {
6✔
1054
                w.WriteHeader(http.StatusInternalServerError)
1✔
1055
                rest.RenderJSON(w, rest.JSON{"error": "can't execute template", "details": err.Error()})
1✔
1056
                return
1✔
1057
        }
1✔
1058
}
1059

1060
func (s *Server) authMiddleware(mw func(next http.Handler) http.Handler) func(next http.Handler) http.Handler {
2✔
1061
        if s.AuthHash == "" {
2✔
NEW
1062
                // if no hash is provided, authentication is disabled
×
UNCOV
1063
                return func(next http.Handler) http.Handler {
×
UNCOV
1064
                        return next
×
UNCOV
1065
                }
×
1066
        }
1067
        return func(next http.Handler) http.Handler {
35✔
1068
                return mw(next)
33✔
1069
        }
33✔
1070
}
1071

1072
// reverseSamples returns reversed lists of spam and ham samples
1073
func (s *Server) reverseSamples(spam, ham []string) (revSpam, revHam []string) {
8✔
1074
        revSpam = make([]string, len(spam))
8✔
1075
        revHam = make([]string, len(ham))
8✔
1076

8✔
1077
        for i, j := 0, len(spam)-1; i < len(spam); i, j = i+1, j-1 {
19✔
1078
                revSpam[i] = spam[j]
11✔
1079
        }
11✔
1080
        for i, j := 0, len(ham)-1; i < len(ham); i, j = i+1, j-1 {
19✔
1081
                revHam[i] = ham[j]
11✔
1082
        }
11✔
1083
        return revSpam, revHam
8✔
1084
}
1085

1086
// staticFS is a filtered filesystem that only exposes specific static files
1087
type staticFS struct {
1088
        fs        fs.FS
1089
        urlToPath map[string]string
1090
}
1091

1092
// staticFileMapping defines a mapping between URL path and filesystem path
1093
type staticFileMapping struct {
1094
        urlPath     string
1095
        filesysPath string
1096
}
1097

1098
func newStaticFS(fsys fs.FS, files ...staticFileMapping) *staticFS {
5✔
1099
        urlToPath := make(map[string]string)
5✔
1100
        for _, f := range files {
20✔
1101
                urlToPath[f.urlPath] = f.filesysPath
15✔
1102
        }
15✔
1103

1104
        return &staticFS{
5✔
1105
                fs:        fsys,
5✔
1106
                urlToPath: urlToPath,
5✔
1107
        }
5✔
1108
}
1109

1110
func (sfs *staticFS) Open(name string) (fs.File, error) {
6✔
1111
        cleanName := path.Clean("/" + name)[1:]
6✔
1112

6✔
1113
        fsPath, ok := sfs.urlToPath[cleanName]
6✔
1114
        if !ok {
9✔
1115
                return nil, fs.ErrNotExist
3✔
1116
        }
3✔
1117

1118
        file, err := sfs.fs.Open(fsPath)
3✔
1119
        if err != nil {
3✔
UNCOV
1120
                return nil, fmt.Errorf("failed to open static file %s: %w", fsPath, err)
×
1121
        }
×
1122
        return file, nil
3✔
1123
}
1124

1125
// GenerateRandomPassword generates a random password of a given length
1126
func GenerateRandomPassword(length int) (string, error) {
2✔
1127
        const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()_+"
2✔
1128
        const charsetLen = int64(len(charset))
2✔
1129

2✔
1130
        result := make([]byte, length)
2✔
1131
        for i := 0; i < length; i++ {
66✔
1132
                n, err := rand.Int(rand.Reader, big.NewInt(charsetLen))
64✔
1133
                if err != nil {
64✔
UNCOV
1134
                        return "", fmt.Errorf("failed to generate random number: %w", err)
×
1135
                }
×
1136
                result[i] = charset[n.Int64()]
64✔
1137
        }
1138
        return string(result), nil
2✔
1139
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc