• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindersec / minder / 13123856909

03 Feb 2025 10:12PM UTC coverage: 57.48% (-0.03%) from 57.505%
13123856909

Pull #5387

github

web-flow
Merge c75ff4c9d into f5f00edca
Pull Request #5387: Adds check to ensure config file exists. Fixes: #4513

18132 of 31545 relevant lines covered (57.48%)

37.69 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.5
/pkg/config/utils.go
1
// SPDX-FileCopyrightText: Copyright 2023 The Minder Authors
2
// SPDX-License-Identifier: Apache-2.0
3

4
// Package config contains the configuration for the minder cli and server
5
package config
6

7
import (
8
        "encoding/json"
9
        "fmt"
10
        "os"
11
        "path/filepath"
12
        "reflect"
13
        "strconv"
14
        "strings"
15
        "time"
16
        "unicode"
17

18
        "github.com/spf13/pflag"
19
        "github.com/spf13/viper"
20
        "gopkg.in/yaml.v3"
21
)
22

23
// FlagInst is a function that creates a flag and returns a pointer to the value
24
type FlagInst[V any] func(name string, value V, usage string) *V
25

26
// FlagInstShort is a function that creates a flag and returns a pointer to the value
27
type FlagInstShort[V any] func(name, shorthand string, value V, usage string) *V
28

29
// BindConfigFlag is a helper function that binds a configuration value to a flag.
30
//
31
// Parameters:
32
// - v: The viper.Viper object used to retrieve the configuration value.
33
// - flags: The pflag.FlagSet object used to retrieve the flag value.
34
// - viperPath: The path used to retrieve the configuration value from Viper.
35
// - cmdLineArg: The flag name used to check if the flag has been set and to retrieve its value.
36
// - help: The help text for the flag.
37
// - defaultValue: A default value used to determine the type of the flag (string, int, etc.).
38
// - binder: A function that creates a flag and returns a pointer to the value.
39
func BindConfigFlag[V any](
40
        v *viper.Viper,
41
        flags *pflag.FlagSet,
42
        viperPath string,
43
        cmdLineArg string,
44
        defaultValue V,
45
        help string,
46
        binder FlagInst[V],
47
) error {
82✔
48
        binder(cmdLineArg, defaultValue, help)
82✔
49
        return doViperBind[V](v, flags, viperPath, cmdLineArg, defaultValue)
82✔
50
}
82✔
51

52
// BindConfigFlagWithShort is a helper function that binds a configuration value to a flag.
53
//
54
// Parameters:
55
// - v: The viper.Viper object used to retrieve the configuration value.
56
// - flags: The pflag.FlagSet object used to retrieve the flag value.
57
// - viperPath: The path used to retrieve the configuration value from Viper.
58
// - cmdLineArg: The flag name used to check if the flag has been set and to retrieve its value.
59
// - short: The short name for the flag.
60
// - help: The help text for the flag.
61
// - defaultValue: A default value used to determine the type of the flag (string, int, etc.).
62
// - binder: A function that creates a flag and returns a pointer to the value.
63
func BindConfigFlagWithShort[V any](
64
        v *viper.Viper,
65
        flags *pflag.FlagSet,
66
        viperPath string,
67
        cmdLineArg string,
68
        short string,
69
        defaultValue V,
70
        help string,
71
        binder FlagInstShort[V],
72
) error {
22✔
73
        binder(cmdLineArg, short, defaultValue, help)
22✔
74
        return doViperBind[V](v, flags, viperPath, cmdLineArg, defaultValue)
22✔
75
}
22✔
76

77
func doViperBind[V any](
78
        v *viper.Viper,
79
        flags *pflag.FlagSet,
80
        viperPath string,
81
        cmdLineArg string,
82
        defaultValue V,
83
) error {
104✔
84
        v.SetDefault(viperPath, defaultValue)
104✔
85
        if err := v.BindPFlag(viperPath, flags.Lookup(cmdLineArg)); err != nil {
104✔
86
                return fmt.Errorf("failed to bind flag %s to viper path %s: %w", cmdLineArg, viperPath, err)
×
87
        }
×
88

89
        return nil
104✔
90
}
91

92
// GetConfigFileData returns the data from the given configuration file.
93
func GetConfigFileData(cfgFilePath string) (interface{}, error) {
2✔
94
        cfgFileBytes, err := os.ReadFile(filepath.Clean(cfgFilePath))
2✔
95
        if err != nil {
2✔
96
                return nil, err
×
97
        }
×
98

99
        var cfgFileData interface{}
2✔
100
        err = yaml.Unmarshal(cfgFileBytes, &cfgFileData)
2✔
101
        if err != nil {
2✔
102
                return nil, err
×
103
        }
×
104

105
        return cfgFileData, nil
2✔
106
}
107

