diff --git a/internal/config/config.go b/internal/config/config.go index da652e5..9d9e4d6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,11 +15,12 @@ type Config struct { type Params map[string]any var ( - ErrParamNotFound = errors.New("not found") - ErrParamNotString = errors.New("not a string") - ErrParamNotNumber = errors.New("not a number") - ErrParamNotBool = errors.New("not a boolean") - ErrParamNotSlice = errors.New("not a slice") + ErrParamNotFound = errors.New("not found") + ErrParamNotString = errors.New("not a string") + ErrParamNotNumber = errors.New("not a number") + ErrParamNotInteger = errors.New("not an integer") + ErrParamNotBool = errors.New("not a boolean") + ErrParamNotSlice = errors.New("not a slice") ) func (p Params) GetString(key string) (string, error) { @@ -42,14 +43,97 @@ func (p Params) GetInt(key string) (int, error) { } intValue, ok := value.(int) - if !ok { - floatValue, ok := value.(float64) - if !ok { - return 0, ErrParamNotNumber - } - intValue = int(floatValue) + if ok { + return intValue, nil } - return intValue, nil + + uintValue, ok := value.(uint) + if ok { + return int(uintValue), nil + } + + byteValue, ok := value.(byte) + if ok { + return int(byteValue), nil + } + + floatValue, ok := value.(float64) + if ok { + if floatValue != math.Floor(floatValue) { + return 0, ErrParamNotInteger + } + return int(floatValue), nil + } + + return 0, ErrParamNotNumber +} + +func (p Params) GetFloat32(key string) (float32, error) { + value, ok := p[key] + if !ok { + return 0, ErrParamNotFound + } + + float32Value, ok := value.(float32) + if ok { + return float32Value, nil + } + + float64Value, ok := value.(float64) + if ok { + return float32(float64Value), nil + } + + intValue, ok := value.(int) + if ok { + return float32(intValue), nil + } + + uintValue, ok := value.(uint) + if ok { + return float32(uintValue), nil + } + + byteValue, ok := value.(byte) + if ok { + return float32(byteValue), nil + } + + return 0, ErrParamNotNumber +} + +func (p Params) GetFloat64(key string) (float64, error) { + value, ok := p[key] + if !ok { + return 0, ErrParamNotFound + } + + float64Value, ok := value.(float64) + if ok { + return float64Value, nil + } + + float32Value, ok := value.(float32) + if ok { + return float64(float32Value), nil + } + + intValue, ok := value.(int) + if ok { + return float64(intValue), nil + } + + uintValue, ok := value.(uint) + if ok { + return float64(uintValue), nil + } + + byteValue, ok := value.(byte) + if ok { + return float64(byteValue), nil + } + + return 0, ErrParamNotNumber } func (p Params) GetBool(key string) (bool, error) { @@ -93,37 +177,40 @@ func (p Params) GetIntSlice(key string) ([]int, error) { return nil, ErrParamNotFound } - interfaceSlice, ok := value.([]any) - if !ok { + v := reflect.ValueOf(value) + if v.Kind() != reflect.Slice { return nil, ErrParamNotSlice } - intSlice := make([]int, len(interfaceSlice)) - for i, v := range interfaceSlice { - - intValue, ok := v.(int) + result := make([]int, v.Len()) + for i := 0; i < v.Len(); i++ { + elem := v.Index(i).Interface() + byteValue, ok := elem.(byte) if ok { - intSlice[i] = intValue + result[i] = int(byteValue) continue } - - uintValue, ok := v.(uint) + uintValue, ok := elem.(uint) if ok { - intSlice[i] = int(uintValue) + result[i] = int(uintValue) continue } - - floatValue, ok := v.(float64) + intValue, ok := elem.(int) + if ok { + result[i] = int(intValue) + continue + } + floatValue, ok := elem.(float64) if ok { if floatValue != math.Floor(floatValue) { return nil, fmt.Errorf("element at index %d is not an integer", i) } - intSlice[i] = int(floatValue) + result[i] = int(floatValue) continue } return nil, fmt.Errorf("element at index %d is not a number", i) } - return intSlice, nil + return result, nil } func (p Params) GetByteSlice(key string) ([]byte, error) { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 7e8f714..bb62941 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -74,6 +74,84 @@ func TestGoodIntParamsJSON(t *testing.T) { } } +func TestGoodFloat32ParamsJSON(t *testing.T) { + testCases := []struct { + name string + paramsJSON string + key string + expected float32 + }{ + { + name: "no decimal param", + paramsJSON: `{"key": 1}`, + key: "key", + expected: 1, + }, + { + name: "float param", + paramsJSON: `{"key": 1.23}`, + key: "key", + expected: 1.23, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + params := config.Params{} + err := json.Unmarshal([]byte(testCase.paramsJSON), ¶ms) + if err != nil { + t.Fatalf("Failed to unmarshal params JSON: %v", err) + } + value, err := params.GetFloat32(testCase.key) + if err != nil { + t.Fatalf("GetFloat32 returned error: %v", err) + } + if value != testCase.expected { + t.Fatalf("GetFloat32 got %f, expected %f", value, testCase.expected) + } + }) + } +} + +func TestGoodFloat64ParamsJSON(t *testing.T) { + testCases := []struct { + name string + paramsJSON string + key string + expected float64 + }{ + { + name: "no decimal param", + paramsJSON: `{"key": 1}`, + key: "key", + expected: 1, + }, + { + name: "float param", + paramsJSON: `{"key": 1.23}`, + key: "key", + expected: 1.23, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + params := config.Params{} + err := json.Unmarshal([]byte(testCase.paramsJSON), ¶ms) + if err != nil { + t.Fatalf("Failed to unmarshal params JSON: %v", err) + } + value, err := params.GetFloat64(testCase.key) + if err != nil { + t.Fatalf("GetFloat64 returned error: %v", err) + } + if value != testCase.expected { + t.Fatalf("GetFloat64 got %f, expected %f", value, testCase.expected) + } + }) + } +} + func TestGoodBoolParamsJSON(t *testing.T) { testCases := []struct { name string @@ -153,6 +231,12 @@ func TestGoodIntSliceParamsJSON(t *testing.T) { key: "key", expected: []int{1, 2, 3}, }, + { + name: "int array with floats", + paramsJSON: `{"key": [1.0, 2.0, 3.0]}`, + key: "key", + expected: []int{1, 2, 3}, + }, } for _, testCase := range testCases {