• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

umputun / tg-spam / 15672803049

16 Jun 2025 05:42AM UTC coverage: 79.352% (-2.1%) from 81.499%
15672803049

Pull #294

github

umputun
Add CLI override functionality for auth credentials in database mode

- Created applyCLIOverrides function to handle selective CLI parameter overrides
- Currently handles --server.auth and --server.auth-hash overrides
- Only applies overrides when values differ from defaults
- Auth hash takes precedence over password when both are provided
- Added comprehensive unit tests covering all override scenarios
- Function is extensible for future CLI override needs (documented in comments)

This fixes the issue where users couldn't change auth credentials when using
database configuration mode (--confdb), as the save-config command would
overwrite all settings rather than just the auth credentials.
Pull Request #294: Implement database configuration support

891 of 1298 new or added lines in 9 files covered. (68.64%)

174 existing lines in 4 files now uncovered.

5734 of 7226 relevant lines covered (79.35%)

57.45 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.71
/lib/tgspam/detector.go
1
package tgspam
2

3
import (
4
        "bufio"
5
        "bytes"
6
        "context"
7
        "encoding/json"
8
        "fmt"
9
        "io"
10
        "iter"
11
        "log"
12
        "math"
13
        "net/http"
14
        "sort"
15
        "strconv"
16
        "strings"
17
        "sync"
18
        "time"
19
        "unicode"
20

21
        "github.com/forPelevin/gomoji"
22

23
        "github.com/umputun/tg-spam/lib/approved"
24
        "github.com/umputun/tg-spam/lib/spamcheck"
25
        "github.com/umputun/tg-spam/lib/tgspam/plugin"
26
)
27

28
//go:generate moq --out mocks/sample_updater.go --pkg mocks --skip-ensure --with-resets . SampleUpdater
29
//go:generate moq --out mocks/http_client.go --pkg mocks --skip-ensure --with-resets . HTTPClient
30
//go:generate moq --out mocks/user_storage.go --pkg mocks --skip-ensure --with-resets . UserStorage
31
//go:generate moq --out mocks/lua_plugin_engine.go --pkg mocks --skip-ensure --with-resets . LuaPluginEngine
32

33
// Detector is a spam detector, thread-safe.
34
// It uses a set of checks to determine if a message is spam, and also keeps a list of approved users.
35
type Detector struct {
36
        Config
37
        classifier     classifier
38
        openaiChecker  *openAIChecker
39
        metaChecks     []MetaCheck
40
        luaChecks      []plugin.Check // separate field for Lua plugin checks
41
        tokenizedSpam  []map[string]int
42
        approvedUsers  map[string]approved.UserInfo
43
        stopWords      []string
44
        excludedTokens map[string]struct{}
45
        luaEngine      LuaPluginEngine
46

47
        spamSamplesUpd SampleUpdater
48
        hamSamplesUpd  SampleUpdater
49
        userStorage    UserStorage
50

51
        // history of recent messages to keep in memory
52
        // can be passed to checkers supporting history
53
        hamHistory  *spamcheck.LastRequests
54
        spamHistory *spamcheck.LastRequests
55

56
        lock sync.RWMutex
57
}
58

59
// Config is a set of parameters for Detector.
60
type Config struct {
61
        SimilarityThreshold float64       // threshold for spam similarity, 0.0 - 1.0
62
        MinMsgLen           int           // minimum message length to check
63
        MaxAllowedEmoji     int           // maximum number of emojis allowed in a message
64
        CasAPI              string        // CAS API URL
65
        CasUserAgent        string        // CAS API User-Agent header value, set only if non-empty
66
        FirstMessageOnly    bool          // if true, only the first message from a user is checked
67
        FirstMessagesCount  int           // number of first messages to check for spam
68
        HTTPClient          HTTPClient    // http client to use for requests
69
        MinSpamProbability  float64       // minimum spam probability to consider a message spam with classifier, if 0 - ignored
70
        OpenAIVeto          bool          // if true, openai will be used to veto spam messages, otherwise it will be used to veto ham messages
71
        OpenAIHistorySize   int           // history size for openai
72
        MultiLangWords      int           // if true, check for number of multi-lingual words
73
        StorageTimeout      time.Duration // timeout for storage operations, if not set - no timeout
74

75
        LuaPlugins struct {
76
                Enabled        bool     // if true, enable Lua plugins
77
                PluginsDir     string   // directory with Lua plugins
78
                EnabledPlugins []string // list of enabled plugins (by name, without .lua extension)
79
                DynamicReload  bool     // if true, enable dynamic reloading of Lua plugins when files change
80
        }
81

82
        AbnormalSpacing struct {
83
                Enabled                 bool    // if true, enable check for abnormal spacing
84
                MinWordsCount           int     // the minimum number of words in the message to be considered
85
                ShortWordLen            int     // the length of the word to be considered short (in rune characters)
86
                ShortWordRatioThreshold float64 // the ratio of short words to all words in the message
87
                SpaceRatioThreshold     float64 // the ratio of spaces to all characters in the message
88
        }
89
        HistorySize int // history of recent messages to keep in memory
90
}
91

92
// SampleUpdater is an interface for updating spam/ham samples on the fly.
93
type SampleUpdater interface {
94
        Append(msg string) error        // append a message to the samples storage
95
        Remove(msg string) error        // remove a message from the samples storage
96
        Reader() (io.ReadCloser, error) // return a reader for the samples storage
97
}
98

99
// UserStorage is an interface for approved users storage.
100
type UserStorage interface {
101
        Read(ctx context.Context) ([]approved.UserInfo, error) // read approved users from storage
102
        Write(ctx context.Context, au approved.UserInfo) error // write approved user to storage
103
        Delete(ctx context.Context, id string) error           // delete approved user from storage
104
}
105

106
// HTTPClient is an interface for http client, satisfied by http.Client.
107
type HTTPClient interface {
108
        Do(req *http.Request) (*http.Response, error)
109
}
110