108
// GetRelevantCfgPath returns the first path that exists (and is a config file).
109
func GetRelevantCfgPath(paths []string) string {
11✔
110
        for _, path := range paths {
19✔
111
                if path == "" {
8✔
112
                        continue
×
113
                }
114

115
                cleanPath := filepath.Clean(path)
8✔
116
                if info, err := os.Stat(cleanPath); err == nil && !info.IsDir() {
16✔
117
                        return cleanPath
8✔
118
                }
8✔
119
        }
120

121
        return ""
3✔
122
}
123

124
// GetKeysWithNullValueFromYAML returns a list of paths to null values in the given configuration data.
125
func GetKeysWithNullValueFromYAML(data interface{}, currentPath string) []string {
71✔
126
        var keysWithNullValue []string
71✔
127
        switch v := data.(type) {
71✔
128
        // gopkg yaml.v2 unmarshals YAML maps into map[interface{}]interface{}.
129
        // gopkg yaml.v3 unmarshals YAML maps into map[string]interface{} or map[interface{}]interface{}.
130
        case map[interface{}]interface{}:
7✔
131
                for key, value := range v {
26✔
132
                        var newPath string
19✔
133
                        if key == nil {
20✔
134
                                newPath = fmt.Sprintf("%s.null", currentPath) // X.<nil> is not a valid path
1✔
135
                        } else {
19✔
136
                                newPath = fmt.Sprintf("%s.%v", currentPath, key)
18✔
137
                        }
18✔
138
                        if value == nil {
24✔
139
                                keysWithNullValue = append(keysWithNullValue, newPath)
5✔
140
                        } else {
19✔
141
                                keysWithNullValue = append(keysWithNullValue, GetKeysWithNullValueFromYAML(value, newPath)...)
14✔
142
                        }
14✔
143
                }
144

145
        case map[string]interface{}:
13✔
146
                for key, value := range v {
44✔
147
                        newPath := fmt.Sprintf("%s.%v", currentPath, key)
31✔
148
                        if value == nil {
38✔
149
                                keysWithNullValue = append(keysWithNullValue, newPath)
7✔
150
                        } else {
31✔
151
                                keysWithNullValue = append(keysWithNullValue, GetKeysWithNullValueFromYAML(value, newPath)...)
24✔
152
                        }
24✔
153
                }
154

155
        case []interface{}:
14✔
156
                for i, item := range v {
54✔
157
                        newPath := fmt.Sprintf("%s[%d]", currentPath, i)
40✔
158
                        if item == nil {
53✔
159
                                keysWithNullValue = append(keysWithNullValue, newPath)
13✔
160
                        } else {
40✔
161
                                keysWithNullValue = append(keysWithNullValue, GetKeysWithNullValueFromYAML(item, newPath)...)
27✔
162
                        }
27✔
163
                }
164
        }
165

166
        return keysWithNullValue
71✔
167
}
168

169
// ReadConfigFromViper reads the configuration from the given Viper instance.
170
// This will return the already-parsed and validated configuration, or an error.
171
func ReadConfigFromViper[CFG any](v *viper.Viper) (*CFG, error) {
15✔
172
        var cfg CFG
15✔
173
        if err := v.Unmarshal(&cfg); err != nil {
15✔
174
                return nil, err
×
175
        }
×
176
        return &cfg, nil
15✔
177
}
178

