diff --git a/persistence/bitmagnet.go b/persistence/bitmagnet.go index e6a9b672..2d002e15 100644 --- a/persistence/bitmagnet.go +++ b/persistence/bitmagnet.go @@ -125,6 +125,10 @@ func (b *bitmagnet) GetNumberOfTorrents() (uint, error) { return 0, nil } +func (b *bitmagnet) GetNumberOfQueryTorrents(query string, epoch int64) (uint64, error) { + return 0, nil +} + func (b *bitmagnet) QueryTorrents(query string, epoch int64, orderBy OrderingCriteria, ascending bool, limit uint64, lastOrderedValue *float64, lastID *uint64) ([]TorrentMetadata, error) { return nil, errors.New("query not supported") } diff --git a/persistence/interface.go b/persistence/interface.go index 64126c03..81684fa7 100644 --- a/persistence/interface.go +++ b/persistence/interface.go @@ -17,6 +17,8 @@ type Database interface { // GetNumberOfTorrents returns the number of torrents saved in the database. Might be an // approximation. GetNumberOfTorrents() (uint, error) + // GetNumberOfQueryTorrents returns the total number of data records in a fuzzy query. + GetNumberOfQueryTorrents(query string, epoch int64) (uint64, error) // QueryTorrents returns @pageSize amount of torrents, // * that are discovered before @discoveredOnBefore // * that match the @query if it's not empty, else all torrents diff --git a/persistence/postgres.go b/persistence/postgres.go index ef5e0fd9..1a0b3288 100644 --- a/persistence/postgres.go +++ b/persistence/postgres.go @@ -167,6 +167,38 @@ func (db *postgresDatabase) GetNumberOfTorrents() (uint, error) { } } +func (db *postgresDatabase) GetNumberOfQueryTorrents(query string, epoch int64) (uint64, error) { + + var querySkeleton = `SELECT COUNT(*) + FROM torrents + WHERE + name ILIKE CONCAT('%',$1::text,'%') AND + discovered_on <= $2; + ` + + rows, err := db.conn.Query(querySkeleton, query, epoch) + if err != nil { + return 0, err + } + defer rows.Close() + + if !rows.Next() { + return 0, errors.New("no rows returned from `SELECT COUNT(*) FROM torrents WHERE name ILIKE CONCAT('%%',$1::text,'%%') AND discovered_on <= $2;`") + } + + var n *int64 + if err = rows.Scan(&n); err != nil { + return 0, err + } + + // If the database is empty (i.e. 0 entries in 'torrents') then the query will return nil. + if n == nil || *n < 0 { + return 0, nil + } else { + return uint64(*n), nil + } +} + func (db *postgresDatabase) QueryTorrents( query string, epoch int64, diff --git a/persistence/postgres_test.go b/persistence/postgres_test.go index 89f0be0e..e660db93 100644 --- a/persistence/postgres_test.go +++ b/persistence/postgres_test.go @@ -155,6 +155,79 @@ func TestPostgresDatabase_GetNumberOfTorrents(t *testing.T) { } } +func TestPostgresDatabase_GetNumberOfQueryTorrents(t *testing.T) { + t.Parallel() + + conn, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("An error '%s' was not expected when opening a stub database connection", err) + } + defer conn.Close() + + pgDb := &postgresDatabase{conn: conn} + + query := "test-query" + epoch := int64(1609459200) // 2021-01-01 00:00:00 UTC + + rows := sqlmock.NewRows([]string{"count"}).AddRow(int64(10)) + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM torrents WHERE name ILIKE CONCAT\('%',\$1::text,'%'\) AND discovered_on <= \$2;`). + WithArgs(query, epoch). + WillReturnRows(rows) + + result, err := pgDb.GetNumberOfQueryTorrents(query, epoch) + + if err != nil { + t.Errorf("Expected no error, but got %v", err) + } + + if result != uint64(10) { + t.Errorf("Expected result to be 10, but got %d", result) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("There were unmet expectations: %s", err) + } + + rows = sqlmock.NewRows([]string{"count"}) + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM torrents WHERE name ILIKE CONCAT\('%',\$1::text,'%'\) AND discovered_on <= \$2;`). + WithArgs(query, epoch). + WillReturnRows(rows) + + result, err = pgDb.GetNumberOfQueryTorrents(query, epoch) + + if err == nil { + t.Error("Expected an error, but got none") + } + + if result != uint64(0) { + t.Errorf("Expected result to be 0, but got %d", result) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("There were unmet expectations: %s", err) + } + + rows = sqlmock.NewRows([]string{"count"}).AddRow(nil) + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM torrents WHERE name ILIKE CONCAT\('%',\$1::text,'%'\) AND discovered_on <= \$2;`). + WithArgs(query, epoch). + WillReturnRows(rows) + + result, err = pgDb.GetNumberOfQueryTorrents(query, epoch) + + if err != nil { + t.Errorf("Expected no error, but got %v", err) + } + + if result != uint64(0) { + t.Errorf("Expected result to be 0, but got %d", result) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("There were unmet expectations: %s", err) + } + +} + func TestPostgresDatabase_Close(t *testing.T) { t.Parallel() diff --git a/persistence/rabbitmq.go b/persistence/rabbitmq.go index 8402a7d8..60b6507c 100644 --- a/persistence/rabbitmq.go +++ b/persistence/rabbitmq.go @@ -136,6 +136,10 @@ func (r *rabbitMQ) GetNumberOfTorrents() (uint, error) { return 0, nil } +func (r *rabbitMQ) GetNumberOfQueryTorrents(query string, epoch int64) (uint64, error) { + return 0, nil +} + func (r *rabbitMQ) QueryTorrents(query string, epoch int64, orderBy OrderingCriteria, ascending bool, limit uint64, lastOrderedValue *float64, lastID *uint64) ([]TorrentMetadata, error) { return nil, errors.New("query not supported") } diff --git a/persistence/sqlite3.go b/persistence/sqlite3.go index 719b9116..166297a0 100644 --- a/persistence/sqlite3.go +++ b/persistence/sqlite3.go @@ -228,6 +228,36 @@ func (db *sqlite3Database) GetNumberOfTorrents() (uint, error) { } } +func (db *sqlite3Database) GetNumberOfQueryTorrents(query string, epoch int64) (uint64, error) { + var querySkeleton = `SELECT COUNT(*) + FROM torrents + WHERE + LOWER(name) LIKE '%' || LOWER($1) || '%' AND + discovered_on <= $2; + ` + rows, err := db.conn.Query(querySkeleton, query, epoch) + if err != nil { + return 0, err + } + defer rows.Close() + + if !rows.Next() { + return 0, fmt.Errorf("no rows returned from `SELECT COUNT(*) FROM torrents WHERE LOWER(name) LIKE '%%' || LOWER($1) || '%%' AND discovered_on <= $2;`") + } + + var n *uint + if err = rows.Scan(&n); err != nil { + return 0, err + } + + // If the database is empty (i.e. 0 entries in 'torrents') then the query will return nil. + if n == nil { + return 0, nil + } else { + return uint64(*n), nil + } +} + func (db *sqlite3Database) QueryTorrents( query string, epoch int64, diff --git a/persistence/sqlite3_test.go b/persistence/sqlite3_test.go index 4a51c90a..06100255 100644 --- a/persistence/sqlite3_test.go +++ b/persistence/sqlite3_test.go @@ -73,6 +73,69 @@ func Test_sqlite3Database_GetNumberOfTorrents(t *testing.T) { } } +func TestSqlite3Database_GetNumberOfQueryTorrents(t *testing.T) { + t.Parallel() + db := newDb(t) + + // The database is empty, so the number of torrents for any query should be 0. + tests := []struct { + name string + query string + epoch int64 + want uint64 + wantErr bool + }{ + { + name: "Test Empty Query", + query: "", + epoch: 0, + want: 0, + wantErr: false, + }, + { + name: "Test Simple Query", + query: "test", + epoch: 0, + want: 0, + wantErr: false, + }, + { + name: "Test Query with Special Characters", + query: "test!@#$%^&*()", + epoch: 0, + want: 0, + wantErr: false, + }, + { + name: "Test Query with Future Epoch", + query: "test", + epoch: 32503680000, // January 1, 3000 + want: 0, + wantErr: false, + }, + { + name: "Test Query with Past Epoch", + query: "test", + epoch: 1000000000, // September 9, 2001 + want: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := db.GetNumberOfQueryTorrents(tt.query, tt.epoch) + if (err != nil) != tt.wantErr { + t.Errorf("sqlite3Database.GetNumberOfQueryTorrents() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("sqlite3Database.GetNumberOfQueryTorrents() = %v, want %v", got, tt.want) + } + }) + } +} + func Test_sqlite3Database_AddNewTorrent(t *testing.T) { t.Parallel() db := newDb(t) diff --git a/persistence/zeromq.go b/persistence/zeromq.go index 10d6606a..2c8fafda 100644 --- a/persistence/zeromq.go +++ b/persistence/zeromq.go @@ -90,6 +90,10 @@ func (instance *zeromq) GetNumberOfTorrents() (uint, error) { return 0, nil } +func (instance *zeromq) GetNumberOfQueryTorrents(query string, epoch int64) (uint64, error) { + return 0, nil +} + func (instance *zeromq) QueryTorrents( query string, epoch int64, diff --git a/persistence/zeromq_mock.go b/persistence/zeromq_mock.go index b983f400..baf6f5a0 100644 --- a/persistence/zeromq_mock.go +++ b/persistence/zeromq_mock.go @@ -34,6 +34,10 @@ func (instance *zeromq) GetNumberOfTorrents() (uint, error) { return 0, nil } +func (instance *zeromq) GetNumberOfQueryTorrents(query string, epoch int64) (uint64, error) { + return 0, nil +} + func (instance *zeromq) QueryTorrents( query string, epoch int64, diff --git a/web/router.go b/web/router.go index faddc54a..8bc15b25 100644 --- a/web/router.go +++ b/web/router.go @@ -57,6 +57,7 @@ func makeRouter() *http.ServeMux { router.HandleFunc("/api/v0.1/statistics", BasicAuth(apiStatistics)) router.HandleFunc("/api/v0.1/torrents", BasicAuth(apiTorrents)) + router.HandleFunc("/api/v0.1/torrentstotal", BasicAuth(apiTorrentsTotal)) router.HandleFunc("/api/v0.1/torrents/{infohash}", BasicAuth(infohashMiddleware(apiTorrent))) router.HandleFunc("/api/v0.1/torrents/{infohash}/filelist", BasicAuth(infohashMiddleware(apiFileList))) diff --git a/web/torrents.go b/web/torrents.go index df634f07..f478faf8 100644 --- a/web/torrents.go +++ b/web/torrents.go @@ -198,6 +198,51 @@ func apiTorrents(w http.ResponseWriter, r *http.Request) { } } +func apiTorrentsTotal(w http.ResponseWriter, r *http.Request) { + // @lastOrderedValue AND @lastID are either both supplied or neither of them should be supplied + // at all; and if that is NOT the case, then return an error. + if q := r.URL.Query(); !((q.Get("lastOrderedValue") != "" && q.Get("lastID") != "") || + (q.Get("lastOrderedValue") == "" && q.Get("lastID") == "")) { + http.Error(w, "`lastOrderedValue`, `lastID` must be supplied altogether, if supplied.", http.StatusBadRequest) + return + } + + var tq struct { + Epoch int64 `schema:"epoch"` + Query string `schema:"query"` + } + + err := r.ParseForm() + if err != nil { + http.Error(w, "error while parsing the URL: "+err.Error(), http.StatusBadRequest) + return + } + + if r.Form.Has("epoch") { + tq.Epoch, err = strconv.ParseInt(r.Form.Get("epoch"), 10, 64) + if err != nil { + http.Error(w, "error while parsing the URL: "+err.Error(), http.StatusBadRequest) + return + } + } else { + http.Error(w, "lack required parameters while parsing the URL: `epoch`", http.StatusBadRequest) + return + } + + tq.Query = r.Form.Get("query") + + torrentsTotal, err := database.GetNumberOfQueryTorrents(tq.Query, tq.Epoch) + if err != nil { + http.Error(w, "GetNumberOfQueryTorrents: "+err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set(ContentType, ContentTypeJson) + if err = json.NewEncoder(w).Encode(torrentsTotal); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } +} + func parseOrderBy(s string) (persistence.OrderingCriteria, error) { switch s { case "RELEVANCE": diff --git a/web/torrents_test.go b/web/torrents_test.go index 2f28b3fa..43d3133d 100644 --- a/web/torrents_test.go +++ b/web/torrents_test.go @@ -154,3 +154,77 @@ func TestParseOrderBy(t *testing.T) { }) } } + +func TestApiTorrentsTotal(t *testing.T) { + t.Parallel() + + initDb() + + tests := []struct { + name string + queryParams string + expectedStatus int + expectedError string + }{ + { + name: "missing required epoch parameter", + queryParams: "", + expectedStatus: http.StatusBadRequest, + expectedError: "lack required parameters while parsing the URL: `epoch`", + }, + { + name: "invalid epoch parameter", + queryParams: "epoch=abc", + expectedStatus: http.StatusBadRequest, + expectedError: "error while parsing the URL: strconv.ParseInt: parsing \"abc\": invalid syntax", + }, + { + name: "valid request with epoch", + queryParams: "epoch=1234567890", + expectedStatus: http.StatusOK, + }, + { + name: "invalid request with only lastOrderedValue", + queryParams: "epoch=1234567890&lastOrderedValue=123.45", + expectedStatus: http.StatusBadRequest, + expectedError: "`lastOrderedValue`, `lastID` must be supplied altogether, if supplied.", + }, + { + name: "invalid request with only lastID", + queryParams: "epoch=1234567890&lastID=123", + expectedStatus: http.StatusBadRequest, + expectedError: "`lastOrderedValue`, `lastID` must be supplied altogether, if supplied.", + }, + { + name: "valid request with both lastOrderedValue and lastID", + queryParams: "epoch=1234567890&lastOrderedValue=123.45&lastID=123", + expectedStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := http.NewRequest("GET", "/api/torrents/total?"+tt.queryParams, nil) + if err != nil { + t.Fatalf("could not create request: %v", err) + } + + rec := httptest.NewRecorder() + handler := http.HandlerFunc(apiTorrentsTotal) + handler.ServeHTTP(rec, req) + + res := rec.Result() + defer res.Body.Close() + + if res.StatusCode != tt.expectedStatus { + t.Errorf("expected status %v; got %v", tt.expectedStatus, res.StatusCode) + } + + if tt.expectedError != "" { + if !strings.Contains(rec.Body.String(), tt.expectedError) { + t.Errorf("expected error %q; got %q", tt.expectedError, rec.Body.String()) + } + } + }) + } +}