111
// LuaPluginEngine defines an interface for the Lua plugin system
112
type LuaPluginEngine interface {
113
        LoadScript(path string) error               // loads a single Lua script
114
        ReloadScript(path string) error             // reloads a single Lua script
115
        LoadDirectory(dir string) error             // loads all Lua scripts from a directory
116
        GetCheck(name string) (plugin.Check, error) // returns a specific named plugin check
117
        GetAllChecks() map[string]plugin.Check      // returns all loaded plugin checks
118
        Close()                                     // cleans up resources
119
}
120

121
// LoadResult is a result of loading samples.
122
type LoadResult struct {
123
        ExcludedTokens int // number of excluded tokens
124
        SpamSamples    int // number of spam samples
125
        HamSamples     int // number of ham samples
126
        StopWords      int // number of stop words (phrases)
127
}
128

129
// NewDetector makes a new Detector with the given config.
130
func NewDetector(p Config) *Detector {
60✔
131
        res := &Detector{
60✔
132
                Config:        p,
60✔
133
                classifier:    newClassifier(),
60✔
134
                approvedUsers: make(map[string]approved.UserInfo),
60✔
135
                tokenizedSpam: []map[string]int{},
60✔
136
                metaChecks:    []MetaCheck{},
60✔
137
                luaChecks:     []plugin.Check{},
60✔
138
                hamHistory:    spamcheck.NewLastRequests(p.HistorySize),
60✔
139
                spamHistory:   spamcheck.NewLastRequests(p.HistorySize),
60✔
140
                luaEngine:     nil, // will be set with WithLuaEngine if needed
60✔
141
        }
60✔
142
        // if FirstMessagesCount is set, FirstMessageOnly enforced to true.
60✔
143
        // this is to avoid confusion when FirstMessagesCount is set but FirstMessageOnly is false.
60✔
144
        // the reason for the redundant FirstMessageOnly flag is to avoid breaking api compatibility.
60✔
145
        if p.FirstMessagesCount > 0 {
65✔
146
                res.FirstMessageOnly = true
5✔
147
        }
5✔
148
        if p.FirstMessageOnly && p.FirstMessagesCount == 0 {
81✔
149
                res.FirstMessagesCount = 1 // default value for FirstMessagesCount if FirstMessageOnly is set
21✔
150
        }
21✔
151
        return res
60✔
152
}
153

154
// Check checks if a given message is spam. Returns true if spam and also returns a list of check results.
155
func (d *Detector) Check(req spamcheck.Request) (spam bool, cr []spamcheck.Response) {
154✔
156

154✔
157
        isSpamDetected := func(cr []spamcheck.Response) bool {
300✔
158
                for _, r := range cr {
311✔
159
                        if r.Spam {
220✔
160
                                return true
55✔
161
                        }
55✔
162
                }
163
                return false
91✔
164
        }
165

166
        cleanMsg := d.cleanText(req.Msg)
154✔
167
        d.lock.RLock()
154✔
168
        defer d.lock.RUnlock()
154✔
169

154✔
170
        // approved user don't need to be checked
154✔
171
        if req.UserID != "" && d.FirstMessageOnly && d.approvedUsers[req.UserID].Count >= d.FirstMessagesCount {
162✔
172
                return false, []spamcheck.Response{{Name: "pre-approved", Spam: false, Details: "user already approved"}}
8✔
173
        }
8✔
174

175
        // all the checks are performed sequentially, so we can collect all the results
176

177
        // check for stop words if any stop words are loaded
178
        if len(d.stopWords) > 0 {
169✔
179
                cr = append(cr, d.isStopWord(cleanMsg, req))
23✔
180
        }
23✔
181

182
        // check for emojis if max allowed emojis is set
183
        if d.MaxAllowedEmoji >= 0 {
191✔
184
                cr = append(cr, d.isManyEmojis(req.Msg))
45✔
185
        }
45✔
186

187
        // check for spam with meta-checks
188
        for _, mc := range d.metaChecks {
164✔
189
                cr = append(cr, mc(req))
18✔
190
        }
18✔
191

192
        // check for spam with Lua plugin checks
193
        for _, lc := range d.luaChecks {
171✔
194
                cr = append(cr, lc(req))
25✔
195
        }
25✔
196

197
        // check for spam with CAS API if CAS API URL is set
198
        if d.CasAPI != "" {
154✔
199
                cr = append(cr, d.isCasSpam(req.UserID))
8✔
200
        }
8✔
201

202
        if d.MultiLangWords > 0 {
161✔
203
                cr = append(cr, d.isMultiLang(req.Msg))
15✔
204
        }
15✔
205

206
        if d.AbnormalSpacing.Enabled {
159✔
207
                cr = append(cr, d.isAbnormalSpacing(req.Msg))
13✔
208
        }
13✔
209

210
        // check for message length exceed the minimum size, if min message length is set.
211
        // the check is done after first simple checks, because stop words and emojis can be triggered by short messages as well.
212
        if len([]rune(req.Msg)) < d.MinMsgLen {
150✔
213
                cr = append(cr, spamcheck.Response{Name: "message length", Spam: false, Details: "too short"})
4✔
214
                if isSpamDetected(cr) {
7✔
215
                        d.spamHistory.Push(req)
3✔
216
                        return true, cr // spam from the checks above
3✔
217
                }
3✔
218
                d.hamHistory.Push(req)
1✔
219
                return false, cr
1✔
220
        }
221

222
        // check for spam similarity if a similarity threshold is set and spam samples are loaded
223
        if d.SimilarityThreshold > 0 && len(d.tokenizedSpam) > 0 {
148✔
224
                cr = append(cr, d.isSpamSimilarityHigh(cleanMsg))
6✔
225
        }
6✔
226

227
        // check for spam with classifier if classifier is loaded
228
        if d.classifier.nAllDocument > 0 && d.classifier.nDocumentByClass["ham"] > 0 && d.classifier.nDocumentByClass["spam"] > 0 {
163✔
229
                cr = append(cr, d.isSpamClassified(cleanMsg))
21✔
230
        }
21✔
231

232
        spamDetected := isSpamDetected(cr)
142✔
233

142✔
234
        // we hit openai in two cases:
142✔
235
        //  - all other checks passed (ham result) and OpenAIVeto is false. In this case, openai primary used to improve false negative rate
142✔
236
        //  - one of the checks failed (spam result) and OpenAIVeto is true. In this case, openai primary used to improve false positive rate
142✔
237
        // FirstMessageOnly or FirstMessagesCount has to be set to use openai, because it's slow and expensive to run on all messages
142✔
238
        if d.openaiChecker != nil && (d.FirstMessageOnly || d.FirstMessagesCount > 0) {
148✔
239
                if !spamDetected && !d.OpenAIVeto || spamDetected && d.OpenAIVeto {
11✔
240
                        var hist []spamcheck.Request // by default, openai doesn't use history
5✔
241
                        if d.OpenAIHistorySize > 0 && d.HistorySize > 0 {
5✔
242
                                // if history size is set, we use the last N messages for openai
×
243
                                hist = d.hamHistory.Last(d.OpenAIHistorySize)
×
244
                        }
×
245
                        spam, details := d.openaiChecker.check(cleanMsg, hist)
5✔
246
                        cr = append(cr, details)
5✔
247
                        if spamDetected && details.Error != nil {
6✔
248
                                // spam detected with other checks, but openai failed. in this case, we still return spam, but log the error
1✔
249
                                log.Printf("[WARN] openai error: %v", details.Error)
1✔
250
                        } else {
5✔
251
                                log.Printf("[DEBUG] openai result: {%s}", details.String())
4✔
252
                                spamDetected = spam
4✔
253
                        }
4✔
254

255
                        // log if veto is enabled, and openai detected no spam for message that was detected as spam by other checks
256
                        if d.OpenAIVeto && !spam {
7✔
257
                                log.Printf("[DEBUG] openai vetoed ham message: %q, checks: %s", req.Msg, spamcheck.ChecksToString(cr))
2✔
258
                        }
2✔
259
                }
260
        }
261

262
        if spamDetected {
195✔
263
                d.spamHistory.Push(req)
53✔
264
                return true, cr
53✔
265
        }
53✔
266

267
        // update approved users only if it's not paranoid mode and not a check-only request
268
        if (d.FirstMessageOnly || d.FirstMessagesCount > 0) && !req.CheckOnly {
102✔
269
                ctx, cancel := d.ctxWithStoreTimeout()
13✔
270
                defer cancel()
13✔
271
                au := approved.UserInfo{
13✔
272
                        Count:     d.approvedUsers[req.UserID].Count + 1,
13✔
273
                        UserID:    req.UserID,
13✔
274
                        UserName:  req.UserName,
13✔
275
                        Timestamp: time.Now(),
13✔
276
                }
13✔
277
                d.approvedUsers[req.UserID] = au // update approved users status in memory
13✔
278
                if d.userStorage != nil {
13✔
279
                        // update approved users status in storage
×
280
                        _ = d.userStorage.Write(ctx, au) // ignore error, failed to write to storage is not critical here
×
281
                }
×
282
        }
283
        d.hamHistory.Push(req)
89✔
284
        return false, cr
89✔
285
}
286

287
// Reset resets spam samples/classifier, excluded tokens, stop words and approved users.
288
func (d *Detector) Reset() {
2✔
289
        d.lock.Lock()
2✔
290
        defer d.lock.Unlock()
2✔
291

2✔
292
        d.tokenizedSpam = []map[string]int{}
2✔
293
        d.excludedTokens = map[string]struct{}{}
2✔
294
        d.classifier.reset()
2✔
295
        d.approvedUsers = make(map[string]approved.UserInfo)
2✔
296
        d.stopWords = []string{}
2✔
297

2✔
298
        // close the Lua engine and reset Lua checks if it exists
2✔
299
        if d.luaEngine != nil {
3✔
300
                d.luaEngine.Close()
1✔
301
                d.luaEngine = nil
1✔
302
                d.luaChecks = nil
1✔
303
        }
1✔
304
}
305

306
// WithOpenAIChecker sets an openAIChecker for spam checking.
307
func (d *Detector) WithOpenAIChecker(client openAIClient, config OpenAIConfig) {
7✔
308
        d.openaiChecker = newOpenAIChecker(client, config)
7✔
309
}
7✔
310

311
// WithLuaEngine sets a Lua plugin engine and loads plugins
312
func (d *Detector) WithLuaEngine(engine LuaPluginEngine) error {
9✔
313
        d.luaEngine = engine
9✔
314

9✔
315
        if !d.LuaPlugins.Enabled || d.LuaPlugins.PluginsDir == "" {
12✔
316
                return nil
3✔
317
        }
3✔
318

319
        // load all plugins from the directory
320
        if err := d.luaEngine.LoadDirectory(d.LuaPlugins.PluginsDir); err != nil {
7✔
321
                return fmt.Errorf("failed to load Lua plugins: %w", err)
1✔
322
        }
1✔
323

324
        // register enabled plugins as Lua checks
325
        if len(d.LuaPlugins.EnabledPlugins) > 0 {
8✔
326
                for _, name := range d.LuaPlugins.EnabledPlugins {
9✔
327
                        pluginCheck, err := d.luaEngine.GetCheck(name)
6✔
328
                        if err != nil {
7✔
329
                                return fmt.Errorf("failed to get Lua check %q: %w", name, err)
1✔
330
                        }
1✔
331
                        // add to luaChecks
332
                        d.luaChecks = append(d.luaChecks, pluginCheck)
5✔
333
                }
334
        } else {
2✔
335
                // if no specific plugins are enabled, load all
2✔
336
                allChecks := d.luaEngine.GetAllChecks()
2✔
337
                for _, pluginCheck := range allChecks {
5✔
338
                        // add to luaChecks
3✔
339
                        d.luaChecks = append(d.luaChecks, pluginCheck)
3✔
340
                }
3✔
341
        }
342

343
        // set up a watcher for dynamic plugin reloading if enabled
344
        if d.LuaPlugins.DynamicReload {
5✔
345
                // we need to cast the luaEngine to a *plugin.Checker to access the watcher methods
1✔
346
                checker, ok := d.luaEngine.(*plugin.Checker)
1✔
347
                if !ok {
1✔
348
                        log.Printf("[WARN] dynamic Lua plugin reloading enabled but engine doesn't support it")
×
349
                        return nil
×
350
                }
×
351

352
                // create a watcher for the plugins directory
353
                watcher, err := plugin.NewWatcher(checker, d.LuaPlugins.PluginsDir)
1✔
354
                if err != nil {
1✔
355
                        return fmt.Errorf("failed to create watcher for Lua plugins: %w", err)
×
356
                }
×
357

358
                // set the watcher on the checker
359
                checker.SetWatcher(watcher)
1✔
360

1✔
361
                // start the watcher
1✔
362
                if err := watcher.Start(); err != nil {
1✔
363
                        return fmt.Errorf("failed to start watcher for Lua plugins: %w", err)
×
364
                }
×
365
        }
366

367
        return nil
4✔
368
}
369

370
// WithUserStorage sets a UserStorage for approved users and loads approved users from it.
371
func (d *Detector) WithUserStorage(storage UserStorage) (count int, err error) {
6✔
372
        d.lock.Lock()
6✔
373
        defer d.lock.Unlock()
6✔
374
        d.approvedUsers = make(map[string]approved.UserInfo) // reset approved users
6✔
375
        d.userStorage = storage
6✔
376

6✔
377
        ctx, cancel := d.ctxWithStoreTimeout()
6✔
378
        defer cancel()
6✔
379

6✔
380
        users, err := d.userStorage.Read(ctx)
6✔
381
        if err != nil {
6✔
382
                return 0, fmt.Errorf("failed to read approved users from storage: %w", err)
×
383
        }
×
384
        for _, user := range users {
18✔
385
                user.Count = d.FirstMessagesCount + 1 // +1 to skip first message check if count is 0
12✔
386
                d.approvedUsers[user.UserID] = user
12✔
387
        }
12✔
388
        return len(users), nil
6✔
389
}
390

391
// WithMetaChecks sets a list of meta-checkers.
392
func (d *Detector) WithMetaChecks(mc ...MetaCheck) {
1✔
393
        d.metaChecks = append(d.metaChecks, mc...)
1✔
394
}
1✔
395

396
// WithSpamUpdater sets a SampleUpdater for spam samples.
397
func (d *Detector) WithSpamUpdater(s SampleUpdater) { d.spamSamplesUpd = s }
3✔
398

399
// WithHamUpdater sets a SampleUpdater for ham samples.
400
func (d *Detector) WithHamUpdater(s SampleUpdater) { d.hamSamplesUpd = s }
2✔
401

402
// ApprovedUsers returns a list of approved users.
403
func (d *Detector) ApprovedUsers() (res []approved.UserInfo) {
1✔
404
        d.lock.RLock()
1✔
405
        defer d.lock.RUnlock()
1✔
406
        res = make([]approved.UserInfo, 0, len(d.approvedUsers))
1✔
407
        for _, info := range d.approvedUsers {
4✔
408
                res = append(res, info)
3✔
409
        }
3✔
410
        sort.Slice(res, func(i, j int) bool {
4✔
411
                return res[i].Timestamp.After(res[j].Timestamp)
3✔
412
        })
3✔
413
        return res
1✔
414
}
415

416
// IsApprovedUser checks if a given user ID is approved.
417
// It uses memory cache for approved users and compares the count of messages sent by the user.
418
func (d *Detector) IsApprovedUser(userID string) bool {
9✔
419
        d.lock.RLock()
9✔
420
        defer d.lock.RUnlock()
9✔
421

9✔
422
        ui, ok := d.approvedUsers[userID]
9✔
423
        if !ok {
12✔
424
                return false
3✔
425
        }
3✔
426
        return ui.Count > d.FirstMessagesCount
6✔
427
}
428

429
// AddApprovedUser adds user IDs to the list of approved users.
430
func (d *Detector) AddApprovedUser(user approved.UserInfo) error {
5✔
431
        d.lock.Lock()
5✔
432
        defer d.lock.Unlock()
5✔
433
        ts := user.Timestamp
5✔
434
        if ts.IsZero() {
10✔
435
                ts = time.Now()
5✔
436
        }
5✔
437
        d.approvedUsers[user.UserID] = approved.UserInfo{
5✔
438
                UserID:    user.UserID,
5✔
439
                UserName:  user.UserName,
5✔
440
                Count:     d.FirstMessagesCount + 1, // +1 to skip first message check if count is 0
5✔
441
                Timestamp: ts,
5✔
442
        }
5✔
443

5✔
444
        if d.userStorage != nil {
8✔
445
                ctx, cancel := d.ctxWithStoreTimeout()
3✔
446
                defer cancel()
3✔
447
                if err := d.userStorage.Write(ctx, user); err != nil {
3✔
448
                        return fmt.Errorf("failed to write approved user %+v to storage: %w", user, err)
×
449
                }
×
450
        }
451
        return nil
5✔
452
}
453

454
// RemoveApprovedUser removes approved user for given IDs
455
func (d *Detector) RemoveApprovedUser(id string) error {
2✔
456
        d.lock.Lock()
2✔
457
        delete(d.approvedUsers, id)
2✔
458
        d.lock.Unlock()
2✔
459

2✔
460
        if d.userStorage != nil {
3✔
461
                ctx, cancel := d.ctxWithStoreTimeout()
1✔
462
                defer cancel()
1✔
463
                if err := d.userStorage.Delete(ctx, id); err != nil {
1✔
464
                        return fmt.Errorf("failed to delete approved user %s from storage: %w", id, err)
×
465
                }
×
466
        }
467
        return nil
2✔
468
}
469

470
// GetLuaPluginNames returns the list of available Lua plugin names.
471
func (d *Detector) GetLuaPluginNames() []string {
5✔
472
        d.lock.RLock()
5✔
473
        defer d.lock.RUnlock()
5✔
474

5✔
475
        if d.luaEngine == nil || !d.LuaPlugins.Enabled {
7✔
476
                return []string{}
2✔
477
        }
2✔
478

479
        allChecks := d.luaEngine.GetAllChecks()
3✔
480
        result := make([]string, 0, len(allChecks))
3✔
481

3✔
482
        for name := range allChecks {
8✔
483
                result = append(result, name)
5✔
484
        }
5✔
485

486
        // sort the result for consistent output
487
        sort.Strings(result)
3✔
488
        return result
3✔
489
}
490

491
// LoadSamples loads spam samples from a reader and updates the classifier.
492
// Reset spam, ham samples/classifier, and excluded tokens.
493
func (d *Detector) LoadSamples(exclReader io.Reader, spamReaders, hamReaders []io.Reader) (LoadResult, error) {
12✔
494
        d.lock.Lock()
12✔
495
        defer d.lock.Unlock()
12✔
496

12✔
497
        d.tokenizedSpam = []map[string]int{}
12✔
498
        d.excludedTokens = map[string]struct{}{}
12✔
499
        d.classifier.reset()
12✔
500

12✔
501
        // excluded tokens should be loaded before spam samples to exclude them from spam tokenization
12✔
502
        for t := range d.readerIterator(exclReader) {
23✔
503
                d.excludedTokens[strings.ToLower(t)] = struct{}{}
11✔
504
        }
11✔
505
        lr := LoadResult{ExcludedTokens: len(d.excludedTokens)}
12✔
506

12✔
507
        // load spam samples and update the classifier with them
12✔
508
        docs := []document{}
12✔
509
        for token := range d.readerIterator(spamReaders...) {
30✔
510
                tokenizedSpam := d.tokenize(token)
18✔
511
                d.tokenizedSpam = append(d.tokenizedSpam, tokenizedSpam) // add to list of samples
18✔
512
                tokens := make([]string, 0, len(tokenizedSpam))
18✔
513
                for token := range tokenizedSpam {
62✔
514
                        tokens = append(tokens, token)
44✔
515
                }
44✔
516
                docs = append(docs, newDocument(ClassSpam, tokens...))
18✔
517
                lr.SpamSamples++
18✔
518
        }
519

520
        // load ham samples and update the classifier with them
521
        for token := range d.readerIterator(hamReaders...) {
38✔
522
                tokenizedSpam := d.tokenize(token)
26✔
523
                tokens := make([]string, 0, len(tokenizedSpam))
26✔
524
                for token := range tokenizedSpam {
93✔
525
                        tokens = append(tokens, token)
67✔
526
                }
67✔
527
                docs = append(docs, document{spamClass: ClassHam, tokens: tokens})
26✔
528
                lr.HamSamples++
26✔
529
        }
530

531
        d.classifier.learn(docs...)
12✔
532
        return lr, nil
12✔
533
}
534

535
// LoadStopWords loads stop words from a reader. Reset stop words list before loading.
536
func (d *Detector) LoadStopWords(readers ...io.Reader) (LoadResult, error) {
16✔
537
        d.lock.Lock()
16✔
538
        defer d.lock.Unlock()
16✔
539

16✔
540
        d.stopWords = []string{}
16✔
541
        for t := range d.readerIterator(readers...) {
45✔
542
                d.stopWords = append(d.stopWords, strings.ToLower(t))
29✔
543
        }
29✔
544
        return LoadResult{StopWords: len(d.stopWords)}, nil
16✔
545
}
546

547
// UpdateSpam appends a message to the spam samples file and updates the classifier
548
func (d *Detector) UpdateSpam(msg string) error {
2✔
549
        return d.updateSample(msg, d.spamSamplesUpd, ClassSpam)
2✔
550
}
2✔
551

552
// UpdateHam appends a message to the ham samples file and updates the classifier
553
func (d *Detector) UpdateHam(msg string) error {
1✔
554
        return d.updateSample(msg, d.hamSamplesUpd, ClassHam)
1✔
555
}
1✔
556

557
// RemoveSpam removes a message from the spam samples file and updates the classifier by unlearning
558
func (d *Detector) RemoveSpam(msg string) error {
3✔
559
        return d.removeSample(msg, d.spamSamplesUpd, ClassSpam)
3✔
560
}
3✔
561

562
// RemoveHam removes a message from the ham samples file and updates the classifier by unlearning
563
func (d *Detector) RemoveHam(msg string) error {
2✔
564
        return d.removeSample(msg, d.hamSamplesUpd, ClassHam)
2✔
565
}
2✔
566

567
// updateSample appends a message to the samples store and updates the classifier
568
// doesn't reset state, update append samples
569
func (d *Detector) updateSample(msg string, upd SampleUpdater, sc spamClass) error {
3✔
570
        d.lock.Lock()
3✔
571
        defer d.lock.Unlock()
3✔
572

3✔
573
        if upd == nil {
3✔
574
                return nil
×
575
        }
×
576

577
        // write to dynamic samples storage
578
        if err := upd.Append(msg); err != nil {
3✔
579
                return fmt.Errorf("can't update %s samples: %w", sc, err)
×
580
        }
×
581

582
        // load samples and update the classifier with them
583
        docs := d.buildDocs(msg, sc)
3✔
584
        d.classifier.learn(docs...)
3✔
585

3✔
586
        // update tokenized spam samples for similarity check
3✔
587
        if sc == ClassSpam {
5✔
588
                tokenizedSpam := d.tokenize(msg)
2✔
589
                d.tokenizedSpam = append(d.tokenizedSpam, tokenizedSpam)
2✔
590
        }
2✔
591

592
        return nil
3✔
593
}
594

595
// removeSample removes a message from the spam samples file and updates the classifier by unlearning
596
func (d *Detector) removeSample(msg string, upd SampleUpdater, sc spamClass) error {
5✔
597
        d.lock.Lock()
5✔
598
        defer d.lock.Unlock()
5✔
599

5✔
600
        if upd == nil {
5✔
601
                return nil
×
602
        }
×
603

604
        // first validate that we can unlearn this sample
605
        docs := d.buildDocs(msg, sc)
5✔
606
        if err := d.classifier.unlearn(docs...); err != nil {
7✔
607
                return fmt.Errorf("can't unlearn %s samples: %w", sc, err)
2✔
608
        }
2✔
609

610
        // if unlearn succeeded, remove from storage
611
        if err := upd.Remove(msg); err != nil {
4✔
612
                // try to relearn since storage update failed
1✔
613
                d.classifier.learn(docs...)
1✔
614
                return fmt.Errorf("can't remove %s samples: %w", sc, err)
1✔
615
        }
1✔
616
        return nil
2✔
617
}
618

619
// buildDocs builds a list of classifier documents from a message
620
func (d *Detector) buildDocs(msg string, sc spamClass) []document {
9✔
621
        docs := []document{}
9✔
622
        for token := range d.readerIterator(bytes.NewBufferString(msg)) {
18✔
623
                tokenizedSample := d.tokenize(token)
9✔
624
                tokens := make([]string, 0, len(tokenizedSample))
9✔
625
                for token := range tokenizedSample {
33✔
626
                        tokens = append(tokens, token)
24✔
627
                }
24✔
628
                docs = append(docs, document{spamClass: sc, tokens: tokens})
9✔
629
        }
630
        return docs
9✔
631
}
632

633
// readerIterator parses readers and returns an iterator of data elements, each line is an element.
634
func (d *Detector) readerIterator(readers ...io.Reader) iter.Seq[string] {
69✔
635
        return func(yield func(string) bool) {
138✔
636
                for _, reader := range readers {
138✔
637
                        scanner := bufio.NewScanner(reader)
69✔
638
                        for scanner.Scan() {
181✔
639
                                line := scanner.Text()
112✔
640
                                // each line with a single element
112✔
641
                                cleanToken := strings.Trim(line, " \n\r\t")
112✔
642
                                if cleanToken != "" {
220✔
643
                                        if !yield(cleanToken) {
108✔
644
                                                return
×
645
                                        }
×
646
                                }
647
                        }
648

649
                        if err := scanner.Err(); err != nil {
69✔
650
                                log.Printf("[WARN] failed to read tokens, error=%v", err)
×
651
                        }
×
652
                }
653
        }
654
}
655

656
// tokenize takes a string and returns a map where the keys are unique words (tokens)
657
// and the values are the frequencies of those words in the string.
658
// exclude tokens representing common words.
659
func (d *Detector) tokenize(inp string) map[string]int {
87✔
660
        isExcludedToken := func(token string) bool {
397✔
661
                if _, ok := d.excludedTokens[strings.ToLower(token)]; ok {
323✔
662
                        return true
13✔
663
                }
13✔
664
                return false
297✔
665
        }
666

667
        tokenFrequency := make(map[string]int)
87✔
668
        tokens := strings.Fields(inp)
87✔
669
        for _, token := range tokens {
397✔
670
                if isExcludedToken(token) {
323✔
671
                        continue
13✔
672
                }
673
                token = cleanEmoji(token)
297✔
674
                token = strings.Trim(token, ".,!?-:;()#")
297✔
675
                token = strings.ToLower(token)
297✔
676
                if len([]rune(token)) < 3 {
319✔
677
                        continue
22✔
678
                }
679
                tokenFrequency[strings.ToLower(token)]++
275✔
680
        }
681
        return tokenFrequency
87✔
682
}
683

684
// isSpam checks if a given message is similar to any of the known bad messages
685
func (d *Detector) isSpamSimilarityHigh(msg string) spamcheck.Response {
6✔
686
        // check for spam similarity
6✔
687
        tokenizedMessage := d.tokenize(msg)
6✔
688
        maxSimilarity := 0.0
6✔
689
        for _, spam := range d.tokenizedSpam {
16✔
690
                similarity := d.cosineSimilarity(tokenizedMessage, spam)
10✔
691
                if similarity > maxSimilarity {
15✔
692
                        maxSimilarity = similarity
5✔
693
                }
5✔
694
                if similarity >= d.SimilarityThreshold {
13✔
695
                        return spamcheck.Response{Spam: true, Name: "similarity",
3✔
696
                                Details: fmt.Sprintf("%0.2f/%0.2f", maxSimilarity, d.SimilarityThreshold)}
3✔
697
                }
3✔
698
        }
699
        return spamcheck.Response{Spam: false, Name: "similarity", Details: fmt.Sprintf("%0.2f/%0.2f", maxSimilarity, d.SimilarityThreshold)}
3✔
700
}
701

702
// cosineSimilarity calculates the cosine similarity between two token frequency maps.
703
func (d *Detector) cosineSimilarity(a, b map[string]int) float64 {
10✔
704
        if len(a) == 0 || len(b) == 0 {
10✔
705
                return 0.0
×
706
        }
×
707

708
        dotProduct := 0      // sum of product of corresponding frequencies
10✔
709
        normA, normB := 0, 0 // square root of sum of squares of frequencies
10✔
710

10✔
711
        for key, val := range a {
44✔
712
                dotProduct += val * b[key]
34✔
713
                normA += val * val
34✔
714
        }
34✔
715
        for _, val := range b {
36✔
716
                normB += val * val
26✔
717
        }
26✔
718

719
        if normA == 0 || normB == 0 {
10✔
720
                return 0.0
×
721
        }
×
722

723
        // cosine similarity formula
724
        return float64(dotProduct) / (math.Sqrt(float64(normA)) * math.Sqrt(float64(normB)))
10✔
725
}
726

727
// isCasSpam checks if a given user ID is a spammer with CAS API.
728
func (d *Detector) isCasSpam(msgID string) spamcheck.Response {
8✔
729
        if msgID == "" {
9✔
730
                return spamcheck.Response{Spam: false, Name: "cas", Details: "check disabled"}
1✔
731
        }
1✔
732
        if _, err := strconv.ParseInt(msgID, 10, 64); err != nil {
7✔
733
                return spamcheck.Response{Spam: false, Name: "cas", Details: fmt.Sprintf("invalid user id %q", msgID)}
×
734
        }
×
735
        reqURL := fmt.Sprintf("%s/check?user_id=%s", d.CasAPI, msgID)
7✔
736
        req, err := http.NewRequest("GET", reqURL, http.NoBody)
7✔
737
        if err != nil {
7✔
738
                return spamcheck.Response{Spam: false, Name: "cas", Details: fmt.Sprintf("failed to make request %s: %v", reqURL, err)}
×
739
        }
×
740

741
        if d.CasUserAgent != "" {
8✔
742
                req.Header.Set("User-Agent", d.CasUserAgent)
1✔
743
        }
1✔
744

745
        resp, err := d.HTTPClient.Do(req)
7✔
746
        if err != nil {
7✔
747
                return spamcheck.Response{Spam: false, Name: "cas", Details: fmt.Sprintf("ffailed to send request %s: %v", reqURL, err)}
×
748
        }
×
749
        defer resp.Body.Close()
7✔
750

7✔
751
        respData := struct {
7✔
752
                OK          bool   `json:"ok"` // ok means user is a spammer
7✔
753
                Description string `json:"description"`
7✔
754
        }{}
7✔
755

7✔
756
        if err := json.NewDecoder(resp.Body).Decode(&respData); err != nil {
7✔
757
                return spamcheck.Response{Spam: false, Name: "cas", Details: fmt.Sprintf("failed to parse response from %s: %v", reqURL, err)}
×
758
        }
×
759
        respData.Description = strings.ToLower(respData.Description)
7✔
760
        respData.Description = strings.TrimSuffix(respData.Description, ".")
7✔
761

7✔
762
        if respData.OK {
9✔
763
                // may return empty description on detected spam
2✔
764
                if respData.Description == "" {
3✔
765
                        respData.Description = "spam detected"
1✔
766
                }
1✔
767
                return spamcheck.Response{Name: "cas", Spam: true, Details: respData.Description}
2✔
768
        }
769
        details := respData.Description
5✔
770
        if details == "" {
5✔
771
                details = "not found"
×
772
        }
×
773
        return spamcheck.Response{Name: "cas", Spam: false, Details: details}
5✔
774
}
775

776
// isSpamClassified classify tokens from a document
777
func (d *Detector) isSpamClassified(msg string) spamcheck.Response {
21✔
778
        tm := d.tokenize(msg)
21✔
779
        tokens := make([]string, 0, len(tm))
21✔
780
        for token := range tm {
121✔
781
                tokens = append(tokens, token)
100✔
782
        }
100✔
783
        class, prob, certain := d.classifier.classify(tokens...)
21✔
784
        isSpam := class == ClassSpam && certain && (d.MinSpamProbability == 0 || prob >= d.MinSpamProbability)
21✔
785

21✔
786
        // handle NaN or infinite probability values
21✔
787
        probStr := "0.00"
21✔
788
        if !math.IsNaN(prob) && !math.IsInf(prob, 0) {
42✔
789
                probStr = fmt.Sprintf("%.2f", prob)
21✔
790
        }
21✔
791

792
        return spamcheck.Response{Name: "classifier", Spam: isSpam,
21✔
793
                Details: fmt.Sprintf("probability of %s: %s%%", class, probStr)}
21✔
794
}
795

796
// isStopWord checks if a given message or username contains any of the stop words.
797
func (d *Detector) isStopWord(msg string, req spamcheck.Request) spamcheck.Response {
23✔
798
        // check message text
23✔
799
        cleanMsg := cleanEmoji(strings.ToLower(msg))
23✔
800
        for _, word := range d.stopWords { // stop words are already lowercased
69✔
801
                if strings.Contains(cleanMsg, strings.ToLower(word)) {
59✔
802
                        return spamcheck.Response{Name: "stopword", Spam: true, Details: word}
13✔
803
                }
13✔
804
        }
805

806
        // check username and user id if they are not empty for stop words
807
        names := []string{}
10✔
808
        if req.UserName != "" {
15✔
809
                names = append(names, req.UserName)
5✔
810
        }
5✔
811
        if req.UserID != "" {
17✔
812
                names = append(names, req.UserID)
7✔
813
        }
7✔
814
        for _, name := range names {
20✔
815
                for _, word := range d.stopWords {
44✔
816
                        if strings.Contains(strings.ToLower(name), strings.ToLower(word)) {
37✔
817
                                return spamcheck.Response{Name: "stopword", Spam: true, Details: word}
3✔
818
                        }
3✔
819
                }
820
        }
821

822
        return spamcheck.Response{Name: "stopword", Spam: false, Details: "not found"}
7✔
823
}
824

825
// isManyEmojis checks if a given message contains more than MaxAllowedEmoji emojis.
826
func (d *Detector) isManyEmojis(msg string) spamcheck.Response {
45✔
827
        count := countEmoji(msg)
45✔
828
        return spamcheck.Response{Name: "emoji", Spam: count > d.MaxAllowedEmoji, Details: fmt.Sprintf("%d/%d", count, d.MaxAllowedEmoji)}
45✔
829
}
45✔
830

831
// isMultiLang checks if a given message contains more than MultiLangWords multi-lingual words.
832
func (d *Detector) isMultiLang(msg string) spamcheck.Response {
15✔
833
        isMultiLingual := func(word string) bool {
173✔
834
                scripts := make(map[string]bool)
158✔
835
                for _, r := range word {
820✔
836
                        if r == 'i' || unicode.IsSpace(r) || unicode.IsNumber(r) { // skip 'i' (common in many langs) and spaces
703✔
837
                                continue
41✔
838
                        }
839

840
                        scriptFound := false
621✔
841
                        for name, table := range unicode.Scripts {
54,643✔
842
                                if unicode.Is(table, r) {
54,643✔
843
                                        if name != "Common" && name != "Inherited" {
1,187✔
844
                                                scripts[name] = true
566✔
845
                                                if len(scripts) > 1 {
624✔
846
                                                        return true
58✔
847
                                                }
58✔
848
                                                scriptFound = true
508✔
849
                                        }
850
                                        break
563✔
851
                                }
852
                        }
853

854
                        // if no specific script was found, it might be a symbol or punctuation
855
                        if !scriptFound {
618✔
856
                                // check for mathematical alphanumeric symbols and letterlike symbols
55✔
857
                                if unicode.In(r, unicode.Other_Math, unicode.Other_Alphabetic) ||
55✔
858
                                        (r >= '\U0001D400' && r <= '\U0001D7FF') || // mathematical Alphanumeric Symbols
55✔
859
                                        (r >= '\u2100' && r <= '\u214F') { // letterlike Symbols
65✔
860
                                        scripts["Mathematical"] = true
10✔
861
                                        if len(scripts) > 1 {
15✔
862
                                                return true
5✔
863
                                        }
5✔
864
                                } else if !unicode.IsPunct(r) && !unicode.IsSymbol(r) {
46✔
865
                                        // if it's not punctuation or a symbol, count it as "Other"
1✔
866
                                        scripts["Other"] = true
1✔
867
                                        if len(scripts) > 1 {
1✔
UNCOV
868
                                                return true
×
UNCOV
869
                                        }
×
870
                                }
871
                        }
872
                }
873
                return false
95✔
874
        }
875

876
        count := 0
15✔
877
        words := strings.FieldsFunc(msg, func(r rune) bool {
1,047✔
878
                return unicode.IsSpace(r) || r == '-'
1,032✔
879
        })
1,032✔
880
        for _, word := range words {
173✔
881
                if isMultiLingual(word) {
221✔
882
                        count++
63✔
883
                }
63✔
884
        }
885
        if count >= d.MultiLangWords {
22✔
886
                return spamcheck.Response{Name: "multi-lingual", Spam: true, Details: fmt.Sprintf("%d/%d", count, d.MultiLangWords)}
7✔
887
        }
7✔
888
        return spamcheck.Response{Name: "multi-lingual", Spam: false, Details: fmt.Sprintf("%d/%d", count, d.MultiLangWords)}
8✔
889
}
890