179
// SetViperStructDefaults recursively sets the viper default values for the given struct.
180
//
181
// Per https://github.com/spf13/viper/issues/188#issuecomment-255519149, and
182
// https://github.com/spf13/viper/issues/761, we need to call viper.SetDefault() for each
183
// field in the struct to be able to use env var overrides.  This also lets us use the
184
// struct as the source of default values, so yay?
185
func SetViperStructDefaults(v *viper.Viper, prefix string, s any) {
191✔
186
        structType := reflect.TypeOf(s)
191✔
187

191✔
188
        for i := 0; i < structType.NumField(); i++ {
841✔
189
                field := structType.Field(i)
650✔
190
                if unicode.IsLower([]rune(field.Name)[0]) {
650✔
191
                        // Skip private fields
×
192
                        continue
×
193
                }
194
                if field.Tag.Get("mapstructure") == "" {
650✔
195
                        // Error, need a tag
×
196
                        panic(fmt.Sprintf("Untagged config struct field %q", field.Name))
×
197
                }
198

199
                var valueName string
650✔
200
                // Check if the tag is "squash" and if so, don't add the field name to the prefix
650✔
201
                if field.Tag.Get("mapstructure") == ",squash" {
665✔
202
                        if strings.HasSuffix(prefix, ".") {
30✔
203
                                valueName = strings.ToLower(prefix[:len(prefix)-1])
15✔
204
                        } else {
15✔
205
                                valueName = strings.ToLower(prefix)
×
206
                        }
×
207
                } else {
635✔
208
                        valueName = strings.ToLower(prefix + field.Tag.Get("mapstructure"))
635✔
209
                }
635✔
210
                fieldType := field.Type
650✔
211

650✔
212
                // Extract a default value the `default` struct tag
650✔
213
                // we don't support all value types yet, but we can add them as needed
650✔
214
                value := field.Tag.Get("default")
650✔
215

650✔
216
                // Dereference one level of pointers, if present
650✔
217
                if fieldType.Kind() == reflect.Ptr {
668✔
218
                        fieldType = fieldType.Elem()
18✔
219
                }
18✔
220

221
                if fieldType.Kind() == reflect.Struct {
834✔
222
                        SetViperStructDefaults(v, valueName+".", reflect.Zero(fieldType).Interface())
184✔
223
                        if _, ok := field.Tag.Lookup("default"); ok {
191✔
224
                                overrideViperStructDefaults(v, valueName, value)
7✔
225
                        }
7✔
226
                        continue
184✔
227
                }
228

229
                defaultValue := getDefaultValue(field, value, valueName)
466✔
230
                if err := v.BindEnv(strings.ToUpper(valueName)); err != nil {
466✔
231
                        panic(fmt.Sprintf("Failed to bind %q to env var: %v", valueName, err))
×
232
                }
233
                v.SetDefault(valueName, defaultValue)
466✔
234
        }
235
}
236

237
func overrideViperStructDefaults(v *viper.Viper, prefix string, newDefaults string) {
7✔
238
        overrides := map[string]any{}
7✔
239
        if err := json.Unmarshal([]byte(newDefaults), &overrides); err != nil {
7✔
240
                panic(fmt.Sprintf("Failed to parse overrides in %q: %v", prefix, err))
×
241
        }
242

243
        for key, value := range overrides {
14✔
244
                // TODO: we don't do any fancy type checking here, so this could blow up later.
7✔
245
                // I expect it will blow up at config-parse time, which should be earlier enough.
7✔
246
                v.SetDefault(prefix+"."+key, value)
7✔
247
        }
7✔
248
}
249

250
func getDefaultValueForInt64(value string) (any, error) {
51✔
251
        var defaultValue any
51✔
252
        var err error
51✔
253

51✔
254
        defaultValue, err = strconv.ParseInt(value, 0, 0)
51✔
255
        if err == nil {
87✔
256
                return defaultValue, nil
36✔
257
        }
36✔
258

259
        // Try to parse it as a time.Duration
260
        var parseErr error
15✔
261
        defaultValue, parseErr = time.ParseDuration(value)
15✔
262
        if parseErr == nil {
30✔
263
                return defaultValue, nil
15✔
264
        }
15✔
265

266
        // Return the original error, not time.ParseDuration's error
267
        return nil, err
×
268
}
269

270
func getDefaultValue(field reflect.StructField, value string, valueName string) any {
466✔
271
        defaultValue := reflect.Zero(field.Type).Interface()
466✔
272
        var err error // We handle errors at the end of the switch
466✔
273
        //nolint:golint,exhaustive
466✔
274
        switch field.Type.Kind() {
466✔
275
        case reflect.String:
302✔
276
                defaultValue = value
302✔
277
        case reflect.Int64:
51✔
278
                defaultValue, err = getDefaultValueForInt64(value)
51✔
279
        case reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int,
280
                reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8, reflect.Uint:
41✔
281
                defaultValue, err = strconv.ParseInt(value, 0, 0)
41✔
282
        case reflect.Float64:
3✔
283
                defaultValue, err = strconv.ParseFloat(value, 64)
3✔
284
        case reflect.Bool:
42✔
285
                defaultValue, err = strconv.ParseBool(value)
42✔
286
        case reflect.Slice:
24✔
287
                defaultValue = nil
24✔
288
        case reflect.Map:
3✔
289
                defaultValue = nil
3✔
290
        default:
×
291
                err = fmt.Errorf("unhandled type %s", field.Type)
×
292
        }
293
        if err != nil {
466✔
294
                // This is effectively a compile-time error, so exit early
×
295
                panic(fmt.Sprintf("Bad value for field %q (%s): %q", valueName, field.Type, err))
×
296
        }
297
        return defaultValue
466✔
298
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc