diff --git a/api_backend.go b/api_backend.go index 2fd0bc96..f54bc4e9 100644 --- a/api_backend.go +++ b/api_backend.go @@ -32,6 +32,7 @@ import ( "net/http" "net/url" "regexp" + "slices" "strings" "time" ) @@ -432,10 +433,12 @@ type TurnCredentials struct { // Information on a backend in the etcd cluster. type BackendInformationEtcd struct { - parsedUrl *url.URL + // Compat setting. + Url string `json:"url,omitempty"` - Url string `json:"url"` - Secret string `json:"secret"` + Urls []string `json:"urls,omitempty"` + parsedUrls []*url.URL + Secret string `json:"secret"` MaxStreamBitrate int `json:"maxstreambitrate,omitempty"` MaxScreenBitrate int `json:"maxscreenbitrate,omitempty"` @@ -444,23 +447,40 @@ type BackendInformationEtcd struct { } func (p *BackendInformationEtcd) CheckValid() error { - if p.Url == "" { - return fmt.Errorf("url missing") - } if p.Secret == "" { return fmt.Errorf("secret missing") } - parsedUrl, err := url.Parse(p.Url) - if err != nil { - return fmt.Errorf("invalid url: %w", err) - } + if len(p.Urls) > 0 { + slices.Sort(p.Urls) + p.Urls = slices.Compact(p.Urls) + for idx, u := range p.Urls { + parsedUrl, err := url.Parse(u) + if err != nil { + return fmt.Errorf("invalid url %s: %w", u, err) + } + if strings.Contains(parsedUrl.Host, ":") && hasStandardPort(parsedUrl) { + parsedUrl.Host = parsedUrl.Hostname() + p.Urls[idx] = parsedUrl.String() + } + + p.parsedUrls = append(p.parsedUrls, parsedUrl) + } + } else if p.Url != "" { + parsedUrl, err := url.Parse(p.Url) + if err != nil { + return fmt.Errorf("invalid url: %w", err) + } + if strings.Contains(parsedUrl.Host, ":") && hasStandardPort(parsedUrl) { + parsedUrl.Host = parsedUrl.Hostname() + p.Url = parsedUrl.String() + } - if strings.Contains(parsedUrl.Host, ":") && hasStandardPort(parsedUrl) { - parsedUrl.Host = parsedUrl.Hostname() - p.Url = parsedUrl.String() + p.Urls = append(p.Urls, p.Url) + p.parsedUrls = append(p.parsedUrls, parsedUrl) + } else { + return fmt.Errorf("urls missing") } - p.parsedUrl = parsedUrl return nil } diff --git a/api_signaling.go b/api_signaling.go index 7242da36..ebd20177 100644 --- a/api_signaling.go +++ b/api_signaling.go @@ -486,6 +486,10 @@ func (m *HelloClientMessage) CheckValid() error { if m.Auth.Url[len(m.Auth.Url)-1] != '/' { m.Auth.Url += "/" } + if pos := strings.Index(m.Auth.Url, "ocs/v2.php/apps/spreed/"); pos != -1 { + m.Auth.Url = m.Auth.Url[:pos] + } + if u, err := url.ParseRequestURI(m.Auth.Url); err != nil { return err } else { diff --git a/backend_configuration.go b/backend_configuration.go index 33d5eb8b..32d10017 100644 --- a/backend_configuration.go +++ b/backend_configuration.go @@ -22,8 +22,10 @@ package signaling import ( + "bytes" "fmt" "net/url" + "slices" "strings" "sync" @@ -68,6 +70,22 @@ func (b *Backend) IsCompat() bool { return len(b.urls) == 0 } +func (b *Backend) Equal(other *Backend) bool { + if b == other { + return true + } else if b == nil || other == nil { + return false + } + + return b.id == other.id && + b.allowHttp == other.allowHttp && + b.maxStreamBitrate == other.maxStreamBitrate && + b.maxScreenBitrate == other.maxScreenBitrate && + b.sessionLimit == other.sessionLimit && + bytes.Equal(b.secret, other.secret) && + slices.Equal(b.urls, other.urls) +} + func (b *Backend) IsUrlAllowed(u *url.URL) bool { switch u.Scheme { case "https": diff --git a/backend_configuration_test.go b/backend_configuration_test.go index 0e612a20..ec619182 100644 --- a/backend_configuration_test.go +++ b/backend_configuration_test.go @@ -461,7 +461,7 @@ func mustParse(s string) *url.URL { return p } -func TestBackendConfiguration_Etcd(t *testing.T) { +func TestBackendConfiguration_EtcdCompat(t *testing.T) { t.Parallel() CatchLogForTest(t) require := require.New(t) diff --git a/backend_storage_etcd.go b/backend_storage_etcd.go index 64a35718..f63d5e21 100644 --- a/backend_storage_etcd.go +++ b/backend_storage_etcd.go @@ -178,53 +178,62 @@ func (s *backendStorageEtcd) EtcdKeyUpdated(client *EtcdClient, key string, data return } + allowHttp := false + for _, u := range info.parsedUrls { + if u.Scheme == "http" { + allowHttp = true + break + } + } + backend := &Backend{ id: key, - urls: []string{info.Url}, + urls: info.Urls, secret: []byte(info.Secret), - allowHttp: info.parsedUrl.Scheme == "http", + allowHttp: allowHttp, maxStreamBitrate: info.MaxStreamBitrate, maxScreenBitrate: info.MaxScreenBitrate, sessionLimit: info.SessionLimit, } - host := info.parsedUrl.Host - s.mu.Lock() defer s.mu.Unlock() s.keyInfos[key] = &info - entries, found := s.backends[host] - if !found { - // Simple case, first backend for this host - log.Printf("Added backend %s (from %s)", info.Url, key) - s.backends[host] = []*Backend{backend} - updateBackendStats(backend) - statsBackendsCurrent.Inc() - s.wakeupForTesting() - return - } - - // Was the backend changed? - replaced := false - for idx, entry := range entries { - if entry.id == key { - log.Printf("Updated backend %s (from %s)", info.Url, key) + for idx, u := range info.parsedUrls { + host := u.Host + entries, found := s.backends[host] + if !found { + // Simple case, first backend for this host + log.Printf("Added backend %s (from %s)", info.Urls[idx], key) + s.backends[host] = []*Backend{backend} updateBackendStats(backend) - entries[idx] = backend - replaced = true - break + statsBackendsCurrent.Inc() + s.wakeupForTesting() + continue } - } - if !replaced { - // New backend, add to list. - log.Printf("Added backend %s (from %s)", info.Url, key) - s.backends[host] = append(entries, backend) - updateBackendStats(backend) - statsBackendsCurrent.Inc() + // Was the backend changed? + replaced := false + for idx, entry := range entries { + if entry.id == key { + log.Printf("Updated backend %s (from %s)", info.Urls[idx], key) + updateBackendStats(backend) + entries[idx] = backend + replaced = true + break + } + } + + if !replaced { + // New backend, add to list. + log.Printf("Added backend %s (from %s)", info.Urls[idx], key) + s.backends[host] = append(entries, backend) + updateBackendStats(backend) + statsBackendsCurrent.Inc() + } } s.wakeupForTesting() } @@ -239,27 +248,29 @@ func (s *backendStorageEtcd) EtcdKeyDeleted(client *EtcdClient, key string, prev } delete(s.keyInfos, key) - host := info.parsedUrl.Host - entries, found := s.backends[host] - if !found { - return - } - - log.Printf("Removing backend %s (from %s)", info.Url, key) - newEntries := make([]*Backend, 0, len(entries)-1) - for _, entry := range entries { - if entry.id == key { - updateBackendStats(entry) - statsBackendsCurrent.Dec() + for idx, u := range info.parsedUrls { + host := u.Host + entries, found := s.backends[host] + if !found { continue } - newEntries = append(newEntries, entry) - } - if len(newEntries) > 0 { - s.backends[host] = newEntries - } else { - delete(s.backends, host) + log.Printf("Removing backend %s (from %s)", info.Urls[idx], key) + newEntries := make([]*Backend, 0, len(entries)-1) + for _, entry := range entries { + if entry.id == key { + updateBackendStats(entry) + statsBackendsCurrent.Dec() + continue + } + + newEntries = append(newEntries, entry) + } + if len(newEntries) > 0 { + s.backends[host] = newEntries + } else { + delete(s.backends, host) + } } s.wakeupForTesting() } diff --git a/backend_storage_static.go b/backend_storage_static.go index bc85524d..0e8199ce 100644 --- a/backend_storage_static.go +++ b/backend_storage_static.go @@ -24,7 +24,7 @@ package signaling import ( "log" "net/url" - "reflect" + "slices" "strings" "github.com/dlintw/goconf" @@ -151,7 +151,7 @@ func (s *backendStorageStatic) UpsertHost(host string, backends []*Backend) { found := false index := 0 for _, newBackend := range backends { - if reflect.DeepEqual(existingBackend, newBackend) { // otherwise we could manually compare the struct members here + if existingBackend.Equal(newBackend) { found = true backends = append(backends[:index], backends[index+1:]...) break @@ -201,35 +201,24 @@ func getConfiguredBackendIDs(backendIds string) (ids []string) { return ids } +func Map[T any](s []T, f func(T) T) []T { + var result []T + for _, v := range s { + result = append(result, f(v)) + } + return result +} + func getConfiguredHosts(backendIds string, config *goconf.ConfigFile, commonSecret string) (hosts map[string][]*Backend) { hosts = make(map[string][]*Backend) + seenUrls := make(map[string]string) for _, id := range getConfiguredBackendIDs(backendIds) { - u, _ := config.GetString(id, "url") - if u == "" { - log.Printf("Backend %s is missing or incomplete, skipping", id) - continue - } - - if u[len(u)-1] != '/' { - u += "/" - } - parsed, err := url.Parse(u) - if err != nil { - log.Printf("Backend %s has an invalid url %s configured (%s), skipping", id, u, err) - continue - } - - if strings.Contains(parsed.Host, ":") && hasStandardPort(parsed) { - parsed.Host = parsed.Hostname() - u = parsed.String() - } - secret, _ := config.GetString(id, "secret") if secret == "" && commonSecret != "" { log.Printf("Backend %s has no own shared secret set, using common shared secret", id) secret = commonSecret } - if u == "" || secret == "" { + if secret == "" { log.Printf("Backend %s is missing or incomplete, skipping", id) continue } @@ -251,18 +240,71 @@ func getConfiguredHosts(backendIds string, config *goconf.ConfigFile, commonSecr maxScreenBitrate = 0 } - hosts[parsed.Host] = append(hosts[parsed.Host], &Backend{ + var urls []string + if u, _ := config.GetString(id, "urls"); u != "" { + urls = strings.Split(u, ",") + urls = Map(urls, func(s string) string { + return strings.TrimSpace(s) + }) + urls = slices.DeleteFunc(urls, func(s string) bool { + return s == "" + }) + slices.Sort(urls) + urls = slices.Compact(urls) + } else if u, _ := config.GetString(id, "url"); u != "" { + if u = strings.TrimSpace(u); u != "" { + urls = []string{u} + } + } + + if len(urls) == 0 { + log.Printf("Backend %s is missing or incomplete, skipping", id) + continue + } + + backend := &Backend{ id: id, - urls: []string{u}, secret: []byte(secret), - allowHttp: parsed.Scheme == "http", - maxStreamBitrate: maxStreamBitrate, maxScreenBitrate: maxScreenBitrate, sessionLimit: uint64(sessionLimit), - }) + } + + added := make(map[string]bool) + for _, u := range urls { + if u[len(u)-1] != '/' { + u += "/" + } + + parsed, err := url.Parse(u) + if err != nil { + log.Printf("Backend %s has an invalid url %s configured (%s), skipping", id, u, err) + continue + } + + if strings.Contains(parsed.Host, ":") && hasStandardPort(parsed) { + parsed.Host = parsed.Hostname() + u = parsed.String() + } + + if prev, found := seenUrls[u]; found { + log.Printf("Url %s in backend %s was already used in backend %s, skipping", u, id, prev) + continue + } + + seenUrls[u] = id + backend.urls = append(backend.urls, u) + if parsed.Scheme == "http" { + backend.allowHttp = true + } + + if !added[parsed.Host] { + hosts[parsed.Host] = append(hosts[parsed.Host], backend) + added[parsed.Host] = true + } + } } return hosts