891
// isAbnormalSpacing detects abnormal spacing patterns used to evade filters
892
// things like this: "w o r d s p a c i n g some thing he re blah blah"
893
func (d *Detector) isAbnormalSpacing(msg string) spamcheck.Response {
13✔
894
        text := strings.ToUpper(msg)
13✔
895

13✔
896
        // quick check for empty or very short text
13✔
897
        if len(text) < 10 {
15✔
898
                return spamcheck.Response{
2✔
899
                        Name:    "word-spacing",
2✔
900
                        Spam:    false,
2✔
901
                        Details: "too short",
2✔
902
                }
2✔
903
        }
2✔
904

905
        words := strings.Fields(text)
11✔
906
        // check for minimum number of words
11✔
907
        if len(words) < d.AbnormalSpacing.MinWordsCount {
12✔
908
                return spamcheck.Response{
1✔
909
                        Name:    "word-spacing",
1✔
910
                        Spam:    false,
1✔
911
                        Details: fmt.Sprintf("too few words (%d)", len(words)),
1✔
912
                }
1✔
913
        }
1✔
914

915
        // count letters and spaces in original text
916
        var totalChars, spaces int
10✔
917
        for _, r := range text {
1,339✔
918
                if unicode.IsLetter(r) {
2,324✔
919
                        totalChars++
995✔
920
                } else if unicode.IsSpace(r) {
1,640✔
921
                        spaces++
311✔
922
                }
311✔
923
        }
924

925
        // look for suspicious word lengths and spacing patterns
926
        shortWords := 0
10✔
927
        if d.AbnormalSpacing.ShortWordLen > 0 { // if ShortWordLen is 0, skip short word detection
19✔
928
                for _, word := range words {
287✔
929
                        wordRunes := []rune(word)
278✔
930
                        if len(wordRunes) <= d.AbnormalSpacing.ShortWordLen && len(wordRunes) > 0 {
443✔
931
                                shortWords++
165✔
932
                        }
165✔
933
                }
934
        }
935

936
        // safety check
937
        if spaces == 0 || totalChars == 0 {
10✔
UNCOV
938
                return spamcheck.Response{
×
UNCOV
939
                        Name:    "word-spacing",
×
UNCOV
940
                        Spam:    false,
×
UNCOV
941
                        Details: "no spaces or letters",
×
UNCOV
942
                }
×
UNCOV
943
        }
×
944

945
        // calculate ratios
946
        spaceRatio := float64(spaces) / float64(totalChars)
10✔
947
        shortWordRatio := float64(shortWords) / float64(len(words))
10✔
948
        if shortWordRatio > d.AbnormalSpacing.ShortWordRatioThreshold || spaceRatio > d.AbnormalSpacing.SpaceRatioThreshold {
16✔
949
                return spamcheck.Response{
6✔
950
                        Name:    "word-spacing",
6✔
951
                        Spam:    true,
6✔
952
                        Details: fmt.Sprintf("abnormal (ratio: %.2f, short: %.0f%%)", spaceRatio, shortWordRatio*100),
6✔
953
                }
6✔
954
        }
6✔
955

956
        return spamcheck.Response{
4✔
957
                Name:    "word-spacing",
4✔
958
                Spam:    false,
4✔
959
                Details: fmt.Sprintf("normal (ratio: %.2f, short: %.0f%%)", spaceRatio, shortWordRatio*100),
4✔
960
        }
4✔
961
}
962

963
// cleanText removes control and format characters from a given text
964
func (d *Detector) cleanText(text string) string {
162✔
965
        var result strings.Builder
162✔
966
        result.Grow(len(text))
162✔
967
        for _, r := range text {
16,584✔
968
                // skip control and format characters
16,422✔
969
                if unicode.Is(unicode.Cc, r) || unicode.Is(unicode.Cf, r) {
16,467✔
970
                        continue
45✔
971
                }
972
                // skip specific ranges of invisible characters
973
                if (r >= 0x200B && r <= 0x200F) || (r >= 0x2060 && r <= 0x206F) {
16,377✔
UNCOV
974
                        continue
×
975
                }
976
                result.WriteRune(r)
16,377✔
977
        }
978
        return result.String()
162✔
979
}
980

981
func (d *Detector) ctxWithStoreTimeout() (context.Context, context.CancelFunc) {
23✔
982
        if d.StorageTimeout == 0 {
46✔
983
                return context.Background(), func() {}
46✔
984
        }
UNCOV
985
        return context.WithTimeout(context.Background(), d.StorageTimeout)
×
986
}
987

988
func cleanEmoji(s string) string {
328✔
989
        return gomoji.RemoveEmojis(s)
328✔
990
}
328✔
991

992
func countEmoji(s string) int {
56✔
993
        return len(gomoji.CollectAll(s))
56✔
994
}
56✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc