diff --git a/.promu-cgo.yml b/.promu-cgo.yml index 8501906..f86aa62 100644 --- a/.promu-cgo.yml +++ b/.promu-cgo.yml @@ -13,7 +13,7 @@ build: path: ./cmd/ceems_lb tags: all: [osusergo, netgo, static_build] - flags: -a -race + flags: -a ldflags: | -X github.com/prometheus/common/version.Version={{.Version}} -X github.com/prometheus/common/version.Revision={{.Revision}} diff --git a/pkg/api/cli/cli_test.go b/pkg/api/cli/cli_test.go index a12b7d2..088860c 100644 --- a/pkg/api/cli/cli_test.go +++ b/pkg/api/cli/cli_test.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "path/filepath" + "syscall" "testing" "time" @@ -104,7 +105,8 @@ func TestCEEMSServerMain(t *testing.T) { --- ceems_api_server: data: - path: %s` + path: %[1]s + backup_path: %[1]s` configFile := fmt.Sprintf(configFileTmpl, dataDir) configFilePath := makeConfigFile(configFile, tmpDir) @@ -131,4 +133,8 @@ ceems_api_server: t.Errorf("Could not start stats server after %d attempts", i) } } + + // Send INT signal and wait a second to clean up server and DB + syscall.Kill(syscall.Getpid(), syscall.SIGINT) + time.Sleep(1 * time.Second) } diff --git a/pkg/api/http/server_test.go b/pkg/api/http/server_test.go index 2179228..1f69526 100644 --- a/pkg/api/http/server_test.go +++ b/pkg/api/http/server_test.go @@ -7,15 +7,50 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "reflect" + "strings" "testing" "time" "github.com/go-kit/log" "github.com/gorilla/mux" + "github.com/mahendrapaipuri/ceems/pkg/api/base" "github.com/mahendrapaipuri/ceems/pkg/api/models" ) +type testCase struct { + name string + req string + user string + admin bool + handler func(http.ResponseWriter, *http.Request) + code int +} + +var ( + mockServerUnits = []models.Unit{ + {UUID: "1000", ClusterID: "slurm-0", ResourceManager: "slurm", Usr: "foousr"}, + {UUID: "10001", ClusterID: "os-0", ResourceManager: "openstack", Usr: "barusr"}, + } + mockServerUsage = []models.Usage{ + {Project: "foo", ClusterID: "slurm-0", ResourceManager: "slurm"}, + {Project: "bar", ClusterID: "os-0", ResourceManager: "openstack"}, + } + mockServerProjects = []models.Project{ + {Name: "foo", ClusterID: "slurm-0", ResourceManager: "slurm", Users: models.List{"foousr"}}, + {Name: "bar", ClusterID: "os-0", ResourceManager: "openstack", Users: models.List{"barusr"}}, + } + mockServerUsers = []models.User{ + {Name: "foousr", ClusterID: "slurm-0", ResourceManager: "slurm", Projects: models.List{"foo"}}, + {Name: "bar", ClusterID: "os-0", ResourceManager: "openstack", Projects: models.List{"bar"}}, + } + mockServerClusters = []models.Cluster{ + {ID: "slurm-0", Manager: "slurm"}, + {ID: "os-0", Manager: "openstack"}, + } +) + func setupServer() *CEEMSServer { logger := log.NewNopLogger() server, _, _ := NewCEEMSServer(&Config{Logger: logger}) @@ -31,188 +66,397 @@ func setupServer() *CEEMSServer { } func unitQuerier(db *sql.DB, q Query, logger log.Logger) ([]models.Unit, error) { - return []models.Unit{{UUID: "1000", Usr: "user"}, {UUID: "10001", Usr: "user"}}, nil + return mockServerUnits, nil } func usageQuerier(db *sql.DB, q Query, logger log.Logger) ([]models.Usage, error) { - return []models.Usage{{Project: "foo"}, {Project: "bar"}}, nil + return mockServerUsage, nil } func projectQuerier(db *sql.DB, q Query, logger log.Logger) ([]models.Project, error) { - return []models.Project{{Name: "foo"}, {Name: "bar"}}, nil + return mockServerProjects, nil } func userQuerier(db *sql.DB, q Query, logger log.Logger) ([]models.User, error) { - return []models.User{{Name: "foo"}, {Name: "bar"}}, nil + return mockServerUsers, nil } func clusterQuerier(db *sql.DB, q Query, logger log.Logger) ([]models.Cluster, error) { - return []models.Cluster{{ID: "slurm-0", Manager: "slurm"}, {ID: "os-0", Manager: "openstack"}}, nil + return mockServerClusters, nil } func getMockUnits( _ Query, _ log.Logger, ) ([]models.Unit, error) { - return []models.Unit{{UUID: "1000", Usr: "user"}, {UUID: "10001", Usr: "user"}}, nil + return mockServerUnits, nil } -// func getMockAdminUsers(url string, client *http.Client, logger log.Logger) ([]string, error) { -// return []string{"adm1", "adm2"}, nil -// } - -// // Test /api/projects when no user header found -// func TestAccountsHandlerNoUserHeader(t *testing.T) { -// server := setupServer() -// // Create request -// req := httptest.NewRequest(http.MethodGet, "/api/projects", nil) - -// // Start recorder -// w := httptest.NewRecorder() -// server.projects(w, req) -// res := w.Result() -// defer res.Body.Close() - -// // Get body -// data, err := io.ReadAll(res.Body) -// if err != nil { -// t.Errorf("expected error to be nil got %v", err) -// } +// Test users and users admin handlers +func TestUsersHandlers(t *testing.T) { + server := setupServer() + defer server.Shutdown(context.Background()) -// // Unmarshal byte into structs. -// var response Response -// json.Unmarshal(data, &response) + // Test cases + tests := []testCase{ + { + name: "users", + req: "/api/" + base.APIVersion + "/users?field=uuid&field=project", + user: "foousr", + admin: false, + handler: server.users, + code: 200, + }, + { + name: "users admin", + req: "/api/" + base.APIVersion + "/users/admin?project=foo", + user: "foousr", + admin: true, + handler: server.usersAdmin, + code: 200, + }, + } -// if response.Status != "error" { -// t.Errorf("expected error status got %v", response.Status) -// } -// if response.ErrorType != "user_error" { -// t.Errorf("expected user_error type got %v", response.ErrorType) -// } -// if response.Data != nil { -// t.Errorf("expected nil data got %v", response.Data) -// } -// } + for _, test := range tests { + request := httptest.NewRequest("GET", test.req, nil) + request.Header.Set("X-Grafana-User", test.user) + if test.admin { + q := url.Values{} + q.Add("user", "foousr") + request.URL.RawQuery = q.Encode() + } + + // Start recorder + w := httptest.NewRecorder() + test.handler(w, request) + res := w.Result() + defer res.Body.Close() + + // Get body + data, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("expected error to be nil got %v", err) + } + + // Unmarshal byte into structs. + var response Response[models.User] + json.Unmarshal(data, &response) + if w.Code != test.code { + t.Errorf("%s: expected status code %d, got %d", test.name, test.code, w.Code) + } + if response.Status != "success" { + t.Errorf("%s: expected success status got %v", test.name, response.Status) + } + if !reflect.DeepEqual(response.Data, mockServerUsers) { + t.Errorf("%s: expected data %#v got %#v", test.name, mockServerUsers, response.Data) + } + } +} -// Test /projects +// Test projects and projects admin handlers func TestProjectsHandler(t *testing.T) { server := setupServer() defer server.Shutdown(context.Background()) - // Create request - req := httptest.NewRequest(http.MethodGet, "/api/v1/projects", nil) - // Add user header - // req.Header.Set("X-Grafana-User", "foo") - - // Start recorder - w := httptest.NewRecorder() - server.projects(w, req) - res := w.Result() - defer res.Body.Close() + // Test cases + tests := []testCase{ + { + name: "projects", + req: "/api/" + base.APIVersion + "/projects", + user: "foousr", + admin: false, + handler: server.projects, + code: 200, + }, + { + name: "projects admin", + req: "/api/" + base.APIVersion + "/projects/admin", + user: "foousr", + admin: true, + handler: server.projectsAdmin, + code: 200, + }, + } - // Get body - data, err := io.ReadAll(res.Body) - if err != nil { - t.Errorf("expected error to be nil got %v", err) + for _, test := range tests { + request := httptest.NewRequest("GET", test.req, nil) + request.Header.Set("X-Grafana-User", test.user) + if test.admin { + q := url.Values{} + q.Add("project", "foo") + request.URL.RawQuery = q.Encode() + } + + // Start recorder + w := httptest.NewRecorder() + test.handler(w, request) + res := w.Result() + defer res.Body.Close() + + // Get body + data, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("expected error to be nil got %v", err) + } + + // Unmarshal byte into structs. + var response Response[models.Project] + json.Unmarshal(data, &response) + if w.Code != test.code { + t.Errorf("%s: expected status code %d, got %d", test.name, test.code, w.Code) + } + if response.Status != "success" { + t.Errorf("%s: expected success status got %v", test.name, response.Status) + } + if !reflect.DeepEqual(response.Data, mockServerProjects) { + t.Errorf("%s: expected data %#v got %#v", test.name, mockServerProjects, response.Data) + } } +} - // Expected result - expectedAccounts, _ := projectQuerier(server.db, Query{}, server.logger) +// Test units and units admin handlers +func TestUnitsHandler(t *testing.T) { + server := setupServer() + defer server.Shutdown(context.Background()) - // Unmarshal byte into structs. - var response Response[models.Project] - json.Unmarshal(data, &response) - if response.Status != "success" { - t.Errorf("expected success status got %v", response.Status) + // Test cases + tests := []testCase{ + { + name: "units", + req: "/api/" + base.APIVersion + "/units", + user: "foousr", + admin: false, + handler: server.units, + code: 200, + }, + { + name: "units admin", + req: "/api/" + base.APIVersion + "/units/admin", + user: "foousr", + admin: true, + handler: server.unitsAdmin, + code: 200, + }, } - if !reflect.DeepEqual(response.Data, expectedAccounts) { - t.Errorf("expected projects %#v got %#v", expectedAccounts, response.Data) + + for _, test := range tests { + request := httptest.NewRequest("GET", test.req, nil) + request.Header.Set("X-Grafana-User", test.user) + if test.admin { + q := url.Values{} + q.Add("user", "foousr") + request.URL.RawQuery = q.Encode() + } + + // Start recorder + w := httptest.NewRecorder() + test.handler(w, request) + res := w.Result() + defer res.Body.Close() + + // Get body + data, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("expected error to be nil got %v", err) + } + + // Unmarshal byte into structs. + var response Response[models.Unit] + json.Unmarshal(data, &response) + if w.Code != test.code { + t.Errorf("%s: expected status code %d, got %d", test.name, test.code, w.Code) + } + if response.Status != "success" { + t.Errorf("%s: expected success status got %v", test.name, response.Status) + } + if !reflect.DeepEqual(response.Data, mockServerUnits) { + t.Errorf("%s: expected data %#v got %#v", test.name, mockServerUnits, response.Data) + } } } -// Test /users -func TestUsersHandler(t *testing.T) { +// Test usage and usage admin handlers +func TestUsageHandlers(t *testing.T) { server := setupServer() defer server.Shutdown(context.Background()) - // Create request - req := httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) - // Add user header - req.Header.Set("X-Grafana-User", "foo") - - // Start recorder - w := httptest.NewRecorder() - server.users(w, req) - res := w.Result() - defer res.Body.Close() + // Test cases + tests := []testCase{ + { + name: "current usage", + req: "/api/" + base.APIVersion + "/usage/current", + user: "foousr", + admin: false, + handler: server.usage, + code: 200, + }, + { + name: "global usage", + req: "/api/" + base.APIVersion + "/usage/global", + user: "foousr", + admin: false, + handler: server.usage, + code: 200, + }, + { + name: "current usage admin", + req: "/api/" + base.APIVersion + "/usage/current/admin", + user: "foousr", + admin: true, + handler: server.usageAdmin, + code: 200, + }, + { + name: "global usage admin", + req: "/api/" + base.APIVersion + "/usage/global/admin", + user: "foousr", + admin: true, + handler: server.usageAdmin, + code: 200, + }, + } - // Get body - data, err := io.ReadAll(res.Body) - if err != nil { - t.Errorf("expected error to be nil got %v", err) + for _, test := range tests { + request := httptest.NewRequest("GET", test.req, nil) + request.Header.Set("X-Grafana-User", test.user) + if test.admin { + q := url.Values{} + q.Add("user", "foousr") + request.URL.RawQuery = q.Encode() + } + if strings.Contains(test.name, "current") { + request = mux.SetURLVars(request, map[string]string{"mode": "current"}) + } else { + request = mux.SetURLVars(request, map[string]string{"mode": "global"}) + } + + // Start recorder + w := httptest.NewRecorder() + test.handler(w, request) + res := w.Result() + defer res.Body.Close() + + // Get body + data, err := io.ReadAll(res.Body) + if err != nil { + t.Errorf("expected error to be nil got %v", err) + } + + // Unmarshal byte into structs. + var response Response[models.Usage] + json.Unmarshal(data, &response) + if w.Code != test.code { + t.Errorf("%s: expected status code %d, got %d", test.name, test.code, w.Code) + } + if response.Status != "success" { + t.Errorf("%s: expected success status got %v", test.name, response.Status) + } + if !reflect.DeepEqual(response.Data, mockServerUsage) { + t.Errorf("%s: expected data %#v got %#v", test.name, mockServerUsage, response.Data) + } } +} - // Expected result - expectedUsers, _ := userQuerier(server.db, Query{}, server.logger) +// Test verify handler +func TestVerifyHandler(t *testing.T) { + server := setupServer() + defer server.Shutdown(context.Background()) - // Unmarshal byte into structs. - var response Response[models.User] - json.Unmarshal(data, &response) - if response.Status != "success" { - t.Errorf("expected success status got %v", response.Status) - } - if !reflect.DeepEqual(response.Data, expectedUsers) { - t.Errorf("expected users %#v got %#v", expectedUsers, response.Data) + tests := []testCase{ + { + name: "verify bad data", + req: "/api/" + base.APIVersion + "/units/verify", + user: "foousr", + admin: false, + handler: server.verifyUnitsOwnership, + code: 400, + }, + { + name: "verify forbidden", + req: "/api/" + base.APIVersion + "/units/verify?uuid=1234", + user: "foousr", + admin: false, + handler: server.verifyUnitsOwnership, + code: 403, + }, } -} -// // Test /api/units when no user header found -// func TestUnitsHandlerNoUserHeader(t *testing.T) { -// server := setupServer() -// // Create request -// req := httptest.NewRequest(http.MethodGet, "/api/units", nil) + for _, test := range tests { + request := httptest.NewRequest("GET", test.req, nil) + request.Header.Set("X-Grafana-User", test.user) -// // Start recorder -// w := httptest.NewRecorder() -// server.units(w, req) -// res := w.Result() -// defer res.Body.Close() + // Start recorder + w := httptest.NewRecorder() + test.handler(w, request) + res := w.Result() + defer res.Body.Close() -// // Get body -// data, err := io.ReadAll(res.Body) -// if err != nil { -// t.Errorf("expected error to be nil got %v", err) -// } + if w.Code != test.code { + t.Errorf("%s: expected status code %d, got %d", test.name, test.code, w.Code) + } + } +} -// // Unmarshal byte into structs. -// var response Response -// json.Unmarshal(data, &response) +// Test demo handlers +func TestDemoHandlers(t *testing.T) { + server := setupServer() + defer server.Shutdown(context.Background()) -// if response.Status != "error" { -// t.Errorf("expected error status got %v", response.Status) -// } -// if response.ErrorType != "user_error" { -// t.Errorf("expected user_error type got %v", response.ErrorType) -// } -// if response.Data != nil { -// t.Errorf("expected nil data got %v", response.Data) -// } -// } + // Test cases + tests := []testCase{ + { + name: "units demo", + req: "/api/" + base.APIVersion + "/demo/units", + user: "foousr", + admin: false, + handler: server.demo, + code: 200, + }, + { + name: "usage demo", + req: "/api/" + base.APIVersion + "/demo/usage", + user: "foousr", + admin: false, + handler: server.demo, + code: 200, + }, + } -// Test /units -func TestUnitsHandler(t *testing.T) { + for _, test := range tests { + request := httptest.NewRequest("GET", test.req, nil) + request.Header.Set("X-Grafana-User", test.user) + if strings.Contains(test.name, "units") { + request = mux.SetURLVars(request, map[string]string{"resource": "units"}) + } else { + request = mux.SetURLVars(request, map[string]string{"resource": "usage"}) + } + + // Start recorder + w := httptest.NewRecorder() + test.handler(w, request) + res := w.Result() + defer res.Body.Close() + + if w.Code != test.code { + t.Errorf("%s: expected status code %d, got %d", test.name, test.code, w.Code) + } + } +} + +// Test clusters handlers +func TestClustersHandler(t *testing.T) { server := setupServer() defer server.Shutdown(context.Background()) // Create request - req := httptest.NewRequest(http.MethodGet, "/api/v1/units", nil) + req := httptest.NewRequest(http.MethodGet, "/api/"+base.APIVersion+"/clusters/admin", nil) // Add user header currentUser := "foo" req.Header.Set("X-Grafana-User", currentUser) // Start recorder w := httptest.NewRecorder() - server.units(w, req) + server.clustersAdmin(w, req) res := w.Result() defer res.Body.Close() @@ -223,58 +467,21 @@ func TestUnitsHandler(t *testing.T) { } // Expected result - expectedUnits, _ := unitQuerier(server.db, Query{}, server.logger) + expectedClusters, _ := clusterQuerier(server.db, Query{}, server.logger) - // Unmarshal byte into structs. - var response Response[models.Unit] + // Unmarshal byte into structs + var response Response[models.Cluster] json.Unmarshal(data, &response) if response.Status != "success" { t.Errorf("expected success status got %v", response.Status) } - if !reflect.DeepEqual(expectedUnits, response.Data) { - t.Errorf("expected units %d units, got %d", len(expectedUnits), len(response.Data)) + if !reflect.DeepEqual(expectedClusters, response.Data) { + t.Errorf("expected clusters %#v clusters, got %#v", expectedClusters, response.Data) } } -// // Test /api/units when user header and impersonated user header found -// func TestUnitsHandlerWithUserHeaderAndAdmin(t *testing.T) { -// server := setupServer() -// // server.adminUsers = []string{"admin"} -// // Create request -// req := httptest.NewRequest(http.MethodGet, "/api/units", nil) -// // Add user header -// // req.Header.Set("X-Grafana-User", server.adminUsers[0]) -// req.Header.Set("X-Dashboard-User", "foo") - -// // Start recorder -// w := httptest.NewRecorder() -// server.units(w, req) -// res := w.Result() -// defer res.Body.Close() - -// // Get body -// data, err := io.ReadAll(res.Body) -// if err != nil { -// t.Errorf("expected error to be nil got %v", err) -// } - -// // Expected result -// expectedUnits, _ := getMockUnits(Query{}, server.logger) - -// // Unmarshal byte into structs. -// var response Response -// json.Unmarshal(data, &response) - -// if response.Status != "success" { -// t.Errorf("expected success status got %v", response.Status) -// } -// if !reflect.DeepEqual(response.Data, expectedUnits) { -// t.Errorf("expected %v got %v", expectedUnits, response.Data) -// } -// } - // Test /units when from/to query parameters are malformed func TestUnitsHandlerWithMalformedQueryParams(t *testing.T) { server := setupServer() @@ -403,83 +610,44 @@ func TestUnitsHandlerWithUnituuidsQueryParams(t *testing.T) { } } -// Test /usage -func TestUsageHandler(t *testing.T) { - server := setupServer() - defer server.Shutdown(context.Background()) - - // Create request - req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/current", nil) - // Need to set path variables here - req = mux.SetURLVars(req, map[string]string{"mode": "current"}) - - // Add user header - currentUser := "foo" - req.Header.Set("X-Grafana-User", currentUser) - - // Start recorder - w := httptest.NewRecorder() - server.usage(w, req) - res := w.Result() - defer res.Body.Close() - - // Get body - data, err := io.ReadAll(res.Body) - if err != nil { - t.Errorf("expected error to be nil got %v", err) - } - - // Expected result - expectedUsage, _ := usageQuerier(server.db, Query{}, server.logger) - - // Unmarshal byte into structs. - var response Response[models.Usage] - json.Unmarshal(data, &response) - - if response.Status != "success" { - t.Errorf("expected success status got %#v", response) - } - - if !reflect.DeepEqual(expectedUsage, response.Data) { - t.Errorf("expected usage %#v usage, got %#v", expectedUsage, response.Data) - } -} +// // Test /usage +// func TestUsageHandler(t *testing.T) { +// server := setupServer() +// defer server.Shutdown(context.Background()) -// Test /clusters -func TestClustersHandler(t *testing.T) { - server := setupServer() - defer server.Shutdown(context.Background()) +// // Create request +// req := httptest.NewRequest(http.MethodGet, "/api/v1/usage/current", nil) +// // Need to set path variables here +// req = mux.SetURLVars(req, map[string]string{"mode": "current"}) - // Create request - req := httptest.NewRequest(http.MethodGet, "/api/v1/clusters/admin", nil) - // Add user header - currentUser := "foo" - req.Header.Set("X-Grafana-User", currentUser) +// // Add user header +// currentUser := "foo" +// req.Header.Set("X-Grafana-User", currentUser) - // Start recorder - w := httptest.NewRecorder() - server.clustersAdmin(w, req) - res := w.Result() - defer res.Body.Close() +// // Start recorder +// w := httptest.NewRecorder() +// server.usage(w, req) +// res := w.Result() +// defer res.Body.Close() - // Get body - data, err := io.ReadAll(res.Body) - if err != nil { - t.Errorf("expected error to be nil got %v", err) - } +// // Get body +// data, err := io.ReadAll(res.Body) +// if err != nil { +// t.Errorf("expected error to be nil got %v", err) +// } - // Expected result - expectedClusters, _ := clusterQuerier(server.db, Query{}, server.logger) +// // Expected result +// expectedUsage, _ := usageQuerier(server.db, Query{}, server.logger) - // Unmarshal byte into structs - var response Response[models.Cluster] - json.Unmarshal(data, &response) +// // Unmarshal byte into structs. +// var response Response[models.Usage] +// json.Unmarshal(data, &response) - if response.Status != "success" { - t.Errorf("expected success status got %v", response.Status) - } +// if response.Status != "success" { +// t.Errorf("expected success status got %#v", response) +// } - if !reflect.DeepEqual(expectedClusters, response.Data) { - t.Errorf("expected clusters %#v clusters, got %#v", expectedClusters, response.Data) - } -} +// if !reflect.DeepEqual(expectedUsage, response.Data) { +// t.Errorf("expected usage %#v usage, got %#v", expectedUsage, response.Data) +// } +// } diff --git a/pkg/api/http/validation.go b/pkg/api/http/validation.go index 78b3df7..a1da7fa 100644 --- a/pkg/api/http/validation.go +++ b/pkg/api/http/validation.go @@ -9,33 +9,29 @@ import ( "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/mahendrapaipuri/ceems/pkg/api/base" - ceems_db "github.com/mahendrapaipuri/ceems/pkg/api/db" "github.com/mahendrapaipuri/ceems/pkg/api/models" ) // adminUsers returns a slice of admin users fetched from DB func adminUsers(dbConn *sql.DB, logger log.Logger) []string { var users []string - for _, source := range ceems_db.AdminUsersSources { - rows, err := dbConn.Query( - fmt.Sprintf("SELECT users FROM %s WHERE source = ?", base.AdminUsersDBTableName), - source, - ) - if err != nil { - level.Error(logger).Log("msg", "Failed to query for admin users", "source", source, "err", err) + rows, err := dbConn.Query( + fmt.Sprintf("SELECT users FROM %s", base.AdminUsersDBTableName), + ) + if err != nil { + level.Error(logger).Log("msg", "Failed to query for admin users", "err", err) + return nil + } + + // Scan users rows + var usersList models.List + for rows.Next() { + if err := rows.Scan(&usersList); err != nil { + level.Error(logger).Log("msg", "Failed to scan row for admin users query", "err", err) continue } - - // Scan users rows - var usersList models.List - for rows.Next() { - if err := rows.Scan(&usersList); err != nil { - level.Error(logger).Log("msg", "Failed to scan row for admin users query", "source", source, "err", err) - continue - } - for _, user := range usersList { - users = append(users, user.(string)) - } + for _, user := range usersList { + users = append(users, user.(string)) } } return users @@ -43,11 +39,6 @@ func adminUsers(dbConn *sql.DB, logger log.Logger) []string { // VerifyOwnership returns true if user is the owner of queried units func VerifyOwnership(user string, clusterIDs []string, uuids []string, db *sql.DB, logger log.Logger) bool { - // If current user is in list of admin users, pass the check - if slices.Contains(adminUsers(db, logger), user) { - return true - } - // If the data is incomplete, forbid the request if db == nil || len(clusterIDs) == 0 || user == "" { level.Debug(logger).Log( @@ -57,6 +48,10 @@ func VerifyOwnership(user string, clusterIDs []string, uuids []string, db *sql.D return false } + // If current user is in list of admin users, pass the check + if slices.Contains(adminUsers(db, logger), user) { + return true + } level.Debug(logger). Log("msg", "UUIDs in query", "user", user, "cluster_id", strings.Join(clusterIDs, ","), "queried_uuids", strings.Join(uuids, ","), diff --git a/pkg/api/resource/default.go b/pkg/api/resource/default.go index 25321ff..54d811d 100644 --- a/pkg/api/resource/default.go +++ b/pkg/api/resource/default.go @@ -30,7 +30,7 @@ func NewDefaultResourceManager(cluster models.Cluster, logger log.Logger) (Fetch // Return empty units response func (d *defaultResourceManager) FetchUnits(start time.Time, end time.Time) ([]models.ClusterUnits, error) { - level.Info(d.logger).Log("msg", "Empty units fetched from default resource manager") + level.Info(d.logger).Log("msg", "Empty units fetched from default NoOp cluster") return []models.ClusterUnits{ { Cluster: models.Cluster{ID: "default"}, @@ -42,7 +42,7 @@ func (d *defaultResourceManager) FetchUnits(start time.Time, end time.Time) ([]m func (d *defaultResourceManager) FetchUsersProjects( currentTime time.Time, ) ([]models.ClusterUsers, []models.ClusterProjects, error) { - level.Info(d.logger).Log("msg", "Empty users and projects fetched from default resource manager") + level.Info(d.logger).Log("msg", "Empty users and projects fetched from default NoOp cluster") return []models.ClusterUsers{ { Cluster: models.Cluster{ID: "default"}, diff --git a/pkg/api/resource/manager.go b/pkg/api/resource/manager.go index 5276c06..5f0184b 100644 --- a/pkg/api/resource/manager.go +++ b/pkg/api/resource/manager.go @@ -125,7 +125,7 @@ func NewManager(logger log.Logger) (*Manager, error) { // Return an instance of default manager if len(fetchers) == 0 { level.Warn(logger).Log( - "msg", "No resource manager enabled. Using a default resource manager", + "msg", "No clusters found in config. Using a default cluster", "available_resource_managers", strings.Join(registeredManagers, ","), ) diff --git a/pkg/api/resource/manager_test.go b/pkg/api/resource/manager_test.go index 2bbded8..0251e45 100644 --- a/pkg/api/resource/manager_test.go +++ b/pkg/api/resource/manager_test.go @@ -5,10 +5,66 @@ import ( "os" "path/filepath" "testing" + "time" + "github.com/go-kit/log" + "github.com/go-kit/log/level" "github.com/mahendrapaipuri/ceems/pkg/api/base" + "github.com/mahendrapaipuri/ceems/pkg/api/models" ) +// mockResourceManager struct +type mockResourceManager struct { + logger log.Logger +} + +// NewMockResourceManager returns a new defaultResourceManager that returns empty compute units +func NewMockResourceManager(cluster models.Cluster, logger log.Logger) (Fetcher, error) { + level.Info(logger).Log("msg", "Default resource manager activated") + return &mockResourceManager{ + logger: logger, + }, nil +} + +// Return empty units response +func (d *mockResourceManager) FetchUnits(start time.Time, end time.Time) ([]models.ClusterUnits, error) { + return []models.ClusterUnits{ + { + Cluster: models.Cluster{ID: "mock"}, + Units: []models.Unit{ + { + UUID: "10000", + }, + }, + }, + }, nil +} + +// Return empty projects response +func (d *mockResourceManager) FetchUsersProjects( + currentTime time.Time, +) ([]models.ClusterUsers, []models.ClusterProjects, error) { + return []models.ClusterUsers{ + { + Cluster: models.Cluster{ID: "mock"}, + Users: []models.User{ + { + Name: "foo", + }, + }, + }, + }, []models.ClusterProjects{ + { + Cluster: models.Cluster{ID: "mock"}, + Projects: []models.Project{ + { + Name: "fooprj", + }, + }, + }, + }, nil +} + func mockConfig(tmpDir string, cfg string, serverURL string) string { var configFileTmpl string switch cfg { @@ -60,6 +116,21 @@ clusters: path: %[1]s web: url: %[2]s` + case "mock_instance": + configFileTmpl = ` +--- +clusters: + - id: default + manager: mock + cli: + path: %[1]s + web: + url: %[2]s` + case "empty_instance": + configFileTmpl = ` +--- +# %[1]s %[2]s +clusters: []` case "unknown_manager": configFileTmpl = ` --- @@ -77,10 +148,10 @@ clusters: web: url: %[2]s` case "malformed_1": - // Missing s in tsbd_instances + // Missing s in clusters configFileTmpl = ` --- -resource_manager: +cluster: - id: default` case "malformed_2": // Missing manager name @@ -194,3 +265,69 @@ func TestMixedClusterConfig(t *testing.T) { t.Errorf("config failed preflight checks to %s", err) } } + +func TestNewManager(t *testing.T) { + // Make mock config + base.ConfigFilePath = mockConfig(t.TempDir(), "mock_instance", "") + + // Register mock manager + RegisterManager("mock", NewMockResourceManager) + + // Create new manager + manager, err := NewManager(log.NewNopLogger()) + if err != nil { + t.Errorf("failed to create new manager: %s", err) + } + + // Fetch units + units, err := manager.FetchUnits(time.Now(), time.Now()) + if err != nil { + t.Errorf("failed to fetch units: %s", err) + } + if len(units[0].Units) != 1 { + t.Errorf("expected only 1 unit got %d", len(units[0].Units)) + } + + // Fetch users and projects + users, projects, err := manager.FetchUsersProjects(time.Now()) + if err != nil { + t.Errorf("failed to fetch users and projects: %s", err) + } + // Index 0 seems to be default manager + if len(users[0].Users) != 1 || len(projects[0].Projects) != 1 { + t.Errorf("expected 1 user and 1 project, got %d, %d", len(users[0].Users), len(projects[0].Projects)) + } +} + +func TestNewManagerWithNoClusters(t *testing.T) { + // Make mock config + base.ConfigFilePath = mockConfig(t.TempDir(), "empty_instance", "") + + // Register mock manager + RegisterManager("mock", NewMockResourceManager) + + // Create new manager + manager, err := NewManager(log.NewNopLogger()) + if err != nil { + t.Errorf("failed to create new manager: %s", err) + } + + // Fetch units + units, err := manager.FetchUnits(time.Now(), time.Now()) + if err != nil { + t.Errorf("failed to fetch units: %s", err) + } + if len(units[0].Units) != 0 { + t.Errorf("expected only 0 units got %d", len(units[0].Units)) + } + + // Fetch users and projects + users, projects, err := manager.FetchUsersProjects(time.Now()) + if err != nil { + t.Errorf("failed to fetch users and projects: %s", err) + } + // Index 0 seems to be default manager + if len(users[0].Users) != 0 || len(projects[0].Projects) != 0 { + t.Errorf("expected 0 users and 0 projects, got %d, %d", len(users[0].Users), len(projects[0].Projects)) + } +} diff --git a/pkg/lb/backend/backend_test.go b/pkg/lb/backend/backend_test.go index c835ad3..53a4bc3 100644 --- a/pkg/lb/backend/backend_test.go +++ b/pkg/lb/backend/backend_test.go @@ -13,7 +13,10 @@ import ( "github.com/mahendrapaipuri/ceems/pkg/tsdb" ) -const testURL = "http://localhost:3333" +const ( + testURL = "http://localhost:3333" + testURLBasicAuth = "http://foo:bar@localhost:3333" // #nosec +) func TestTSDBConfigSuccess(t *testing.T) { // Start test server @@ -42,6 +45,9 @@ func TestTSDBConfigSuccess(t *testing.T) { if !b.IsAlive() { t.Errorf("expected backend to be alive") } + if b.ActiveConnections() != 0 { + t.Errorf("expected zero active connections to backend") + } // Stop dummy server and query for retention period, we should get last updated value server.Close() @@ -112,3 +118,13 @@ func TestTSDBBackendAlive(t *testing.T) { t.Errorf("expected backend to be alive") } } + +func TestTSDBBackendAliveWithBasicAuth(t *testing.T) { + url, _ := url.Parse(testURLBasicAuth) + b := NewTSDBServer(url, httputil.NewSingleHostReverseProxy(url), log.NewNopLogger()) + b.SetAlive(b.IsAlive()) + + if !b.IsAlive() { + t.Errorf("expected backend to be alive") + } +} diff --git a/pkg/lb/frontend/frontend.go b/pkg/lb/frontend/frontend.go index e96e77c..655d01d 100644 --- a/pkg/lb/frontend/frontend.go +++ b/pkg/lb/frontend/frontend.go @@ -142,6 +142,7 @@ func (lb *loadBalancer) ValidateClusterIDs() error { } // If neither CEEMD DB or API server is configured, return + // This means LB is used without any access control configured if lb.amw.ceems.db == nil && lb.amw.ceems.clustersEndpoint() == nil { return nil } diff --git a/pkg/lb/frontend/frontend_test.go b/pkg/lb/frontend/frontend_test.go index 6b8a0fc..d96b40a 100644 --- a/pkg/lb/frontend/frontend_test.go +++ b/pkg/lb/frontend/frontend_test.go @@ -2,22 +2,55 @@ package frontend import ( "context" + "database/sql" "encoding/json" + "fmt" "net/http" "net/http/httptest" "net/http/httputil" "net/url" + "path/filepath" "strings" "testing" "time" "github.com/go-kit/log" + ceems_api_http "github.com/mahendrapaipuri/ceems/pkg/api/http" + "github.com/mahendrapaipuri/ceems/pkg/api/models" "github.com/mahendrapaipuri/ceems/pkg/lb/backend" "github.com/mahendrapaipuri/ceems/pkg/lb/serverpool" "github.com/mahendrapaipuri/ceems/pkg/tsdb" ) -func dummyTSDBServer(retention string, rmID string) *httptest.Server { +func setupClusterIDsDB(d string) (*sql.DB, string) { + dbPath := filepath.Join(d, "ceems.db") + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + fmt.Printf("failed to create DB") + } + + stmts := ` +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; +CREATE TABLE units ( + "id" integer not null primary key, + "cluster_id" text, + "resource_manager" text +); +INSERT INTO units VALUES(1, 'slurm-0', 'slurm'); +INSERT INTO units VALUES(2, 'os-0', 'openstack'); +INSERT INTO units VALUES(3, 'os-1', 'openstack'); +INSERT INTO units VALUES(4, 'slurm-1', 'slurm'); +COMMIT;` + + _, err = db.Exec(stmts) + if err != nil { + fmt.Printf("failed to insert mock data into DB: %s", err) + } + return db, dbPath +} + +func dummyTSDBServer(retention string, clusterID string) *httptest.Server { // Start test server expected := tsdb.Response{ Status: "success", @@ -31,17 +64,17 @@ func dummyTSDBServer(retention string, rmID string) *httptest.Server { w.Write([]byte("KO")) } } else { - w.Write([]byte(rmID)) + w.Write([]byte(clusterID)) } })) return server } func TestNewFrontendSingleGroup(t *testing.T) { - rmID := "default" + clusterID := "default" // Backends - dummyServer1 := dummyTSDBServer("30d", rmID) + dummyServer1 := dummyTSDBServer("30d", clusterID) defer dummyServer1.Close() backend1URL, err := url.Parse(dummyServer1.URL) if err != nil { @@ -56,7 +89,7 @@ func TestNewFrontendSingleGroup(t *testing.T) { if err != nil { t.Fatal(err) } - manager.Add(rmID, backend1) + manager.Add(clusterID, backend1) // make minimal config config := &Config{ @@ -105,7 +138,7 @@ func TestNewFrontendSingleGroup(t *testing.T) { newReq = request.WithContext( context.WithValue( request.Context(), QueryParamsContextKey{}, - &QueryParams{queryPeriod: period, id: rmID}, + &QueryParams{queryPeriod: period, id: clusterID}, ), ) } else { @@ -120,7 +153,7 @@ func TestNewFrontendSingleGroup(t *testing.T) { t.Errorf("%s: expected status %d, got %d", test.name, test.code, responseRecorder.Code) } if test.response { - if strings.TrimSpace(responseRecorder.Body.String()) != rmID { + if strings.TrimSpace(responseRecorder.Body.String()) != clusterID { t.Errorf("%s: expected dummy-response, got %s", test.name, responseRecorder.Body) } } @@ -187,39 +220,44 @@ func TestNewFrontendTwoGroups(t *testing.T) { t.Errorf("failed to create load balancer: %s", err) } + // Validate cluster IDs + if err := lb.ValidateClusterIDs(); err != nil { + t.Errorf("expected validation to pass, got error: %s", err) + } + tests := []struct { - name string - start int64 - rmID string - code int - response bool + name string + start int64 + clusterID string + code int + response bool }{ { - name: "query for rm-0 with params in ctx", - start: time.Now().UTC().Unix(), - rmID: "rm-0", - code: 200, - response: true, + name: "query for rm-0 with params in ctx", + start: time.Now().UTC().Unix(), + clusterID: "rm-0", + code: 200, + response: true, }, { - name: "query for rm-1 with params in ctx", - start: time.Now().UTC().Unix(), - rmID: "rm-1", - code: 200, - response: true, + name: "query for rm-1 with params in ctx", + start: time.Now().UTC().Unix(), + clusterID: "rm-1", + code: 200, + response: true, }, { - name: "query with no rmID params in ctx", + name: "query with no clusterID params in ctx", start: time.Now().UTC().Unix(), code: 503, response: false, }, { - name: "query with params in ctx and start more than retention period", - start: time.Now().UTC().Add(-time.Duration(31 * 24 * time.Hour)).Unix(), - rmID: "rm-0", - code: 503, - response: false, + name: "query with params in ctx and start more than retention period", + start: time.Now().UTC().Add(-time.Duration(31 * 24 * time.Hour)).Unix(), + clusterID: "rm-0", + code: 503, + response: false, }, } @@ -233,7 +271,7 @@ func TestNewFrontendTwoGroups(t *testing.T) { newReq = request.WithContext( context.WithValue( request.Context(), QueryParamsContextKey{}, - &QueryParams{queryPeriod: period, id: test.rmID}, + &QueryParams{queryPeriod: period, id: test.clusterID}, ), ) } else { @@ -248,7 +286,7 @@ func TestNewFrontendTwoGroups(t *testing.T) { t.Errorf("%s: expected status %d, got %d", test.name, test.code, responseRecorder.Code) } if test.response { - if strings.TrimSpace(responseRecorder.Body.String()) != test.rmID { + if strings.TrimSpace(responseRecorder.Body.String()) != test.clusterID { t.Errorf("%s: expected dummy-response, got %s", test.name, responseRecorder.Body) } } @@ -271,3 +309,195 @@ func TestNewFrontendTwoGroups(t *testing.T) { t.Errorf("expected status 503, got %d", responseRecorder.Code) } } + +func TestValidateClusterIDsWithDBPass(t *testing.T) { + tmpDir := t.TempDir() + setupClusterIDsDB(tmpDir) + + // Backends for group 1 + dummyServer := dummyTSDBServer("30d", "slurm-0") + defer dummyServer.Close() + backendURL, err := url.Parse(dummyServer.URL) + if err != nil { + t.Fatal(err) + } + + rp := httputil.NewSingleHostReverseProxy(backendURL) + backend := backend.NewTSDBServer(backendURL, rp, log.NewNopLogger()) + + // Start manager + manager, err := serverpool.NewManager("resource-based", log.NewNopLogger()) + if err != nil { + t.Fatal(err) + } + manager.Add("slurm-0", backend) + manager.Add("os-1", backend) + + // make minimal config + config := &Config{ + Logger: log.NewNopLogger(), + Manager: manager, + } + config.APIServer.Data.Path = tmpDir + + // New load balancer + lb, err := NewLoadBalancer(config) + if err != nil { + t.Errorf("failed to create load balancer: %s", err) + } + + // Validate cluster IDs + if err := lb.ValidateClusterIDs(); err != nil { + t.Errorf("expected validation to pass, got error %s", err) + } +} + +func TestValidateClusterIDsWithDBFail(t *testing.T) { + tmpDir := t.TempDir() + setupClusterIDsDB(tmpDir) + + // Backends for group 1 + dummyServer := dummyTSDBServer("30d", "slurm-0") + defer dummyServer.Close() + backendURL, err := url.Parse(dummyServer.URL) + if err != nil { + t.Fatal(err) + } + + rp := httputil.NewSingleHostReverseProxy(backendURL) + backend := backend.NewTSDBServer(backendURL, rp, log.NewNopLogger()) + + // Start manager + manager, err := serverpool.NewManager("resource-based", log.NewNopLogger()) + if err != nil { + t.Fatal(err) + } + manager.Add("unknown", backend) + manager.Add("os-1", backend) + + // make minimal config + config := &Config{ + Logger: log.NewNopLogger(), + Manager: manager, + } + config.APIServer.Data.Path = tmpDir + + // New load balancer + lb, err := NewLoadBalancer(config) + if err != nil { + t.Errorf("failed to create load balancer: %s", err) + } + + // Validate cluster IDs + if err := lb.ValidateClusterIDs(); err == nil { + t.Errorf("expected validation error, got none") + } +} + +func TestValidateClusterIDsWithAPIPass(t *testing.T) { + // Test CEEMS API server + expected := ceems_api_http.Response[models.Cluster]{ + Status: "success", + Data: []models.Cluster{ + { + ID: "slurm-0", + }, + { + ID: "os-1", + }, + }, + } + ceemsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewEncoder(w).Encode(&expected); err != nil { + w.Write([]byte("KO")) + } + })) + defer ceemsServer.Close() + + // Backends for group 1 + dummyServer := dummyTSDBServer("30d", "slurm-0") + defer dummyServer.Close() + backendURL, err := url.Parse(dummyServer.URL) + if err != nil { + t.Fatal(err) + } + + rp := httputil.NewSingleHostReverseProxy(backendURL) + backend := backend.NewTSDBServer(backendURL, rp, log.NewNopLogger()) + + // Start manager + manager, err := serverpool.NewManager("resource-based", log.NewNopLogger()) + if err != nil { + t.Fatal(err) + } + manager.Add("slurm-0", backend) + manager.Add("os-1", backend) + + // make minimal config + config := &Config{ + Logger: log.NewNopLogger(), + Manager: manager, + } + config.APIServer.Web.URL = ceemsServer.URL + + // New load balancer + lb, err := NewLoadBalancer(config) + if err != nil { + t.Errorf("failed to create load balancer: %s", err) + } + + // Validate cluster IDs + if err := lb.ValidateClusterIDs(); err != nil { + t.Errorf("expected validation to pass, got error %s", err) + } +} + +func TestValidateClusterIDsWithAPIFail(t *testing.T) { + // Test CEEMS API server + expected := ceems_api_http.Response[models.Cluster]{ + Status: "error", + } + ceemsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewEncoder(w).Encode(&expected); err != nil { + w.Write([]byte("KO")) + } + })) + defer ceemsServer.Close() + + // Backends for group 1 + dummyServer := dummyTSDBServer("30d", "slurm-0") + defer dummyServer.Close() + backendURL, err := url.Parse(dummyServer.URL) + if err != nil { + t.Fatal(err) + } + + rp := httputil.NewSingleHostReverseProxy(backendURL) + backend := backend.NewTSDBServer(backendURL, rp, log.NewNopLogger()) + + // Start manager + manager, err := serverpool.NewManager("resource-based", log.NewNopLogger()) + if err != nil { + t.Fatal(err) + } + manager.Add("slurm-0", backend) + manager.Add("os-1", backend) + + // make minimal config + config := &Config{ + Logger: log.NewNopLogger(), + Manager: manager, + } + config.APIServer.Web.URL = ceemsServer.URL + + // New load balancer + lb, err := NewLoadBalancer(config) + if err != nil { + t.Errorf("failed to create load balancer: %s", err) + } + + // Validate cluster IDs + if err := lb.ValidateClusterIDs(); err == nil { + t.Errorf("expected validation error, got none") + } +}