diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index e8304a45586..d49fe3fc4f2 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,5 +1,7 @@ +- [ ] Public API changes documented in changelogs (optional) + Signed-off-by: diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 4ced4f78125..30255995c56 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -15,12 +15,9 @@ jobs: uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@nightly with: - toolchain: nightly components: rustfmt - profile: minimal - override: true - name: Run Benchmarks run: cargo bench | tee benchmark-output.txt diff --git a/.github/workflows/bindings_ci.yml b/.github/workflows/bindings_ci.yml index b53a81a8cff..1c6e22fcefb 100644 --- a/.github/workflows/bindings_ci.yml +++ b/.github/workflows/bindings_ci.yml @@ -12,54 +12,28 @@ on: - synchronize - ready_for_review +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + env: CARGO_TERM_COLOR: always MATRIX_SDK_CRYPTO_NODEJS_PATH: bindings/matrix-sdk-crypto-nodejs MATRIX_SDK_CRYPTO_JS_PATH: bindings/matrix-sdk-crypto-js jobs: - xtask-linux: - runs-on: ubuntu-latest - steps: - - name: Checkout repo - uses: actions/checkout@v2 - - - name: Install Protoc - uses: arduino/setup-protoc@v1 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - - - name: Check xtask cache - uses: actions/cache@v3 - id: xtask-cache - with: - path: target/debug/xtask - key: xtask-linux-${{ hashFiles('Cargo.toml', 'xtask/**') }} - - - name: Install rust stable toolchain - if: steps.xtask-cache.outputs.cache-hit != 'true' - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - override: true - - - name: Build - if: steps.xtask-cache.outputs.cache-hit != 'true' - uses: actions-rs/cargo@v1 - with: - command: build - args: -p xtask + xtask: + uses: ./.github/workflows/xtask.yml test-uniffi-codegen: name: Test UniFFI bindings generation - needs: xtask-linux + needs: xtask if: github.event_name == 'push' || !github.event.pull_request.draft runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v3 - name: Install Protoc uses: arduino/setup-protoc@v1 @@ -67,24 +41,49 @@ jobs: repo-token: ${{ secrets.GITHUB_TOKEN }} - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - profile: minimal - override: true + uses: dtolnay/rust-toolchain@stable - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - name: Get xtask - uses: actions/cache@v3 + uses: actions/cache/restore@v3 with: path: target/debug/xtask - key: xtask-linux-${{ hashFiles('Cargo.toml', 'xtask/**') }} + key: "${{ needs.xtask.outputs.cachekey-linux }}" + fail-on-cache-miss: true - name: Build library & generate bindings run: target/debug/xtask ci bindings + lint-js-bindings: + strategy: + fail-fast: true + matrix: + include: + - name: "[m]-crypto-nodejs" + path: "bindings/matrix-sdk-crypto-nodejs" + - name: "[m]-crypto-js" + path: "bindings/matrix-sdk-crypto-js" + + name: lint ${{ matrix.name }} + runs-on: ubuntu-latest + + steps: + - name: Checkout the repo + uses: actions/checkout@v3 + + - name: Install Node.js + uses: actions/setup-node@v3 + + - name: Install NPM dependencies + working-directory: ${{ matrix.path }} + run: npm install + + - name: run lint + working-directory: ${{ matrix.path }} + run: npm run lint + test-matrix-sdk-crypto-nodejs: name: ${{ matrix.os-name }} [m]-crypto-nodejs, v${{ matrix.node-version }} if: github.event_name == 'push' || !github.event.pull_request.draft @@ -111,14 +110,10 @@ jobs: uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - profile: minimal - override: true + uses: dtolnay/rust-toolchain@stable - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - name: Install Node.js uses: actions/setup-node@v3 @@ -169,15 +164,12 @@ jobs: uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: - toolchain: stable - target: wasm32-unknown-unknown - profile: minimal - override: true + targets: wasm32-unknown-unknown - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - name: Install Node.js uses: actions/setup-node@v3 @@ -200,72 +192,37 @@ jobs: working-directory: ${{ env.MATRIX_SDK_CRYPTO_JS_PATH }} run: npm run doc - xtask-macos: - runs-on: macos-12 - steps: - - name: Checkout repo - uses: actions/checkout@v2 - - - name: Install Protoc - uses: arduino/setup-protoc@v1 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - - - name: Check xtask cache - uses: actions/cache@v3 - id: xtask-cache - with: - path: target/debug/xtask - key: xtask-macos-${{ hashFiles('Cargo.toml', 'xtask/**') }} - - - name: Install rust stable toolchain - if: steps.xtask-cache.outputs.cache-hit != 'true' - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - override: true - - - name: Build - if: steps.xtask-cache.outputs.cache-hit != 'true' - uses: actions-rs/cargo@v1 - with: - command: build - args: -p xtask - test-apple: name: matrix-rust-components-swift - needs: xtask-macos + needs: xtask runs-on: macos-12 if: github.event_name == 'push' || !github.event.pull_request.draft steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v3 + # install protoc in case we end up rebuilding opentelemetry-proto - name: Install Protoc uses: arduino/setup-protoc@v1 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - toolchain: nightly - profile: minimal - override: true + uses: dtolnay/rust-toolchain@nightly - name: Install aarch64-apple-ios target run: rustup target install aarch64-apple-ios - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - name: Get xtask - uses: actions/cache@v3 + uses: actions/cache/restore@v3 with: path: target/debug/xtask - key: xtask-macos-${{ hashFiles('Cargo.toml', 'xtask/**') }} + key: "${{ needs.xtask.outputs.cachekey-macos }}" + fail-on-cache-miss: true - name: Build library & bindings run: target/debug/xtask swift build-library @@ -275,4 +232,4 @@ jobs: run: swift test - name: Build Framework - run: cargo xtask swift build-framework --only-target=aarch64-apple-ios + run: target/debug/xtask swift build-framework --only-target=aarch64-apple-ios diff --git a/.github/workflows/cancel_others.yml b/.github/workflows/cancel_others.yml deleted file mode 100644 index 0f1227f0622..00000000000 --- a/.github/workflows/cancel_others.yml +++ /dev/null @@ -1,13 +0,0 @@ -on: - pull_request: - branches: [main] - -jobs: - cancel-others: - runs-on: ubuntu-latest - steps: - - name: Cancel workflows for older commits - uses: styfle/cancel-workflow-action@0.11.0 - with: - workflow_id: all - all_but_latest: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7795a8715eb..01426cc4231 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,42 +12,16 @@ on: - synchronize - ready_for_review +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + env: CARGO_TERM_COLOR: always jobs: xtask: - runs-on: ubuntu-latest - steps: - - name: Checkout repo - uses: actions/checkout@v2 - - - name: Install Protoc - uses: arduino/setup-protoc@v1 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - - - name: Check xtask cache - uses: actions/cache@v3 - id: xtask-cache - with: - path: target/debug/xtask - key: xtask-${{ hashFiles('xtask/**') }} - - - name: Install rust stable toolchain - if: steps.xtask-cache.outputs.cache-hit != 'true' - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - override: true - - - name: Build - if: steps.xtask-cache.outputs.cache-hit != 'true' - uses: actions-rs/cargo@v1 - with: - command: build - args: -p xtask + uses: ./.github/workflows/xtask.yml test-matrix-sdk-features: name: 🐧 [m], ${{ matrix.name }} @@ -70,32 +44,35 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - profile: minimal - override: true + uses: dtolnay/rust-toolchain@stable - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 + with: + # use a separate cache for each job to work around + # https://github.com/Swatinem/rust-cache/issues/124 + key: "${{ matrix.name }}" + + # ... but only save the cache on the main branch + # cf https://github.com/Swatinem/rust-cache/issues/95 + save-if: ${{ github.ref == 'refs/head/main' }} - name: Install nextest uses: taiki-e/install-action@nextest - name: Get xtask - uses: actions/cache@v3 + uses: actions/cache/restore@v3 with: path: target/debug/xtask - key: xtask-${{ hashFiles('xtask/**') }} + key: "${{ needs.xtask.outputs.cachekey-linux }}" + fail-on-cache-miss: true - name: Test - uses: actions-rs/cargo@v1 - with: - command: run - args: -p xtask -- ci test-features ${{ matrix.name }} + run: | + target/debug/xtask ci test-features ${{ matrix.name }} test-matrix-sdk-examples: name: 🐧 [m]-examples @@ -108,29 +85,24 @@ jobs: uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - profile: minimal - override: true + uses: dtolnay/rust-toolchain@stable - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - name: Install nextest uses: taiki-e/install-action@nextest - name: Get xtask - uses: actions/cache@v3 + uses: actions/cache/restore@v3 with: path: target/debug/xtask - key: xtask-${{ hashFiles('xtask/**') }} + key: "${{ needs.xtask.outputs.cachekey-linux }}" + fail-on-cache-miss: true - name: Test - uses: actions-rs/cargo@v1 - with: - command: run - args: -p xtask -- ci examples + run: | + target/debug/xtask ci examples test-matrix-sdk-crypto: name: 🐧 [m]-crypto @@ -143,29 +115,24 @@ jobs: uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - profile: minimal - override: true + uses: dtolnay/rust-toolchain@stable - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - name: Install nextest uses: taiki-e/install-action@nextest - name: Get xtask - uses: actions/cache@v3 + uses: actions/cache/restore@v3 with: path: target/debug/xtask - key: xtask-${{ hashFiles('xtask/**') }} + key: "${{ needs.xtask.outputs.cachekey-linux }}" + fail-on-cache-miss: true - name: Test - uses: actions-rs/cargo@v1 - with: - command: run - args: -p xtask -- ci test-crypto + run: | + target/debug/xtask ci test-crypto test-all-crates: name: ${{ matrix.name }} @@ -190,37 +157,40 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v3 - name: Install Protoc uses: arduino/setup-protoc@v1 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{ matrix.rust }} - profile: minimal - override: true + # Can't use `${{ matrix.* }}` inside uses + - name: Install Rust stable + if: matrix.rust == 'stable' + uses: dtolnay/rust-toolchain@stable + + - name: Install Rust beta + if: matrix.rust == 'beta' + uses: dtolnay/rust-toolchain@beta + + - name: Install Rust nightly + if: matrix.rust == 'nightly' + uses: dtolnay/rust-toolchain@nightly - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - name: Install nextest uses: taiki-e/install-action@nextest - name: Test - uses: actions-rs/cargo@v1 - with: - command: nextest - args: run --workspace --exclude matrix-sdk-integration-testing --exclude sliding-sync-integration-test + run: | + cargo nextest run --workspace \ + --exclude matrix-sdk-integration-testing --exclude sliding-sync-integration-test - name: Test documentation - uses: actions-rs/cargo@v1 - with: - command: test - args: --doc + run: | + cargo test --doc --features docsrs test-wasm: name: πŸ•ΈοΈ ${{ matrix.name }} @@ -265,13 +235,10 @@ jobs: uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: - toolchain: stable - target: wasm32-unknown-unknown + targets: wasm32-unknown-unknown components: clippy - profile: minimal - override: true - name: Install wasm-pack uses: jetli/wasm-pack-action@v0.4.0 @@ -279,28 +246,33 @@ jobs: version: v0.10.3 - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 + with: + # use a separate cache for each job to work around + # https://github.com/Swatinem/rust-cache/issues/124 + key: "${{ matrix.cmd }}" + + # ... but only save the cache on the main branch + # cf https://github.com/Swatinem/rust-cache/issues/95 + save-if: ${{ github.ref == 'refs/head/main' }} - name: Install nextest uses: taiki-e/install-action@nextest - name: Get xtask - uses: actions/cache@v3 + uses: actions/cache/restore@v3 with: path: target/debug/xtask - key: xtask-${{ hashFiles('xtask/**') }} + key: "${{ needs.xtask.outputs.cachekey-linux }}" + fail-on-cache-miss: true - name: Rust Check - uses: actions-rs/cargo@v1 - with: - command: run - args: -p xtask -- ci wasm ${{ matrix.cmd }} + run: | + target/debug/xtask ci wasm ${{ matrix.cmd }} - name: Wasm-Pack test - uses: actions-rs/cargo@v1 - with: - command: run - args: -p xtask -- ci wasm-pack ${{ matrix.cmd }} + run: | + target/debug/xtask ci wasm-pack ${{ matrix.cmd }} test-appservice: name: ${{ matrix.os-name }} [m]-appservice @@ -314,38 +286,35 @@ jobs: include: - os: ubuntu-latest os-name: 🐧 + xtask-cachekey: "${{ needs.xtask.outputs.cachekey-linux }}" - os: macos-latest os-name: 🍏 + xtask-cachekey: "${{ needs.xtask.outputs.cachekey-macos }}" steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - profile: minimal - override: true + uses: dtolnay/rust-toolchain@stable - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - name: Install nextest uses: taiki-e/install-action@nextest - name: Get xtask - uses: actions/cache@v3 + uses: actions/cache/restore@v3 with: path: target/debug/xtask - key: xtask-${{ hashFiles('xtask/**') }} + key: "${{ matrix.xtask-cachekey }}" + fail-on-cache-miss: true - name: Run checks - uses: actions-rs/cargo@v1 - with: - command: run - args: -p xtask -- ci test-appservice + run: | + target/debug/xtask ci test-appservice formatting: name: Check Formatting @@ -357,18 +326,13 @@ jobs: uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@nightly with: - toolchain: nightly components: rustfmt - profile: minimal - override: true - name: Cargo fmt - uses: actions-rs/cargo@v1 - with: - command: fmt - args: -- --check + run: | + cargo fmt -- --check typos: name: Spell Check with Typos @@ -398,27 +362,23 @@ jobs: repo-token: ${{ secrets.GITHUB_TOKEN }} - name: Install Rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@nightly with: - toolchain: nightly components: clippy - profile: minimal - override: true - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - name: Get xtask - uses: actions/cache@v3 + uses: actions/cache/restore@v3 with: path: target/debug/xtask - key: xtask-${{ hashFiles('xtask/**') }} + key: "${{ needs.xtask.outputs.cachekey-linux }}" + fail-on-cache-miss: true - name: Clippy - uses: actions-rs/cargo@v1 - with: - command: run - args: -p xtask -- ci clippy + run: | + target/debug/xtask ci clippy integration-tests: name: Integration test @@ -431,14 +391,10 @@ jobs: uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - profile: minimal - override: true + uses: dtolnay/rust-toolchain@stable - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - name: Install nextest uses: taiki-e/install-action@nextest @@ -454,22 +410,19 @@ jobs: disableRateLimiting: true - name: Test - uses: actions-rs/cargo@v1 - with: - command: nextest - args: run -p matrix-sdk-integration-testing + run: | + cargo nextest run -p matrix-sdk-integration-testing sliding-sync-integration-tests: name: Sliding Sync Integration test - # disabled until we can figure out the weird docker-not-starting-situation - if: false - # if: github.event_name == 'push' || !github.event.pull_request.draft + if: github.event_name == 'push' || !github.event.pull_request.draft runs-on: ubuntu-latest - # Service containers to run with `runner-job` + # run several docker containers with the same networking stack so the hostname 'postgres' + # maps to the postgres container, etc. services: - # Label used to access the service container + # sliding sync needs a postgres container postgres: # Docker Hub image image: postgres @@ -487,20 +440,37 @@ jobs: ports: # Maps tcp port 5432 on service container to the host - 5432:5432 - + # run sliding sync and point it at the postgres container and synapse container. + # the postgres container needs to be above this to make sure it has started prior to this service. + slidingsync: + image: "ghcr.io/matrix-org/sliding-sync:v0.99.0" + env: + SYNCV3_SERVER: "http://synapse:8008" + SYNCV3_SECRET: "SUPER_CI_SECRET" + SYNCV3_BINDADDR: ":8118" + SYNCV3_DB: "user=postgres password=postgres dbname=syncv3 sslmode=disable host=postgres" + ports: + - 8118:8118 + # tests need a synapse: this is a service and not michaelkaye/setup-matrix-synapse@main as the + # latter does not provide networking for services to communicate with it. + synapse: + # Custom image built from https://github.com/matrix-org/synapse/tree/v1.72.0/docker/complement + # with a dummy /complement/ca set + image: ghcr.io/matrix-org/synapse-service:v1.72.0 + env: + SYNAPSE_COMPLEMENT_DATABASE: sqlite + SERVER_NAME: synapse + ports: + - 8008:8008 steps: - name: Checkout the repo uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - profile: minimal - override: true + uses: dtolnay/rust-toolchain@stable - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - name: Install nextest uses: taiki-e/install-action@nextest @@ -509,24 +479,11 @@ jobs: with: python-version: 3.8 - # local synapse - - uses: michaelkaye/setup-matrix-synapse@main - with: - uploadLogs: true - httpPort: 8228 - disableRateLimiting: true - - # latest sliding sync proxy - - - uses: addnab/docker-run-action@v3 - with: - registry: gcr.io - image: "matrix-org/sliding-sync:v0.98.0" - docker_network: "host" - options: '-e "SYNCV3_SERVER=http://locahost:8228" -e "SYNCV3_SECRET=SUPER_CI_SECRET" -e "SYNCV3_BINDADDR=:8118" -e "SYNCV3_DB=user=postgres password=postgres dbname=syncv3 sslmode=disable host=postgres" -p 8118:8118' - - name: Test - uses: actions-rs/cargo@v1 - with: - command: nextest - args: run -p sliding-sync-integration-tests + env: + RUST_LOG: "hyper=trace" + HOMESERVER_URL: "http://localhost:8008" + HOMESERVER_DOMAIN: "synapse" + SLIDING_SYNC_PROXY_URL: "http://localhost:8118" + run: | + cargo nextest run -p sliding-sync-integration-test diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index a4103d8b7a8..07355e53e69 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -6,6 +6,10 @@ on: pull_request: branches: [main] +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + env: CARGO_TERM_COLOR: always @@ -22,20 +26,15 @@ jobs: ref: ${{ github.event.pull_request.head.sha }} - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - profile: minimal - override: true + uses: dtolnay/rust-toolchain@stable - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - name: Install tarpaulin - uses: actions-rs/cargo@v1 + uses: taiki-e/install-action@v2 with: - command: install - args: cargo-tarpaulin + tool: cargo-tarpaulin # set up backend for integration tests - uses: actions/setup-python@v4 @@ -50,10 +49,15 @@ jobs: serverName: "matrix-sdk.rs" - name: Run tarpaulin - uses: actions-rs/cargo@v1 - with: - command: tarpaulin - args: --out Xml -e sliding-sync-integration-test + run: | + cargo tarpaulin --out Xml -e sliding-sync-integration-test - name: Upload to codecov.io uses: codecov/codecov-action@v3 + with: + # Work around frequent upload errors, for runs inside the main repo (not PRs from forks). + # Otherwise not required for public repos. + token: ${{ secrets.CODECOV_UPLOAD_TOKEN }} + # The upload sometimes fails due to https://github.com/codecov/codecov-action/issues/837. + # To make sure that the failure gets flagged clearly in the UI, fail the action. + fail_ci_if_error: true diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 8444a87e4eb..b6bad5b95e6 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -5,6 +5,10 @@ on: branches: [main] pull_request: +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: docs: name: All crates @@ -16,11 +20,7 @@ jobs: uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: nightly - override: true + uses: dtolnay/rust-toolchain@nightly - name: Install Node.js uses: actions/setup-node@v3 @@ -28,18 +28,16 @@ jobs: node-version: 18 - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 # Keep in sync with xtask docs - name: Build rust documentation - uses: actions-rs/cargo@v1 env: # Work around https://github.com/rust-lang/cargo/issues/10744 CARGO_TARGET_APPLIES_TO_HOST: "true" RUSTDOCFLAGS: "--enable-index-page -Zunstable-options --cfg docsrs -Dwarnings" - with: - command: doc - args: --no-deps --features docsrs + run: + cargo doc --no-deps --features docsrs - name: Build `matrix-sdk-crypto-nodejs` doc run: | diff --git a/.github/workflows/release-crypto-nodejs.yml b/.github/workflows/release-crypto-nodejs.yml index 4296708d82e..d682a3e606f 100644 --- a/.github/workflows/release-crypto-nodejs.yml +++ b/.github/workflows/release-crypto-nodejs.yml @@ -76,16 +76,13 @@ jobs: - uses: actions/checkout@v3 if: "${{ !inputs.tag }}" - name: Install Rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@nightly with: - toolchain: nightly - profile: minimal - target: ${{ matrix.target }} - override: true + targets: ${{ matrix.target }} - name: Install Node.js uses: actions/setup-node@v3 - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - if: ${{ matrix.apt_install }} run: | sudo apt-get update @@ -117,11 +114,7 @@ jobs: - uses: actions/checkout@v3 if: "${{ !inputs.tag }}" - name: Install Rust - uses: actions-rs/toolchain@v1 - with: - toolchain: nightly - profile: minimal - override: true + uses: dtolnay/rust-toolchain@nightly - name: Install Node.js uses: actions/setup-node@v3 - name: Build lib diff --git a/.github/workflows/release_crypto_js.yml b/.github/workflows/release_crypto_js.yml index be60ad3834f..f5b9b0b0955 100644 --- a/.github/workflows/release_crypto_js.yml +++ b/.github/workflows/release_crypto_js.yml @@ -29,15 +29,12 @@ jobs: uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: - toolchain: stable - target: wasm32-unknown-unknown - profile: minimal - override: true + targets: wasm32-unknown-unknown - name: Load cache - uses: Swatinem/rust-cache@v1 + uses: Swatinem/rust-cache@v2 - name: Install Node.js uses: actions/setup-node@v3 diff --git a/.github/workflows/xtask.yml b/.github/workflows/xtask.yml new file mode 100644 index 00000000000..aa4303f6dc4 --- /dev/null +++ b/.github/workflows/xtask.yml @@ -0,0 +1,76 @@ +# A reusable github actions workflow that will build xtask, if it is not +# already cached. +# +# It will create a pair of GHA cache entries, if they do not already exist. +# The cache keys take the form `xtask-{os}-{hash}`, where "{os}" is "linux" +# or "macos", and "{hash}" is the hash of the xtask# directory. +# +# The cache keys are written to output variables named "cachekey-{os}". +# + +name: Build xtask if necessary + +on: + workflow_call: + outputs: + cachekey-linux: + description: "The cache key for the linux build artifact" + value: "${{ jobs.xtask.outputs.cachekey-linux }}" + cachekey-macos: + description: "The cache key for the macos build artifact" + value: "${{ jobs.xtask.outputs.cachekey-macos }}" + +env: + CARGO_TERM_COLOR: always + +jobs: + xtask: + name: "xtask-${{ matrix.os-name }}" + + strategy: + fail-fast: true + matrix: + include: + - os: ubuntu-latest + os-name: 🐧 + cachekey-id: linux + + - os: macos-12 + os-name: 🍏 + cachekey-id: macos + + runs-on: "${{ matrix.os }}" + + steps: + - name: Checkout repo + uses: actions/checkout@v3 + + - name: Calculate cache key + id: cachekey + # set a step output variable "cachekey-{os}" that can be referenced in + # the job outputs below. + run: | + echo "cachekey-${{ matrix.cachekey-id }}=xtask-${{ matrix.cachekey-id }}-${{ hashFiles('Cargo.toml', 'xtask/**') }}" >> $GITHUB_OUTPUT + + - name: Check xtask cache + uses: actions/cache@v3 + id: xtask-cache + with: + path: target/debug/xtask + # use the cache key calculated in the step above. Bit of an awkard + # syntax + key: | + ${{ steps.cachekey.outputs[format('cachekey-{0}', matrix.cachekey-id)] }} + + - name: Install rust stable toolchain + if: steps.xtask-cache.outputs.cache-hit != 'true' + uses: dtolnay/rust-toolchain@stable + + - name: Build + if: steps.xtask-cache.outputs.cache-hit != 'true' + run: | + cargo build -p xtask + + outputs: + "cachekey-linux": "${{ steps.cachekey.outputs.cachekey-linux }}" + "cachekey-macos": "${{ steps.cachekey.outputs.cachekey-macos }}" diff --git a/.gitignore b/.gitignore index b9a0e4f3854..d202f963ade 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ emsdk-* .idea/ .env .build +.swiftpm /Package.swift ## User settings diff --git a/.typos.toml b/.typos.toml index 0f725247480..e274b4d9a65 100644 --- a/.typos.toml +++ b/.typos.toml @@ -5,7 +5,9 @@ Fo = "Fo" BA = "BA" UE = "UE" Ure = "Ure" +OFO = "OFO" Ot = "Ot" +ket = "ket" # This is the thead html tag, remove this once typos is updated in the github # action. 1.3.1 seems to work correctly, while 1.11.0 on the CI seems to get # this wrong diff --git a/CONVENTIONAL_COMMITS.md b/CONVENTIONAL_COMMITS.md deleted file mode 100644 index f3d3b135834..00000000000 --- a/CONVENTIONAL_COMMITS.md +++ /dev/null @@ -1,124 +0,0 @@ -# Conventional Commits - -This project uses [Conventional -Commits](https://www.conventionalcommits.org/). Read the -[Summary](https://www.conventionalcommits.org/en/v1.0.0/#summary) or -the [Full -Specification](https://www.conventionalcommits.org/en/v1.0.0/#specification) -to learn more. - -## Types - -Conventional Commits defines _type_ (as in `type(scope): -message`). This section aims at listing the types used inside this -project: - -| Type | Definition | -|-|-| -| `feat` | About a new feature. | -| `fix` | About a bug fix. | -| `test` | About a test (suite, case, runner…). | -| `docs` | About a documentation modification. | -| `refactor` | About a refactoring. | -| `ci` | About a Continuous Integration modification. | -| `chore` | About some cleanup, or regular tasks. | - -## Scopes - -Conventional Commits defines _scope_ (as in `type(scope): message`). This -section aims at listing all the scopes used inside this project: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
GroupScopeDefinition
CratessdkAbout the matrix-sdk crate.
appserviceAbout the matrix-sdk-appservice crate.
baseAbout the matrix-sdk-base crate.
commonAbout the matrix-sdk-common crate.
cryptoAbout the matrix-sdk-crypto crate.
indexeddbAbout the matrix-sdk-indexeddb crate.
qrcodeAbout the matrix-sdk-qrcode crate.
sledAbout the matrix-sdk-sled crate.
store-encryptionAbout the matrix-sdk-store-encryption crate.
testAbout the matrix-sdk-test and matrix-sdk-test-macros crate.
BindingsappleAbout the matrix-rust-components-swift binding.
crypto-nodejsAbout the matrix-sdk-crypto-nodejs binding.
crypto-jsAbout the matrix-sdk-crypto-js binding.
crypto-ffiAbout the matrix-sdk-crypto-ffi binding.
ffiAbout the matrix-sdk-ffi binding.
Labssled-state-inspectorAbout the sled-state-inspector project.
Continuous IntegrationxtaskAbout the xtask project.
- -## Generating `CHANGELOG.md` - -The [`git-cliff`](https://github.com/orhun/git-cliff) project is used -to generate `CHANGELOG.md` automatically. Hence the various -`cliff.toml` files that are present in this project, or the -`package.metadata.git-cliff` sections in various `Cargo.toml` files. - -Its companion, -[`git-cliff-action`](https://github.com/orhun/git-cliff-action) -project, is used inside Github Action workflows. diff --git a/Cargo.lock b/Cargo.lock index f26a4071516..aa81aadd222 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,9 +57,9 @@ dependencies = [ [[package]] name = "ahash" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf6ccdb167abbf410dcb915cabd428929d7f6a04980b54a11f26a39f1c7f7107" +checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" dependencies = [ "cfg-if", "getrandom 0.2.8", @@ -82,18 +82,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85965b6739a430150bdd138e2374a98af0c3ee0d030b3bb7fc3bddff58d0102e" -[[package]] -name = "android_logger" -version = "0.11.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8619b80c242aa7bd638b5c7ddd952addeecb71f69c75e33f1d47b2804f8f883a" -dependencies = [ - "android_log-sys", - "env_logger", - "log", - "once_cell", -] - [[package]] name = "android_system_properties" version = "0.1.5" @@ -111,9 +99,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anyhow" -version = "1.0.68" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cb2f989d18dd141ab8ae82f64d1a8cdd37e0840f73a406896cf5e99502fab61" +checksum = "224afbd727c3d6e4b90103ece64b8d1b67fbb1973b1046c2281eed3f3803f800" [[package]] name = "anymap2" @@ -123,9 +111,9 @@ checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c" [[package]] name = "app_dirs2" -version = "2.5.4" +version = "2.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47a8d2d8dbda5fca0a522259fb88e4f55d2b10ad39f5f03adeebf85031eba501" +checksum = "a7e7b35733e3a8c1ccb90385088dd5b6eaa61325cb4d1ad56e683b5224ff352e" dependencies = [ "jni", "ndk-context", @@ -209,7 +197,7 @@ dependencies = [ "quote", "serde", "syn", - "toml", + "toml 0.5.11", ] [[package]] @@ -258,6 +246,55 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-executor" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17adb73da160dfb475c183343c8cccd80721ea5a605d3eb57125f0a7b7a92d0b" +dependencies = [ + "async-lock", + "async-task", + "concurrent-queue", + "fastrand", + "futures-lite", + "slab", +] + +[[package]] +name = "async-global-executor" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1b6f5d7df27bd294849f8eec66ecfc63d11814df7a4f5d74168a2394467b776" +dependencies = [ + "async-channel", + "async-executor", + "async-io", + "async-lock", + "blocking", + "futures-lite", + "once_cell", +] + +[[package]] +name = "async-io" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c374dda1ed3e7d8f0d9ba58715f924862c63eae6849c92d3a18e7fbde9e2794" +dependencies = [ + "async-lock", + "autocfg", + "concurrent-queue", + "futures-lite", + "libc", + "log", + "parking", + "polling", + "slab", + "socket2", + "waker-fn", + "windows-sys 0.42.0", +] + [[package]] name = "async-lock" version = "2.6.0" @@ -270,9 +307,54 @@ dependencies = [ [[package]] name = "async-once-cell" -version = "0.4.2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "390a110411bbc7c93b77a736cbd694f64cb06dfa2702173f63169d7a1e1b5298" + +[[package]] +name = "async-process" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6381ead98388605d0d9ff86371043b5aa922a3905824244de40dc263a14fcba4" +dependencies = [ + "async-io", + "async-lock", + "autocfg", + "blocking", + "cfg-if", + "event-listener", + "futures-lite", + "libc", + "signal-hook", + "windows-sys 0.42.0", +] + +[[package]] +name = "async-std" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f61305cacf1d0c5c9d3ee283d22f8f1f8c743a18ceb44a1b102bd53476c141de" +checksum = "62565bb4402e926b29953c785397c6dc0391b7b446e45008b0049eb43cec6f5d" +dependencies = [ + "async-channel", + "async-global-executor", + "async-io", + "async-lock", + "async-process", + "crossbeam-utils", + "futures-channel", + "futures-core", + "futures-io", + "futures-lite", + "gloo-timers", + "kv-log-macro", + "log", + "memchr", + "once_cell", + "pin-project-lite", + "pin-utils", + "slab", + "wasm-bindgen-futures", +] [[package]] name = "async-stream" @@ -295,11 +377,17 @@ dependencies = [ "syn", ] +[[package]] +name = "async-task" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a40729d2133846d9ed0ea60a8b9541bccddab49cd30f0715a1da672fe9a2524" + [[package]] name = "async-trait" -version = "0.1.61" +version = "0.1.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "705339e0e4a9690e2908d2b3d049d85682cf19fbd5782494498fbf7003a6a282" +checksum = "1cd7fce9ba8c3c042128ce72d8b2ddbf3a05747efb67ea0313c635e10bda47a2" dependencies = [ "proc-macro2", "quote", @@ -315,6 +403,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "atomic-waker" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "debc29dde2e69f9e47506b525f639ed42300fc014a3e007832592448fa8e4599" + [[package]] name = "atty" version = "0.2.14" @@ -334,9 +428,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "axum" -version = "0.6.2" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1304eab461cf02bd70b083ed8273388f9724c549b316ba3d1e213ce0e9e7fb7e" +checksum = "4e246206a63c9830e118d12c894f56a82033da1a2361f5544deeee3df85c99d9" dependencies = [ "async-trait", "axum-core", @@ -365,9 +459,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f487e40dc9daee24d8a1779df88522f159a54a980f99cfbe43db0be0bd3444a8" +checksum = "1cae3e661676ffbacb30f1a824089a8c9150e71017f7e1e38f2aa32009188d34" dependencies = [ "async-trait", "bytes", @@ -415,12 +509,6 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" -[[package]] -name = "base64" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5" - [[package]] name = "base64" version = "0.21.0" @@ -457,27 +545,21 @@ dependencies = [ "serde", ] -[[package]] -name = "bit-set" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" -dependencies = [ - "bit-vec", -] - -[[package]] -name = "bit-vec" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" - [[package]] name = "bitflags" version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitmaps" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "031043d04099746d8db04daf1fa424b2bc8bd69d92b25962dcde24da39ab64a2" +dependencies = [ + "typenum", +] + [[package]] name = "blake3" version = "1.3.3" @@ -519,6 +601,20 @@ dependencies = [ "generic-array", ] +[[package]] +name = "blocking" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c67b173a56acffd6d2326fb7ab938ba0b00a71480e14902b2591c87bc5741e8" +dependencies = [ + "async-channel", + "async-lock", + "async-task", + "atomic-waker", + "fastrand", + "futures-lite", +] + [[package]] name = "bs58" version = "0.4.0" @@ -527,15 +623,15 @@ checksum = "771fe0050b883fcc3ea2359b1a96bcfbc090b7116eae7c3c512c7a083fdf23d3" [[package]] name = "bumpalo" -version = "3.11.1" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba" +checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" [[package]] name = "bytemuck" -version = "1.12.3" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aaa3a8d9a1ca92e282c96a32d6511b695d7d994d1d102ba85d279f9b2756947f" +checksum = "c041d3eab048880cb0b86b256447da3f18859a163c3b8d8893f4e6368abe6393" [[package]] name = "byteorder" @@ -545,9 +641,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.3.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfb24e866b15a1af2a1b663f10c6b6b8f397a84aadb828f12e5b289ec23a3a3c" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" [[package]] name = "bytesize" @@ -575,9 +671,9 @@ dependencies = [ [[package]] name = "cargo_metadata" -version = "0.15.2" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "982a0cf6a99c350d7246035613882e376d58cebe571785abc5da4f648d53ac0a" +checksum = "08a1ec454bc3eead8719cb56e15dbbfecdbc14e4b3a3ae4936cc6e31f5fc0d07" dependencies = [ "camino", "cargo-platform", @@ -610,9 +706,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a20104e2335ce8a659d6dd92a51a767a0c062599c73b343fd152cb401e828c3d" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" [[package]] name = "cesu8" @@ -737,13 +833,13 @@ dependencies = [ [[package]] name = "clap" -version = "4.0.32" +version = "4.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7db700bc935f9e43e88d00b0850dae18a63773cfbec6d8e070fccf7fef89a39" +checksum = "c3d7ae14b20b94cb02149ed21a86c423859cbe18dc7ed69845cace50e52b40a5" dependencies = [ "bitflags", - "clap_derive 4.0.21", - "clap_lex 0.3.0", + "clap_derive 4.1.8", + "clap_lex 0.3.1", "is-terminal", "once_cell", "strsim", @@ -765,9 +861,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.0.21" +version = "4.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0177313f9f02afc995627906bbd8967e2be069f5261954222dac78290c2b9014" +checksum = "44bec8e5c9d09e439c4335b1af0abaab56dcf3b94999a936e1bb47b9134288f0" dependencies = [ "heck", "proc-macro-error", @@ -787,24 +883,13 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d4198f73e42b4936b35b5bb248d81d2b595ecb170da0bac7655c54eedfa8da8" +checksum = "783fe232adfca04f90f56201b26d79682d4cd2625e0bc7290b95123afe558ade" dependencies = [ "os_str_bytes", ] -[[package]] -name = "clipboard-win" -version = "4.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7191c27c2357d9b7ef96baac1773290d4ca63b24205b82a3fd8a0637afcf0362" -dependencies = [ - "error-code", - "str-buf", - "winapi", -] - [[package]] name = "cmake" version = "0.1.49" @@ -842,24 +927,24 @@ dependencies = [ [[package]] name = "concurrent-queue" -version = "2.0.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd7bef69dc86e3c610e4e7aed41035e2a7ed12e72dd7530f61327a6579a4390b" +checksum = "c278839b831783b70278b14df4d45e1beb1aad306c07bb796637de9a0e323e8e" dependencies = [ "crossbeam-utils", ] [[package]] name = "console" -version = "0.15.4" +version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9b6515d269224923b26b5febea2ed42b2d5f2ce37284a4dd670fedd6cb8347a" +checksum = "c3d79fbe8970a77e3e34151cc13d3b3e248aa0faaecb9f6091fa07ebefe5ad60" dependencies = [ "encode_unicode", "lazy_static", "libc", "unicode-width", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -1107,9 +1192,9 @@ dependencies = [ [[package]] name = "cxx" -version = "1.0.86" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d1075c37807dcf850c379432f0df05ba52cc30f279c5cfc43cc221ce7f8579" +checksum = "86d3488e7665a7a483b57e25bdd90d0aeb2bc7608c8d0346acf2ad3f1caf1d62" dependencies = [ "cc", "cxxbridge-flags", @@ -1119,9 +1204,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.86" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5044281f61b27bc598f2f6647d480aed48d2bf52d6eb0b627d84c0361b17aa70" +checksum = "48fcaf066a053a41a81dfb14d57d99738b767febb8b735c3016e469fac5da690" dependencies = [ "cc", "codespan-reporting", @@ -1134,53 +1219,18 @@ dependencies = [ [[package]] name = "cxxbridge-flags" -version = "1.0.86" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61b50bc93ba22c27b0d31128d2d130a0a6b3d267ae27ef7e4fae2167dfe8781c" +checksum = "a2ef98b8b717a829ca5603af80e1f9e2e48013ab227b68ef37872ef84ee479bf" [[package]] name = "cxxbridge-macro" -version = "1.0.86" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e61fda7e62115119469c7b3591fd913ecca96fb766cfd3f2e2502ab7bc87a5" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "darling" -version = "0.14.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0dd3cd20dc6b5a876612a6e5accfe7f3dd883db6d07acfbf14c128f61550dfa" -dependencies = [ - "darling_core", - "darling_macro", -] - -[[package]] -name = "darling_core" -version = "0.14.2" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a784d2ccaf7c98501746bf0be29b2022ba41fd62a2e622af997a03e9f972859f" +checksum = "086c685979a698443656e5cf7856c95c642295a38599f12fb1ff76fb28d19892" dependencies = [ - "fnv", - "ident_case", "proc-macro2", "quote", - "strsim", - "syn", -] - -[[package]] -name = "darling_macro" -version = "0.14.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7618812407e9402654622dd402b0a89dff9ba93badd6540781526117b92aab7e" -dependencies = [ - "darling_core", - "quote", "syn", ] @@ -1194,7 +1244,7 @@ dependencies = [ "hashbrown", "lock_api", "once_cell", - "parking_lot_core 0.9.5", + "parking_lot_core 0.9.7", ] [[package]] @@ -1278,44 +1328,14 @@ dependencies = [ "syn", ] -[[package]] -name = "derive_builder" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d07adf7be193b71cc36b193d0f5fe60b918a3a9db4dad0449f57bcfd519704a3" -dependencies = [ - "derive_builder_macro", -] - -[[package]] -name = "derive_builder_core" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f91d4cfa921f1c05904dc3c57b4a32c38aed3340cce209f3a6fd1478babafc4" -dependencies = [ - "darling", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "derive_builder_macro" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f0314b72bed045f3a68671b3c86328386762c93f82d98c65c3cb5e5f573dd68" -dependencies = [ - "derive_builder_core", - "syn", -] - [[package]] name = "dialoguer" -version = "0.10.2" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a92e7e37ecef6857fdc0c0c5d42fd5b0938e46590c2183cc92dd310a6d078eb1" +checksum = "af3c796f3b0b408d9fd581611b47fa850821fcb84aa640b83a3c1a5be2d691f2" dependencies = [ "console", + "shell-words", "tempfile", "zeroize", ] @@ -1349,16 +1369,6 @@ dependencies = [ "dirs-sys", ] -[[package]] -name = "dirs-next" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" -dependencies = [ - "cfg-if", - "dirs-sys-next", -] - [[package]] name = "dirs-sys" version = "0.3.7" @@ -1370,23 +1380,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "dirs-sys-next" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" -dependencies = [ - "libc", - "redox_users", - "winapi", -] - -[[package]] -name = "discard" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "212d0f5754cb6769937f4501cc0e67f4f4483c8d2c3e1e922ee9edbe4ab4c7c0" - [[package]] name = "displaydoc" version = "0.2.3" @@ -1400,9 +1393,9 @@ dependencies = [ [[package]] name = "ed25519" -version = "1.5.2" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9c280362032ea4203659fc489832d0204ef09f247a0506f170dafcac08c369" +checksum = "91cff35c70bba8a626e3185d8cd48cc11b5437e1a5bcd15b9b5fa3c64b6dfee7" dependencies = [ "serde", "signature", @@ -1425,9 +1418,9 @@ dependencies = [ [[package]] name = "either" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90e5c1c8368803113bf0c9584fc495a58b86dc8a29edbf8fe877d21d9507e797" +checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" [[package]] name = "encode_unicode" @@ -1437,29 +1430,13 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" [[package]] name = "encoding_rs" -version = "0.8.31" +version = "0.8.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9852635589dc9f9ea1b6fe9f05b50ef208c85c834a562f0c6abb1c475736ec2b" +checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394" dependencies = [ "cfg-if", ] -[[package]] -name = "endian-type" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" - -[[package]] -name = "env_logger" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85cdab6a89accf66733ad5a1693a4dcced6aeff64602b634530dd73c1f3ee9f0" -dependencies = [ - "log", - "regex", -] - [[package]] name = "errno" version = "0.2.8" @@ -1481,16 +1458,6 @@ dependencies = [ "libc", ] -[[package]] -name = "error-code" -version = "2.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64f18991e7bf11e7ffee451b5318b5c1a73c52d0d0ada6e5a3017c8c1ced6a21" -dependencies = [ - "libc", - "str-buf", -] - [[package]] name = "event-listener" version = "2.5.3" @@ -1557,7 +1524,7 @@ name = "example-emoji-verification" version = "0.1.0" dependencies = [ "anyhow", - "clap 4.0.32", + "clap 4.1.8", "futures", "matrix-sdk", "tokio", @@ -1610,14 +1577,27 @@ dependencies = [ "url", ] +[[package]] +name = "example-persist-session" +version = "0.1.0" +dependencies = [ + "anyhow", + "dirs", + "matrix-sdk", + "rand 0.8.5", + "serde", + "serde_json", + "tokio", + "tracing-subscriber", +] + [[package]] name = "example-timeline" version = "0.1.0" dependencies = [ "anyhow", - "clap 4.0.32", + "clap 4.1.8", "futures", - "futures-signals", "matrix-sdk", "tokio", "tracing-subscriber", @@ -1635,6 +1615,28 @@ dependencies = [ "syn", ] +[[package]] +name = "eyeball" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c7be1d67275032c662cadf525a79aef6909469579c5d81c69c148f7257257af" +dependencies = [ + "futures-core", + "readlock", +] + +[[package]] +name = "eyeball-im" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bb8a6cfd1f5947d0426dcb753723318d5922c738e905be7af167547565f81d9" +dependencies = [ + "futures-core", + "im", + "tokio", + "tokio-stream", +] + [[package]] name = "eyre" version = "0.6.8" @@ -1657,36 +1659,15 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" -[[package]] -name = "fancy-regex" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d6b8560a05112eb52f04b00e5d3790c0dd75d9d980eb8a122fb23b92a623ccf" -dependencies = [ - "bit-set", - "regex", -] - [[package]] name = "fastrand" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a407cfaa3385c4ae6b23e84623d48c2798d06e3e6a1878f7f59f17b3f86499" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" dependencies = [ "instant", ] -[[package]] -name = "fd-lock" -version = "3.0.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb21c69b9fea5e15dbc1049e4b77145dd0ba1c84019c488102de0dc4ea4b0a27" -dependencies = [ - "cfg-if", - "rustix", - "windows-sys", -] - [[package]] name = "findshlibs" version = "0.10.2" @@ -1763,15 +1744,25 @@ dependencies = [ [[package]] name = "fs_extra" -version = "1.2.0" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + +[[package]] +name = "futf" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2022715d62ab30faffd124d40b76f4134a550a87792276512b18d63272333394" +checksum = "df420e2e84819663797d1ec6544b13c5be84629e7bb00dc960d6917db2987843" +dependencies = [ + "mac", + "new_debug_unreachable", +] [[package]] name = "futures" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38390104763dc37a5145a53c29c63c1290b5d316d6086ec32c293f6736051bb0" +checksum = "13e2792b0ff0340399d58445b88fd9770e3489eff258a4cbc1523418f12abf84" dependencies = [ "futures-channel", "futures-core", @@ -1784,9 +1775,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52ba265a92256105f45b719605a571ffe2d1f0fea3807304b522c1d778f79eed" +checksum = "2e5317663a9089767a1ec00a487df42e0ca174b61b4483213ac24448e4664df5" dependencies = [ "futures-core", "futures-sink", @@ -1794,15 +1785,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04909a7a7e4633ae6c4a9ab280aeb86da1236243a77b694a49eacd659a4bd3ac" +checksum = "ec90ff4d0fe1f57d600049061dc6bb68ed03c7d2fbd697274c41805dcb3f8608" [[package]] name = "futures-executor" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7acc85df6714c176ab5edf386123fafe217be88c0840ec11f199441134a074e2" +checksum = "e8de0a35a6ab97ec8869e32a2473f4b1324459e14c29275d14b10cb1fd19b50e" dependencies = [ "futures-core", "futures-task", @@ -1811,9 +1802,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb" +checksum = "bfb8371b6fb2aeb2d280374607aeabfc99d95c72edfe51692e42d3d7f0d08531" [[package]] name = "futures-lite" @@ -1832,41 +1823,26 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdfb8ce053d86b91919aad980c220b1fb8401a9394410e1c289ed7e66b61835d" +checksum = "95a73af87da33b5acf53acfebdc339fe592ecf5357ac7c0a7734ab9d8c876a70" dependencies = [ "proc-macro2", "quote", "syn", ] -[[package]] -name = "futures-signals" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3acc659ba666cff13fdf65242d16428f2f11935b688f82e4024ad39667a5132" -dependencies = [ - "discard", - "futures-channel", - "futures-core", - "futures-util", - "log", - "pin-project", - "serde", -] - [[package]] name = "futures-sink" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39c15cf1a4aa79df40f1bb462fb39676d0ad9e366c2a33b590d7c66f4f81fcf9" +checksum = "f310820bb3e8cfd46c80db4d7fb8353e15dfff853a127158425f31e0be6c8364" [[package]] name = "futures-task" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ffb393ac5d9a6eaa9d3fdf37ae2776656b706e200c8e16b1bdb227f5198e6ea" +checksum = "dcf79a1bf610b10f42aea489289c5a2c478a786509693b80cd39c44ccd936366" [[package]] name = "futures-timer" @@ -1876,9 +1852,9 @@ checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" [[package]] name = "futures-util" -version = "0.3.25" +version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6" +checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1" dependencies = [ "futures-channel", "futures-core", @@ -1949,9 +1925,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.27.0" +version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dec7af912d60cdbd3677c1af9352ebae6fb8394d165568a2234df0fa00f87793" +checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4" [[package]] name = "glob" @@ -1959,6 +1935,18 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "gloo-timers" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b995a66bb87bebce9a0f4a95aed01daca4872c050bfcb21653361c03bc35e5c" +dependencies = [ + "futures-channel", + "futures-core", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "gloo-utils" version = "0.1.6" @@ -2028,9 +2016,9 @@ dependencies = [ [[package]] name = "heck" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" @@ -2050,6 +2038,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hermit-abi" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" + [[package]] name = "hkdf" version = "0.12.3" @@ -2068,6 +2062,20 @@ dependencies = [ "digest 0.10.6", ] +[[package]] +name = "html5ever" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bea68cab48b8459f17cf1c944c67ddc572d272d9f2b274140f223ecb1da4a3b7" +dependencies = [ + "log", + "mac", + "markup5ever", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "http" version = "0.2.8" @@ -2131,9 +2139,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" -version = "0.14.23" +version = "0.14.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "034711faac9d2166cb1baf1a2fb0b60b1f277f8492fd72176c17f3515e1abd3c" +checksum = "5e011372fa0b68db8350aa7a248930ecc7839bf46d8485577d69f117a75f164c" dependencies = [ "bytes", "futures-channel", @@ -2215,12 +2223,6 @@ dependencies = [ "cxx-build", ] -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - [[package]] name = "idna" version = "0.3.0" @@ -2231,6 +2233,21 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "im" +version = "15.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0acd33ff0285af998aaf9b57342af478078f53492322fafc47450e09397e0e9" +dependencies = [ + "bitmaps", + "rand_core 0.6.4", + "rand_xoshiro", + "serde", + "sized-chunks", + "typenum", + "version_check", +] + [[package]] name = "image" version = "0.23.14" @@ -2301,9 +2318,9 @@ dependencies = [ [[package]] name = "indoc" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da2d6f23ffea9d7e76c53eee25dfb67bcd8fde7f1198b0855350698c9f07c780" +checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306" [[package]] name = "infer" @@ -2313,13 +2330,13 @@ checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac" [[package]] name = "inferno" -version = "0.11.13" +version = "0.11.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7207d75fcf6c1868f1390fc1c610431fe66328e9ee6813330a041ef6879eca1" +checksum = "2fb7c1b80a1dfa604bb4a649a5c5aeef3d913f7c520cb42b40e534e8a61bcdfc" dependencies = [ - "ahash 0.8.2", - "atty", + "ahash 0.8.3", "indexmap", + "is-terminal", "itoa", "log", "num-format", @@ -2353,12 +2370,12 @@ dependencies = [ [[package]] name = "io-lifetimes" -version = "1.0.3" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46112a93252b123d31a119a8d1a1ac19deac4fac6e0e8b0df58f0d4e5870e63c" +checksum = "1abeb7a0dd0f8181267ff8adc397075586500b81b28a73e8a0208b00fc170fb3" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] @@ -2369,14 +2386,14 @@ checksum = "30e22bd8629359895450b59ea7a776c850561b96a3b1d31321c1949d9e6c9146" [[package]] name = "is-terminal" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28dfb6c8100ccc63462345b67d1bbc3679177c75ee4bf59bf29c8b1d110b8189" +checksum = "22e18b0a45d56fe973d6db23972bf5bc46f988a4a2385deac9cc29572f09daef" dependencies = [ - "hermit-abi 0.2.6", + "hermit-abi 0.3.1", "io-lifetimes", "rustix", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] @@ -2409,11 +2426,12 @@ version = "0.2.0" dependencies = [ "app_dirs2", "chrono", - "clap 4.0.32", + "clap 4.1.8", "dialoguer", + "eyeball", + "eyeball-im", "eyre", "futures", - "futures-signals", "log4rs", "matrix-sdk", "matrix-sdk-common", @@ -2431,16 +2449,18 @@ dependencies = [ [[package]] name = "jni" -version = "0.19.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec" +checksum = "19bfb8e36ca99b00e6d368320e0822dec9d81db4ccf122f82091f972c90b9985" dependencies = [ "cesu8", + "cfg-if", "combine", "jni-sys", "log", "thiserror", "walkdir", + "windows-sys 0.45.0", ] [[package]] @@ -2469,9 +2489,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.60" +version = "0.3.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49409df3e3bf0856b916e2ceaca09ee28e6871cf7d9ce97a692cacfdb2a25a47" +checksum = "445dde2150c55e483f3d8416706b97ec8e8237c307e5b7b4b8dd15e6af2a0730" dependencies = [ "wasm-bindgen", ] @@ -2516,6 +2536,15 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "984e109462d46ad18314f10e392c286c3d47bce203088a09012de1015b45b737" +[[package]] +name = "kv-log-macro" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de8b303297635ad57c9f5059fd9cee7a47f8e8daa09df0fcd07dd39fb22977f" +dependencies = [ + "log", +] + [[package]] name = "lazy-regex" version = "2.4.1" @@ -2561,6 +2590,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "libm" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "348108ab3fba42ec82ff6e9564fc4ca0247bdccdc68dd8af9764bbc79c3c8ffb" + [[package]] name = "libsqlite3-sys" version = "0.25.2" @@ -2604,6 +2639,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" dependencies = [ "cfg-if", + "value-bag", ] [[package]] @@ -2640,12 +2676,32 @@ dependencies = [ "thread-id", ] +[[package]] +name = "mac" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c41e0c4fef86961ac6d6f8a82609f55f31b05e4fce149ac5710e439df7619ba4" + [[package]] name = "maplit" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" +[[package]] +name = "markup5ever" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2629bb1404f3d34c2e921f21fd34ba00b206124c81f65c50b43b6aaefeb016" +dependencies = [ + "log", + "phf 0.10.1", + "phf_codegen", + "string_cache", + "string_cache_codegen", + "tendril", +] + [[package]] name = "matchers" version = "0.1.0" @@ -2700,17 +2756,19 @@ dependencies = [ "chrono", "ctor", "dashmap", - "derive_builder", "dirs", "event-listener", + "eyeball", + "eyeball-im", "eyre", "futures", "futures-core", - "futures-signals", "futures-util", "getrandom 0.2.8", + "gloo-timers", "http", "hyper", + "im", "image 0.24.5", "indexmap", "matrix-sdk-base", @@ -2719,7 +2777,9 @@ dependencies = [ "matrix-sdk-sled", "matrix-sdk-test", "mime", + "mime_guess", "once_cell", + "pin-project-lite", "rand 0.8.5", "reqwest", "ruma", @@ -2729,13 +2789,12 @@ dependencies = [ "tempfile", "thiserror", "tokio", - "tokio-stream", "tower", "tracing", "tracing-subscriber", "url", + "uuid", "wasm-bindgen-test", - "wasm-timer", "wiremock", "zeroize", ] @@ -2770,15 +2829,15 @@ dependencies = [ name = "matrix-sdk-base" version = "0.6.1" dependencies = [ + "assert_matches", "assign", "async-stream", "async-trait", "ctor", "dashmap", + "eyeball", "futures", - "futures-channel", "futures-core", - "futures-signals", "futures-util", "http", "matrix-sdk-common", @@ -2804,6 +2863,7 @@ dependencies = [ "async-lock", "futures-core", "futures-util", + "gloo-timers", "instant", "matrix-sdk-test", "ruma", @@ -2812,7 +2872,6 @@ dependencies = [ "tokio", "wasm-bindgen-futures", "wasm-bindgen-test", - "wasm-timer", ] [[package]] @@ -2823,22 +2882,25 @@ dependencies = [ "anyhow", "aquamarine", "assert_matches", + "async-std", "async-trait", "atomic", - "base64 0.20.0", + "base64 0.21.0", "bs58", "byteorder", "cfg-if", + "ctor", "ctr", "dashmap", "event-listener", + "eyeball", "futures", "futures-core", - "futures-signals", "futures-util", "hmac", "http", "indoc", + "itertools 0.10.5", "matrix-sdk-common", "matrix-sdk-qrcode", "matrix-sdk-test", @@ -2846,6 +2908,7 @@ dependencies = [ "pbkdf2", "proptest", "rand 0.8.5", + "rmp-serde", "ruma", "serde", "serde_json", @@ -2863,7 +2926,7 @@ name = "matrix-sdk-crypto-ffi" version = "0.1.0" dependencies = [ "anyhow", - "base64 0.20.0", + "base64 0.21.0", "futures-util", "hmac", "http", @@ -2917,6 +2980,7 @@ dependencies = [ "matrix-sdk-common", "matrix-sdk-crypto", "matrix-sdk-sled", + "matrix-sdk-sqlite", "napi", "napi-build", "napi-derive", @@ -2931,12 +2995,12 @@ dependencies = [ name = "matrix-sdk-ffi" version = "0.2.0" dependencies = [ - "android_logger", "anyhow", "base64 0.21.0", "extension-trait", + "eyeball", + "eyeball-im", "futures-core", - "futures-signals", "futures-util", "log-panics", "matrix-sdk", @@ -2944,15 +3008,18 @@ dependencies = [ "once_cell", "opentelemetry", "opentelemetry-otlp", + "ruma", "sanitize-filename-reader-friendly", "serde_json", "thiserror", "tokio", "tokio-stream", "tracing", + "tracing-android", "tracing-opentelemetry", "tracing-subscriber", "uniffi", + "url", "zeroize", ] @@ -2961,10 +3028,10 @@ name = "matrix-sdk-indexeddb" version = "0.2.0" dependencies = [ "anyhow", + "assert_matches", "async-trait", - "base64 0.20.0", + "base64 0.21.0", "dashmap", - "derive_builder", "getrandom 0.2.8", "gloo-utils", "indexed_db_futures", @@ -3004,7 +3071,7 @@ dependencies = [ name = "matrix-sdk-qrcode" version = "0.4.0" dependencies = [ - "base64 0.20.0", + "base64 0.21.0", "byteorder", "image 0.23.14", "qrcode", @@ -3020,7 +3087,6 @@ dependencies = [ "async-stream", "async-trait", "dashmap", - "derive_builder", "fs_extra", "futures-core", "futures-util", @@ -3064,12 +3130,12 @@ dependencies = [ "ruma", "rusqlite", "serde", - "serde_json", "tempfile", "thiserror", "tokio", "tracing", "tracing-subscriber", + "vodozemac", ] [[package]] @@ -3084,6 +3150,7 @@ dependencies = [ "hmac", "pbkdf2", "rand 0.8.5", + "rmp-serde", "serde", "serde_json", "sha2 0.10.6", @@ -3190,14 +3257,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5d732bc30207a6423068df043e3d02e0735b155ad7ce1a6f76fe2baa5b158de" +checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" dependencies = [ "libc", "log", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] @@ -3208,9 +3275,9 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" [[package]] name = "napi" -version = "2.10.5" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83c8ae31209e4268eae6003d37c298135d0f36e721b4d1fa91dd938a52388ccf" +checksum = "2412d19892730f62fd592f8af41606ca6717ea1eca026103cd44b447829f00c1" dependencies = [ "bitflags", "ctor", @@ -3228,9 +3295,9 @@ checksum = "882a73d9ef23e8dc2ebbffb6a6ae2ef467c0f18ac10711e4cc59c5485d41df0e" [[package]] name = "napi-derive" -version = "2.9.3" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af4e44e34e70aa61be9036ae652e27c20db5bca80e006be0f482419f6601352a" +checksum = "03f15c1ac0eac01eca2a24c27905ab47f7411acefd829d0d01fb131dc39befd7" dependencies = [ "convert_case", "napi-derive-backend", @@ -3241,9 +3308,9 @@ dependencies = [ [[package]] name = "napi-derive-backend" -version = "1.0.40" +version = "1.0.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17925fff04b6fa636f8e4b4608cc1a4f1360b64ac8ecbfdb7da1be1dc74f6843" +checksum = "4930d5fa70f5663b9e7d6b4f0816b70d095574ee7f3c865fdb8c43b0f7e6406d" dependencies = [ "convert_case", "once_cell", @@ -3255,9 +3322,9 @@ dependencies = [ [[package]] name = "napi-sys" -version = "2.2.2" +version = "2.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "529671ebfae679f2ce9630b62dd53c72c56b3eb8b2c852e7e2fa91704ff93d67" +checksum = "166b5ef52a3ab5575047a9fe8d4a030cdd0f63c96f071cd6907674453b07bae3" dependencies = [ "libloading", ] @@ -3287,13 +3354,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" [[package]] -name = "nibble_vec" -version = "0.1.0" +name = "new_debug_unreachable" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" -dependencies = [ - "smallvec", -] +checksum = "e4a24736216ec316047a1fc4252e27dabb04218aa4a3f37c6e7ddbf1f9782b54" [[package]] name = "nix" @@ -3308,14 +3372,23 @@ dependencies = [ [[package]] name = "nom" -version = "7.1.2" +version = "7.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5507769c4919c998e69e49c839d9dc6e693ede4cc4290d6ad8b41d4f09c548c" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" dependencies = [ "memchr", "minimal-lexical", ] +[[package]] +name = "nom8" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae01545c9c7fc4486ab7debaf2aad7003ac19431791868fb2e8066df97fad2f8" +dependencies = [ + "memchr", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -3386,6 +3459,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -3400,9 +3474,9 @@ dependencies = [ [[package]] name = "object" -version = "0.30.1" +version = "0.30.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d864c91689fdc196779b98dba0aceac6118594c2df6ee5d943eb6a8df4d107a" +checksum = "ea86265d3d3dcb6a27fc51bd29a4bf387fae9d2986b823079d4986af253eb439" dependencies = [ "memchr", ] @@ -3432,9 +3506,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.17.0" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f61fba1741ea2b3d6a1e3178721804bb716a68a6aeba1149b5d52e3d464ea66" +checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" [[package]] name = "oorandom" @@ -3624,7 +3698,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.5", + "parking_lot_core 0.9.7", ] [[package]] @@ -3643,15 +3717,15 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.5" +version = "0.9.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ff9f3fef3968a3ec5945535ed654cb38ff72d7495a25619e2247fb15a2ed9ba" +checksum = "9069cbb9f99e3a5083476ccb29ceb1de18b9118cafa53e90c9551235de2b9521" dependencies = [ "cfg-if", "libc", "redox_syscall", "smallvec", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] @@ -3691,14 +3765,94 @@ checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" [[package]] name = "petgraph" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6d5014253a1331579ce62aa67443b4a658c5e7dd03d4bc6d302b94474888143" +checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" dependencies = [ "fixedbitset", "indexmap", ] +[[package]] +name = "phf" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fabbf1ead8a5bcbc20f5f8b939ee3f5b0f6f281b6ad3468b84656b658b455259" +dependencies = [ + "phf_shared 0.10.0", +] + +[[package]] +name = "phf" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "928c6535de93548188ef63bb7c4036bd415cd8f36ad25af44b9789b2ee72a48c" +dependencies = [ + "phf_macros", + "phf_shared 0.11.1", +] + +[[package]] +name = "phf_codegen" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb1c3a8bc4dd4e5cfce29b44ffc14bedd2ee294559a294e2a4d4c9e9a6a13cd" +dependencies = [ + "phf_generator 0.10.0", + "phf_shared 0.10.0", +] + +[[package]] +name = "phf_generator" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5285893bb5eb82e6aaf5d59ee909a06a16737a8970984dd7746ba9283498d6" +dependencies = [ + "phf_shared 0.10.0", + "rand 0.8.5", +] + +[[package]] +name = "phf_generator" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1181c94580fa345f50f19d738aaa39c0ed30a600d95cb2d3e23f94266f14fbf" +dependencies = [ + "phf_shared 0.11.1", + "rand 0.8.5", +] + +[[package]] +name = "phf_macros" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92aacdc5f16768709a569e913f7451034034178b05bdc8acda226659a3dccc66" +dependencies = [ + "phf_generator 0.11.1", + "phf_shared 0.11.1", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "phf_shared" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" +dependencies = [ + "siphasher", +] + +[[package]] +name = "phf_shared" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1fb5f6f826b772a8d4c0394209441e7d37cbbb967ae9c7e0e8134365c9ee676" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project" version = "1.0.12" @@ -3805,6 +3959,20 @@ dependencies = [ "miniz_oxide 0.6.2", ] +[[package]] +name = "polling" +version = "2.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22122d5ec4f9fe1b3916419b76be1e80bcb93f618d071d2edf841b137b2a2bd6" +dependencies = [ + "autocfg", + "cfg-if", + "libc", + "log", + "wepoll-ffi", + "windows-sys 0.42.0", +] + [[package]] name = "poly1305" version = "0.7.2" @@ -3844,6 +4012,12 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "precomputed-hash" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" + [[package]] name = "prettyplease" version = "0.1.23" @@ -3856,13 +4030,12 @@ dependencies = [ [[package]] name = "proc-macro-crate" -version = "1.2.1" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eda0fc3b0fb7c975631757e14d9049da17374063edb6ebbcbc54d880d4fe94e9" +checksum = "66618389e4ec1c7afe67d51a9bf34ff9236480f8d51e7489b7d5ab0303c13f34" dependencies = [ "once_cell", - "thiserror", - "toml", + "toml_edit 0.18.1", ] [[package]] @@ -3891,18 +4064,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.49" +version = "1.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57a8eca9f9c4ffde41714334dee777596264c7825420f521abc92b5b5deb63a5" +checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" dependencies = [ "unicode-ident", ] [[package]] name = "proptest" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e0d9cc07f18492d879586c92b485def06bc850da3118075cd45d50e9c95b0e5" +checksum = "29f1b898011ce9595050a68e60f90bad083ff2987a695a42357134c8381fba70" dependencies = [ "bitflags", "byteorder", @@ -3913,13 +4086,14 @@ dependencies = [ "rand_chacha 0.3.1", "rand_xorshift", "regex-syntax", + "unarray", ] [[package]] name = "prost" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c01db6702aa05baa3f57dec92b8eeeeb4cb19e894e73996b32a4093289e54592" +checksum = "21dc42e00223fc37204bd4aa177e69420c604ca4a183209a8f9de30c6d934698" dependencies = [ "bytes", "prost-derive", @@ -3927,13 +4101,13 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb5320c680de74ba083512704acb90fe00f28f79207286a848e730c45dd73ed6" +checksum = "a3f8ad728fb08fe212df3c05169e940fbb6d9d16a877ddde14644a983ba2012e" dependencies = [ "bytes", "heck", - "itertools", + "itertools 0.10.5", "lazy_static", "log", "multimap", @@ -3949,9 +4123,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8842bad1a5419bca14eac663ba798f6bc19c413c2fdceb5f3ba3b0932d96720" +checksum = "8bda8c0881ea9f722eb9629376db3d0b903b462477c1aafcb0566610ac28ac5d" dependencies = [ "anyhow", "itertools 0.10.5", @@ -3962,9 +4136,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "017f79637768cde62820bc2d4fe0e45daaa027755c323ad077767c6c5f173091" +checksum = "a5e0526209433e96d83d750dd81a99118edbc55739e7e61a46764fd2ad537788" dependencies = [ "bytes", "prost", @@ -4015,16 +4189,6 @@ dependencies = [ "proc-macro2", ] -[[package]] -name = "radix_trie" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" -dependencies = [ - "endian-type", - "nibble_vec", -] - [[package]] name = "rand" version = "0.7.3" @@ -4105,6 +4269,15 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand_xoshiro" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" +dependencies = [ + "rand_core 0.6.4", +] + [[package]] name = "rayon" version = "1.6.1" @@ -4117,9 +4290,9 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.10.1" +version = "1.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cac410af5d00ab6884528b4ab69d1e8e146e8d471201800fa1b4524126de6ad3" +checksum = "356a0625f1954f730c0201cdab48611198dc6ce21f4acff55089b5a78e6e835b" dependencies = [ "crossbeam-channel", "crossbeam-deque", @@ -4127,6 +4300,12 @@ dependencies = [ "num_cpus", ] +[[package]] +name = "readlock" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35c8a22130504d1f661d1bc373b424f2d45910fa5319132d903a4074e1527b2e" + [[package]] name = "redox_syscall" version = "0.2.16" @@ -4184,12 +4363,12 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.11.13" +version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68cc60575865c7831548863cc02356512e3f1dc2f3f82cb837d7fc4cc8f3c97c" +checksum = "21eed90ec8570952d53b772ecf8f206aa1ec9a3d76b2521c56c42973f2d91ee9" dependencies = [ "async-compression", - "base64 0.13.1", + "base64 0.21.0", "bytes", "encoding_rs", "futures-core", @@ -4235,9 +4414,9 @@ checksum = "4389f1d5789befaf6029ebd9f7dac4af7f7e3d61b69d4f30e2ac02b57e7712b0" [[package]] name = "rgb" -version = "0.8.34" +version = "0.8.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3603b7d71ca82644f79b5a06d1220e9a58ede60bd32255f698cb1af8838b8db3" +checksum = "7495acf66551cdb696b7711408144bcd3194fc78e32f3a09e809bfe7dd4a7ce3" dependencies = [ "bytemuck", ] @@ -4281,8 +4460,8 @@ dependencies = [ [[package]] name = "ruma" -version = "0.7.4" -source = "git+https://github.com/ruma/ruma?rev=00045e559f864eabff08295d603f7b3238288b6f#00045e559f864eabff08295d603f7b3238288b6f" +version = "0.8.2" +source = "git+https://github.com/ruma/ruma?rev=8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5#8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5" dependencies = [ "assign", "js_int", @@ -4291,12 +4470,13 @@ dependencies = [ "ruma-client-api", "ruma-common", "ruma-federation-api", + "ruma-push-gateway-api", ] [[package]] name = "ruma-appservice-api" -version = "0.7.0" -source = "git+https://github.com/ruma/ruma?rev=00045e559f864eabff08295d603f7b3238288b6f#00045e559f864eabff08295d603f7b3238288b6f" +version = "0.8.1" +source = "git+https://github.com/ruma/ruma?rev=8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5#8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5" dependencies = [ "js_int", "ruma-common", @@ -4306,8 +4486,8 @@ dependencies = [ [[package]] name = "ruma-client-api" -version = "0.15.3" -source = "git+https://github.com/ruma/ruma?rev=00045e559f864eabff08295d603f7b3238288b6f#00045e559f864eabff08295d603f7b3238288b6f" +version = "0.16.2" +source = "git+https://github.com/ruma/ruma?rev=8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5#8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5" dependencies = [ "assign", "bytes", @@ -4323,13 +4503,14 @@ dependencies = [ [[package]] name = "ruma-common" -version = "0.10.5" -source = "git+https://github.com/ruma/ruma?rev=00045e559f864eabff08295d603f7b3238288b6f#00045e559f864eabff08295d603f7b3238288b6f" +version = "0.11.3" +source = "git+https://github.com/ruma/ruma?rev=8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5#8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5" dependencies = [ - "base64 0.20.0", + "base64 0.21.0", "bytes", "form_urlencoded", "getrandom 0.2.8", + "html5ever", "http", "indexmap", "js-sys", @@ -4337,6 +4518,7 @@ dependencies = [ "js_option", "konst", "percent-encoding", + "phf 0.11.1", "pulldown-cmark", "rand 0.8.5", "regex", @@ -4354,8 +4536,8 @@ dependencies = [ [[package]] name = "ruma-federation-api" -version = "0.6.0" -source = "git+https://github.com/ruma/ruma?rev=00045e559f864eabff08295d603f7b3238288b6f#00045e559f864eabff08295d603f7b3238288b6f" +version = "0.7.1" +source = "git+https://github.com/ruma/ruma?rev=8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5#8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5" dependencies = [ "js_int", "ruma-common", @@ -4365,8 +4547,8 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" -version = "0.9.0" -source = "git+https://github.com/ruma/ruma?rev=00045e559f864eabff08295d603f7b3238288b6f#00045e559f864eabff08295d603f7b3238288b6f" +version = "0.9.1" +source = "git+https://github.com/ruma/ruma?rev=8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5#8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5" dependencies = [ "js_int", "thiserror", @@ -4374,8 +4556,8 @@ dependencies = [ [[package]] name = "ruma-macros" -version = "0.10.5" -source = "git+https://github.com/ruma/ruma?rev=00045e559f864eabff08295d603f7b3238288b6f#00045e559f864eabff08295d603f7b3238288b6f" +version = "0.11.3" +source = "git+https://github.com/ruma/ruma?rev=8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5#8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5" dependencies = [ "once_cell", "proc-macro-crate", @@ -4384,7 +4566,18 @@ dependencies = [ "ruma-identifiers-validation", "serde", "syn", - "toml", + "toml 0.7.2", +] + +[[package]] +name = "ruma-push-gateway-api" +version = "0.7.1" +source = "git+https://github.com/ruma/ruma?rev=8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5#8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5" +dependencies = [ + "js_int", + "ruma-common", + "serde", + "serde_json", ] [[package]] @@ -4409,23 +4602,23 @@ checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" [[package]] name = "rustix" -version = "0.36.6" +version = "0.36.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4feacf7db682c6c329c4ede12649cd36ecab0f3be5b7d74e6a20304725db4549" +checksum = "f43abb88211988493c1abb44a70efa56ff0ce98f233b7b276146f1f3f7ba9644" dependencies = [ "bitflags", "errno", "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.45.0", ] [[package]] name = "rustls" -version = "0.20.7" +version = "0.20.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "539a2bfe908f471bfa933876bd1eb6a19cf2176d375f82ef7f99530a40e48c2c" +checksum = "fff78fc74d175294f4e83b28343315ffcfb114b156f0185e9741cb5570f50e2f" dependencies = [ "log", "ring", @@ -4435,11 +4628,11 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0864aeff53f8c05aa08d86e5ef839d3dfcf07aeba2db32f12db0ef716e87bd55" +checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64 0.13.1", + "base64 0.21.0", ] [[package]] @@ -4448,40 +4641,6 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5583e89e108996506031660fe09baa5011b9dd0341b89029313006d1fb508d70" -[[package]] -name = "rustyline" -version = "10.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d1cd5ae51d3f7bf65d7969d579d502168ef578f289452bd8ccc91de28fda20e" -dependencies = [ - "bitflags", - "cfg-if", - "clipboard-win", - "dirs-next", - "fd-lock", - "libc", - "log", - "memchr", - "nix", - "radix_trie", - "scopeguard", - "unicode-segmentation", - "unicode-width", - "utf8parse", - "winapi", -] - -[[package]] -name = "rustyline-derive" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "107c3d5d7f370ac09efa62a78375f94d94b8a33c61d8c278b96683fb4dbf2d8d" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "ryu" version = "1.0.12" @@ -4512,7 +4671,7 @@ version = "0.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" dependencies = [ - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -4571,9 +4730,9 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.7.0" +version = "2.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2bc1bb97804af6631813c55739f771071e0f2ed33ee20b68c86ec505d906356c" +checksum = "a332be01508d814fed64bf28f798a146d73792121129962fdf335bb3c49a4254" dependencies = [ "bitflags", "core-foundation", @@ -4584,9 +4743,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.6.1" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0160a13a177a45bfb43ce71c01580998474f556ad854dcbca936dd2841a5c556" +checksum = "31c9bb296072e961fcbd8853511dd39c2d8be2deb1e17c6860b1d30732b323b4" dependencies = [ "core-foundation-sys", "libc", @@ -4612,9 +4771,9 @@ dependencies = [ [[package]] name = "serde_bytes" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "718dc5fff5b36f99093fc49b280cfc96ce6fc824317783bff5a1fed0c7a64819" +checksum = "416bda436f9aab92e02c8e10d49a15ddd339cea90b6e340fe51ed97abb548294" dependencies = [ "serde", ] @@ -4645,9 +4804,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.91" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877c235533714907a8c2464236f5c4b2a17262ef1bd71f38f35ea592c8da6883" +checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76" dependencies = [ "itoa", "ryu", @@ -4674,6 +4833,15 @@ dependencies = [ "thiserror", ] +[[package]] +name = "serde_spanned" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0efd8caf556a6cebd3b285caf480045fcc1ac04f6bd786b09a6f11af30c4fcf4" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -4688,9 +4856,9 @@ dependencies = [ [[package]] name = "serde_yaml" -version = "0.9.16" +version = "0.9.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92b5b431e8907b50339b51223b97d102db8d987ced36f6e4d03621db9316c834" +checksum = "8fb06d4b6cdaef0e0c51fa881acb721bed3c924cfaa71d9c94a3b771dfdf6567" dependencies = [ "indexmap", "itoa", @@ -4732,11 +4900,17 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shell-words" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" + [[package]] name = "signal-hook" -version = "0.3.14" +version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a253b5e89e2698464fc26b545c9edceb338e18a89effeeecfea192c3025be29d" +checksum = "732768f1176d21d09e076c23a93123d40bba92d50c4058da34d45c8de8e682b9" dependencies = [ "libc", "signal-hook-registry", @@ -4755,9 +4929,9 @@ dependencies = [ [[package]] name = "signal-hook-registry" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" dependencies = [ "libc", ] @@ -4774,6 +4948,16 @@ version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" +[[package]] +name = "sized-chunks" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d69225bde7a69b235da73377861095455d298f2b970996eec25ddbb42b3d1e" +dependencies = [ + "bitmaps", + "typenum", +] + [[package]] name = "slab" version = "0.4.7" @@ -4799,29 +4983,14 @@ dependencies = [ "parking_lot 0.11.2", ] -[[package]] -name = "sled-state-inspector" -version = "0.1.0" -dependencies = [ - "atty", - "clap 3.2.23", - "futures", - "matrix-sdk-base", - "matrix-sdk-sled", - "ruma", - "rustyline", - "rustyline-derive", - "serde", - "serde_json", - "syntect", -] - [[package]] name = "sliding-sync-integration-test" version = "0.1.0" dependencies = [ "anyhow", - "ctor", + "assert_matches", + "eyeball", + "eyeball-im", "futures", "matrix-sdk", "matrix-sdk-integration-testing", @@ -4884,18 +5053,38 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" -[[package]] -name = "str-buf" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e08d8363704e6c71fc928674353e6b7c23dcea9d82d7012c8faf2a3a025f8d0" - [[package]] name = "str_stack" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9091b6114800a5f2141aee1d1b9d6ca3592ac062dc5decb3764ec5895a47b4eb" +[[package]] +name = "string_cache" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d69e88b23f23030bf4d0e9ca7b07434f70e1c1f4d3ca7e93ce958b373654d9f" +dependencies = [ + "new_debug_unreachable", + "once_cell", + "parking_lot 0.12.1", + "phf_shared 0.10.0", + "precomputed-hash", + "serde", +] + +[[package]] +name = "string_cache_codegen" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bb30289b722be4ff74a408c3cc27edeaad656e06cb1fe8fa9231fa59c728988" +dependencies = [ + "phf_generator 0.10.0", + "phf_shared 0.10.0", + "proc-macro2", + "quote", +] + [[package]] name = "strsim" version = "0.10.0" @@ -4944,9 +5133,9 @@ dependencies = [ [[package]] name = "sync_wrapper" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "synstructure" @@ -4960,27 +5149,6 @@ dependencies = [ "unicode-xid", ] -[[package]] -name = "syntect" -version = "5.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6c454c27d9d7d9a84c7803aaa3c50cd088d2906fe3c6e42da3209aa623576a8" -dependencies = [ - "bincode", - "bitflags", - "fancy-regex", - "flate2", - "fnv", - "lazy_static", - "once_cell", - "regex-syntax", - "serde", - "serde_derive", - "serde_json", - "thiserror", - "walkdir", -] - [[package]] name = "tempfile" version = "3.3.0" @@ -4995,11 +5163,22 @@ dependencies = [ "winapi", ] +[[package]] +name = "tendril" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d24a120c5fc464a3458240ee02c299ebcb9d67b5249c8848b09d639dca8d7bb0" +dependencies = [ + "futf", + "mac", + "utf-8", +] + [[package]] name = "termcolor" -version = "1.1.3" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bab24d30b911b2376f3a13cc2cd443142f0c81dda04c118693e35b3835757755" +checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" dependencies = [ "winapi-util", ] @@ -5054,10 +5233,11 @@ dependencies = [ [[package]] name = "thread_local" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180" +checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" dependencies = [ + "cfg-if", "once_cell", ] @@ -5096,9 +5276,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.17" +version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a561bf4617eebd33bca6434b988f39ed798e527f51a1e797d0ee4f61c0a38376" +checksum = "53250a3b3fed8ff8fd988587d8925d26a83ac3845d9e03b220b37f34c2b8d6c2" dependencies = [ "itoa", "serde", @@ -5114,9 +5294,9 @@ checksum = "2e153e1f1acaef8acc537e68b44906d2db6436e2b35ac2c6b42640fff91f00fd" [[package]] name = "time-macros" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d967f99f534ca7e495c575c62638eebc2898a8c84c119b89e250477bc4ba16b2" +checksum = "a460aeb8de6dcb0f381e1ee05f1cd56fcf5a5f6eb8187ff3d8f0b11078d38b7c" dependencies = [ "time-core", ] @@ -5142,15 +5322,15 @@ dependencies = [ [[package]] name = "tinyvec_macros" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.24.1" +version = "1.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d9f76183f91ecfb55e1d7d5602bd1d979e38a3a522fe900241cf195624d67ae" +checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af" dependencies = [ "autocfg", "bytes", @@ -5161,7 +5341,7 @@ dependencies = [ "pin-project-lite", "socket2", "tokio-macros", - "windows-sys", + "windows-sys 0.42.0", ] [[package]] @@ -5187,9 +5367,9 @@ dependencies = [ [[package]] name = "tokio-native-tls" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7d995660bd2b7f8c1568414c1126076c13fbb725c40112dc0120b78eb9b717b" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" dependencies = [ "native-tls", "tokio", @@ -5227,13 +5407,14 @@ dependencies = [ "futures-core", "pin-project-lite", "tokio", + "tokio-util", ] [[package]] name = "tokio-util" -version = "0.7.4" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bb2e075f03b3d66d8d8785356224ba688d2906a371015e225beeb65ca92c740" +checksum = "5427d89453009325de0d8f342c9490009f76e999cb7672d77e46267448f7e6b2" dependencies = [ "bytes", "futures-core", @@ -5245,11 +5426,62 @@ dependencies = [ [[package]] name = "toml" -version = "0.5.10" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" +dependencies = [ + "serde", +] + +[[package]] +name = "toml" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7afcae9e3f0fe2c370fd4657108972cbb2fa9db1b9f84849cefd80741b01cb6" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime 0.6.1", + "toml_edit 0.19.3", +] + +[[package]] +name = "toml_datetime" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4553f467ac8e3d374bc9a177a26801e5d0f9b211aa1673fb137a403afd1c9cf5" + +[[package]] +name = "toml_datetime" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ab8ed2edee10b50132aed5f331333428b011c99402b5a534154ed15746f9622" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56c59d8dd7d0dcbc6428bf7aa2f0e823e26e43b3c9aca15bbc9475d23e5fa12b" +dependencies = [ + "indexmap", + "nom8", + "toml_datetime 0.5.1", +] + +[[package]] +name = "toml_edit" +version = "0.19.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1333c76748e868a4d9d1017b5ab53171dfd095f70c712fdb4653a406547f598f" +checksum = "5e6a7712b49e1775fb9a7b998de6635b299237f48b404dde71704f2e0e7f37e5" dependencies = [ + "indexmap", + "nom8", "serde", + "serde_spanned", + "toml_datetime 0.6.1", ] [[package]] @@ -5361,6 +5593,17 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-android" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12612be8f868a09c0ceae7113ff26afe79d81a24473a393cb9120ece162e86c0" +dependencies = [ + "android_log-sys", + "tracing", + "tracing-subscriber", +] + [[package]] name = "tracing-attributes" version = "0.1.23" @@ -5441,7 +5684,7 @@ dependencies = [ "sharded-slab", "smallvec", "thread_local", - "time 0.3.17", + "time 0.3.19", "tracing", "tracing-core", "tracing-log", @@ -5523,6 +5766,12 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + [[package]] name = "unicase" version = "2.6.0" @@ -5534,9 +5783,9 @@ dependencies = [ [[package]] name = "unicode-bidi" -version = "0.3.8" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992" +checksum = "d54675592c1dbefd78cbd98db9bacd89886e1ca50692a0692baefffdeb92dd58" [[package]] name = "unicode-ident" @@ -5565,9 +5814,9 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fdbf052a0783de01e944a6ce7a8cb939e295b1e7be835a1112c3b9a7f047a5a" +checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" [[package]] name = "unicode-width" @@ -5584,8 +5833,7 @@ checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" [[package]] name = "uniffi" version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f71cc01459bc34cfe43fabf32b39f1228709bc6db1b3a664a92940af3d062376" +source = "git+https://github.com/mozilla/uniffi-rs?rev=58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4#58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4" dependencies = [ "anyhow", "camino", @@ -5606,8 +5854,7 @@ dependencies = [ [[package]] name = "uniffi_bindgen" version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbbba5103051c18f10b22f80a74439ddf7100273f217a547005d2735b2498994" +source = "git+https://github.com/mozilla/uniffi-rs?rev=58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4#58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4" dependencies = [ "anyhow", "askama", @@ -5621,7 +5868,7 @@ dependencies = [ "paste", "serde", "serde_json", - "toml", + "toml 0.5.11", "uniffi_meta", "uniffi_testing", "weedle2", @@ -5630,8 +5877,7 @@ dependencies = [ [[package]] name = "uniffi_build" version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ee1a28368ff3d83717e3d3e2e15a66269c43488c3f036914131bb68892f29fb" +source = "git+https://github.com/mozilla/uniffi-rs?rev=58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4#58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4" dependencies = [ "anyhow", "camino", @@ -5641,8 +5887,7 @@ dependencies = [ [[package]] name = "uniffi_checksum_derive" version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03de61393a42b4ad4984a3763c0600594ac3e57e5aaa1d05cede933958987c03" +source = "git+https://github.com/mozilla/uniffi-rs?rev=58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4#58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4" dependencies = [ "quote", "syn", @@ -5651,8 +5896,7 @@ dependencies = [ [[package]] name = "uniffi_core" version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a2b4852d638d74ca2d70e450475efb6d91fe6d54a7cd8d6bd80ad2ee6cd7daa" +source = "git+https://github.com/mozilla/uniffi-rs?rev=58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4#58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4" dependencies = [ "anyhow", "bytes", @@ -5667,8 +5911,7 @@ dependencies = [ [[package]] name = "uniffi_macros" version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa03394de21e759e0022f1ea8d992d2e39290d735b9ed52b1f74b20a684f794e" +source = "git+https://github.com/mozilla/uniffi-rs?rev=58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4#58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4" dependencies = [ "bincode", "camino", @@ -5678,7 +5921,7 @@ dependencies = [ "quote", "serde", "syn", - "toml", + "toml 0.5.11", "uniffi_build", "uniffi_meta", ] @@ -5686,8 +5929,7 @@ dependencies = [ [[package]] name = "uniffi_meta" version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66fdab2c436aed7a6391bec64204ec33948bfed9b11b303235740771f85c4ea6" +source = "git+https://github.com/mozilla/uniffi-rs?rev=58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4#58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4" dependencies = [ "serde", "siphasher", @@ -5697,8 +5939,7 @@ dependencies = [ [[package]] name = "uniffi_testing" version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92b0570953ec41d97ce23e3b92161ac18231670a1f97523258a6d2ab76d7f76c" +source = "git+https://github.com/mozilla/uniffi-rs?rev=58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4#58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4" dependencies = [ "anyhow", "camino", @@ -5744,16 +5985,16 @@ dependencies = [ ] [[package]] -name = "utf8parse" -version = "0.2.0" +name = "utf-8" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "936e4b492acfd135421d8dca4b1aa80a7bfc26e702ef3af710e0752684df5372" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" [[package]] name = "uuid" -version = "1.2.2" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "422ee0de9031b5b948b97a8fc04e3aa35230001a722ddd27943e0be31564ce4c" +checksum = "1674845326ee10d37ca60470760d4288a6f80f304007d92e5c53bab78c9cfd79" dependencies = [ "getrandom 0.2.8", "wasm-bindgen", @@ -5765,6 +6006,16 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "value-bag" +version = "1.0.0-alpha.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2209b78d1249f7e6f3293657c9779fe31ced465df091bbd433a1cf88e916ec55" +dependencies = [ + "ctor", + "version_check", +] + [[package]] name = "vcpkg" version = "0.2.15" @@ -5780,7 +6031,7 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "vodozemac" version = "0.3.0" -source = "git+https://github.com/matrix-org/vodozemac?rev=12b24e909107c1fac23245376f294eaf48ba186a#12b24e909107c1fac23245376f294eaf48ba186a" +source = "git+https://github.com/matrix-org/vodozemac?rev=fb609ca1e4df5a7a818490ae86ac694119e41e71#fb609ca1e4df5a7a818490ae86ac694119e41e71" dependencies = [ "aes", "arrayvec", @@ -5849,9 +6100,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaf9f5aceeec8be17c128b2e93e031fb8a4d469bb9c4ae2d7dc1888b26887268" +checksum = "31f8dcbc21f30d9b8f2ea926ecb58f6b91192c17e9d33594b3df58b2007ca53b" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -5859,9 +6110,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c8ffb332579b0557b52d268b91feab8df3615f265d5270fec2a8c95b17c1142" +checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9" dependencies = [ "bumpalo", "log", @@ -5874,9 +6125,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.33" +version = "0.4.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23639446165ca5a5de86ae1d8896b737ae80319560fbaa4c2887b7da6e7ebd7d" +checksum = "f219e0d211ba40266969f6dbdd90636da12f75bee4fc9d6c23d1260dadb51454" dependencies = [ "cfg-if", "js-sys", @@ -5886,9 +6137,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "052be0f94026e6cbc75cdefc9bae13fd6052cdcaf532fa6c45e7ae33a1e6c810" +checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -5896,9 +6147,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07bc0c051dc5f23e307b13285f9d75df86bfdf816c5721e573dec1f9b8aa193c" +checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6" dependencies = [ "proc-macro2", "quote", @@ -5909,15 +6160,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.83" +version = "0.2.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c38c045535d93ec4f0b4defec448e4291638ee608530863b1e2ba115d4fff7f" +checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d" [[package]] name = "wasm-bindgen-test" -version = "0.3.33" +version = "0.3.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d2fff962180c3fadf677438054b1db62bee4aa32af26a45388af07d1287e1d" +checksum = "6db36fc0f9fb209e88fb3642590ae0205bb5a56216dabd963ba15879fe53a30b" dependencies = [ "console_error_panic_hook", "js-sys", @@ -5929,9 +6180,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-test-macro" -version = "0.3.33" +version = "0.3.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4683da3dfc016f704c9f82cf401520c4f1cb3ee440f7f52b3d6ac29506a49ca7" +checksum = "0734759ae6b3b1717d661fe4f016efcfb9828f5edb4520c18eaee05af3b43be9" dependencies = [ "proc-macro2", "quote", @@ -5950,26 +6201,11 @@ dependencies = [ "web-sys", ] -[[package]] -name = "wasm-timer" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be0ecb0db480561e9a7642b5d3e4187c128914e58aa84330b9493e3eb68c5e7f" -dependencies = [ - "futures", - "js-sys", - "parking_lot 0.11.2", - "pin-utils", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - [[package]] name = "web-sys" -version = "0.3.60" +version = "0.3.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcda906d8be16e728fd5adc5b729afad4e444e106ab28cd1c7256e54fa61510f" +checksum = "e33b99f4b23ba3eec1a53ac264e35a755f00e966e0065077d6027c0f575b0b97" dependencies = [ "js-sys", "wasm-bindgen", @@ -5997,8 +6233,7 @@ dependencies = [ [[package]] name = "weedle2" version = "4.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e79c5206e1f43a2306fd64bdb95025ee4228960f2e6c5a8b173f3caaf807741" +source = "git+https://github.com/mozilla/uniffi-rs?rev=58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4#58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4" dependencies = [ "nom", ] @@ -6009,6 +6244,15 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb" +[[package]] +name = "wepoll-ffi" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d743fdedc5c64377b5fc2bc036b01c7fd642205a0d96356034ae3404d49eb7fb" +dependencies = [ + "cc", +] + [[package]] name = "which" version = "4.4.0" @@ -6072,47 +6316,71 @@ dependencies = [ "windows_x86_64_msvc", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.42.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + [[package]] name = "windows_aarch64_gnullvm" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e" +checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608" [[package]] name = "windows_aarch64_msvc" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4" +checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7" [[package]] name = "windows_i686_gnu" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7" +checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640" [[package]] name = "windows_i686_msvc" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246" +checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605" [[package]] name = "windows_x86_64_gnu" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed" +checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45" [[package]] name = "windows_x86_64_gnullvm" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028" +checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463" [[package]] name = "windows_x86_64_msvc" -version = "0.42.0" +version = "0.42.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5" +checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd" [[package]] name = "winreg" @@ -6125,9 +6393,9 @@ dependencies = [ [[package]] name = "wiremock" -version = "0.5.16" +version = "0.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "631cafe37a030d8453218cf7c650abcc359be8fba4a2fbc5c27fdb9728635406" +checksum = "12316b50eb725e22b2f6b9c4cbede5b7b89984274d113a7440c86e5c3fc6f99b" dependencies = [ "assert-json-diff", "async-trait", @@ -6186,7 +6454,7 @@ name = "xtask" version = "0.1.0" dependencies = [ "camino", - "clap 4.0.32", + "clap 4.1.8", "fs_extra", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 8e5fad15ca5..f64995365b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,22 +23,25 @@ rust-version = "1.65" anyhow = "1.0.68" async-stream = "0.3.3" async-trait = "0.1.60" -base64 = "0.20.0" +base64 = "0.21.0" byteorder = "1.4.3" ctor = "0.1.26" dashmap = "5.2.0" +eyeball = "0.4.0" +eyeball-im = "0.1.0" +futures-util = { version = "0.3.26", default-features = false, features = ["alloc"] } http = "0.2.6" -ruma = { git = "https://github.com/ruma/ruma", rev = "00045e559f864eabff08295d603f7b3238288b6f", features = ["client-api-c"] } -ruma-common = { git = "https://github.com/ruma/ruma", rev = "00045e559f864eabff08295d603f7b3238288b6f" } +ruma = { git = "https://github.com/ruma/ruma", rev = "8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5", features = ["client-api-c"] } +ruma-common = { git = "https://github.com/ruma/ruma", rev = "8eea3e05490fa9a318f9ed66c3a75272e6ef0ee5" } once_cell = "1.16.0" serde = "1.0.151" serde_html_form = "0.2.0" serde_json = "1.0.91" thiserror = "1.0.38" tracing = { version = "0.1.36", default-features = false, features = ["std"] } -uniffi = "0.23.0" -uniffi_bindgen = "0.23.0" -vodozemac = { git = "https://github.com/matrix-org/vodozemac", rev = "12b24e909107c1fac23245376f294eaf48ba186a" } +uniffi = { git = "https://github.com/mozilla/uniffi-rs", rev = "58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4" } +uniffi_bindgen = { git = "https://github.com/mozilla/uniffi-rs", rev = "58758341b72e9e8ff51ecd57a3eb22d6cc41a4b4" } +vodozemac = { git = "https://github.com/matrix-org/vodozemac", rev = "fb609ca1e4df5a7a818490ae86ac694119e41e71" } zeroize = "1.3.0" # Default release profile, select with `--release` diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 976ac8bd480..5465b4299b0 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -15,7 +15,7 @@ matrix-sdk-test = { path = "../testing/matrix-sdk-test", version = "0.6.0"} ruma = { workspace = true } serde_json = { workspace = true } tempfile = "3.3.0" -tokio = { version = "1.23.1", default-features = false, features = ["rt-multi-thread"] } +tokio = { version = "1.24.2", default-features = false, features = ["rt-multi-thread"] } [target.'cfg(target_os = "linux")'.dependencies] pprof = { version = "0.11.0", features = ["flamegraph", "criterion"] } diff --git a/bindings/CONTRIBUTING.md b/bindings/CONTRIBUTING.md new file mode 100644 index 00000000000..6f5aae26eaf --- /dev/null +++ b/bindings/CONTRIBUTING.md @@ -0,0 +1,41 @@ +## Introduction +**matrix-rust-sdk** leverages [UniFFI](https://mozilla.github.io/uniffi-rs/) to generate bindings for host languages (eg. Swift and Kotlin). + +Rust code related with bindings live in the [matrix-rust-sdk/bindings](https://github.com/matrix-org/matrix-rust-sdk/tree/main/bindings) folder. + +Developers can expose Rust code to UniFFI using two different approaches: +- Using an `.udl` file. When a crate has one, you find it under the `src` folder (an example is [here](https://github.com/matrix-org/matrix-rust-sdk/blob/main/bindings/matrix-sdk-ffi/src/api.udl)). +- Add UniFFI directivies as Rust attributes. In this case Rust source files (`.rs`) contain attributes related to UniFFI (e.g. `#[uniffi::export]`). Attributes are preferred, where applicable. + + +## Expose Rust definitions to UniFFI + +### Check if the API is already on UniFFI + +First of all check if the Rust definition you are looking for exists on UniFFI already. Most of exposed matrix definitions are collected in the crate [matrix-sdk-ffi](https://github.com/matrix-org/matrix-rust-sdk/tree/main/bindings/matrix-sdk-ffi). +This crate contains mainly small Rust wrappers around the actual Rust SDK (e.g. the crate [matrix-sdk](https://github.com/matrix-org/matrix-rust-sdk/tree/main/crates/matrix-sdk)) + +If the Rust definition is on UniFFI already, you either: +- find it in a `.udl` file like [matrix-sdk-ffi/src/api.udl](https://github.com/matrix-org/matrix-rust-sdk/blob/main/bindings/matrix-sdk-ffi/src/api.udl) +- see it marked with a proper UniFFI Rust attribute like this `#[uniffi::export]` + + +### Adding a missing matrix API + +1. Unless you want to contribute on the crypto side, you probably need to add some code in the [matrix-sdk-ffi](https://github.com/matrix-org/matrix-rust-sdk/tree/main/bindings/matrix-sdk-ffi) crate. After you find the crate you need to understand which file is best to contain the new Rust definition. When exposing new matrix API often (but not always) you need to touch the file [client.rs](https://github.com/matrix-org/matrix-rust-sdk/blob/main/bindings/matrix-sdk-ffi/src/client.rs) + +2. Identify the API to expose in the target Rust crate (typically in [matrix-sdk](https://github.com/matrix-org/matrix-rust-sdk/tree/main/crates/matrix-sdk). If you can’t find it, you probably need to touch the actual Rust SDK as well. In this case you typically just need to write some code around [ruma](https://github.com/ruma/ruma) (a Rust SDK’s dependency) which already implements most of the matrix protocol + +3. After you got (by finding or writing) the required Rust code, you need to expose to UniFFI. To do that just write a small Rust wrapper in the related UniFFI crate (most of the time is **matrix-sdk-ffi**) you found on step 1. + +4. When your new (wrapping) Rust definition is ready, remember to expose it to UniFFI. +It’s best to do it using UniFFI Rust attributes (e.g. `#[uniffi::export]`). Otherwise add the new definition in the crate’s `.udl` file. For the **matrix-sdk-ffi** crate the definition file is [api.udl](https://github.com/matrix-org/matrix-rust-sdk/blob/main/bindings/matrix-sdk-ffi/src/api.udl). **Remember**: the language inside a `.udl` file isn’t Rust. To learn more about how map Rust into UDL read [here](https://mozilla.github.io/uniffi-rs/udl_file_spec.html) + +## FAQ + +**Q**: I wrote my Rust code and exposed it to UniFFI. How can I check if the compiler is happy?\ +**A**: Run `cargo build` in the crate you touched (e.g. matrix-sdk-ffi). The compiler will complain if the Rust code and/or the `.udl` is wrong. + + +**Q**: The compiler is happy with my code but the CI is failing on GitHub. How can I fix it?\ +**A**: The CI may fail for different reasons, you need to have a look on the failing GitHub workflow. One common reason though is that the linter ([Clippy](https://github.com/rust-lang/rust-clippy)) isn’t happy with your code. If this is the case, you can run `cargo clippy` in the crate you touched to see what’s wrong and fix it accordingly. \ No newline at end of file diff --git a/bindings/README.md b/bindings/README.md index b4c07977e06..a6bfd8975e0 100644 --- a/bindings/README.md +++ b/bindings/README.md @@ -20,3 +20,6 @@ maintained by the owners of the Matrix Rust SDK project. [`matrix-sdk-crypto`]: ../crates/matrix-sdk-crypto [`matrix-sdk-ffi`]: ./matrix-sdk-ffi [`matrix-sdk`]: ../crates/matrix-sdk + +# Contributing +To contribute read this [guide](./CONTRIBUTING.md). \ No newline at end of file diff --git a/bindings/apple/Package.swift b/bindings/apple/Package.swift index 05f4003ab3c..fb04144f081 100644 --- a/bindings/apple/Package.swift +++ b/bindings/apple/Package.swift @@ -5,6 +5,10 @@ import PackageDescription let package = Package( name: "MatrixRustSDK", + platforms: [ + .iOS(.v15), + .macOS(.v12) + ], products: [ .library(name: "MatrixRustSDK", targets: ["MatrixRustSDK"]), diff --git a/bindings/apple/Tests/MatrixRustSDKTests/AuthenticationServiceTests.swift b/bindings/apple/Tests/MatrixRustSDKTests/AuthenticationServiceTests.swift new file mode 100644 index 00000000000..df92b6a32a5 --- /dev/null +++ b/bindings/apple/Tests/MatrixRustSDKTests/AuthenticationServiceTests.swift @@ -0,0 +1,64 @@ +@testable import MatrixRustSDK +import XCTest + +class AuthenticationServiceTests: XCTestCase { + var service: AuthenticationService! + + override func setUp() { + service = AuthenticationService(basePath: FileManager.default.temporaryDirectory.path, + passphrase: nil, + customSlidingSyncProxy: nil) + } + + func testValidServers() { + XCTAssertNoThrow(try service.configureHomeserver(serverNameOrHomeserverUrl: "matrix.org")) + XCTAssertNoThrow(try service.configureHomeserver(serverNameOrHomeserverUrl: "https://matrix.org")) + XCTAssertNoThrow(try service.configureHomeserver(serverNameOrHomeserverUrl: "https://matrix.org/")) + } + + func testInvalidCharacters() { + XCTAssertThrowsError(try service.configureHomeserver(serverNameOrHomeserverUrl: "hello!@$Β£%^world"), + "A server name with invalid characters should not succeed to build.") { error in + guard case AuthenticationError.InvalidServerName = error else { XCTFail("Expected invalid name error."); return } + } + } + + func textNonExistentDomain() { + XCTAssertThrowsError(try service.configureHomeserver(serverNameOrHomeserverUrl: "somesillylinkthatdoesntexist.com"), + "A server name that doesn't exist should not succeed.") { error in + guard case AuthenticationError.Generic = error else { XCTFail("Expected generic error."); return } + } + XCTAssertThrowsError(try service.configureHomeserver(serverNameOrHomeserverUrl: "https://somesillylinkthatdoesntexist.com"), + "A server URL that doesn't exist should not succeed.") { error in + guard case AuthenticationError.Generic = error else { XCTFail("Expected generic error."); return } + } + } + + func testValidDomainWithoutServer() { + XCTAssertThrowsError(try service.configureHomeserver(serverNameOrHomeserverUrl: "https://google.com"), + "Google should not succeed as it doesn't host a homeserver.") { error in + guard case AuthenticationError.Generic = error else { XCTFail("Expected generic error."); return } + } + } + + func testServerWithoutSlidingSync() { + XCTAssertThrowsError(try service.configureHomeserver(serverNameOrHomeserverUrl: "envs.net"), + "Envs should not succeed as it doesn't advertise a sliding sync proxy.") { error in + guard case AuthenticationError.SlidingSyncNotAvailable = error else { XCTFail("Expected sliding sync error."); return } + } + } + + func testHomeserverURL() { + XCTAssertThrowsError(try service.configureHomeserver(serverNameOrHomeserverUrl: "https://matrix-client.matrix.org"), + "Directly using a homeserver should not succeed as a sliding sync proxy won't be found.") { error in + guard case AuthenticationError.SlidingSyncNotAvailable = error else { XCTFail("Expected sliding sync error."); return } + } + } + + func testHomeserverURLWithProxyOverride() { + service = AuthenticationService(basePath: FileManager.default.temporaryDirectory.path, + passphrase: nil, customSlidingSyncProxy: "https://slidingsync.proxy") + XCTAssertNoThrow(try service.configureHomeserver(serverNameOrHomeserverUrl: "https://matrix-client.matrix.org"), + "Directly using a homeserver should succeed what a custom sliding sync proxy has been set.") + } +} diff --git a/bindings/kotlin/README.md b/bindings/kotlin/README.md deleted file mode 100644 index ebb206b7eab..00000000000 --- a/bindings/kotlin/README.md +++ /dev/null @@ -1,23 +0,0 @@ -# Matrix rust components kotlin - -This project and build scripts demonstrate how to create an aar and how to import it in your android projects. - -## Prerequisites - -* the Rust toolchain -* cargo-ndk < 2.12.0 `cargo install cargo-ndk --version 2.11.0` -* android targets (e.g. `rustup target add \ - aarch64-linux-android \ - armv7-linux-androideabi \ - x86_64-linux-android \ - i686-linux-android`) - -## Building the SDK - -To build the full sdk and get an aar you can call : -`./bindings/kotlin/scripts/build_sdk.sh /matrix-rust_sdk/bindings/kotlin/sample/libs` -where the parameter is the path for the aar to go - -## License - -[Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) diff --git a/bindings/kotlin/SECURITY.md b/bindings/kotlin/SECURITY.md deleted file mode 100644 index 3126b47a07e..00000000000 --- a/bindings/kotlin/SECURITY.md +++ /dev/null @@ -1,5 +0,0 @@ -# Reporting a Vulnerability - -**If you've found a security vulnerability, please report it to security@matrix.org** - -For more information on our security disclosure policy, visit https://www.matrix.org/security-disclosure-policy/ diff --git a/bindings/kotlin/build.gradle b/bindings/kotlin/build.gradle deleted file mode 100644 index 4924a37c7b1..00000000000 --- a/bindings/kotlin/build.gradle +++ /dev/null @@ -1,23 +0,0 @@ -// Top-level build file where you can add configuration options common to all sub-projects/modules. - -apply plugin: 'io.github.gradle-nexus.publish-plugin' -apply from: "${rootDir}/scripts/publish-root.gradle" - -buildscript { - repositories { - maven { url "https://plugins.gradle.org/m2/" } - google() - mavenCentral() - } - - dependencies { - classpath BuildPlugins.android - classpath BuildPlugins.kotlin - classpath BuildPlugins.nexusPublish - } -} - - -task clean(type: Delete) { - delete rootProject.buildDir -} \ No newline at end of file diff --git a/bindings/kotlin/buildSrc/build.gradle.kts b/bindings/kotlin/buildSrc/build.gradle.kts deleted file mode 100644 index 8e88a958d7f..00000000000 --- a/bindings/kotlin/buildSrc/build.gradle.kts +++ /dev/null @@ -1,9 +0,0 @@ -import org.gradle.kotlin.dsl.`kotlin-dsl` - -plugins { - `kotlin-dsl` -} - -repositories { - mavenCentral() -} \ No newline at end of file diff --git a/bindings/kotlin/buildSrc/src/main/java/ConfigurationData.kt b/bindings/kotlin/buildSrc/src/main/java/ConfigurationData.kt deleted file mode 100644 index e32455732cb..00000000000 --- a/bindings/kotlin/buildSrc/src/main/java/ConfigurationData.kt +++ /dev/null @@ -1,11 +0,0 @@ -object ConfigurationData { - const val compileSdk = 31 - const val targetSdk = 31 - const val minSdk = 21 - const val majorVersion = 0 - const val minorVersion = 2 - const val patchVersion = 0 - const val versionName = "$majorVersion.$minorVersion.$patchVersion" - const val snapshotVersionName = "$majorVersion.$minorVersion.${patchVersion + 1}-SNAPSHOT" - const val publishGroupId = "org.matrix.rustcomponents" -} \ No newline at end of file diff --git a/bindings/kotlin/buildSrc/src/main/java/Dependencies.kt b/bindings/kotlin/buildSrc/src/main/java/Dependencies.kt deleted file mode 100644 index f5f6e394ec9..00000000000 --- a/bindings/kotlin/buildSrc/src/main/java/Dependencies.kt +++ /dev/null @@ -1,22 +0,0 @@ -internal object Versions { - const val androidGradlePlugin = "7.1.2" - const val kotlin = "1.6.10" - const val jUnit = "4.12" - const val nexusPublishGradlePlugin = "1.1.0" - const val jna = "5.10.0" -} - -internal object BuildPlugins { - const val android = "com.android.tools.build:gradle:${Versions.androidGradlePlugin}" - const val kotlin = "org.jetbrains.kotlin:kotlin-gradle-plugin:${Versions.kotlin}" - const val nexusPublish = "io.github.gradle-nexus:publish-plugin:${Versions.nexusPublishGradlePlugin}" -} - -/** - * To define dependencies - */ -internal object Dependencies { - const val kotlin = "org.jetbrains.kotlin:kotlin-stdlib-jdk7:${Versions.kotlin}" - const val junit = "junit:junit:${Versions.jUnit}" - const val jna = "net.java.dev.jna:jna:${Versions.jna}@aar" -} \ No newline at end of file diff --git a/bindings/kotlin/crypto/crypto-android/.gitignore b/bindings/kotlin/crypto/crypto-android/.gitignore deleted file mode 100644 index 42afabfd2ab..00000000000 --- a/bindings/kotlin/crypto/crypto-android/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/build \ No newline at end of file diff --git a/bindings/kotlin/crypto/crypto-android/build.gradle b/bindings/kotlin/crypto/crypto-android/build.gradle deleted file mode 100644 index 543cd2a603b..00000000000 --- a/bindings/kotlin/crypto/crypto-android/build.gradle +++ /dev/null @@ -1,51 +0,0 @@ -plugins { - id 'com.android.library' - id 'org.jetbrains.kotlin.android' -} - -ext { - PUBLISH_GROUP_ID = ConfigurationData.publishGroupId - PUBLISH_ARTIFACT_ID = 'crypto-android' - PUBLISH_VERSION = rootVersionName - PUBLISH_DESCRIPTION = 'Android Bindings to the Matrix Rust Crypto SDK' -} - -apply from: "${rootDir}/scripts/publish-module.gradle" - -android { - - compileSdk ConfigurationData.compileSdk - - defaultConfig { - minSdk ConfigurationData.minSdk - targetSdk ConfigurationData.targetSdk - versionName ConfigurationData.versionName - - testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" - consumerProguardFiles "consumer-rules.pro" - } - - buildTypes { - release { - minifyEnabled false - proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' - } - } - compileOptions { - sourceCompatibility JavaVersion.VERSION_1_8 - targetCompatibility JavaVersion.VERSION_1_8 - } - kotlinOptions { - jvmTarget = '1.8' - } -} - -android.libraryVariants.all { variant -> - def sourceSet = variant.sourceSets.find { it.name == variant.name } - sourceSet.java.srcDir new File(buildDir, "generated/source/${variant.name}") -} - -dependencies { - implementation Dependencies.jna - testImplementation Dependencies.junit -} \ No newline at end of file diff --git a/bindings/kotlin/crypto/crypto-android/consumer-rules.pro b/bindings/kotlin/crypto/crypto-android/consumer-rules.pro deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/bindings/kotlin/crypto/crypto-android/proguard-rules.pro b/bindings/kotlin/crypto/crypto-android/proguard-rules.pro deleted file mode 100644 index 481bb434814..00000000000 --- a/bindings/kotlin/crypto/crypto-android/proguard-rules.pro +++ /dev/null @@ -1,21 +0,0 @@ -# Add project specific ProGuard rules here. -# You can control the set of applied configuration files using the -# proguardFiles setting in build.gradle. -# -# For more details, see -# http://developer.android.com/guide/developing/tools/proguard.html - -# If your project uses WebView with JS, uncomment the following -# and specify the fully qualified class name to the JavaScript interface -# class: -#-keepclassmembers class fqcn.of.javascript.interface.for.webview { -# public *; -#} - -# Uncomment this to preserve the line number information for -# debugging stack traces. -#-keepattributes SourceFile,LineNumberTable - -# If you keep the line number information, uncomment this to -# hide the original source file name. -#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/bindings/kotlin/crypto/crypto-android/src/main/AndroidManifest.xml b/bindings/kotlin/crypto/crypto-android/src/main/AndroidManifest.xml deleted file mode 100644 index 730df2c4c1a..00000000000 --- a/bindings/kotlin/crypto/crypto-android/src/main/AndroidManifest.xml +++ /dev/null @@ -1,4 +0,0 @@ - - - - diff --git a/bindings/kotlin/crypto/crypto-jvm/.gitignore b/bindings/kotlin/crypto/crypto-jvm/.gitignore deleted file mode 100644 index 42afabfd2ab..00000000000 --- a/bindings/kotlin/crypto/crypto-jvm/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/build \ No newline at end of file diff --git a/bindings/kotlin/crypto/crypto-jvm/build.gradle b/bindings/kotlin/crypto/crypto-jvm/build.gradle deleted file mode 100644 index ce669345bf8..00000000000 --- a/bindings/kotlin/crypto/crypto-jvm/build.gradle +++ /dev/null @@ -1,13 +0,0 @@ -plugins { - id 'java-library' - id 'org.jetbrains.kotlin.jvm' -} - -java { - sourceCompatibility = JavaVersion.VERSION_1_7 - targetCompatibility = JavaVersion.VERSION_1_7 -} - -dependencies { - implementation 'net.java.dev.jna:jna:5.10.0@aar' -} \ No newline at end of file diff --git a/bindings/kotlin/gradle.properties b/bindings/kotlin/gradle.properties deleted file mode 100644 index d2c86c8ce02..00000000000 --- a/bindings/kotlin/gradle.properties +++ /dev/null @@ -1,25 +0,0 @@ -# Project-wide Gradle settings. -# IDE (e.g. Android Studio) users: -# Gradle settings configured through the IDE *will override* -# any settings specified in this file. -# For more details on how to configure your build environment visit -# http://www.gradle.org/docs/current/userguide/build_environment.html -# Specifies the JVM arguments used for the daemon process. -# The setting is particularly useful for tweaking memory settings. -org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 -# When configured, Gradle will run in incubating parallel mode. -# This option should only be used with decoupled projects. More details, visit -# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects -# org.gradle.parallel=true -# AndroidX package structure to make it clearer which packages are bundled with the -# Android operating system, and which are packaged with your app"s APK -# https://developer.android.com/topic/libraries/support-library/androidx-rn -android.useAndroidX=true -# Automatically convert third-party libraries to use AndroidX -android.enableJetifier=true -# Kotlin code style for this project: "official" or "obsolete": -kotlin.code.style=official -# Enables namespacing of each library's R class so that its R class includes only the -# resources declared in the library itself and none from the library's dependencies, -# thereby reducing the size of the R class for that library -android.nonTransitiveRClass=true diff --git a/bindings/kotlin/gradle/wrapper/gradle-wrapper.jar b/bindings/kotlin/gradle/wrapper/gradle-wrapper.jar deleted file mode 100644 index e708b1c023e..00000000000 Binary files a/bindings/kotlin/gradle/wrapper/gradle-wrapper.jar and /dev/null differ diff --git a/bindings/kotlin/gradle/wrapper/gradle-wrapper.properties b/bindings/kotlin/gradle/wrapper/gradle-wrapper.properties deleted file mode 100644 index a10cc8b8d88..00000000000 --- a/bindings/kotlin/gradle/wrapper/gradle-wrapper.properties +++ /dev/null @@ -1,6 +0,0 @@ -#Mon Feb 28 18:48:31 CET 2022 -distributionBase=GRADLE_USER_HOME -distributionUrl=https\://services.gradle.org/distributions/gradle-7.2-bin.zip -distributionPath=wrapper/dists -zipStorePath=wrapper/dists -zipStoreBase=GRADLE_USER_HOME diff --git a/bindings/kotlin/gradlew b/bindings/kotlin/gradlew deleted file mode 100755 index 4f906e0c811..00000000000 --- a/bindings/kotlin/gradlew +++ /dev/null @@ -1,185 +0,0 @@ -#!/usr/bin/env sh - -# -# Copyright 2015 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -############################################################################## -## -## Gradle start up script for UN*X -## -############################################################################## - -# Attempt to set APP_HOME -# Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi -done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null - -APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` - -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' - -# Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" - -warn () { - echo "$*" -} - -die () { - echo - echo "$*" - echo - exit 1 -} - -# OS specific support (must be 'true' or 'false'). -cygwin=false -msys=false -darwin=false -nonstop=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MINGW* ) - msys=true - ;; - NONSTOP* ) - nonstop=true - ;; -esac - -CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar - - -# Determine the Java command to use to start the JVM. -if [ -n "$JAVA_HOME" ] ; then - if [ -x "$JAVA_HOME/jre/sh/java" ] ; then - # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" - else - JAVACMD="$JAVA_HOME/bin/java" - fi - if [ ! -x "$JAVACMD" ] ; then - die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME - -Please set the JAVA_HOME variable in your environment to match the -location of your Java installation." - fi -else - JAVACMD="java" - which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. - -Please set the JAVA_HOME variable in your environment to match the -location of your Java installation." -fi - -# Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi -fi - -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi - -# For Cygwin or MSYS, switch paths to Windows format before running java -if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi - # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" - fi - i=`expr $i + 1` - done - case $i in - 0) set -- ;; - 1) set -- "$args0" ;; - 2) set -- "$args0" "$args1" ;; - 3) set -- "$args0" "$args1" "$args2" ;; - 4) set -- "$args0" "$args1" "$args2" "$args3" ;; - 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; - esac -fi - -# Escape application args -save () { - for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done - echo " " -} -APP_ARGS=`save "$@"` - -# Collect all arguments for the java command, following the shell quoting and substitution rules -eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" - -exec "$JAVACMD" "$@" diff --git a/bindings/kotlin/gradlew.bat b/bindings/kotlin/gradlew.bat deleted file mode 100644 index ac1b06f9382..00000000000 --- a/bindings/kotlin/gradlew.bat +++ /dev/null @@ -1,89 +0,0 @@ -@rem -@rem Copyright 2015 the original author or authors. -@rem -@rem Licensed under the Apache License, Version 2.0 (the "License"); -@rem you may not use this file except in compliance with the License. -@rem You may obtain a copy of the License at -@rem -@rem https://www.apache.org/licenses/LICENSE-2.0 -@rem -@rem Unless required by applicable law or agreed to in writing, software -@rem distributed under the License is distributed on an "AS IS" BASIS, -@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -@rem See the License for the specific language governing permissions and -@rem limitations under the License. -@rem - -@if "%DEBUG%" == "" @echo off -@rem ########################################################################## -@rem -@rem Gradle startup script for Windows -@rem -@rem ########################################################################## - -@rem Set local scope for the variables with windows NT shell -if "%OS%"=="Windows_NT" setlocal - -set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. -set APP_BASE_NAME=%~n0 -set APP_HOME=%DIRNAME% - -@rem Resolve any "." and ".." in APP_HOME to make it shorter. -for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi - -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" - -@rem Find java.exe -if defined JAVA_HOME goto findJavaFromJavaHome - -set JAVA_EXE=java.exe -%JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto execute - -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:findJavaFromJavaHome -set JAVA_HOME=%JAVA_HOME:"=% -set JAVA_EXE=%JAVA_HOME%/bin/java.exe - -if exist "%JAVA_EXE%" goto execute - -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:execute -@rem Setup the command line - -set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar - - -@rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* - -:end -@rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd - -:fail -rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of -rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 - -:mainEnd -if "%OS%"=="Windows_NT" endlocal - -:omega diff --git a/bindings/kotlin/scripts/build_crypto.sh b/bindings/kotlin/scripts/build_crypto.sh deleted file mode 100755 index d0ee9710227..00000000000 --- a/bindings/kotlin/scripts/build_crypto.sh +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env bash -set -eEu - -cd "$(dirname "$0")" -CURRENT_DIR=$(pwd) - -# FOR DEBUG -#RELEASE_FLAG="" -#RELEASE_TYPE_DIR="debug" -#RELEASE_AAR_NAME="crypto-android-debug" - -# FOR RELEASE -RELEASE_FLAG="--release" -RELEASE_TYPE_DIR="release" -RELEASE_AAR_NAME="crypto-android-release" - -SRC_ROOT=../../.. -# Path to the kotlin root project -KOTLIN_ROOT=.. - -BASE_TARGET_DIR="${SRC_ROOT}/target" -SDK_ROOT="${KOTLIN_ROOT}/crypto/crypto-android" -SDK_TARGET_DIR="${SDK_ROOT}/src/main/jniLibs" -BUILD_DIR="${SDK_ROOT}/build" -GENERATED_DIR="${BUILD_DIR}/generated/source/${RELEASE_TYPE_DIR}" -mkdir -p ${GENERATED_DIR} - -TARGET_CRATE=matrix-sdk-crypto-ffi - -AAR_DESTINATION=$1 - -# Build libs for all the different architectures - -echo -e "Building for x86_64-linux-android[1/4]" -cargo ndk --target x86_64-linux-android -o ${SDK_TARGET_DIR}/ build "${RELEASE_FLAG}" -p ${TARGET_CRATE} - -echo -e "Building for aarch64-linux-android[2/4]" -cargo ndk --target aarch64-linux-android -o ${SDK_TARGET_DIR}/ build "${RELEASE_FLAG}" -p ${TARGET_CRATE} - -echo -e "Building for armv7-linux-androideabi[3/4]" -cargo ndk --target armv7-linux-androideabi -o ${SDK_TARGET_DIR}/ build "${RELEASE_FLAG}" -p ${TARGET_CRATE} - -echo -e "Building for i686-linux-android[4/4]" -cargo ndk --target i686-linux-android -o ${SDK_TARGET_DIR}/ build "${RELEASE_FLAG}" -p ${TARGET_CRATE} - -# Generate uniffi files -echo -e "Generate uniffi kotlin file" -cargo uniffi-bindgen generate "${SRC_ROOT}/bindings/${TARGET_CRATE}/src/olm.udl" \ - --language kotlin \ - --config "${SRC_ROOT}/bindings/${TARGET_CRATE}/uniffi.toml" \ - --out-dir ${GENERATED_DIR} \ - --lib-file "${BASE_TARGET_DIR}/x86_64-linux-android/${RELEASE_TYPE_DIR}/libmatrix_sdk_crypto_ffi.a" - -# Create android library -cd "${KOTLIN_ROOT}" -./gradlew :crypto:crypto-android:assemble -cd "${CURRENT_DIR}" - -echo -e "Moving the generated aar file to ${AAR_DESTINATION}/matrix-rust-sdk-crypto.aar" -mv "${BUILD_DIR}/outputs/aar/${RELEASE_AAR_NAME}.aar" "${AAR_DESTINATION}/matrix-rust-sdk-crypto.aar" - -# Clean-up -echo -e "Cleaning up temporary files" - -rm -r "${BUILD_DIR}" -rm -r "${SDK_TARGET_DIR}" diff --git a/bindings/kotlin/scripts/build_sdk.sh b/bindings/kotlin/scripts/build_sdk.sh deleted file mode 100755 index e7f3249f7c0..00000000000 --- a/bindings/kotlin/scripts/build_sdk.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env bash -set -eEu - -cd "$(dirname "$0")" -CURRENT_DIR=$(pwd) - -# FOR DEBUG -#RELEASE_FLAG="" -#RELEASE_TYPE_DIR="debug" -#RELEASE_AAR_NAME="sdk-android-debug" - -# FOR RELEASE -RELEASE_FLAG="--release" -RELEASE_TYPE_DIR="release" -RELEASE_AAR_NAME="sdk-android-release" - -SRC_ROOT=../../.. -# Path to the kotlin root project -KOTLIN_ROOT=.. - -BASE_TARGET_DIR="${SRC_ROOT}/target" -SDK_ROOT="${KOTLIN_ROOT}/sdk/sdk-android" -SDK_TARGET_DIR="${SDK_ROOT}/src/main/jniLibs" -BUILD_DIR="${SDK_ROOT}/build" -GENERATED_DIR="${BUILD_DIR}/generated/source/${RELEASE_TYPE_DIR}" -mkdir -p ${GENERATED_DIR} - -AAR_DESTINATION=$1 - -# Build libs for all the different architectures - -echo -e "Building for x86_64-linux-android[1/4]" -cargo ndk --target x86_64-linux-android -o ${SDK_TARGET_DIR}/ build "${RELEASE_FLAG}" -p matrix-sdk-ffi - -echo -e "Building for aarch64-linux-android[2/4]" -cargo ndk --target aarch64-linux-android -o ${SDK_TARGET_DIR}/ build "${RELEASE_FLAG}" -p matrix-sdk-ffi - -echo -e "Building for armv7-linux-androideabi[3/4]" -cargo ndk --target armv7-linux-androideabi -o ${SDK_TARGET_DIR}/ build "${RELEASE_FLAG}" -p matrix-sdk-ffi - -echo -e "Building for i686-linux-android[4/4]" -cargo ndk --target i686-linux-android -o ${SDK_TARGET_DIR}/ build "${RELEASE_FLAG}" -p matrix-sdk-ffi - -# Generate uniffi files -echo -e "Generate uniffi kotlin file" -cargo uniffi-bindgen generate "${SRC_ROOT}/bindings/matrix-sdk-ffi/src/api.udl" \ - --language kotlin \ - --out-dir ${GENERATED_DIR} \ - --lib-file "${BASE_TARGET_DIR}/x86_64-linux-android/${RELEASE_TYPE_DIR}/libmatrix_sdk_ffi.a" - -# Create android library -cd "${KOTLIN_ROOT}" -./gradlew :sdk:sdk-android:assemble -cd "${CURRENT_DIR}" - -echo -e "Moving the generated aar file to ${AAR_DESTINATION}/matrix-rust-sdk.aar" -mv "${BUILD_DIR}/outputs/aar/${RELEASE_AAR_NAME}.aar" "${AAR_DESTINATION}/matrix-rust-sdk.aar" - -# Clean-up -echo -e "Cleaning up temporary files" - -rm -r "${BUILD_DIR}" -rm -r "${SDK_TARGET_DIR}" diff --git a/bindings/kotlin/scripts/publish-module.gradle b/bindings/kotlin/scripts/publish-module.gradle deleted file mode 100644 index 0b7d3f195c0..00000000000 --- a/bindings/kotlin/scripts/publish-module.gradle +++ /dev/null @@ -1,77 +0,0 @@ -apply plugin: 'maven-publish' -apply plugin: 'signing' - -task androidSourcesJar(type: Jar) { - archiveClassifier.set('sources') - if (project.plugins.findPlugin("com.android.library")) { - // For Android libraries - from android.sourceSets.main.java.srcDirs - from android.sourceSets.main.kotlin.srcDirs - } else { - // For pure Kotlin libraries, in case you have them - from sourceSets.main.java.srcDirs - from sourceSets.main.kotlin.srcDirs - } -} - -artifacts { - archives androidSourcesJar -} - -group = PUBLISH_GROUP_ID -version = rootVersionName - -afterEvaluate { - publishing { - publications { - release(MavenPublication) { - - groupId PUBLISH_GROUP_ID - artifactId PUBLISH_ARTIFACT_ID - version PUBLISH_VERSION - - if (project.plugins.findPlugin("com.android.library")) { - from components.release - } else { - from components.java - } - - artifact androidSourcesJar - - pom { - name = PUBLISH_ARTIFACT_ID - description = PUBLISH_DESCRIPTION - url = 'https://github.com/matrix-org/matrix-rust-components-kotlin' - licenses { - license { - name = 'The Apache Software License, Version 2.0' - url = 'https://www.apache.org/licenses/LICENSE-2.0.txt' - } - } - developers { - developer { - id = 'matrixdev' - name = 'matrixdev' - email = 'android@element.io' - } - } - - scm { - connection = 'scm:git:git://github.com/matrix-org/matrix-rust-components-kotlin.git' - developerConnection = 'scm:git:ssh://git@github.com/matrix-org/matrix-rust-components-kotlin.git' - url = 'https://github.com/matrix-org/matrix-rust-components-kotlin' - } - } - } - } - } -} - -signing { - useInMemoryPgpKeys( - rootProject.ext["signing.keyId"], - rootProject.ext["signing.key"], - rootProject.ext["signing.password"], - ) - sign publishing.publications -} \ No newline at end of file diff --git a/bindings/kotlin/scripts/publish-root.gradle b/bindings/kotlin/scripts/publish-root.gradle deleted file mode 100644 index 609cb9a8f6a..00000000000 --- a/bindings/kotlin/scripts/publish-root.gradle +++ /dev/null @@ -1,43 +0,0 @@ -ext["signing.keyId"] = '' -ext["signing.password"] = '' -ext["signing.key"] = '' -ext["ossrhUsername"] = '' -ext["ossrhPassword"] = '' -ext["sonatypeStagingProfileId"] = '' -ext["snapshot"] = '' - -File secretPropsFile = project.rootProject.file('local.properties') -if (secretPropsFile.exists()) { - // Read local.properties file first if it exists - Properties p = new Properties() - new FileInputStream(secretPropsFile).withCloseable { is -> p.load(is) } - p.each { name, value -> ext[name] = value } -} else { - // Use system environment variables - ext["ossrhUsername"] = System.getenv('OSSRH_USERNAME') - ext["ossrhPassword"] = System.getenv('OSSRH_PASSWORD') - ext["sonatypeStagingProfileId"] = System.getenv('SONATYPE_STAGING_PROFILE_ID') - ext["signing.keyId"] = System.getenv('SIGNING_KEY_ID') - ext["signing.password"] = System.getenv('SIGNING_PASSWORD') - ext["signing.key"] = System.getenv('SIGNING_KEY') - ext["snapshot"] = System.getenv('SNAPSHOT') -} - -if (snapshot.toBoolean()) { - ext["rootVersionName"] = ConfigurationData.snapshotVersionName -} else { - ext["rootVersionName"] = ConfigurationData.versionName -} - -nexusPublishing { - repositories { - sonatype { - stagingProfileId = sonatypeStagingProfileId - username = ossrhUsername - password = ossrhPassword - version = rootVersionName - nexusUrl.set(uri("https://s01.oss.sonatype.org/service/local/")) - snapshotRepositoryUrl.set(uri("https://s01.oss.sonatype.org/content/repositories/snapshots/")) - } - } -} \ No newline at end of file diff --git a/bindings/kotlin/sdk/sdk-android/.gitignore b/bindings/kotlin/sdk/sdk-android/.gitignore deleted file mode 100644 index 42afabfd2ab..00000000000 --- a/bindings/kotlin/sdk/sdk-android/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/build \ No newline at end of file diff --git a/bindings/kotlin/sdk/sdk-android/build.gradle b/bindings/kotlin/sdk/sdk-android/build.gradle deleted file mode 100644 index 14cb16560a7..00000000000 --- a/bindings/kotlin/sdk/sdk-android/build.gradle +++ /dev/null @@ -1,51 +0,0 @@ -plugins { - id 'com.android.library' - id 'org.jetbrains.kotlin.android' -} - -ext { - PUBLISH_GROUP_ID = ConfigurationData.publishGroupId - PUBLISH_ARTIFACT_ID = 'sdk-android' - PUBLISH_VERSION = rootVersionName - PUBLISH_DESCRIPTION = 'Android Bindings to the Matrix Rust SDK' -} - -apply from: "${rootDir}/scripts/publish-module.gradle" - -android { - - compileSdk ConfigurationData.compileSdk - - defaultConfig { - minSdk ConfigurationData.minSdk - targetSdk ConfigurationData.targetSdk - versionName ConfigurationData.versionName - - testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" - consumerProguardFiles "consumer-rules.pro" - } - - buildTypes { - release { - minifyEnabled false - proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' - } - } - compileOptions { - sourceCompatibility JavaVersion.VERSION_1_8 - targetCompatibility JavaVersion.VERSION_1_8 - } - kotlinOptions { - jvmTarget = '1.8' - } -} - -android.libraryVariants.all { variant -> - def sourceSet = variant.sourceSets.find { it.name == variant.name } - sourceSet.java.srcDir new File(buildDir, "generated/source/${variant.name}") -} - -dependencies { - implementation Dependencies.jna - testImplementation Dependencies.junit -} diff --git a/bindings/kotlin/sdk/sdk-android/consumer-rules.pro b/bindings/kotlin/sdk/sdk-android/consumer-rules.pro deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/bindings/kotlin/sdk/sdk-android/proguard-rules.pro b/bindings/kotlin/sdk/sdk-android/proguard-rules.pro deleted file mode 100644 index 481bb434814..00000000000 --- a/bindings/kotlin/sdk/sdk-android/proguard-rules.pro +++ /dev/null @@ -1,21 +0,0 @@ -# Add project specific ProGuard rules here. -# You can control the set of applied configuration files using the -# proguardFiles setting in build.gradle. -# -# For more details, see -# http://developer.android.com/guide/developing/tools/proguard.html - -# If your project uses WebView with JS, uncomment the following -# and specify the fully qualified class name to the JavaScript interface -# class: -#-keepclassmembers class fqcn.of.javascript.interface.for.webview { -# public *; -#} - -# Uncomment this to preserve the line number information for -# debugging stack traces. -#-keepattributes SourceFile,LineNumberTable - -# If you keep the line number information, uncomment this to -# hide the original source file name. -#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/bindings/kotlin/sdk/sdk-android/src/main/AndroidManifest.xml b/bindings/kotlin/sdk/sdk-android/src/main/AndroidManifest.xml deleted file mode 100644 index 1ae1f06b5cd..00000000000 --- a/bindings/kotlin/sdk/sdk-android/src/main/AndroidManifest.xml +++ /dev/null @@ -1,2 +0,0 @@ - - diff --git a/bindings/kotlin/sdk/sdk-jvm/.gitignore b/bindings/kotlin/sdk/sdk-jvm/.gitignore deleted file mode 100644 index 42afabfd2ab..00000000000 --- a/bindings/kotlin/sdk/sdk-jvm/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/build \ No newline at end of file diff --git a/bindings/kotlin/sdk/sdk-jvm/build.gradle b/bindings/kotlin/sdk/sdk-jvm/build.gradle deleted file mode 100644 index ce669345bf8..00000000000 --- a/bindings/kotlin/sdk/sdk-jvm/build.gradle +++ /dev/null @@ -1,13 +0,0 @@ -plugins { - id 'java-library' - id 'org.jetbrains.kotlin.jvm' -} - -java { - sourceCompatibility = JavaVersion.VERSION_1_7 - targetCompatibility = JavaVersion.VERSION_1_7 -} - -dependencies { - implementation 'net.java.dev.jna:jna:5.10.0@aar' -} \ No newline at end of file diff --git a/bindings/kotlin/settings.gradle b/bindings/kotlin/settings.gradle deleted file mode 100644 index f1f0f22796d..00000000000 --- a/bindings/kotlin/settings.gradle +++ /dev/null @@ -1,27 +0,0 @@ -pluginManagement { - repositories { - gradlePluginPortal() - google() - mavenCentral() - } - plugins { - id 'com.android.application' version '7.1.0-beta01' - id 'com.android.library' version '7.1.0-beta01' - id 'org.jetbrains.kotlin.android' version '1.5.30' - id 'org.jetbrains.kotlin.jvm' version '1.5.30' - } -} -dependencyResolutionManagement { - repositories { - google() - mavenCentral() - flatDir { - dirs 'libs' - } - } -} -rootProject.name = "MatrixKotlinRustSDK" -include ':crypto:crypto-android' -include ':crypto:crypto-jvm' -include ':sdk:sdk-jvm' -include ':sdk:sdk-android' diff --git a/bindings/matrix-sdk-crypto-ffi/Cargo.toml b/bindings/matrix-sdk-crypto-ffi/Cargo.toml index 0c8c876b7ca..31ed0ecf1db 100644 --- a/bindings/matrix-sdk-crypto-ffi/Cargo.toml +++ b/bindings/matrix-sdk-crypto-ffi/Cargo.toml @@ -52,7 +52,7 @@ default_features = false features = ["crypto-store"] [dependencies.tokio] -version = "1.23.1" +version = "1.24.2" default_features = false features = ["rt-multi-thread"] diff --git a/bindings/matrix-sdk-crypto-ffi/src/backup_recovery_key.rs b/bindings/matrix-sdk-crypto-ffi/src/backup_recovery_key.rs index 5cf6e3974f4..110406fb26e 100644 --- a/bindings/matrix-sdk-crypto-ffi/src/backup_recovery_key.rs +++ b/bindings/matrix-sdk-crypto-ffi/src/backup_recovery_key.rs @@ -27,8 +27,7 @@ pub enum PkDecryptionError { } /// Error type for the decoding and storing of the backup key. -#[derive(Debug, Error, uniffi::Error)] -#[uniffi(flat_error)] +#[derive(Debug, Error)] pub enum DecodeError { /// An error happened while decoding the recovery key. #[error(transparent)] @@ -41,7 +40,7 @@ pub enum DecodeError { /// Struct containing info about the way the backup key got derived from a /// passphrase. -#[derive(Debug, Clone, uniffi::Record)] +#[derive(Debug, Clone)] pub struct PassphraseInfo { /// The salt that was used during key derivation. pub private_key_salt: String, diff --git a/bindings/matrix-sdk-crypto-ffi/src/error.rs b/bindings/matrix-sdk-crypto-ffi/src/error.rs index 8017481eea2..6ae377a0191 100644 --- a/bindings/matrix-sdk-crypto-ffi/src/error.rs +++ b/bindings/matrix-sdk-crypto-ffi/src/error.rs @@ -4,6 +4,7 @@ use matrix_sdk_crypto::{ store::CryptoStoreError as InnerStoreError, KeyExportError, MegolmError, OlmError, SecretImportError as RustSecretImportError, SignatureError as InnerSignatureError, }; +use matrix_sdk_sqlite::OpenStoreError; use ruma::{IdParseError, OwnedUserId}; #[derive(Debug, thiserror::Error)] @@ -38,9 +39,10 @@ pub enum SignatureError { UnknownUserIdentity(String), } -#[derive(Debug, thiserror::Error, uniffi::Error)] -#[uniffi(flat_error)] +#[derive(Debug, thiserror::Error)] pub enum CryptoStoreError { + #[error("Failed to open the store")] + OpenStore(#[from] OpenStoreError), #[error(transparent)] CryptoStore(#[from] InnerStoreError), #[error(transparent)] diff --git a/bindings/matrix-sdk-crypto-ffi/src/lib.rs b/bindings/matrix-sdk-crypto-ffi/src/lib.rs index 4527025f3b2..aad28d80562 100644 --- a/bindings/matrix-sdk-crypto-ffi/src/lib.rs +++ b/bindings/matrix-sdk-crypto-ffi/src/lib.rs @@ -28,9 +28,11 @@ pub use error::{ use js_int::UInt; pub use logger::{set_logger, Logger}; pub use machine::{KeyRequestPair, OlmMachine, SignatureVerification}; -use matrix_sdk_common::deserialized_responses::VerificationState; +use matrix_sdk_common::deserialized_responses::ShieldState as RustShieldState; use matrix_sdk_crypto::{ backups::SignatureState, + olm::{IdentityKeys, InboundGroupSession, Session}, + store::{Changes, CryptoStore, RoomSettings as RustRoomSettings}, types::{EventEncryptionAlgorithm as RustEventEncryptionAlgorithm, SigningKey}, EncryptionSettings as RustEncryptionSettings, LocalTrust, }; @@ -41,15 +43,18 @@ pub use responses::{ }; use ruma::{ events::room::history_visibility::HistoryVisibility as RustHistoryVisibility, DeviceId, - DeviceKeyAlgorithm, OwnedUserId, RoomId, SecondsSinceUnixEpoch, UserId, + DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, RoomId, SecondsSinceUnixEpoch, UserId, }; use serde::{Deserialize, Serialize}; +use tokio::runtime::Runtime; +use uniffi_api::*; pub use users::UserIdentity; pub use verification::{ CancelInfo, ConfirmVerificationResult, QrCode, QrCodeListener, QrCodeState, RequestVerificationResult, Sas, SasListener, SasState, ScanResult, StartSasResult, Verification, VerificationRequest, VerificationRequestListener, VerificationRequestState, }; +use vodozemac::{Curve25519PublicKey, Ed25519PublicKey}; /// Struct collecting data that is important to migrate to the rust-sdk #[derive(Deserialize, Serialize)] @@ -70,6 +75,26 @@ pub struct MigrationData { cross_signing: CrossSigningKeyExport, /// The list of users that the Rust SDK should track. tracked_users: Vec, + /// Map of room settings + room_settings: HashMap, +} + +/// Struct collecting data that is important to migrate sessions to the rust-sdk +pub struct SessionMigrationData { + /// The user id that the data belongs to. + user_id: String, + /// The device id that the data belongs to. + device_id: String, + /// The Curve25519 public key of the Account that owns this data. + curve25519_key: String, + /// The Ed25519 public key of the Account that owns this data. + ed25519_key: String, + /// The list of pickleds Olm Sessions. + sessions: Vec, + /// The list of pickled Megolm inbound group sessions. + inbound_group_sessions: Vec, + /// The Olm pickle key that was used to pickle all the Olm objects. + pickle_key: Vec, } /// A pickled version of an `Account`. @@ -149,17 +174,17 @@ impl From for MigrationError { } } -/// Migrate a libolm based setup to a vodozemac based setup stored in a Sled +/// Migrate a libolm based setup to a vodozemac based setup stored in a SQLite /// store. /// /// # Arguments /// -/// * `data` - The data that should be migrated over to the Sled store. +/// * `data` - The data that should be migrated over to the SQLite store. /// -/// * `path` - The path where the Sled store should be created. +/// * `path` - The path where the SQLite store should be created. /// /// * `passphrase` - The passphrase that should be used to encrypt the data at -/// rest in the Sled store. **Warning**, if no passphrase is given, the store +/// rest in the SQLite store. **Warning**, if no passphrase is given, the store /// and all its data will remain unencrypted. /// /// * `progress_listener` - A callback that can be used to introspect the @@ -170,7 +195,6 @@ pub fn migrate( passphrase: Option, progress_listener: Box, ) -> anyhow::Result<()> { - use tokio::runtime::Runtime; let runtime = Runtime::new()?; runtime.block_on(async move { migrate_data(data, path, passphrase, progress_listener).await?; @@ -184,15 +208,8 @@ async fn migrate_data( passphrase: Option, progress_listener: Box, ) -> anyhow::Result<()> { - use matrix_sdk_crypto::{ - olm::PrivateCrossSigningIdentity, - store::{Changes as RustChanges, CryptoStore, RecoveryKey}, - }; - use vodozemac::{ - megolm::InboundGroupSession, - olm::{Account, Session}, - Curve25519PublicKey, - }; + use matrix_sdk_crypto::{olm::PrivateCrossSigningIdentity, store::RecoveryKey}; + use vodozemac::olm::Account; use zeroize::Zeroize; // The total steps here include all the sessions/inbound group sessions and @@ -238,11 +255,170 @@ async fn migrate_data( processed_steps += 1; listener(processed_steps, total_steps); + let (sessions, inbound_group_sessions) = collect_sessions( + processed_steps, + total_steps, + &listener, + &data.pickle_key, + user_id.clone(), + device_id, + identity_keys, + data.sessions, + data.inbound_group_sessions, + )?; + + let recovery_key = + data.backup_recovery_key.map(|k| RecoveryKey::from_base58(k.as_str())).transpose()?; + + let cross_signing = PrivateCrossSigningIdentity::empty((*user_id).into()); + cross_signing + .import_secrets_unchecked( + data.cross_signing.master_key.as_deref(), + data.cross_signing.self_signing_key.as_deref(), + data.cross_signing.user_signing_key.as_deref(), + ) + .await?; + + data.cross_signing.master_key.zeroize(); + data.cross_signing.self_signing_key.zeroize(); + data.cross_signing.user_signing_key.zeroize(); + + processed_steps += 1; + listener(processed_steps, total_steps); + + let tracked_users: Vec<_> = data + .tracked_users + .into_iter() + .map(|u| Ok(((parse_user_id(&u)?), true))) + .collect::>()?; + + let tracked_users: Vec<_> = tracked_users.iter().map(|(u, d)| (&**u, *d)).collect(); + store.save_tracked_users(tracked_users.as_slice()).await?; + + processed_steps += 1; + listener(processed_steps, total_steps); + + let mut room_settings = HashMap::new(); + for (room_id, settings) in data.room_settings { + let room_id = RoomId::parse(room_id)?; + room_settings.insert(room_id, settings.into()); + } + + let changes = Changes { + account: Some(account), + private_identity: Some(cross_signing), + sessions, + inbound_group_sessions, + recovery_key, + backup_version: data.backup_version, + room_settings, + ..Default::default() + }; + + save_changes(processed_steps, total_steps, &listener, changes, &store).await +} + +async fn save_changes( + mut processed_steps: usize, + total_steps: usize, + listener: &dyn Fn(usize, usize), + changes: Changes, + store: &SqliteCryptoStore, +) -> anyhow::Result<()> { + store.save_changes(changes).await?; + + processed_steps += 1; + listener(processed_steps, total_steps); + + Ok(()) +} + +/// Migrate sessions and group sessions of a libolm based setup to a vodozemac +/// based setup stored in a SQLite store. +/// +/// This method allows you to migrate a subset of the data, it should only be +/// used after the [`migrate()`] method has been already used. +/// +/// # Arguments +/// +/// * `data` - The data that should be migrated over to the SQLite store. +/// +/// * `path` - The path where the SQLite store should be created. +/// +/// * `passphrase` - The passphrase that should be used to encrypt the data at +/// rest in the SQLite store. **Warning**, if no passphrase is given, the store +/// and all its data will remain unencrypted. +/// +/// * `progress_listener` - A callback that can be used to introspect the +/// progress of the migration. +pub fn migrate_sessions( + data: SessionMigrationData, + path: &str, + passphrase: Option, + progress_listener: Box, +) -> anyhow::Result<()> { + let runtime = Runtime::new()?; + runtime.block_on(migrate_session_data(data, path, passphrase, progress_listener)) +} + +async fn migrate_session_data( + data: SessionMigrationData, + path: &str, + passphrase: Option, + progress_listener: Box, +) -> anyhow::Result<()> { + let store = SqliteCryptoStore::open(path, passphrase.as_deref()).await?; + + let listener = |progress: usize, total: usize| { + progress_listener.on_progress(progress as i32, total as i32) + }; + + let total_steps = 1 + data.sessions.len() + data.inbound_group_sessions.len(); + let processed_steps = 0; + + let user_id = UserId::parse(data.user_id)?.into(); + let device_id: OwnedDeviceId = data.device_id.into(); + + let identity_keys = IdentityKeys { + ed25519: Ed25519PublicKey::from_base64(&data.ed25519_key)?, + curve25519: Curve25519PublicKey::from_base64(&data.curve25519_key)?, + } + .into(); + + let (sessions, inbound_group_sessions) = collect_sessions( + processed_steps, + total_steps, + &listener, + &data.pickle_key, + user_id, + device_id.into(), + identity_keys, + data.sessions, + data.inbound_group_sessions, + )?; + + let changes = Changes { sessions, inbound_group_sessions, ..Default::default() }; + save_changes(processed_steps, total_steps, &listener, changes, &store).await +} + +#[allow(clippy::too_many_arguments)] +fn collect_sessions( + mut processed_steps: usize, + total_steps: usize, + listener: &dyn Fn(usize, usize), + pickle_key: &[u8], + user_id: Arc, + device_id: Arc, + identity_keys: Arc, + session_pickles: Vec, + group_session_pickles: Vec, +) -> anyhow::Result<(Vec, Vec)> { let mut sessions = Vec::new(); - for session_pickle in data.sessions { + for session_pickle in session_pickles { let pickle = - Session::from_libolm_pickle(&session_pickle.pickle, &data.pickle_key)?.pickle(); + vodozemac::olm::Session::from_libolm_pickle(&session_pickle.pickle, pickle_key)? + .pickle(); let creation_time = SecondsSinceUnixEpoch(UInt::from_str(&session_pickle.creation_time)?); let last_use_time = SecondsSinceUnixEpoch(UInt::from_str(&session_pickle.last_use_time)?); @@ -255,12 +431,8 @@ async fn migrate_data( last_use_time, }; - let session = matrix_sdk_crypto::olm::Session::from_pickle( - user_id.clone(), - device_id.clone(), - identity_keys.clone(), - pickle, - ); + let session = + Session::from_pickle(user_id.clone(), device_id.clone(), identity_keys.clone(), pickle); sessions.push(session); processed_steps += 1; @@ -269,9 +441,12 @@ async fn migrate_data( let mut inbound_group_sessions = Vec::new(); - for session in data.inbound_group_sessions { - let pickle = - InboundGroupSession::from_libolm_pickle(&session.pickle, &data.pickle_key)?.pickle(); + for session in group_session_pickles { + let pickle = vodozemac::megolm::InboundGroupSession::from_libolm_pickle( + &session.pickle, + pickle_key, + )? + .pickle(); let sender_key = Curve25519PublicKey::from_base64(&session.sender_key)?; @@ -302,52 +477,46 @@ async fn migrate_data( listener(processed_steps, total_steps); } - let recovery_key = - data.backup_recovery_key.map(|k| RecoveryKey::from_base58(k.as_str())).transpose()?; - - let cross_signing = PrivateCrossSigningIdentity::empty((*user_id).into()); - cross_signing - .import_secrets_unchecked( - data.cross_signing.master_key.as_deref(), - data.cross_signing.self_signing_key.as_deref(), - data.cross_signing.user_signing_key.as_deref(), - ) - .await?; - - data.cross_signing.master_key.zeroize(); - data.cross_signing.self_signing_key.zeroize(); - data.cross_signing.user_signing_key.zeroize(); - - processed_steps += 1; - listener(processed_steps, total_steps); - - let tracked_users: Vec<_> = data - .tracked_users - .into_iter() - .map(|u| Ok(((parse_user_id(&u)?), true))) - .collect::>()?; + Ok((sessions, inbound_group_sessions)) +} - let tracked_users: Vec<_> = tracked_users.iter().map(|(u, d)| (&**u, *d)).collect(); - store.save_tracked_users(tracked_users.as_slice()).await?; +/// Migrate room settings, including room algorithm and whether to block +/// untrusted devices from legacy store to Sqlite store. +/// +/// Note that this method should only be used if a client has already migrated +/// account data via [migrate](#method.migrate) method, which did not include +/// room settings. For a brand new migration, the [migrate](#method.migrate) +/// method will take care of room settings automatically, if provided. +/// +/// # Arguments +/// +/// * `room_settings` - Map of room settings +/// +/// * `path` - The path where the Sqlite store should be created. +/// +/// * `passphrase` - The passphrase that should be used to encrypt the data at +/// rest in the Sqlite store. **Warning**, if no passphrase is given, the store +/// and all its data will remain unencrypted. +pub fn migrate_room_settings( + room_settings: HashMap, + path: &str, + passphrase: Option, +) -> anyhow::Result<()> { + let runtime = Runtime::new()?; + runtime.block_on(async move { + let store = SqliteCryptoStore::open(path, passphrase.as_deref()).await?; - processed_steps += 1; - listener(processed_steps, total_steps); + let mut rust_settings = HashMap::new(); + for (room_id, settings) in room_settings { + let room_id = RoomId::parse(room_id)?; + rust_settings.insert(room_id, settings.into()); + } - let changes = RustChanges { - account: Some(account), - private_identity: Some(cross_signing), - sessions, - inbound_group_sessions, - recovery_key, - backup_version: data.backup_version, - ..Default::default() - }; - store.save_changes(changes).await?; + let changes = Changes { room_settings: rust_settings, ..Default::default() }; + store.save_changes(changes).await?; - processed_steps += 1; - listener(processed_steps, total_steps); - - Ok(()) + Ok(()) + }) } /// Callback that will be passed over the FFI to report progress @@ -369,6 +538,7 @@ impl ProgressListener for T { } /// An encryption algorithm to be used to encrypt messages sent to a room. +#[derive(Debug, Deserialize, Serialize, PartialEq)] pub enum EventEncryptionAlgorithm { /// Olm version 1 using Curve25519, AES-256, and SHA-256. OlmV1Curve25519AesSha2, @@ -379,12 +549,22 @@ pub enum EventEncryptionAlgorithm { impl From for RustEventEncryptionAlgorithm { fn from(a: EventEncryptionAlgorithm) -> Self { match a { - EventEncryptionAlgorithm::OlmV1Curve25519AesSha2 => { - RustEventEncryptionAlgorithm::OlmV1Curve25519AesSha2 - } - EventEncryptionAlgorithm::MegolmV1AesSha2 => { - RustEventEncryptionAlgorithm::MegolmV1AesSha2 + EventEncryptionAlgorithm::OlmV1Curve25519AesSha2 => Self::OlmV1Curve25519AesSha2, + EventEncryptionAlgorithm::MegolmV1AesSha2 => Self::MegolmV1AesSha2, + } + } +} + +impl TryFrom for EventEncryptionAlgorithm { + type Error = serde_json::Error; + + fn try_from(value: RustEventEncryptionAlgorithm) -> Result { + match value { + RustEventEncryptionAlgorithm::OlmV1Curve25519AesSha2 => { + Ok(Self::OlmV1Curve25519AesSha2) } + RustEventEncryptionAlgorithm::MegolmV1AesSha2 => Ok(Self::MegolmV1AesSha2), + _ => Err(serde::de::Error::custom(format!("Unsupported algorithm {value}"))), } } } @@ -419,10 +599,10 @@ pub enum HistoryVisibility { impl From for RustHistoryVisibility { fn from(h: HistoryVisibility) -> Self { match h { - HistoryVisibility::Invited => RustHistoryVisibility::Invited, - HistoryVisibility::Joined => RustHistoryVisibility::Joined, - HistoryVisibility::Shared => RustHistoryVisibility::Shared, - HistoryVisibility::WorldReadable => RustHistoryVisibility::Shared, + HistoryVisibility::Invited => Self::Invited, + HistoryVisibility::Joined => Self::Joined, + HistoryVisibility::Shared => Self::Shared, + HistoryVisibility::WorldReadable => Self::Shared, } } } @@ -432,6 +612,7 @@ impl From for RustHistoryVisibility { /// These settings control which algorithm the room key should use, how long a /// room key should be used and some other important information that determines /// the lifetime of a room key. +#[derive(uniffi::Record)] pub struct EncryptionSettings { /// The encryption algorithm that should be used in the room. pub algorithm: EventEncryptionAlgorithm, @@ -462,6 +643,7 @@ impl From for RustEncryptionSettings { } /// An event that was successfully decrypted. +#[derive(uniffi::Record)] pub struct DecryptedEvent { /// The decrypted version of the event. pub clear_event: String, @@ -473,15 +655,53 @@ pub struct DecryptedEvent { /// key to us. Is empty if the key came directly from the sender of the /// event. pub forwarding_curve25519_chain: Vec, - /// The verification state of the device that sent us the event, note this - /// is the state of the device at the time of decryption. It may change in - /// the future if a device gets verified or deleted. - pub verification_state: VerificationState, + /// The shield state (color and message to display to user) for the event, + /// representing the event's authenticity. Computed from the properties of + /// the sender user identity and their Olm device. + /// + /// Note that this is computed at time of decryption, so the value reflects + /// the computed event authenticity at that time. Authenticity-related + /// properties can change later on, such as when a user identity is + /// subsequently verified or a device is deleted. + pub shield_state: ShieldState, +} + +/// Take a look at [`matrix_sdk_common::deserialized_responses::ShieldState`] +/// for more info. +#[allow(missing_docs)] +#[derive(uniffi::Enum)] +pub enum ShieldColor { + Red, + Grey, + None, +} + +/// Take a look at [`matrix_sdk_common::deserialized_responses::ShieldState`] +/// for more info. +#[derive(uniffi::Record)] +#[allow(missing_docs)] +pub struct ShieldState { + color: ShieldColor, + message: Option, +} + +impl From for ShieldState { + fn from(value: RustShieldState) -> Self { + match value { + RustShieldState::Red { message } => { + Self { color: ShieldColor::Red, message: Some(message.to_owned()) } + } + RustShieldState::Grey { message } => { + Self { color: ShieldColor::Grey, message: Some(message.to_owned()) } + } + RustShieldState::None => Self { color: ShieldColor::None, message: None }, + } + } } /// Struct representing the state of our private cross signing keys, it shows /// which private cross signing keys we have locally stored. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, uniffi::Record)] pub struct CrossSigningStatus { /// Do we have the master key. pub has_master: bool, @@ -587,6 +807,34 @@ impl From for CrossSigningStatus { } } +/// Room encryption settings which are modified by state events or user options +#[derive(Debug, Deserialize, Serialize, PartialEq)] +pub struct RoomSettings { + /// The encryption algorithm that should be used in the room. + pub algorithm: EventEncryptionAlgorithm, + /// Should untrusted devices receive the room key, or should they be + /// excluded from the conversation. + pub only_allow_trusted_devices: bool, +} + +impl TryFrom for RoomSettings { + type Error = serde_json::Error; + + fn try_from(value: RustRoomSettings) -> Result { + let algorithm = value.algorithm.try_into()?; + Ok(Self { algorithm, only_allow_trusted_devices: value.only_allow_trusted_devices }) + } +} + +impl From for RustRoomSettings { + fn from(value: RoomSettings) -> Self { + Self { + algorithm: value.algorithm.into(), + only_allow_trusted_devices: value.only_allow_trusted_devices, + } + } +} + fn parse_user_id(user_id: &str) -> Result { ruma::UserId::parse(user_id).map_err(|e| CryptoStoreError::InvalidUserId(user_id.to_owned(), e)) } @@ -596,10 +844,14 @@ mod uniffi_types { backup_recovery_key::{ BackupRecoveryKey, DecodeError, MegolmV1BackupKey, PassphraseInfo, PkDecryptionError, }, - error::CryptoStoreError, - machine::OlmMachine, - responses::Request, - BackupKeys, RoomKeyCounts, + error::{CryptoStoreError, DecryptionError, SecretImportError}, + machine::{KeyRequestPair, OlmMachine}, + responses::{BootstrapCrossSigningResult, DeviceLists, Request}, + verification::{ + RequestVerificationResult, StartSasResult, Verification, VerificationRequest, + }, + BackupKeys, CrossSigningKeyExport, CrossSigningStatus, DecryptedEvent, EncryptionSettings, + EventEncryptionAlgorithm, RoomKeyCounts, RoomSettings, ShieldColor, ShieldState, }; } @@ -610,7 +862,7 @@ mod test { use tempfile::tempdir; use super::MigrationData; - use crate::{migrate, OlmMachine}; + use crate::{migrate, EventEncryptionAlgorithm, OlmMachine, RoomSettings}; #[test] fn android_migration() -> Result<()> { @@ -693,7 +945,17 @@ mod test { "@this-is-me:matrix.org", "@Amandine:matrix.org", "@ganfra:matrix.org" - ] + ], + "room_settings": { + "!AZkqtjvtwPAuyNOXEt:matrix.org": { + "algorithm": "OlmV1Curve25519AesSha2", + "only_allow_trusted_devices": true + }, + "!CWLUCoEWXSFyTCOtfL:matrix.org": { + "algorithm": "MegolmV1AesSha2", + "only_allow_trusted_devices": false + }, + } }); let migration_data: MigrationData = serde_json::from_value(data)?; @@ -722,6 +984,27 @@ mod test { let backup_keys = machine.get_backup_keys()?; assert!(backup_keys.is_some()); + let settings1 = machine.get_room_settings("!AZkqtjvtwPAuyNOXEt:matrix.org".into())?; + assert_eq!( + Some(RoomSettings { + algorithm: EventEncryptionAlgorithm::OlmV1Curve25519AesSha2, + only_allow_trusted_devices: true + }), + settings1 + ); + + let settings2 = machine.get_room_settings("!CWLUCoEWXSFyTCOtfL:matrix.org".into())?; + assert_eq!( + Some(RoomSettings { + algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, + only_allow_trusted_devices: false + }), + settings2 + ); + + let settings3 = machine.get_room_settings("!XYZ:matrix.org".into())?; + assert!(settings3.is_none()); + Ok(()) } } diff --git a/bindings/matrix-sdk-crypto-ffi/src/logger.rs b/bindings/matrix-sdk-crypto-ffi/src/logger.rs index ce9ea068741..c2c99111953 100644 --- a/bindings/matrix-sdk-crypto-ffi/src/logger.rs +++ b/bindings/matrix-sdk-crypto-ffi/src/logger.rs @@ -44,9 +44,13 @@ pub struct LoggerWrapper { pub fn set_logger(logger: Box) { let logger = LoggerWrapper { inner: Arc::new(Mutex::new(logger)) }; - let filter = EnvFilter::from_default_env().add_directive( - "matrix_sdk_crypto=trace".parse().expect("Can't parse logging filter directive"), - ); + let filter = EnvFilter::from_default_env() + .add_directive( + "matrix_sdk_crypto=trace".parse().expect("Can't parse logging filter directive"), + ) + .add_directive( + "matrix_sdk_sqlite=debug".parse().expect("Can't parse logging filter directive"), + ); let _ = tracing_subscriber::fmt() .with_writer(logger) diff --git a/bindings/matrix-sdk-crypto-ffi/src/machine.rs b/bindings/matrix-sdk-crypto-ffi/src/machine.rs index 97c72c1735e..bd9b476ac90 100644 --- a/bindings/matrix-sdk-crypto-ffi/src/machine.rs +++ b/bindings/matrix-sdk-crypto-ffi/src/machine.rs @@ -16,7 +16,7 @@ use matrix_sdk_crypto::{ }, decrypt_room_key_export, encrypt_room_key_export, olm::ExportedRoomKey, - store::RecoveryKey, + store::{Changes, RecoveryKey}, LocalTrust, OlmMachine as InnerMachine, UserIdentities, }; use ruma::{ @@ -53,9 +53,9 @@ use crate::{ responses::{response_from_string, OwnedResponse}, BackupKeys, BackupRecoveryKey, BootstrapCrossSigningResult, CrossSigningKeyExport, CrossSigningStatus, DecodeError, DecryptedEvent, Device, DeviceLists, EncryptionSettings, - KeyImportError, KeysImportResult, MegolmV1BackupKey, ProgressListener, Request, RequestType, - RequestVerificationResult, RoomKeyCounts, Sas, SignatureUploadRequest, StartSasResult, - UserIdentity, Verification, VerificationRequest, + EventEncryptionAlgorithm, KeyImportError, KeysImportResult, MegolmV1BackupKey, + ProgressListener, Request, RequestType, RequestVerificationResult, RoomKeyCounts, RoomSettings, + Sas, SignatureUploadRequest, StartSasResult, UserIdentity, Verification, VerificationRequest, }; /// A high level state machine that handles E2EE for Matrix. @@ -147,6 +147,14 @@ impl OlmMachine { HashMap::from([("ed25519".to_owned(), ed25519_key), ("curve25519".to_owned(), curve_key)]) } + + /// Get the status of the private cross signing keys. + /// + /// This can be used to check which private cross signing keys we have + /// stored locally. + pub fn cross_signing_status(&self) -> CrossSigningStatus { + self.runtime.block_on(self.inner.cross_signing_status()).into() + } } impl OlmMachine { @@ -174,18 +182,7 @@ impl OlmMachine { let runtime = Runtime::new().expect("Couldn't create a tokio runtime"); let store = runtime - .block_on(matrix_sdk_sqlite::SqliteCryptoStore::open(path, passphrase.as_deref())) - .map_err(|e| match e { - // This is a bit of an error in the sled store, the - // CryptoStore returns an `OpenStoreError` which has a - // variant for the state store. Not sure what to do about - // this. - matrix_sdk_sqlite::OpenStoreError::Crypto(r) => r.into(), - matrix_sdk_sqlite::OpenStoreError::Sqlite(s) => CryptoStoreError::CryptoStore( - matrix_sdk_crypto::store::CryptoStoreError::backend(s), - ), - _ => unreachable!(), - })?; + .block_on(matrix_sdk_sqlite::SqliteCryptoStore::open(path, passphrase.as_deref()))?; passphrase.zeroize(); @@ -450,7 +447,10 @@ impl OlmMachine { Ok(()) } +} +#[uniffi::export] +impl OlmMachine { /// Let the state machine know about E2EE related sync changes that we /// received from the server. /// @@ -468,12 +468,12 @@ impl OlmMachine { /// * `key_counts` - The map of uploaded one-time key types and counts. pub fn receive_sync_changes( &self, - events: &str, + events: String, device_changes: DeviceLists, key_counts: HashMap, unused_fallback_keys: Option>, ) -> Result { - let to_device: ToDevice = serde_json::from_str(events)?; + let to_device: ToDevice = serde_json::from_str(&events)?; let device_changes: RumaDeviceLists = device_changes.into(); let key_counts: BTreeMap = key_counts .into_iter() @@ -528,8 +528,8 @@ impl OlmMachine { /// /// A user can be marked for tracking using the /// [`OlmMachine::update_tracked_users()`] method. - pub fn is_user_tracked(&self, user_id: &str) -> Result { - let user_id = parse_user_id(user_id)?; + pub fn is_user_tracked(&self, user_id: String) -> Result { + let user_id = parse_user_id(&user_id)?; Ok(self.runtime.block_on(self.inner.tracked_users())?.contains(&user_id)) } @@ -560,6 +560,101 @@ impl OlmMachine { .map(|r| r.into())) } + /// Get the stored room settings, such as the encryption algorithm or + /// whether to encrypt only for trusted devices. + /// + /// These settings can be modified via + /// [set_room_algorithm()](#method.set_room_algorithm) and + /// [set_room_only_allow_trusted_devices()](#method. + /// set_room_only_allow_trusted_devices) methods. + pub fn get_room_settings( + &self, + room_id: String, + ) -> Result, CryptoStoreError> { + let room_id = RoomId::parse(room_id)?; + let settings = self + .runtime + .block_on(self.inner.store().get_room_settings(&room_id))? + .map(|v| v.try_into()) + .transpose()?; + Ok(settings) + } + + /// Set the room algorithm used for encrypting messages to one of the + /// available variants + pub fn set_room_algorithm( + &self, + room_id: String, + algorithm: EventEncryptionAlgorithm, + ) -> Result<(), CryptoStoreError> { + let room_id = RoomId::parse(room_id)?; + self.runtime.block_on(async move { + let mut settings = + self.inner.store().get_room_settings(&room_id).await?.unwrap_or_default(); + settings.algorithm = algorithm.into(); + self.inner + .store() + .save_changes(Changes { + room_settings: HashMap::from([(room_id, settings)]), + ..Default::default() + }) + .await?; + Ok(()) + }) + } + + /// Set flag whether this room should encrypt messages for untrusted + /// devices, or whether they should be excluded from the conversation. + /// + /// Note that per-room setting may be overridden by a global + /// [set_only_allow_trusted_devices()](#method. + /// set_only_allow_trusted_devices) method. + pub fn set_room_only_allow_trusted_devices( + &self, + room_id: String, + only_allow_trusted_devices: bool, + ) -> Result<(), CryptoStoreError> { + let room_id = RoomId::parse(room_id)?; + self.runtime.block_on(async move { + let mut settings = + self.inner.store().get_room_settings(&room_id).await?.unwrap_or_default(); + settings.only_allow_trusted_devices = only_allow_trusted_devices; + self.inner + .store() + .save_changes(Changes { + room_settings: HashMap::from([(room_id, settings)]), + ..Default::default() + }) + .await?; + Ok(()) + }) + } + + /// Check whether there is a global flag to only encrypt messages for + /// trusted devices or for everyone. + /// + /// Note that if the global flag is false, individual rooms may still be + /// encrypting only for trusted devices, depending on the per-room + /// `only_allow_trusted_devices` flag. + pub fn get_only_allow_trusted_devices(&self) -> Result { + let block = self.runtime.block_on(self.inner.store().get_only_allow_trusted_devices())?; + Ok(block) + } + + /// Set global flag whether to encrypt messages for untrusted devices, or + /// whether they should be excluded from the conversation. + /// + /// Note that if enabled, it will override any per-room settings. + pub fn set_only_allow_trusted_devices( + &self, + only_allow_trusted_devices: bool, + ) -> Result<(), CryptoStoreError> { + self.runtime.block_on( + self.inner.store().set_only_allow_trusted_devices(only_allow_trusted_devices), + )?; + Ok(()) + } + /// Share a room key with the given list of users for the given room. /// /// After the request was sent out and a successful response was received @@ -581,7 +676,7 @@ impl OlmMachine { /// * `settings` - The settings that should be used for the room key. pub fn share_room_key( &self, - room_id: &str, + room_id: String, users: Vec, settings: EncryptionSettings, ) -> Result, CryptoStoreError> { @@ -633,16 +728,16 @@ impl OlmMachine { /// * `content` - The serialized content of the event. pub fn encrypt( &self, - room_id: &str, - event_type: &str, - content: &str, + room_id: String, + event_type: String, + content: String, ) -> Result { let room_id = RoomId::parse(room_id)?; - let content: Value = serde_json::from_str(content)?; + let content: Value = serde_json::from_str(&content)?; let encrypted_content = self .runtime - .block_on(self.inner.encrypt_room_event_raw(&room_id, content, event_type)) + .block_on(self.inner.encrypt_room_event_raw(&room_id, content, &event_type)) .expect("Encrypting an event produced an error"); Ok(serde_json::to_string(&encrypted_content)?) @@ -655,11 +750,16 @@ impl OlmMachine { /// * `event` - The serialized encrypted version of the event. /// /// * `room_id` - The unique id of the room where the event was sent to. + /// + /// * `strict_shields` - If `true`, messages will be decorated with strict + /// warnings (use `false` to match legacy behaviour where unsafe keys have + /// lower severity warnings and unverified identities are not decorated). pub fn decrypt_room_event( &self, - event: &str, - room_id: &str, + event: String, + room_id: String, handle_verification_events: bool, + strict_shields: bool, ) -> Result { // Element Android wants only the content and the type and will create a // decrypted event with those two itself, this struct makes sure we @@ -672,7 +772,7 @@ impl OlmMachine { content: &'a RawValue, } - let event: Raw<_> = serde_json::from_str(event)?; + let event: Raw<_> = serde_json::from_str(&event)?; let room_id = RoomId::parse(room_id)?; let decrypted = self.runtime.block_on(self.inner.decrypt_room_event(&event, &room_id))?; @@ -710,7 +810,11 @@ impl OlmMachine { .get(&DeviceKeyAlgorithm::Ed25519) .cloned(), forwarding_curve25519_chain: vec![], - verification_state: encryption_info.verification_state, + shield_state: if strict_shields { + encryption_info.verification_state.to_shield_state_strict().into() + } else { + encryption_info.verification_state.to_shield_state_lax().into() + }, } } }) @@ -727,10 +831,10 @@ impl OlmMachine { /// * `room_id` - The id of the room the event was sent to. pub fn request_room_key( &self, - event: &str, - room_id: &str, + event: String, + room_id: String, ) -> Result { - let event: Raw<_> = serde_json::from_str(event)?; + let event: Raw<_> = serde_json::from_str(&event)?; let room_id = RoomId::parse(room_id)?; let (cancel, request) = @@ -753,17 +857,19 @@ impl OlmMachine { /// passphrase into an key. pub fn export_room_keys( &self, - passphrase: &str, + passphrase: String, rounds: i32, ) -> Result { let keys = self.runtime.block_on(self.inner.export_room_keys(|_| true))?; - let encrypted = encrypt_room_key_export(&keys, passphrase, rounds as u32) + let encrypted = encrypt_room_key_export(&keys, &passphrase, rounds as u32) .map_err(CryptoStoreError::Serialization)?; Ok(encrypted) } +} +impl OlmMachine { fn import_room_keys_helper( &self, keys: Vec, @@ -838,10 +944,13 @@ impl OlmMachine { self.import_room_keys_helper(keys, true, progress_listener) } +} +#[uniffi::export] +impl OlmMachine { /// Discard the currently active room key for the given room if there is /// one. - pub fn discard_room_key(&self, room_id: &str) -> Result<(), CryptoStoreError> { + pub fn discard_room_key(&self, room_id: String) -> Result<(), CryptoStoreError> { let room_id = RoomId::parse(room_id)?; self.runtime.block_on(self.inner.invalidate_group_session(&room_id))?; @@ -857,8 +966,8 @@ impl OlmMachine { /// **Note**: This has been deprecated. pub fn receive_unencrypted_verification_event( &self, - event: &str, - room_id: &str, + event: String, + room_id: String, ) -> Result<(), CryptoStoreError> { self.receive_verification_event(event, room_id) } @@ -869,11 +978,11 @@ impl OlmMachine { /// in rooms to the `OlmMachine`. The event should be in the decrypted form. pub fn receive_verification_event( &self, - event: &str, - room_id: &str, + event: String, + room_id: String, ) -> Result<(), CryptoStoreError> { let room_id = RoomId::parse(room_id)?; - let event: AnySyncMessageLikeEvent = serde_json::from_str(event)?; + let event: AnySyncMessageLikeEvent = serde_json::from_str(&event)?; let event = event.into_full_event(room_id); @@ -888,7 +997,7 @@ impl OlmMachine { /// /// * `user_id` - The ID of the user for which we would like to fetch the /// verification requests. - pub fn get_verification_requests(&self, user_id: &str) -> Vec> { + pub fn get_verification_requests(&self, user_id: String) -> Vec> { let Ok(user_id) = UserId::parse(user_id) else { return vec![]; }; @@ -913,8 +1022,8 @@ impl OlmMachine { /// * `flow_id` - The ID that uniquely identifies the verification flow. pub fn get_verification_request( &self, - user_id: &str, - flow_id: &str, + user_id: String, + flow_id: String, ) -> Option> { let user_id = UserId::parse(user_id).ok()?; @@ -934,10 +1043,10 @@ impl OlmMachine { /// support. pub fn verification_request_content( &self, - user_id: &str, + user_id: String, methods: Vec, ) -> Result, CryptoStoreError> { - let user_id = parse_user_id(user_id)?; + let user_id = parse_user_id(&user_id)?; let identity = self.runtime.block_on(self.inner.get_identity(&user_id, None))?; @@ -974,12 +1083,12 @@ impl OlmMachine { /// [verification_request_content()]: #method.verification_request_content pub fn request_verification( &self, - user_id: &str, - room_id: &str, - event_id: &str, + user_id: String, + room_id: String, + event_id: String, methods: Vec, ) -> Result>, CryptoStoreError> { - let user_id = parse_user_id(user_id)?; + let user_id = parse_user_id(&user_id)?; let event_id = EventId::parse(event_id)?; let room_id = RoomId::parse(room_id)?; @@ -1016,17 +1125,18 @@ impl OlmMachine { /// supported in the `m.key.verification.request` event. pub fn request_verification_with_device( &self, - user_id: &str, - device_id: &str, + user_id: String, + device_id: String, methods: Vec, ) -> Result, CryptoStoreError> { - let user_id = parse_user_id(user_id)?; + let user_id = parse_user_id(&user_id)?; + let device_id = device_id.as_str().into(); let methods = methods.into_iter().map(VerificationMethod::from).collect(); Ok( if let Some(device) = - self.runtime.block_on(self.inner.get_device(&user_id, device_id.into(), None))? + self.runtime.block_on(self.inner.get_device(&user_id, device_id, None))? { let (verification, request) = self.runtime.block_on(device.request_verification_with_methods(methods)); @@ -1085,11 +1195,11 @@ impl OlmMachine { /// verification. /// /// * `flow_id` - The ID that uniquely identifies the verification flow. - pub fn get_verification(&self, user_id: &str, flow_id: &str) -> Option> { + pub fn get_verification(&self, user_id: String, flow_id: String) -> Option> { let user_id = UserId::parse(user_id).ok()?; self.inner - .get_verification(&user_id, flow_id) + .get_verification(&user_id, &flow_id) .map(|v| Verification { inner: v, runtime: self.runtime.handle().to_owned() }.into()) } @@ -1109,14 +1219,15 @@ impl OlmMachine { /// [request_verification_with_device()]: #method.request_verification_with_device pub fn start_sas_with_device( &self, - user_id: &str, - device_id: &str, + user_id: String, + device_id: String, ) -> Result, CryptoStoreError> { - let user_id = parse_user_id(user_id)?; + let user_id = parse_user_id(&user_id)?; + let device_id = device_id.as_str().into(); Ok( if let Some(device) = - self.runtime.block_on(self.inner.get_device(&user_id, device_id.into(), None))? + self.runtime.block_on(self.inner.get_device(&user_id, device_id, None))? { let (sas, request) = self.runtime.block_on(device.start_verification())?; @@ -1136,14 +1247,6 @@ impl OlmMachine { Ok(self.runtime.block_on(self.inner.bootstrap_cross_signing(true))?.into()) } - /// Get the status of the private cross signing keys. - /// - /// This can be used to check which private cross signing keys we have - /// stored locally. - pub fn cross_signing_status(&self) -> CrossSigningStatus { - self.runtime.block_on(self.inner.cross_signing_status()).into() - } - /// Export all our private cross signing keys. /// /// The export will contain the seed for the ed25519 keys as a base64 @@ -1167,10 +1270,7 @@ impl OlmMachine { Ok(()) } -} -#[uniffi::export] -impl OlmMachine { /// Activate the given backup key to be used with the given backup version. /// /// **Warning**: The caller needs to make sure that the given `BackupKey` is diff --git a/bindings/matrix-sdk-crypto-ffi/src/olm.udl b/bindings/matrix-sdk-crypto-ffi/src/olm.udl index fc5af203816..a175051e446 100644 --- a/bindings/matrix-sdk-crypto-ffi/src/olm.udl +++ b/bindings/matrix-sdk-crypto-ffi/src/olm.udl @@ -7,11 +7,24 @@ namespace matrix_sdk_crypto_ffi { string? passphrase, ProgressListener progress_listener ); + [Throws=MigrationError] + void migrate_sessions( + SessionMigrationData data, + [ByRef] string path, + string? passphrase, + ProgressListener progress_listener + ); + [Throws=MigrationError] + void migrate_room_settings( + record room_settings, + [ByRef] string path, + string? passphrase + ); }; [Error] interface MigrationError { - Generic(string error_message); + Generic(string error_message); }; callback interface Logger { @@ -47,6 +60,7 @@ enum SecretImportError { [Error] enum CryptoStoreError { + "OpenStore", "CryptoStore", "OlmError", "Serialization", @@ -63,25 +77,12 @@ enum DecryptionError { "Store", }; -dictionary DeviceLists { - sequence changed; - sequence left; -}; - dictionary KeysImportResult { i64 imported; i64 total; record>> keys; }; -dictionary DecryptedEvent { - string clear_event; - string sender_curve25519_key; - string? claimed_ed25519_key; - sequence forwarding_curve25519_chain; - VerificationState verification_state; -}; - dictionary Device { string user_id; string device_id; @@ -109,12 +110,6 @@ interface UserIdentity { ); }; -dictionary CrossSigningStatus { - boolean has_master; - boolean has_self_signing; - boolean has_user_signing; -}; - dictionary CrossSigningKeyExport { string? master_key; string? self_signing_key; @@ -165,12 +160,12 @@ interface Sas { [Enum] interface SasState { - Started(); - Accepted(); - KeysExchanged(sequence? emojis, sequence decimals); - Confirmed(); - Done(); - Cancelled(CancelInfo cancel_info); + Started(); + Accepted(); + KeysExchanged(sequence? emojis, sequence decimals); + Confirmed(); + Done(); + Cancelled(CancelInfo cancel_info); }; callback interface SasListener { @@ -205,12 +200,12 @@ interface QrCode { [Enum] interface QrCodeState { - Started(); - Scanned(); - Confirmed(); - Reciprocated(); - Done(); - Cancelled(CancelInfo cancel_info); + Started(); + Scanned(); + Confirmed(); + Reciprocated(); + Done(); + Cancelled(CancelInfo cancel_info); }; callback interface QrCodeListener { @@ -249,10 +244,10 @@ interface VerificationRequest { [Enum] interface VerificationRequestState { - Requested(); - Ready(sequence their_methods, sequence our_methods); - Done(); - Cancelled(CancelInfo cancel_info); + Requested(); + Ready(sequence their_methods, sequence our_methods); + Done(); + Cancelled(CancelInfo cancel_info); }; callback interface VerificationRequestListener { @@ -317,11 +312,6 @@ enum LocalTrust { "Unset", }; -enum VerificationState { - "Trusted", - "Untrusted", - "UnknownDevice", -}; enum EventEncryptionAlgorithm { "OlmV1Curve25519AesSha2", @@ -335,14 +325,6 @@ enum HistoryVisibility { "WorldReadable", }; -dictionary EncryptionSettings { - EventEncryptionAlgorithm algorithm; - u64 rotation_period; - u64 rotation_period_msgs; - HistoryVisibility history_visibility; - boolean only_allow_trusted_devices; -}; - interface OlmMachine { [Throws=CryptoStoreError] constructor( @@ -352,11 +334,6 @@ interface OlmMachine { string? passphrase ); - [Throws=CryptoStoreError] - string receive_sync_changes([ByRef] string events, - DeviceLists device_changes, - record key_counts, - sequence? unused_fallback_keys); [Throws=CryptoStoreError] sequence outgoing_requests(); [Throws=CryptoStoreError] @@ -366,11 +343,6 @@ interface OlmMachine { [ByRef] string response ); - [Throws=DecryptionError] - DecryptedEvent decrypt_room_event([ByRef] string event, [ByRef] string room_id, boolean handle_verificaton_events); - [Throws=CryptoStoreError] - string encrypt([ByRef] string room_id, [ByRef] string event_type, [ByRef] string content); - [Throws=CryptoStoreError] UserIdentity? get_identity([ByRef] string user_id, u32 timeout); [Throws=SignatureError] @@ -384,56 +356,6 @@ interface OlmMachine { [Throws=CryptoStoreError] sequence get_user_devices([ByRef] string user_id, u32 timeout); - [Throws=CryptoStoreError] - boolean is_user_tracked([ByRef] string user_id); - [Throws=CryptoStoreError] - void update_tracked_users(sequence users); - [Throws=CryptoStoreError] - Request? get_missing_sessions(sequence users); - [Throws=CryptoStoreError] - sequence share_room_key( - [ByRef] string room_id, - sequence users, - EncryptionSettings settings - ); - - [Throws=CryptoStoreError] - void receive_unencrypted_verification_event([ByRef] string event, [ByRef] string room_id); - [Throws=CryptoStoreError] - void receive_verification_event([ByRef] string event, [ByRef] string room_id); - sequence get_verification_requests([ByRef] string user_id); - VerificationRequest? get_verification_request([ByRef] string user_id, [ByRef] string flow_id); - Verification? get_verification([ByRef] string user_id, [ByRef] string flow_id); - - [Throws=CryptoStoreError] - VerificationRequest? request_verification( - [ByRef] string user_id, - [ByRef] string room_id, - [ByRef] string event_id, - sequence methods - ); - [Throws=CryptoStoreError] - string? verification_request_content( - [ByRef] string user_id, - sequence methods - ); - [Throws=CryptoStoreError] - RequestVerificationResult? request_self_verification(sequence methods); - [Throws=CryptoStoreError] - RequestVerificationResult? request_verification_with_device( - [ByRef] string user_id, - [ByRef] string device_id, - sequence methods - ); - - [Throws=CryptoStoreError] - StartSasResult? start_sas_with_device([ByRef] string user_id, [ByRef] string device_id); - - [Throws=DecryptionError] - KeyRequestPair request_room_key([ByRef] string event, [ByRef] string room_id); - - [Throws=CryptoStoreError] - string export_room_keys([ByRef] string passphrase, i32 rounds); [Throws=KeyImportError] KeysImportResult import_room_keys( [ByRef] string keys, @@ -445,15 +367,7 @@ interface OlmMachine { [ByRef] string keys, ProgressListener progress_listener ); - [Throws=CryptoStoreError] - void discard_room_key([ByRef] string room_id); - CrossSigningStatus cross_signing_status(); - [Throws=CryptoStoreError] - BootstrapCrossSigningResult bootstrap_cross_signing(); - CrossSigningKeyExport? export_cross_signing_keys(); - [Throws=SecretImportError] - void import_cross_signing_keys(CrossSigningKeyExport export); [Throws=CryptoStoreError] boolean is_identity_verified([ByRef] string user_id); @@ -508,6 +422,17 @@ dictionary MigrationData { sequence pickle_key; CrossSigningKeyExport cross_signing; sequence tracked_users; + record room_settings; +}; + +dictionary SessionMigrationData { + string user_id; + string device_id; + string curve25519_key; + string ed25519_key; + sequence sessions; + sequence inbound_group_sessions; + sequence pickle_key; }; dictionary PickledAccount { @@ -535,3 +460,8 @@ dictionary PickledInboundGroupSession { boolean imported; boolean backed_up; }; + +dictionary RoomSettings { + EventEncryptionAlgorithm algorithm; + boolean only_allow_trusted_devices; +}; diff --git a/bindings/matrix-sdk-crypto-ffi/src/responses.rs b/bindings/matrix-sdk-crypto-ffi/src/responses.rs index 24a0c2c147f..935c3e2e1e5 100644 --- a/bindings/matrix-sdk-crypto-ffi/src/responses.rs +++ b/bindings/matrix-sdk-crypto-ffi/src/responses.rs @@ -112,7 +112,7 @@ impl From for OutgoingVerificationRequest { } } -#[derive(Debug, uniffi::Enum)] +#[derive(Debug)] pub enum Request { ToDevice { request_id: String, event_type: String, body: String }, KeysUpload { request_id: String, body: String }, @@ -231,6 +231,7 @@ pub enum RequestType { RoomMessage, } +#[derive(uniffi::Record)] pub struct DeviceLists { pub changed: Vec, pub left: Vec, diff --git a/bindings/matrix-sdk-crypto-ffi/src/verification.rs b/bindings/matrix-sdk-crypto-ffi/src/verification.rs index 6eafbb88997..fda5a6989b9 100644 --- a/bindings/matrix-sdk-crypto-ffi/src/verification.rs +++ b/bindings/matrix-sdk-crypto-ffi/src/verification.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use base64::{ - alphabet, decode_engine, encode_engine, - engine::fast_portable::{self, FastPortable}, + alphabet, + engine::{general_purpose, GeneralPurpose}, + Engine, }; use futures_util::{Stream, StreamExt}; use matrix_sdk_crypto::{ @@ -16,8 +17,8 @@ use tokio::runtime::Handle; use crate::{CryptoStoreError, OutgoingVerificationRequest, SignatureUploadRequest}; -const STANDARD_NO_PAD: FastPortable = - FastPortable::from(&alphabet::STANDARD, fast_portable::NO_PAD); +const STANDARD_NO_PAD: GeneralPurpose = + GeneralPurpose::new(&alphabet::STANDARD, general_purpose::NO_PAD); /// Listener that will be passed over the FFI to report changes to a SAS /// verification. @@ -407,7 +408,7 @@ impl QrCode { /// decoded on the other side before it can be put through a QR code /// generator. pub fn generate_qr_code(&self) -> Option { - self.inner.to_bytes().map(|data| encode_engine(data, &STANDARD_NO_PAD)).ok() + self.inner.to_bytes().map(|data| STANDARD_NO_PAD.encode(data)).ok() } /// Set a listener for changes in the QrCode verification process. @@ -709,7 +710,7 @@ impl VerificationRequest { /// * `data` - The data that was extracted from the scanned QR code as an /// base64 encoded string, without padding. pub fn scan_qr_code(&self, data: &str) -> Option { - let data = decode_engine(data, &STANDARD_NO_PAD).ok()?; + let data = STANDARD_NO_PAD.decode(data).ok()?; let data = QrVerificationData::from_bytes(data).ok()?; if let Some(qr) = self.runtime.block_on(self.inner.scan_qr_code(data)).ok()? { diff --git a/bindings/matrix-sdk-crypto-js/.prettierignore b/bindings/matrix-sdk-crypto-js/.prettierignore new file mode 100644 index 00000000000..fc5bd908320 --- /dev/null +++ b/bindings/matrix-sdk-crypto-js/.prettierignore @@ -0,0 +1 @@ +/pkg diff --git a/bindings/matrix-sdk-crypto-js/.prettierrc.js b/bindings/matrix-sdk-crypto-js/.prettierrc.js new file mode 100644 index 00000000000..f739c10be90 --- /dev/null +++ b/bindings/matrix-sdk-crypto-js/.prettierrc.js @@ -0,0 +1,9 @@ +// prettier configuration: the same as the conventions used throughout Matrix.org +// see: https://github.com/matrix-org/eslint-plugin-matrix-org/blob/main/.prettierrc.js + +module.exports = { + printWidth: 120, + tabWidth: 4, + quoteProps: "consistent", + trailingComma: "all", +}; diff --git a/bindings/matrix-sdk-crypto-js/README.md b/bindings/matrix-sdk-crypto-js/README.md index 640055c4a3b..33ea75ee557 100644 --- a/bindings/matrix-sdk-crypto-js/README.md +++ b/bindings/matrix-sdk-crypto-js/README.md @@ -49,8 +49,6 @@ $ npm run doc The documentation is generated in the `./docs` directory. - - [WebAssembly]: https://webassembly.org/ [`matrix-sdk-crypto`]: https://github.com/matrix-org/matrix-rust-sdk/tree/main/crates/matrix-sdk-crypto [`matrix-rust-sdk`]: https://github.com/matrix-org/matrix-rust-sdk diff --git a/bindings/matrix-sdk-crypto-js/cliff.toml b/bindings/matrix-sdk-crypto-js/cliff.toml deleted file mode 100644 index 26f33b838cc..00000000000 --- a/bindings/matrix-sdk-crypto-js/cliff.toml +++ /dev/null @@ -1,61 +0,0 @@ -# configuration file for git-cliff (0.1.0) - -[changelog] -# changelog header -header = """ -# Matrix SDK Crypto JavaScript Changelog\n -All notable changes to this project will be documented in this file.\n -""" -# template for the changelog body -# https://tera.netlify.app/docs/#introduction -body = """ -{% if version %}\ - ## [{{ version | trim_start_matches(pat="v") }}] - {{ timestamp | date(format="%Y-%m-%d") }} -{% else %}\ - ## [unreleased] -{% endif %}\ -{% for group, commits in commits | filter(attribute="scope", value="crypto-js") | group_by(attribute="group") %} - ### {{ group | upper_first }} - {% for commit in commits %} - - {% if commit.breaking %}[**breaking**] {% endif %}{{ commit.message | upper_first }}\ - {% endfor %} -{% endfor %}\n -""" -# remove the leading and trailing whitespace from the template -trim = true -# changelog footer -footer = """ -""" - -[git] -# parse the commits based on https://www.conventionalcommits.org -conventional_commits = true -# filter out the commits that are not conventional -filter_unconventional = true -# regex for preprocessing the commit messages -commit_preprocessors = [ - { pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](https://github.com/matrix-org/matrix-rust-sdk/issues/${2}))"}, -] -# regex for parsing and grouping commits -commit_parsers = [ - { message = "^feat", group = "Features"}, - { message = "^fix", group = "Bug Fixes"}, - { message = "^test", group = "Testing"}, - { message = "^doc", group = "Documentation"}, - { message = "^refactor", group = "Refactoring"}, - { message = "^ci", group = "Continuous Integration"}, - { message = "^chore", group = "Miscellaneous Tasks"}, - { body = ".*security", group = "Security"}, -] -# filter out the commits that are not matched by commit parsers -filter_commits = false -# glob pattern for matching git tags -tag_pattern = "v[0-9]*" -# regex for skipping tags -skip_tags = "" -# regex for ignoring tags -ignore_tags = "" -# sort the tags chronologically -date_order = false -# sort the commits inside sections by oldest/newest order -sort_commits = "oldest" diff --git a/bindings/matrix-sdk-crypto-js/package.json b/bindings/matrix-sdk-crypto-js/package.json index a3152425900..47c5bb57c5e 100644 --- a/bindings/matrix-sdk-crypto-js/package.json +++ b/bindings/matrix-sdk-crypto-js/package.json @@ -30,6 +30,7 @@ "cross-env": "^7.0.3", "fake-indexeddb": "^4.0", "jest": "^28.1.0", + "prettier": "^2.8.3", "typedoc": "^0.22.17", "wasm-pack": "^0.10.2", "yargs-parser": "~21.0.1" @@ -38,7 +39,9 @@ "node": ">= 10" }, "scripts": { - "build": "./scripts/build.sh", + "lint": "prettier --check .", + "build": "WASM_PACK_ARGS=--release ./scripts/build.sh", + "build:dev": "WASM_PACK_ARGS=--dev ./scripts/build.sh", "test": "jest --verbose", "doc": "typedoc --tsconfig .", "prepack": "npm run build && npm run test" diff --git a/bindings/matrix-sdk-crypto-js/scripts/build.sh b/bindings/matrix-sdk-crypto-js/scripts/build.sh index 6eb149a9697..f782a2b2e04 100755 --- a/bindings/matrix-sdk-crypto-js/scripts/build.sh +++ b/bindings/matrix-sdk-crypto-js/scripts/build.sh @@ -16,7 +16,7 @@ set -e cd $(dirname "$0")/.. -RUSTFLAGS='-C opt-level=z' WASM_BINDGEN_WEAKREF=1 wasm-pack build --release --target nodejs --scope matrix-org --out-dir pkg +RUSTFLAGS='-C opt-level=z' WASM_BINDGEN_WEAKREF=1 wasm-pack build --target nodejs --scope matrix-org --out-dir pkg "${WASM_PACK_ARGS[@]}" # Convert the Wasm into a JS file that exports the base64'ed Wasm. echo "module.exports = \`$(base64 pkg/matrix_sdk_crypto_js_bg.wasm)\`;" > pkg/matrix_sdk_crypto_js_bg.wasm.js diff --git a/bindings/matrix-sdk-crypto-js/scripts/epilogue.js b/bindings/matrix-sdk-crypto-js/scripts/epilogue.js index 8db0d3f098a..f294804b22b 100644 --- a/bindings/matrix-sdk-crypto-js/scripts/epilogue.js +++ b/bindings/matrix-sdk-crypto-js/scripts/epilogue.js @@ -2,12 +2,15 @@ // replace 'wasm' with a reference to the exports from the wasm module. // // Ideally this will never get used because the application will call initAsync instead. -wasm = new Proxy({}, { - get: (target, prop, receiver) => __initSync()[prop], -}); +wasm = new Proxy( + {}, + { + get: (target, prop, receiver) => __initSync()[prop], + }, +); let inited = false; -__initSync = function() { +__initSync = function () { if (inited) { return; } @@ -21,7 +24,7 @@ __initSync = function() { wasm.__wbindgen_start(); inited = true; return wasm; -} +}; let initPromise = null; @@ -39,43 +42,47 @@ module.exports.initAsync = function () { if (!initPromise) { initPromise = Promise.resolve() .then(() => require("./matrix_sdk_crypto_js_bg.wasm.js")) - .then(b64 => WebAssembly.instantiate(unbase64(b64), imports)) - .then(result => { + .then((b64) => WebAssembly.instantiate(unbase64(b64), imports)) + .then((result) => { wasm = result.instance.exports; wasm.__wbindgen_start(); inited = true; }); } return initPromise; -} +}; -const b64lookup = new Uint8Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 62, 0, 62, 0, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0, 0, 0, 0, 63, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51]); +const b64lookup = new Uint8Array([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 62, 0, 62, 0, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0, 0, 0, 0, 63, 0, 26, 27, 28, 29, 30, 31, 32, + 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, +]); // base64 decoder, based on the code at https://developer.mozilla.org/en-US/docs/Glossary/Base64#solution_2_%E2%80%93_rewriting_atob_and_btoa_using_typedarrays_and_utf-8 -function unbase64(sBase64) { - const sB64Enc = sBase64.replace(/[^A-Za-z0-9+/]/g, ""); - const nInLen = sB64Enc.length; - const nOutLen = (nInLen * 3 + 1) >> 2; - const taBytes = new Uint8Array(nOutLen); +function unbase64(sBase64) { + const sB64Enc = sBase64.replace(/[^A-Za-z0-9+/]/g, ""); + const nInLen = sB64Enc.length; + const nOutLen = (nInLen * 3 + 1) >> 2; + const taBytes = new Uint8Array(nOutLen); - let nMod3; - let nMod4; - let nUint24 = 0; - let nOutIdx = 0; - for (let nInIdx = 0; nInIdx < nInLen; nInIdx++) { - nMod4 = nInIdx & 3; - nUint24 |= b64lookup[sB64Enc.charCodeAt(nInIdx)] << (6 * (3 - nMod4)); - if (nMod4 === 3 || nInLen - nInIdx === 1) { - nMod3 = 0; - while (nMod3 < 3 && nOutIdx < nOutLen) { - taBytes[nOutIdx] = (nUint24 >>> ((16 >>> nMod3) & 24)) & 255; - nMod3++; - nOutIdx++; - } - nUint24 = 0; + let nMod3; + let nMod4; + let nUint24 = 0; + let nOutIdx = 0; + for (let nInIdx = 0; nInIdx < nInLen; nInIdx++) { + nMod4 = nInIdx & 3; + nUint24 |= b64lookup[sB64Enc.charCodeAt(nInIdx)] << (6 * (3 - nMod4)); + if (nMod4 === 3 || nInLen - nInIdx === 1) { + nMod3 = 0; + while (nMod3 < 3 && nOutIdx < nOutLen) { + taBytes[nOutIdx] = (nUint24 >>> ((16 >>> nMod3) & 24)) & 255; + nMod3++; + nOutIdx++; + } + nUint24 = 0; + } } - } - - return taBytes; -}; + return taBytes; +} diff --git a/bindings/matrix-sdk-crypto-js/src/encryption.rs b/bindings/matrix-sdk-crypto-js/src/encryption.rs index aa4d1c65a73..72cc1e4afc8 100644 --- a/bindings/matrix-sdk-crypto-js/src/encryption.rs +++ b/bindings/matrix-sdk-crypto-js/src/encryption.rs @@ -2,6 +2,7 @@ use std::time::Duration; +use matrix_sdk_common::deserialized_responses::ShieldState as RustShieldState; use wasm_bindgen::prelude::*; use crate::events; @@ -108,28 +109,48 @@ impl From for EncryptionAlgo } } -/// The verification state of the device that sent an event to us. +/// Take a look at [`matrix_sdk_common::deserialized_responses::ShieldState`] +/// for more info. #[wasm_bindgen] -#[derive(Debug)] -pub enum VerificationState { - /// The device is trusted. - Trusted, - - /// The device is not trusted. - Untrusted, +#[derive(Debug, Clone, Copy)] +pub enum ShieldColor { + /// Important warning + Red, + /// Low warning + Grey, + /// No warning + None, +} - /// The device is not known to us. - UnknownDevice, +/// Take a look at [`matrix_sdk_common::deserialized_responses::ShieldState`] +/// for more info. +#[wasm_bindgen] +#[derive(Debug, Clone)] +pub struct ShieldState { + /// The shield color + pub color: ShieldColor, + message: Option, } -impl From<&matrix_sdk_common::deserialized_responses::VerificationState> for VerificationState { - fn from(value: &matrix_sdk_common::deserialized_responses::VerificationState) -> Self { - use matrix_sdk_common::deserialized_responses::VerificationState::*; +#[wasm_bindgen] +impl ShieldState { + /// Error message that can be displayed as a tooltip + #[wasm_bindgen(getter)] + pub fn message(&self) -> Option { + self.message.clone() + } +} +impl From for ShieldState { + fn from(value: RustShieldState) -> Self { match value { - Trusted => Self::Trusted, - Untrusted => Self::Untrusted, - UnknownDevice => Self::UnknownDevice, + RustShieldState::Red { message } => { + Self { color: ShieldColor::Red, message: Some(message.to_owned()) } + } + RustShieldState::Grey { message } => { + Self { color: ShieldColor::Grey, message: Some(message.to_owned()) } + } + RustShieldState::None => Self { color: ShieldColor::None, message: None }, } } } diff --git a/bindings/matrix-sdk-crypto-js/src/lib.rs b/bindings/matrix-sdk-crypto-js/src/lib.rs index 65a70998ae4..afbacff004d 100644 --- a/bindings/matrix-sdk-crypto-js/src/lib.rs +++ b/bindings/matrix-sdk-crypto-js/src/lib.rs @@ -15,7 +15,8 @@ #![doc = include_str!("../README.md")] #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![warn(missing_docs, missing_debug_implementations)] -#![allow(clippy::drop_non_drop)] // triggered by wasm_bindgen code +// triggered by wasm_bindgen code +#![allow(clippy::drop_non_drop)] pub mod attachment; pub mod device; diff --git a/bindings/matrix-sdk-crypto-js/src/machine.rs b/bindings/matrix-sdk-crypto-js/src/machine.rs index 66b06e04d49..258c24b22dd 100644 --- a/bindings/matrix-sdk-crypto-js/src/machine.rs +++ b/bindings/matrix-sdk-crypto-js/src/machine.rs @@ -1,6 +1,6 @@ //! The crypto specific Olm objects. -use std::collections::BTreeMap; +use std::{collections::BTreeMap, ops::Deref}; use js_sys::{Array, Function, Map, Promise, Set}; use ruma::{serde::Raw, DeviceKeyAlgorithm, OwnedTransactionId, UInt}; @@ -13,7 +13,7 @@ use crate::{ identifiers, identities, js::downcast, olm, requests, - requests::OutgoingRequest, + requests::{OutgoingRequest, ToDeviceRequest}, responses::{self, response_from_string}, store, sync_events, types, verification, vodozemac, }; @@ -70,16 +70,7 @@ impl OlmMachine { future_to_promise(async move { let store = match (store_name, store_passphrase) { - // We need this `#[cfg]` because `IndexeddbCryptoStore` - // implements `CryptoStore` only on `target_arch = - // "wasm32"`. Without that, we could have a compilation - // error when checking the entire workspace. In - // practise, it doesn't impact this crate because it's - // always compiled for `wasm32`. - #[cfg(target_arch = "wasm32")] (Some(store_name), Some(mut store_passphrase)) => { - use std::sync::Arc; - use zeroize::Zeroize; let store = Some( @@ -87,8 +78,7 @@ impl OlmMachine { &store_name, &store_passphrase, ) - .await - .map(Arc::new)?, + .await?, ); store_passphrase.zeroize(); @@ -96,15 +86,9 @@ impl OlmMachine { store } - #[cfg(target_arch = "wasm32")] - (Some(store_name), None) => { - use std::sync::Arc; - Some( - matrix_sdk_indexeddb::IndexeddbCryptoStore::open_with_name(&store_name) - .await - .map(Arc::new)?, - ) - } + (Some(store_name), None) => Some( + matrix_sdk_indexeddb::IndexeddbCryptoStore::open_with_name(&store_name).await?, + ), (None, Some(_)) => { return Err(anyhow::Error::msg( @@ -113,11 +97,18 @@ impl OlmMachine { )) } - _ => None, + (None, None) => None, }; Ok(OlmMachine { inner: match store { + // We need this `#[cfg]` because `IndexeddbCryptoStore` + // implements `CryptoStore` only on `target_arch = + // "wasm32"`. Without that, we could have a compilation + // error when checking the entire workspace. In practice, + // it doesn't impact this crate because it's always + // compiled for `wasm32`. + #[cfg(target_arch = "wasm32")] Some(store) => { matrix_sdk_crypto::OlmMachine::with_store( user_id.as_ref(), @@ -126,7 +117,7 @@ impl OlmMachine { ) .await? } - None => { + _ => { matrix_sdk_crypto::OlmMachine::new(user_id.as_ref(), device_id.as_ref()) .await } @@ -491,6 +482,8 @@ impl OlmMachine { /// `room_id` is the room ID. `users` is an array of `UserId` /// objects. `encryption_settings` are an `EncryptionSettings` /// object. + /// + /// Returns an array of `ToDeviceRequest`s. #[wasm_bindgen(js_name = "shareRoomKey")] pub fn share_room_key( &self, @@ -509,10 +502,19 @@ impl OlmMachine { let me = self.inner.clone(); Ok(future_to_promise(async move { - Ok(serde_json::to_string( - &me.share_room_key(&room_id, users.iter().map(AsRef::as_ref), encryption_settings) - .await?, - )?) + let to_device_requests = me + .share_room_key(&room_id, users.iter().map(AsRef::as_ref), encryption_settings) + .await?; + + // convert each request to our own ToDeviceRequest struct, and then wrap it in a + // JsValue. + // + // Then collect the results into a javascript Array, throwing any errors into + // the promise. + Ok(to_device_requests + .into_iter() + .map(|td| ToDeviceRequest::try_from(td.deref()).map(JsValue::from)) + .collect::>()?) })) } diff --git a/bindings/matrix-sdk-crypto-js/src/requests.rs b/bindings/matrix-sdk-crypto-js/src/requests.rs index 00b51731d29..ba50b744000 100644 --- a/bindings/matrix-sdk-crypto-js/src/requests.rs +++ b/bindings/matrix-sdk-crypto-js/src/requests.rs @@ -239,11 +239,11 @@ pub struct RoomMessageRequest { #[wasm_bindgen(readonly)] pub txn_id: JsString, - /// A string representing the type of even from the message's content. + /// A string representing the type of event to be sent. #[wasm_bindgen(readonly)] pub event_type: JsString, - /// A JSON-encoded string containing the message's body. + /// A JSON-encoded string containing the message's content. #[wasm_bindgen(readonly, js_name = "body")] pub content: JsString, } diff --git a/bindings/matrix-sdk-crypto-js/src/responses.rs b/bindings/matrix-sdk-crypto-js/src/responses.rs index 24d51cd48dd..2602a43d47d 100644 --- a/bindings/matrix-sdk-crypto-js/src/responses.rs +++ b/bindings/matrix-sdk-crypto-js/src/responses.rs @@ -1,7 +1,5 @@ //! Types related to responses. -use std::borrow::Borrow; - use js_sys::{Array, JsString}; use matrix_sdk_common::deserialized_responses::{AlgorithmInfo, EncryptionInfo}; use matrix_sdk_crypto::IncomingResponse; @@ -190,9 +188,15 @@ impl DecryptedRoomEvent { /// note this is the state of the device at the time of /// decryption. It may change in the future if a device gets /// verified or deleted. - #[wasm_bindgen(getter, js_name = "verificationState")] - pub fn verification_state(&self) -> Option { - Some((self.encryption_info.as_ref()?.verification_state.borrow()).into()) + #[wasm_bindgen(js_name = "shieldState")] + pub fn shield_state(&self, strict: bool) -> Option { + let state = &self.encryption_info.as_ref()?.verification_state; + + if strict { + Some(state.to_shield_state_strict().into()) + } else { + Some(state.to_shield_state_lax().into()) + } } } diff --git a/bindings/matrix-sdk-crypto-js/src/tracing.rs b/bindings/matrix-sdk-crypto-js/src/tracing.rs index 3d5ca7f9f02..b7a516ac1c5 100644 --- a/bindings/matrix-sdk-crypto-js/src/tracing.rs +++ b/bindings/matrix-sdk-crypto-js/src/tracing.rs @@ -145,9 +145,6 @@ mod inner { #[wasm_bindgen] extern "C" { - #[wasm_bindgen(js_namespace = console, js_name = "trace")] - fn log_trace(message: String); - #[wasm_bindgen(js_namespace = console, js_name = "debug")] fn log_debug(message: String); @@ -213,7 +210,7 @@ mod inner { let message = format!("{level} {origin}{recorder}"); match *level { - Level::TRACE => log_trace(message), + Level::TRACE => log_debug(message), Level::DEBUG => log_debug(message), Level::INFO => log_info(message), Level::WARN => log_warn(message), diff --git a/bindings/matrix-sdk-crypto-js/tests/asyncload.test.js b/bindings/matrix-sdk-crypto-js/tests/asyncload.test.js index 6e6d0879720..fe0dea526ca 100644 --- a/bindings/matrix-sdk-crypto-js/tests/asyncload.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/asyncload.test.js @@ -1,10 +1,12 @@ const { UserId, initAsync } = require("../pkg/matrix_sdk_crypto_js"); -test('can instantiate rust objects with async initialiser', async () => { - initUserId = () => new UserId('@foo:bar.org'); +test("can instantiate rust objects with async initialiser", async () => { + initUserId = () => new UserId("@foo:bar.org"); // stub out the synchronous WebAssembly loader with one that raises an error - jest.spyOn(WebAssembly, 'Module').mockImplementation(() => { throw new Error('synchronous WebAssembly.Module() not allowed')}); + jest.spyOn(WebAssembly, "Module").mockImplementation(() => { + throw new Error("synchronous WebAssembly.Module() not allowed"); + }); // this should fail expect(initUserId).toThrow(/synchronous/); diff --git a/bindings/matrix-sdk-crypto-js/tests/attachment.test.js b/bindings/matrix-sdk-crypto-js/tests/attachment.test.js index a614ef0dc41..20b4df8b580 100644 --- a/bindings/matrix-sdk-crypto-js/tests/attachment.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/attachment.test.js @@ -1,37 +1,41 @@ -const { Attachment, EncryptedAttachment } = require('../pkg/matrix_sdk_crypto_js'); +const { Attachment, EncryptedAttachment } = require("../pkg/matrix_sdk_crypto_js"); describe(Attachment.name, () => { - const originalData = 'hello'; + const originalData = "hello"; const textEncoder = new TextEncoder(); const textDecoder = new TextDecoder(); let encryptedAttachment; - test('can encrypt data', () => { + test("can encrypt data", () => { encryptedAttachment = Attachment.encrypt(textEncoder.encode(originalData)); const mediaEncryptionInfo = JSON.parse(encryptedAttachment.mediaEncryptionInfo); expect(mediaEncryptionInfo).toMatchObject({ - v: 'v2', + v: "v2", key: { kty: expect.any(String), - key_ops: expect.arrayContaining(['encrypt', 'decrypt']), + key_ops: expect.arrayContaining(["encrypt", "decrypt"]), alg: expect.any(String), k: expect.any(String), ext: expect.any(Boolean), }, iv: expect.stringMatching(/^[A-Za-z0-9\+/]+$/), hashes: { - sha256: expect.stringMatching(/^[A-Za-z0-9\+/]+$/) - } + sha256: expect.stringMatching(/^[A-Za-z0-9\+/]+$/), + }, }); const encryptedData = encryptedAttachment.encryptedData; - expect(encryptedData.every((i) => { i != 0 })).toStrictEqual(false); + expect( + encryptedData.every((i) => { + i != 0; + }), + ).toStrictEqual(false); }); - test('can decrypt data', () => { + test("can decrypt data", () => { expect(encryptedAttachment.hasMediaEncryptionInfoBeenConsumed).toStrictEqual(false); const decryptedAttachment = Attachment.decrypt(encryptedAttachment); @@ -40,34 +44,36 @@ describe(Attachment.name, () => { expect(encryptedAttachment.hasMediaEncryptionInfoBeenConsumed).toStrictEqual(true); }); - test('can only decrypt once', () => { + test("can only decrypt once", () => { expect(encryptedAttachment.hasMediaEncryptionInfoBeenConsumed).toStrictEqual(true); - expect(() => { textDecoder.decode(decryptedAttachment) }).toThrow() + expect(() => { + textDecoder.decode(decryptedAttachment); + }).toThrow(); }); }); describe(EncryptedAttachment.name, () => { - const originalData = 'hello'; + const originalData = "hello"; const textDecoder = new TextDecoder(); - test('can be created manually', () => { + test("can be created manually", () => { const encryptedAttachment = new EncryptedAttachment( new Uint8Array([24, 150, 67, 37, 144]), JSON.stringify({ - v: 'v2', + v: "v2", key: { - kty: 'oct', - key_ops: [ 'encrypt', 'decrypt' ], - alg: 'A256CTR', - k: 'QbNXUjuukFyEJ8cQZjJuzN6mMokg0HJIjx0wVMLf5BM', - ext: true + kty: "oct", + key_ops: ["encrypt", "decrypt"], + alg: "A256CTR", + k: "QbNXUjuukFyEJ8cQZjJuzN6mMokg0HJIjx0wVMLf5BM", + ext: true, }, - iv: 'xk2AcWkomiYAAAAAAAAAAA', + iv: "xk2AcWkomiYAAAAAAAAAAA", hashes: { - sha256: 'JsRbDXgOja4xvDiF3DwBuLHdxUzIrVYIuj7W/t3aEok' - } - }) + sha256: "JsRbDXgOja4xvDiF3DwBuLHdxUzIrVYIuj7W/t3aEok", + }, + }), ); expect(encryptedAttachment.hasMediaEncryptionInfoBeenConsumed).toStrictEqual(false); diff --git a/bindings/matrix-sdk-crypto-js/tests/device.test.js b/bindings/matrix-sdk-crypto-js/tests/device.test.js index 14ecc205195..2d19d910c87 100644 --- a/bindings/matrix-sdk-crypto-js/tests/device.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/device.test.js @@ -27,11 +27,11 @@ const { Qr, QrCode, QrCodeScan, -} = require('../pkg/matrix_sdk_crypto_js'); -const { zip, addMachineToMachine } = require('./helper'); +} = require("../pkg/matrix_sdk_crypto_js"); +const { zip, addMachineToMachine } = require("./helper"); -describe('LocalTrust', () => { - test('has the correct variant values', () => { +describe("LocalTrust", () => { + test("has the correct variant values", () => { expect(LocalTrust.Verified).toStrictEqual(0); expect(LocalTrust.BlackListed).toStrictEqual(1); expect(LocalTrust.Ignored).toStrictEqual(2); @@ -39,8 +39,8 @@ describe('LocalTrust', () => { }); }); -describe('DeviceKeyName', () => { - test('has the correct variant values', () => { +describe("DeviceKeyName", () => { + test("has the correct variant values", () => { expect(DeviceKeyName.Curve25519).toStrictEqual(0); expect(DeviceKeyName.Ed25519).toStrictEqual(1); expect(DeviceKeyName.Unknown).toStrictEqual(2); @@ -48,28 +48,58 @@ describe('DeviceKeyName', () => { }); describe(OlmMachine.name, () => { - const user = new UserId('@alice:example.org'); - const device = new DeviceId('foobar'); - const room = new RoomId('!baz:matrix.org'); + const user = new UserId("@alice:example.org"); + const device = new DeviceId("foobar"); + const room = new RoomId("!baz:matrix.org"); function machine(new_user, new_device) { return OlmMachine.initialize(new_user || user, new_device || device); } - test('can read user devices', async () => { + test("can read user devices", async () => { const m = await machine(); const userDevices = await m.getUserDevices(user); expect(userDevices).toBeInstanceOf(UserDevices); expect(userDevices.get(device)).toBeInstanceOf(Device); expect(userDevices.isAnyVerified()).toStrictEqual(false); - expect(userDevices.keys().map(device_id => device_id.toString())).toStrictEqual([device.toString()]); - expect(userDevices.devices().map(device => device.deviceId.toString())).toStrictEqual([device.toString()]); + expect(userDevices.keys().map((device_id) => device_id.toString())).toStrictEqual([device.toString()]); + expect(userDevices.devices().map((device) => device.deviceId.toString())).toStrictEqual([device.toString()]); }); - test('can read a user device', async () => { + test("can read a user device", async () => { const m = await machine(); - const dev = await m.getDevice(user, device); + + const hypothetical_response = JSON.stringify({ + device_keys: { + "@alice:example.org": { + JLAFKJWSCS: { + algorithms: ["m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + device_id: "JLAFKJWSCS", + keys: { + "curve25519:JLAFKJWSCS": "wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4", + "ed25519:JLAFKJWSCS": "nE6W2fCblxDcOFmeEtCHNl8/l8bXcu7GKyAswA4r3mM", + }, + signatures: { + "@alice:example.org": { + "ed25519:JLAFKJWSCS": + "m53Wkbh2HXkc3vFApZvCrfXcX3AI51GsDHustMhKwlv3TuOJMj4wistcOTM8q2+e/Ro7rWFUb9ZfnNbwptSUBA", + }, + }, + unsigned: { + device_display_name: "Alice's mobile phone", + }, + user_id: "@alice:example.org", + }, + }, + }, + failures: {}, + }); + // Insert another device into the store + await m.markRequestAsSent("ID", RequestType.KeysQuery, hypothetical_response); + + const secondDeviceId = new DeviceId("JLAFKJWSCS"); + const dev = await m.getDevice(user, secondDeviceId); expect(dev).toBeInstanceOf(Device); expect(dev.isVerified()).toStrictEqual(false); @@ -82,7 +112,7 @@ describe(OlmMachine.name, () => { expect(dev.isLocallyTrusted()).toStrictEqual(true); expect(dev.userId.toString()).toStrictEqual(user.toString()); - expect(dev.deviceId.toString()).toStrictEqual(device.toString()); + expect(dev.deviceId.toString()).toStrictEqual(secondDeviceId.toString()); expect(dev.deviceName).toBeUndefined(); const deviceKey = dev.getKey(DeviceKeyAlgorithmName.Ed25519); @@ -108,18 +138,18 @@ describe(OlmMachine.name, () => { }); }); -describe('Key Verification', () => { - const userId1 = new UserId('@alice:example.org'); - const deviceId1 = new DeviceId('alice_device'); +describe("Key Verification", () => { + const userId1 = new UserId("@alice:example.org"); + const deviceId1 = new DeviceId("alice_device"); - const userId2 = new UserId('@bob:example.org'); - const deviceId2 = new DeviceId('bob_device'); + const userId2 = new UserId("@bob:example.org"); + const deviceId2 = new DeviceId("bob_device"); function machine(new_user, new_device) { return OlmMachine.initialize(new_user || userId1, new_device || deviceId1); } - describe('SAS', () => { + describe("SAS", () => { // First Olm machine. let m1; @@ -137,7 +167,7 @@ describe('Key Verification', () => { // The flow ID. let flowId; - test('can request verification (`m.key.verification.request`)', async () => { + test("can request verification (`m.key.verification.request`)", async () => { // Make `m1` and `m2` be aware of each other. { await addMachineToMachine(m2, m1); @@ -164,7 +194,9 @@ describe('Key Verification', () => { expect(verificationRequest1.isReady()).toStrictEqual(false); expect(verificationRequest1.timedOut()).toStrictEqual(false); expect(verificationRequest1.theirSupportedMethods).toBeUndefined(); - expect(verificationRequest1.ourSupportedMethods).toEqual(expect.arrayContaining([VerificationMethod.SasV1, VerificationMethod.ReciprocateV1])); + expect(verificationRequest1.ourSupportedMethods).toEqual( + expect.arrayContaining([VerificationMethod.SasV1, VerificationMethod.ReciprocateV1]), + ); expect(verificationRequest1.flowId).toMatch(/^[a-f0-9]+$/); expect(verificationRequest1.isSelfVerification()).toStrictEqual(false); expect(verificationRequest1.weStarted()).toStrictEqual(true); @@ -172,13 +204,17 @@ describe('Key Verification', () => { expect(verificationRequest1.isCancelled()).toStrictEqual(false); expect(outgoingVerificationRequest).toBeInstanceOf(ToDeviceRequest); - expect(outgoingVerificationRequest.event_type).toStrictEqual('m.key.verification.request'); + expect(outgoingVerificationRequest.event_type).toStrictEqual("m.key.verification.request"); - const toDeviceEvents = [{ - sender: userId1.toString(), - type: outgoingVerificationRequest.event_type, - content: JSON.parse(outgoingVerificationRequest.body).messages[userId2.toString()][deviceId2.toString()], - }]; + const toDeviceEvents = [ + { + sender: userId1.toString(), + type: outgoingVerificationRequest.event_type, + content: JSON.parse(outgoingVerificationRequest.body).messages[userId2.toString()][ + deviceId2.toString() + ], + }, + ]; // Let's send the verification request to `m2`. await m2.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); @@ -189,7 +225,7 @@ describe('Key Verification', () => { // Verification request for `m2`. let verificationRequest2; - test('can fetch received request verification', async () => { + test("can fetch received request verification", async () => { // Oh, a new verification request. verificationRequest2 = m2.getVerificationRequest(userId1, flowId); @@ -203,7 +239,9 @@ describe('Key Verification', () => { expect(verificationRequest2.isPassive()).toStrictEqual(false); expect(verificationRequest2.isReady()).toStrictEqual(false); expect(verificationRequest2.timedOut()).toStrictEqual(false); - expect(verificationRequest2.theirSupportedMethods).toEqual(expect.arrayContaining([VerificationMethod.SasV1, VerificationMethod.ReciprocateV1])); + expect(verificationRequest2.theirSupportedMethods).toEqual( + expect.arrayContaining([VerificationMethod.SasV1, VerificationMethod.ReciprocateV1]), + ); expect(verificationRequest2.ourSupportedMethods).toBeUndefined(); expect(verificationRequest2.flowId).toStrictEqual(flowId); expect(verificationRequest2.isSelfVerification()).toStrictEqual(false); @@ -216,40 +254,52 @@ describe('Key Verification', () => { expect(verificationRequests[0].flowId).toStrictEqual(verificationRequest2.flowId); // there are the same }); - test('can accept a verification request (`m.key.verification.ready`)', async () => { + test("can accept a verification request (`m.key.verification.ready`)", async () => { // Accept the verification request. let outgoingVerificationRequest = verificationRequest2.accept(); expect(outgoingVerificationRequest).toBeInstanceOf(ToDeviceRequest); // The request verification is ready. - expect(outgoingVerificationRequest.event_type).toStrictEqual('m.key.verification.ready'); + expect(outgoingVerificationRequest.event_type).toStrictEqual("m.key.verification.ready"); - const toDeviceEvents = [{ - sender: userId2.toString(), - type: outgoingVerificationRequest.event_type, - content: JSON.parse(outgoingVerificationRequest.body).messages[userId1.toString()][deviceId1.toString()], - }]; + const toDeviceEvents = [ + { + sender: userId2.toString(), + type: outgoingVerificationRequest.event_type, + content: JSON.parse(outgoingVerificationRequest.body).messages[userId1.toString()][ + deviceId1.toString() + ], + }, + ]; // Let's send the verification ready to `m1`. await m1.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); }); - test('verification requests are synchronized and automatically updated', () => { + test("verification requests are synchronized and automatically updated", () => { expect(verificationRequest1.isReady()).toStrictEqual(true); expect(verificationRequest2.isReady()).toStrictEqual(true); - expect(verificationRequest1.theirSupportedMethods).toEqual(expect.arrayContaining([VerificationMethod.SasV1, VerificationMethod.ReciprocateV1])); - expect(verificationRequest1.ourSupportedMethods).toEqual(expect.arrayContaining([VerificationMethod.SasV1, VerificationMethod.ReciprocateV1])); - - expect(verificationRequest2.theirSupportedMethods).toEqual(expect.arrayContaining([VerificationMethod.SasV1, VerificationMethod.ReciprocateV1])); - expect(verificationRequest2.ourSupportedMethods).toEqual(expect.arrayContaining([VerificationMethod.SasV1, VerificationMethod.ReciprocateV1])); + expect(verificationRequest1.theirSupportedMethods).toEqual( + expect.arrayContaining([VerificationMethod.SasV1, VerificationMethod.ReciprocateV1]), + ); + expect(verificationRequest1.ourSupportedMethods).toEqual( + expect.arrayContaining([VerificationMethod.SasV1, VerificationMethod.ReciprocateV1]), + ); + + expect(verificationRequest2.theirSupportedMethods).toEqual( + expect.arrayContaining([VerificationMethod.SasV1, VerificationMethod.ReciprocateV1]), + ); + expect(verificationRequest2.ourSupportedMethods).toEqual( + expect.arrayContaining([VerificationMethod.SasV1, VerificationMethod.ReciprocateV1]), + ); }); // SAS verification for the second machine. let sas2; - test('can start a SAS verification (`m.key.verification.start`)', async () => { + test("can start a SAS verification (`m.key.verification.start`)", async () => { // Let's start a SAS verification, from `m2` for example. [sas2, outgoingVerificationRequest] = await verificationRequest2.startSas(); expect(sas2).toBeInstanceOf(Sas); @@ -276,13 +326,17 @@ describe('Key Verification', () => { expect(sas2.decimals()).toBeUndefined(); expect(outgoingVerificationRequest).toBeInstanceOf(ToDeviceRequest); - expect(outgoingVerificationRequest.event_type).toStrictEqual('m.key.verification.start'); + expect(outgoingVerificationRequest.event_type).toStrictEqual("m.key.verification.start"); - const toDeviceEvents = [{ - sender: userId2.toString(), - type: outgoingVerificationRequest.event_type, - content: JSON.parse(outgoingVerificationRequest.body).messages[userId1.toString()][deviceId1.toString()], - }]; + const toDeviceEvents = [ + { + sender: userId2.toString(), + type: outgoingVerificationRequest.event_type, + content: JSON.parse(outgoingVerificationRequest.body).messages[userId1.toString()][ + deviceId1.toString() + ], + }, + ]; // Let's send the SAS start to `m1`. await m1.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); @@ -291,7 +345,7 @@ describe('Key Verification', () => { // SAS verification for the second machine. let sas1; - test('can fetch and accept an ongoing SAS verification (`m.key.verification.accept`)', async () => { + test("can fetch and accept an ongoing SAS verification (`m.key.verification.accept`)", async () => { // Let's fetch the ongoing SAS verification. sas1 = await m1.getVerification(userId2, flowId); @@ -321,64 +375,72 @@ describe('Key Verification', () => { let outgoingVerificationRequest = sas1.accept(); expect(outgoingVerificationRequest).toBeInstanceOf(ToDeviceRequest); - expect(outgoingVerificationRequest.event_type).toStrictEqual('m.key.verification.accept'); + expect(outgoingVerificationRequest.event_type).toStrictEqual("m.key.verification.accept"); - const toDeviceEvents = [{ - sender: userId1.toString(), - type: outgoingVerificationRequest.event_type, - content: JSON.parse(outgoingVerificationRequest.body).messages[userId2.toString()][deviceId2.toString()], - }]; + const toDeviceEvents = [ + { + sender: userId1.toString(), + type: outgoingVerificationRequest.event_type, + content: JSON.parse(outgoingVerificationRequest.body).messages[userId2.toString()][ + deviceId2.toString() + ], + }, + ]; // Let's send the SAS accept to `m2`. await m2.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); }); - test('emojis are supported by both sides', () => { + test("emojis are supported by both sides", () => { expect(sas1.supportsEmoji()).toStrictEqual(true); expect(sas2.supportsEmoji()).toStrictEqual(true); }); - test('one side sends verification key (`m.key.verification.key`)', async () => { + test("one side sends verification key (`m.key.verification.key`)", async () => { // Let's send the verification keys from `m2` to `m1`. const outgoingRequests = await m2.outgoingRequests(); let toDeviceRequest = outgoingRequests.find((request) => request.type == RequestType.ToDevice); expect(toDeviceRequest).toBeInstanceOf(ToDeviceRequest); - expect(toDeviceRequest.event_type).toStrictEqual('m.key.verification.key'); + expect(toDeviceRequest.event_type).toStrictEqual("m.key.verification.key"); - const toDeviceEvents = [{ - sender: userId2.toString(), - type: toDeviceRequest.event_type, - content: JSON.parse(toDeviceRequest.body).messages[userId1.toString()][deviceId1.toString()], - }]; + const toDeviceEvents = [ + { + sender: userId2.toString(), + type: toDeviceRequest.event_type, + content: JSON.parse(toDeviceRequest.body).messages[userId1.toString()][deviceId1.toString()], + }, + ]; // Let's send te SAS key to `m1`. await m1.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); - m2.markRequestAsSent(toDeviceRequest.id, toDeviceRequest.type, '{}'); + m2.markRequestAsSent(toDeviceRequest.id, toDeviceRequest.type, "{}"); }); - test('other side sends back verification key (`m.key.verification.key`)', async () => { + test("other side sends back verification key (`m.key.verification.key`)", async () => { // Let's send the verification keys from `m1` to `m2`. const outgoingRequests = await m1.outgoingRequests(); let toDeviceRequest = outgoingRequests.find((request) => request.type == RequestType.ToDevice); expect(toDeviceRequest).toBeInstanceOf(ToDeviceRequest); - expect(toDeviceRequest.event_type).toStrictEqual('m.key.verification.key'); + expect(toDeviceRequest.event_type).toStrictEqual("m.key.verification.key"); - const toDeviceEvents = [{ - sender: userId1.toString(), - type: toDeviceRequest.event_type, - content: JSON.parse(toDeviceRequest.body).messages[userId2.toString()][deviceId2.toString()], - }]; + const toDeviceEvents = [ + { + sender: userId1.toString(), + type: toDeviceRequest.event_type, + content: JSON.parse(toDeviceRequest.body).messages[userId2.toString()][deviceId2.toString()], + }, + ]; // Let's send te SAS key to `m2`. await m2.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); - m1.markRequestAsSent(toDeviceRequest.id, toDeviceRequest.type, '{}'); + m1.markRequestAsSent(toDeviceRequest.id, toDeviceRequest.type, "{}"); }); - test('emojis match from both sides', () => { + test("emojis match from both sides", () => { const emojis1 = sas1.emoji(); const emojiIndexes1 = sas1.emojiIndex(); const emojis2 = sas2.emoji(); @@ -389,9 +451,15 @@ describe('Key Verification', () => { expect(emojis2).toHaveLength(emojis1.length); expect(emojiIndexes2).toHaveLength(emojis1.length); - const isEmoji = /(\u00a9|\u00ae|[\u2000-\u3300]|\ud83c[\ud000-\udfff]|\ud83d[\ud000-\udfff]|\ud83e[\ud000-\udfff])/; + const isEmoji = + /(\u00a9|\u00ae|[\u2000-\u3300]|\ud83c[\ud000-\udfff]|\ud83d[\ud000-\udfff]|\ud83e[\ud000-\udfff])/; - for (const [emoji1, emojiIndex1, emoji2, emojiIndex2] of zip(emojis1, emojiIndexes1, emojis2, emojiIndexes2)) { + for (const [emoji1, emojiIndex1, emoji2, emojiIndex2] of zip( + emojis1, + emojiIndexes1, + emojis2, + emojiIndexes2, + )) { expect(emoji1).toBeInstanceOf(Emoji); expect(emoji1.symbol).toMatch(isEmoji); expect(emoji1.description).toBeTruthy(); @@ -407,7 +475,7 @@ describe('Key Verification', () => { } }); - test('decimals match from both sides', () => { + test("decimals match from both sides", () => { const decimals1 = sas1.decimals(); const decimals2 = sas2.decimals(); @@ -423,7 +491,7 @@ describe('Key Verification', () => { } }); - test('can confirm keys match (`m.key.verification.mac`)', async () => { + test("can confirm keys match (`m.key.verification.mac`)", async () => { // `m1` confirms. const [outgoingVerificationRequests, signatureUploadRequest] = await sas1.confirm(); @@ -433,19 +501,23 @@ describe('Key Verification', () => { let outgoingVerificationRequest = outgoingVerificationRequests[0]; expect(outgoingVerificationRequest).toBeInstanceOf(ToDeviceRequest); - expect(outgoingVerificationRequest.event_type).toStrictEqual('m.key.verification.mac'); + expect(outgoingVerificationRequest.event_type).toStrictEqual("m.key.verification.mac"); - const toDeviceEvents = [{ - sender: userId1.toString(), - type: outgoingVerificationRequest.event_type, - content: JSON.parse(outgoingVerificationRequest.body).messages[userId2.toString()][deviceId2.toString()], - }]; + const toDeviceEvents = [ + { + sender: userId1.toString(), + type: outgoingVerificationRequest.event_type, + content: JSON.parse(outgoingVerificationRequest.body).messages[userId2.toString()][ + deviceId2.toString() + ], + }, + ]; // Let's send te SAS confirmation to `m2`. await m2.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); }); - test('can confirm back keys match (`m.key.verification.done`)', async () => { + test("can confirm back keys match (`m.key.verification.done`)", async () => { // `m2` confirms. const [outgoingVerificationRequests, signatureUploadRequest] = await sas2.confirm(); @@ -457,13 +529,17 @@ describe('Key Verification', () => { let outgoingVerificationRequest = outgoingVerificationRequests[0]; expect(outgoingVerificationRequest).toBeInstanceOf(ToDeviceRequest); - expect(outgoingVerificationRequest.event_type).toStrictEqual('m.key.verification.mac'); - - const toDeviceEvents = [{ - sender: userId2.toString(), - type: outgoingVerificationRequest.event_type, - content: JSON.parse(outgoingVerificationRequest.body).messages[userId1.toString()][deviceId1.toString()], - }]; + expect(outgoingVerificationRequest.event_type).toStrictEqual("m.key.verification.mac"); + + const toDeviceEvents = [ + { + sender: userId2.toString(), + type: outgoingVerificationRequest.event_type, + content: JSON.parse(outgoingVerificationRequest.body).messages[userId1.toString()][ + deviceId1.toString() + ], + }, + ]; // Let's send te SAS confirmation to `m1`. await m1.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); @@ -474,41 +550,47 @@ describe('Key Verification', () => { let outgoingVerificationRequest = outgoingVerificationRequests[1]; expect(outgoingVerificationRequest).toBeInstanceOf(ToDeviceRequest); - expect(outgoingVerificationRequest.event_type).toStrictEqual('m.key.verification.done'); - - const toDeviceEvents = [{ - sender: userId2.toString(), - type: outgoingVerificationRequest.event_type, - content: JSON.parse(outgoingVerificationRequest.body).messages[userId1.toString()][deviceId1.toString()], - }]; + expect(outgoingVerificationRequest.event_type).toStrictEqual("m.key.verification.done"); + + const toDeviceEvents = [ + { + sender: userId2.toString(), + type: outgoingVerificationRequest.event_type, + content: JSON.parse(outgoingVerificationRequest.body).messages[userId1.toString()][ + deviceId1.toString() + ], + }, + ]; // Let's send te SAS done to `m1`. await m1.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); } }); - test('can send final done (`m.key.verification.done`)', async () => { + test("can send final done (`m.key.verification.done`)", async () => { const outgoingRequests = await m1.outgoingRequests(); expect(outgoingRequests).toHaveLength(4); let toDeviceRequest = outgoingRequests.find((request) => request.type == RequestType.ToDevice); expect(toDeviceRequest).toBeInstanceOf(ToDeviceRequest); - expect(toDeviceRequest.event_type).toStrictEqual('m.key.verification.done'); + expect(toDeviceRequest.event_type).toStrictEqual("m.key.verification.done"); - const toDeviceEvents = [{ - sender: userId1.toString(), - type: toDeviceRequest.event_type, - content: JSON.parse(toDeviceRequest.body).messages[userId2.toString()][deviceId2.toString()], - }]; + const toDeviceEvents = [ + { + sender: userId1.toString(), + type: toDeviceRequest.event_type, + content: JSON.parse(toDeviceRequest.body).messages[userId2.toString()][deviceId2.toString()], + }, + ]; // Let's send te SAS key to `m2`. await m2.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); - m1.markRequestAsSent(toDeviceRequest.id, toDeviceRequest.type, '{}'); + m1.markRequestAsSent(toDeviceRequest.id, toDeviceRequest.type, "{}"); }); - test('can see if verification is done', () => { + test("can see if verification is done", () => { expect(verificationRequest1.isDone()).toStrictEqual(true); expect(verificationRequest2.isDone()).toStrictEqual(true); @@ -517,10 +599,10 @@ describe('Key Verification', () => { }); }); - describe('QR Code', () => { + describe("QR Code", () => { if (undefined === Qr) { // qrcode supports is not enabled - console.info('qrcode support is disabled, skip the associated test suite'); + console.info("qrcode support is disabled, skip the associated test suite"); return; } @@ -542,7 +624,7 @@ describe('Key Verification', () => { // The flow ID. let flowId; - test('can request verification (`m.key.verification.request`)', async () => { + test("can request verification (`m.key.verification.request`)", async () => { // Make `m1` and `m2` be aware of each other. { await addMachineToMachine(m2, m1); @@ -572,7 +654,9 @@ describe('Key Verification', () => { expect(verificationRequest1.isReady()).toStrictEqual(false); expect(verificationRequest1.timedOut()).toStrictEqual(false); expect(verificationRequest1.theirSupportedMethods).toBeUndefined(); - expect(verificationRequest1.ourSupportedMethods).toEqual(expect.arrayContaining([VerificationMethod.QrCodeShowV1])); + expect(verificationRequest1.ourSupportedMethods).toEqual( + expect.arrayContaining([VerificationMethod.QrCodeShowV1]), + ); expect(verificationRequest1.flowId).toMatch(/^[a-f0-9]+$/); expect(verificationRequest1.isSelfVerification()).toStrictEqual(false); expect(verificationRequest1.weStarted()).toStrictEqual(true); @@ -580,13 +664,17 @@ describe('Key Verification', () => { expect(verificationRequest1.isCancelled()).toStrictEqual(false); expect(outgoingVerificationRequest).toBeInstanceOf(ToDeviceRequest); - expect(outgoingVerificationRequest.event_type).toStrictEqual('m.key.verification.request'); + expect(outgoingVerificationRequest.event_type).toStrictEqual("m.key.verification.request"); - const toDeviceEvents = [{ - sender: userId1.toString(), - type: outgoingVerificationRequest.event_type, - content: JSON.parse(outgoingVerificationRequest.body).messages[userId2.toString()][deviceId2.toString()], - }]; + const toDeviceEvents = [ + { + sender: userId1.toString(), + type: outgoingVerificationRequest.event_type, + content: JSON.parse(outgoingVerificationRequest.body).messages[userId2.toString()][ + deviceId2.toString() + ], + }, + ]; // Let's send the verification request to `m2`. await m2.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); @@ -597,7 +685,7 @@ describe('Key Verification', () => { // Verification request for `m2`. let verificationRequest2; - test('can fetch received request verification', async () => { + test("can fetch received request verification", async () => { // Oh, a new verification request. verificationRequest2 = m2.getVerificationRequest(userId1, flowId); @@ -611,7 +699,9 @@ describe('Key Verification', () => { expect(verificationRequest2.isPassive()).toStrictEqual(false); expect(verificationRequest2.isReady()).toStrictEqual(false); expect(verificationRequest2.timedOut()).toStrictEqual(false); - expect(verificationRequest2.theirSupportedMethods).toEqual(expect.arrayContaining([VerificationMethod.QrCodeScanV1, VerificationMethod.QrCodeShowV1])); + expect(verificationRequest2.theirSupportedMethods).toEqual( + expect.arrayContaining([VerificationMethod.QrCodeScanV1, VerificationMethod.QrCodeShowV1]), + ); expect(verificationRequest2.ourSupportedMethods).toBeUndefined(); expect(verificationRequest2.flowId).toStrictEqual(flowId); expect(verificationRequest2.isSelfVerification()).toStrictEqual(false); @@ -624,7 +714,7 @@ describe('Key Verification', () => { expect(verificationRequests[0].flowId).toStrictEqual(verificationRequest2.flowId); // there are the same }); - test('can accept a verification request with methods (`m.key.verification.ready`)', async () => { + test("can accept a verification request with methods (`m.key.verification.ready`)", async () => { // Accept the verification request. let outgoingVerificationRequest = verificationRequest2.acceptWithMethods([ VerificationMethod.QrCodeScanV1, // by default @@ -634,33 +724,45 @@ describe('Key Verification', () => { expect(outgoingVerificationRequest).toBeInstanceOf(ToDeviceRequest); // The request verification is ready. - expect(outgoingVerificationRequest.event_type).toStrictEqual('m.key.verification.ready'); + expect(outgoingVerificationRequest.event_type).toStrictEqual("m.key.verification.ready"); - const toDeviceEvents = [{ - sender: userId2.toString(), - type: outgoingVerificationRequest.event_type, - content: JSON.parse(outgoingVerificationRequest.body).messages[userId1.toString()][deviceId1.toString()], - }]; + const toDeviceEvents = [ + { + sender: userId2.toString(), + type: outgoingVerificationRequest.event_type, + content: JSON.parse(outgoingVerificationRequest.body).messages[userId1.toString()][ + deviceId1.toString() + ], + }, + ]; // Let's send the verification ready to `m1`. await m1.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); }); - test('verification requests are synchronized and automatically updated', () => { + test("verification requests are synchronized and automatically updated", () => { expect(verificationRequest1.isReady()).toStrictEqual(true); expect(verificationRequest2.isReady()).toStrictEqual(true); - expect(verificationRequest1.theirSupportedMethods).toEqual(expect.arrayContaining([VerificationMethod.QrCodeScanV1, VerificationMethod.QrCodeShowV1])); - expect(verificationRequest1.ourSupportedMethods).toEqual(expect.arrayContaining([VerificationMethod.QrCodeScanV1, VerificationMethod.QrCodeShowV1])); - - expect(verificationRequest2.theirSupportedMethods).toEqual(expect.arrayContaining([VerificationMethod.QrCodeScanV1, VerificationMethod.QrCodeShowV1])); - expect(verificationRequest2.ourSupportedMethods).toEqual(expect.arrayContaining([VerificationMethod.QrCodeScanV1, VerificationMethod.QrCodeShowV1])); + expect(verificationRequest1.theirSupportedMethods).toEqual( + expect.arrayContaining([VerificationMethod.QrCodeScanV1, VerificationMethod.QrCodeShowV1]), + ); + expect(verificationRequest1.ourSupportedMethods).toEqual( + expect.arrayContaining([VerificationMethod.QrCodeScanV1, VerificationMethod.QrCodeShowV1]), + ); + + expect(verificationRequest2.theirSupportedMethods).toEqual( + expect.arrayContaining([VerificationMethod.QrCodeScanV1, VerificationMethod.QrCodeShowV1]), + ); + expect(verificationRequest2.ourSupportedMethods).toEqual( + expect.arrayContaining([VerificationMethod.QrCodeScanV1, VerificationMethod.QrCodeShowV1]), + ); }); // QR verification for the second machine. let qr2; - test('can generate a QR code', async () => { + test("can generate a QR code", async () => { qr2 = await verificationRequest2.generateQrCode(); expect(qr2).toBeInstanceOf(Qr); @@ -680,17 +782,19 @@ describe('Key Verification', () => { expect(qr2.roomId).toBeUndefined(); }); - test('can read QR code\'s bytes', async () => { - const qrCodeHeader = 'MATRIX'; - const qrCodeVersion = '\x02'; + test("can read QR code's bytes", async () => { + const qrCodeHeader = "MATRIX"; + const qrCodeVersion = "\x02"; const qrCodeBytes = qr2.toBytes(); expect(qrCodeBytes).toHaveLength(122); - expect(Array.from(qrCodeBytes.slice(0, 7))).toEqual([...qrCodeHeader, ...qrCodeVersion].map(char => char.charCodeAt(0))); + expect(Array.from(qrCodeBytes.slice(0, 7))).toEqual( + [...qrCodeHeader, ...qrCodeVersion].map((char) => char.charCodeAt(0)), + ); }); - test('can render QR code', async () => { + test("can render QR code", async () => { const qrCode = qr2.toQrCode(); expect(qrCode).toBeInstanceOf(QrCode); @@ -706,7 +810,7 @@ describe('Key Verification', () => { // 45px ⨉ 45px expect(buffer).toHaveLength(45 * 45); // 0 for a white pixel, 1 for a black pixel. - expect(buffer.every(p => p == 0 || p == 1)).toStrictEqual(true); + expect(buffer.every((p) => p == 0 || p == 1)).toStrictEqual(true); /* const { Canvas } = require('canvas'); @@ -763,7 +867,7 @@ describe('Key Verification', () => { let qr1; - test('can scan a QR code from bytes', async () => { + test("can scan a QR code from bytes", async () => { const scan = QrCodeScan.fromBytes(qr2.toBytes()); expect(scan).toBeInstanceOf(QrCodeScan); @@ -787,50 +891,58 @@ describe('Key Verification', () => { expect(qr1.roomId).toBeUndefined(); }); - test('can start a QR verification/reciprocate (`m.key.verification.start`)', async () => { + test("can start a QR verification/reciprocate (`m.key.verification.start`)", async () => { let outgoingVerificationRequest = qr1.reciprocate(); expect(outgoingVerificationRequest).toBeInstanceOf(ToDeviceRequest); - expect(outgoingVerificationRequest.event_type).toStrictEqual('m.key.verification.start'); + expect(outgoingVerificationRequest.event_type).toStrictEqual("m.key.verification.start"); - const toDeviceEvents = [{ - sender: userId1.toString(), - type: outgoingVerificationRequest.event_type, - content: JSON.parse(outgoingVerificationRequest.body).messages[userId2.toString()][deviceId2.toString()], - }]; + const toDeviceEvents = [ + { + sender: userId1.toString(), + type: outgoingVerificationRequest.event_type, + content: JSON.parse(outgoingVerificationRequest.body).messages[userId2.toString()][ + deviceId2.toString() + ], + }, + ]; // Let's send the verification request to `m2`. await m2.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); }); - test('can confirm QR code has been scanned', () => { + test("can confirm QR code has been scanned", () => { expect(qr2.hasBeenScanned()).toStrictEqual(true); }); - test('can confirm scanning (`m.key.verification.done`)', async () => { + test("can confirm scanning (`m.key.verification.done`)", async () => { let outgoingVerificationRequest = qr2.confirmScanning(); expect(outgoingVerificationRequest).toBeInstanceOf(ToDeviceRequest); - expect(outgoingVerificationRequest.event_type).toStrictEqual('m.key.verification.done'); + expect(outgoingVerificationRequest.event_type).toStrictEqual("m.key.verification.done"); - const toDeviceEvents = [{ - sender: userId2.toString(), - type: outgoingVerificationRequest.event_type, - content: JSON.parse(outgoingVerificationRequest.body).messages[userId1.toString()][deviceId1.toString()], - }]; + const toDeviceEvents = [ + { + sender: userId2.toString(), + type: outgoingVerificationRequest.event_type, + content: JSON.parse(outgoingVerificationRequest.body).messages[userId1.toString()][ + deviceId1.toString() + ], + }, + ]; // Let's send the verification request to `m2`. await m2.receiveSyncChanges(JSON.stringify(toDeviceEvents), new DeviceLists(), new Map(), new Set()); }); - test('can confirm QR code has been confirmed', () => { + test("can confirm QR code has been confirmed", () => { expect(qr2.hasBeenConfirmed()).toStrictEqual(true); }); }); }); -describe('VerificationMethod', () => { - test('has the correct variant values', () => { +describe("VerificationMethod", () => { + test("has the correct variant values", () => { expect(VerificationMethod.SasV1).toStrictEqual(0); expect(VerificationMethod.QrCodeScanV1).toStrictEqual(1); expect(VerificationMethod.QrCodeShowV1).toStrictEqual(2); diff --git a/bindings/matrix-sdk-crypto-js/tests/encryption.test.js b/bindings/matrix-sdk-crypto-js/tests/encryption.test.js index 4307deaf3e8..2edccdbbabc 100644 --- a/bindings/matrix-sdk-crypto-js/tests/encryption.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/encryption.test.js @@ -1,14 +1,19 @@ -const { EncryptionAlgorithm, EncryptionSettings, HistoryVisibility, VerificationState } = require('../pkg/matrix_sdk_crypto_js'); +const { + EncryptionAlgorithm, + EncryptionSettings, + HistoryVisibility, + VerificationState, +} = require("../pkg/matrix_sdk_crypto_js"); -describe('EncryptionAlgorithm', () => { - test('has the correct variant values', () => { +describe("EncryptionAlgorithm", () => { + test("has the correct variant values", () => { expect(EncryptionAlgorithm.OlmV1Curve25519AesSha2).toStrictEqual(0); expect(EncryptionAlgorithm.MegolmV1AesSha2).toStrictEqual(1); }); }); describe(EncryptionSettings.name, () => { - test('can be instantiated with default values', () => { + test("can be instantiated with default values", () => { const es = new EncryptionSettings(); expect(es.algorithm).toStrictEqual(EncryptionAlgorithm.MegolmV1AesSha2); @@ -17,20 +22,14 @@ describe(EncryptionSettings.name, () => { expect(es.historyVisibility).toStrictEqual(HistoryVisibility.Shared); }); - test('checks the history visibility values', () => { + test("checks the history visibility values", () => { const es = new EncryptionSettings(); es.historyVisibility = HistoryVisibility.Invited; expect(es.historyVisibility).toStrictEqual(HistoryVisibility.Invited); - expect(() => { es.historyVisibility = 42 }).toThrow(); - }); -}); - -describe('VerificationState', () => { - test('has the correct variant values', () => { - expect(VerificationState.Trusted).toStrictEqual(0); - expect(VerificationState.Untrusted).toStrictEqual(1); - expect(VerificationState.UnknownDevice).toStrictEqual(2); + expect(() => { + es.historyVisibility = 42; + }).toThrow(); }); }); diff --git a/bindings/matrix-sdk-crypto-js/tests/events.test.js b/bindings/matrix-sdk-crypto-js/tests/events.test.js index 75ed2b61037..4c7fdc8934a 100644 --- a/bindings/matrix-sdk-crypto-js/tests/events.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/events.test.js @@ -1,7 +1,7 @@ -const { HistoryVisibility } = require('../pkg/matrix_sdk_crypto_js'); +const { HistoryVisibility } = require("../pkg/matrix_sdk_crypto_js"); -describe('HistoryVisibility', () => { - test('has the correct variant values', () => { +describe("HistoryVisibility", () => { + test("has the correct variant values", () => { expect(HistoryVisibility.Invited).toStrictEqual(0); expect(HistoryVisibility.Joined).toStrictEqual(1); expect(HistoryVisibility.Shared).toStrictEqual(2); diff --git a/bindings/matrix-sdk-crypto-js/tests/helper.js b/bindings/matrix-sdk-crypto-js/tests/helper.js index 0d3136bed4d..eaf58d5b8b1 100644 --- a/bindings/matrix-sdk-crypto-js/tests/helper.js +++ b/bindings/matrix-sdk-crypto-js/tests/helper.js @@ -1,10 +1,10 @@ -const { DeviceLists, RequestType, KeysUploadRequest, KeysQueryRequest } = require('../pkg/matrix_sdk_crypto_js'); +const { DeviceLists, RequestType, KeysUploadRequest, KeysQueryRequest } = require("../pkg/matrix_sdk_crypto_js"); function* zip(...arrays) { const len = Math.min(...arrays.map((array) => array.length)); for (let nth = 0; nth < len; ++nth) { - yield [...arrays.map((array) => array.at(nth))] + yield [...arrays.map((array) => array.at(nth))]; } } @@ -16,7 +16,9 @@ async function addMachineToMachine(machineToAdd, machine) { const oneTimeKeyCounts = new Map(); const unusedFallbackKeys = new Set(); - const receiveSyncChanges = JSON.parse(await machineToAdd.receiveSyncChanges(toDeviceEvents, changedDevices, oneTimeKeyCounts, unusedFallbackKeys)); + const receiveSyncChanges = JSON.parse( + await machineToAdd.receiveSyncChanges(toDeviceEvents, changedDevices, oneTimeKeyCounts, unusedFallbackKeys), + ); expect(receiveSyncChanges).toEqual([]); @@ -38,12 +40,16 @@ async function addMachineToMachine(machineToAdd, machine) { // https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3keysupload const hypothetical_response = JSON.stringify({ - "one_time_key_counts": { - "curve25519": 10, - "signed_curve25519": 20 - } + one_time_key_counts: { + curve25519: 10, + signed_curve25519: 20, + }, }); - const marked = await machineToAdd.markRequestAsSent(outgoingRequests[0].id, outgoingRequests[0].type, hypothetical_response); + const marked = await machineToAdd.markRequestAsSent( + outgoingRequests[0].id, + outgoingRequests[0].type, + hypothetical_response, + ); expect(marked).toStrictEqual(true); keysUploadRequest = outgoingRequests[0]; @@ -71,7 +77,11 @@ async function addMachineToMachine(machineToAdd, machine) { keyQueryResponse.self_signing_keys[userId] = keys.self_signing_key; keyQueryResponse.user_signing_keys[userId] = keys.user_signing_key; - const marked = await machine.markRequestAsSent(outgoingRequests[1].id, outgoingRequests[1].type, JSON.stringify(keyQueryResponse)); + const marked = await machine.markRequestAsSent( + outgoingRequests[1].id, + outgoingRequests[1].type, + JSON.stringify(keyQueryResponse), + ); expect(marked).toStrictEqual(true); } } diff --git a/bindings/matrix-sdk-crypto-js/tests/identifiers.test.js b/bindings/matrix-sdk-crypto-js/tests/identifiers.test.js index 714dc64a6b4..5fe37f2d669 100644 --- a/bindings/matrix-sdk-crypto-js/tests/identifiers.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/identifiers.test.js @@ -7,65 +7,75 @@ const { RoomId, ServerName, UserId, -} = require('../pkg/matrix_sdk_crypto_js'); +} = require("../pkg/matrix_sdk_crypto_js"); describe(UserId.name, () => { - test('cannot be invalid', () => { - expect(() => { new UserId('@foobar') }).toThrow(); + test("cannot be invalid", () => { + expect(() => { + new UserId("@foobar"); + }).toThrow(); }); - const user = new UserId('@foo:bar.org'); + const user = new UserId("@foo:bar.org"); - test('localpart is present', () => { - expect(user.localpart).toStrictEqual('foo'); + test("localpart is present", () => { + expect(user.localpart).toStrictEqual("foo"); }); - test('server name is present', () => { + test("server name is present", () => { expect(user.serverName).toBeInstanceOf(ServerName); }); - test('user ID is not historical', () => { + test("user ID is not historical", () => { expect(user.isHistorical()).toStrictEqual(false); }); - test('can read the user ID as a string', () => { - expect(user.toString()).toStrictEqual('@foo:bar.org'); - }) + test("can read the user ID as a string", () => { + expect(user.toString()).toStrictEqual("@foo:bar.org"); + }); }); describe(DeviceId.name, () => { - const device = new DeviceId('foo'); + const device = new DeviceId("foo"); - test('can read the device ID as a string', () => { - expect(device.toString()).toStrictEqual('foo'); - }) + test("can read the device ID as a string", () => { + expect(device.toString()).toStrictEqual("foo"); + }); }); describe(DeviceKeyId.name, () => { for (const deviceKey of [ - { name: 'ed25519', - id: 'ed25519:foobar', - algorithmName: DeviceKeyAlgorithmName.Ed25519, - algorithm: 'ed25519', - deviceId: 'foobar' }, - - { name: 'curve25519', - id: 'curve25519:foobar', - algorithmName: DeviceKeyAlgorithmName.Curve25519, - algorithm: 'curve25519', - deviceId: 'foobar' }, - - { name: 'signed curve25519', - id: 'signed_curve25519:foobar', - algorithmName: DeviceKeyAlgorithmName.SignedCurve25519, - algorithm: 'signed_curve25519', - deviceId: 'foobar' }, - - { name: 'unknown', - id: 'hello:foobar', - algorithmName: DeviceKeyAlgorithmName.Unknown, - algorithm: 'hello', - deviceId: 'foobar' }, + { + name: "ed25519", + id: "ed25519:foobar", + algorithmName: DeviceKeyAlgorithmName.Ed25519, + algorithm: "ed25519", + deviceId: "foobar", + }, + + { + name: "curve25519", + id: "curve25519:foobar", + algorithmName: DeviceKeyAlgorithmName.Curve25519, + algorithm: "curve25519", + deviceId: "foobar", + }, + + { + name: "signed curve25519", + id: "signed_curve25519:foobar", + algorithmName: DeviceKeyAlgorithmName.SignedCurve25519, + algorithm: "signed_curve25519", + deviceId: "foobar", + }, + + { + name: "unknown", + id: "hello:foobar", + algorithmName: DeviceKeyAlgorithmName.Unknown, + algorithm: "hello", + deviceId: "foobar", + }, ]) { test(`${deviceKey.name} algorithm`, () => { const dk = new DeviceKeyId(deviceKey.id); @@ -78,8 +88,8 @@ describe(DeviceKeyId.name, () => { } }); -describe('DeviceKeyAlgorithmName', () => { - test('has the correct variants', () => { +describe("DeviceKeyAlgorithmName", () => { + test("has the correct variants", () => { expect(DeviceKeyAlgorithmName.Ed25519).toStrictEqual(0); expect(DeviceKeyAlgorithmName.Curve25519).toStrictEqual(1); expect(DeviceKeyAlgorithmName.SignedCurve25519).toStrictEqual(2); @@ -88,94 +98,100 @@ describe('DeviceKeyAlgorithmName', () => { }); describe(RoomId.name, () => { - test('cannot be invalid', () => { - expect(() => { new RoomId('!foo') }).toThrow(); + test("cannot be invalid", () => { + expect(() => { + new RoomId("!foo"); + }).toThrow(); }); - const room = new RoomId('!foo:bar.org'); + const room = new RoomId("!foo:bar.org"); - test('localpart is present', () => { - expect(room.localpart).toStrictEqual('foo'); + test("localpart is present", () => { + expect(room.localpart).toStrictEqual("foo"); }); - test('server name is present', () => { + test("server name is present", () => { expect(room.serverName).toBeInstanceOf(ServerName); }); - test('can read the room ID as string', () => { - expect(room.toString()).toStrictEqual('!foo:bar.org'); + test("can read the room ID as string", () => { + expect(room.toString()).toStrictEqual("!foo:bar.org"); }); }); describe(ServerName.name, () => { - test('cannot be invalid', () => { - expect(() => { new ServerName('@foobar') }).toThrow() + test("cannot be invalid", () => { + expect(() => { + new ServerName("@foobar"); + }).toThrow(); }); - test('host is present', () => { - expect(new ServerName('foo.org').host).toStrictEqual('foo.org'); + test("host is present", () => { + expect(new ServerName("foo.org").host).toStrictEqual("foo.org"); }); - test('port can be optional', () => { - expect(new ServerName('foo.org').port).toStrictEqual(undefined); - expect(new ServerName('foo.org:1234').port).toStrictEqual(1234); + test("port can be optional", () => { + expect(new ServerName("foo.org").port).toStrictEqual(undefined); + expect(new ServerName("foo.org:1234").port).toStrictEqual(1234); }); - test('server is not an IP literal', () => { - expect(new ServerName('foo.org').isIpLiteral()).toStrictEqual(false); + test("server is not an IP literal", () => { + expect(new ServerName("foo.org").isIpLiteral()).toStrictEqual(false); }); }); describe(EventId.name, () => { - test('cannot be invalid', () => { - expect(() => { new EventId('%foo') }).toThrow(); + test("cannot be invalid", () => { + expect(() => { + new EventId("%foo"); + }).toThrow(); }); - describe('Versions 1 & 2', () => { - const room = new EventId('$h29iv0s8:foo.org'); + describe("Versions 1 & 2", () => { + const room = new EventId("$h29iv0s8:foo.org"); - test('localpart is present', () => { - expect(room.localpart).toStrictEqual('h29iv0s8'); + test("localpart is present", () => { + expect(room.localpart).toStrictEqual("h29iv0s8"); }); - test('server name is present', () => { + test("server name is present", () => { expect(room.serverName).toBeInstanceOf(ServerName); }); - test('can read the room ID as string', () => { - expect(room.toString()).toStrictEqual('$h29iv0s8:foo.org'); + test("can read the room ID as string", () => { + expect(room.toString()).toStrictEqual("$h29iv0s8:foo.org"); }); }); - describe('Version 3', () => { - const room = new EventId('$acR1l0raoZnm60CBwAVgqbZqoO/mYU81xysh1u7XcJk'); + describe("Version 3", () => { + const room = new EventId("$acR1l0raoZnm60CBwAVgqbZqoO/mYU81xysh1u7XcJk"); - test('localpart is present', () => { - expect(room.localpart).toStrictEqual('acR1l0raoZnm60CBwAVgqbZqoO/mYU81xysh1u7XcJk'); + test("localpart is present", () => { + expect(room.localpart).toStrictEqual("acR1l0raoZnm60CBwAVgqbZqoO/mYU81xysh1u7XcJk"); }); - test('server name is present', () => { + test("server name is present", () => { expect(room.serverName).toBeUndefined(); }); - test('can read the room ID as string', () => { - expect(room.toString()).toStrictEqual('$acR1l0raoZnm60CBwAVgqbZqoO/mYU81xysh1u7XcJk'); + test("can read the room ID as string", () => { + expect(room.toString()).toStrictEqual("$acR1l0raoZnm60CBwAVgqbZqoO/mYU81xysh1u7XcJk"); }); }); - describe('Version 4', () => { - const room = new EventId('$Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg'); + describe("Version 4", () => { + const room = new EventId("$Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg"); - test('localpart is present', () => { - expect(room.localpart).toStrictEqual('Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg'); + test("localpart is present", () => { + expect(room.localpart).toStrictEqual("Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg"); }); - test('server name is present', () => { + test("server name is present", () => { expect(room.serverName).toBeUndefined(); }); - test('can read the room ID as string', () => { - expect(room.toString()).toStrictEqual('$Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg'); + test("can read the room ID as string", () => { + expect(room.toString()).toStrictEqual("$Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg"); }); }); -}) +}); diff --git a/bindings/matrix-sdk-crypto-js/tests/machine.test.js b/bindings/matrix-sdk-crypto-js/tests/machine.test.js index 1d11eafd4d1..32a0ffa0d90 100644 --- a/bindings/matrix-sdk-crypto-js/tests/machine.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/machine.test.js @@ -21,27 +21,29 @@ const { UserId, UserIdentity, VerificationRequest, - VerificationState, -} = require('../pkg/matrix_sdk_crypto_js'); -const { addMachineToMachine } = require('./helper'); -require('fake-indexeddb/auto'); + ShieldColor, +} = require("../pkg/matrix_sdk_crypto_js"); +const { addMachineToMachine } = require("./helper"); +require("fake-indexeddb/auto"); describe(OlmMachine.name, () => { - test('can be instantiated with the async initializer', async () => { - expect(await OlmMachine.initialize(new UserId('@foo:bar.org'), new DeviceId('baz'))).toBeInstanceOf(OlmMachine); + test("can be instantiated with the async initializer", async () => { + expect(await OlmMachine.initialize(new UserId("@foo:bar.org"), new DeviceId("baz"))).toBeInstanceOf(OlmMachine); }); - test('can be instantiated with a store', async () => { - let store_name = 'hello'; - let store_passphrase = 'world'; + test("can be instantiated with a store", async () => { + let store_name = "hello"; + let store_passphrase = "world"; - const by_store_name = db => db.name.startsWith(store_name); + const by_store_name = (db) => db.name.startsWith(store_name); // No databases. expect((await indexedDB.databases()).filter(by_store_name)).toHaveLength(0); // Creating a new Olm machine. - expect(await OlmMachine.initialize(new UserId('@foo:bar.org'), new DeviceId('baz'), store_name, store_passphrase)).toBeInstanceOf(OlmMachine); + expect( + await OlmMachine.initialize(new UserId("@foo:bar.org"), new DeviceId("baz"), store_name, store_passphrase), + ).toBeInstanceOf(OlmMachine); // Oh, there is 2 databases now, prefixed by `store_name`. let databases = (await indexedDB.databases()).filter(by_store_name); @@ -49,25 +51,32 @@ describe(OlmMachine.name, () => { expect(databases).toHaveLength(2); expect(databases).toStrictEqual([ { name: `${store_name}::matrix-sdk-crypto-meta`, version: 1 }, - { name: `${store_name}::matrix-sdk-crypto`, version: 1 }, + { name: `${store_name}::matrix-sdk-crypto`, version: 2 }, ]); // Creating a new Olm machine, with the stored state. - expect(await OlmMachine.initialize(new UserId('@foo:bar.org'), new DeviceId('baz'), store_name, store_passphrase)).toBeInstanceOf(OlmMachine); + expect( + await OlmMachine.initialize(new UserId("@foo:bar.org"), new DeviceId("baz"), store_name, store_passphrase), + ).toBeInstanceOf(OlmMachine); // Same number of databases. expect((await indexedDB.databases()).filter(by_store_name)).toHaveLength(2); }); - describe('cannot be instantiated with a store', () => { - test('store name is missing', async () => { + describe("cannot be instantiated with a store", () => { + test("store name is missing", async () => { let store_name = null; - let store_passphrase = 'world'; + let store_passphrase = "world"; let err = null; try { - await OlmMachine.initialize(new UserId('@foo:bar.org'), new DeviceId('baz'), store_name, store_passphrase); + await OlmMachine.initialize( + new UserId("@foo:bar.org"), + new DeviceId("baz"), + store_name, + store_passphrase, + ); } catch (error) { err = error; } @@ -75,14 +84,19 @@ describe(OlmMachine.name, () => { expect(err).toBeDefined(); }); - test('store passphrase is missing', async () => { - let store_name = 'hello'; + test("store passphrase is missing", async () => { + let store_name = "hello"; let store_passphrase = null; let err = null; try { - await OlmMachine.initialize(new UserId('@foo:bar.org'), new DeviceId('baz'), store_name, store_passphrase); + await OlmMachine.initialize( + new UserId("@foo:bar.org"), + new DeviceId("baz"), + store_name, + store_passphrase, + ); } catch (error) { err = error; } @@ -91,30 +105,35 @@ describe(OlmMachine.name, () => { }); }); - const user = new UserId('@alice:example.org'); - const device = new DeviceId('foobar'); - const room = new RoomId('!baz:matrix.org'); + const user = new UserId("@alice:example.org"); + const device = new DeviceId("foobar"); + const room = new RoomId("!baz:matrix.org"); function machine(new_user, new_device) { return OlmMachine.initialize(new_user || user, new_device || device); } - test('can drop/close', async () => { + test("can drop/close", async () => { m = await machine(); m.close(); }); - test('can drop/close with a store', async () => { - let store_name = 'temporary'; - let store_passphrase = 'temporary'; + test("can drop/close with a store", async () => { + let store_name = "temporary"; + let store_passphrase = "temporary"; - const by_store_name = db => db.name.startsWith(store_name); + const by_store_name = (db) => db.name.startsWith(store_name); // No databases. expect((await indexedDB.databases()).filter(by_store_name)).toHaveLength(0); // Creating a new Olm machine. - const m = await OlmMachine.initialize(new UserId('@foo:bar.org'), new DeviceId('baz'), store_name, store_passphrase); + const m = await OlmMachine.initialize( + new UserId("@foo:bar.org"), + new DeviceId("baz"), + store_name, + store_passphrase, + ); expect(m).toBeInstanceOf(OlmMachine); // Oh, there is 2 databases now, prefixed by `store_name`. @@ -123,7 +142,7 @@ describe(OlmMachine.name, () => { expect(databases).toHaveLength(2); expect(databases).toStrictEqual([ { name: `${store_name}::matrix-sdk-crypto-meta`, version: 1 }, - { name: `${store_name}::matrix-sdk-crypto`, version: 1 }, + { name: `${store_name}::matrix-sdk-crypto`, version: 2 }, ]); // Let's force to close the `OlmMachine`. @@ -133,31 +152,35 @@ describe(OlmMachine.name, () => { for (const database_name of [`${store_name}::matrix-sdk-crypto`, `${store_name}::matrix-sdk-crypto-meta`]) { const deleting = indexedDB.deleteDatabase(database_name); deleting.onsuccess = () => {}; - deleting.onerror = () => { throw new Error('failed to remove the database (error)') }; - deleting.onblocked = () => { throw new Error('failed to remove the database (blocked)') }; + deleting.onerror = () => { + throw new Error("failed to remove the database (error)"); + }; + deleting.onblocked = () => { + throw new Error("failed to remove the database (blocked)"); + }; } }); - test('can read user ID', async () => { + test("can read user ID", async () => { expect((await machine()).userId.toString()).toStrictEqual(user.toString()); }); - test('can read device ID', async () => { + test("can read device ID", async () => { expect((await machine()).deviceId.toString()).toStrictEqual(device.toString()); }); - test('can read identity keys', async () => { + test("can read identity keys", async () => { const identityKeys = (await machine()).identityKeys; expect(identityKeys.ed25519.toBase64()).toMatch(/^[A-Za-z0-9+/]+$/); expect(identityKeys.curve25519.toBase64()).toMatch(/^[A-Za-z0-9+/]+$/); }); - test('can read display name', async () => { + test("can read display name", async () => { expect(await machine().displayName).toBeUndefined(); }); - test('can read tracked users', async () => { + test("can read tracked users", async () => { const m = await machine(); const trackedUsers = await m.trackedUsers(); @@ -165,32 +188,36 @@ describe(OlmMachine.name, () => { expect(trackedUsers.size).toStrictEqual(0); }); - test('can update tracked users', async () => { + test("can update tracked users", async () => { const m = await machine(); expect(await m.updateTrackedUsers([user])).toStrictEqual(undefined); }); - test('can receive sync changes', async () => { + test("can receive sync changes", async () => { const m = await machine(); const toDeviceEvents = JSON.stringify([]); const changedDevices = new DeviceLists(); const oneTimeKeyCounts = new Map(); const unusedFallbackKeys = new Set(); - const receiveSyncChanges = JSON.parse(await m.receiveSyncChanges(toDeviceEvents, changedDevices, oneTimeKeyCounts, unusedFallbackKeys)); + const receiveSyncChanges = JSON.parse( + await m.receiveSyncChanges(toDeviceEvents, changedDevices, oneTimeKeyCounts, unusedFallbackKeys), + ); expect(receiveSyncChanges).toEqual([]); }); - test('can get the outgoing requests that need to be send out', async () => { + test("can get the outgoing requests that need to be send out", async () => { const m = await machine(); const toDeviceEvents = JSON.stringify([]); const changedDevices = new DeviceLists(); const oneTimeKeyCounts = new Map(); const unusedFallbackKeys = new Set(); - const receiveSyncChanges = JSON.parse(await m.receiveSyncChanges(toDeviceEvents, changedDevices, oneTimeKeyCounts, unusedFallbackKeys)); + const receiveSyncChanges = JSON.parse( + await m.receiveSyncChanges(toDeviceEvents, changedDevices, oneTimeKeyCounts, unusedFallbackKeys), + ); expect(receiveSyncChanges).toEqual([]); @@ -222,35 +249,40 @@ describe(OlmMachine.name, () => { } }); - describe('setup workflow to mark requests as sent', () => { + describe("setup workflow to mark requests as sent", () => { let m; let ougoingRequests; beforeAll(async () => { - m = await machine(new UserId('@alice:example.org'), new DeviceId('DEVICEID')); + m = await machine(new UserId("@alice:example.org"), new DeviceId("DEVICEID")); const toDeviceEvents = JSON.stringify([]); const changedDevices = new DeviceLists(); const oneTimeKeyCounts = new Map(); const unusedFallbackKeys = new Set(); - const receiveSyncChanges = await m.receiveSyncChanges(toDeviceEvents, changedDevices, oneTimeKeyCounts, unusedFallbackKeys); + const receiveSyncChanges = await m.receiveSyncChanges( + toDeviceEvents, + changedDevices, + oneTimeKeyCounts, + unusedFallbackKeys, + ); outgoingRequests = await m.outgoingRequests(); expect(outgoingRequests).toHaveLength(2); }); - test('can mark requests as sent', async () => { + test("can mark requests as sent", async () => { { const request = outgoingRequests[0]; expect(request).toBeInstanceOf(KeysUploadRequest); // https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3keysupload const hypothetical_response = JSON.stringify({ - "one_time_key_counts": { - "curve25519": 10, - "signed_curve25519": 20 - } + one_time_key_counts: { + curve25519: 10, + signed_curve25519: 20, + }, }); const marked = await m.markRequestAsSent(request.id, request.type, hypothetical_response); expect(marked).toStrictEqual(true); @@ -262,31 +294,29 @@ describe(OlmMachine.name, () => { // https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3keysquery const hypothetical_response = JSON.stringify({ - "device_keys": { + device_keys: { "@alice:example.org": { - "JLAFKJWSCS": { - "algorithms": [ - "m.olm.v1.curve25519-aes-sha2", - "m.megolm.v1.aes-sha2" - ], - "device_id": "JLAFKJWSCS", - "keys": { + JLAFKJWSCS: { + algorithms: ["m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + device_id: "JLAFKJWSCS", + keys: { "curve25519:JLAFKJWSCS": "wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4", - "ed25519:JLAFKJWSCS": "nE6W2fCblxDcOFmeEtCHNl8/l8bXcu7GKyAswA4r3mM" + "ed25519:JLAFKJWSCS": "nE6W2fCblxDcOFmeEtCHNl8/l8bXcu7GKyAswA4r3mM", }, - "signatures": { + signatures: { "@alice:example.org": { - "ed25519:JLAFKJWSCS": "m53Wkbh2HXkc3vFApZvCrfXcX3AI51GsDHustMhKwlv3TuOJMj4wistcOTM8q2+e/Ro7rWFUb9ZfnNbwptSUBA" - } + "ed25519:JLAFKJWSCS": + "m53Wkbh2HXkc3vFApZvCrfXcX3AI51GsDHustMhKwlv3TuOJMj4wistcOTM8q2+e/Ro7rWFUb9ZfnNbwptSUBA", + }, }, - "unsigned": { - "device_display_name": "Alice's mobile phone" + unsigned: { + device_display_name: "Alice's mobile phone", }, - "user_id": "@alice:example.org" - } - } + user_id: "@alice:example.org", + }, + }, }, - "failures": {} + failures: {}, }); const marked = await m.markRequestAsSent(request.id, request.type, hypothetical_response); expect(marked).toStrictEqual(true); @@ -294,143 +324,145 @@ describe(OlmMachine.name, () => { }); }); - describe('setup workflow to encrypt/decrypt events', () => { + describe("setup workflow to encrypt/decrypt events", () => { let m; - const user = new UserId('@alice:example.org'); - const device = new DeviceId('JLAFKJWSCS'); - const room = new RoomId('!test:localhost'); + const user = new UserId("@alice:example.org"); + const device = new DeviceId("JLAFKJWSCS"); + const room = new RoomId("!test:localhost"); beforeAll(async () => { m = await machine(user, device); }); - test('can pass keysquery and keysclaim requests directly', async () => { + test("can pass keysquery and keysclaim requests directly", async () => { { // derived from https://github.com/matrix-org/matrix-rust-sdk/blob/7f49618d350fab66b7e1dc4eaf64ec25ceafd658/benchmarks/benches/crypto_bench/keys_query.json const hypothetical_response = JSON.stringify({ - "device_keys": { + device_keys: { "@example:localhost": { - "AFGUOBTZWM": { - "algorithms": [ - "m.olm.v1.curve25519-aes-sha2", - "m.megolm.v1.aes-sha2" - ], - "device_id": "AFGUOBTZWM", - "keys": { + AFGUOBTZWM: { + algorithms: ["m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + device_id: "AFGUOBTZWM", + keys: { "curve25519:AFGUOBTZWM": "boYjDpaC+7NkECQEeMh5dC+I1+AfriX0VXG2UV7EUQo", - "ed25519:AFGUOBTZWM": "NayrMQ33ObqMRqz6R9GosmHdT6HQ6b/RX/3QlZ2yiec" + "ed25519:AFGUOBTZWM": "NayrMQ33ObqMRqz6R9GosmHdT6HQ6b/RX/3QlZ2yiec", }, - "signatures": { + signatures: { "@example:localhost": { - "ed25519:AFGUOBTZWM": "RoSWvru1jj6fs2arnTedWsyIyBmKHMdOu7r9gDi0BZ61h9SbCK2zLXzuJ9ZFLao2VvA0yEd7CASCmDHDLYpXCA" - } + "ed25519:AFGUOBTZWM": + "RoSWvru1jj6fs2arnTedWsyIyBmKHMdOu7r9gDi0BZ61h9SbCK2zLXzuJ9ZFLao2VvA0yEd7CASCmDHDLYpXCA", + }, + }, + user_id: "@example:localhost", + unsigned: { + device_display_name: "rust-sdk", }, - "user_id": "@example:localhost", - "unsigned": { - "device_display_name": "rust-sdk" - } }, - } + }, }, - "failures": {}, - "master_keys": { + failures: {}, + master_keys: { "@example:localhost": { - "user_id": "@example:localhost", - "usage": [ - "master" - ], - "keys": { - "ed25519:n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU": "n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU" + user_id: "@example:localhost", + usage: ["master"], + keys: { + "ed25519:n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU": + "n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU", }, - "signatures": { + signatures: { "@example:localhost": { - "ed25519:TCSJXPWGVS": "+j9G3L41I1fe0++wwusTTQvbboYW0yDtRWUEujhwZz4MAltjLSfJvY0hxhnz+wHHmuEXvQDen39XOpr1p29sAg" - } - } - } + "ed25519:TCSJXPWGVS": + "+j9G3L41I1fe0++wwusTTQvbboYW0yDtRWUEujhwZz4MAltjLSfJvY0hxhnz+wHHmuEXvQDen39XOpr1p29sAg", + }, + }, + }, }, - "self_signing_keys": { + self_signing_keys: { "@example:localhost": { - "user_id": "@example:localhost", - "usage": [ - "self_signing" - ], - "keys": { - "ed25519:kQXOuy639Yt47mvNTdrIluoC6DMvfbZLYbxAmwiDyhI": "kQXOuy639Yt47mvNTdrIluoC6DMvfbZLYbxAmwiDyhI" + user_id: "@example:localhost", + usage: ["self_signing"], + keys: { + "ed25519:kQXOuy639Yt47mvNTdrIluoC6DMvfbZLYbxAmwiDyhI": + "kQXOuy639Yt47mvNTdrIluoC6DMvfbZLYbxAmwiDyhI", }, - "signatures": { + signatures: { "@example:localhost": { - "ed25519:n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU": "q32ifix/qyRpvmegw2BEJklwoBCAJldDNkcX+fp+lBA4Rpyqtycxge6BA4hcJdxYsy3oV0IHRuugS8rJMMFyAA" - } - } - } + "ed25519:n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU": + "q32ifix/qyRpvmegw2BEJklwoBCAJldDNkcX+fp+lBA4Rpyqtycxge6BA4hcJdxYsy3oV0IHRuugS8rJMMFyAA", + }, + }, + }, }, - "user_signing_keys": { + user_signing_keys: { "@example:localhost": { - "user_id": "@example:localhost", - "usage": [ - "user_signing" - ], - "keys": { - "ed25519:g4ED07Fnqf3GzVWNN1pZ0IFrPQVdqQf+PYoJNH4eE0s": "g4ED07Fnqf3GzVWNN1pZ0IFrPQVdqQf+PYoJNH4eE0s" + user_id: "@example:localhost", + usage: ["user_signing"], + keys: { + "ed25519:g4ED07Fnqf3GzVWNN1pZ0IFrPQVdqQf+PYoJNH4eE0s": + "g4ED07Fnqf3GzVWNN1pZ0IFrPQVdqQf+PYoJNH4eE0s", }, - "signatures": { + signatures: { "@example:localhost": { - "ed25519:n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU": "nKQu8alQKDefNbZz9luYPcNj+Z+ouQSot4fU/A23ELl1xrI06QVBku/SmDx0sIW1ytso0Cqwy1a+3PzCa1XABg" - } - } - } - } + "ed25519:n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU": + "nKQu8alQKDefNbZz9luYPcNj+Z+ouQSot4fU/A23ELl1xrI06QVBku/SmDx0sIW1ytso0Cqwy1a+3PzCa1XABg", + }, + }, + }, + }, }); - const marked = await m.markRequestAsSent('foo', RequestType.KeysQuery, hypothetical_response); + const marked = await m.markRequestAsSent("foo", RequestType.KeysQuery, hypothetical_response); } { // derived from https://github.com/matrix-org/matrix-rust-sdk/blob/7f49618d350fab66b7e1dc4eaf64ec25ceafd658/benchmarks/benches/crypto_bench/keys_claim.json const hypothetical_response = JSON.stringify({ - "one_time_keys": { + one_time_keys: { "@example:localhost": { - "AFGUOBTZWM": { + AFGUOBTZWM: { "signed_curve25519:AAAABQ": { - "key": "9IGouMnkB6c6HOd4xUsNv4i3Dulb4IS96TzDordzOws", - "signatures": { + key: "9IGouMnkB6c6HOd4xUsNv4i3Dulb4IS96TzDordzOws", + signatures: { "@example:localhost": { - "ed25519:AFGUOBTZWM": "2bvUbbmJegrV0eVP/vcJKuIWC3kud+V8+C0dZtg4dVovOSJdTP/iF36tQn2bh5+rb9xLlSeztXBdhy4c+LiOAg" - } - } - } + "ed25519:AFGUOBTZWM": + "2bvUbbmJegrV0eVP/vcJKuIWC3kud+V8+C0dZtg4dVovOSJdTP/iF36tQn2bh5+rb9xLlSeztXBdhy4c+LiOAg", + }, + }, + }, }, - } + }, }, - "failures": {} + failures: {}, }); - const marked = await m.markRequestAsSent('bar', RequestType.KeysClaim, hypothetical_response); + const marked = await m.markRequestAsSent("bar", RequestType.KeysClaim, hypothetical_response); } }); - test('can share a room key', async () => { - const other_users = [new UserId('@example:localhost')]; + test("can share a room key", async () => { + const other_users = [new UserId("@example:localhost")]; - const requests = JSON.parse(await m.shareRoomKey(room, other_users, new EncryptionSettings())); + const requests = await m.shareRoomKey(room, other_users, new EncryptionSettings()); expect(requests).toHaveLength(1); - expect(requests[0].event_type).toBeDefined(); + expect(requests[0]).toBeInstanceOf(ToDeviceRequest); + expect(requests[0].event_type).toEqual("m.room.encrypted"); expect(requests[0].txn_id).toBeDefined(); - expect(requests[0].messages).toBeDefined(); - expect(requests[0].messages['@example:localhost']).toBeDefined(); + const content = JSON.parse(requests[0].body); + expect(Object.keys(content.messages)).toEqual(["@example:localhost"]); }); let encrypted; - test('can encrypt an event', async () => { - encrypted = JSON.parse(await m.encryptRoomEvent( - room, - 'm.room.message', - JSON.stringify({ - "msgtype": "m.text", - "body": "Hello, World!" - }), - )); + test("can encrypt an event", async () => { + encrypted = JSON.parse( + await m.encryptRoomEvent( + room, + "m.room.message", + JSON.stringify({ + msgtype: "m.text", + body: "Hello, World!", + }), + ), + ); expect(encrypted.algorithm).toBeDefined(); expect(encrypted.ciphertext).toBeDefined(); @@ -439,17 +471,17 @@ describe(OlmMachine.name, () => { expect(encrypted.session_id).toBeDefined(); }); - test('can decrypt an event', async () => { + test("can decrypt an event", async () => { const decrypted = await m.decryptRoomEvent( JSON.stringify({ - "type": "m.room.encrypted", - "event_id": "$xxxxx:example.org", - "origin_server_ts": Date.now(), - "sender": user.toString(), + type: "m.room.encrypted", + event_id: "$xxxxx:example.org", + origin_server_ts: Date.now(), + sender: user.toString(), content: encrypted, unsigned: { - "age": 1234 - } + age: 1234, + }, }), room, ); @@ -465,7 +497,8 @@ describe(OlmMachine.name, () => { expect(decrypted.senderCurve25519Key).toBeDefined(); expect(decrypted.senderClaimedEd25519Key).toBeDefined(); expect(decrypted.forwardingCurve25519KeyChain).toHaveLength(0); - expect(decrypted.verificationState).toStrictEqual(VerificationState.Trusted); + expect(decrypted.shieldState(true).color).toStrictEqual(ShieldColor.Red); + expect(decrypted.shieldState(false).color).toStrictEqual(ShieldColor.Red); }); }); @@ -484,7 +517,7 @@ describe(OlmMachine.name, () => { await expect(() => m.decryptRoomEvent(JSON.stringify(evt), room)).rejects.toThrowError(); }); - test('can read cross-signing status', async () => { + test("can read cross-signing status", async () => { const m = await machine(); const crossSigningStatus = await m.crossSigningStatus(); @@ -494,9 +527,9 @@ describe(OlmMachine.name, () => { expect(crossSigningStatus.hasUserSigning).toStrictEqual(false); }); - test('can sign a message', async () => { + test("can sign a message", async () => { const m = await machine(); - const signatures = await m.sign('foo'); + const signatures = await m.sign("foo"); expect(signatures.isEmpty()).toStrictEqual(false); expect(signatures.count).toStrictEqual(1); @@ -507,9 +540,9 @@ describe(OlmMachine.name, () => { { const signature = signatures.get(user); - expect(signature.has('ed25519:foobar')).toStrictEqual(true); + expect(signature.has("ed25519:foobar")).toStrictEqual(true); - const s = signature.get('ed25519:foobar'); + const s = signature.get("ed25519:foobar"); expect(s).toBeInstanceOf(MaybeSignature); @@ -525,18 +558,18 @@ describe(OlmMachine.name, () => { // `getSignature` { - const signature = signatures.getSignature(user, new DeviceKeyId('ed25519:foobar')); + const signature = signatures.getSignature(user, new DeviceKeyId("ed25519:foobar")); expect(signature.toBase64()).toStrictEqual(base64); } // Unknown signatures. { - expect(signatures.get(new UserId('@hello:example.org'))).toBeUndefined(); - expect(signatures.getSignature(user, new DeviceKeyId('world:foobar'))).toBeUndefined(); + expect(signatures.get(new UserId("@hello:example.org"))).toBeUndefined(); + expect(signatures.getSignature(user, new DeviceKeyId("world:foobar"))).toBeUndefined(); } }); - test('can get a user identities', async () => { + test("can get a user identities", async () => { const m = await machine(); let _ = m.bootstrapCrossSigning(true); @@ -556,15 +589,15 @@ describe(OlmMachine.name, () => { expect(isTrusted).toStrictEqual(false); }); - describe('can export/import room keys', () => { + describe("can export/import room keys", () => { let m; let exportedRoomKeys; - test('can export room keys', async () => { + test("can export room keys", async () => { m = await machine(); - await m.shareRoomKey(room, [new UserId('@bob:example.org')], new EncryptionSettings()); + await m.shareRoomKey(room, [new UserId("@bob:example.org")], new EncryptionSettings()); - exportedRoomKeys = await m.exportRoomKeys(session => { + exportedRoomKeys = await m.exportRoomKeys((session) => { expect(session).toBeInstanceOf(InboundGroupSession); expect(session.roomId.toString()).toStrictEqual(room.toString()); expect(session.sessionId).toBeDefined(); @@ -589,9 +622,9 @@ describe(OlmMachine.name, () => { }); let encryptedExportedRoomKeys; - let encryptionPassphrase = 'Hello, Matrix!'; + let encryptionPassphrase = "Hello, Matrix!"; - test('can encrypt the exported room keys', () => { + test("can encrypt the exported room keys", () => { encryptedExportedRoomKeys = OlmMachine.encryptExportedRoomKeys( exportedRoomKeys, encryptionPassphrase, @@ -601,7 +634,7 @@ describe(OlmMachine.name, () => { expect(encryptedExportedRoomKeys).toMatch(/^-----BEGIN MEGOLM SESSION DATA-----/); }); - test('can decrypt the exported room keys', () => { + test("can decrypt the exported room keys", () => { const decryptedExportedRoomKeys = OlmMachine.decryptExportedRoomKeys( encryptedExportedRoomKeys, encryptionPassphrase, @@ -610,7 +643,7 @@ describe(OlmMachine.name, () => { expect(decryptedExportedRoomKeys).toStrictEqual(exportedRoomKeys); }); - test('can import room keys', async () => { + test("can import room keys", async () => { const progressListener = (progress, total) => { expect(progress).toBeLessThan(total); @@ -629,169 +662,160 @@ describe(OlmMachine.name, () => { }); }); - describe('can do in-room verification', () => { + describe("can do in-room verification", () => { let m; - const user = new UserId('@alice:example.org'); - const device = new DeviceId('JLAFKJWSCS'); - const room = new RoomId('!test:localhost'); + const user = new UserId("@alice:example.org"); + const device = new DeviceId("JLAFKJWSCS"); + const room = new RoomId("!test:localhost"); beforeAll(async () => { m = await machine(user, device); }); - test('can inject devices from someone else', async () => { + test("can inject devices from someone else", async () => { { const hypothetical_response = JSON.stringify({ - "device_keys": { + device_keys: { "@example:morpheus.localhost": { - "ATRLDCRXAC": { - "algorithms": [ - "m.olm.v1.curve25519-aes-sha2", - "m.megolm.v1.aes-sha2" - ], - "device_id": "ATRLDCRXAC", - "keys": { + ATRLDCRXAC: { + algorithms: ["m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + device_id: "ATRLDCRXAC", + keys: { "curve25519:ATRLDCRXAC": "cAVT5Es3Z3F5pFD+2w3HT7O9+R3PstzYVkzD51X/FWQ", - "ed25519:ATRLDCRXAC": "V2w/T/x7i7AXiCCtS6JldrpbvRliRoef3CqTUNqMRHA" + "ed25519:ATRLDCRXAC": "V2w/T/x7i7AXiCCtS6JldrpbvRliRoef3CqTUNqMRHA", }, - "signatures": { + signatures: { "@example:morpheus.localhost": { - "ed25519:ATRLDCRXAC": "ro2BjO5J6089B/JOANHnFmGrogrC2TIdMlgJbJO00DjOOcGxXfvOezCFIORTwZNHvkHU617YIGl/4keTDIWvBQ" - } + "ed25519:ATRLDCRXAC": + "ro2BjO5J6089B/JOANHnFmGrogrC2TIdMlgJbJO00DjOOcGxXfvOezCFIORTwZNHvkHU617YIGl/4keTDIWvBQ", + }, + }, + user_id: "@example:morpheus.localhost", + unsigned: { + device_display_name: "Element Desktop: Linux", }, - "user_id": "@example:morpheus.localhost", - "unsigned": { - "device_display_name": "Element Desktop: Linux" - } }, - "EYYGYTCTNC": { - "algorithms": [ - "m.olm.v1.curve25519-aes-sha2", - "m.megolm.v1.aes-sha2" - ], - "device_id": "EYYGYTCTNC", - "keys": { + EYYGYTCTNC: { + algorithms: ["m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + device_id: "EYYGYTCTNC", + keys: { "curve25519:EYYGYTCTNC": "Pqu50fo472wgb6NjKkaUxjuqoAIEAmhln2gw/zSQ7Ek", - "ed25519:EYYGYTCTNC": "Pf/2QPvui8lDty6TCTglVPRVM+irNHYavNNkyv5yFpU" + "ed25519:EYYGYTCTNC": "Pf/2QPvui8lDty6TCTglVPRVM+irNHYavNNkyv5yFpU", }, - "signatures": { + signatures: { "@example:morpheus.localhost": { - "ed25519:EYYGYTCTNC": "pnP5BYLEUUaxDgrvdzCznkjNDbvY1/MFBr1JejdnLiXlcmxRULQpIWZUCO7QTbULsCwMsYQNGn50nfmjBQX3CQ" - } + "ed25519:EYYGYTCTNC": + "pnP5BYLEUUaxDgrvdzCznkjNDbvY1/MFBr1JejdnLiXlcmxRULQpIWZUCO7QTbULsCwMsYQNGn50nfmjBQX3CQ", + }, + }, + user_id: "@example:morpheus.localhost", + unsigned: { + device_display_name: "WeeChat-Matrix-rs", }, - "user_id": "@example:morpheus.localhost", - "unsigned": { - "device_display_name": "WeeChat-Matrix-rs" - } }, - "SUMODVLSIU": { - "algorithms": [ - "m.olm.v1.curve25519-aes-sha2", - "m.megolm.v1.aes-sha2" - ], - "device_id": "SUMODVLSIU", - "keys": { + SUMODVLSIU: { + algorithms: ["m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + device_id: "SUMODVLSIU", + keys: { "curve25519:SUMODVLSIU": "geQXWGWc++gcUHk0JcFmEVSjyzDOnk2mjVsUQwbNqQU", - "ed25519:SUMODVLSIU": "ccktaQ3g+B18E6FwVhTBYie26OlHbvDUzDEtxOQ4Qcs" + "ed25519:SUMODVLSIU": "ccktaQ3g+B18E6FwVhTBYie26OlHbvDUzDEtxOQ4Qcs", }, - "signatures": { + signatures: { "@example:morpheus.localhost": { - "ed25519:SUMODVLSIU": "Yn+AOxHRt1GQpY2xT2Jcqqn8jh5+Vw23ctA7NXyDiWPsLPLNTpjGWHMjZdpUqflQvpiKfhODPICoIa7Pu0iSAg", - "ed25519:rUiMNDjIu6gqsrhJPbj3phyIzuEtuQGrLOEa9mCbtTM": "Cio6k/sq289XNTOvTCWre7Q6zg+A3euzMUe7Uy1T3gPqYFzX+kt7EAxrhbPqx1HyXAEz9zD0D/uw9VEXFCvWBQ" - } + "ed25519:SUMODVLSIU": + "Yn+AOxHRt1GQpY2xT2Jcqqn8jh5+Vw23ctA7NXyDiWPsLPLNTpjGWHMjZdpUqflQvpiKfhODPICoIa7Pu0iSAg", + "ed25519:rUiMNDjIu6gqsrhJPbj3phyIzuEtuQGrLOEa9mCbtTM": + "Cio6k/sq289XNTOvTCWre7Q6zg+A3euzMUe7Uy1T3gPqYFzX+kt7EAxrhbPqx1HyXAEz9zD0D/uw9VEXFCvWBQ", + }, + }, + user_id: "@example:morpheus.localhost", + unsigned: { + device_display_name: "Element Desktop (Linux)", }, - "user_id": "@example:morpheus.localhost", - "unsigned": { - "device_display_name": "Element Desktop (Linux)" - } - } - } + }, + }, }, - "failures": {}, - "master_keys": { + failures: {}, + master_keys: { "@example:morpheus.localhost": { - "user_id": "@example:morpheus.localhost", - "usage": [ - "master" - ], - "keys": { - "ed25519:ZzU4WCyBfOFitdGmfKCq6F39iQCDk/zhNNTsi+tWH7A": "ZzU4WCyBfOFitdGmfKCq6F39iQCDk/zhNNTsi+tWH7A" + user_id: "@example:morpheus.localhost", + usage: ["master"], + keys: { + "ed25519:ZzU4WCyBfOFitdGmfKCq6F39iQCDk/zhNNTsi+tWH7A": + "ZzU4WCyBfOFitdGmfKCq6F39iQCDk/zhNNTsi+tWH7A", }, - "signatures": { + signatures: { "@example:morpheus.localhost": { - "ed25519:SUMODVLSIU": "RL6WOuuzB/mZ+edfUFG/KeEcmKh+NaWpM6m2bUYmDnJrtTCYyoU+pgHJuL2/6nynemmONo18JEHBuqtNcMq2AQ" - } - } - } + "ed25519:SUMODVLSIU": + "RL6WOuuzB/mZ+edfUFG/KeEcmKh+NaWpM6m2bUYmDnJrtTCYyoU+pgHJuL2/6nynemmONo18JEHBuqtNcMq2AQ", + }, + }, + }, }, - "self_signing_keys": { + self_signing_keys: { "@example:morpheus.localhost": { - "user_id": "@example:morpheus.localhost", - "usage": [ - "self_signing" - ], - "keys": { - "ed25519:rUiMNDjIu6gqsrhJPbj3phyIzuEtuQGrLOEa9mCbtTM": "rUiMNDjIu6gqsrhJPbj3phyIzuEtuQGrLOEa9mCbtTM" + user_id: "@example:morpheus.localhost", + usage: ["self_signing"], + keys: { + "ed25519:rUiMNDjIu6gqsrhJPbj3phyIzuEtuQGrLOEa9mCbtTM": + "rUiMNDjIu6gqsrhJPbj3phyIzuEtuQGrLOEa9mCbtTM", }, - "signatures": { + signatures: { "@example:morpheus.localhost": { - "ed25519:ZzU4WCyBfOFitdGmfKCq6F39iQCDk/zhNNTsi+tWH7A": "uCBn9rpeg6umY8H97ejN26UMp6QDwNL98869t1DoVGL50J8adLN05OZd8lYk9QzwTr2d56ZTGYSYX8kv28SDDA" - } - } - } + "ed25519:ZzU4WCyBfOFitdGmfKCq6F39iQCDk/zhNNTsi+tWH7A": + "uCBn9rpeg6umY8H97ejN26UMp6QDwNL98869t1DoVGL50J8adLN05OZd8lYk9QzwTr2d56ZTGYSYX8kv28SDDA", + }, + }, + }, }, - "user_signing_keys": { + user_signing_keys: { "@example:morpheus.localhost": { - "user_id": "@example:morpheus.localhost", - "usage": [ - "user_signing" - ], - "keys": { - "ed25519:GLhEKLQ50jnF6IMEPsO2ucpHUNIUEnbBXs5gYbHg4Aw": "GLhEKLQ50jnF6IMEPsO2ucpHUNIUEnbBXs5gYbHg4Aw" + user_id: "@example:morpheus.localhost", + usage: ["user_signing"], + keys: { + "ed25519:GLhEKLQ50jnF6IMEPsO2ucpHUNIUEnbBXs5gYbHg4Aw": + "GLhEKLQ50jnF6IMEPsO2ucpHUNIUEnbBXs5gYbHg4Aw", }, - "signatures": { + signatures: { "@example:morpheus.localhost": { - "ed25519:ZzU4WCyBfOFitdGmfKCq6F39iQCDk/zhNNTsi+tWH7A": "4fIyWlVzuz1pgoegNLZASycORXqKycVS0dNq5vmmwsVEudp1yrPhndnaIJ3fjF8LDHvwzXTvohOid7DiU1j0AA" - } - } - } - } + "ed25519:ZzU4WCyBfOFitdGmfKCq6F39iQCDk/zhNNTsi+tWH7A": + "4fIyWlVzuz1pgoegNLZASycORXqKycVS0dNq5vmmwsVEudp1yrPhndnaIJ3fjF8LDHvwzXTvohOid7DiU1j0AA", + }, + }, + }, + }, }); - const marked = await m.markRequestAsSent('foo', RequestType.KeysQuery, hypothetical_response); + const marked = await m.markRequestAsSent("foo", RequestType.KeysQuery, hypothetical_response); } }); - test('can start an in-room SAS verification', async () => { + test("can start an in-room SAS verification", async () => { let _ = m.bootstrapCrossSigning(true); - const identity = await m.getIdentity(new UserId('@example:morpheus.localhost')); + const identity = await m.getIdentity(new UserId("@example:morpheus.localhost")); expect(identity).toBeInstanceOf(UserIdentity); expect(identity.isVerified()).toStrictEqual(false); - const eventId = new EventId('$Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg'); + const eventId = new EventId("$Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg"); const verificationRequest = await identity.requestVerification(room, eventId); expect(verificationRequest).toBeInstanceOf(VerificationRequest); await m.receiveVerificationEvent( JSON.stringify({ - "sender": "@example:morpheus.localhost", - "type": "m.key.verification.ready", - "event_id": "$QguWmaeMt6Hao7Ea6XHDInvr8ndknev79t9a2eBxlz0", - "origin_server_ts": 1674037263075, - "content": { - "methods": [ - "m.sas.v1", - "m.qr_code.show.v1", - "m.reciprocate.v1" - ], + sender: "@example:morpheus.localhost", + type: "m.key.verification.ready", + event_id: "$QguWmaeMt6Hao7Ea6XHDInvr8ndknev79t9a2eBxlz0", + origin_server_ts: 1674037263075, + content: { + "methods": ["m.sas.v1", "m.qr_code.show.v1", "m.reciprocate.v1"], "from_device": "SUMODVLSIU", "m.relates_to": { - "rel_type": "m.reference", - "event_id": eventId.toString(), - } - } + rel_type: "m.reference", + event_id: eventId.toString(), + }, + }, }), - room + room, ); expect(verificationRequest.roomId.toString()).toStrictEqual(room.toString()); @@ -802,22 +826,22 @@ describe(OlmMachine.name, () => { expect(outgoingVerificationRequest.id).toBeDefined(); expect(outgoingVerificationRequest.room_id).toStrictEqual(room.toString()); expect(outgoingVerificationRequest.txn_id).toBeDefined(); - expect(outgoingVerificationRequest.event_type).toStrictEqual('m.key.verification.start'); + expect(outgoingVerificationRequest.event_type).toStrictEqual("m.key.verification.start"); expect(outgoingVerificationRequest.body).toBeDefined(); const body = JSON.parse(outgoingVerificationRequest.body); expect(body).toMatchObject({ - from_device: expect.any(String), - method: 'm.sas.v1', - key_agreement_protocols: [expect.any(String)], - hashes: [expect.any(String)], - message_authentication_codes: [expect.any(String), expect.any(String)], - short_authentication_string: ['decimal', 'emoji'], - 'm.relates_to': { - rel_type: 'm.reference', + "from_device": expect.any(String), + "method": "m.sas.v1", + "key_agreement_protocols": [expect.any(String)], + "hashes": [expect.any(String)], + "message_authentication_codes": [expect.any(String), expect.any(String)], + "short_authentication_string": ["decimal", "emoji"], + "m.relates_to": { + rel_type: "m.reference", event_id: eventId.toString(), - } + }, }); - }) + }); }); }); diff --git a/bindings/matrix-sdk-crypto-js/tests/requests.test.js b/bindings/matrix-sdk-crypto-js/tests/requests.test.js index a69b7b39224..bb216550ec2 100644 --- a/bindings/matrix-sdk-crypto-js/tests/requests.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/requests.test.js @@ -1,7 +1,16 @@ -const { RequestType, KeysUploadRequest, KeysQueryRequest, KeysClaimRequest, ToDeviceRequest, SignatureUploadRequest, RoomMessageRequest, KeysBackupRequest } = require('../pkg/matrix_sdk_crypto_js'); +const { + RequestType, + KeysUploadRequest, + KeysQueryRequest, + KeysClaimRequest, + ToDeviceRequest, + SignatureUploadRequest, + RoomMessageRequest, + KeysBackupRequest, +} = require("../pkg/matrix_sdk_crypto_js"); -describe('RequestType', () => { - test('has the correct variant values', () => { +describe("RequestType", () => { + test("has the correct variant values", () => { expect(RequestType.KeysUpload).toStrictEqual(0); expect(RequestType.KeysQuery).toStrictEqual(1); expect(RequestType.KeysClaim).toStrictEqual(2); diff --git a/bindings/matrix-sdk-crypto-js/tests/sync_events.test.js b/bindings/matrix-sdk-crypto-js/tests/sync_events.test.js index 305d22556a5..f3acd127cc1 100644 --- a/bindings/matrix-sdk-crypto-js/tests/sync_events.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/sync_events.test.js @@ -1,7 +1,7 @@ -const { DeviceLists, UserId } = require('../pkg/matrix_sdk_crypto_js'); +const { DeviceLists, UserId } = require("../pkg/matrix_sdk_crypto_js"); describe(DeviceLists.name, () => { - test('can be empty', () => { + test("can be empty", () => { const empty = new DeviceLists(); expect(empty.isEmpty()).toStrictEqual(true); @@ -9,7 +9,7 @@ describe(DeviceLists.name, () => { expect(empty.left).toHaveLength(0); }); - test('can be coerced empty', () => { + test("can be coerced empty", () => { const empty = new DeviceLists([], []); expect(empty.isEmpty()).toStrictEqual(true); @@ -17,15 +17,15 @@ describe(DeviceLists.name, () => { expect(empty.left).toHaveLength(0); }); - test('returns the correct `changed` and `left`', () => { - const list = new DeviceLists([new UserId('@foo:bar.org')], [new UserId('@baz:qux.org')]); + test("returns the correct `changed` and `left`", () => { + const list = new DeviceLists([new UserId("@foo:bar.org")], [new UserId("@baz:qux.org")]); expect(list.isEmpty()).toStrictEqual(false); expect(list.changed).toHaveLength(1); - expect(list.changed[0].toString()).toStrictEqual('@foo:bar.org'); + expect(list.changed[0].toString()).toStrictEqual("@foo:bar.org"); expect(list.left).toHaveLength(1); - expect(list.left[0].toString()).toStrictEqual('@baz:qux.org'); + expect(list.left[0].toString()).toStrictEqual("@baz:qux.org"); }); }); diff --git a/bindings/matrix-sdk-crypto-js/tests/tracing.test.js b/bindings/matrix-sdk-crypto-js/tests/tracing.test.js index a37383f24cb..9c36dfa4043 100644 --- a/bindings/matrix-sdk-crypto-js/tests/tracing.test.js +++ b/bindings/matrix-sdk-crypto-js/tests/tracing.test.js @@ -1,7 +1,7 @@ -const { Tracing, LoggerLevel, OlmMachine, UserId, DeviceId } = require('../pkg/matrix_sdk_crypto_js'); +const { Tracing, LoggerLevel, OlmMachine, UserId, DeviceId } = require("../pkg/matrix_sdk_crypto_js"); -describe('LoggerLevel', () => { - test('has the correct variant values', () => { +describe("LoggerLevel", () => { + test("has the correct variant values", () => { expect(LoggerLevel.Trace).toStrictEqual(0); expect(LoggerLevel.Debug).toStrictEqual(1); expect(LoggerLevel.Info).toStrictEqual(2); @@ -14,7 +14,7 @@ describe(Tracing.name, () => { if (Tracing.isAvailable()) { let tracing = new Tracing(LoggerLevel.Debug); - test('can installed several times', () => { + test("can installed several times", () => { new Tracing(LoggerLevel.Debug); new Tracing(LoggerLevel.Warn); new Tracing(LoggerLevel.Debug); @@ -23,27 +23,30 @@ describe(Tracing.name, () => { const originalConsoleDebug = console.debug; for (const [testName, testPreState, testPostState, expectedGotcha] of [ + ["can log something", () => {}, () => {}, true], [ - 'can log something', - () => {}, - () => {}, - true, - ], - [ - 'can change the logger level', - () => { tracing.minLevel = LoggerLevel.Warn }, - () => { tracing.minLevel = LoggerLevel.Debug }, + "can change the logger level", + () => { + tracing.minLevel = LoggerLevel.Warn; + }, + () => { + tracing.minLevel = LoggerLevel.Debug; + }, false, ], [ - 'can be turned off', - () => { tracing.turnOff() }, + "can be turned off", + () => { + tracing.turnOff(); + }, () => {}, false, ], [ - 'can be turned on', - () => { tracing.turnOn() }, + "can be turned on", + () => { + tracing.turnOn(); + }, () => {}, true, ], @@ -51,8 +54,10 @@ describe(Tracing.name, () => { // This one *must* be the last. We are turning tracing off // again for the other tests. [ - 'can be turned off', - () => { tracing.turnOff() }, + "can be turned off", + () => { + tracing.turnOff(); + }, () => {}, false, ], @@ -68,7 +73,7 @@ describe(Tracing.name, () => { }; // Do something that emits a `DEBUG` log. - await OlmMachine.initialize(new UserId('@alice:example.org'), new DeviceId('foo')); + await OlmMachine.initialize(new UserId("@alice:example.org"), new DeviceId("foo")); console.debug = originalConsoleDebug; testPostState(); @@ -77,8 +82,10 @@ describe(Tracing.name, () => { }); } } else { - test('cannot be constructed', () => { - expect(() => { new Tracing(LoggerLevel.Error) }).toThrow(); + test("cannot be constructed", () => { + expect(() => { + new Tracing(LoggerLevel.Error); + }).toThrow(); }); } }); diff --git a/bindings/matrix-sdk-crypto-js/tsconfig.json b/bindings/matrix-sdk-crypto-js/tsconfig.json index cca9bfa1968..2bcd353240e 100644 --- a/bindings/matrix-sdk-crypto-js/tsconfig.json +++ b/bindings/matrix-sdk-crypto-js/tsconfig.json @@ -5,6 +5,6 @@ "typedocOptions": { "entryPoints": ["pkg/matrix_sdk_crypto_js.d.ts"], "out": "docs", - "readme": "README.md", + "readme": "README.md" } } diff --git a/bindings/matrix-sdk-crypto-nodejs/.gitignore b/bindings/matrix-sdk-crypto-nodejs/.gitignore index f1fbf26ec7c..4dc1f759280 100644 --- a/bindings/matrix-sdk-crypto-nodejs/.gitignore +++ b/bindings/matrix-sdk-crypto-nodejs/.gitignore @@ -4,4 +4,4 @@ /index.d.ts /matrix-sdk-crypto.*.node /docs/* -*.tgz \ No newline at end of file +*.tgz diff --git a/bindings/matrix-sdk-crypto-nodejs/.npmignore b/bindings/matrix-sdk-crypto-nodejs/.npmignore index 3e3d54025bf..cebcc362e6a 100644 --- a/bindings/matrix-sdk-crypto-nodejs/.npmignore +++ b/bindings/matrix-sdk-crypto-nodejs/.npmignore @@ -5,4 +5,3 @@ build.rs *.node *.tgz tsconfig.json -cliff.toml \ No newline at end of file diff --git a/bindings/matrix-sdk-crypto-nodejs/.prettierignore b/bindings/matrix-sdk-crypto-nodejs/.prettierignore new file mode 100644 index 00000000000..fc5bd908320 --- /dev/null +++ b/bindings/matrix-sdk-crypto-nodejs/.prettierignore @@ -0,0 +1 @@ +/pkg diff --git a/bindings/matrix-sdk-crypto-nodejs/.prettierrc.js b/bindings/matrix-sdk-crypto-nodejs/.prettierrc.js new file mode 100644 index 00000000000..f739c10be90 --- /dev/null +++ b/bindings/matrix-sdk-crypto-nodejs/.prettierrc.js @@ -0,0 +1,9 @@ +// prettier configuration: the same as the conventions used throughout Matrix.org +// see: https://github.com/matrix-org/eslint-plugin-matrix-org/blob/main/.prettierrc.js + +module.exports = { + printWidth: 120, + tabWidth: 4, + quoteProps: "consistent", + trailingComma: "all", +}; diff --git a/bindings/matrix-sdk-crypto-nodejs/CHANGELOG.md b/bindings/matrix-sdk-crypto-nodejs/CHANGELOG.md index 2d83a8260ac..ad372b52f69 100644 --- a/bindings/matrix-sdk-crypto-nodejs/CHANGELOG.md +++ b/bindings/matrix-sdk-crypto-nodejs/CHANGELOG.md @@ -2,7 +2,7 @@ ## 0.1.0-beta.1 - 2022-07-14 -- Fixing broken download link, [#842](https://github.com/matrix-org/matrix-rust-sdk/issues/842) +- Fixing broken download link, [#842](https://github.com/matrix-org/matrix-rust-sdk/issues/842) ## 0.1.0-beta.0 - 2022-07-12 diff --git a/bindings/matrix-sdk-crypto-nodejs/Cargo.toml b/bindings/matrix-sdk-crypto-nodejs/Cargo.toml index 85e965b3b8a..15d4eca0d20 100644 --- a/bindings/matrix-sdk-crypto-nodejs/Cargo.toml +++ b/bindings/matrix-sdk-crypto-nodejs/Cargo.toml @@ -26,6 +26,7 @@ tracing = ["dep:tracing-subscriber"] matrix-sdk-crypto = { version = "0.6.0", path = "../../crates/matrix-sdk-crypto", features = ["js"] } matrix-sdk-common = { version = "0.6.0", path = "../../crates/matrix-sdk-common", features = ["js"] } matrix-sdk-sled = { version = "0.2.0", path = "../../crates/matrix-sdk-sled", default-features = false, features = ["crypto-store"] } +matrix-sdk-sqlite = { version = "0.1.0", path = "../../crates/matrix-sdk-sqlite", features = ["crypto-store"] } ruma = { workspace = true, features = ["rand", "unstable-msc2677"] } napi = { version = "2.9.1", default-features = false, features = ["napi6", "tokio_rt"] } napi-derive = "2.9.1" diff --git a/bindings/matrix-sdk-crypto-nodejs/README.md b/bindings/matrix-sdk-crypto-nodejs/README.md index 51bed6f2ed5..7e7b80ff1e1 100644 --- a/bindings/matrix-sdk-crypto-nodejs/README.md +++ b/bindings/matrix-sdk-crypto-nodejs/README.md @@ -12,6 +12,7 @@ Encryption](https://en.wikipedia.org/wiki/End-to-end_encryption)) for ## Usage Just add the latest release to your `package.json`: + ```sh $ npm install --save @matrix-org/matrix-sdk-crypto-nodejs ``` @@ -112,28 +113,27 @@ generated. At the same level of those files, you can edit a file and try this: ```javascript -const { OlmMachine } = require('./index.js'); +const { OlmMachine } = require("./index.js"); // Let's see what we can do. ``` The `OlmMachine` state machine works in a push/pull manner: -* You push state changes and events retrieved from a Matrix homeserver - `/sync` response, into the state machine, - -* You pull requests that you will need to send back to the homeserver - out of the state machine. - +- You push state changes and events retrieved from a Matrix homeserver + `/sync` response, into the state machine, +- You pull requests that you will need to send back to the homeserver + out of the state machine. + ```javascript -const { OlmMachine, UserId, DeviceId, RoomId, DeviceLists } = require('./index.js'); +const { OlmMachine, UserId, DeviceId, RoomId, DeviceLists } = require("./index.js"); async function main() { // Define a user ID. - const alice = new UserId('@alice:example.org'); + const alice = new UserId("@alice:example.org"); // Define a device ID. - const device = new DeviceId('DEVICEID'); + const device = new DeviceId("DEVICEID"); // Let's create the `OlmMachine` state machine. const machine = await OlmMachine.initialize(alice, device); @@ -198,8 +198,6 @@ $ npm run doc The documentation is generated in the `./docs` directory. - - [Node.js]: https://nodejs.org/ [`matrix-sdk-crypto`]: https://github.com/matrix-org/matrix-rust-sdk/tree/main/crates/matrix-sdk-crypto [`matrix-rust-sdk`]: https://github.com/matrix-org/matrix-rust-sdk diff --git a/bindings/matrix-sdk-crypto-nodejs/cliff.toml b/bindings/matrix-sdk-crypto-nodejs/cliff.toml deleted file mode 100644 index f2c005219b2..00000000000 --- a/bindings/matrix-sdk-crypto-nodejs/cliff.toml +++ /dev/null @@ -1,61 +0,0 @@ -# configuration file for git-cliff (0.1.0) - -[changelog] -# changelog header -header = """ -# Matrix SDK Crypto Node.js Changelog\n -All notable changes to this project will be documented in this file.\n -""" -# template for the changelog body -# https://tera.netlify.app/docs/#introduction -body = """ -{% if version %}\ - ## [{{ version | trim_start_matches(pat="v") }}] - {{ timestamp | date(format="%Y-%m-%d") }} -{% else %}\ - ## [unreleased] -{% endif %}\ -{% for group, commits in commits | filter(attribute="scope", value="crypto-nodejs") | group_by(attribute="group") %} - ### {{ group | upper_first }} - {% for commit in commits %} - - {{ commit.id | truncate(length=7, end="") }}{% if commit.breaking %} [**breaking**] {% endif %}: {{ commit.message | upper_first }}\ - {% endfor %} -{% endfor %}\n -""" -# remove the leading and trailing whitespace from the template -trim = true -# changelog footer -footer = """ -""" - -[git] -# parse the commits based on https://www.conventionalcommits.org -conventional_commits = true -# filter out the commits that are not conventional -filter_unconventional = true -# regex for preprocessing the commit messages -commit_preprocessors = [ - { pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](https://github.com/matrix-org/matrix-rust-sdk/issues/${2}))"}, -] -# regex for parsing and grouping commits -commit_parsers = [ - { message = "^feat", group = "Features"}, - { message = "^fix", group = "Bug Fixes"}, - { message = "^test", group = "Testing"}, - { message = "^doc", group = "Documentation"}, - { message = "^refactor", group = "Refactoring"}, - { message = "^ci", group = "Continuous Integration"}, - { message = "^chore", group = "Miscellaneous Tasks"}, - { body = ".*security", group = "Security"}, -] -# filter out the commits that are not matched by commit parsers -filter_commits = true -# glob pattern for matching git tags -tag_pattern = "v[0-9]*" -# regex for skipping tags -skip_tags = "v0.1.0-beta.1" -# regex for ignoring tags -ignore_tags = "" -# sort the tags chronologically -date_order = false -# sort the commits inside sections by oldest/newest order -sort_commits = "oldest" diff --git a/bindings/matrix-sdk-crypto-nodejs/download-lib.js b/bindings/matrix-sdk-crypto-nodejs/download-lib.js index 6e360db133c..c23719e301d 100644 --- a/bindings/matrix-sdk-crypto-nodejs/download-lib.js +++ b/bindings/matrix-sdk-crypto-nodejs/download-lib.js @@ -1,19 +1,18 @@ -const { HttpsProxyAgent } = require('https-proxy-agent'); -const { DownloaderHelper } = require('node-downloader-helper'); +const { HttpsProxyAgent } = require("https-proxy-agent"); +const { DownloaderHelper } = require("node-downloader-helper"); const { version } = require("./package.json"); -const { platform, arch } = process +const { platform, arch } = process; const DOWNLOADS_BASE_URL = "https://github.com/matrix-org/matrix-rust-sdk/releases/download"; const CURRENT_VERSION = `matrix-sdk-crypto-nodejs-v${version}`; const byteHelper = function (value) { if (value === 0) { - return '0 b'; + return "0 b"; } - const units = ['b', 'kB', 'MB', 'GB', 'TB']; + const units = ["b", "kB", "MB", "GB", "TB"]; const number = Math.floor(Math.log(value) / Math.log(1024)); - return (value / Math.pow(1024, Math.floor(number))).toFixed(1) + ' ' + - units[number]; + return (value / Math.pow(1024, Math.floor(number))).toFixed(1) + " " + units[number]; }; function download_lib(libname) { @@ -33,9 +32,9 @@ function download_lib(libname) { }); } - dl.on('end', () => console.info('Download Completed')); - dl.on('error', (err) => console.info('Download Failed', err)); - dl.on('progress', stats => { + dl.on("end", () => console.info("Download Completed")); + dl.on("error", (err) => console.info("Download Failed", err)); + dl.on("progress", (stats) => { const progress = stats.progress.toFixed(1); const speed = byteHelper(stats.speed); const downloaded = byteHelper(stats.downloaded); @@ -49,74 +48,74 @@ function download_lib(libname) { console.info(`${speed}/s - ${progress}% [${downloaded}/${total}]`); } }); - dl.start().catch(err => console.error(err)); + dl.start().catch((err) => console.error(err)); } function isMusl() { - // For Node 10 - if (!process.report || typeof process.report.getReport !== 'function') { - try { - return readFileSync('/usr/bin/ldd', 'utf8').includes('musl') - } catch (e) { - return true + // For Node 10 + if (!process.report || typeof process.report.getReport !== "function") { + try { + return readFileSync("/usr/bin/ldd", "utf8").includes("musl"); + } catch (e) { + return true; + } + } else { + const { glibcVersionRuntime } = process.report.getReport().header; + return !glibcVersionRuntime; } - } else { - const { glibcVersionRuntime } = process.report.getReport().header - return !glibcVersionRuntime - } } switch (platform) { - case 'win32': - switch (arch) { - case 'x64': - download_lib('matrix-sdk-crypto.win32-x64-msvc.node') - break - case 'ia32': - download_lib('matrix-sdk-crypto.win32-ia32-msvc.node') - break - case 'arm64': - download_lib('matrix-sdk-crypto.win32-arm64-msvc.node') - break - default: - throw new Error(`Unsupported architecture on Windows: ${arch}`) - } - break - case 'darwin': - switch (arch) { - case 'x64': - download_lib('matrix-sdk-crypto.darwin-x64.node') - break - case 'arm64': - download_lib('matrix-sdk-crypto.darwin-arm64.node') - break - default: - throw new Error(`Unsupported architecture on macOS: ${arch}`) - } - break - case 'linux': - switch (arch) { - case 'x64': - if (isMusl()) { - download_lib('matrix-sdk-crypto.linux-x64-musl.node') - } else { - download_lib('matrix-sdk-crypto.linux-x64-gnu.node') + case "win32": + switch (arch) { + case "x64": + download_lib("matrix-sdk-crypto.win32-x64-msvc.node"); + break; + case "ia32": + download_lib("matrix-sdk-crypto.win32-ia32-msvc.node"); + break; + case "arm64": + download_lib("matrix-sdk-crypto.win32-arm64-msvc.node"); + break; + default: + throw new Error(`Unsupported architecture on Windows: ${arch}`); } - break - case 'arm64': - if (isMusl()) { - throw new Error('Linux for arm64 musl isn\'t support at the moment') - } else { - download_lib('matrix-sdk-crypto.linux-arm64-gnu.node') + break; + case "darwin": + switch (arch) { + case "x64": + download_lib("matrix-sdk-crypto.darwin-x64.node"); + break; + case "arm64": + download_lib("matrix-sdk-crypto.darwin-arm64.node"); + break; + default: + throw new Error(`Unsupported architecture on macOS: ${arch}`); } - break - case 'arm': - download_lib('matrix-sdk-crypto.linux-arm-gnueabihf.node') - break - default: - throw new Error(`Unsupported architecture on Linux: ${arch}`) - } - break - default: - throw new Error(`Unsupported OS: ${platform}, architecture: ${arch}`) + break; + case "linux": + switch (arch) { + case "x64": + if (isMusl()) { + download_lib("matrix-sdk-crypto.linux-x64-musl.node"); + } else { + download_lib("matrix-sdk-crypto.linux-x64-gnu.node"); + } + break; + case "arm64": + if (isMusl()) { + throw new Error("Linux for arm64 musl isn't support at the moment"); + } else { + download_lib("matrix-sdk-crypto.linux-arm64-gnu.node"); + } + break; + case "arm": + download_lib("matrix-sdk-crypto.linux-arm-gnueabihf.node"); + break; + default: + throw new Error(`Unsupported architecture on Linux: ${arch}`); + } + break; + default: + throw new Error(`Unsupported OS: ${platform}, architecture: ${arch}`); } diff --git a/bindings/matrix-sdk-crypto-nodejs/package.json b/bindings/matrix-sdk-crypto-nodejs/package.json index d858158e25c..e7e3ed1feb3 100644 --- a/bindings/matrix-sdk-crypto-nodejs/package.json +++ b/bindings/matrix-sdk-crypto-nodejs/package.json @@ -15,6 +15,7 @@ "devDependencies": { "@napi-rs/cli": "^2.9.0", "jest": "^28.1.0", + "prettier": "^2.8.3", "typedoc": "^0.22.17", "yargs-parser": "~21.0.1" }, @@ -22,6 +23,7 @@ "node": ">= 14" }, "scripts": { + "lint": "prettier --check .", "release-build": "napi build --platform --release --strip", "build": "napi build --platform", "postinstall": "node download-lib.js", diff --git a/bindings/matrix-sdk-crypto-nodejs/src/encryption.rs b/bindings/matrix-sdk-crypto-nodejs/src/encryption.rs index 8b30a3bd6ba..240a217295e 100644 --- a/bindings/matrix-sdk-crypto-nodejs/src/encryption.rs +++ b/bindings/matrix-sdk-crypto-nodejs/src/encryption.rs @@ -1,5 +1,6 @@ use std::time::Duration; +use matrix_sdk_common::deserialized_responses::ShieldState as RustShieldState; use napi::bindgen_prelude::{BigInt, ToNapiValue}; use napi_derive::*; @@ -107,27 +108,33 @@ impl From<&EncryptionSettings> for matrix_sdk_crypto::olm::EncryptionSettings { } } -/// The verification state of the device that sent an event to us. +/// Take a look at [`matrix_sdk_common::deserialized_responses::ShieldState`] +/// for more info. #[napi] -pub enum VerificationState { - /// The device is trusted. - Trusted, - - /// The device is not trusted. - Untrusted, - - /// The device is not known to us. - UnknownDevice, +pub enum ShieldColor { + Red, + Grey, + None, } -impl From<&matrix_sdk_common::deserialized_responses::VerificationState> for VerificationState { - fn from(value: &matrix_sdk_common::deserialized_responses::VerificationState) -> Self { - use matrix_sdk_common::deserialized_responses::VerificationState::*; +/// Take a look at [`matrix_sdk_common::deserialized_responses::ShieldState`] +/// for more info. +#[napi] +pub struct ShieldState { + pub color: ShieldColor, + pub message: Option<&'static str>, +} +impl From for ShieldState { + fn from(value: RustShieldState) -> Self { match value { - Trusted => Self::Trusted, - Untrusted => Self::Untrusted, - UnknownDevice => Self::UnknownDevice, + RustShieldState::Red { message } => { + ShieldState { color: ShieldColor::Red, message: Some(message) } + } + RustShieldState::Grey { message } => { + ShieldState { color: ShieldColor::Grey, message: Some(message) } + } + RustShieldState::None => ShieldState { color: ShieldColor::None, message: None }, } } } diff --git a/bindings/matrix-sdk-crypto-nodejs/src/machine.rs b/bindings/matrix-sdk-crypto-nodejs/src/machine.rs index c40deb0243f..52ef0e8d58b 100644 --- a/bindings/matrix-sdk-crypto-nodejs/src/machine.rs +++ b/bindings/matrix-sdk-crypto-nodejs/src/machine.rs @@ -2,10 +2,12 @@ use std::{ collections::{BTreeMap, HashMap}, + mem::ManuallyDrop, + ops::Deref, sync::Arc, }; -use napi::bindgen_prelude::Either7; +use napi::bindgen_prelude::{within_runtime_if_available, Either7, ToNapiValue}; use napi_derive::*; use ruma::{serde::Raw, DeviceKeyAlgorithm, OwnedTransactionId, UInt}; use serde_json::{value::RawValue, Value as JsonValue}; @@ -16,11 +18,61 @@ use crate::{ sync_events, types, vodozemac, }; +/// The value used by the `OlmMachine` JS class. +/// +/// It has 2 states: `Opened` and `Closed`. Why maintaining the state here? +/// Because NodeJS has no way to drop an object explicitly, and we want to be +/// able to β€œclose” the `OlmMachine` to free all associated data. More over, +/// `napi-rs` doesn't allow a function to take the ownership of the type itself +/// (`fn close(self) { … }`). So we manage the state ourselves. +/// +/// Using the `OlmMachine` when its state is `Closed` will panic. +enum OlmMachineInner { + Opened(ManuallyDrop), + Closed, +} + +impl Drop for OlmMachineInner { + fn drop(&mut self) { + if let Self::Opened(machine) = self { + // SAFETY: `self` won't be used anymore after this `take`, so it's safe to do it + // here. + let machine = unsafe { ManuallyDrop::take(machine) }; + within_runtime_if_available(move || drop(machine)); + } + } +} + +impl Deref for OlmMachineInner { + type Target = matrix_sdk_crypto::OlmMachine; + + #[inline] + fn deref(&self) -> &Self::Target { + match self { + Self::Opened(machine) => machine, + Self::Closed => panic!("The `OlmMachine` has been closed, cannot use it anymore"), + } + } +} + +/// Represents the type of store an `OlmMachine` can use. +#[derive(Default)] +#[napi] +pub enum StoreType { + /// Use `matrix-sdk-sled`. + #[default] + Sled, + + /// Use `matrix-sdk-sqlite`. + Sqlite, +} + /// State machine implementation of the Olm/Megolm encryption protocol /// used for Matrix end to end encryption. +// #[napi(custom_finalize)] #[napi] pub struct OlmMachine { - inner: matrix_sdk_crypto::OlmMachine, + inner: OlmMachineInner, } #[napi] @@ -59,40 +111,56 @@ impl OlmMachine { device_id: &identifiers::DeviceId, store_path: Option, mut store_passphrase: Option, + store_type: Option, ) -> napi::Result { - let user_id = user_id.clone(); - let device_id = device_id.clone(); - - let store = if let Some(store_path) = store_path { - Some( - matrix_sdk_sled::SledCryptoStore::open(store_path, store_passphrase.as_deref()) - .await - .map(Arc::new) - .map_err(into_err)?, - ) - } else { - None - }; - - store_passphrase.zeroize(); + let user_id = user_id.clone().inner; + let device_id = device_id.clone().inner; + + let user_id = user_id.as_ref(); + let device_id = device_id.as_ref(); Ok(OlmMachine { - inner: match store { - Some(store) => matrix_sdk_crypto::OlmMachine::with_store( - user_id.inner.as_ref(), - device_id.inner.as_ref(), - store, - ) - .await - .map_err(into_err)?, - None => { - matrix_sdk_crypto::OlmMachine::new( - user_id.inner.as_ref(), - device_id.inner.as_ref(), - ) - .await + inner: OlmMachineInner::Opened(ManuallyDrop::new(match store_path { + Some(store_path) => { + let machine = match store_type.unwrap_or_default() { + StoreType::Sled => { + matrix_sdk_crypto::OlmMachine::with_store( + user_id, + device_id, + matrix_sdk_sled::SledCryptoStore::open( + store_path, + store_passphrase.as_deref(), + ) + .await + .map(Arc::new) + .map_err(into_err)?, + ) + .await + } + + StoreType::Sqlite => { + matrix_sdk_crypto::OlmMachine::with_store( + user_id, + device_id, + matrix_sdk_sqlite::SqliteCryptoStore::open( + store_path, + store_passphrase.as_deref(), + ) + .await + .map(Arc::new) + .map_err(into_err)?, + ) + .await + } + }; + + store_passphrase.zeroize(); + + machine.map_err(into_err)? } - }, + + None => matrix_sdk_crypto::OlmMachine::new(user_id, device_id).await, + })), }) } @@ -409,4 +477,21 @@ impl OlmMachine { pub async fn sign(&self, message: String) -> types::Signatures { self.inner.sign(message.as_str()).await.into() } + + /// Shut down the `OlmMachine`. + /// + /// The `OlmMachine` cannot be used after this method has been called, + /// otherwise it will panic. + /// + /// All associated resources will be closed too, like the crypto storage + /// connections. + /// + /// # Safety + /// + /// The caller is responsible to **not** use any objects that came from this + /// `OlmMachine` after this `close` method has been called. + #[napi(strict)] + pub fn close(&mut self) { + self.inner = OlmMachineInner::Closed; + } } diff --git a/bindings/matrix-sdk-crypto-nodejs/src/requests.rs b/bindings/matrix-sdk-crypto-nodejs/src/requests.rs index 9d1e6376dd1..03c27cbacb4 100644 --- a/bindings/matrix-sdk-crypto-nodejs/src/requests.rs +++ b/bindings/matrix-sdk-crypto-nodejs/src/requests.rs @@ -8,10 +8,13 @@ use matrix_sdk_crypto::requests::{ }; use napi::bindgen_prelude::{Either7, ToNapiValue}; use napi_derive::*; -use ruma::api::client::keys::{ - claim_keys::v3::Request as RumaKeysClaimRequest, - upload_keys::v3::Request as RumaKeysUploadRequest, - upload_signatures::v3::Request as RumaSignatureUploadRequest, +use ruma::{ + api::client::keys::{ + claim_keys::v3::Request as RumaKeysClaimRequest, + upload_keys::v3::Request as RumaKeysUploadRequest, + upload_signatures::v3::Request as RumaSignatureUploadRequest, + }, + events::EventContent, }; use crate::into_err; @@ -28,11 +31,10 @@ pub struct KeysUploadRequest { #[napi(readonly)] pub id: String, - /// A JSON-encoded object of form: + /// A JSON-encoded string containing the rest of the payload: `device_keys`, + /// `one_time_keys`, `fallback_keys`. /// - /// ```json - /// {"device_keys": …, "one_time_keys": …, "fallback_keys": …} - /// ``` + /// It represents the body of the HTTP request. #[napi(readonly)] pub body: String, } @@ -61,7 +63,7 @@ pub struct KeysQueryRequest { /// A JSON-encoded object of form: /// /// ```json - /// {"timeout": …, "device_keys": …, "token": …} + /// {"timeout": …, "one_time_keys": …} /// ``` #[napi(readonly)] pub body: String, @@ -92,7 +94,7 @@ pub struct KeysClaimRequest { /// A JSON-encoded object of form: /// /// ```json - /// {"timeout": …, "one_time_keys": …} + /// {"event_type": …, "txn_id": …, "messages": …} /// ``` #[napi(readonly)] pub body: String, @@ -119,11 +121,18 @@ pub struct ToDeviceRequest { #[napi(readonly)] pub id: String, - /// A JSON-encoded object of form: + /// A string representing the type of event being sent to each devices. + #[napi(readonly)] + pub event_type: String, + + /// A string representing a request identifier unique to the access token + /// used to send the request. + #[napi(readonly)] + pub txn_id: String, + + /// A JSON-encoded string containing the rest of the payload: `messages`. /// - /// ```json - /// {"event_type": …, "txn_id": …, "messages": …} - /// ``` + /// It represents the body of the HTTP request. #[napi(readonly)] pub body: String, } @@ -149,11 +158,9 @@ pub struct SignatureUploadRequest { #[napi(readonly)] pub id: String, - /// A JSON-encoded object of form: + /// A JSON-encoded string containing the rest of the payload: `signed_keys`. /// - /// ```json - /// {"signed_keys": …, "txn_id": …, "messages": …} - /// ``` + /// It represents the body of the HTTP request. #[napi(readonly)] pub body: String, } @@ -177,13 +184,25 @@ pub struct RoomMessageRequest { #[napi(readonly)] pub id: String, - /// A JSON-encoded object of form: + /// A string representing the room to send the event to. + #[napi(readonly)] + pub room_id: String, + + /// A string representing the transaction ID for this event. /// - /// ```json - /// {"room_id": …, "txn_id": …, "content": …} - /// ``` + /// Clients should generate an ID unique across requests with the same + /// access token; it will be used by the server to ensure idempotency of + /// requests. #[napi(readonly)] - pub body: String, + pub txn_id: String, + + /// A string representing the type of event to be sent. + #[napi(readonly)] + pub event_type: String, + + /// A JSON-encoded string containing the message's content. + #[napi(readonly, js_name = "body")] + pub content: String, } #[napi] @@ -205,11 +224,9 @@ pub struct KeysBackupRequest { #[napi(readonly)] pub id: String, - /// A JSON-encoded object of form: + /// A JSON-encoded string containing the rest of the payload: `rooms`. /// - /// ```json - /// {"rooms": …} - /// ``` + /// It represents the body of the HTTP request. #[napi(readonly)] pub body: String, } @@ -224,43 +241,89 @@ impl KeysBackupRequest { } macro_rules! request { - ($request:ident from $ruma_request:ident maps fields $( $field:ident $( { $transformation:expr } )? ),+ $(,)? ) => { - impl TryFrom<(String, &$ruma_request)> for $request { + ( + $destination_request:ident from $source_request:ident + $( extracts $( $field_name:ident : $field_type:tt ),+ $(,)? )? + $( $( and )? groups $( $grouped_field_name:ident $( { $grouped_field_transformation:expr } )? ),+ $(,)? )? + ) => { + impl TryFrom<(String, &$source_request)> for $destination_request { type Error = napi::Error; fn try_from( - (request_id, request): (String, &$ruma_request), + (request_id, request): (String, &$source_request), ) -> Result { - let mut map = serde_json::Map::new(); + request!( + @__try_from $destination_request from $source_request + (request_id = request_id.into(), request = request) + $( extracts [ $( $field_name : $field_type, )+ ] )? + $( groups [ $( $grouped_field_name $( { $grouped_field_transformation } )? , )+ ] )? + ) + } + } + }; + + ( + @__try_from $destination_request:ident from $source_request:ident + (request_id = $request_id:expr, request = $request:expr) + $( extracts [ $( $field_name:ident : $field_type:tt ),* $(,)? ] )? + $( groups [ $( $grouped_field_name:ident $( { $grouped_field_transformation:expr } )? ),* $(,)? ] )? + ) => { + { + Ok($destination_request { + id: $request_id, $( - let field = &request.$field; $( - let field = { - let $field = field; - - $transformation - }; - )? - map.insert(stringify!($field).to_owned(), serde_json::to_value(field).map_err(into_err)?); - )+ - let value = serde_json::Value::Object(map); - - Ok($request { - id: request_id, - body: serde_json::to_string(&value).map_err(into_err)?.into(), - }) - } + $field_name: request!(@__field $field_name : $field_type ; request = $request), + )* + )? + $( + body: { + let mut map = serde_json::Map::new(); + $( + + let field = &$request.$grouped_field_name; + $( + let field = { + let $grouped_field_name = field; + + $grouped_field_transformation + }; + )? + map.insert(stringify!($grouped_field_name).to_owned(), serde_json::to_value(field).map_err(into_err)?); + )* + let object = serde_json::Value::Object(map); + + serde_json::to_string(&object).map_err(into_err)?.into() + } + )? + }) } }; + + ( @__field $field_name:ident : $field_type:ident ; request = $request:expr ) => { + request!(@__field_type as $field_type ; request = $request, field_name = $field_name) + }; + + ( @__field_type as string ; request = $request:expr, field_name = $field_name:ident ) => { + $request.$field_name.to_string().into() + }; + + ( @__field_type as json ; request = $request:expr, field_name = $field_name:ident ) => { + serde_json::to_string(&$request.$field_name).map_err(into_err)?.into() + }; + + ( @__field_type as event_type ; request = $request:expr, field_name = $field_name:ident ) => { + $request.content.event_type().to_string().into() + }; } -request!(KeysUploadRequest from RumaKeysUploadRequest maps fields device_keys, one_time_keys, fallback_keys); -request!(KeysQueryRequest from RumaKeysQueryRequest maps fields timeout { timeout.as_ref().map(Duration::as_millis).map(u64::try_from).transpose().map_err(into_err)? }, device_keys, token); -request!(KeysClaimRequest from RumaKeysClaimRequest maps fields timeout { timeout.as_ref().map(Duration::as_millis).map(u64::try_from).transpose().map_err(into_err)? }, one_time_keys); -request!(ToDeviceRequest from RumaToDeviceRequest maps fields event_type, txn_id, messages); -request!(SignatureUploadRequest from RumaSignatureUploadRequest maps fields signed_keys); -request!(RoomMessageRequest from RumaRoomMessageRequest maps fields room_id, txn_id, content); -request!(KeysBackupRequest from RumaKeysBackupRequest maps fields rooms); +request!(KeysUploadRequest from RumaKeysUploadRequest groups device_keys, one_time_keys, fallback_keys); +request!(KeysQueryRequest from RumaKeysQueryRequest groups timeout { timeout.as_ref().map(Duration::as_millis).map(u64::try_from).transpose().map_err(into_err)? }, device_keys, token); +request!(KeysClaimRequest from RumaKeysClaimRequest groups timeout { timeout.as_ref().map(Duration::as_millis).map(u64::try_from).transpose().map_err(into_err)? }, one_time_keys); +request!(ToDeviceRequest from RumaToDeviceRequest extracts event_type: string, txn_id: string and groups messages); +request!(SignatureUploadRequest from RumaSignatureUploadRequest groups signed_keys); +request!(RoomMessageRequest from RumaRoomMessageRequest extracts room_id: string, txn_id: string, event_type: event_type, content: json); +request!(KeysBackupRequest from RumaKeysBackupRequest groups rooms); pub type OutgoingRequests = Either7< KeysUploadRequest, diff --git a/bindings/matrix-sdk-crypto-nodejs/src/responses.rs b/bindings/matrix-sdk-crypto-nodejs/src/responses.rs index d6488f52865..14eb94c890e 100644 --- a/bindings/matrix-sdk-crypto-nodejs/src/responses.rs +++ b/bindings/matrix-sdk-crypto-nodejs/src/responses.rs @@ -1,5 +1,3 @@ -use std::borrow::Borrow; - use matrix_sdk_common::deserialized_responses::{AlgorithmInfo, EncryptionInfo}; use matrix_sdk_crypto::IncomingResponse; use napi_derive::*; @@ -186,9 +184,14 @@ impl DecryptedRoomEvent { /// note this is the state of the device at the time of /// decryption. It may change in the future if a device gets /// verified or deleted. - #[napi(getter)] - pub fn verification_state(&self) -> Option { - Some(self.encryption_info.as_ref()?.verification_state.borrow().into()) + #[napi] + pub fn shield_state(&self, strict: bool) -> Option { + let state = &self.encryption_info.as_ref()?.verification_state; + if strict { + Some(state.to_shield_state_strict().into()) + } else { + Some(state.to_shield_state_lax().into()) + } } } diff --git a/bindings/matrix-sdk-crypto-nodejs/tests/attachment.test.js b/bindings/matrix-sdk-crypto-nodejs/tests/attachment.test.js index 86e3eaf8b53..99dc632c1f9 100644 --- a/bindings/matrix-sdk-crypto-nodejs/tests/attachment.test.js +++ b/bindings/matrix-sdk-crypto-nodejs/tests/attachment.test.js @@ -1,37 +1,41 @@ -const { Attachment, EncryptedAttachment } = require('../'); +const { Attachment, EncryptedAttachment } = require("../"); describe(Attachment.name, () => { - const originalData = 'hello'; + const originalData = "hello"; const textEncoder = new TextEncoder(); const textDecoder = new TextDecoder(); let encryptedAttachment; - - test('can encrypt data', () => { + + test("can encrypt data", () => { encryptedAttachment = Attachment.encrypt(textEncoder.encode(originalData)); const mediaEncryptionInfo = JSON.parse(encryptedAttachment.mediaEncryptionInfo); expect(mediaEncryptionInfo).toMatchObject({ - v: 'v2', + v: "v2", key: { kty: expect.any(String), - key_ops: expect.arrayContaining(['encrypt', 'decrypt']), + key_ops: expect.arrayContaining(["encrypt", "decrypt"]), alg: expect.any(String), k: expect.any(String), ext: expect.any(Boolean), }, iv: expect.stringMatching(/^[A-Za-z0-9\+/]+$/), hashes: { - sha256: expect.stringMatching(/^[A-Za-z0-9\+/]+$/) - } + sha256: expect.stringMatching(/^[A-Za-z0-9\+/]+$/), + }, }); const encryptedData = encryptedAttachment.encryptedData; - expect(encryptedData.every((i) => { i != 0 })).toStrictEqual(false); + expect( + encryptedData.every((i) => { + i != 0; + }), + ).toStrictEqual(false); }); - test('can decrypt data', () => { + test("can decrypt data", () => { expect(encryptedAttachment.hasMediaEncryptionInfoBeenConsumed).toStrictEqual(false); const decryptedAttachment = Attachment.decrypt(encryptedAttachment); @@ -40,34 +44,36 @@ describe(Attachment.name, () => { expect(encryptedAttachment.hasMediaEncryptionInfoBeenConsumed).toStrictEqual(true); }); - test('can only decrypt once', () => { + test("can only decrypt once", () => { expect(encryptedAttachment.hasMediaEncryptionInfoBeenConsumed).toStrictEqual(true); - expect(() => { textDecoder.decode(decryptedAttachment) }).toThrow() + expect(() => { + textDecoder.decode(decryptedAttachment); + }).toThrow(); }); }); describe(EncryptedAttachment.name, () => { - const originalData = 'hello'; + const originalData = "hello"; const textDecoder = new TextDecoder(); - test('can be created manually', () => { + test("can be created manually", () => { const encryptedAttachment = new EncryptedAttachment( new Uint8Array([24, 150, 67, 37, 144]), JSON.stringify({ - v: 'v2', + v: "v2", key: { - kty: 'oct', - key_ops: [ 'encrypt', 'decrypt' ], - alg: 'A256CTR', - k: 'QbNXUjuukFyEJ8cQZjJuzN6mMokg0HJIjx0wVMLf5BM', - ext: true + kty: "oct", + key_ops: ["encrypt", "decrypt"], + alg: "A256CTR", + k: "QbNXUjuukFyEJ8cQZjJuzN6mMokg0HJIjx0wVMLf5BM", + ext: true, }, - iv: 'xk2AcWkomiYAAAAAAAAAAA', + iv: "xk2AcWkomiYAAAAAAAAAAA", hashes: { - sha256: 'JsRbDXgOja4xvDiF3DwBuLHdxUzIrVYIuj7W/t3aEok' - } - }) + sha256: "JsRbDXgOja4xvDiF3DwBuLHdxUzIrVYIuj7W/t3aEok", + }, + }), ); expect(encryptedAttachment.hasMediaEncryptionInfoBeenConsumed).toStrictEqual(false); diff --git a/bindings/matrix-sdk-crypto-nodejs/tests/encryption.test.js b/bindings/matrix-sdk-crypto-nodejs/tests/encryption.test.js index 83241d6b199..aac9731509f 100644 --- a/bindings/matrix-sdk-crypto-nodejs/tests/encryption.test.js +++ b/bindings/matrix-sdk-crypto-nodejs/tests/encryption.test.js @@ -1,14 +1,14 @@ -const { EncryptionAlgorithm, EncryptionSettings, HistoryVisibility, VerificationState } = require('../'); +const { EncryptionAlgorithm, EncryptionSettings, HistoryVisibility, VerificationState } = require("../"); -describe('EncryptionAlgorithm', () => { - test('has the correct variant values', () => { +describe("EncryptionAlgorithm", () => { + test("has the correct variant values", () => { expect(EncryptionAlgorithm.OlmV1Curve25519AesSha2).toStrictEqual(0); expect(EncryptionAlgorithm.MegolmV1AesSha2).toStrictEqual(1); }); }); describe(EncryptionSettings.name, () => { - test('can be instantiated with default values', () => { + test("can be instantiated with default values", () => { const es = new EncryptionSettings(); expect(es.algorithm).toStrictEqual(EncryptionAlgorithm.MegolmV1AesSha2); @@ -17,20 +17,14 @@ describe(EncryptionSettings.name, () => { expect(es.historyVisibility).toStrictEqual(HistoryVisibility.Shared); }); - test('checks the history visibility values', () => { + test("checks the history visibility values", () => { const es = new EncryptionSettings(); es.historyVisibility = HistoryVisibility.Invited; expect(es.historyVisibility).toStrictEqual(HistoryVisibility.Invited); - expect(() => { es.historyVisibility = 42 }).toThrow(); - }); -}); - -describe('VerificationState', () => { - test('has the correct variant values', () => { - expect(VerificationState.Trusted).toStrictEqual(0); - expect(VerificationState.Untrusted).toStrictEqual(1); - expect(VerificationState.UnknownDevice).toStrictEqual(2); + expect(() => { + es.historyVisibility = 42; + }).toThrow(); }); }); diff --git a/bindings/matrix-sdk-crypto-nodejs/tests/events.test.js b/bindings/matrix-sdk-crypto-nodejs/tests/events.test.js index 8318c924418..cec4d57e06a 100644 --- a/bindings/matrix-sdk-crypto-nodejs/tests/events.test.js +++ b/bindings/matrix-sdk-crypto-nodejs/tests/events.test.js @@ -1,7 +1,7 @@ -const { HistoryVisibility } = require('../'); +const { HistoryVisibility } = require("../"); -describe('HistoryVisibility', () => { - test('has the correct variant values', () => { +describe("HistoryVisibility", () => { + test("has the correct variant values", () => { expect(HistoryVisibility.Invited).toStrictEqual(0); expect(HistoryVisibility.Joined).toStrictEqual(1); expect(HistoryVisibility.Shared).toStrictEqual(2); diff --git a/bindings/matrix-sdk-crypto-nodejs/tests/identifiers.test.js b/bindings/matrix-sdk-crypto-nodejs/tests/identifiers.test.js index ef8a17943ff..f08e67d008e 100644 --- a/bindings/matrix-sdk-crypto-nodejs/tests/identifiers.test.js +++ b/bindings/matrix-sdk-crypto-nodejs/tests/identifiers.test.js @@ -1,62 +1,80 @@ -const { UserId, DeviceId, DeviceKeyId, DeviceKeyAlgorithm, DeviceKeyAlgorithmName, RoomId, ServerName } = require('../'); +const { + UserId, + DeviceId, + DeviceKeyId, + DeviceKeyAlgorithm, + DeviceKeyAlgorithmName, + RoomId, + ServerName, +} = require("../"); describe(UserId.name, () => { - test('cannot be invalid', () => { - expect(() => { new UserId('@foobar') }).toThrow(); + test("cannot be invalid", () => { + expect(() => { + new UserId("@foobar"); + }).toThrow(); }); - const user = new UserId('@foo:bar.org'); + const user = new UserId("@foo:bar.org"); - test('localpart is present', () => { - expect(user.localpart).toStrictEqual('foo'); + test("localpart is present", () => { + expect(user.localpart).toStrictEqual("foo"); }); - test('server name is present', () => { + test("server name is present", () => { expect(user.serverName).toBeInstanceOf(ServerName); }); - test('user ID is not historical', () => { + test("user ID is not historical", () => { expect(user.isHistorical()).toStrictEqual(false); }); - test('can read the user ID as a string', () => { - expect(user.toString()).toStrictEqual('@foo:bar.org'); - }) + test("can read the user ID as a string", () => { + expect(user.toString()).toStrictEqual("@foo:bar.org"); + }); }); describe(DeviceId.name, () => { - const device = new DeviceId('foo'); + const device = new DeviceId("foo"); - test('can read the device ID as a string', () => { - expect(device.toString()).toStrictEqual('foo'); - }) + test("can read the device ID as a string", () => { + expect(device.toString()).toStrictEqual("foo"); + }); }); describe(DeviceKeyId.name, () => { for (const deviceKey of [ - { name: 'ed25519', - id: 'ed25519:foobar', - algorithmName: DeviceKeyAlgorithmName.Ed25519, - algorithm: 'ed25519', - deviceId: 'foobar' }, - - { name: 'curve25519', - id: 'curve25519:foobar', - algorithmName: DeviceKeyAlgorithmName.Curve25519, - algorithm: 'curve25519', - deviceId: 'foobar' }, - - { name: 'signed curve25519', - id: 'signed_curve25519:foobar', - algorithmName: DeviceKeyAlgorithmName.SignedCurve25519, - algorithm: 'signed_curve25519', - deviceId: 'foobar' }, - - { name: 'unknown', - id: 'hello:foobar', - algorithmName: DeviceKeyAlgorithmName.Unknown, - algorithm: 'hello', - deviceId: 'foobar' }, + { + name: "ed25519", + id: "ed25519:foobar", + algorithmName: DeviceKeyAlgorithmName.Ed25519, + algorithm: "ed25519", + deviceId: "foobar", + }, + + { + name: "curve25519", + id: "curve25519:foobar", + algorithmName: DeviceKeyAlgorithmName.Curve25519, + algorithm: "curve25519", + deviceId: "foobar", + }, + + { + name: "signed curve25519", + id: "signed_curve25519:foobar", + algorithmName: DeviceKeyAlgorithmName.SignedCurve25519, + algorithm: "signed_curve25519", + deviceId: "foobar", + }, + + { + name: "unknown", + id: "hello:foobar", + algorithmName: DeviceKeyAlgorithmName.Unknown, + algorithm: "hello", + deviceId: "foobar", + }, ]) { test(`${deviceKey.name} algorithm`, () => { const dk = new DeviceKeyId(deviceKey.id); @@ -69,8 +87,8 @@ describe(DeviceKeyId.name, () => { } }); -describe('DeviceKeyAlgorithmName', () => { - test('has the correct variants', () => { +describe("DeviceKeyAlgorithmName", () => { + test("has the correct variants", () => { expect(DeviceKeyAlgorithmName.Ed25519).toStrictEqual(0); expect(DeviceKeyAlgorithmName.Curve25519).toStrictEqual(1); expect(DeviceKeyAlgorithmName.SignedCurve25519).toStrictEqual(2); @@ -79,40 +97,44 @@ describe('DeviceKeyAlgorithmName', () => { }); describe(RoomId.name, () => { - test('cannot be invalid', () => { - expect(() => { new RoomId('!foo') }).toThrow(); + test("cannot be invalid", () => { + expect(() => { + new RoomId("!foo"); + }).toThrow(); }); - const room = new RoomId('!foo:bar.org'); + const room = new RoomId("!foo:bar.org"); - test('localpart is present', () => { - expect(room.localpart).toStrictEqual('foo'); + test("localpart is present", () => { + expect(room.localpart).toStrictEqual("foo"); }); - test('server name is present', () => { + test("server name is present", () => { expect(room.serverName).toBeInstanceOf(ServerName); }); - test('can read the room ID as string', () => { - expect(room.toString()).toStrictEqual('!foo:bar.org'); + test("can read the room ID as string", () => { + expect(room.toString()).toStrictEqual("!foo:bar.org"); }); }); describe(ServerName.name, () => { - test('cannot be invalid', () => { - expect(() => { new ServerName('@foobar') }).toThrow() + test("cannot be invalid", () => { + expect(() => { + new ServerName("@foobar"); + }).toThrow(); }); - test('host is present', () => { - expect(new ServerName('foo.org').host).toStrictEqual('foo.org'); + test("host is present", () => { + expect(new ServerName("foo.org").host).toStrictEqual("foo.org"); }); - test('port can be optional', () => { - expect(new ServerName('foo.org').port).toStrictEqual(null); - expect(new ServerName('foo.org:1234').port).toStrictEqual(1234); + test("port can be optional", () => { + expect(new ServerName("foo.org").port).toStrictEqual(null); + expect(new ServerName("foo.org:1234").port).toStrictEqual(1234); }); - test('server is not an IP literal', () => { - expect(new ServerName('foo.org').isIpLiteral()).toStrictEqual(false); + test("server is not an IP literal", () => { + expect(new ServerName("foo.org").isIpLiteral()).toStrictEqual(false); }); }); diff --git a/bindings/matrix-sdk-crypto-nodejs/tests/machine.test.js b/bindings/matrix-sdk-crypto-nodejs/tests/machine.test.js index a3b424683c5..90ed77f194d 100644 --- a/bindings/matrix-sdk-crypto-nodejs/tests/machine.test.js +++ b/bindings/matrix-sdk-crypto-nodejs/tests/machine.test.js @@ -1,74 +1,147 @@ -const { OlmMachine, UserId, DeviceId, DeviceKeyId, RoomId, DeviceLists, RequestType, KeysUploadRequest, KeysQueryRequest, KeysClaimRequest, EncryptionSettings, DecryptedRoomEvent, VerificationState, CrossSigningStatus, MaybeSignature } = require('../'); -const path = require('path'); -const os = require('os'); -const fs = require('fs/promises'); +const { + OlmMachine, + UserId, + DeviceId, + DeviceKeyId, + RoomId, + DeviceLists, + RequestType, + KeysUploadRequest, + KeysQueryRequest, + KeysClaimRequest, + EncryptionSettings, + DecryptedRoomEvent, + VerificationState, + CrossSigningStatus, + MaybeSignature, + StoreType, + ShieldColor, +} = require("../"); +const path = require("path"); +const os = require("os"); +const fs = require("fs/promises"); + +describe("StoreType", () => { + test("has the correct variant values", () => { + expect(StoreType.Sled).toStrictEqual(0); + expect(StoreType.Sqlite).toStrictEqual(1); + }); +}); describe(OlmMachine.name, () => { - test('cannot be instantiated with the constructor', () => { - expect(() => { new OlmMachine() }).toThrow(); + test("cannot be instantiated with the constructor", () => { + expect(() => { + new OlmMachine(); + }).toThrow(); }); - test('can be instantiated with the async initializer', async () => { - expect(await OlmMachine.initialize(new UserId('@foo:bar.org'), new DeviceId('baz'))).toBeInstanceOf(OlmMachine); + test("can be instantiated with the async initializer", async () => { + expect(await OlmMachine.initialize(new UserId("@foo:bar.org"), new DeviceId("baz"))).toBeInstanceOf(OlmMachine); }); - describe('can be instantiated with a store', () => { - test('with no passphrase', async () => { - const temp_directory = await fs.mkdtemp(path.join(os.tmpdir(), 'matrix-sdk-crypto--')); - - expect(await OlmMachine.initialize(new UserId('@foo:bar.org'), new DeviceId('baz'), temp_directory)).toBeInstanceOf(OlmMachine); - }); - - test('with a passphrase', async () => { - const temp_directory = await fs.mkdtemp(path.join(os.tmpdir(), 'matrix-sdk-crypto--')); + describe("can be instantiated with a store", () => { + for (const [store_type, store_name] of [ + [StoreType.Sled, "sled"], + [StoreType.Sqlite, "sqlite"], + [null, "default"], + ]) { + test(`with no passphrase (store: ${store_name})`, async () => { + const temp_directory = await fs.mkdtemp(path.join(os.tmpdir(), "matrix-sdk-crypto--")); + + expect( + await OlmMachine.initialize( + new UserId("@foo:bar.org"), + new DeviceId("baz"), + temp_directory, + null, + store_type, + ), + ).toBeInstanceOf(OlmMachine); + }); - expect(await OlmMachine.initialize(new UserId('@foo:bar.org'), new DeviceId('baz'), temp_directory, 'hello')).toBeInstanceOf(OlmMachine); - }); + test(`with a passphrase (store: ${store_name})`, async () => { + const temp_directory = await fs.mkdtemp(path.join(os.tmpdir(), "matrix-sdk-crypto--")); + + expect( + await OlmMachine.initialize( + new UserId("@foo:bar.org"), + new DeviceId("baz"), + temp_directory, + "hello", + store_type, + ), + ).toBeInstanceOf(OlmMachine); + }); + } }); - - const user = new UserId('@alice:example.org'); - const device = new DeviceId('foobar'); - const room = new RoomId('!baz:matrix.org'); + + const user = new UserId("@alice:example.org"); + const device = new DeviceId("foobar"); + const room = new RoomId("!baz:matrix.org"); function machine(new_user, new_device) { return OlmMachine.initialize(new_user || user, new_device || device); } - test('can read user ID', async () => { + test("can drop/close, and then re-open", async () => { + const temp_directory = await fs.mkdtemp(path.join(os.tmpdir(), "matrix-sdk-crypto--")); + + let m1 = await OlmMachine.initialize( + new UserId("@test:bar.org"), + new DeviceId("device"), + temp_directory, + "hello", + ); + m1.close(); + + let m2 = await OlmMachine.initialize( + new UserId("@test:bar.org"), + new DeviceId("device"), + temp_directory, + "hello", + ); + m2.close(); + }); + + test("can read user ID", async () => { expect((await machine()).userId.toString()).toStrictEqual(user.toString()); }); - test('can read device ID', async () => { + test("can read device ID", async () => { expect((await machine()).deviceId.toString()).toStrictEqual(device.toString()); }); - test('can read identity keys', async () => { + test("can read identity keys", async () => { const identityKeys = (await machine()).identityKeys; expect(identityKeys.ed25519.toBase64()).toMatch(/^[A-Za-z0-9+/]+$/); expect(identityKeys.curve25519.toBase64()).toMatch(/^[A-Za-z0-9+/]+$/); }); - test('can receive sync changes', async () => { + test("can receive sync changes", async () => { const m = await machine(); const toDeviceEvents = JSON.stringify([]); const changedDevices = new DeviceLists(); const oneTimeKeyCounts = {}; const unusedFallbackKeys = []; - const receiveSyncChanges = JSON.parse(await m.receiveSyncChanges(toDeviceEvents, changedDevices, oneTimeKeyCounts, unusedFallbackKeys)); + const receiveSyncChanges = JSON.parse( + await m.receiveSyncChanges(toDeviceEvents, changedDevices, oneTimeKeyCounts, unusedFallbackKeys), + ); expect(receiveSyncChanges).toEqual([]); }); - test('can get the outgoing requests that need to be send out', async () => { + test("can get the outgoing requests that need to be send out", async () => { const m = await machine(); const toDeviceEvents = JSON.stringify([]); const changedDevices = new DeviceLists(); const oneTimeKeyCounts = {}; const unusedFallbackKeys = []; - const receiveSyncChanges = JSON.parse(await m.receiveSyncChanges(toDeviceEvents, changedDevices, oneTimeKeyCounts, unusedFallbackKeys)); + const receiveSyncChanges = JSON.parse( + await m.receiveSyncChanges(toDeviceEvents, changedDevices, oneTimeKeyCounts, unusedFallbackKeys), + ); expect(receiveSyncChanges).toEqual([]); @@ -98,12 +171,12 @@ describe(OlmMachine.name, () => { } }); - describe('setup workflow to mark requests as sent', () => { + describe("setup workflow to mark requests as sent", () => { let m; let ougoingRequests; beforeAll(async () => { - m = await machine(new UserId('@alice:example.org'), new DeviceId('DEVICEID')); + m = await machine(new UserId("@alice:example.org"), new DeviceId("DEVICEID")); const toDeviceEvents = JSON.stringify([]); const changedDevices = new DeviceLists(); @@ -116,17 +189,17 @@ describe(OlmMachine.name, () => { expect(outgoingRequests).toHaveLength(2); }); - test('can mark requests as sent', async () => { + test("can mark requests as sent", async () => { { const request = outgoingRequests[0]; expect(request).toBeInstanceOf(KeysUploadRequest); // https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3keysupload const hypothetical_response = JSON.stringify({ - "one_time_key_counts": { - "curve25519": 10, - "signed_curve25519": 20 - } + one_time_key_counts: { + curve25519: 10, + signed_curve25519: 20, + }, }); const marked = await m.markRequestAsSent(request.id, request.type, hypothetical_response); expect(marked).toStrictEqual(true); @@ -138,31 +211,29 @@ describe(OlmMachine.name, () => { // https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3keysquery const hypothetical_response = JSON.stringify({ - "device_keys": { + device_keys: { "@alice:example.org": { - "JLAFKJWSCS": { - "algorithms": [ - "m.olm.v1.curve25519-aes-sha2", - "m.megolm.v1.aes-sha2" - ], - "device_id": "JLAFKJWSCS", - "keys": { + JLAFKJWSCS: { + algorithms: ["m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + device_id: "JLAFKJWSCS", + keys: { "curve25519:JLAFKJWSCS": "wjLpTLRqbqBzLs63aYaEv2Boi6cFEbbM/sSRQ2oAKk4", - "ed25519:JLAFKJWSCS": "nE6W2fCblxDcOFmeEtCHNl8/l8bXcu7GKyAswA4r3mM" + "ed25519:JLAFKJWSCS": "nE6W2fCblxDcOFmeEtCHNl8/l8bXcu7GKyAswA4r3mM", }, - "signatures": { + signatures: { "@alice:example.org": { - "ed25519:JLAFKJWSCS": "m53Wkbh2HXkc3vFApZvCrfXcX3AI51GsDHustMhKwlv3TuOJMj4wistcOTM8q2+e/Ro7rWFUb9ZfnNbwptSUBA" - } + "ed25519:JLAFKJWSCS": + "m53Wkbh2HXkc3vFApZvCrfXcX3AI51GsDHustMhKwlv3TuOJMj4wistcOTM8q2+e/Ro7rWFUb9ZfnNbwptSUBA", + }, }, - "unsigned": { - "device_display_name": "Alice's mobile phone" + unsigned: { + device_display_name: "Alice's mobile phone", }, - "user_id": "@alice:example.org" - } - } + user_id: "@alice:example.org", + }, + }, }, - "failures": {} + failures: {}, }); const marked = await m.markRequestAsSent(request.id, request.type, hypothetical_response); expect(marked).toStrictEqual(true); @@ -170,122 +241,121 @@ describe(OlmMachine.name, () => { }); }); - describe('setup workflow to encrypt/decrypt events', () => { + describe("setup workflow to encrypt/decrypt events", () => { let m; - const user = new UserId('@alice:example.org'); - const device = new DeviceId('JLAFKJWSCS'); - const room = new RoomId('!test:localhost'); + const user = new UserId("@alice:example.org"); + const device = new DeviceId("JLAFKJWSCS"); + const room = new RoomId("!test:localhost"); beforeAll(async () => { m = await machine(user, device); }); - - test('can pass keysquery and keysclaim requests directly', async () => { + + test("can pass keysquery and keysclaim requests directly", async () => { { // derived from https://github.com/matrix-org/matrix-rust-sdk/blob/7f49618d350fab66b7e1dc4eaf64ec25ceafd658/benchmarks/benches/crypto_bench/keys_query.json const hypothetical_response = JSON.stringify({ - "device_keys": { + device_keys: { "@example:localhost": { - "AFGUOBTZWM": { - "algorithms": [ - "m.olm.v1.curve25519-aes-sha2", - "m.megolm.v1.aes-sha2" - ], - "device_id": "AFGUOBTZWM", - "keys": { + AFGUOBTZWM: { + algorithms: ["m.olm.v1.curve25519-aes-sha2", "m.megolm.v1.aes-sha2"], + device_id: "AFGUOBTZWM", + keys: { "curve25519:AFGUOBTZWM": "boYjDpaC+7NkECQEeMh5dC+I1+AfriX0VXG2UV7EUQo", - "ed25519:AFGUOBTZWM": "NayrMQ33ObqMRqz6R9GosmHdT6HQ6b/RX/3QlZ2yiec" + "ed25519:AFGUOBTZWM": "NayrMQ33ObqMRqz6R9GosmHdT6HQ6b/RX/3QlZ2yiec", }, - "signatures": { + signatures: { "@example:localhost": { - "ed25519:AFGUOBTZWM": "RoSWvru1jj6fs2arnTedWsyIyBmKHMdOu7r9gDi0BZ61h9SbCK2zLXzuJ9ZFLao2VvA0yEd7CASCmDHDLYpXCA" - } + "ed25519:AFGUOBTZWM": + "RoSWvru1jj6fs2arnTedWsyIyBmKHMdOu7r9gDi0BZ61h9SbCK2zLXzuJ9ZFLao2VvA0yEd7CASCmDHDLYpXCA", + }, + }, + user_id: "@example:localhost", + unsigned: { + device_display_name: "rust-sdk", }, - "user_id": "@example:localhost", - "unsigned": { - "device_display_name": "rust-sdk" - } }, - } + }, }, - "failures": {}, - "master_keys": { + failures: {}, + master_keys: { "@example:localhost": { - "user_id": "@example:localhost", - "usage": [ - "master" - ], - "keys": { - "ed25519:n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU": "n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU" + user_id: "@example:localhost", + usage: ["master"], + keys: { + "ed25519:n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU": + "n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU", }, - "signatures": { + signatures: { "@example:localhost": { - "ed25519:TCSJXPWGVS": "+j9G3L41I1fe0++wwusTTQvbboYW0yDtRWUEujhwZz4MAltjLSfJvY0hxhnz+wHHmuEXvQDen39XOpr1p29sAg" - } - } - } + "ed25519:TCSJXPWGVS": + "+j9G3L41I1fe0++wwusTTQvbboYW0yDtRWUEujhwZz4MAltjLSfJvY0hxhnz+wHHmuEXvQDen39XOpr1p29sAg", + }, + }, + }, }, - "self_signing_keys": { + self_signing_keys: { "@example:localhost": { - "user_id": "@example:localhost", - "usage": [ - "self_signing" - ], - "keys": { - "ed25519:kQXOuy639Yt47mvNTdrIluoC6DMvfbZLYbxAmwiDyhI": "kQXOuy639Yt47mvNTdrIluoC6DMvfbZLYbxAmwiDyhI" + user_id: "@example:localhost", + usage: ["self_signing"], + keys: { + "ed25519:kQXOuy639Yt47mvNTdrIluoC6DMvfbZLYbxAmwiDyhI": + "kQXOuy639Yt47mvNTdrIluoC6DMvfbZLYbxAmwiDyhI", }, - "signatures": { + signatures: { "@example:localhost": { - "ed25519:n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU": "q32ifix/qyRpvmegw2BEJklwoBCAJldDNkcX+fp+lBA4Rpyqtycxge6BA4hcJdxYsy3oV0IHRuugS8rJMMFyAA" - } - } - } + "ed25519:n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU": + "q32ifix/qyRpvmegw2BEJklwoBCAJldDNkcX+fp+lBA4Rpyqtycxge6BA4hcJdxYsy3oV0IHRuugS8rJMMFyAA", + }, + }, + }, }, - "user_signing_keys": { + user_signing_keys: { "@example:localhost": { - "user_id": "@example:localhost", - "usage": [ - "user_signing" - ], - "keys": { - "ed25519:g4ED07Fnqf3GzVWNN1pZ0IFrPQVdqQf+PYoJNH4eE0s": "g4ED07Fnqf3GzVWNN1pZ0IFrPQVdqQf+PYoJNH4eE0s" + user_id: "@example:localhost", + usage: ["user_signing"], + keys: { + "ed25519:g4ED07Fnqf3GzVWNN1pZ0IFrPQVdqQf+PYoJNH4eE0s": + "g4ED07Fnqf3GzVWNN1pZ0IFrPQVdqQf+PYoJNH4eE0s", }, - "signatures": { + signatures: { "@example:localhost": { - "ed25519:n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU": "nKQu8alQKDefNbZz9luYPcNj+Z+ouQSot4fU/A23ELl1xrI06QVBku/SmDx0sIW1ytso0Cqwy1a+3PzCa1XABg" - } - } - } - } + "ed25519:n2lpJGx0LiKnuNE1IucZP3QExrD4SeRP0veBHPe3XUU": + "nKQu8alQKDefNbZz9luYPcNj+Z+ouQSot4fU/A23ELl1xrI06QVBku/SmDx0sIW1ytso0Cqwy1a+3PzCa1XABg", + }, + }, + }, + }, }); - const marked = await m.markRequestAsSent('foo', RequestType.KeysQuery, hypothetical_response); + const marked = await m.markRequestAsSent("foo", RequestType.KeysQuery, hypothetical_response); } { // derived from https://github.com/matrix-org/matrix-rust-sdk/blob/7f49618d350fab66b7e1dc4eaf64ec25ceafd658/benchmarks/benches/crypto_bench/keys_claim.json const hypothetical_response = JSON.stringify({ - "one_time_keys": { + one_time_keys: { "@example:localhost": { - "AFGUOBTZWM": { + AFGUOBTZWM: { "signed_curve25519:AAAABQ": { - "key": "9IGouMnkB6c6HOd4xUsNv4i3Dulb4IS96TzDordzOws", - "signatures": { + key: "9IGouMnkB6c6HOd4xUsNv4i3Dulb4IS96TzDordzOws", + signatures: { "@example:localhost": { - "ed25519:AFGUOBTZWM": "2bvUbbmJegrV0eVP/vcJKuIWC3kud+V8+C0dZtg4dVovOSJdTP/iF36tQn2bh5+rb9xLlSeztXBdhy4c+LiOAg" - } - } - } + "ed25519:AFGUOBTZWM": + "2bvUbbmJegrV0eVP/vcJKuIWC3kud+V8+C0dZtg4dVovOSJdTP/iF36tQn2bh5+rb9xLlSeztXBdhy4c+LiOAg", + }, + }, + }, }, - } + }, }, - "failures": {} + failures: {}, }); - const marked = await m.markRequestAsSent('bar', RequestType.KeysClaim, hypothetical_response); + const marked = await m.markRequestAsSent("bar", RequestType.KeysClaim, hypothetical_response); } }); - test('can share a room key', async () => { - const other_users = [new UserId('@example:localhost')]; + test("can share a room key", async () => { + const other_users = [new UserId("@example:localhost")]; const requests = JSON.parse(await m.shareRoomKey(room, other_users, new EncryptionSettings())); @@ -293,19 +363,21 @@ describe(OlmMachine.name, () => { expect(requests[0].event_type).toBeDefined(); expect(requests[0].txn_id).toBeDefined(); expect(requests[0].messages).toBeDefined(); - expect(requests[0].messages['@example:localhost']).toBeDefined(); + expect(requests[0].messages["@example:localhost"]).toBeDefined(); }); let encrypted; - test('can encrypt an event', async () => { - encrypted = JSON.parse(await m.encryptRoomEvent( - room, - 'm.room.message', - JSON.stringify({ - "hello": "world" - }), - )); + test("can encrypt an event", async () => { + encrypted = JSON.parse( + await m.encryptRoomEvent( + room, + "m.room.message", + JSON.stringify({ + hello: "world", + }), + ), + ); expect(encrypted.algorithm).toBeDefined(); expect(encrypted.ciphertext).toBeDefined(); @@ -314,17 +386,17 @@ describe(OlmMachine.name, () => { expect(encrypted.session_id).toBeDefined(); }); - test('can decrypt an event', async () => { + test("can decrypt an event", async () => { const decrypted = await m.decryptRoomEvent( JSON.stringify({ - "type": "m.room.encrypted", - "event_id": "$xxxxx:example.org", - "origin_server_ts": Date.now(), - "sender": user.toString(), + type: "m.room.encrypted", + event_id: "$xxxxx:example.org", + origin_server_ts: Date.now(), + sender: user.toString(), content: encrypted, unsigned: { - "age": 1234 - } + age: 1234, + }, }), room, ); @@ -339,17 +411,18 @@ describe(OlmMachine.name, () => { expect(decrypted.senderCurve25519Key).toBeDefined(); expect(decrypted.senderClaimedEd25519Key).toBeDefined(); expect(decrypted.forwardingCurve25519KeyChain).toHaveLength(0); - expect(decrypted.verificationState).toStrictEqual(VerificationState.Trusted); + expect(decrypted.shieldState(true).color).toStrictEqual(ShieldColor.Red); + expect(decrypted.shieldState(false).color).toStrictEqual(ShieldColor.Red); }); }); - test('can update tracked users', async () => { + test("can update tracked users", async () => { const m = await machine(); expect(await m.updateTrackedUsers([user])).toStrictEqual(undefined); }); - test('can read cross-signing status', async () => { + test("can read cross-signing status", async () => { const m = await machine(); const crossSigningStatus = await m.crossSigningStatus(); @@ -359,9 +432,9 @@ describe(OlmMachine.name, () => { expect(crossSigningStatus.hasUserSigning).toStrictEqual(false); }); - test('can sign a message', async () => { + test("can sign a message", async () => { const m = await machine(); - const signatures = await m.sign('foo'); + const signatures = await m.sign("foo"); expect(signatures.isEmpty).toStrictEqual(false); expect(signatures.count).toStrictEqual(1n); @@ -375,26 +448,26 @@ describe(OlmMachine.name, () => { expect(signature).toMatchObject({ "ed25519:foobar": expect.any(MaybeSignature), }); - expect(signature['ed25519:foobar'].isValid).toStrictEqual(true); - expect(signature['ed25519:foobar'].isInvalid).toStrictEqual(false); - expect(signature['ed25519:foobar'].invalidSignatureSource).toBeNull(); + expect(signature["ed25519:foobar"].isValid).toStrictEqual(true); + expect(signature["ed25519:foobar"].isInvalid).toStrictEqual(false); + expect(signature["ed25519:foobar"].invalidSignatureSource).toBeNull(); - base64 = signature['ed25519:foobar'].signature.toBase64(); + base64 = signature["ed25519:foobar"].signature.toBase64(); expect(base64).toMatch(/^[A-Za-z0-9\+/]+$/); - expect(signature['ed25519:foobar'].signature.ed25519.toBase64()).toStrictEqual(base64); + expect(signature["ed25519:foobar"].signature.ed25519.toBase64()).toStrictEqual(base64); } // `getSignature` { - const signature = signatures.getSignature(user, new DeviceKeyId('ed25519:foobar')); + const signature = signatures.getSignature(user, new DeviceKeyId("ed25519:foobar")); expect(signature.toBase64()).toStrictEqual(base64); } // Unknown signatures. { - expect(signatures.get(new UserId('@hello:example.org'))).toBeNull(); - expect(signatures.getSignature(user, new DeviceKeyId('world:foobar'))).toBeNull(); + expect(signatures.get(new UserId("@hello:example.org"))).toBeNull(); + expect(signatures.getSignature(user, new DeviceKeyId("world:foobar"))).toBeNull(); } }); }); diff --git a/bindings/matrix-sdk-crypto-nodejs/tests/requests.test.js b/bindings/matrix-sdk-crypto-nodejs/tests/requests.test.js index 96cf946b30b..51aeb394441 100644 --- a/bindings/matrix-sdk-crypto-nodejs/tests/requests.test.js +++ b/bindings/matrix-sdk-crypto-nodejs/tests/requests.test.js @@ -1,7 +1,16 @@ -const { RequestType, KeysUploadRequest, KeysQueryRequest, KeysClaimRequest, ToDeviceRequest, SignatureUploadRequest, RoomMessageRequest, KeysBackupRequest } = require('../'); +const { + RequestType, + KeysUploadRequest, + KeysQueryRequest, + KeysClaimRequest, + ToDeviceRequest, + SignatureUploadRequest, + RoomMessageRequest, + KeysBackupRequest, +} = require("../"); -describe('RequestType', () => { - test('has the correct variant values', () => { +describe("RequestType", () => { + test("has the correct variant values", () => { expect(RequestType.KeysUpload).toStrictEqual(0); expect(RequestType.KeysQuery).toStrictEqual(1); expect(RequestType.KeysClaim).toStrictEqual(2); @@ -22,8 +31,10 @@ for (const request of [ KeysBackupRequest, ]) { describe(request.name, () => { - test('cannot be instantiated', () => { - expect(() => { new (request)() }).toThrow(); + test("cannot be instantiated", () => { + expect(() => { + new request(); + }).toThrow(); }); - }) + }); } diff --git a/bindings/matrix-sdk-crypto-nodejs/tests/responses.test.js b/bindings/matrix-sdk-crypto-nodejs/tests/responses.test.js index ddd68c828b9..924f00d7b50 100644 --- a/bindings/matrix-sdk-crypto-nodejs/tests/responses.test.js +++ b/bindings/matrix-sdk-crypto-nodejs/tests/responses.test.js @@ -1,7 +1,9 @@ -const { DecryptedRoomEvent } = require('../'); +const { DecryptedRoomEvent } = require("../"); describe(DecryptedRoomEvent.name, () => { - test('cannot be instantiated', () => { - expect(() => { new DecryptedRoomEvent() }).toThrow(); + test("cannot be instantiated", () => { + expect(() => { + new DecryptedRoomEvent(); + }).toThrow(); }); }); diff --git a/bindings/matrix-sdk-crypto-nodejs/tests/sync_events.test.js b/bindings/matrix-sdk-crypto-nodejs/tests/sync_events.test.js index b6db25e702d..c52e180ac35 100644 --- a/bindings/matrix-sdk-crypto-nodejs/tests/sync_events.test.js +++ b/bindings/matrix-sdk-crypto-nodejs/tests/sync_events.test.js @@ -1,7 +1,7 @@ -const { DeviceLists, UserId } = require('../'); +const { DeviceLists, UserId } = require("../"); describe(DeviceLists.name, () => { - test('can be empty', () => { + test("can be empty", () => { const empty = new DeviceLists(); expect(empty.isEmpty()).toStrictEqual(true); @@ -9,7 +9,7 @@ describe(DeviceLists.name, () => { expect(empty.left).toHaveLength(0); }); - test('can be coerced empty', () => { + test("can be coerced empty", () => { const empty = new DeviceLists([], []); expect(empty.isEmpty()).toStrictEqual(true); @@ -17,15 +17,15 @@ describe(DeviceLists.name, () => { expect(empty.left).toHaveLength(0); }); - test('returns the correct `changed` and `left`', () => { - const list = new DeviceLists([new UserId('@foo:bar.org')], [new UserId('@baz:qux.org')]); + test("returns the correct `changed` and `left`", () => { + const list = new DeviceLists([new UserId("@foo:bar.org")], [new UserId("@baz:qux.org")]); expect(list.isEmpty()).toStrictEqual(false); expect(list.changed).toHaveLength(1); - expect(list.changed[0].toString()).toStrictEqual('@foo:bar.org'); + expect(list.changed[0].toString()).toStrictEqual("@foo:bar.org"); expect(list.left).toHaveLength(1); - expect(list.left[0].toString()).toStrictEqual('@baz:qux.org'); + expect(list.left[0].toString()).toStrictEqual("@baz:qux.org"); }); }); diff --git a/bindings/matrix-sdk-crypto-nodejs/tsconfig.json b/bindings/matrix-sdk-crypto-nodejs/tsconfig.json index 584b1ddac01..fbbcf41053d 100644 --- a/bindings/matrix-sdk-crypto-nodejs/tsconfig.json +++ b/bindings/matrix-sdk-crypto-nodejs/tsconfig.json @@ -5,6 +5,6 @@ "typedocOptions": { "entryPoints": ["index.d.ts"], "out": "docs", - "readme": "README.md", + "readme": "README.md" } } diff --git a/bindings/matrix-sdk-ffi/Cargo.toml b/bindings/matrix-sdk-ffi/Cargo.toml index 06e5e2fefe9..8dec0adf443 100644 --- a/bindings/matrix-sdk-ffi/Cargo.toml +++ b/bindings/matrix-sdk-ffi/Cargo.toml @@ -18,9 +18,10 @@ uniffi = { workspace = true, features = ["build"] } [dependencies] anyhow = { workspace = true } base64 = "0.21" +eyeball = { workspace = true } +eyeball-im = { workspace = true } extension-trait = "1.0.1" futures-core = "0.3.17" -futures-signals = { version = "0.3.30", default-features = false } futures-util = { version = "0.3.17", default-features = false } mime = "0.3.16" # FIXME: we currently can't feature flag anything in the api.udl, therefore we must enforce experimental-sliding-sync being exposed here.. @@ -28,23 +29,51 @@ mime = "0.3.16" once_cell = { workspace = true } opentelemetry = { version = "0.18.0", features = ["rt-tokio"] } opentelemetry-otlp = { version = "0.11.0", features = ["tokio", "reqwest-client", "http-proto"] } +ruma = { workspace = true, features = ["unstable-sanitize", "unstable-unspecified"] } sanitize-filename-reader-friendly = "2.2.1" serde_json = { workspace = true } thiserror = { workspace = true } +tracing = { workspace = true } tracing-opentelemetry = { version = "0.18.0" } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } tokio = { version = "1", features = ["rt-multi-thread", "macros"] } tokio-stream = "0.1.8" uniffi = { workspace = true } +url = "2.2.2" zeroize = { workspace = true } - [target.'cfg(target_os = "android")'.dependencies] -tracing = { version = "0.1.29", default-features = false, features = ["log"] } -android_logger = "0.11" log-panics = { version = "2", features = ["with-backtrace"]} -matrix-sdk = { path = "../../crates/matrix-sdk", default-features = false, features = ["anyhow", "experimental-timeline", "e2e-encryption", "sled", "markdown", "experimental-sliding-sync", "socks", "rustls-tls"], version = "0.6.0" } +tracing-android = "0.2.0" [target.'cfg(not(target_os = "android"))'.dependencies] tracing = { workspace = true } tracing-subscriber = { version = "0.3", features = ["env-filter"] } -matrix-sdk = { path = "../../crates/matrix-sdk", features = ["anyhow", "experimental-timeline", "markdown", "experimental-sliding-sync", "socks"], version = "0.6.0" } + +[target.'cfg(target_os = "android")'.dependencies.matrix-sdk] +path = "../../crates/matrix-sdk" +default-features = false +features = [ + "anyhow", + "experimental-sliding-sync", + "experimental-timeline", + "e2e-encryption", + "markdown", + "sled", + "socks", + "rustls-tls", +] + +[target.'cfg(not(target_os = "android"))'.dependencies.matrix-sdk] +path = "../../crates/matrix-sdk" +default-features = false +features = [ + "anyhow", + "experimental-sliding-sync", + "experimental-timeline", + "e2e-encryption", + "markdown", + "native-tls", + "sled", + "socks", +] diff --git a/bindings/matrix-sdk-ffi/src/api.udl b/bindings/matrix-sdk-ffi/src/api.udl index f241b0055d4..a0526d65e8d 100644 --- a/bindings/matrix-sdk-ffi/src/api.udl +++ b/bindings/matrix-sdk-ffi/src/api.udl @@ -1,8 +1,7 @@ namespace matrix_sdk_ffi {}; -/// Cancels on drop -interface StoppableSpawn {}; +interface TaskHandle {}; [Error] interface ClientError { @@ -10,9 +9,7 @@ interface ClientError { }; callback interface ClientDelegate { - void did_receive_sync_update(); void did_receive_auth_error(boolean is_soft_logout); - void did_update_restore_token(); }; dictionary RequiredState { @@ -26,7 +23,7 @@ dictionary RoomSubscription { }; dictionary UpdateSummary { - sequence views; + sequence lists; sequence rooms; }; @@ -54,7 +51,7 @@ enum SlidingSyncMode { "Selective", }; -callback interface SlidingSyncViewStateObserver { +callback interface SlidingSyncListStateObserver { void did_receive_update(SlidingSyncState new_state); }; @@ -66,61 +63,61 @@ interface RoomListEntry { }; [Enum] -interface SlidingSyncViewRoomsListDiff { - Replace(sequence values); - InsertAt( - u32 index, - RoomListEntry value - ); - UpdateAt( - u32 index, - RoomListEntry value - ); - RemoveAt(u32 index); - Move( - u32 old_index, - u32 new_index - ); - Push(RoomListEntry value); - Pop(); +interface SlidingSyncListRoomsListDiff { + Append(sequence values); + Insert(u32 index, RoomListEntry value); + Set(u32 index, RoomListEntry value); + Remove(u32 index); + PushBack(RoomListEntry value); + PushFront(RoomListEntry value); + PopBack(); + PopFront(); Clear(); + Reset(sequence values); }; -callback interface SlidingSyncViewRoomListObserver { - void did_receive_update(SlidingSyncViewRoomsListDiff diff); +callback interface SlidingSyncListRoomListObserver { + void did_receive_update(SlidingSyncListRoomsListDiff diff); }; -callback interface SlidingSyncViewRoomsCountObserver { +callback interface SlidingSyncListRoomsCountObserver { void did_receive_update(u32 count); }; -callback interface SlidingSyncViewRoomItemsObserver { +callback interface SlidingSyncListRoomItemsObserver { void did_receive_update(); }; -interface SlidingSyncViewBuilder { +interface SlidingSyncListBuilder { constructor(); [Self=ByArc] - SlidingSyncViewBuilder sync_mode(SlidingSyncMode mode); + SlidingSyncListBuilder sync_mode(SlidingSyncMode mode); [Self=ByArc] - SlidingSyncViewBuilder send_updates_for_items(boolean enable); + SlidingSyncListBuilder send_updates_for_items(boolean enable); [Throws=ClientError, Self=ByArc] - SlidingSyncView build(); + SlidingSyncList build(); }; -interface SlidingSyncView { - StoppableSpawn observe_room_list(SlidingSyncViewRoomListObserver observer); - StoppableSpawn observe_rooms_count(SlidingSyncViewRoomsCountObserver observer); - StoppableSpawn observe_state(SlidingSyncViewStateObserver observer); - StoppableSpawn observe_room_items(SlidingSyncViewRoomItemsObserver observer); +interface SlidingSyncList { + TaskHandle observe_room_list(SlidingSyncListRoomListObserver observer); + TaskHandle observe_rooms_count(SlidingSyncListRoomsCountObserver observer); + TaskHandle observe_state(SlidingSyncListStateObserver observer); + TaskHandle observe_room_items(SlidingSyncListRoomItemsObserver observer); }; interface SlidingSyncRoom { - StoppableSpawn? subscribe_and_add_timeline_listener(TimelineListener listener, RoomSubscription? settings); - StoppableSpawn? add_timeline_listener(TimelineListener listener); + [Throws=ClientError] + SlidingSyncSubscribeResult subscribe_and_add_timeline_listener(TimelineListener listener, RoomSubscription? settings); + [Throws=ClientError] + SlidingSyncSubscribeResult add_timeline_listener(TimelineListener listener); +}; + +dictionary SlidingSyncSubscribeResult { + sequence items; + TaskHandle task_handle; }; interface SlidingSync { @@ -152,65 +149,51 @@ interface SlidingSyncBuilder { SlidingSync build(); }; -interface Client { - void set_delegate(ClientDelegate? delegate); - - [Throws=ClientError] - void login(string username, string password, string? initial_device_name, string? device_id); - - [Throws=ClientError] - void restore_session(Session session); - - [Throws=ClientError] - Session session(); - - [Throws=ClientError] - string user_id(); - - [Throws=ClientError] - string display_name(); - - [Throws=ClientError] - void set_display_name(string name); - - [Throws=ClientError] - string avatar_url(); - - [Throws=ClientError] - string device_id(); +dictionary CreateRoomParameters { + string name; + string? topic = null; + boolean is_encrypted; + boolean is_direct = false; + RoomVisibility visibility; + RoomPreset preset; + sequence? invite = null; + string? avatar = null; +}; - [Throws=ClientError] - string? account_data(string event_type); +enum RoomVisibility { + /// Indicates that the room will be shown in the published room list. + "Public", - [Throws=ClientError] - void set_account_data(string event_type, string content); + /// Indicates that the room will not be shown in the published room list. + "Private", +}; - [Throws=ClientError] - string upload_media(string mime_type, sequence content); +enum RoomPreset { + /// `join_rules` is set to `invite` and `history_visibility` is set to + /// `shared`. + "PrivateChat", - [Throws=ClientError] - sequence get_media_content(MediaSource source); + /// `join_rules` is set to `public` and `history_visibility` is set to + /// `shared`. + "PublicChat", - [Throws=ClientError] - sequence get_media_thumbnail(MediaSource source, u64 width, u64 height); + /// Same as `PrivateChat`, but all initial invitees get the same power level + /// as the creator. + "TrustedPrivateChat", +}; - [Throws=ClientError] - SessionVerificationController get_session_verification_controller(); +interface Client { + void set_delegate(ClientDelegate? delegate); [Throws=ClientError] - SlidingSync full_sliding_sync(); + void login(string username, string password, string? initial_device_name, string? device_id); [Throws=ClientError] - void logout(); + MediaFileHandle get_media_file(MediaSource source, string mime_type); }; -dictionary Session { - string access_token; - string? refresh_token; - string user_id; - string device_id; - string homeserver_url; - boolean is_soft_logout; +interface MediaFileHandle { + string path(); }; enum MembershipState { @@ -252,7 +235,7 @@ interface Room { [Throws=ClientError] string? member_display_name(string user_id); - void add_timeline_listener(TimelineListener listener); + sequence add_timeline_listener(TimelineListener listener); // Loads older messages into the timeline. // @@ -277,17 +260,35 @@ interface Room { [Throws=ClientError] void redact(string event_id, string? reason, string? txn_id); + [Throws=ClientError] + void report_content(string event_id, i32? score, string? reason); + [Throws=ClientError] void send_reaction(string event_id, string key); + + [Throws=ClientError] + void leave(); + + [Throws=ClientError] + void reject_invitation(); + + [Throws=ClientError] + void set_topic(string topic); + + [Throws=ClientError] + void upload_avatar(string mime_type, sequence data); + + [Throws=ClientError] + void remove_avatar(); }; callback interface TimelineListener { void on_update(TimelineDiff update); }; -interface TimelineDiff { - MoveData? move(); -}; +interface TimelineItem {}; + +interface TimelineDiff {}; dictionary MoveData { u32 old_index; @@ -307,7 +308,22 @@ interface MediaSource { }; interface AuthenticationService { - constructor(string base_path, string? passphrase); + constructor(string base_path, string? passphrase, string? custom_sliding_sync_proxy); +}; + +dictionary NotificationItem { + TimelineItem item; + string title; + string? subtitle; + boolean is_noisy; + string? avatar_url; +}; + +interface NotificationService { + constructor(string base_path, string user_id); + + [Throws=ClientError] + NotificationItem? get_notification_item(string room_id, string event_id); }; interface SessionVerificationEmoji {}; diff --git a/bindings/matrix-sdk-ffi/src/authentication_service.rs b/bindings/matrix-sdk-ffi/src/authentication_service.rs index 6a9ebfe22cc..50ed0570238 100644 --- a/bindings/matrix-sdk-ffi/src/authentication_service.rs +++ b/bindings/matrix-sdk-ffi/src/authentication_service.rs @@ -2,9 +2,10 @@ use std::sync::{Arc, RwLock}; use futures_util::future::join3; use matrix_sdk::{ - ruma::{OwnedDeviceId, UserId}, + ruma::{IdParseError, OwnedDeviceId, UserId}, Session, }; +use url::Url; use zeroize::Zeroize; use super::{client::Client, client_builder::ClientBuilder, RUNTIME}; @@ -14,6 +15,7 @@ pub struct AuthenticationService { passphrase: Option, client: RwLock>>, homeserver_details: RwLock>>, + custom_sliding_sync_proxy: RwLock>, } impl Drop for AuthenticationService { @@ -25,8 +27,12 @@ impl Drop for AuthenticationService { #[derive(Debug, thiserror::Error, uniffi::Error)] #[uniffi(flat_error)] pub enum AuthenticationError { - #[error("A successful call to use_server must be made first.")] + #[error("A successful call to configure_homeserver must be made first.")] ClientMissing, + #[error("{message}")] + InvalidServerName { message: String }, + #[error("The homeserver doesn't provide a trusted a sliding sync proxy in its well-known configuration.")] + SlidingSyncNotAvailable, #[error("Login was successful but is missing a valid Session to configure the file store.")] SessionMissing, #[error("An error occurred: {message}")] @@ -39,6 +45,12 @@ impl From for AuthenticationError { } } +impl From for AuthenticationError { + fn from(e: IdParseError) -> AuthenticationError { + AuthenticationError::InvalidServerName { message: e.to_string() } + } +} + #[derive(uniffi::Object)] pub struct HomeserverLoginDetails { url: String, @@ -67,12 +79,17 @@ impl HomeserverLoginDetails { impl AuthenticationService { /// Creates a new service to authenticate a user with. - pub fn new(base_path: String, passphrase: Option) -> Self { + pub fn new( + base_path: String, + passphrase: Option, + custom_sliding_sync_proxy: Option, + ) -> Self { AuthenticationService { base_path, passphrase, client: RwLock::new(None), homeserver_details: RwLock::new(None), + custom_sliding_sync_proxy: RwLock::new(custom_sliding_sync_proxy), } } @@ -90,7 +107,7 @@ impl AuthenticationService { let url = login_details.0; let authentication_issuer = login_details.1; - let supports_password_login = login_details.2.map_err(AuthenticationError::from)?; + let supports_password_login = login_details.2?; Ok(HomeserverLoginDetails { url, authentication_issuer, supports_password_login }) } @@ -104,25 +121,55 @@ impl AuthenticationService { /// Updates the service to authenticate with the homeserver for the /// specified address. - pub fn configure_homeserver(&self, server_name: String) -> Result<(), AuthenticationError> { + pub fn configure_homeserver( + &self, + server_name_or_homeserver_url: String, + ) -> Result<(), AuthenticationError> { let mut builder = Arc::new(ClientBuilder::new()).base_path(self.base_path.clone()); - if server_name.starts_with("http://") || server_name.starts_with("https://") { - builder = builder.homeserver_url(server_name) - } else { - builder = builder.server_name(server_name); + // Attempt discovery as a server name first. + let result = matrix_sdk::sanitize_server_name(&server_name_or_homeserver_url); + match result { + Ok(server_name) => { + builder = builder.server_name(server_name.to_string()); + } + Err(e) => { + // When the input isn't a valid server name check it is a URL. + // If this is the case, build the client with a homeserver URL. + if let Ok(_url) = Url::parse(&server_name_or_homeserver_url) { + builder = builder.homeserver_url(server_name_or_homeserver_url.clone()); + } else { + return Err(e.into()); + } + } } - let client = builder.build().map_err(AuthenticationError::from)?; + let client = builder.build().or_else(|e| { + if !server_name_or_homeserver_url.starts_with("http://") + && !server_name_or_homeserver_url.starts_with("https://") + { + return Err(e); + } + // When discovery fails, fallback to the homeserver URL if supplied. + let mut builder = Arc::new(ClientBuilder::new()).base_path(self.base_path.clone()); + builder = builder.homeserver_url(server_name_or_homeserver_url); + builder.build() + })?; - RUNTIME.block_on(async move { - let details = Arc::new(self.details_from_client(&client).await?); + let details = RUNTIME.block_on(self.details_from_client(&client))?; - *self.client.write().unwrap() = Some(client); - *self.homeserver_details.write().unwrap() = Some(details); + // Now we've verified that it's a valid homeserver, make sure + // there's a sliding sync proxy available one way or another. + if self.custom_sliding_sync_proxy.read().unwrap().is_none() + && client.discovered_sliding_sync_proxy().is_none() + { + return Err(AuthenticationError::SlidingSyncNotAvailable); + } - Ok(()) - }) + *self.client.write().unwrap() = Some(client); + *self.homeserver_details.write().unwrap() = Some(Arc::new(details)); + + Ok(()) } /// Performs a password login using the current homeserver. @@ -139,24 +186,32 @@ impl AuthenticationService { // Login and ask the server for the full user ID as this could be different from // the username that was entered. - client - .login(username, password, initial_device_name, device_id) - .map_err(AuthenticationError::from)?; + client.login(username, password, initial_device_name, device_id)?; let whoami = client.whoami()?; // Create a new client to setup the store path now the user ID is known. let homeserver_url = client.homeserver(); let session = client.client.session().ok_or(AuthenticationError::SessionMissing)?; + + let sliding_sync_proxy: Option; + if let Some(custom_proxy) = self.custom_sliding_sync_proxy.read().unwrap().clone() { + sliding_sync_proxy = Some(custom_proxy); + } else if let Some(discovered_proxy) = client.discovered_sliding_sync_proxy() { + sliding_sync_proxy = Some(discovered_proxy); + } else { + sliding_sync_proxy = None; + } + let client = Arc::new(ClientBuilder::new()) .base_path(self.base_path.clone()) .passphrase(self.passphrase.clone()) .homeserver_url(homeserver_url) + .sliding_sync_proxy(sliding_sync_proxy) .username(whoami.user_id.to_string()) - .build() - .map_err(AuthenticationError::from)?; + .build()?; // Restore the client using the session from the login request. - client.restore_session_inner(session).map_err(AuthenticationError::from)?; + client.restore_session_inner(session)?; Ok(client) } @@ -189,7 +244,7 @@ impl AuthenticationService { device_id: device_id.clone(), }; - client.restore_session_inner(discovery_session).map_err(AuthenticationError::from)?; + client.restore_session_inner(discovery_session)?; let whoami = client.whoami()?; // Create the actual client with a store path from the user ID. @@ -205,11 +260,10 @@ impl AuthenticationService { .passphrase(self.passphrase.clone()) .homeserver_url(homeserver_url) .username(whoami.user_id.to_string()) - .build() - .map_err(AuthenticationError::from)?; + .build()?; // Restore the client using the session. - client.restore_session_inner(session).map_err(AuthenticationError::from)?; + client.restore_session_inner(session)?; Ok(client) } } diff --git a/bindings/matrix-sdk-ffi/src/client.rs b/bindings/matrix-sdk-ffi/src/client.rs index a3f2fb79910..ef28c730cf6 100644 --- a/bindings/matrix-sdk-ffi/src/client.rs +++ b/bindings/matrix-sdk-ffi/src/client.rs @@ -2,29 +2,94 @@ use std::sync::{Arc, RwLock}; use anyhow::{anyhow, Context}; use matrix_sdk::{ - config::SyncSettings, - media::{MediaFormat, MediaRequest, MediaThumbnailSize}, + media::{MediaFileHandle as SdkMediaFileHandle, MediaFormat, MediaRequest, MediaThumbnailSize}, ruma::{ api::client::{ account::whoami, error::ErrorKind, - filter::{FilterDefinition, LazyLoadOptions, RoomEventFilter, RoomFilter}, media::get_content_thumbnail::v3::Method, + push::{EmailPusherData, PusherIds, PusherInit, PusherKind as RumaPusherKind}, + room::{create_room, Visibility}, session::get_login_types, - sync::sync_events::v3::Filter, }, - events::{room::MediaSource, AnyToDeviceEvent}, + events::{ + room::{ + avatar::RoomAvatarEventContent, encryption::RoomEncryptionEventContent, MediaSource, + }, + AnyInitialStateEvent, AnyToDeviceEvent, InitialStateEvent, + }, serde::Raw, - TransactionId, UInt, + EventEncryptionAlgorithm, TransactionId, UInt, UserId, }, Client as MatrixClient, Error, LoopCtrl, }; -use tokio::sync::broadcast; -use tracing::{debug, warn}; +use ruma::push::{HttpPusherData as RumaHttpPusherData, PushFormat as RumaPushFormat}; +use serde_json::Value; +use tokio::sync::broadcast::{self, error::RecvError}; +use tracing::{debug, error, warn}; + +use super::{room::Room, session_verification::SessionVerificationController, RUNTIME}; +use crate::{client, ClientError}; + +#[derive(Clone, uniffi::Record)] +pub struct PusherIdentifiers { + pub pushkey: String, + pub app_id: String, +} -use super::{ - room::Room, session_verification::SessionVerificationController, ClientState, RUNTIME, -}; +impl From for PusherIds { + fn from(value: PusherIdentifiers) -> Self { + Self::new(value.pushkey, value.app_id) + } +} + +#[derive(Clone, uniffi::Record)] +pub struct HttpPusherData { + pub url: String, + pub format: Option, + pub default_payload: Option, +} + +#[derive(Clone, uniffi::Enum)] +pub enum PusherKind { + Http { data: HttpPusherData }, + Email, +} + +impl TryFrom for RumaPusherKind { + type Error = anyhow::Error; + + fn try_from(value: PusherKind) -> anyhow::Result { + match value { + PusherKind::Http { data } => { + let mut ruma_data = RumaHttpPusherData::new(data.url); + if let Some(payload) = data.default_payload { + let json: Value = serde_json::from_str(&payload)?; + ruma_data.default_payload = json; + } + ruma_data.format = data.format.map(Into::into); + Ok(Self::Http(ruma_data)) + } + PusherKind::Email => { + let ruma_data = EmailPusherData::new(); + Ok(Self::Email(ruma_data)) + } + } + } +} + +#[derive(Clone, uniffi::Enum)] +pub enum PushFormat { + EventIdOnly, +} + +impl From for RumaPushFormat { + fn from(value: PushFormat) -> Self { + match value { + client::PushFormat::EventIdOnly => Self::EventIdOnly, + } + } +} impl std::ops::Deref for Client { type Target = MatrixClient; @@ -34,23 +99,24 @@ impl std::ops::Deref for Client { } pub trait ClientDelegate: Sync + Send { - fn did_receive_sync_update(&self); fn did_receive_auth_error(&self, is_soft_logout: bool); - fn did_update_restore_token(&self); } #[derive(Clone)] pub struct Client { pub(crate) client: MatrixClient, - state: Arc>, delegate: Arc>>>, session_verification_controller: Arc>>, + /// The sliding sync proxy that the client is configured to use by default. + /// If this value is `Some`, it will be automatically added to the builder + /// when calling `sliding_sync()`. + pub(crate) sliding_sync_proxy: Arc>>, pub(crate) sliding_sync_reset_broadcast_tx: broadcast::Sender<()>, } impl Client { - pub fn new(client: MatrixClient, state: ClientState) -> Self { + pub fn new(client: MatrixClient) -> Self { let session_verification_controller: Arc< matrix_sdk::locks::RwLock>, > = Default::default(); @@ -69,13 +135,30 @@ impl Client { let (sliding_sync_reset_broadcast_tx, _) = broadcast::channel(1); - Client { + let client = Client { client, - state: Arc::new(RwLock::new(state)), delegate: Arc::new(RwLock::new(None)), session_verification_controller, + sliding_sync_proxy: Arc::new(RwLock::new(None)), sliding_sync_reset_broadcast_tx, - } + }; + + let mut unknown_token_error_receiver = client.subscribe_to_unknown_token_errors(); + let client_clone = client.clone(); + RUNTIME.spawn(async move { + loop { + match unknown_token_error_receiver.recv().await { + Ok(unknown_token) => client_clone.process_unknown_token_error(unknown_token), + Err(receive_error) => { + if let RecvError::Closed = receive_error { + break; + } + } + } + } + }); + + client } /// Login using a username and password. @@ -99,19 +182,44 @@ impl Client { }) } + pub fn get_media_file( + &self, + media_source: Arc, + mime_type: String, + ) -> anyhow::Result> { + let client = self.client.clone(); + let source = (*media_source).clone(); + let mime_type: mime::Mime = mime_type.parse()?; + + RUNTIME.block_on(async move { + let handle = client + .media() + .get_media_file( + &MediaRequest { source, format: MediaFormat::File }, + &mime_type, + true, + ) + .await?; + + Ok(Arc::new(MediaFileHandle { inner: handle })) + }) + } +} + +#[uniffi::export] +impl Client { /// Restores the client from a `Session`. - pub fn restore_session(&self, session: Session) -> anyhow::Result<()> { + pub fn restore_session(&self, session: Session) -> Result<(), ClientError> { let Session { access_token, refresh_token, user_id, device_id, homeserver_url: _, - is_soft_logout, + sliding_sync_proxy, } = session; - // update soft logout state - self.state.write().unwrap().is_soft_logout = is_soft_logout; + *self.sliding_sync_proxy.write().unwrap() = sliding_sync_proxy; let session = matrix_sdk::Session { access_token, @@ -119,9 +227,11 @@ impl Client { user_id: user_id.try_into()?, device_id: device_id.into(), }; - self.restore_session_inner(session) + Ok(self.restore_session_inner(session)?) } +} +impl Client { /// Restores the client from a `matrix_sdk::Session`. pub(crate) fn restore_session_inner(&self, session: matrix_sdk::Session) -> anyhow::Result<()> { RUNTIME.block_on(async move { @@ -144,6 +254,18 @@ impl Client { self.client.authentication_issuer().await.map(|server| server.to_string()) } + /// The sliding sync proxy that is trusted by the homeserver. `None` when + /// not configured. + pub fn discovered_sliding_sync_proxy(&self) -> Option { + RUNTIME.block_on(async move { + self.client.sliding_sync_proxy().await.map(|server| server.to_string()) + }) + } + + pub(crate) fn set_sliding_sync_proxy(&self, sliding_sync_proxy: Option) { + *self.sliding_sync_proxy.write().unwrap() = sliding_sync_proxy; + } + /// Whether or not the client's homeserver supports the password login flow. pub async fn supports_password_login(&self) -> anyhow::Result { let login_types = self.client.get_login_types().await?; @@ -159,13 +281,16 @@ impl Client { RUNTIME .block_on(async move { self.client.whoami().await.map_err(|e| anyhow!(e.to_string())) }) } +} - pub fn session(&self) -> anyhow::Result { +#[uniffi::export] +impl Client { + pub fn session(&self) -> Result { RUNTIME.block_on(async move { let matrix_sdk::Session { access_token, refresh_token, user_id, device_id } = self.client.session().context("Missing session")?; let homeserver_url = self.client.homeserver().await.into(); - let is_soft_logout = self.state.read().unwrap().is_soft_logout; + let sliding_sync_proxy = self.sliding_sync_proxy.read().unwrap().clone(); Ok(Session { access_token, @@ -173,17 +298,17 @@ impl Client { user_id: user_id.to_string(), device_id: device_id.to_string(), homeserver_url, - is_soft_logout, + sliding_sync_proxy, }) }) } - pub fn user_id(&self) -> anyhow::Result { + pub fn user_id(&self) -> Result { let user_id = self.client.user_id().context("No User ID found")?; Ok(user_id.to_string()) } - pub fn display_name(&self) -> anyhow::Result { + pub fn display_name(&self) -> Result { let l = self.client.clone(); RUNTIME.block_on(async move { let display_name = l.account().get_display_name().await?.context("No User ID found")?; @@ -191,35 +316,53 @@ impl Client { }) } - pub fn set_display_name(&self, name: String) -> anyhow::Result<()> { + pub fn set_display_name(&self, name: String) -> Result<(), ClientError> { let client = self.client.clone(); RUNTIME.block_on(async move { client .account() .set_display_name(Some(name.as_str())) .await - .context("Unable to set display name") + .context("Unable to set display name")?; + Ok(()) + }) + } + + pub fn avatar_url(&self) -> Result, ClientError> { + let l = self.client.clone(); + RUNTIME.block_on(async move { + let avatar_url = l.account().get_avatar_url().await?; + Ok(avatar_url.map(|u| u.to_string())) }) } - pub fn avatar_url(&self) -> anyhow::Result { + pub fn cached_avatar_url(&self) -> Result, ClientError> { let l = self.client.clone(); RUNTIME.block_on(async move { - let avatar_url = l.account().get_avatar_url().await?.context("No User ID found")?; - Ok(avatar_url.to_string()) + let url = l.account().get_cached_avatar_url().await?; + Ok(url) }) } - pub fn device_id(&self) -> anyhow::Result { + pub fn device_id(&self) -> Result { let device_id = self.client.device_id().context("No Device ID found")?; Ok(device_id.to_string()) } + pub fn create_room(&self, request: CreateRoomParameters) -> Result { + let client = self.client.clone(); + + RUNTIME.block_on(async move { + let response = client.create_room(request.into()).await?; + Ok(String::from(response.room_id())) + }) + } + /// Get the content of the event of the given type out of the account data /// store. /// /// It will be returned as a JSON string. - pub fn account_data(&self, event_type: String) -> anyhow::Result> { + pub fn account_data(&self, event_type: String) -> Result, ClientError> { RUNTIME.block_on(async move { let event = self.client.account().account_data_raw(event_type.into()).await?; Ok(event.map(|e| e.json().get().to_owned())) @@ -229,7 +372,7 @@ impl Client { /// Set the given account data content for the given event type. /// /// It should be supplied as a JSON string. - pub fn set_account_data(&self, event_type: String, content: String) -> anyhow::Result<()> { + pub fn set_account_data(&self, event_type: String, content: String) -> Result<(), ClientError> { RUNTIME.block_on(async move { let raw_content = Raw::from_json_string(content)?; self.client.account().set_account_data_raw(event_type.into(), raw_content).await?; @@ -237,17 +380,20 @@ impl Client { }) } - pub fn upload_media(&self, mime_type: String, data: Vec) -> anyhow::Result { + pub fn upload_media(&self, mime_type: String, data: Vec) -> Result { let l = self.client.clone(); RUNTIME.block_on(async move { - let mime_type: mime::Mime = mime_type.parse()?; + let mime_type: mime::Mime = mime_type.parse().context("Parsing mime type")?; let response = l.media().upload(&mime_type, data).await?; Ok(String::from(response.content_uri)) }) } - pub fn get_media_content(&self, media_source: Arc) -> anyhow::Result> { + pub fn get_media_content( + &self, + media_source: Arc, + ) -> Result, ClientError> { let l = self.client.clone(); let source = (*media_source).clone(); @@ -263,7 +409,7 @@ impl Client { media_source: Arc, width: u64, height: u64, - ) -> anyhow::Result> { + ) -> Result, ClientError> { let l = self.client.clone(); let source = (*media_source).clone(); @@ -286,7 +432,7 @@ impl Client { pub fn get_session_verification_controller( &self, - ) -> anyhow::Result> { + ) -> Result, ClientError> { RUNTIME.block_on(async move { if let Some(session_verification_controller) = &*self.session_verification_controller.read().await @@ -312,27 +458,52 @@ impl Client { } /// Log out the current user - pub fn logout(&self) -> anyhow::Result<()> { + pub fn logout(&self) -> Result<(), ClientError> { + RUNTIME.block_on(self.client.logout())?; + Ok(()) + } + + /// Registers a pusher with given parameters + pub fn set_pusher( + &self, + identifiers: PusherIdentifiers, + kind: PusherKind, + app_display_name: String, + device_display_name: String, + profile_tag: Option, + lang: String, + ) -> Result<(), ClientError> { RUNTIME.block_on(async move { - match self.client.logout().await { - Ok(_) => Ok(()), - Err(error) => Err(anyhow!(error.to_string())), - } + let ids = identifiers.into(); + + let pusher_init = PusherInit { + ids, + kind: kind.try_into()?, + app_display_name, + device_display_name, + profile_tag, + lang, + }; + self.client.set_pusher(pusher_init.into()).await?; + Ok(()) }) } + /// The homeserver this client is configured to use. + pub fn homeserver(&self) -> String { + RUNTIME.block_on(self.async_homeserver()) + } + + pub fn rooms(&self) -> Vec> { + self.client.rooms().into_iter().map(|room| Arc::new(Room::new(room))).collect() + } +} + +impl Client { /// Process a sync error and return loop control accordingly pub(crate) fn process_sync_error(&self, sync_error: Error) -> LoopCtrl { let client_api_error_kind = sync_error.client_api_error_kind(); match client_api_error_kind { - Some(ErrorKind::UnknownToken { soft_logout }) => { - self.state.write().unwrap().is_soft_logout = *soft_logout; - if let Some(delegate) = &*self.delegate.read().unwrap() { - delegate.did_update_restore_token(); - delegate.did_receive_auth_error(*soft_logout); - } - LoopCtrl::Break - } Some(ErrorKind::UnknownPos) => { let _ = self.sliding_sync_reset_broadcast_tx.send(()); LoopCtrl::Continue @@ -343,87 +514,109 @@ impl Client { } } } -} -#[uniffi::export] -impl Client { - /// The homeserver this client is configured to use. - pub fn homeserver(&self) -> String { - RUNTIME.block_on(async move { self.async_homeserver().await }) + fn process_unknown_token_error(&self, unknown_token: matrix_sdk::UnknownToken) { + if let Some(delegate) = &*self.delegate.read().unwrap() { + delegate.did_receive_auth_error(unknown_token.soft_logout); + } } +} - /// Indication whether we've received a first sync response since - /// establishing the client (in memory) - pub fn has_first_synced(&self) -> bool { - self.state.read().unwrap().has_first_synced - } +pub struct CreateRoomParameters { + pub name: String, + pub topic: Option, + pub is_encrypted: bool, + pub is_direct: bool, + pub visibility: RoomVisibility, + pub preset: RoomPreset, + pub invite: Option>, + pub avatar: Option, +} - /// Indication whether we are currently syncing - pub fn is_syncing(&self) -> bool { - self.state.read().unwrap().is_syncing - } +impl From for create_room::v3::Request { + fn from(value: CreateRoomParameters) -> create_room::v3::Request { + let mut request = create_room::v3::Request::new(); + request.name = Some(value.name); + request.topic = value.topic; + request.is_direct = value.is_direct; + request.visibility = value.visibility.into(); + request.preset = Some(value.preset.into()); + request.invite = match value.invite { + Some(invite) => invite + .iter() + .filter_map(|user_id| match UserId::parse(user_id) { + Ok(id) => Some(id), + Err(e) => { + error!(user_id, "Skipping invalid user ID, error: {e}"); + None + } + }) + .collect(), + None => vec![], + }; - /// Flag indicating whether the session is in soft logout mode - pub fn is_soft_logout(&self) -> bool { - self.state.read().unwrap().is_soft_logout - } + let mut initial_state: Vec> = vec![]; - pub fn rooms(&self) -> Vec> { - self.client.rooms().into_iter().map(|room| Arc::new(Room::new(room))).collect() - } + if value.is_encrypted { + let content = + RoomEncryptionEventContent::new(EventEncryptionAlgorithm::MegolmV1AesSha2); + initial_state.push(InitialStateEvent::new(content).to_raw_any()); + } - pub fn start_sync(&self, timeline_limit: Option) { - let client = self.client.clone(); - let state = self.state.clone(); - let delegate = self.delegate.clone(); - let local_self = self.clone(); - RUNTIME.spawn(async move { - let mut filter = FilterDefinition::default(); - let mut room_filter = RoomFilter::default(); - let mut event_filter = RoomEventFilter::default(); - let mut timeline_filter = RoomEventFilter::default(); + if let Some(url) = value.avatar { + let mut content = RoomAvatarEventContent::new(); + content.url = Some(url.into()); + initial_state.push(InitialStateEvent::new(content).to_raw_any()); + } - event_filter.lazy_load_options = - LazyLoadOptions::Enabled { include_redundant_members: false }; - room_filter.state = event_filter; - filter.room = room_filter; + request.initial_state = initial_state; - timeline_filter.limit = timeline_limit.map(|limit| limit.into()); - filter.room.timeline = timeline_filter; + request + } +} - let filter_id = client.get_or_upload_filter("sync", filter).await.unwrap(); +pub enum RoomVisibility { + /// Indicates that the room will be shown in the published room list. + Public, - let sync_settings = SyncSettings::new().filter(Filter::FilterId(filter_id)); + /// Indicates that the room will not be shown in the published room list. + Private, +} - client - .sync_with_result_callback(sync_settings, |result| async { - Ok(if result.is_ok() { - if !state.read().unwrap().has_first_synced { - state.write().unwrap().has_first_synced = true; - } +impl From for Visibility { + fn from(value: RoomVisibility) -> Self { + match value { + RoomVisibility::Public => Self::Public, + RoomVisibility::Private => Self::Private, + } + } +} - if state.read().unwrap().should_stop_syncing { - state.write().unwrap().is_syncing = false; - return Ok(LoopCtrl::Break); - } else if !state.read().unwrap().is_syncing { - state.write().unwrap().is_syncing = true; - } +pub enum RoomPreset { + /// `join_rules` is set to `invite` and `history_visibility` is set to + /// `shared`. + PrivateChat, - if let Some(delegate) = &*delegate.read().unwrap() { - delegate.did_receive_sync_update() - } + /// `join_rules` is set to `public` and `history_visibility` is set to + /// `shared`. + PublicChat, - LoopCtrl::Continue - } else { - local_self.process_sync_error(result.err().unwrap()) - }) - }) - .await - .unwrap(); - }); + /// Same as `PrivateChat`, but all initial invitees get the same power level + /// as the creator. + TrustedPrivateChat, +} + +impl From for create_room::v3::RoomPreset { + fn from(value: RoomPreset) -> Self { + match value { + RoomPreset::PrivateChat => Self::PrivateChat, + RoomPreset::PublicChat => Self::PublicChat, + RoomPreset::TrustedPrivateChat => Self::TrustedPrivateChat, + } } } +#[derive(uniffi::Record)] pub struct Session { // Same fields as the Session type in matrix-sdk, just simpler types /// The access token used for this session. @@ -439,10 +632,23 @@ pub struct Session { // FFI-only fields (for now) pub homeserver_url: String, - pub is_soft_logout: bool, + pub sliding_sync_proxy: Option, } #[uniffi::export] fn gen_transaction_id() -> String { TransactionId::new().to_string() } + +/// A file handle that takes ownership of a media file on disk. When the handle +/// is dropped, the file will be removed from the disk. +pub struct MediaFileHandle { + inner: SdkMediaFileHandle, +} + +impl MediaFileHandle { + /// Get the media file's path. + pub fn path(&self) -> String { + self.inner.path().to_str().unwrap().to_owned() + } +} diff --git a/bindings/matrix-sdk-ffi/src/client_builder.rs b/bindings/matrix-sdk-ffi/src/client_builder.rs index cc6d8d29343..3200c7e0315 100644 --- a/bindings/matrix-sdk-ffi/src/client_builder.rs +++ b/bindings/matrix-sdk-ffi/src/client_builder.rs @@ -11,7 +11,7 @@ use matrix_sdk::{ use sanitize_filename_reader_friendly::sanitize; use zeroize::Zeroizing; -use super::{client::Client, ClientState, RUNTIME}; +use super::{client::Client, RUNTIME}; use crate::helpers::unwrap_or_clone_arc; #[derive(Clone)] @@ -23,6 +23,7 @@ pub struct ClientBuilder { server_versions: Option>, passphrase: Zeroizing>, user_agent: Option, + sliding_sync_proxy: Option, inner: MatrixClientBuilder, } @@ -69,6 +70,12 @@ impl ClientBuilder { builder.user_agent = Some(user_agent); Arc::new(builder) } + + pub fn sliding_sync_proxy(self: Arc, sliding_sync_proxy: Option) -> Arc { + let mut builder = unwrap_or_clone_arc(self); + builder.sliding_sync_proxy = sliding_sync_proxy; + Arc::new(builder) + } } impl ClientBuilder { @@ -81,6 +88,7 @@ impl ClientBuilder { server_versions: None, passphrase: Zeroizing::new(None), user_agent: None, + sliding_sync_proxy: None, inner: MatrixClient::builder(), } } @@ -127,7 +135,8 @@ impl ClientBuilder { RUNTIME.block_on(async move { let client = inner_builder.build().await?; - let c = Client::new(client, ClientState::default()); + let c = Client::new(client); + c.set_sliding_sync_proxy(builder.sliding_sync_proxy); Ok(Arc::new(c)) }) } diff --git a/bindings/matrix-sdk-ffi/src/lib.rs b/bindings/matrix-sdk-ffi/src/lib.rs index 094e84a128c..3e73017ec8b 100644 --- a/bindings/matrix-sdk-ffi/src/lib.rs +++ b/bindings/matrix-sdk-ffi/src/lib.rs @@ -1,6 +1,6 @@ // TODO: target-os conditional would be good. -#![allow(unused_qualifications)] +#![allow(unused_qualifications, clippy::new_without_default)] macro_rules! unwrap_or_clone_arc_into_variant { ( @@ -26,6 +26,7 @@ pub mod authentication_service; pub mod client; pub mod client_builder; mod helpers; +pub mod notification_service; pub mod room; pub mod session_verification; pub mod sliding_sync; @@ -34,6 +35,7 @@ mod uniffi_api; use client::Client; use client_builder::ClientBuilder; +use matrix_sdk::{encryption::CryptoStoreError, HttpError, IdParseError}; use once_cell::sync::Lazy; use tokio::runtime::Runtime; pub use uniffi_api::*; @@ -47,18 +49,10 @@ pub use matrix_sdk::{ }; pub use self::{ - authentication_service::*, client::*, room::*, session_verification::*, sliding_sync::*, - timeline::*, + authentication_service::*, client::*, notification_service::*, room::*, + session_verification::*, sliding_sync::*, timeline::*, }; -#[derive(Default, Debug)] -pub struct ClientState { - has_first_synced: bool, - is_syncing: bool, - should_stop_syncing: bool, - is_soft_logout: bool, -} - #[derive(thiserror::Error, Debug)] pub enum ClientError { #[error("client error: {msg}")] @@ -71,6 +65,36 @@ impl From for ClientError { } } +impl From for ClientError { + fn from(e: matrix_sdk::Error) -> Self { + anyhow::Error::from(e).into() + } +} + +impl From for ClientError { + fn from(e: CryptoStoreError) -> Self { + anyhow::Error::from(e).into() + } +} + +impl From for ClientError { + fn from(e: HttpError) -> Self { + anyhow::Error::from(e).into() + } +} + +impl From for ClientError { + fn from(e: IdParseError) -> Self { + anyhow::Error::from(e).into() + } +} + +impl From for ClientError { + fn from(e: serde_json::Error) -> Self { + anyhow::Error::from(e).into() + } +} + pub use platform::*; mod uniffi_types { @@ -80,23 +104,27 @@ mod uniffi_types { authentication_service::{ AuthenticationError, AuthenticationService, HomeserverLoginDetails, }, - client::Client, + client::{ + Client, CreateRoomParameters, HttpPusherData, PushFormat, PusherIdentifiers, + PusherKind, Session, + }, client_builder::ClientBuilder, room::{Membership, MembershipState, Room, RoomMember}, session_verification::{SessionVerificationController, SessionVerificationEmoji}, sliding_sync::{ - RequiredState, RoomListEntry, SlidingSync, SlidingSyncBuilder, - SlidingSyncRequestListFilters, SlidingSyncRoom, SlidingSyncView, - SlidingSyncViewBuilder, StoppableSpawn, UnreadNotificationsCount, + RequiredState, RoomListEntry, SlidingSync, SlidingSyncBuilder, SlidingSyncList, + SlidingSyncListBuilder, SlidingSyncRequestListFilters, SlidingSyncRoom, TaskHandle, + UnreadNotificationsCount, }, timeline::{ - EmoteMessageContent, EncryptedMessage, EventSendState, EventTimelineItem, FileInfo, - FileMessageContent, FormattedBody, ImageInfo, ImageMessageContent, InsertAtData, - MembershipChange, Message, MessageFormat, MessageType, NoticeMessageContent, - OtherState, ProfileTimelineDetails, Reaction, TextMessageContent, ThumbnailInfo, - TimelineChange, TimelineDiff, TimelineItem, TimelineItemContent, - TimelineItemContentKind, UpdateAtData, VideoInfo, VideoMessageContent, + AudioInfo, AudioMessageContent, EmoteMessageContent, EncryptedMessage, EventSendState, + EventTimelineItem, FileInfo, FileMessageContent, FormattedBody, ImageInfo, + ImageMessageContent, InsertData, MembershipChange, Message, MessageFormat, MessageType, + NoticeMessageContent, OtherState, ProfileTimelineDetails, Reaction, SetData, + TextMessageContent, ThumbnailInfo, TimelineChange, TimelineDiff, TimelineItem, + TimelineItemContent, TimelineItemContentKind, VideoInfo, VideoMessageContent, VirtualTimelineItem, }, + ClientError, }; } diff --git a/bindings/matrix-sdk-ffi/src/notification_service.rs b/bindings/matrix-sdk-ffi/src/notification_service.rs new file mode 100644 index 00000000000..eb935e0980b --- /dev/null +++ b/bindings/matrix-sdk-ffi/src/notification_service.rs @@ -0,0 +1,47 @@ +use std::sync::Arc; + +use crate::TimelineItem; + +#[allow(dead_code)] +pub struct NotificationService { + base_path: String, + user_id: String, +} + +/// Notification item struct. +pub struct NotificationItem { + /// Actual timeline item for the event sent. + pub item: Arc, + /// Title of the notification. Usually would be event sender's display name. + pub title: String, + /// Subtitle of the notification. Usually would be the room name for + /// non-direct rooms, and none for direct rooms. + pub subtitle: Option, + /// Flag indicating the notification should play a sound. + pub is_noisy: bool, + /// Avatar url of the room the event sent to (if any). + pub avatar_url: Option, +} + +impl NotificationService { + /// Creates a new notification service. + /// + /// Will be used to fetch an event after receiving a notification. + /// Please note that this will be called on a new process than the + /// application context. + pub fn new(base_path: String, user_id: String) -> Self { + Self { base_path, user_id } + } + + /// Get notification item for a given `room_id `and `event_id`. + /// + /// Returns `None` if this notification should not be displayed to the user. + pub fn get_notification_item( + &self, + _room_id: String, + _event_id: String, + ) -> anyhow::Result> { + // TODO: Implement + Ok(None) + } +} diff --git a/bindings/matrix-sdk-ffi/src/platform.rs b/bindings/matrix-sdk-ffi/src/platform.rs index b506c47441c..cc1fe83249d 100644 --- a/bindings/matrix-sdk-ffi/src/platform.rs +++ b/bindings/matrix-sdk-ffi/src/platform.rs @@ -1,4 +1,6 @@ use std::collections::HashMap; +#[cfg(not(target_os = "android"))] +use std::io; #[cfg(target_os = "android")] use android as platform_impl; @@ -18,28 +20,11 @@ use opentelemetry_otlp::{Protocol, WithExportConfig}; #[cfg(not(any(target_os = "ios", target_os = "android")))] use other as platform_impl; use tokio::runtime::Handle; +#[cfg(not(target_os = "android"))] +use tracing_subscriber::{fmt, prelude::*, EnvFilter}; use crate::RUNTIME; -#[cfg(target_os = "android")] -mod android { - use android_logger::{Config, FilterBuilder}; - use tracing::log::Level; - - pub fn setup_tracing(filter: String) { - std::env::set_var("RUST_BACKTRACE", "1"); - - log_panics::init(); - - let log_config = Config::default() - .with_min_level(Level::Trace) - .with_tag("matrix-rust-sdk") - .with_filter(FilterBuilder::new().parse(&filter).build()); - - android_logger::init_once(log_config); - } -} - #[derive(Clone, Debug)] struct TracingRuntime { runtime: Handle, @@ -109,32 +94,74 @@ pub fn create_otlp_tracer( Ok(tracer) } -#[cfg(target_os = "ios")] -mod ios { - use std::io; +#[cfg(not(target_os = "android"))] +fn setup_tracing_helper(configuration: String, colors: bool) { + tracing_subscriber::registry() + .with(EnvFilter::new(configuration)) + .with(fmt::layer().with_ansi(colors).with_writer(io::stderr)) + .init(); +} + +#[cfg(not(target_os = "android"))] +fn setup_otlp_tracing_helper( + configuration: String, + colors: bool, + client_name: String, + user: String, + password: String, + otlp_endpoint: String, +) -> anyhow::Result<()> { + let otlp_tracer = super::create_otlp_tracer(user, password, otlp_endpoint, client_name)?; + let otlp_layer = tracing_opentelemetry::layer().with_tracer(otlp_tracer); + + tracing_subscriber::registry() + .with(EnvFilter::new(configuration)) + .with(fmt::layer().with_ansi(colors).with_writer(io::stderr)) + .with(otlp_layer) + .init(); + + Ok(()) +} + +#[cfg(target_os = "android")] +mod android { + use tracing_subscriber::{prelude::*, EnvFilter}; + + fn log_panics() { + std::env::set_var("RUST_BACKTRACE", "1"); + log_panics::init(); + } - use tracing_subscriber::{fmt, prelude::*, EnvFilter}; pub fn setup_tracing(configuration: String) { + log_panics(); + tracing_subscriber::registry() .with(EnvFilter::new(configuration)) - .with(fmt::layer().with_ansi(false).with_writer(io::stderr)) + .with( + tracing_android::layer("org.matrix.rust.sdk") + .expect("Could not configure the Android tracing layer"), + ) .init(); } pub fn setup_otlp_tracing( configuration: String, + client_name: String, user: String, password: String, otlp_endpoint: String, ) -> anyhow::Result<()> { - let otlp_tracer = - super::create_otlp_tracer(user, password, otlp_endpoint, "element-x-ios".to_owned())?; + log_panics(); + let otlp_tracer = super::create_otlp_tracer(user, password, otlp_endpoint, client_name)?; let otlp_layer = tracing_opentelemetry::layer().with_tracer(otlp_tracer); tracing_subscriber::registry() .with(EnvFilter::new(configuration)) - .with(fmt::layer().with_ansi(false).with_writer(io::stderr)) + .with( + tracing_android::layer("org.matrix.rust.sdk") + .expect("Could not configure the Android tracing layer"), + ) .with(otlp_layer) .init(); @@ -142,17 +169,51 @@ mod ios { } } -#[cfg(not(any(target_os = "ios", target_os = "android")))] -mod other { - use std::io; +#[cfg(target_os = "ios")] +mod ios { + pub fn setup_tracing(configuration: String) { + super::setup_tracing_helper(configuration, false); + } - use tracing_subscriber::{fmt, prelude::*, EnvFilter}; + pub fn setup_otlp_tracing( + configuration: String, + client_name: String, + user: String, + password: String, + otlp_endpoint: String, + ) -> anyhow::Result<()> { + super::setup_otlp_tracing_helper( + configuration, + false, + client_name, + user, + password, + otlp_endpoint, + ) + } +} +#[cfg(not(any(target_os = "ios", target_os = "android")))] +mod other { pub fn setup_tracing(configuration: String) { - tracing_subscriber::registry() - .with(EnvFilter::new(configuration)) - .with(fmt::layer().with_ansi(true).with_writer(io::stderr)) - .init(); + super::setup_tracing_helper(configuration, true); + } + + pub fn setup_otlp_tracing( + configuration: String, + client_name: String, + user: String, + password: String, + otlp_endpoint: String, + ) -> anyhow::Result<()> { + super::setup_otlp_tracing_helper( + configuration, + true, + client_name, + user, + password, + otlp_endpoint, + ) } } @@ -161,9 +222,14 @@ pub fn setup_tracing(filter: String) { platform_impl::setup_tracing(filter) } -#[cfg(target_os = "ios")] #[uniffi::export] -pub fn setup_otlp_tracing(filter: String, user: String, password: String, otlp_endpoint: String) { - platform_impl::setup_otlp_tracing(filter, user, password, otlp_endpoint) +pub fn setup_otlp_tracing( + filter: String, + client_name: String, + user: String, + password: String, + otlp_endpoint: String, +) { + platform_impl::setup_otlp_tracing(filter, client_name, user, password, otlp_endpoint) .expect("Couldn't configure the OpenTelemetry tracer") } diff --git a/bindings/matrix-sdk-ffi/src/room.rs b/bindings/matrix-sdk-ffi/src/room.rs index dca5543ff1a..ff0e396748a 100644 --- a/bindings/matrix-sdk-ffi/src/room.rs +++ b/bindings/matrix-sdk-ffi/src/room.rs @@ -4,12 +4,14 @@ use std::{ }; use anyhow::{bail, Context, Result}; -use futures_signals::signal_vec::SignalVecExt; +use futures_util::StreamExt; use matrix_sdk::{ - room::{timeline::Timeline, Room as SdkRoom}, + room::{timeline::Timeline, Receipts, Room as SdkRoom}, ruma::{ + api::client::{receipt::create_receipt::v3::ReceiptType, room::report_content}, events::{ reaction::ReactionEventContent, + receipt::ReceiptThread, relation::{Annotation, Replacement}, room::message::{ ForwardThread, MessageType, Relation, RoomMessageEvent, RoomMessageEventContent, @@ -18,10 +20,11 @@ use matrix_sdk::{ EventId, UserId, }, }; +use mime::Mime; use tracing::error; use super::RUNTIME; -use crate::{TimelineDiff, TimelineListener}; +use crate::{TimelineDiff, TimelineItem, TimelineListener}; #[derive(uniffi::Enum)] pub enum Membership { @@ -37,7 +40,7 @@ pub struct Room { timeline: TimelineLock, } -#[derive(Clone, uniffi::Enum)] +#[derive(Clone)] pub enum MembershipState { /// The user is banned. Ban, @@ -238,30 +241,40 @@ impl Room { }) } - pub fn add_timeline_listener(&self, listener: Box) { - let timeline_signal = self + pub fn add_timeline_listener( + &self, + listener: Box, + ) -> Vec> { + let timeline = self .timeline .write() .unwrap() .get_or_insert_with(|| { let room = self.room.clone(); - let timeline = RUNTIME.block_on(async move { room.timeline().await }); + #[allow(unknown_lints, clippy::redundant_async_block)] // false positive + let timeline = RUNTIME.block_on(room.timeline()); Arc::new(timeline) }) - .signal(); - - let listener: Arc = listener.into(); - RUNTIME.spawn(timeline_signal.for_each(move |diff| { - let listener = listener.clone(); - let fut = RUNTIME - .spawn_blocking(move || listener.on_update(Arc::new(TimelineDiff::new(diff)))); + .clone(); - async move { - if let Err(e) = fut.await { - error!("Timeline listener error: {e}"); + RUNTIME.block_on(async move { + let (timeline_items, timeline_stream) = timeline.subscribe().await; + + let listener: Arc = listener.into(); + RUNTIME.spawn(timeline_stream.for_each(move |diff| { + let listener = listener.clone(); + let fut = RUNTIME + .spawn_blocking(move || listener.on_update(Arc::new(TimelineDiff::new(diff)))); + + async move { + if let Err(e) = fut.await { + error!("Timeline listener error: {e}"); + } } - } - })); + })); + + timeline_items.into_iter().map(TimelineItem::from_arc).collect() + }) } pub fn paginate_backwards(&self, opts: PaginationOptions) -> Result<()> { @@ -281,7 +294,8 @@ impl Room { let event_id = EventId::parse(event_id)?; RUNTIME.block_on(async move { - room.read_receipt(&event_id).await?; + room.send_single_receipt(ReceiptType::Read, ReceiptThread::Unthreaded, event_id) + .await?; Ok(()) }) } @@ -302,9 +316,11 @@ impl Room { .map(EventId::parse) .transpose() .context("parsing read receipt event ID")?; + let receipts = + Receipts::new().fully_read_marker(fully_read).public_read_receipt(read_receipt); RUNTIME.block_on(async move { - room.read_marker(&fully_read, read_receipt.as_deref()).await?; + room.send_multiple_receipts(receipts).await?; Ok(()) }) } @@ -452,6 +468,123 @@ impl Room { Ok(()) }) } + + /// Reports an event from the room. + /// + /// # Arguments + /// + /// * `event_id` - The ID of the event to report + /// + /// * `reason` - The reason for the event being reported (optional). + /// + /// * `score` - The score to rate this content as where -100 is most + /// offensive and 0 is inoffensive (optional). + pub fn report_content( + &self, + event_id: String, + score: Option, + reason: Option, + ) -> Result<()> { + let int_score = score.map(|value| value.into()); + RUNTIME.block_on(async move { + let event_id = EventId::parse(event_id)?; + self.room + .client() + .send( + report_content::v3::Request::new( + self.room_id().into(), + event_id, + int_score, + reason, + ), + None, + ) + .await?; + Ok(()) + }) + } + + /// Leaves the joined room. + /// + /// Will throw an error if used on an room that isn't in a joined state + pub fn leave(&self) -> Result<()> { + let room = match &self.room { + SdkRoom::Joined(j) => j.clone(), + _ => bail!("Can't leave a room that isn't in joined state"), + }; + + RUNTIME.block_on(async move { + room.leave().await?; + Ok(()) + }) + } + + /// Rejects invitation for the invited room. + /// + /// Will throw an error if used on an room that isn't in an invited state + pub fn reject_invitation(&self) -> Result<()> { + let room = match &self.room { + SdkRoom::Invited(i) => i.clone(), + _ => bail!("Can't reject an invite for a room that isn't in invited state"), + }; + + RUNTIME.block_on(async move { + room.reject_invitation().await?; + Ok(()) + }) + } + + /// Sets a new topic in the room. + pub fn set_topic(&self, topic: String) -> Result<()> { + let room = match &self.room { + SdkRoom::Joined(j) => j.clone(), + _ => bail!("Can't set a topic in a room that isn't in joined state"), + }; + + RUNTIME.block_on(async move { + room.set_room_topic(&topic).await?; + Ok(()) + }) + } + + /// Upload and set the room's avatar. + /// + /// This will upload the data produced by the reader to the homeserver's + /// content repository, and set the room's avatar to the MXC URI for the + /// uploaded file. + /// + /// # Arguments + /// + /// * `mime_type` - The mime description of the avatar, for example + /// image/jpeg + /// * `data` - The raw data that will be uploaded to the homeserver's + /// content repository + pub fn upload_avatar(&self, mime_type: String, data: Vec) -> Result<()> { + let room = match &self.room { + SdkRoom::Joined(j) => j.clone(), + _ => bail!("Can't set a avatar in a room that isn't in joined state"), + }; + + RUNTIME.block_on(async move { + let mime: Mime = mime_type.parse()?; + // TODO: We could add an FFI ImageInfo struct in the future + room.upload_avatar(&mime, data, None).await?; + Ok(()) + }) + } + + /// Removes the current room avatar + pub fn remove_avatar(&self) -> Result<()> { + let room = match &self.room { + SdkRoom::Joined(j) => j.clone(), + _ => bail!("Can't remove a avatar in a room that isn't in joined state"), + }; + + RUNTIME.block_on(async move { + room.remove_avatar().await?; + Ok(()) + }) + } } impl std::ops::Deref for Room { diff --git a/bindings/matrix-sdk-ffi/src/sliding_sync.rs b/bindings/matrix-sdk-ffi/src/sliding_sync.rs index 408e93745f2..1b4634dc76c 100644 --- a/bindings/matrix-sdk-ffi/src/sliding_sync.rs +++ b/bindings/matrix-sdk-ffi/src/sliding_sync.rs @@ -1,12 +1,8 @@ -use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, RwLock, -}; +use std::sync::{Arc, RwLock}; -use futures_signals::{ - signal::SignalExt, - signal_vec::{SignalVecExt, VecDiff}, -}; +use anyhow::Context; +use eyeball::unique::Observable; +use eyeball_im::VectorDiff; use futures_util::{future::join, pin_mut, StreamExt}; use matrix_sdk::ruma::{ api::client::sync::sync_events::{ @@ -21,57 +17,52 @@ pub use matrix_sdk::{ SlidingSyncBuilder as MatrixSlidingSyncBuilder, SlidingSyncMode, SlidingSyncState, }; use tokio::{sync::broadcast::error::RecvError, task::JoinHandle}; -use tracing::{debug, error, trace, warn}; +use tracing::{debug, error, warn}; +use url::Url; -use super::{Client, Room, RUNTIME}; use crate::{ - helpers::unwrap_or_clone_arc, room::TimelineLock, EventTimelineItem, TimelineDiff, - TimelineListener, + helpers::unwrap_or_clone_arc, room::TimelineLock, Client, ClientError, EventTimelineItem, Room, + TimelineDiff, TimelineItem, TimelineListener, RUNTIME, }; -type StoppableSpawnCallback = Box; +type TaskHandleFinalizer = Box; -pub struct StoppableSpawn { - handle: Option>, - callback: RwLock>, +pub struct TaskHandle { + handle: JoinHandle<()>, + finalizer: RwLock>, } -impl StoppableSpawn { - fn with_handle(handle: JoinHandle<()>) -> StoppableSpawn { - StoppableSpawn { handle: Some(handle), callback: Default::default() } - } - fn with_callback(callback: StoppableSpawnCallback) -> StoppableSpawn { - StoppableSpawn { handle: Default::default(), callback: RwLock::new(Some(callback)) } - } - - fn set_callback(&mut self, f: StoppableSpawnCallback) { - *self.callback.write().unwrap() = Some(f) +impl TaskHandle { + // Create a new task handle. + fn new(handle: JoinHandle<()>) -> Self { + Self { handle, finalizer: RwLock::new(None) } } -} -impl From> for StoppableSpawn { - fn from(value: JoinHandle<()>) -> Self { - StoppableSpawn::with_handle(value) + /// Define a function that will run after the handle has been aborted. + fn set_finalizer(&mut self, finalizer: TaskHandleFinalizer) { + *self.finalizer.write().unwrap() = Some(finalizer); } } #[uniffi::export] -impl StoppableSpawn { +impl TaskHandle { pub fn cancel(&self) { debug!("stoppable.cancel() called"); - if let Some(handle) = &self.handle { - handle.abort(); - } - if let Some(callback) = self.callback.write().unwrap().take() { - callback(); + + self.handle.abort(); + + if let Some(finalizer) = self.finalizer.write().unwrap().take() { + finalizer(); } } + + /// Check whether the handle is finished. pub fn is_finished(&self) -> bool { - self.handle.as_ref().map(|h| h.is_finished()).unwrap_or_default() + self.handle.is_finished() } } -impl Drop for StoppableSpawn { +impl Drop for TaskHandle { fn drop(&mut self) { self.cancel(); } @@ -129,22 +120,23 @@ impl SlidingSyncRoom { } pub fn is_dm(&self) -> Option { - self.inner.is_dm + self.inner.is_dm() } pub fn is_initial(&self) -> Option { - self.inner.initial + self.inner.is_initial_response() } + pub fn is_loading_more(&self) -> bool { self.inner.is_loading_more() } pub fn has_unread_notifications(&self) -> bool { - !self.inner.unread_notifications.is_empty() + self.inner.has_unread_notifications() } pub fn unread_notifications(&self) -> Arc { - Arc::new(self.inner.unread_notifications.clone().into()) + Arc::new(self.inner.unread_notifications().clone().into()) } pub fn full_room(&self) -> Option> { @@ -164,58 +156,66 @@ impl SlidingSyncRoom { pub fn add_timeline_listener( &self, listener: Box, - ) -> Option> { - Some(Arc::new(self.add_timeline_listener_inner(listener)?)) + ) -> anyhow::Result { + let (items, stoppable_spawn) = self.add_timeline_listener_inner(listener)?; + + Ok(SlidingSyncSubscribeResult { items, task_handle: Arc::new(stoppable_spawn) }) } pub fn subscribe_and_add_timeline_listener( &self, listener: Box, settings: Option, - ) -> Option> { - let mut spawner = self.add_timeline_listener_inner(listener)?; + ) -> anyhow::Result { + let (items, mut stoppable_spawn) = self.add_timeline_listener_inner(listener)?; let room_id = self.inner.room_id().clone(); + self.runner.subscribe(room_id.clone(), settings.map(Into::into)); + let runner = self.runner.clone(); - spawner.set_callback(Box::new(move || runner.unsubscribe(room_id))); - Some(Arc::new(spawner)) + stoppable_spawn.set_finalizer(Box::new(move || runner.unsubscribe(room_id))); + + Ok(SlidingSyncSubscribeResult { items, task_handle: Arc::new(stoppable_spawn) }) } fn add_timeline_listener_inner( &self, listener: Box, - ) -> Option { + ) -> anyhow::Result<(Vec>, TaskHandle)> { let mut timeline_lock = self.timeline.write().unwrap(); let timeline = match &*timeline_lock { Some(timeline) => timeline, None => { - let Some(timeline) = RUNTIME.block_on(self.inner.timeline()) else { - warn!( - room_id = ?self.room_id(), - "Could not set timeline listener: no timeline found." - ); - return None; - }; + let timeline = RUNTIME + .block_on(self.inner.timeline()) + .context("Could not set timeline listener: room not found.")?; timeline_lock.insert(Arc::new(timeline)) } }; - let timeline_signal = timeline.signal(); - - let listener: Arc = listener.into(); - let handle_events = timeline_signal.for_each(move |diff| { - let listener = listener.clone(); - let fut = RUNTIME - .spawn_blocking(move || listener.on_update(Arc::new(TimelineDiff::new(diff)))); - async move { - if let Err(e) = fut.await { - error!("Timeline listener error: {e}"); - } - } - }); + #[allow(unknown_lints, clippy::redundant_async_block)] // false positive + let (timeline_items, timeline_stream) = RUNTIME.block_on(timeline.subscribe()); + + let handle_events = async move { + let listener: Arc = listener.into(); + timeline_stream + .for_each(move |diff| { + let listener = listener.clone(); + let fut = RUNTIME.spawn_blocking(move || { + listener.on_update(Arc::new(TimelineDiff::new(diff))) + }); + + async move { + if let Err(e) = fut.await { + error!("Timeline listener error: {e}"); + } + } + }) + .await; + }; let mut reset_broadcast_rx = self.client.sliding_sync_reset_broadcast_tx.subscribe(); - let timeline = timeline.clone(); + let timeline = timeline.to_owned(); let handle_sliding_sync_reset = async move { loop { match reset_broadcast_rx.recv().await { @@ -225,19 +225,26 @@ impl SlidingSyncRoom { } }; - Some(StoppableSpawn::with_handle(RUNTIME.spawn(async move { + let items = timeline_items.into_iter().map(TimelineItem::from_arc).collect(); + let task_handle = TaskHandle::new(RUNTIME.spawn(async move { join(handle_events, handle_sliding_sync_reset).await; - }))) + })); + + Ok((items, task_handle)) } } +pub struct SlidingSyncSubscribeResult { + pub items: Vec>, + pub task_handle: Arc, +} + pub struct UpdateSummary { - /// The views (according to their name), which have seen an update - pub views: Vec, + /// The lists (according to their name), which have seen an update + pub lists: Vec, pub rooms: Vec, } -#[derive(uniffi::Record)] pub struct RequiredState { pub key: String, pub value: String, @@ -260,56 +267,54 @@ impl From for RumaRoomSubscription { } impl From for UpdateSummary { - fn from(other: matrix_sdk::UpdateSummary) -> UpdateSummary { - UpdateSummary { - views: other.views, + fn from(other: matrix_sdk::UpdateSummary) -> Self { + Self { + lists: other.lists, rooms: other.rooms.into_iter().map(|r| r.as_str().to_owned()).collect(), } } } -pub enum SlidingSyncViewRoomsListDiff { - Replace { values: Vec }, - InsertAt { index: u32, value: RoomListEntry }, - UpdateAt { index: u32, value: RoomListEntry }, - RemoveAt { index: u32 }, - Move { old_index: u32, new_index: u32 }, - Push { value: RoomListEntry }, - Pop, // removes the last item - Clear, // clears the list +pub enum SlidingSyncListRoomsListDiff { + Append { values: Vec }, + Insert { index: u32, value: RoomListEntry }, + Set { index: u32, value: RoomListEntry }, + Remove { index: u32 }, + PushBack { value: RoomListEntry }, + PushFront { value: RoomListEntry }, + PopBack, + PopFront, + Clear, + Reset { values: Vec }, } -impl From> for SlidingSyncViewRoomsListDiff { - fn from(other: VecDiff) -> Self { +impl From> for SlidingSyncListRoomsListDiff { + fn from(other: VectorDiff) -> Self { match other { - VecDiff::Replace { values } => SlidingSyncViewRoomsListDiff::Replace { - values: values.into_iter().map(|e| (&e).into()).collect(), - }, - VecDiff::InsertAt { index, value } => SlidingSyncViewRoomsListDiff::InsertAt { - index: index as u32, - value: (&value).into(), - }, - VecDiff::UpdateAt { index, value } => SlidingSyncViewRoomsListDiff::UpdateAt { - index: index as u32, - value: (&value).into(), - }, - VecDiff::RemoveAt { index } => { - SlidingSyncViewRoomsListDiff::RemoveAt { index: index as u32 } + VectorDiff::Append { values } => { + Self::Append { values: values.into_iter().map(|e| (&e).into()).collect() } + } + VectorDiff::Insert { index, value } => { + Self::Insert { index: index as u32, value: (&value).into() } + } + VectorDiff::Set { index, value } => { + Self::Set { index: index as u32, value: (&value).into() } } - VecDiff::Move { old_index, new_index } => SlidingSyncViewRoomsListDiff::Move { - old_index: old_index as u32, - new_index: new_index as u32, - }, - VecDiff::Push { value } => { - SlidingSyncViewRoomsListDiff::Push { value: (&value).into() } + VectorDiff::Remove { index } => Self::Remove { index: index as u32 }, + VectorDiff::PushBack { value } => Self::PushBack { value: (&value).into() }, + VectorDiff::PushFront { value } => Self::PushFront { value: (&value).into() }, + VectorDiff::PopBack => Self::PopBack, + VectorDiff::PopFront => Self::PopFront, + VectorDiff::Clear => Self::Clear, + VectorDiff::Reset { values } => { + warn!("Room list subscriber lagged behind and was reset"); + Self::Reset { values: values.into_iter().map(|e| (&e).into()).collect() } } - VecDiff::Pop {} => SlidingSyncViewRoomsListDiff::Pop, - VecDiff::Clear {} => SlidingSyncViewRoomsListDiff::Clear, } } } -#[derive(Clone, Debug, uniffi::Enum)] +#[derive(Clone, Debug)] pub enum RoomListEntry { Empty, Invalidated { room_id: String }, @@ -319,33 +324,32 @@ pub enum RoomListEntry { impl From<&MatrixRoomEntry> for RoomListEntry { fn from(other: &MatrixRoomEntry) -> Self { match other { - MatrixRoomEntry::Empty => RoomListEntry::Empty, - MatrixRoomEntry::Filled(b) => RoomListEntry::Filled { room_id: b.to_string() }, - MatrixRoomEntry::Invalidated(b) => { - RoomListEntry::Invalidated { room_id: b.to_string() } - } + MatrixRoomEntry::Empty => Self::Empty, + MatrixRoomEntry::Filled(b) => Self::Filled { room_id: b.to_string() }, + MatrixRoomEntry::Invalidated(b) => Self::Invalidated { room_id: b.to_string() }, } } } -pub trait SlidingSyncViewRoomItemsObserver: Sync + Send { +pub trait SlidingSyncListRoomItemsObserver: Sync + Send { fn did_receive_update(&self); } -pub trait SlidingSyncViewRoomListObserver: Sync + Send { - fn did_receive_update(&self, diff: SlidingSyncViewRoomsListDiff); +pub trait SlidingSyncListRoomListObserver: Sync + Send { + fn did_receive_update(&self, diff: SlidingSyncListRoomsListDiff); } -pub trait SlidingSyncViewRoomsCountObserver: Sync + Send { +pub trait SlidingSyncListRoomsCountObserver: Sync + Send { fn did_receive_update(&self, new_count: u32); } -pub trait SlidingSyncViewStateObserver: Sync + Send { +pub trait SlidingSyncListStateObserver: Sync + Send { fn did_receive_update(&self, new_state: SlidingSyncState); } -#[derive(Clone, Default)] -pub struct SlidingSyncViewBuilder { - inner: matrix_sdk::SlidingSyncViewBuilder, + +#[derive(Clone)] +pub struct SlidingSyncListBuilder { + inner: matrix_sdk::SlidingSyncListBuilder, } #[derive(uniffi::Record)] @@ -377,15 +381,16 @@ impl From for SyncRequestListFilters { tags, not_tags, } = value; + assign!(SyncRequestListFilters::default(), { is_dm, spaces, is_encrypted, is_invite, is_tombstoned, room_types, not_room_types, room_name_like, tags, not_tags, }) } } -impl SlidingSyncViewBuilder { +impl SlidingSyncListBuilder { pub fn new() -> Self { - Default::default() + Self { inner: matrix_sdk::SlidingSyncList::builder() } } pub fn sync_mode(self: Arc, mode: SlidingSyncMode) -> Arc { @@ -406,14 +411,14 @@ impl SlidingSyncViewBuilder { Arc::new(builder) } - pub fn build(self: Arc) -> anyhow::Result> { + pub fn build(self: Arc) -> anyhow::Result> { let builder = unwrap_or_clone_arc(self); Ok(Arc::new(builder.inner.build()?.into())) } } #[uniffi::export] -impl SlidingSyncViewBuilder { +impl SlidingSyncListBuilder { pub fn sort(self: Arc, sort: Vec) -> Arc { let mut builder = unwrap_or_clone_arc(self); builder.inner = builder.inner.sort(sort); @@ -490,25 +495,26 @@ impl SlidingSyncViewBuilder { } #[derive(Clone)] -pub struct SlidingSyncView { - inner: matrix_sdk::SlidingSyncView, +pub struct SlidingSyncList { + inner: matrix_sdk::SlidingSyncList, } -impl From for SlidingSyncView { - fn from(inner: matrix_sdk::SlidingSyncView) -> Self { - SlidingSyncView { inner } +impl From for SlidingSyncList { + fn from(inner: matrix_sdk::SlidingSyncList) -> Self { + SlidingSyncList { inner } } } -impl SlidingSyncView { +impl SlidingSyncList { pub fn observe_state( &self, - observer: Box, - ) -> Arc { - let mut signal = self.inner.state.signal_cloned().to_stream(); - Arc::new(StoppableSpawn::with_handle(RUNTIME.spawn(async move { + observer: Box, + ) -> Arc { + let mut state_stream = self.inner.state_stream(); + + Arc::new(TaskHandle::new(RUNTIME.spawn(async move { loop { - if let Some(new_state) = signal.next().await { + if let Some(new_state) = state_stream.next().await { observer.did_receive_update(new_state); } } @@ -517,12 +523,13 @@ impl SlidingSyncView { pub fn observe_room_list( &self, - observer: Box, - ) -> Arc { - let mut room_list = self.inner.rooms_list.signal_vec_cloned().to_stream(); - Arc::new(StoppableSpawn::with_handle(RUNTIME.spawn(async move { + observer: Box, + ) -> Arc { + let mut rooms_list_stream = self.inner.rooms_list_stream(); + + Arc::new(TaskHandle::new(RUNTIME.spawn(async move { loop { - if let Some(diff) = room_list.next().await { + if let Some(diff) = rooms_list_stream.next().await { observer.did_receive_update(diff.into()); } } @@ -531,10 +538,11 @@ impl SlidingSyncView { pub fn observe_room_items( &self, - observer: Box, - ) -> Arc { - let mut rooms_updated = self.inner.rooms_updated_broadcaster.signal_cloned().to_stream(); - Arc::new(StoppableSpawn::with_handle(RUNTIME.spawn(async move { + observer: Box, + ) -> Arc { + let mut rooms_updated = + Observable::subscribe(&self.inner.rooms_updated_broadcast.read().unwrap()); + Arc::new(TaskHandle::new(RUNTIME.spawn(async move { loop { if rooms_updated.next().await.is_some() { observer.did_receive_update(); @@ -545,12 +553,13 @@ impl SlidingSyncView { pub fn observe_rooms_count( &self, - observer: Box, - ) -> Arc { - let mut rooms_count = self.inner.rooms_count.signal_cloned().to_stream(); - Arc::new(StoppableSpawn::with_handle(RUNTIME.spawn(async move { + observer: Box, + ) -> Arc { + let mut rooms_count_stream = self.inner.rooms_count_stream(); + + Arc::new(TaskHandle::new(RUNTIME.spawn(async move { loop { - if let Some(Some(new)) = rooms_count.next().await { + if let Some(Some(new)) = rooms_count_stream.next().await { observer.did_receive_update(new); } } @@ -559,10 +568,10 @@ impl SlidingSyncView { } #[uniffi::export] -impl SlidingSyncView { +impl SlidingSyncList { /// Get the current list of rooms pub fn current_rooms_list(&self) -> Vec { - self.inner.rooms_list.lock_ref().as_slice().iter().map(|e| e.into()).collect() + self.inner.rooms_list() } /// Reset the ranges to a particular set @@ -588,22 +597,24 @@ impl SlidingSyncView { /// Total of rooms matching the filter pub fn current_room_count(&self) -> Option { - self.inner.rooms_count.get_cloned() + self.inner.rooms_count() } /// The current timeline limit pub fn get_timeline_limit(&self) -> Option { - self.inner.timeline_limit.get_cloned().map(|limit| u32::try_from(limit).unwrap_or_default()) + (**self.inner.timeline_limit.read().unwrap()) + .map(|limit| u32::try_from(limit).unwrap_or_default()) } /// The current timeline limit pub fn set_timeline_limit(&self, value: u32) { - self.inner.timeline_limit.set(Some(UInt::try_from(value).unwrap())) + let value = Some(UInt::try_from(value).unwrap()); + Observable::set(&mut self.inner.timeline_limit.write().unwrap(), value); } /// Unset the current timeline limit pub fn unset_timeline_limit(&self) { - self.inner.timeline_limit.set(None) + Observable::set(&mut self.inner.timeline_limit.write().unwrap(), None); } } @@ -619,7 +630,7 @@ pub struct SlidingSync { impl SlidingSync { fn new(inner: matrix_sdk::SlidingSync, client: Client) -> Self { - SlidingSync { inner, client, observer: Default::default() } + Self { inner, client, observer: Default::default() } } pub fn set_observer(&self, observer: Option>) { @@ -682,63 +693,55 @@ impl SlidingSync { #[uniffi::export] impl SlidingSync { #[allow(clippy::significant_drop_in_scrutinee)] - pub fn get_view(&self, name: String) -> Option> { - self.inner.view(&name).map(|inner| Arc::new(SlidingSyncView { inner })) + pub fn get_list(&self, name: String) -> Option> { + self.inner.list(&name).map(|inner| Arc::new(SlidingSyncList { inner })) } - pub fn add_view(&self, view: Arc) -> Option> { - self.inner.add_view(view.inner.clone()).map(|inner| Arc::new(SlidingSyncView { inner })) + pub fn add_list(&self, list: Arc) -> Option> { + self.inner.add_list(list.inner.clone()).map(|inner| Arc::new(SlidingSyncList { inner })) } - pub fn pop_view(&self, name: String) -> Option> { - self.inner.pop_view(&name).map(|inner| Arc::new(SlidingSyncView { inner })) + pub fn pop_list(&self, name: String) -> Option> { + self.inner.pop_list(&name).map(|inner| Arc::new(SlidingSyncList { inner })) } pub fn add_common_extensions(&self) { self.inner.add_common_extensions(); } - pub fn sync(&self) -> Arc { + pub fn sync(&self) -> Arc { let inner = self.inner.clone(); let client = self.client.clone(); let observer = self.observer.clone(); - let stop_loop = Arc::new(AtomicBool::new(false)); - let remote_stopper = stop_loop.clone(); - - let stoppable = Arc::new(StoppableSpawn::with_callback(Box::new(move || { - remote_stopper.store(true, Ordering::Relaxed); - }))); - RUNTIME.spawn(async move { + Arc::new(TaskHandle::new(RUNTIME.spawn(async move { let stream = inner.stream(); pin_mut!(stream); + loop { - let update = match stream.next().await { - Some(Ok(u)) => u, - Some(Err(e)) => { - if client.process_sync_error(e) == LoopCtrl::Break { + let update_summary = match stream.next().await { + Some(Ok(update_summary)) => update_summary, + + Some(Err(error)) => { + if client.process_sync_error(error) == LoopCtrl::Break { warn!("loop was stopped by client error processing"); break; } else { continue; } } + None => { warn!("Inner streaming loop ended unexpectedly"); break; } }; + if let Some(ref observer) = *observer.read().unwrap() { - observer.did_receive_sync_update(update.into()); - } - if stop_loop.load(Ordering::Relaxed) { - trace!("stopped sync loop after cancellation"); - break; + observer.did_receive_sync_update(update_summary.into()); } } - }); - - stoppable + }))) } } @@ -765,15 +768,15 @@ impl SlidingSyncBuilder { #[uniffi::export] impl SlidingSyncBuilder { - pub fn add_fullsync_view(self: Arc) -> Arc { + pub fn add_fullsync_list(self: Arc) -> Arc { let mut builder = unwrap_or_clone_arc(self); - builder.inner = builder.inner.add_fullsync_view(); + builder.inner = builder.inner.add_fullsync_list(); Arc::new(builder) } - pub fn no_views(self: Arc) -> Arc { + pub fn no_lists(self: Arc) -> Arc { let mut builder = unwrap_or_clone_arc(self); - builder.inner = builder.inner.no_views(); + builder.inner = builder.inner.no_lists(); Arc::new(builder) } @@ -783,10 +786,10 @@ impl SlidingSyncBuilder { Arc::new(builder) } - pub fn add_view(self: Arc, v: Arc) -> Arc { + pub fn add_list(self: Arc, v: Arc) -> Arc { let mut builder = unwrap_or_clone_arc(self); - let view = unwrap_or_clone_arc(v); - builder.inner = builder.inner.add_view(view.inner); + let list = unwrap_or_clone_arc(v); + builder.inner = builder.inner.add_list(list.inner); Arc::new(builder) } @@ -833,11 +836,12 @@ impl SlidingSyncBuilder { } } +#[uniffi::export] impl Client { - pub fn full_sliding_sync(&self) -> anyhow::Result> { + pub fn full_sliding_sync(&self) -> Result, ClientError> { RUNTIME.block_on(async move { let builder = self.client.sliding_sync().await; - let inner = builder.add_fullsync_view().build().await?; + let inner = builder.add_fullsync_list().build().await?; Ok(Arc::new(SlidingSync::new(inner, self.clone()))) }) } @@ -847,7 +851,16 @@ impl Client { impl Client { pub fn sliding_sync(&self) -> Arc { RUNTIME.block_on(async move { - let inner = self.client.sliding_sync().await; + let mut inner = self.client.sliding_sync().await; + if let Some(sliding_sync_proxy) = self + .sliding_sync_proxy + .read() + .unwrap() + .clone() + .and_then(|p| Url::parse(p.as_str()).ok()) + { + inner = inner.homeserver(sliding_sync_proxy); + } Arc::new(SlidingSyncBuilder { inner, client: self.clone() }) }) } diff --git a/bindings/matrix-sdk-ffi/src/timeline.rs b/bindings/matrix-sdk-ffi/src/timeline.rs index 949885cb2af..12f444f3836 100644 --- a/bindings/matrix-sdk-ffi/src/timeline.rs +++ b/bindings/matrix-sdk-ffi/src/timeline.rs @@ -1,9 +1,12 @@ use std::sync::Arc; use extension_trait::extension_trait; -use futures_signals::signal_vec::VecDiff; +use eyeball_im::VectorDiff; use matrix_sdk::room::timeline::{Profile, TimelineDetails}; pub use matrix_sdk::ruma::events::room::{message::RoomMessageEventContent, MediaSource}; +use tracing::warn; + +use crate::helpers::unwrap_or_clone_arc; #[uniffi::export] pub fn media_source_from_url(url: String) -> Arc { @@ -19,98 +22,135 @@ pub trait TimelineListener: Sync + Send { fn on_update(&self, diff: Arc); } -#[repr(transparent)] #[derive(Clone)] -pub struct TimelineDiff(VecDiff>); +pub enum TimelineDiff { + Append { values: Vec> }, + Clear, + PushFront { value: Arc }, + PushBack { value: Arc }, + PopFront, + PopBack, + Insert { index: usize, value: Arc }, + Set { index: usize, value: Arc }, + Remove { index: usize }, + Reset { values: Vec> }, +} impl TimelineDiff { - pub(crate) fn new(inner: VecDiff>) -> Self { - TimelineDiff(match inner { - // Note: It's _probably_ valid to only transmute here too but not - // as clear, and less important because this only happens - // once when constructing the timeline. - VecDiff::Replace { values } => VecDiff::Replace { - values: values.into_iter().map(TimelineItem::from_arc).collect(), - }, - VecDiff::InsertAt { index, value } => { - VecDiff::InsertAt { index, value: TimelineItem::from_arc(value) } + pub(crate) fn new(inner: VectorDiff>) -> Self { + match inner { + VectorDiff::Append { values } => { + Self::Append { values: values.into_iter().map(TimelineItem::from_arc).collect() } } - VecDiff::UpdateAt { index, value } => { - VecDiff::UpdateAt { index, value: TimelineItem::from_arc(value) } + VectorDiff::Clear => Self::Clear, + VectorDiff::Insert { index, value } => { + Self::Insert { index, value: TimelineItem::from_arc(value) } } - VecDiff::RemoveAt { index } => VecDiff::RemoveAt { index }, - VecDiff::Move { old_index, new_index } => VecDiff::Move { old_index, new_index }, - VecDiff::Push { value } => VecDiff::Push { value: TimelineItem::from_arc(value) }, - VecDiff::Pop {} => VecDiff::Pop {}, - VecDiff::Clear {} => VecDiff::Clear {}, - }) + VectorDiff::Set { index, value } => { + Self::Set { index, value: TimelineItem::from_arc(value) } + } + VectorDiff::Remove { index } => Self::Remove { index }, + VectorDiff::PushBack { value } => { + Self::PushBack { value: TimelineItem::from_arc(value) } + } + VectorDiff::PushFront { value } => { + Self::PushFront { value: TimelineItem::from_arc(value) } + } + VectorDiff::PopBack => Self::PopBack, + VectorDiff::PopFront => Self::PopFront, + VectorDiff::Reset { values } => { + warn!("Timeline subscriber lagged behind and was reset"); + Self::Reset { values: values.into_iter().map(TimelineItem::from_arc).collect() } + } + } } } #[uniffi::export] impl TimelineDiff { pub fn change(&self) -> TimelineChange { - match &self.0 { - VecDiff::Replace { .. } => TimelineChange::Replace, - VecDiff::InsertAt { .. } => TimelineChange::InsertAt, - VecDiff::UpdateAt { .. } => TimelineChange::UpdateAt, - VecDiff::RemoveAt { .. } => TimelineChange::RemoveAt, - VecDiff::Move { .. } => TimelineChange::Move, - VecDiff::Push { .. } => TimelineChange::Push, - VecDiff::Pop {} => TimelineChange::Pop, - VecDiff::Clear {} => TimelineChange::Clear, + match self { + Self::Append { .. } => TimelineChange::Append, + Self::Insert { .. } => TimelineChange::Insert, + Self::Set { .. } => TimelineChange::Set, + Self::Remove { .. } => TimelineChange::Remove, + Self::PushBack { .. } => TimelineChange::PushBack, + Self::PushFront { .. } => TimelineChange::PushFront, + Self::PopBack => TimelineChange::PopBack, + Self::PopFront => TimelineChange::PopFront, + Self::Clear => TimelineChange::Clear, + Self::Reset { .. } => TimelineChange::Reset, } } - pub fn replace(self: Arc) -> Option>> { - unwrap_or_clone_arc_into_variant!(self, .0, VecDiff::Replace { values } => values) + pub fn append(self: Arc) -> Option>> { + let this = unwrap_or_clone_arc(self); + match this { + Self::Append { values } => Some(values), + _ => None, + } } - pub fn insert_at(self: Arc) -> Option { - unwrap_or_clone_arc_into_variant!(self, .0, VecDiff::InsertAt { index, value } => { - InsertAtData { index: index.try_into().unwrap(), item: value } - }) + pub fn insert(self: Arc) -> Option { + let this = unwrap_or_clone_arc(self); + match this { + Self::Insert { index, value } => { + Some(InsertData { index: index.try_into().unwrap(), item: value }) + } + _ => None, + } } - pub fn update_at(self: Arc) -> Option { - unwrap_or_clone_arc_into_variant!(self, .0, VecDiff::UpdateAt { index, value } => { - UpdateAtData { index: index.try_into().unwrap(), item: value } - }) + pub fn set(self: Arc) -> Option { + let this = unwrap_or_clone_arc(self); + match this { + Self::Set { index, value } => { + Some(SetData { index: index.try_into().unwrap(), item: value }) + } + _ => None, + } } - pub fn remove_at(&self) -> Option { - match &self.0 { - VecDiff::RemoveAt { index } => Some((*index).try_into().unwrap()), + pub fn remove(&self) -> Option { + match self { + Self::Remove { index } => Some((*index).try_into().unwrap()), _ => None, } } - pub fn push(self: Arc) -> Option> { - unwrap_or_clone_arc_into_variant!(self, .0, VecDiff::Push { value } => value) + pub fn push_back(self: Arc) -> Option> { + let this = unwrap_or_clone_arc(self); + match this { + Self::PushBack { value } => Some(value), + _ => None, + } } -} -// UniFFI currently chokes on the r# -impl TimelineDiff { - pub fn r#move(&self) -> Option { - match &self.0 { - VecDiff::Move { old_index, new_index } => Some(MoveData { - old_index: (*old_index).try_into().unwrap(), - new_index: (*new_index).try_into().unwrap(), - }), + pub fn push_front(self: Arc) -> Option> { + let this = unwrap_or_clone_arc(self); + match this { + Self::PushFront { value } => Some(value), + _ => None, + } + } + + pub fn reset(self: Arc) -> Option>> { + let this = unwrap_or_clone_arc(self); + match this { + Self::Reset { values } => Some(values), _ => None, } } } #[derive(uniffi::Record)] -pub struct InsertAtData { +pub struct InsertData { pub index: u32, pub item: Arc, } #[derive(uniffi::Record)] -pub struct UpdateAtData { +pub struct SetData { pub index: u32, pub item: Arc, } @@ -122,22 +162,24 @@ pub struct MoveData { #[derive(Clone, Copy, uniffi::Enum)] pub enum TimelineChange { - Replace, - InsertAt, - UpdateAt, - RemoveAt, - Move, - Push, - Pop, + Append, Clear, + Insert, + Set, + Remove, + PushBack, + PushFront, + PopBack, + PopFront, + Reset, } #[repr(transparent)] -#[derive(Clone, uniffi::Object)] -pub struct TimelineItem(matrix_sdk::room::timeline::TimelineItem); +#[derive(Clone)] +pub struct TimelineItem(pub(crate) matrix_sdk::room::timeline::TimelineItem); impl TimelineItem { - fn from_arc(arc: Arc) -> Arc { + pub(crate) fn from_arc(arc: Arc) -> Arc { // SAFETY: This is valid because Self is a repr(transparent) wrapper // around the other Timeline type. unsafe { Arc::from_raw(Arc::into_raw(arc) as _) } @@ -270,7 +312,7 @@ impl EventTimelineItem { use matrix_sdk::room::timeline::EventTimelineItem::*; match &self.0 { - Local(local_event) => Some((&local_event.send_state).into()), + Local(local_event) => Some(local_event.send_state().into()), Remote(_) => None, } } @@ -434,6 +476,13 @@ impl Message { info: c.info.as_deref().map(Into::into), }, }), + MTy::Audio(c) => Some(MessageType::Audio { + content: AudioMessageContent { + body: c.body.clone(), + source: Arc::new(c.source.clone()), + info: c.info.as_deref().map(Into::into), + }, + }), MTy::Video(c) => Some(MessageType::Video { content: VideoMessageContent { body: c.body.clone(), @@ -482,6 +531,7 @@ impl Message { pub enum MessageType { Emote { content: EmoteMessageContent }, Image { content: ImageMessageContent }, + Audio { content: AudioMessageContent }, Video { content: VideoMessageContent }, File { content: FileMessageContent }, Notice { content: NoticeMessageContent }, @@ -501,6 +551,13 @@ pub struct ImageMessageContent { pub info: Option, } +#[derive(Clone, uniffi::Record)] +pub struct AudioMessageContent { + pub body: String, + pub source: Arc, + pub info: Option, +} + #[derive(Clone, uniffi::Record)] pub struct VideoMessageContent { pub body: String, @@ -526,6 +583,15 @@ pub struct ImageInfo { pub blurhash: Option, } +#[derive(Clone, uniffi::Record)] +pub struct AudioInfo { + // FIXME: duration should be a std::time::Duration once the UniFFI proc-macro API adds support + // for that + pub duration: Option, + pub size: Option, + pub mimetype: Option, +} + #[derive(Clone, uniffi::Record)] pub struct VideoInfo { pub duration: Option, @@ -611,6 +677,16 @@ impl From<&matrix_sdk::ruma::events::room::ImageInfo> for ImageInfo { } } +impl From<&matrix_sdk::ruma::events::room::message::AudioInfo> for AudioInfo { + fn from(info: &matrix_sdk::ruma::events::room::message::AudioInfo) -> Self { + Self { + duration: info.duration.map(|d| d.as_millis() as u64), + size: info.size.map(Into::into), + mimetype: info.mimetype.clone(), + } + } +} + impl From<&matrix_sdk::ruma::events::room::message::VideoInfo> for VideoInfo { fn from(info: &matrix_sdk::ruma::events::room::message::VideoInfo) -> Self { let thumbnail_info = info.thumbnail_info.as_ref().map(|info| ThumbnailInfo { diff --git a/crates/matrix-sdk-appservice/Cargo.toml b/crates/matrix-sdk-appservice/Cargo.toml index ccb41b948a7..42538b351b2 100644 --- a/crates/matrix-sdk-appservice/Cargo.toml +++ b/crates/matrix-sdk-appservice/Cargo.toml @@ -42,7 +42,7 @@ serde = { workspace = true } serde_html_form = { workspace = true } serde_json = { workspace = true } serde_yaml = "0.9.4" -tokio = { version = "1.23.1", default-features = false, features = ["rt-multi-thread"] } +tokio = { version = "1.24.2", default-features = false, features = ["rt-multi-thread"] } thiserror = { workspace = true } tower = { version = "0.4.13", default-features = false } tracing = { workspace = true } @@ -50,6 +50,6 @@ url = "2.2.2" [dev-dependencies] matrix-sdk-test = { version = "0.6.0", path = "../../testing/matrix-sdk-test", features = ["appservice"] } -tokio = { version = "1.23.1", default-features = false, features = ["rt-multi-thread", "macros"] } +tokio = { version = "1.24.2", default-features = false, features = ["rt-multi-thread", "macros"] } tracing-subscriber = "0.3.11" wiremock = "0.5.13" diff --git a/crates/matrix-sdk-base/Cargo.toml b/crates/matrix-sdk-base/Cargo.toml index 452bc27bfd9..ec210d3042d 100644 --- a/crates/matrix-sdk-base/Cargo.toml +++ b/crates/matrix-sdk-base/Cargo.toml @@ -20,23 +20,25 @@ default = [] e2e-encryption = ["dep:matrix-sdk-crypto"] js = ["matrix-sdk-common/js", "matrix-sdk-crypto?/js", "ruma/js", "matrix-sdk-store-encryption/js"] qrcode = ["matrix-sdk-crypto?/qrcode"] +automatic-room-key-forwarding = ["matrix-sdk-crypto?/automatic-room-key-forwarding"] experimental-sliding-sync = ["ruma/unstable-msc3575"] # helpers for testing features build upon this -testing = ["dep:http"] +testing = ["dep:http", "dep:matrix-sdk-test", "dep:assert_matches"] [dependencies] +assert_matches = { version = "1.5.0", optional = true } async-stream = { workspace = true } async-trait = { workspace = true } dashmap = { workspace = true } -futures-channel = "0.3.21" +eyeball = { workspace = true } futures-core = "0.3.21" -futures-signals = { version = "0.3.30", default-features = false } -futures-util = { version = "0.3.21", default-features = false } +futures-util = { workspace = true } http = { workspace = true, optional = true } matrix-sdk-common = { version = "0.6.0", path = "../matrix-sdk-common" } -matrix-sdk-crypto = { version = "0.6.0", path = "../matrix-sdk-crypto", optional = true } +matrix-sdk-crypto = { version = "0.6.0", path = "../matrix-sdk-crypto", optional = true, default-features = false } matrix-sdk-store-encryption = { version = "0.2.0", path = "../matrix-sdk-store-encryption" } +matrix-sdk-test = { version = "0.6.0", path = "../../testing/matrix-sdk-test", optional = true } once_cell = { workspace = true } ruma = { workspace = true, features = ["canonical-json"] } serde = { workspace = true, features = ["rc"] } @@ -46,6 +48,7 @@ tracing = { workspace = true } zeroize = { workspace = true, features = ["zeroize_derive"] } [dev-dependencies] +assert_matches = "1.5.0" assign = "1.1.1" ctor = { workspace = true } futures = { version = "0.3.21", default-features = false, features = ["executor"] } @@ -54,7 +57,7 @@ matrix-sdk-test = { version = "0.6.0", path = "../../testing/matrix-sdk-test" } tracing-subscriber = { version = "0.3.11", features = ["env-filter"] } [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] -tokio = { version = "1.23.1", default-features = false, features = ["rt-multi-thread", "macros"] } +tokio = { version = "1.24.2", default-features = false, features = ["rt-multi-thread", "macros"] } [target.'cfg(target_arch = "wasm32")'.dev-dependencies] wasm-bindgen-test = "0.3.33" diff --git a/crates/matrix-sdk-base/Changelog.md b/crates/matrix-sdk-base/Changelog.md index f8e6777b982..138d3f7d7d8 100644 --- a/crates/matrix-sdk-base/Changelog.md +++ b/crates/matrix-sdk-base/Changelog.md @@ -1,6 +1,9 @@ # Changelog -All notable changes to this crate will be documented in this file. +## unreleased + +- Rename `RoomType` to `RoomState` +- Add `RoomInfo::state` accessor ## 0.5.1 diff --git a/crates/matrix-sdk-base/src/client.rs b/crates/matrix-sdk-base/src/client.rs index c635c097f4b..8be3af22433 100644 --- a/crates/matrix-sdk-base/src/client.rs +++ b/crates/matrix-sdk-base/src/client.rs @@ -21,11 +21,11 @@ use std::{ #[cfg(feature = "e2e-encryption")] use std::{ops::Deref, sync::Arc}; -use futures_signals::signal::ReadOnlyMutable; +use eyeball::Subscriber; use matrix_sdk_common::{instant::Instant, locks::RwLock}; #[cfg(feature = "e2e-encryption")] use matrix_sdk_crypto::{ - store::CryptoStore, EncryptionSettings, OlmError, OlmMachine, ToDeviceRequest, + store::DynCryptoStore, EncryptionSettings, OlmError, OlmMachine, ToDeviceRequest, }; #[cfg(feature = "e2e-encryption")] use once_cell::sync::OnceCell; @@ -60,13 +60,13 @@ use crate::error::Error; use crate::{ deserialized_responses::{AmbiguityChanges, MembersResponse, SyncTimelineEvent}, error::Result, - rooms::{Room, RoomInfo, RoomType}, + rooms::{Room, RoomInfo, RoomState}, store::{ - ambiguity_map::AmbiguityCache, Result as StoreResult, StateChanges, StateStoreExt, Store, - StoreConfig, + ambiguity_map::AmbiguityCache, DynStateStore, Result as StoreResult, StateChanges, + StateStoreDataKey, StateStoreDataValue, StateStoreExt, Store, StoreConfig, }, sync::{JoinedRoom, LeftRoom, Rooms, SyncResponse, Timeline}, - Session, SessionMeta, SessionTokens, StateStore, + Session, SessionMeta, SessionTokens, }; /// A no IO Client implementation. @@ -82,7 +82,7 @@ pub struct BaseClient { /// This field is only meant to be used for `OlmMachine` initialization. /// All operations on it happen inside the `OlmMachine`. #[cfg(feature = "e2e-encryption")] - crypto_store: Arc, + crypto_store: Arc, /// The olm-machine that is created once the /// [`SessionMeta`][crate::session::SessionMeta] is set via /// [`BaseClient::set_session_meta`] @@ -133,10 +133,14 @@ impl BaseClient { /// Get the session tokens. /// - /// If the client is currently logged in, this will return a + /// This returns a subscriber object that you can use both to + /// [`get`](Subscriber::get) the current value as well as to react to + /// changes to the tokens. + /// + /// If the client is currently logged in, the inner value is a /// [`SessionTokens`] object which contains the access token and optional - /// refresh token. Otherwise it returns `None`. - pub fn session_tokens(&self) -> ReadOnlyMutable> { + /// refresh token. Otherwise it is `None`. + pub fn session_tokens(&self) -> Subscriber> { self.store.session_tokens() } @@ -163,8 +167,8 @@ impl BaseClient { /// Lookup the Room for the given RoomId, or create one, if it didn't exist /// yet in the store - pub async fn get_or_create_room(&self, room_id: &RoomId, room_type: RoomType) -> Room { - self.store.get_or_create_room(room_id, room_type).await + pub async fn get_or_create_room(&self, room_id: &RoomId, room_state: RoomState) -> Room { + self.store.get_or_create_room(room_id, room_state).await } /// Get all the rooms this client knows about. @@ -174,7 +178,7 @@ impl BaseClient { /// Get a reference to the store. #[allow(unknown_lints, clippy::explicit_auto_deref)] - pub fn store(&self) -> &dyn StateStore { + pub fn store(&self) -> &DynStateStore { &*self.store } @@ -407,7 +411,7 @@ impl BaseClient { if let Some(context) = &push_context { let actions = push_rules.get_actions(&event.event, context); - if actions.iter().any(|a| matches!(a, Action::Notify)) { + if actions.iter().any(Action::should_notify) { changes.add_notification( room_id, Notification::new( @@ -419,14 +423,7 @@ impl BaseClient { ), ); } - // TODO if there is an - // Action::SetTweak(Tweak::Highlight) we need to store - // its value with the event so a client can show if the - // event is highlighted - // in the UI. - // Requires the possibility to associate custom data - // with events and to - // store them. + event.push_actions = actions.to_owned(); } } Err(e) => { @@ -619,8 +616,8 @@ impl BaseClient { /// /// Update the internal and cached state accordingly. Return the final Room. pub async fn room_joined(&self, room_id: &RoomId) -> Result { - let room = self.store.get_or_create_room(room_id, RoomType::Joined).await; - if room.room_type() != RoomType::Joined { + let room = self.store.get_or_create_room(room_id, RoomState::Joined).await; + if room.state() != RoomState::Joined { let _sync_lock = self.sync_lock().read().await; let mut room_info = room.clone_info(); @@ -640,8 +637,8 @@ impl BaseClient { /// /// Update the internal and cached state accordingly. Return the final Room. pub async fn room_left(&self, room_id: &RoomId) -> Result { - let room = self.store.get_or_create_room(room_id, RoomType::Left).await; - if room.room_type() != RoomType::Left { + let room = self.store.get_or_create_room(room_id, RoomState::Left).await; + if room.state() != RoomState::Left { let _sync_lock = self.sync_lock().read().await; let mut room_info = room.clone_info(); @@ -714,7 +711,7 @@ impl BaseClient { let mut new_rooms = Rooms::default(); for (room_id, new_info) in rooms.join { - let room = self.store.get_or_create_room(&room_id, RoomType::Joined).await; + let room = self.store.get_or_create_room(&room_id, RoomState::Joined).await; let mut room_info = room.clone_info(); room_info.mark_as_joined(); @@ -798,7 +795,7 @@ impl BaseClient { } for (room_id, new_info) in rooms.leave { - let room = self.store.get_or_create_room(&room_id, RoomType::Left).await; + let room = self.store.get_or_create_room(&room_id, RoomState::Left).await; let mut room_info = room.clone_info(); room_info.mark_as_left(); room_info.mark_state_partially_synced(); @@ -1024,7 +1021,13 @@ impl BaseClient { filter_name: &str, response: &api::filter::create_filter::v3::Response, ) -> Result<()> { - Ok(self.store.save_filter(filter_name, &response.filter_id).await?) + Ok(self + .store + .set_kv_data( + StateStoreDataKey::Filter(filter_name), + StateStoreDataValue::Filter(response.filter_id.clone()), + ) + .await?) } /// Get the filter id of a previously uploaded filter. @@ -1039,7 +1042,13 @@ impl BaseClient { /// /// [`receive_filter_upload`]: #method.receive_filter_upload pub async fn get_filter(&self, filter_name: &str) -> StoreResult> { - self.store.get_filter(filter_name).await + let filter = self + .store + .get_kv_data(StateStoreDataKey::Filter(filter_name)) + .await? + .map(|d| d.into_filter().expect("State store data not a filter")); + + Ok(filter) } /// Get a to-device request that will share a room key with users in a room. @@ -1238,7 +1247,7 @@ mod tests { use serde_json::json; use super::BaseClient; - use crate::{DisplayName, RoomType, SessionMeta}; + use crate::{DisplayName, RoomState, SessionMeta}; #[async_test] async fn invite_after_leaving() { @@ -1272,7 +1281,7 @@ mod tests { )) .build_sync_response(); client.receive_sync_response(response).await.unwrap(); - assert_eq!(client.get_room(room_id).unwrap().room_type(), RoomType::Left); + assert_eq!(client.get_room(room_id).unwrap().state(), RoomState::Left); let response = ev_builder .add_invited_room(InvitedRoomBuilder::new(room_id).add_state_event( @@ -1290,7 +1299,7 @@ mod tests { )) .build_sync_response(); client.receive_sync_response(response).await.unwrap(); - assert_eq!(client.get_room(room_id).unwrap().room_type(), RoomType::Invited); + assert_eq!(client.get_room(room_id).unwrap().state(), RoomState::Invited); } #[async_test] @@ -1381,7 +1390,7 @@ mod tests { client.receive_sync_response(response).await.unwrap(); let room = client.get_room(room_id).expect("Room not found"); - assert_eq!(room.room_type(), RoomType::Invited); + assert_eq!(room.state(), RoomState::Invited); assert_eq!( room.display_name().await.expect("fetching display name failed"), DisplayName::Calculated("Kyra".to_owned()) diff --git a/crates/matrix-sdk-base/src/deserialized_responses.rs b/crates/matrix-sdk-base/src/deserialized_responses.rs index f96ca21fec3..8d09b617f79 100644 --- a/crates/matrix-sdk-base/src/deserialized_responses.rs +++ b/crates/matrix-sdk-base/src/deserialized_responses.rs @@ -25,7 +25,7 @@ use ruma::{ serde::Raw, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, UserId, }; -use serde::{Deserialize, Serialize}; +use serde::Serialize; /// A change in ambiguity of room members that an `m.room.member` event /// triggers. @@ -82,8 +82,7 @@ impl RawMemberEvent { } /// Wrapper around both MemberEvent-Types -#[allow(clippy::large_enum_variant)] -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug)] pub enum MemberEvent { /// A member event from a room in joined or left state. Sync(SyncRoomMemberEvent), diff --git a/crates/matrix-sdk-base/src/lib.rs b/crates/matrix-sdk-base/src/lib.rs index 8480ad95e79..c69e4867f23 100644 --- a/crates/matrix-sdk-base/src/lib.rs +++ b/crates/matrix-sdk-base/src/lib.rs @@ -42,8 +42,8 @@ pub use http; #[cfg(feature = "e2e-encryption")] pub use matrix_sdk_crypto as crypto; pub use once_cell; -pub use rooms::{DisplayName, Room, RoomInfo, RoomMember, RoomType}; -pub use store::{StateChanges, StateStore, StoreError}; +pub use rooms::{DisplayName, Room, RoomInfo, RoomMember, RoomState}; +pub use store::{StateChanges, StateStore, StateStoreDataKey, StateStoreDataValue, StoreError}; pub use utils::{ MinimalRoomMemberEvent, MinimalStateEvent, OriginalMinimalStateEvent, RedactedMinimalStateEvent, }; diff --git a/crates/matrix-sdk-base/src/rooms/members.rs b/crates/matrix-sdk-base/src/rooms/members.rs index e1179997a8e..86718a73177 100644 --- a/crates/matrix-sdk-base/src/rooms/members.rs +++ b/crates/matrix-sdk-base/src/rooms/members.rs @@ -17,7 +17,10 @@ use std::sync::Arc; use ruma::{ events::{ presence::PresenceEvent, - room::{member::MembershipState, power_levels::SyncRoomPowerLevelsEvent}, + room::{ + member::MembershipState, + power_levels::{PowerLevelAction, SyncRoomPowerLevelsEvent}, + }, }, MxcUri, UserId, }; @@ -101,6 +104,15 @@ impl RoomMember { .unwrap_or_else(|| if self.is_room_creator { 100 } else { 0 }) } + /// Whether the given user can do the given action based on the power + /// levels. + pub fn can_do(&self, action: PowerLevelAction) -> bool { + (*self.power_levels) + .as_ref() + .map(|e| e.power_levels().user_can_do(self.user_id(), action)) + .unwrap_or_else(|| self.is_room_creator) + } + /// Is the name that the member uses ambiguous in the room. /// /// A name is considered to be ambiguous if at least one other member shares diff --git a/crates/matrix-sdk-base/src/rooms/mod.rs b/crates/matrix-sdk-base/src/rooms/mod.rs index 7ef982238ef..f414fd20ee4 100644 --- a/crates/matrix-sdk-base/src/rooms/mod.rs +++ b/crates/matrix-sdk-base/src/rooms/mod.rs @@ -4,7 +4,7 @@ mod normal; use std::{collections::HashSet, fmt}; pub use members::RoomMember; -pub use normal::{Room, RoomInfo, RoomType}; +pub use normal::{Room, RoomInfo, RoomState}; use ruma::{ assign, events::{ diff --git a/crates/matrix-sdk-base/src/rooms/normal.rs b/crates/matrix-sdk-base/src/rooms/normal.rs index 7fb82ddedfd..3b2fb92ad4c 100644 --- a/crates/matrix-sdk-base/src/rooms/normal.rs +++ b/crates/matrix-sdk-base/src/rooms/normal.rs @@ -21,7 +21,7 @@ use futures_util::stream::{self, StreamExt}; use ruma::{ api::client::sync::sync_events::v3::RoomSummary as RumaSummary, events::{ - receipt::{Receipt, ReceiptType}, + receipt::{Receipt, ReceiptThread, ReceiptType}, room::{ create::RoomCreateEventContent, encryption::RoomEncryptionEventContent, guest_access::GuestAccess, history_visibility::HistoryVisibility, join_rules::JoinRule, @@ -31,7 +31,7 @@ use ruma::{ AnyRoomAccountDataEvent, AnyStrippedStateEvent, AnySyncStateEvent, RoomAccountDataEventType, }, - room::RoomType as CreateRoomType, + room::RoomType, EventId, OwnedEventId, OwnedMxcUri, OwnedRoomAliasId, OwnedUserId, RoomAliasId, RoomId, RoomVersionId, UserId, }; @@ -40,7 +40,7 @@ use tracing::debug; use super::{BaseRoomInfo, DisplayName, RoomMember}; use crate::{ - store::{Result as StoreResult, StateStore, StateStoreExt}, + store::{DynStateStore, Result as StoreResult, StateStoreExt}, sync::UnreadNotificationsCount, MinimalStateEvent, }; @@ -52,7 +52,7 @@ pub struct Room { room_id: Arc, own_user_id: Arc, inner: Arc>, - store: Arc, + store: Arc, } /// The room summary containing member counts and members that should be used to @@ -71,7 +71,7 @@ pub struct RoomSummary { /// Enum keeping track in which state the room is, e.g. if our own user is /// joined, invited, or has left the room. #[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)] -pub enum RoomType { +pub enum RoomState { /// The room is in a joined state. Joined, /// The room is in a left state. @@ -83,17 +83,17 @@ pub enum RoomType { impl Room { pub(crate) fn new( own_user_id: &UserId, - store: Arc, + store: Arc, room_id: &RoomId, - room_type: RoomType, + room_state: RoomState, ) -> Self { - let room_info = RoomInfo::new(room_id, room_type); + let room_info = RoomInfo::new(room_id, room_state); Self::restore(own_user_id, store, room_info) } pub(crate) fn restore( own_user_id: &UserId, - store: Arc, + store: Arc, room_info: RoomInfo, ) -> Self { Self { @@ -114,14 +114,14 @@ impl Room { &self.own_user_id } - /// Get the type of the room. - pub fn room_type(&self) -> RoomType { - self.inner.read().unwrap().room_type + /// Get the state of the room. + pub fn state(&self) -> RoomState { + self.inner.read().unwrap().room_state } - /// Whether this room's [`RoomType`](CreateRoomType) is `m.space`. + /// Whether this room's [`RoomType`] is `m.space`. pub fn is_space(&self) -> bool { - self.inner.read().unwrap().room_type().map_or(false, |t| *t == CreateRoomType::Space) + self.inner.read().unwrap().room_type().map_or(false, |t| *t == RoomType::Space) } /// Get the unread notification counts. @@ -381,13 +381,13 @@ impl Room { members? }; - let (joined, invited) = match self.room_type() { - RoomType::Invited => { + let (joined, invited) = match self.state() { + RoomState::Invited => { // when we were invited we don't have a proper summary, we have to do best // guessing (members.len() as u64, 1u64) } - RoomType::Joined if summary.joined_member_count == 0 => { + RoomState::Joined if summary.joined_member_count == 0 => { // joined but the summary is not completed yet ( (members.len() as u64) + 1, // we've taken ourselves out of the count @@ -482,22 +482,28 @@ impl Room { } } - /// Get the read receipt as a `EventId` and `Receipt` tuple for the given - /// `user_id` in this room. - pub async fn user_read_receipt( + /// Get the receipt as an `OwnedEventId` and `Receipt` tuple for the given + /// `receipt_type`, `thread` and `user_id` in this room. + pub async fn user_receipt( &self, + receipt_type: ReceiptType, + thread: ReceiptThread, user_id: &UserId, ) -> StoreResult> { - self.store.get_user_room_receipt_event(self.room_id(), ReceiptType::Read, user_id).await + self.store.get_user_room_receipt_event(self.room_id(), receipt_type, thread, user_id).await } - /// Get the read receipts as a list of `UserId` and `Receipt` tuples for the - /// given `event_id` in this room. - pub async fn event_read_receipts( + /// Get the receipts as a list of `OwnedUserId` and `Receipt` tuples for the + /// given `receipt_type`, `thread` and `event_id` in this room. + pub async fn event_receipts( &self, + receipt_type: ReceiptType, + thread: ReceiptThread, event_id: &EventId, ) -> StoreResult> { - self.store.get_event_room_receipt_events(self.room_id(), ReceiptType::Read, event_id).await + self.store + .get_event_room_receipt_events(self.room_id(), receipt_type, thread, event_id) + .await } } @@ -508,22 +514,23 @@ impl Room { pub struct RoomInfo { /// The unique room id of the room. pub(crate) room_id: Arc, - /// The type of the room. - pub(crate) room_type: RoomType, + /// The state of the room. + #[serde(rename = "room_type")] // for backwards compatibility + room_state: RoomState, /// The unread notifications counts. - pub(crate) notification_counts: UnreadNotificationsCount, + notification_counts: UnreadNotificationsCount, /// The summary of this room. - pub(crate) summary: RoomSummary, + summary: RoomSummary, /// Flag remembering if the room members are synced. - pub(crate) members_synced: bool, + members_synced: bool, /// The prev batch of this room we received during the last sync. pub(crate) last_prev_batch: Option, /// How much we know about this room. #[serde(default = "SyncInfo::complete")] // see fn docs for why we use this default - pub(crate) sync_info: SyncInfo, + sync_info: SyncInfo, /// Whether or not the encryption info was been synced. #[serde(default = "encryption_state_default")] // see fn docs for why we use this default - pub(crate) encryption_state_synced: bool, + encryption_state_synced: bool, /// Base room info which holds some basic event contents important for the /// room state. pub(crate) base_info: BaseRoomInfo, @@ -566,10 +573,10 @@ fn encryption_state_default() -> bool { impl RoomInfo { #[doc(hidden)] // used by store tests, otherwise it would be pub(crate) - pub fn new(room_id: &RoomId, room_type: RoomType) -> Self { + pub fn new(room_id: &RoomId, room_state: RoomState) -> Self { Self { room_id: room_id.into(), - room_type, + room_state, notification_counts: Default::default(), summary: Default::default(), members_synced: false, @@ -582,17 +589,17 @@ impl RoomInfo { /// Mark this Room as joined. pub fn mark_as_joined(&mut self) { - self.room_type = RoomType::Joined; + self.room_state = RoomState::Joined; } /// Mark this Room as left. pub fn mark_as_left(&mut self) { - self.room_type = RoomType::Left; + self.room_state = RoomState::Left; } /// Mark this Room as invited. pub fn mark_as_invited(&mut self) { - self.room_type = RoomType::Invited; + self.room_state = RoomState::Invited; } /// Mark this Room as having all the members synced. @@ -642,7 +649,12 @@ impl RoomInfo { } } - /// Returns whether this is an encrypted Room. + /// Returns the state this room is in. + pub fn state(&self) -> RoomState { + self.room_state + } + + /// Returns whether this is an encrypted room. pub fn is_encrypted(&self) -> bool { self.base_info.encryption.is_some() } @@ -737,7 +749,7 @@ impl RoomInfo { } /// Get the room type of this room. - pub fn room_type(&self) -> Option<&CreateRoomType> { + pub fn room_type(&self) -> Option<&RoomType> { self.base_info.create.as_ref()?.as_original()?.content.room_type.as_ref() } @@ -805,11 +817,70 @@ mod test { use super::*; use crate::{ - store::{MemoryStore, StateChanges}, + store::{MemoryStore, StateChanges, StateStore}, MinimalStateEvent, OriginalMinimalStateEvent, }; - fn make_room(room_type: RoomType) -> (Arc, Room) { + #[test] + fn room_info_serialization() { + // This test exists to make sure we don't accidentally change the + // serialized format for `RoomInfo`. + + let info = RoomInfo { + room_id: room_id!("!gda78o:server.tld").into(), + room_state: RoomState::Invited, + notification_counts: UnreadNotificationsCount { + highlight_count: 1, + notification_count: 2, + }, + summary: RoomSummary { + heroes: vec!["Somebody".to_owned()], + joined_member_count: 5, + invited_member_count: 0, + }, + members_synced: true, + last_prev_batch: Some("pb".to_owned()), + sync_info: SyncInfo::FullySynced, + encryption_state_synced: true, + base_info: BaseRoomInfo::new(), + }; + + let info_json = json!({ + "room_id": "!gda78o:server.tld", + "room_type": "Invited", + "notification_counts": { + "highlight_count": 1, + "notification_count": 2, + }, + "summary": { + "heroes": ["Somebody"], + "joined_member_count": 5, + "invited_member_count": 0, + }, + "members_synced": true, + "last_prev_batch": "pb", + "sync_info": "FullySynced", + "encryption_state_synced": true, + "base_info": { + "avatar": null, + "canonical_alias": null, + "create": null, + "dm_targets": [], + "encryption": null, + "guest_access": null, + "history_visibility": null, + "join_rules": null, + "max_power_level": 100, + "name": null, + "tombstone": null, + "topic": null, + } + }); + + assert_eq!(serde_json::to_value(info).unwrap(), info_json); + } + + fn make_room(room_type: RoomState) -> (Arc, Room) { let store = Arc::new(MemoryStore::new()); let user_id = user_id!("@me:example.org"); let room_id = room_id!("!test:localhost"); @@ -847,7 +918,7 @@ mod test { #[async_test] async fn test_display_name_default() { - let (_, room) = make_room(RoomType::Joined); + let (_, room) = make_room(RoomState::Joined); assert_eq!(room.display_name().await.unwrap(), DisplayName::Empty); let canonical_alias_event = MinimalStateEvent::Original(OriginalMinimalStateEvent { @@ -870,7 +941,7 @@ mod test { room.inner.write().unwrap().base_info.name = Some(name_event.clone()); assert_eq!(room.display_name().await.unwrap(), DisplayName::Named("Test Room".to_owned())); - let (_, room) = make_room(RoomType::Invited); + let (_, room) = make_room(RoomState::Invited); assert_eq!(room.display_name().await.unwrap(), DisplayName::Empty); // has precedence @@ -884,7 +955,7 @@ mod test { #[async_test] async fn test_display_name_dm_invited() { - let (store, room) = make_room(RoomType::Invited); + let (store, room) = make_room(RoomState::Invited); let room_id = room_id!("!test:localhost"); let matthew = user_id!("@matthew:example.org"); let me = user_id!("@me:example.org"); @@ -910,7 +981,7 @@ mod test { #[async_test] async fn test_display_name_dm_invited_no_heroes() { - let (store, room) = make_room(RoomType::Invited); + let (store, room) = make_room(RoomState::Invited); let room_id = room_id!("!test:localhost"); let matthew = user_id!("@matthew:example.org"); let me = user_id!("@me:example.org"); @@ -932,7 +1003,7 @@ mod test { #[async_test] async fn test_display_name_dm_joined() { - let (store, room) = make_room(RoomType::Joined); + let (store, room) = make_room(RoomState::Joined); let room_id = room_id!("!test:localhost"); let matthew = user_id!("@matthew:example.org"); let me = user_id!("@me:example.org"); @@ -963,7 +1034,7 @@ mod test { #[async_test] async fn test_display_name_dm_joined_no_heroes() { - let (store, room) = make_room(RoomType::Joined); + let (store, room) = make_room(RoomState::Joined); let room_id = room_id!("!test:localhost"); let matthew = user_id!("@matthew:example.org"); let me = user_id!("@me:example.org"); @@ -989,7 +1060,7 @@ mod test { #[async_test] async fn test_display_name_dm_alone() { - let (store, room) = make_room(RoomType::Joined); + let (store, room) = make_room(RoomState::Joined); let room_id = room_id!("!test:localhost"); let matthew = user_id!("@matthew:example.org"); let me = user_id!("@me:example.org"); diff --git a/crates/matrix-sdk-base/src/session.rs b/crates/matrix-sdk-base/src/session.rs index 35b58213b77..05222fe3461 100644 --- a/crates/matrix-sdk-base/src/session.rs +++ b/crates/matrix-sdk-base/src/session.rs @@ -68,6 +68,7 @@ impl Session { } } +#[cfg(not(tarpaulin_include))] impl fmt::Debug for Session { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Session") diff --git a/crates/matrix-sdk-base/src/sliding_sync.rs b/crates/matrix-sdk-base/src/sliding_sync.rs index 7bc4c9c2a6b..0b1cd684705 100644 --- a/crates/matrix-sdk-base/src/sliding_sync.rs +++ b/crates/matrix-sdk-base/src/sliding_sync.rs @@ -1,16 +1,23 @@ +use std::collections::BTreeMap; #[cfg(feature = "e2e-encryption")] use std::ops::Deref; -use ruma::api::client::sync::sync_events::{v3, v4}; #[cfg(feature = "e2e-encryption")] use ruma::UserId; +use ruma::{ + api::client::sync::sync_events::{ + v3::{self, Ephemeral}, + v4, DeviceLists, + }, + DeviceKeyAlgorithm, UInt, +}; use tracing::{debug, info, instrument}; use super::BaseClient; use crate::{ deserialized_responses::AmbiguityChanges, error::Result, - rooms::RoomType, + rooms::RoomState, store::{ambiguity_map::AmbiguityCache, StateChanges}, sync::{JoinedRoom, Rooms, SyncResponse}, }; @@ -23,7 +30,7 @@ impl BaseClient { /// * `response` - The response that we received after a successful sliding /// sync. #[instrument(skip_all, level = "trace")] - pub async fn process_sliding_sync(&self, response: v4::Response) -> Result { + pub async fn process_sliding_sync(&self, response: &v4::Response) -> Result { #[allow(unused_variables)] let v4::Response { // FIXME not yet supported by sliding sync. see @@ -36,6 +43,7 @@ impl BaseClient { //presence, .. } = response; + info!(rooms = rooms.len(), lists = lists.len(), extensions = !extensions.is_empty()); if rooms.is_empty() && extensions.is_empty() { @@ -44,23 +52,37 @@ impl BaseClient { return Ok(SyncResponse::default()); }; - let v4::Extensions { to_device, e2ee, account_data, .. } = extensions; + let v4::Extensions { to_device, e2ee, account_data, receipts, .. } = extensions; - let to_device_events = to_device.map(|v4| v4.events).unwrap_or_default(); + let to_device_events = to_device.as_ref().map(|v4| v4.events.clone()).unwrap_or_default(); // Destructure the single `None` of the E2EE extension into separate objects - // since that's what the OlmMachine API expects. Passing in the default - // empty maps and vecs for this is completely fine, since the OlmMachine + // since that's what the `OlmMachine` API expects. Passing in the default + // empty maps and vecs for this is completely fine, since the `OlmMachine` // assumes empty maps/vecs mean no change in the one-time key counts. + + // We declare default values that can be referenced hereinbelow. When we try to + // extract values from `e2ee`, that would be unfortunate to clone the + // value just to pass them (to remove them `e2ee`) as a reference later. + let device_one_time_keys_count = BTreeMap::::default(); + let device_unused_fallback_key_types = None; + let (device_lists, device_one_time_keys_count, device_unused_fallback_key_types) = e2ee + .as_ref() .map(|e2ee| { ( - e2ee.device_lists, - e2ee.device_one_time_keys_count, - e2ee.device_unused_fallback_key_types, + e2ee.device_lists.clone(), + &e2ee.device_one_time_keys_count, + &e2ee.device_unused_fallback_key_types, ) }) - .unwrap_or_default(); + .unwrap_or_else(|| { + ( + DeviceLists::default(), + &device_one_time_keys_count, + &device_unused_fallback_key_types, + ) + }); info!( to_device_events = to_device_events.len(), @@ -77,7 +99,7 @@ impl BaseClient { self.preprocess_to_device_events( to_device_events, &device_lists, - &device_one_time_keys_count, + device_one_time_keys_count, device_unused_fallback_key_types.as_deref(), ) .await? @@ -87,22 +109,22 @@ impl BaseClient { let mut changes = StateChanges::default(); let mut ambiguity_cache = AmbiguityCache::new(store.inner.clone()); - if let Some(global_data) = account_data.as_ref().map(|a| &a.global) { - self.handle_account_data(global_data, &mut changes).await; + if let Some(global_data) = account_data.as_ref() { + self.handle_account_data(&global_data.global, &mut changes).await; } let push_rules = self.get_push_rules(&changes).await?; let mut new_rooms = Rooms::default(); - for (room_id, room_data) in rooms.into_iter() { + for (room_id, room_data) in rooms { if !room_data.invite_state.is_empty() { let invite_states = &room_data.invite_state; - let room = store.get_or_create_stripped_room(&room_id).await; + let room = store.get_or_create_stripped_room(room_id).await; let mut room_info = room.clone_info(); room_info.mark_state_partially_synced(); - if let Some(r) = store.get_room(&room_id) { + if let Some(r) = store.get_room(room_id) { let mut room_info = r.clone_info(); room_info.mark_as_invited(); // FIXME: this might not be accurate room_info.mark_state_partially_synced(); @@ -116,7 +138,7 @@ impl BaseClient { v3::InvitedRoom::from(v3::InviteState::from(invite_states.clone())), ); } else { - let room = store.get_or_create_room(&room_id, RoomType::Joined).await; + let room = store.get_or_create_room(room_id, RoomState::Joined).await; let mut room_info = room.clone_info(); room_info.mark_as_joined(); // FIXME: this might not be accurate room_info.mark_state_partially_synced(); @@ -138,20 +160,9 @@ impl BaseClient { Default::default() }; - // FIXME not yet supported by sliding sync. see - // https://github.com/matrix-org/matrix-rust-sdk/issues/1014 - // if let Some(event) = - // room_data.ephemeral.events.iter().find_map(|e| match e.deserialize() { - // Ok(AnySyncEphemeralRoomEvent::Receipt(event)) => Some(event.content), - // _ => None, - // }) - // { - // changes.add_receipts(&room_id, event); - // } - let room_account_data = if let Some(inner_account_data) = &account_data { - if let Some(events) = inner_account_data.rooms.get(&room_id) { - self.handle_room_account_data(&room_id, events, &mut changes).await; + if let Some(events) = inner_account_data.rooms.get(room_id) { + self.handle_room_account_data(room_id, events, &mut changes).await; Some(events.to_vec()) } else { None @@ -168,8 +179,8 @@ impl BaseClient { .handle_timeline( &room, room_data.limited, - room_data.timeline, - room_data.prev_batch, + room_data.timeline.clone(), + room_data.prev_batch.clone(), &push_rules, &mut user_ids, &mut room_info, @@ -185,8 +196,8 @@ impl BaseClient { // The room turned on encryption in this sync, we need // to also get all the existing users and mark them for // tracking. - let joined = store.get_joined_user_ids(&room_id).await?; - let invited = store.get_invited_user_ids(&room_id).await?; + let joined = store.get_joined_user_ids(room_id).await?; + let invited = store.get_invited_user_ids(room_id).await?; let user_ids: Vec<&UserId> = joined.iter().chain(&invited).map(Deref::deref).collect(); @@ -207,7 +218,7 @@ impl BaseClient { timeline, v3::State::with_events(room_data.required_state.clone()), room_account_data.unwrap_or_default(), - Default::default(), // room_info.ephemeral, + Ephemeral::default(), notification_count, ), ); @@ -216,12 +227,21 @@ impl BaseClient { } } + // Process receipts now we have rooms + if let Some(receipts) = &receipts { + for (room_id, receipt_edu) in &receipts.rooms { + if let Ok(receipt_edu) = receipt_edu.deserialize() { + changes.add_receipts(room_id, receipt_edu.content); + } + } + } + // TODO remove this, we're processing account data events here again // because we want to have the push rules in place before we process // rooms and their events, but we want to create the rooms before we // process the `m.direct` account data event. - if let Some(global_data) = account_data.as_ref().map(|a| &a.global) { - self.handle_account_data(global_data, &mut changes).await; + if let Some(global_data) = account_data.as_ref() { + self.handle_account_data(&global_data.global, &mut changes).await; } // FIXME not yet supported by sliding sync. @@ -243,7 +263,7 @@ impl BaseClient { debug!("applied changes"); let device_one_time_keys_count = - device_one_time_keys_count.into_iter().map(|(k, v)| (k, v.into())).collect(); + device_one_time_keys_count.iter().map(|(k, v)| (k.clone(), (*v).into())).collect(); Ok(SyncResponse { rooms: new_rooms, @@ -251,7 +271,7 @@ impl BaseClient { notifications: changes.notifications, // FIXME not yet supported by sliding sync. presence: Default::default(), - account_data: account_data.map(|a| a.global).unwrap_or_default(), + account_data: account_data.as_ref().map(|a| a.global.clone()).unwrap_or_default(), to_device_events, device_lists, device_one_time_keys_count, diff --git a/crates/matrix-sdk-base/src/store/ambiguity_map.rs b/crates/matrix-sdk-base/src/store/ambiguity_map.rs index bba99c61760..4b52002d92d 100644 --- a/crates/matrix-sdk-base/src/store/ambiguity_map.rs +++ b/crates/matrix-sdk-base/src/store/ambiguity_map.rs @@ -23,15 +23,12 @@ use ruma::{ }; use tracing::trace; -use super::{Result, StateChanges}; -use crate::{ - deserialized_responses::{AmbiguityChange, RawMemberEvent}, - StateStore, -}; +use super::{DynStateStore, Result, StateChanges}; +use crate::deserialized_responses::{AmbiguityChange, RawMemberEvent}; #[derive(Debug)] pub(crate) struct AmbiguityCache { - pub store: Arc, + pub store: Arc, pub cache: BTreeMap>>, pub changes: BTreeMap>, } @@ -72,7 +69,7 @@ impl AmbiguityMap { } impl AmbiguityCache { - pub fn new(store: Arc) -> Self { + pub fn new(store: Arc) -> Self { Self { store, cache: BTreeMap::new(), changes: BTreeMap::new() } } diff --git a/crates/matrix-sdk-base/src/store/integration_tests.rs b/crates/matrix-sdk-base/src/store/integration_tests.rs index 83f76c0645a..05fe6e0490b 100644 --- a/crates/matrix-sdk-base/src/store/integration_tests.rs +++ b/crates/matrix-sdk-base/src/store/integration_tests.rs @@ -1,4 +1,859 @@ -//! Macro of integration tests for StateStore implementations. +//! Trait and macro of integration tests for StateStore implementations. + +use std::collections::{BTreeMap, BTreeSet}; + +use assert_matches::assert_matches; +use async_trait::async_trait; +use matrix_sdk_test::test_json; +use ruma::{ + api::client::media::get_content_thumbnail::v3::Method, + event_id, + events::{ + presence::PresenceEvent, + receipt::{ReceiptThread, ReceiptType}, + room::{ + member::{ + MembershipState, RoomMemberEventContent, StrippedRoomMemberEvent, + SyncRoomMemberEvent, + }, + power_levels::RoomPowerLevelsEventContent, + topic::{OriginalRoomTopicEvent, RedactedRoomTopicEvent, RoomTopicEventContent}, + MediaSource, + }, + AnyEphemeralRoomEventContent, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, + AnyStrippedStateEvent, AnySyncEphemeralRoomEvent, AnySyncStateEvent, + GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType, + }, + mxc_uri, room_id, + serde::Raw, + uint, user_id, EventId, OwnedEventId, RoomId, UserId, +}; +use serde_json::{json, value::Value as JsonValue}; + +use super::DynStateStore; +use crate::{ + deserialized_responses::MemberEvent, + media::{MediaFormat, MediaRequest, MediaThumbnailSize}, + store::{Result, StateStoreExt}, + RoomInfo, RoomState, StateChanges, StateStoreDataKey, StateStoreDataValue, +}; + +/// `StateStore` integration tests. +/// +/// This trait is not meant to be used directly, but will be used with the [``] +/// macro. +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +pub trait StateStoreIntegrationTests { + /// Populate the given `StateStore`. + async fn populate(&self) -> Result<()>; + /// Test media content storage. + async fn test_media_content(&self); + /// Test room topic redaction. + async fn test_topic_redaction(&self) -> Result<()>; + /// Test populating the store. + async fn test_populate_store(&self) -> Result<()>; + /// Test room member saving. + async fn test_member_saving(&self); + /// Test filter saving. + async fn test_filter_saving(&self); + /// Test sync token saving. + async fn test_sync_token_saving(&self); + /// Test stripped room member saving. + async fn test_stripped_member_saving(&self); + /// Test room power levels saving. + async fn test_power_level_saving(&self); + /// Test user receipts saving. + async fn test_receipts_saving(&self); + /// Test custom storage. + async fn test_custom_storage(&self) -> Result<()>; + /// Test invited room saving. + async fn test_persist_invited_room(&self) -> Result<()>; + /// Test stripped and non-stripped room member saving. + async fn test_stripped_non_stripped(&self) -> Result<()>; + /// Test room removal. + async fn test_room_removal(&self) -> Result<()>; +} + +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl StateStoreIntegrationTests for DynStateStore { + async fn populate(&self) -> Result<()> { + let mut changes = StateChanges::default(); + + let user_id = user_id(); + let invited_user_id = invited_user_id(); + let room_id = room_id(); + let stripped_room_id = stripped_room_id(); + + changes.sync_token = Some("t392-516_47314_0_7_1_1_1_11444_1".to_owned()); + + let presence_json: &JsonValue = &test_json::PRESENCE; + let presence_raw = + serde_json::from_value::>(presence_json.clone()).unwrap(); + let presence_event = presence_raw.deserialize().unwrap(); + changes.add_presence_event(presence_event, presence_raw); + + let pushrules_json: &JsonValue = &test_json::PUSH_RULES; + let pushrules_raw = + serde_json::from_value::>(pushrules_json.clone()) + .unwrap(); + let pushrules_event = pushrules_raw.deserialize().unwrap(); + changes.add_account_data(pushrules_event, pushrules_raw); + + let mut room = RoomInfo::new(room_id, RoomState::Joined); + room.mark_as_left(); + + let tag_json: &JsonValue = &test_json::TAG; + let tag_raw = + serde_json::from_value::>(tag_json.clone()).unwrap(); + let tag_event = tag_raw.deserialize().unwrap(); + changes.add_room_account_data(room_id, tag_event, tag_raw); + + let name_json: &JsonValue = &test_json::NAME; + let name_raw = serde_json::from_value::>(name_json.clone()).unwrap(); + let name_event = name_raw.deserialize().unwrap(); + room.handle_state_event(&name_event); + changes.add_state_event(room_id, name_event, name_raw); + + let topic_json: &JsonValue = &test_json::TOPIC; + let topic_raw = serde_json::from_value::>(topic_json.clone()) + .expect("can create sync-state-event for topic"); + let topic_event = topic_raw.deserialize().expect("can deserialize raw topic"); + room.handle_state_event(&topic_event); + changes.add_state_event(room_id, topic_event, topic_raw); + + let mut room_ambiguity_map = BTreeMap::new(); + let mut room_profiles = BTreeMap::new(); + let mut room_members = BTreeMap::new(); + + let member_json: &JsonValue = &test_json::MEMBER; + let member_event: SyncRoomMemberEvent = + serde_json::from_value(member_json.clone()).unwrap(); + let displayname = member_event.as_original().unwrap().content.displayname.clone().unwrap(); + room_ambiguity_map.insert(displayname.clone(), BTreeSet::from([user_id.to_owned()])); + room_profiles.insert(user_id.to_owned(), (&member_event).into()); + room_members.insert(user_id.to_owned(), Raw::new(&member_json).unwrap().cast()); + + let member_state_raw = + serde_json::from_value::>(member_json.clone()).unwrap(); + let member_state_event = member_state_raw.deserialize().unwrap(); + changes.add_state_event(room_id, member_state_event, member_state_raw); + + let invited_member_json: &JsonValue = &test_json::MEMBER_INVITE; + // FIXME: Should be stripped room member event + let invited_member_event: SyncRoomMemberEvent = + serde_json::from_value(invited_member_json.clone()).unwrap(); + room_ambiguity_map.entry(displayname).or_default().insert(invited_user_id.to_owned()); + room_profiles.insert(invited_user_id.to_owned(), (&invited_member_event).into()); + room_members + .insert(invited_user_id.to_owned(), Raw::new(&invited_member_json).unwrap().cast()); + + let invited_member_state_raw = + serde_json::from_value::>(invited_member_json.clone()).unwrap(); + let invited_member_state_event = invited_member_state_raw.deserialize().unwrap(); + changes.add_state_event(room_id, invited_member_state_event, invited_member_state_raw); + + let receipt_json: &JsonValue = &test_json::READ_RECEIPT; + let receipt_event = + serde_json::from_value::(receipt_json.clone()).unwrap(); + let receipt_content = match receipt_event.content() { + AnyEphemeralRoomEventContent::Receipt(content) => content, + _ => panic!(), + }; + changes.add_receipts(room_id, receipt_content); + + changes.ambiguity_maps.insert(room_id.to_owned(), room_ambiguity_map); + changes.profiles.insert(room_id.to_owned(), room_profiles); + changes.members.insert(room_id.to_owned(), room_members); + changes.add_room(room); + + let mut stripped_room = RoomInfo::new(stripped_room_id, RoomState::Invited); + + let stripped_name_json: &JsonValue = &test_json::NAME_STRIPPED; + let stripped_name_raw = + serde_json::from_value::>(stripped_name_json.clone()) + .unwrap(); + let stripped_name_event = stripped_name_raw.deserialize().unwrap(); + stripped_room.handle_stripped_state_event(&stripped_name_event); + changes.stripped_state.insert( + stripped_room_id.to_owned(), + BTreeMap::from([( + stripped_name_event.event_type(), + BTreeMap::from([( + stripped_name_event.state_key().to_owned(), + stripped_name_raw.clone(), + )]), + )]), + ); + + changes.add_stripped_room(stripped_room); + + let stripped_member_json: &JsonValue = &test_json::MEMBER_STRIPPED; + let stripped_member_event = Raw::new(&stripped_member_json.clone()).unwrap().cast(); + changes.add_stripped_member(stripped_room_id, user_id, stripped_member_event); + + self.save_changes(&changes).await?; + + Ok(()) + } + + async fn test_media_content(&self) { + let uri = mxc_uri!("mxc://localhost/media"); + let content: Vec = "somebinarydata".into(); + + let request_file = + MediaRequest { source: MediaSource::Plain(uri.to_owned()), format: MediaFormat::File }; + + let request_thumbnail = MediaRequest { + source: MediaSource::Plain(uri.to_owned()), + format: MediaFormat::Thumbnail(MediaThumbnailSize { + method: Method::Crop, + width: uint!(100), + height: uint!(100), + }), + }; + + assert!( + self.get_media_content(&request_file).await.unwrap().is_none(), + "unexpected media found" + ); + assert!( + self.get_media_content(&request_thumbnail).await.unwrap().is_none(), + "media not found" + ); + + self.add_media_content(&request_file, content.clone()).await.expect("adding media failed"); + assert!( + self.get_media_content(&request_file).await.unwrap().is_some(), + "media not found though added" + ); + + self.remove_media_content(&request_file).await.expect("removing media failed"); + assert!( + self.get_media_content(&request_file).await.unwrap().is_none(), + "media still there after removing" + ); + + self.add_media_content(&request_file, content.clone()) + .await + .expect("adding media again failed"); + assert!( + self.get_media_content(&request_file).await.unwrap().is_some(), + "media not found after adding again" + ); + + self.add_media_content(&request_thumbnail, content.clone()) + .await + .expect("adding thumbnail failed"); + assert!( + self.get_media_content(&request_thumbnail).await.unwrap().is_some(), + "thumbnail not found" + ); + + self.remove_media_content_for_uri(uri).await.expect("removing all media for uri failed"); + assert!( + self.get_media_content(&request_file).await.unwrap().is_none(), + "media wasn't removed" + ); + assert!( + self.get_media_content(&request_thumbnail).await.unwrap().is_none(), + "thumbnail wasn't removed" + ); + } + + async fn test_topic_redaction(&self) -> Result<()> { + let room_id = room_id(); + self.populate().await?; + + assert!(self.get_kv_data(StateStoreDataKey::SyncToken).await?.is_some()); + assert_eq!( + self.get_state_event_static::(room_id) + .await? + .expect("room topic found before redaction") + .deserialize_as::() + .expect("can deserialize room topic before redaction") + .content + .topic, + "πŸ˜€" + ); + + let mut changes = StateChanges::default(); + + let redaction_json: &JsonValue = &test_json::TOPIC_REDACTION; + let redaction_evt: Raw<_> = serde_json::from_value(redaction_json.clone()) + .expect("topic redaction event making works"); + let redacted_event_id: OwnedEventId = redaction_evt.get_field("redacts").unwrap().unwrap(); + + changes.add_redaction(room_id, &redacted_event_id, redaction_evt); + self.save_changes(&changes).await?; + + match self + .get_state_event_static::(room_id) + .await? + .expect("room topic found before redaction") + .deserialize_as::() + { + Err(_) => {} // as expected + Ok(_) => panic!("Topic has not been redacted"), + } + + let _ = self + .get_state_event_static::(room_id) + .await? + .expect("room topic found after redaction") + .deserialize_as::() + .expect("can deserialize room topic after redaction"); + + Ok(()) + } + + async fn test_populate_store(&self) -> Result<()> { + let room_id = room_id(); + let user_id = user_id(); + self.populate().await?; + + assert!(self.get_kv_data(StateStoreDataKey::SyncToken).await?.is_some()); + assert!(self.get_presence_event(user_id).await?.is_some()); + assert_eq!(self.get_room_infos().await?.len(), 1, "Expected to find 1 room info"); + assert_eq!( + self.get_stripped_room_infos().await?.len(), + 1, + "Expected to find 1 stripped room info" + ); + assert!(self + .get_account_data_event(GlobalAccountDataEventType::PushRules) + .await? + .is_some()); + + assert!(self.get_state_event(room_id, StateEventType::RoomName, "").await?.is_some()); + assert_eq!( + self.get_state_events(room_id, StateEventType::RoomTopic).await?.len(), + 1, + "Expected to find 1 room topic" + ); + assert!(self.get_profile(room_id, user_id).await?.is_some()); + assert!(self.get_member_event(room_id, user_id).await?.is_some()); + assert_eq!( + self.get_user_ids(room_id).await?.len(), + 2, + "Expected to find 2 members for room" + ); + assert_eq!( + self.get_invited_user_ids(room_id).await?.len(), + 1, + "Expected to find 1 invited user ids" + ); + assert_eq!( + self.get_joined_user_ids(room_id).await?.len(), + 1, + "Expected to find 1 joined user ids" + ); + assert_eq!( + self.get_users_with_display_name(room_id, "example").await?.len(), + 2, + "Expected to find 2 display names for room" + ); + assert!(self + .get_room_account_data_event(room_id, RoomAccountDataEventType::Tag) + .await? + .is_some()); + assert!(self + .get_user_room_receipt_event( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + user_id + ) + .await? + .is_some()); + assert_eq!( + self.get_event_room_receipt_events( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + first_receipt_event_id() + ) + .await? + .len(), + 1, + "Expected to find 1 read receipt" + ); + Ok(()) + } + + async fn test_member_saving(&self) { + let room_id = room_id!("!test_member_saving:localhost"); + let user_id = user_id(); + + assert!(self.get_member_event(room_id, user_id).await.unwrap().is_none()); + let mut changes = StateChanges::default(); + changes + .members + .entry(room_id.to_owned()) + .or_default() + .insert(user_id.to_owned(), membership_event()); + + self.save_changes(&changes).await.unwrap(); + assert!(self.get_member_event(room_id, user_id).await.unwrap().is_some()); + + let members = self.get_user_ids(room_id).await.unwrap(); + assert!(!members.is_empty(), "We expected to find members for the room") + } + + async fn test_filter_saving(&self) { + let filter_name = "filter_name"; + let filter_id = "filter_id_1234"; + + self.set_kv_data( + StateStoreDataKey::Filter(filter_name), + StateStoreDataValue::Filter(filter_id.to_owned()), + ) + .await + .unwrap(); + let stored_filter_id = assert_matches!( + self.get_kv_data(StateStoreDataKey::Filter(filter_name)).await, + Ok(Some(StateStoreDataValue::Filter(s))) => s + ); + assert_eq!(stored_filter_id, filter_id); + + self.remove_kv_data(StateStoreDataKey::Filter(filter_name)).await.unwrap(); + assert_matches!(self.get_kv_data(StateStoreDataKey::Filter(filter_name)).await, Ok(None)); + } + + async fn test_sync_token_saving(&self) { + let sync_token_1 = "t392-516_47314_0_7_1"; + let sync_token_2 = "t392-516_47314_0_7_2"; + + assert_matches!(self.get_kv_data(StateStoreDataKey::SyncToken).await, Ok(None)); + + let changes = + StateChanges { sync_token: Some(sync_token_1.to_owned()), ..Default::default() }; + self.save_changes(&changes).await.unwrap(); + let stored_sync_token = assert_matches!( + self.get_kv_data(StateStoreDataKey::SyncToken).await, + Ok(Some(StateStoreDataValue::SyncToken(s))) => s + ); + assert_eq!(stored_sync_token, sync_token_1); + + self.set_kv_data( + StateStoreDataKey::SyncToken, + StateStoreDataValue::SyncToken(sync_token_2.to_owned()), + ) + .await + .unwrap(); + let stored_sync_token = assert_matches!( + self.get_kv_data(StateStoreDataKey::SyncToken).await, + Ok(Some(StateStoreDataValue::SyncToken(s))) => s + ); + assert_eq!(stored_sync_token, sync_token_2); + + self.remove_kv_data(StateStoreDataKey::SyncToken).await.unwrap(); + assert_matches!(self.get_kv_data(StateStoreDataKey::SyncToken).await, Ok(None)); + } + + async fn test_stripped_member_saving(&self) { + let room_id = room_id!("!test_stripped_member_saving:localhost"); + let user_id = user_id(); + + assert!(self.get_member_event(room_id, user_id).await.unwrap().is_none()); + let mut changes = StateChanges::default(); + changes + .stripped_members + .entry(room_id.to_owned()) + .or_default() + .insert(user_id.to_owned(), stripped_membership_event()); + + self.save_changes(&changes).await.unwrap(); + assert!(self.get_member_event(room_id, user_id).await.unwrap().is_some()); + + let members = self.get_user_ids(room_id).await.unwrap(); + assert!(!members.is_empty(), "We expected to find members for the room") + } + + async fn test_power_level_saving(&self) { + let room_id = room_id!("!test_power_level_saving:localhost"); + + let raw_event = power_level_event(); + let event = raw_event.deserialize().unwrap(); + + assert!(self + .get_state_event(room_id, StateEventType::RoomPowerLevels, "") + .await + .unwrap() + .is_none()); + let mut changes = StateChanges::default(); + changes.add_state_event(room_id, event, raw_event); + + self.save_changes(&changes).await.unwrap(); + assert!(self + .get_state_event(room_id, StateEventType::RoomPowerLevels, "") + .await + .unwrap() + .is_some()); + } + + async fn test_receipts_saving(&self) { + let room_id = room_id!("!test_receipts_saving:localhost"); + + let first_event_id = event_id!("$1435641916114394fHBLK:matrix.org"); + let second_event_id = event_id!("$fHBLK1435641916114394:matrix.org"); + + let first_receipt_ts = uint!(1436451550); + let second_receipt_ts = uint!(1436451653); + let third_receipt_ts = uint!(1436474532); + + let first_receipt_event = serde_json::from_value(json!({ + first_event_id: { + "m.read": { + user_id(): { + "ts": first_receipt_ts, + } + } + } + })) + .expect("json creation failed"); + + let second_receipt_event = serde_json::from_value(json!({ + second_event_id: { + "m.read": { + user_id(): { + "ts": second_receipt_ts, + } + } + } + })) + .expect("json creation failed"); + + let third_receipt_event = serde_json::from_value(json!({ + second_event_id: { + "m.read": { + user_id(): { + "ts": third_receipt_ts, + "thread_id": "main", + } + } + } + })) + .expect("json creation failed"); + + assert!(self + .get_user_room_receipt_event( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + user_id() + ) + .await + .expect("failed to read unthreaded user room receipt") + .is_none()); + assert!(self + .get_event_room_receipt_events( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + first_event_id + ) + .await + .expect("failed to read unthreaded event room receipt for 1") + .is_empty()); + assert!(self + .get_event_room_receipt_events( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + second_event_id + ) + .await + .expect("failed to read unthreaded event room receipt for 2") + .is_empty()); + + let mut changes = StateChanges::default(); + changes.add_receipts(room_id, first_receipt_event); + + self.save_changes(&changes).await.expect("writing changes fauked"); + let (unthreaded_user_receipt_event_id, unthreaded_user_receipt) = self + .get_user_room_receipt_event( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + user_id(), + ) + .await + .expect("failed to read unthreaded user room receipt after save") + .unwrap(); + assert_eq!(unthreaded_user_receipt_event_id, first_event_id); + assert_eq!(unthreaded_user_receipt.ts.unwrap().0, first_receipt_ts); + let first_event_unthreaded_receipts = self + .get_event_room_receipt_events( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + first_event_id, + ) + .await + .expect("failed to read unthreaded event room receipt for 1 after save"); + assert_eq!( + first_event_unthreaded_receipts.len(), + 1, + "Found a wrong number of unthreaded receipts for 1 after save" + ); + assert_eq!(first_event_unthreaded_receipts[0].0, user_id()); + assert_eq!(first_event_unthreaded_receipts[0].1.ts.unwrap().0, first_receipt_ts); + assert!(self + .get_event_room_receipt_events( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + second_event_id + ) + .await + .expect("failed to read unthreaded event room receipt for 2 after save") + .is_empty()); + + let mut changes = StateChanges::default(); + changes.add_receipts(room_id, second_receipt_event); + + self.save_changes(&changes).await.expect("Saving works"); + let (unthreaded_user_receipt_event_id, unthreaded_user_receipt) = self + .get_user_room_receipt_event( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + user_id(), + ) + .await + .expect("Getting unthreaded user room receipt after save failed") + .unwrap(); + assert_eq!(unthreaded_user_receipt_event_id, second_event_id); + assert_eq!(unthreaded_user_receipt.ts.unwrap().0, second_receipt_ts); + assert!(self + .get_event_room_receipt_events( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + first_event_id + ) + .await + .expect("Getting unthreaded event room receipt events for first event failed") + .is_empty()); + let second_event_unthreaded_receipts = self + .get_event_room_receipt_events( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + second_event_id, + ) + .await + .expect("Getting unthreaded event room receipt events for second event failed"); + assert_eq!( + second_event_unthreaded_receipts.len(), + 1, + "Found a wrong number of unthreaded receipts for second event after save" + ); + assert_eq!(second_event_unthreaded_receipts[0].0, user_id()); + assert_eq!(second_event_unthreaded_receipts[0].1.ts.unwrap().0, second_receipt_ts); + + assert!(self + .get_user_room_receipt_event(room_id, ReceiptType::Read, ReceiptThread::Main, user_id()) + .await + .expect("failed to read threaded user room receipt") + .is_none()); + assert!(self + .get_event_room_receipt_events( + room_id, + ReceiptType::Read, + ReceiptThread::Main, + second_event_id + ) + .await + .expect("Getting threaded event room receipts for 2 failed") + .is_empty()); + + let mut changes = StateChanges::default(); + changes.add_receipts(room_id, third_receipt_event); + + self.save_changes(&changes).await.expect("Saving works"); + // Unthreaded receipts should not have changed. + let (unthreaded_user_receipt_event_id, unthreaded_user_receipt) = self + .get_user_room_receipt_event( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + user_id(), + ) + .await + .expect("Getting unthreaded user room receipt after save failed") + .unwrap(); + assert_eq!(unthreaded_user_receipt_event_id, second_event_id); + assert_eq!(unthreaded_user_receipt.ts.unwrap().0, second_receipt_ts); + let second_event_unthreaded_receipts = self + .get_event_room_receipt_events( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + second_event_id, + ) + .await + .expect("Getting unthreaded event room receipt events for second event failed"); + assert_eq!( + second_event_unthreaded_receipts.len(), + 1, + "Found a wrong number of unthreaded receipts for second event after save" + ); + assert_eq!(second_event_unthreaded_receipts[0].0, user_id()); + assert_eq!(second_event_unthreaded_receipts[0].1.ts.unwrap().0, second_receipt_ts); + // Threaded receipts should have changed + let (threaded_user_receipt_event_id, threaded_user_receipt) = self + .get_user_room_receipt_event(room_id, ReceiptType::Read, ReceiptThread::Main, user_id()) + .await + .expect("Getting threaded user room receipt after save failed") + .unwrap(); + assert_eq!(threaded_user_receipt_event_id, second_event_id); + assert_eq!(threaded_user_receipt.ts.unwrap().0, third_receipt_ts); + let second_event_threaded_receipts = self + .get_event_room_receipt_events( + room_id, + ReceiptType::Read, + ReceiptThread::Main, + second_event_id, + ) + .await + .expect("Getting threaded event room receipt events for second event failed"); + assert_eq!( + second_event_threaded_receipts.len(), + 1, + "Found a wrong number of threaded receipts for second event after save" + ); + assert_eq!(second_event_threaded_receipts[0].0, user_id()); + assert_eq!(second_event_threaded_receipts[0].1.ts.unwrap().0, third_receipt_ts); + } + + async fn test_custom_storage(&self) -> Result<()> { + let key = "my_key"; + let value = &[0, 1, 2, 3]; + + self.set_custom_value(key.as_bytes(), value.to_vec()).await?; + + let read = self.get_custom_value(key.as_bytes()).await?; + + assert_eq!(Some(value.as_ref()), read.as_deref()); + + Ok(()) + } + + async fn test_persist_invited_room(&self) -> Result<()> { + self.populate().await?; + + assert_eq!(self.get_stripped_room_infos().await?.len(), 1); + + Ok(()) + } + + async fn test_stripped_non_stripped(&self) -> Result<()> { + let room_id = room_id!("!test_stripped_non_stripped:localhost"); + let user_id = user_id(); + + assert!(self.get_member_event(room_id, user_id).await.unwrap().is_none()); + assert_eq!(self.get_room_infos().await.unwrap().len(), 0); + assert_eq!(self.get_stripped_room_infos().await.unwrap().len(), 0); + + let mut changes = StateChanges::default(); + changes + .members + .entry(room_id.to_owned()) + .or_default() + .insert(user_id.to_owned(), membership_event()); + changes.add_room(RoomInfo::new(room_id, RoomState::Left)); + self.save_changes(&changes).await.unwrap(); + + let member_event = + self.get_member_event(room_id, user_id).await.unwrap().unwrap().deserialize().unwrap(); + assert!(matches!(member_event, MemberEvent::Sync(_))); + assert_eq!(self.get_room_infos().await.unwrap().len(), 1); + assert_eq!(self.get_stripped_room_infos().await.unwrap().len(), 0); + + let members = self.get_user_ids(room_id).await.unwrap(); + assert_eq!(members, vec![user_id.to_owned()]); + + let mut changes = StateChanges::default(); + changes.add_stripped_member(room_id, user_id, custom_stripped_membership_event(user_id)); + changes.add_stripped_room(RoomInfo::new(room_id, RoomState::Invited)); + self.save_changes(&changes).await.unwrap(); + + let member_event = + self.get_member_event(room_id, user_id).await.unwrap().unwrap().deserialize().unwrap(); + assert!(matches!(member_event, MemberEvent::Stripped(_))); + assert_eq!(self.get_room_infos().await.unwrap().len(), 0); + assert_eq!(self.get_stripped_room_infos().await.unwrap().len(), 1); + + let members = self.get_user_ids(room_id).await.unwrap(); + assert_eq!(members, vec![user_id.to_owned()]); + + Ok(()) + } + + async fn test_room_removal(&self) -> Result<()> { + let room_id = room_id(); + let user_id = user_id(); + let stripped_room_id = stripped_room_id(); + + self.populate().await?; + + self.remove_room(room_id).await?; + + assert!(self.get_room_infos().await?.is_empty(), "room is still there"); + assert_eq!(self.get_stripped_room_infos().await?.len(), 1); + + assert!(self.get_state_event(room_id, StateEventType::RoomName, "").await?.is_none()); + assert!( + self.get_state_events(room_id, StateEventType::RoomTopic).await?.is_empty(), + "still state events found" + ); + assert!(self.get_profile(room_id, user_id).await?.is_none()); + assert!(self.get_member_event(room_id, user_id).await?.is_none()); + assert!(self.get_user_ids(room_id).await?.is_empty(), "still user ids found"); + assert!( + self.get_invited_user_ids(room_id).await?.is_empty(), + "still invited user ids found" + ); + assert!(self.get_joined_user_ids(room_id).await?.is_empty(), "still joined users found"); + assert!( + self.get_users_with_display_name(room_id, "example").await?.is_empty(), + "still display names found" + ); + assert!(self + .get_room_account_data_event(room_id, RoomAccountDataEventType::Tag) + .await? + .is_none()); + assert!(self + .get_user_room_receipt_event( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + user_id + ) + .await? + .is_none()); + assert!( + self.get_event_room_receipt_events( + room_id, + ReceiptType::Read, + ReceiptThread::Unthreaded, + first_receipt_event_id() + ) + .await? + .is_empty(), + "still event recepts in the store" + ); + + self.remove_room(stripped_room_id).await?; + + assert!(self.get_room_infos().await?.is_empty(), "still room info found"); + assert!(self.get_stripped_room_infos().await?.is_empty(), "still stripped room info found"); + Ok(()) + } +} /// Macro building to allow your StateStore implementation to run the entire /// tests suite locally. @@ -32,89 +887,10 @@ macro_rules! statestore_integration_tests { mod statestore_integration_tests { $crate::statestore_integration_tests!(@inner); - use ruma::{ - api::client::media::get_content_thumbnail::v3::Method, - events::room::MediaSource, - mxc_uri, uint, - }; - - use $crate::media::{MediaFormat, MediaRequest, MediaThumbnailSize}; - #[async_test] async fn test_media_content() { - let store = get_store().await.unwrap(); - - let uri = mxc_uri!("mxc://localhost/media"); - let content: Vec = "somebinarydata".into(); - - let request_file = MediaRequest { - source: MediaSource::Plain(uri.to_owned()), - format: MediaFormat::File, - }; - - let request_thumbnail = MediaRequest { - source: MediaSource::Plain(uri.to_owned()), - format: MediaFormat::Thumbnail(MediaThumbnailSize { - method: Method::Crop, - width: uint!(100), - height: uint!(100), - }), - }; - - assert!( - store.get_media_content(&request_file).await.unwrap().is_none(), - "unexpected media found" - ); - assert!( - store.get_media_content(&request_thumbnail).await.unwrap().is_none(), - "media not found" - ); - - store - .add_media_content(&request_file, content.clone()) - .await - .expect("adding media failed"); - assert!( - store.get_media_content(&request_file).await.unwrap().is_some(), - "media not found though added" - ); - - store.remove_media_content(&request_file).await.expect("removing media failed"); - assert!( - store.get_media_content(&request_file).await.unwrap().is_none(), - "media still there after removing" - ); - - store - .add_media_content(&request_file, content.clone()) - .await - .expect("adding media again failed"); - assert!( - store.get_media_content(&request_file).await.unwrap().is_some(), - "media not found after adding again" - ); - - store - .add_media_content(&request_thumbnail, content.clone()) - .await - .expect("adding thumbnail failed"); - assert!( - store.get_media_content(&request_thumbnail).await.unwrap().is_some(), - "thumbnail not found" - ); - - store - .remove_media_content_for_uri(uri) - .await - .expect("removing all media for uri failed"); - assert!( - store.get_media_content(&request_file).await.unwrap().is_none(), - "media wasn't removed" - ); - assert!( - store.get_media_content(&request_thumbnail).await.unwrap().is_none(), - "thumbnail wasn't removed" - ); + let store = get_store().await.unwrap().into_state_store(); + store.test_media_content().await; } } }; @@ -124,708 +900,149 @@ macro_rules! statestore_integration_tests { } }; (@inner) => { - use std::{ - collections::{BTreeMap, BTreeSet}, - sync::Arc, - }; - - use matrix_sdk_test::{async_test, test_json}; - use ruma::{ - event_id, - events::{ - presence::PresenceEvent, - receipt::ReceiptType, - room::{ - member::{ - MembershipState, RoomMemberEventContent, StrippedRoomMemberEvent, - SyncRoomMemberEvent, - }, - power_levels::RoomPowerLevelsEventContent, - topic::{RoomTopicEventContent, OriginalRoomTopicEvent, RedactedRoomTopicEvent}, - }, - AnyEphemeralRoomEventContent, AnyGlobalAccountDataEvent, - AnyRoomAccountDataEvent, AnyStrippedStateEvent, AnySyncEphemeralRoomEvent, - AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, - StateEventType, - }, - room_id, - serde::Raw, - user_id, EventId, OwnedEventId, RoomId, UserId, - }; - use serde_json::{json, Value as JsonValue}; + use matrix_sdk_test::async_test; - use $crate::{ - store::{Result as StoreResult, StateChanges, StateStore, StateStoreExt}, - RoomInfo, RoomType, - }; + use $crate::store::{IntoStateStore, Result as StoreResult, StateStoreIntegrationTests}; use super::get_store; - fn user_id() -> &'static UserId { - user_id!("@example:localhost") - } - pub(crate) fn invited_user_id() -> &'static UserId { - user_id!("@invited:localhost") - } - - pub(crate) fn room_id() -> &'static RoomId { - room_id!("!test:localhost") - } - - pub(crate) fn stripped_room_id() -> &'static RoomId { - room_id!("!stripped:localhost") - } - - pub(crate) fn first_receipt_event_id() -> &'static EventId { - event_id!("$example") - } - - /// Populate the given `StateStore`. - pub async fn populate_store(store: Arc) -> StoreResult<()> { - let mut changes = StateChanges::default(); - - let user_id = user_id(); - let invited_user_id = invited_user_id(); - let room_id = room_id(); - let stripped_room_id = stripped_room_id(); - - changes.sync_token = Some("t392-516_47314_0_7_1_1_1_11444_1".to_owned()); - - let presence_json: &JsonValue = &test_json::PRESENCE; - let presence_raw = - serde_json::from_value::>(presence_json.clone()).unwrap(); - let presence_event = presence_raw.deserialize().unwrap(); - changes.add_presence_event(presence_event, presence_raw); - - let pushrules_json: &JsonValue = &test_json::PUSH_RULES; - let pushrules_raw = serde_json::from_value::>( - pushrules_json.clone(), - ) - .unwrap(); - let pushrules_event = pushrules_raw.deserialize().unwrap(); - changes.add_account_data(pushrules_event, pushrules_raw); - - let mut room = RoomInfo::new(room_id, RoomType::Joined); - room.mark_as_left(); - - let tag_json: &JsonValue = &test_json::TAG; - let tag_raw = - serde_json::from_value::>(tag_json.clone()) - .unwrap(); - let tag_event = tag_raw.deserialize().unwrap(); - changes.add_room_account_data(room_id, tag_event, tag_raw); - - let name_json: &JsonValue = &test_json::NAME; - let name_raw = - serde_json::from_value::>(name_json.clone()).unwrap(); - let name_event = name_raw.deserialize().unwrap(); - room.handle_state_event(&name_event); - changes.add_state_event(room_id, name_event, name_raw); - - let topic_json: &JsonValue = &test_json::TOPIC; - let topic_raw = - serde_json::from_value::>(topic_json.clone()).expect("can create sync-state-event for topic"); - let topic_event = topic_raw.deserialize().expect("can deserialize raw topic"); - room.handle_state_event(&topic_event); - changes.add_state_event(room_id, topic_event, topic_raw); - - let mut room_ambiguity_map = BTreeMap::new(); - let mut room_profiles = BTreeMap::new(); - let mut room_members = BTreeMap::new(); - - let member_json: &JsonValue = &test_json::MEMBER; - let member_event: SyncRoomMemberEvent = - serde_json::from_value(member_json.clone()).unwrap(); - let displayname = - member_event.as_original().unwrap().content.displayname.clone().unwrap(); - room_ambiguity_map - .insert(displayname.clone(), BTreeSet::from([user_id.to_owned()])); - room_profiles.insert(user_id.to_owned(), (&member_event).into()); - room_members.insert(user_id.to_owned(), Raw::new(&member_json).unwrap().cast()); - - let member_state_raw = - serde_json::from_value::>(member_json.clone()).unwrap(); - let member_state_event = member_state_raw.deserialize().unwrap(); - changes.add_state_event(room_id, member_state_event, member_state_raw); - - let invited_member_json: &JsonValue = &test_json::MEMBER_INVITE; - // FIXME: Should be stripped room member event - let invited_member_event: SyncRoomMemberEvent = - serde_json::from_value(invited_member_json.clone()).unwrap(); - room_ambiguity_map - .entry(displayname) - .or_default() - .insert(invited_user_id.to_owned()); - room_profiles.insert(invited_user_id.to_owned(), (&invited_member_event).into()); - room_members.insert( - invited_user_id.to_owned(), - Raw::new(&invited_member_json).unwrap().cast(), - ); - - let invited_member_state_raw = - serde_json::from_value::>(invited_member_json.clone()) - .unwrap(); - let invited_member_state_event = invited_member_state_raw.deserialize().unwrap(); - changes.add_state_event( - room_id, - invited_member_state_event, - invited_member_state_raw, - ); - - let receipt_json: &JsonValue = &test_json::READ_RECEIPT; - let receipt_event = - serde_json::from_value::(receipt_json.clone()) - .unwrap(); - let receipt_content = match receipt_event.content() { - AnyEphemeralRoomEventContent::Receipt(content) => content, - _ => panic!(), - }; - changes.add_receipts(room_id, receipt_content); - - changes.ambiguity_maps.insert(room_id.to_owned(), room_ambiguity_map); - changes.profiles.insert(room_id.to_owned(), room_profiles); - changes.members.insert(room_id.to_owned(), room_members); - changes.add_room(room); - - let mut stripped_room = RoomInfo::new(stripped_room_id, RoomType::Invited); - - let stripped_name_json: &JsonValue = &test_json::NAME_STRIPPED; - let stripped_name_raw = serde_json::from_value::>( - stripped_name_json.clone(), - ) - .unwrap(); - let stripped_name_event = stripped_name_raw.deserialize().unwrap(); - stripped_room.handle_stripped_state_event(&stripped_name_event); - changes.stripped_state.insert( - stripped_room_id.to_owned(), - BTreeMap::from([( - stripped_name_event.event_type(), - BTreeMap::from([( - stripped_name_event.state_key().to_owned(), - stripped_name_raw.clone(), - )]), - )]), - ); - - changes.add_stripped_room(stripped_room); - - let stripped_member_json: &JsonValue = &test_json::MEMBER_STRIPPED; - let stripped_member_event = Raw::new(&stripped_member_json.clone()).unwrap().cast(); - changes.add_stripped_member(stripped_room_id, user_id, stripped_member_event); - - store.save_changes(&changes).await?; - Ok(()) - } - - fn power_level_event() -> Raw { - let content = RoomPowerLevelsEventContent::default(); - - let event = json!({ - "event_id": "$h29iv0s8:example.com", - "content": content, - "sender": user_id(), - "type": "m.room.power_levels", - "origin_server_ts": 0u64, - "state_key": "", - }); - - serde_json::from_value(event).unwrap() - } - - fn stripped_membership_event() -> Raw { - custom_stripped_membership_event(user_id()) - } - - fn custom_stripped_membership_event(user_id: &UserId) -> Raw { - let ev_json = json!({ - "type": "m.room.member", - "content": RoomMemberEventContent::new(MembershipState::Join), - "sender": user_id, - "state_key": user_id, - }); - - Raw::new(&ev_json).unwrap().cast() - } - - fn membership_event() -> Raw { - custom_membership_event(user_id(), event_id!("$h29iv0s8:example.com").to_owned()) - } - - fn custom_membership_event( - user_id: &UserId, - event_id: OwnedEventId, - ) -> Raw { - let ev_json = json!({ - "type": "m.room.member", - "content": RoomMemberEventContent::new(MembershipState::Join), - "event_id": event_id, - "origin_server_ts": 198, - "sender": user_id, - "state_key": user_id, - }); - - Raw::new(&ev_json).unwrap().cast() - } - #[async_test] async fn test_topic_redaction() -> StoreResult<()> { - let room_id = room_id(); - let inner_store = get_store().await?; - - let store = Arc::new(inner_store); - populate_store(store.clone()).await?; - - assert!(store.get_sync_token().await?.is_some()); - assert_eq!( - store - .get_state_event_static::(room_id) - .await? - .expect("room topic found before redaction") - .deserialize_as::() - .expect("can deserialize room topic before redaction") - .content - .topic, - "πŸ˜€" - ); - - let mut changes = StateChanges::default(); - - let redaction_json: &JsonValue = &test_json::TOPIC_REDACTION; - let redaction_evt: Raw<_> = serde_json::from_value(redaction_json.clone()).expect("topic redaction event making works"); - let redacted_event_id: OwnedEventId = redaction_evt.get_field("redacts").unwrap().unwrap(); - - changes.add_redaction(room_id, &redacted_event_id, redaction_evt); - store.save_changes(&changes).await?; - - match store - .get_state_event_static::(room_id) - .await? - .expect("room topic found before redaction") - .deserialize_as::() - { - Err(_) => { } // as expected - Ok(_) => panic!("Topic has not been redacted") - } - - let _ = store - .get_state_event_static::(room_id) - .await? - .expect("room topic found after redaction") - .deserialize_as::() - .expect("can deserialize room topic after redaction"); - - Ok(()) + let store = get_store().await?.into_state_store(); + store.test_topic_redaction().await } #[async_test] async fn test_populate_store() -> StoreResult<()> { - let room_id = room_id(); - let user_id = user_id(); - let inner_store = get_store().await?; - - let store = Arc::new(inner_store); - populate_store(store.clone()).await?; - - assert!(store.get_sync_token().await?.is_some()); - assert!(store.get_presence_event(user_id).await?.is_some()); - assert_eq!(store.get_room_infos().await?.len(), 1, "Expected to find 1 room info"); - assert_eq!( - store.get_stripped_room_infos().await?.len(), - 1, - "Expected to find 1 stripped room info" - ); - assert!(store - .get_account_data_event(GlobalAccountDataEventType::PushRules) - .await? - .is_some()); - - assert!(store - .get_state_event(room_id, StateEventType::RoomName, "") - .await? - .is_some()); - assert_eq!( - store.get_state_events(room_id, StateEventType::RoomTopic).await?.len(), - 1, - "Expected to find 1 room topic" - ); - assert!(store.get_profile(room_id, user_id).await?.is_some()); - assert!(store.get_member_event(room_id, user_id).await?.is_some()); - assert_eq!( - store.get_user_ids(room_id).await?.len(), - 2, - "Expected to find 2 members for room" - ); - assert_eq!( - store.get_invited_user_ids(room_id).await?.len(), - 1, - "Expected to find 1 invited user ids" - ); - assert_eq!( - store.get_joined_user_ids(room_id).await?.len(), - 1, - "Expected to find 1 joined user ids" - ); - assert_eq!( - store.get_users_with_display_name(room_id, "example").await?.len(), - 2, - "Expected to find 2 display names for room" - ); - assert!(store - .get_room_account_data_event(room_id, RoomAccountDataEventType::Tag) - .await? - .is_some()); - assert!(store - .get_user_room_receipt_event(room_id, ReceiptType::Read, user_id) - .await? - .is_some()); - assert_eq!( - store - .get_event_room_receipt_events( - room_id, - ReceiptType::Read, - first_receipt_event_id() - ) - .await? - .len(), - 1, - "Expected to find 1 read receipt" - ); - Ok(()) + let store = get_store().await?.into_state_store(); + store.test_populate_store().await } #[async_test] async fn test_member_saving() { - let store = get_store().await.unwrap(); - let room_id = room_id!("!test_member_saving:localhost"); - let user_id = user_id(); - - assert!(store.get_member_event(room_id, user_id).await.unwrap().is_none()); - let mut changes = StateChanges::default(); - changes - .members - .entry(room_id.to_owned()) - .or_default() - .insert(user_id.to_owned(), membership_event()); - - store.save_changes(&changes).await.unwrap(); - assert!(store.get_member_event(room_id, user_id).await.unwrap().is_some()); - - let members = store.get_user_ids(room_id).await.unwrap(); - assert!(!members.is_empty(), "We expected to find members for the room") + let store = get_store().await.unwrap().into_state_store(); + store.test_member_saving().await } #[async_test] async fn test_filter_saving() { - let store = get_store().await.unwrap(); - let test_name = "filter_name"; - let filter_id = "filter_id_1234"; - assert_eq!(store.get_filter(test_name).await.unwrap(), None); - store.save_filter(test_name, filter_id).await.unwrap(); - assert_eq!(store.get_filter(test_name).await.unwrap(), Some(filter_id.to_owned())); + let store = get_store().await.unwrap().into_state_store(); + store.test_filter_saving().await } #[async_test] async fn test_sync_token_saving() { - let mut changes = StateChanges::default(); - let store = get_store().await.unwrap(); - let sync_token = "t392-516_47314_0_7_1".to_owned(); - - changes.sync_token = Some(sync_token.clone()); - assert_eq!(store.get_sync_token().await.unwrap(), None); - store.save_changes(&changes).await.unwrap(); - assert_eq!(store.get_sync_token().await.unwrap(), Some(sync_token)); + let store = get_store().await.unwrap().into_state_store(); + store.test_sync_token_saving().await } #[async_test] async fn test_stripped_member_saving() { - let store = get_store().await.unwrap(); - let room_id = room_id!("!test_stripped_member_saving:localhost"); - let user_id = user_id(); - - assert!(store.get_member_event(room_id, user_id).await.unwrap().is_none()); - let mut changes = StateChanges::default(); - changes - .stripped_members - .entry(room_id.to_owned()) - .or_default() - .insert(user_id.to_owned(), stripped_membership_event()); - - store.save_changes(&changes).await.unwrap(); - assert!(store.get_member_event(room_id, user_id).await.unwrap().is_some()); - - let members = store.get_user_ids(room_id).await.unwrap(); - assert!(!members.is_empty(), "We expected to find members for the room") + let store = get_store().await.unwrap().into_state_store(); + store.test_stripped_member_saving().await } #[async_test] async fn test_power_level_saving() { - let store = get_store().await.unwrap(); - let room_id = room_id!("!test_power_level_saving:localhost"); - - let raw_event = power_level_event(); - let event = raw_event.deserialize().unwrap(); - - assert!(store - .get_state_event(room_id, StateEventType::RoomPowerLevels, "") - .await - .unwrap() - .is_none()); - let mut changes = StateChanges::default(); - changes.add_state_event(room_id, event, raw_event); - - store.save_changes(&changes).await.unwrap(); - assert!(store - .get_state_event(room_id, StateEventType::RoomPowerLevels, "") - .await - .unwrap() - .is_some()); + let store = get_store().await.unwrap().into_state_store(); + store.test_power_level_saving().await } #[async_test] async fn test_receipts_saving() { - let store = get_store().await.expect("creating store failed"); - - let room_id = room_id!("!test_receipts_saving:localhost"); - - let first_event_id = event_id!("$1435641916114394fHBLK:matrix.org"); - let second_event_id = event_id!("$fHBLK1435641916114394:matrix.org"); - - let first_receipt_event = serde_json::from_value(json!({ - first_event_id: { - "m.read": { - user_id(): { - "ts": 1436451550453u64 - } - } - } - })) - .expect("json creation failed"); - - let second_receipt_event = serde_json::from_value(json!({ - second_event_id: { - "m.read": { - user_id(): { - "ts": 1436451551453u64 - } - } - } - })) - .expect("json creation failed"); - - assert!(store - .get_user_room_receipt_event(room_id, ReceiptType::Read, user_id()) - .await - .expect("failed to read user room receipt") - .is_none()); - assert!(store - .get_event_room_receipt_events(room_id, ReceiptType::Read, &first_event_id) - .await - .expect("failed to read user room receipt for 1") - .is_empty()); - assert!(store - .get_event_room_receipt_events(room_id, ReceiptType::Read, &second_event_id) - .await - .expect("failed to read user room receipt for 2") - .is_empty()); - - let mut changes = StateChanges::default(); - changes.add_receipts(room_id, first_receipt_event); - - store.save_changes(&changes).await.expect("writing changes fauked"); - assert!(store - .get_user_room_receipt_event(room_id, ReceiptType::Read, user_id()) - .await - .expect("failed to read user room receipt after save") - .is_some()); - assert_eq!( - store - .get_event_room_receipt_events(room_id, ReceiptType::Read, &first_event_id) - .await - .expect("failed to read user room receipt for 1 after save") - .len(), - 1, - "Found a wrong number of receipts for 1 after save" - ); - assert!(store - .get_event_room_receipt_events(room_id, ReceiptType::Read, &second_event_id) - .await - .expect("failed to read user room receipt for 2 after save") - .is_empty()); - - let mut changes = StateChanges::default(); - changes.add_receipts(room_id, second_receipt_event); - - store.save_changes(&changes).await.expect("Saving works"); - assert!(store - .get_user_room_receipt_event(room_id, ReceiptType::Read, user_id()) - .await - .expect("Getting user room receipts failed") - .is_some()); - assert!(store - .get_event_room_receipt_events(room_id, ReceiptType::Read, &first_event_id) - .await - .expect("Getting event room receipt events for first event failed") - .is_empty()); - assert_eq!( - store - .get_event_room_receipt_events(room_id, ReceiptType::Read, &second_event_id) - .await - .expect("Getting event room receipt events for second event failed") - .len(), - 1, - "Found a wrong number of receipts for second event after save" - ); + let store = get_store().await.expect("creating store failed").into_state_store(); + store.test_receipts_saving().await; } #[async_test] async fn test_custom_storage() -> StoreResult<()> { - let key = "my_key"; - let value = &[0, 1, 2, 3]; - let store = get_store().await?; - - store.set_custom_value(key.as_bytes(), value.to_vec()).await?; - - let read = store.get_custom_value(key.as_bytes()).await?; - - assert_eq!(Some(value.as_ref()), read.as_deref()); - - Ok(()) + let store = get_store().await?.into_state_store(); + store.test_custom_storage().await } #[async_test] async fn test_persist_invited_room() -> StoreResult<()> { - let inner_store = get_store().await?; - let store = Arc::new(inner_store); - populate_store(store.clone()).await?; - - assert_eq!(store.get_stripped_room_infos().await?.len(), 1); - - Ok(()) + let store = get_store().await?.into_state_store(); + store.test_persist_invited_room().await } #[async_test] async fn test_stripped_non_stripped() -> StoreResult<()> { - let store = get_store().await.unwrap(); - let room_id = room_id!("!test_stripped_non_stripped:localhost"); - let user_id = user_id(); - - assert!(store.get_member_event(room_id, user_id).await.unwrap().is_none()); - assert_eq!(store.get_room_infos().await.unwrap().len(), 0); - assert_eq!(store.get_stripped_room_infos().await.unwrap().len(), 0); - - let mut changes = StateChanges::default(); - changes - .members - .entry(room_id.to_owned()) - .or_default() - .insert(user_id.to_owned(), membership_event()); - changes.add_room(RoomInfo::new(room_id, RoomType::Left)); - store.save_changes(&changes).await.unwrap(); - - let member_event = store - .get_member_event(room_id, user_id) - .await - .unwrap() - .unwrap() - .deserialize() - .unwrap(); - assert!(matches!(member_event, $crate::deserialized_responses::MemberEvent::Sync(_))); - assert_eq!(store.get_room_infos().await.unwrap().len(), 1); - assert_eq!(store.get_stripped_room_infos().await.unwrap().len(), 0); - - let members = store.get_user_ids(room_id).await.unwrap(); - assert_eq!(members, vec![user_id.to_owned()]); - - let mut changes = StateChanges::default(); - changes.add_stripped_member(room_id, user_id, custom_stripped_membership_event(user_id)); - changes.add_stripped_room(RoomInfo::new(room_id, RoomType::Invited)); - store.save_changes(&changes).await.unwrap(); - - let member_event = store - .get_member_event(room_id, user_id) - .await - .unwrap() - .unwrap() - .deserialize() - .unwrap(); - assert!( - matches!(member_event, $crate::deserialized_responses::MemberEvent::Stripped(_)) - ); - assert_eq!(store.get_room_infos().await.unwrap().len(), 0); - assert_eq!(store.get_stripped_room_infos().await.unwrap().len(), 1); - - let members = store.get_user_ids(room_id).await.unwrap(); - assert_eq!(members, vec![user_id.to_owned()]); - - Ok(()) + let store = get_store().await.unwrap().into_state_store(); + store.test_stripped_non_stripped().await } #[async_test] async fn test_room_removal() -> StoreResult<()> { - let room_id = room_id(); - let user_id = user_id(); - let inner_store = get_store().await?; - let stripped_room_id = stripped_room_id(); + let store = get_store().await?.into_state_store(); + store.test_room_removal().await + } + }; +} - let store = Arc::new(inner_store); - populate_store(store.clone()).await?; +fn user_id() -> &'static UserId { + user_id!("@example:localhost") +} - store.remove_room(room_id).await?; +fn invited_user_id() -> &'static UserId { + user_id!("@invited:localhost") +} - assert!(store.get_room_infos().await?.is_empty(), "room is still there"); - assert_eq!(store.get_stripped_room_infos().await?.len(), 1); +fn room_id() -> &'static RoomId { + room_id!("!test:localhost") +} - assert!(store - .get_state_event(room_id, StateEventType::RoomName, "") - .await? - .is_none()); - assert!( - store.get_state_events(room_id, StateEventType::RoomTopic).await?.is_empty(), - "still state events found" - ); - assert!(store.get_profile(room_id, user_id).await?.is_none()); - assert!(store.get_member_event(room_id, user_id).await?.is_none()); - assert!(store.get_user_ids(room_id).await?.is_empty(), "still user ids found"); - assert!( - store.get_invited_user_ids(room_id).await?.is_empty(), - "still invited user ids found" - ); - assert!( - store.get_joined_user_ids(room_id).await?.is_empty(), - "still joined users found" - ); - assert!( - store.get_users_with_display_name(room_id, "example").await?.is_empty(), - "still display names found" - ); - assert!(store - .get_room_account_data_event(room_id, RoomAccountDataEventType::Tag) - .await? - .is_none()); - assert!(store - .get_user_room_receipt_event(room_id, ReceiptType::Read, user_id) - .await? - .is_none()); - assert!( - store - .get_event_room_receipt_events( - room_id, - ReceiptType::Read, - first_receipt_event_id() - ) - .await? - .is_empty(), - "still event recepts in the store" - ); - - store.remove_room(stripped_room_id).await?; - - assert!(store.get_room_infos().await?.is_empty(), "still room info found"); - assert!( - store.get_stripped_room_infos().await?.is_empty(), - "still stripped room info found" - ); - Ok(()) - } - }; +fn stripped_room_id() -> &'static RoomId { + room_id!("!stripped:localhost") +} + +fn first_receipt_event_id() -> &'static EventId { + event_id!("$example") +} + +fn power_level_event() -> Raw { + let content = RoomPowerLevelsEventContent::default(); + + let event = json!({ + "event_id": "$h29iv0s8:example.com", + "content": content, + "sender": user_id(), + "type": "m.room.power_levels", + "origin_server_ts": 0u64, + "state_key": "", + }); + + serde_json::from_value(event).unwrap() +} + +fn stripped_membership_event() -> Raw { + custom_stripped_membership_event(user_id()) +} + +fn custom_stripped_membership_event(user_id: &UserId) -> Raw { + let ev_json = json!({ + "type": "m.room.member", + "content": RoomMemberEventContent::new(MembershipState::Join), + "sender": user_id, + "state_key": user_id, + }); + + Raw::new(&ev_json).unwrap().cast() +} + +fn membership_event() -> Raw { + custom_membership_event(user_id(), event_id!("$h29iv0s8:example.com").to_owned()) +} + +fn custom_membership_event(user_id: &UserId, event_id: OwnedEventId) -> Raw { + let ev_json = json!({ + "type": "m.room.member", + "content": RoomMemberEventContent::new(MembershipState::Join), + "event_id": event_id, + "origin_server_ts": 198, + "sender": user_id, + "state_key": user_id, + }); + + Raw::new(&ev_json).unwrap().cast() } diff --git a/crates/matrix-sdk-base/src/store/memory_store.rs b/crates/matrix-sdk-base/src/store/memory_store.rs index 2aa2ee2d6dc..81b35e886c6 100644 --- a/crates/matrix-sdk-base/src/store/memory_store.rs +++ b/crates/matrix-sdk-base/src/store/memory_store.rs @@ -25,7 +25,7 @@ use ruma::{ canonical_json::redact, events::{ presence::PresenceEvent, - receipt::{Receipt, ReceiptType}, + receipt::{Receipt, ReceiptThread, ReceiptType}, room::member::{MembershipState, StrippedRoomMemberEvent, SyncRoomMemberEvent}, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType, @@ -37,7 +37,10 @@ use ruma::{ use tracing::{debug, info, warn}; use super::{Result, RoomInfo, StateChanges, StateStore, StoreError}; -use crate::{deserialized_responses::RawMemberEvent, media::MediaRequest, MinimalRoomMemberEvent}; +use crate::{ + deserialized_responses::RawMemberEvent, media::MediaRequest, MinimalRoomMemberEvent, + StateStoreDataKey, StateStoreDataValue, +}; /// In-Memory, non-persistent implementation of the `StateStore` /// @@ -45,6 +48,7 @@ use crate::{deserialized_responses::RawMemberEvent, media::MediaRequest, Minimal #[allow(clippy::type_complexity)] #[derive(Debug, Clone)] pub struct MemoryStore { + user_avatar_url: Arc>, sync_token: Arc>>, filters: Arc>, account_data: Arc>>, @@ -66,10 +70,17 @@ pub struct MemoryStore { stripped_joined_user_ids: Arc>>, stripped_invited_user_ids: Arc>>, presence: Arc>>, - room_user_receipts: - Arc>>>, + room_user_receipts: Arc< + DashMap< + OwnedRoomId, + DashMap<(String, Option), DashMap>, + >, + >, room_event_receipts: Arc< - DashMap>>>, + DashMap< + OwnedRoomId, + DashMap<(String, Option), DashMap>>, + >, >, custom: Arc, Vec>>, } @@ -85,6 +96,7 @@ impl MemoryStore { /// Create a new empty MemoryStore pub fn new() -> Self { Self { + user_avatar_url: Default::default(), sync_token: Default::default(), filters: Default::default(), account_data: Default::default(), @@ -112,18 +124,61 @@ impl MemoryStore { } } - async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { - self.filters.insert(filter_name.to_owned(), filter_id.to_owned()); + async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result> { + match key { + StateStoreDataKey::SyncToken => { + Ok(self.sync_token.read().unwrap().clone().map(StateStoreDataValue::SyncToken)) + } + StateStoreDataKey::Filter(filter_name) => Ok(self + .filters + .get(filter_name) + .map(|f| StateStoreDataValue::Filter(f.value().clone()))), + StateStoreDataKey::UserAvatarUrl(user_id) => Ok(self + .user_avatar_url + .get(user_id.as_str()) + .map(|u| StateStoreDataValue::UserAvatarUrl(u.value().clone()))), + } + } + + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> Result<()> { + match key { + StateStoreDataKey::SyncToken => { + *self.sync_token.write().unwrap() = + Some(value.into_sync_token().expect("Session data not a sync token")) + } + StateStoreDataKey::Filter(filter_name) => { + self.filters.insert( + filter_name.to_owned(), + value.into_filter().expect("Session data not a filter"), + ); + } + StateStoreDataKey::UserAvatarUrl(user_id) => { + self.filters.insert( + user_id.to_string(), + value.into_user_avatar_url().expect("Session data not a user avatar url"), + ); + } + } Ok(()) } - async fn get_filter(&self, filter_name: &str) -> Result> { - Ok(self.filters.get(filter_name).map(|f| f.to_string())) - } + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> { + match key { + StateStoreDataKey::SyncToken => *self.sync_token.write().unwrap() = None, + StateStoreDataKey::Filter(filter_name) => { + self.filters.remove(filter_name); + } + StateStoreDataKey::UserAvatarUrl(user_id) => { + self.filters.remove(user_id.as_str()); + } + } - async fn get_sync_token(&self) -> Result> { - Ok(self.sync_token.read().unwrap().clone()) + Ok(()) } async fn save_changes(&self, changes: &StateChanges) -> Result<()> { @@ -317,18 +372,21 @@ impl MemoryStore { for (event_id, receipts) in &content.0 { for (receipt_type, receipts) in receipts { for (user_id, receipt) in receipts { + let thread = receipt.thread.as_str().map(ToOwned::to_owned); // Add the receipt to the room user receipts if let Some((old_event, _)) = self .room_user_receipts .entry(room.clone()) .or_default() - .entry(receipt_type.to_string()) + .entry((receipt_type.to_string(), thread.clone())) .or_default() .insert(user_id.clone(), (event_id.clone(), receipt.clone())) { // Remove the old receipt from the room event receipts if let Some(receipt_map) = self.room_event_receipts.get(room) { - if let Some(event_map) = receipt_map.get(receipt_type.as_ref()) { + if let Some(event_map) = + receipt_map.get(&(receipt_type.to_string(), thread.clone())) + { if let Some(user_map) = event_map.get_mut(&old_event) { user_map.remove(user_id); } @@ -340,7 +398,7 @@ impl MemoryStore { self.room_event_receipts .entry(room.clone()) .or_default() - .entry(receipt_type.to_string()) + .entry((receipt_type.to_string(), thread)) .or_default() .entry(event_id.clone()) .or_default() @@ -509,10 +567,12 @@ impl MemoryStore { &self, room_id: &RoomId, receipt_type: ReceiptType, + thread: ReceiptThread, user_id: &UserId, ) -> Result> { Ok(self.room_user_receipts.get(room_id).and_then(|m| { - m.get(receipt_type.as_ref()).and_then(|m| m.get(user_id).map(|r| r.clone())) + m.get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned))) + .and_then(|m| m.get(user_id).map(|r| r.clone())) })) } @@ -520,16 +580,20 @@ impl MemoryStore { &self, room_id: &RoomId, receipt_type: ReceiptType, + thread: ReceiptThread, event_id: &EventId, ) -> Result> { Ok(self .room_event_receipts .get(room_id) .and_then(|m| { - m.get(receipt_type.as_ref()).and_then(|m| { - m.get(event_id) - .map(|m| m.iter().map(|r| (r.key().clone(), r.value().clone())).collect()) - }) + m.get(&(receipt_type.to_string(), thread.as_str().map(ToOwned::to_owned))).and_then( + |m| { + m.get(event_id).map(|m| { + m.iter().map(|r| (r.key().clone(), r.value().clone())).collect() + }) + }, + ) }) .unwrap_or_default()) } @@ -542,6 +606,10 @@ impl MemoryStore { Ok(self.custom.insert(key.to_vec(), value)) } + async fn remove_custom_value(&self, key: &[u8]) -> Result>> { + Ok(self.custom.remove(key).map(|entry| entry.1)) + } + // The in-memory store doesn't cache media async fn add_media_content(&self, _request: &MediaRequest, _data: Vec) -> Result<()> { Ok(()) @@ -578,20 +646,26 @@ impl MemoryStore { #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] #[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl StateStore for MemoryStore { - async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { - self.save_filter(filter_name, filter_id).await + type Error = StoreError; + + async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result> { + self.get_kv_data(key).await } - async fn save_changes(&self, changes: &StateChanges) -> Result<()> { - self.save_changes(changes).await + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> Result<()> { + self.set_kv_data(key, value).await } - async fn get_filter(&self, filter_id: &str) -> Result> { - self.get_filter(filter_id).await + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> { + self.remove_kv_data(key).await } - async fn get_sync_token(&self) -> Result> { - self.get_sync_token().await + async fn save_changes(&self, changes: &StateChanges) -> Result<()> { + self.save_changes(changes).await } async fn get_presence_event(&self, user_id: &UserId) -> Result>> { @@ -690,18 +764,20 @@ impl StateStore for MemoryStore { &self, room_id: &RoomId, receipt_type: ReceiptType, + thread: ReceiptThread, user_id: &UserId, ) -> Result> { - self.get_user_room_receipt_event(room_id, receipt_type, user_id).await + self.get_user_room_receipt_event(room_id, receipt_type, thread, user_id).await } async fn get_event_room_receipt_events( &self, room_id: &RoomId, receipt_type: ReceiptType, + thread: ReceiptThread, event_id: &EventId, ) -> Result> { - self.get_event_room_receipt_events(room_id, receipt_type, event_id).await + self.get_event_room_receipt_events(room_id, receipt_type, thread, event_id).await } async fn get_custom_value(&self, key: &[u8]) -> Result>> { @@ -712,6 +788,10 @@ impl StateStore for MemoryStore { self.set_custom_value(key, value).await } + async fn remove_custom_value(&self, key: &[u8]) -> Result>> { + self.remove_custom_value(key).await + } + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { self.add_media_content(request, data).await } diff --git a/crates/matrix-sdk-base/src/store/mod.rs b/crates/matrix-sdk-base/src/store/mod.rs index 490d5c9dd5d..35a0cdb4114 100644 --- a/crates/matrix-sdk-base/src/store/mod.rs +++ b/crates/matrix-sdk-base/src/store/mod.rs @@ -21,7 +21,6 @@ //! store. use std::{ - borrow::Borrow, collections::{BTreeMap, BTreeSet}, fmt, ops::Deref, @@ -31,52 +30,55 @@ use std::{ sync::Arc, }; -use futures_signals::signal::{Mutable, ReadOnlyMutable}; +use eyeball::{shared::Observable as SharedObservable, Subscriber}; use once_cell::sync::OnceCell; #[cfg(any(test, feature = "testing"))] #[macro_use] pub mod integration_tests; +mod traits; -use async_trait::async_trait; use dashmap::DashMap; -use matrix_sdk_common::{locks::RwLock, AsyncTraitDeps}; +use matrix_sdk_common::locks::RwLock; #[cfg(feature = "e2e-encryption")] -use matrix_sdk_crypto::store::{CryptoStore, IntoCryptoStore}; +use matrix_sdk_crypto::store::{DynCryptoStore, IntoCryptoStore}; pub use matrix_sdk_store_encryption::Error as StoreEncryptionError; use ruma::{ api::client::push::get_notifications::v3::Notification, events::{ presence::PresenceEvent, - receipt::{Receipt, ReceiptEventContent, ReceiptType}, + receipt::ReceiptEventContent, room::{ member::{StrippedRoomMemberEvent, SyncRoomMemberEvent}, redaction::OriginalSyncRoomRedactionEvent, }, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnyStrippedStateEvent, - AnySyncStateEvent, EmptyStateKey, GlobalAccountDataEvent, GlobalAccountDataEventContent, - GlobalAccountDataEventType, RedactContent, RedactedStateEventContent, RoomAccountDataEvent, - RoomAccountDataEventContent, RoomAccountDataEventType, StateEventType, StaticEventContent, - StaticStateEventContent, SyncStateEvent, + AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType, }, serde::Raw, - EventId, MxcUri, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId, + EventId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId, }; /// BoxStream of owned Types pub type BoxStream = Pin + Send>>; use crate::{ - deserialized_responses::RawMemberEvent, - media::MediaRequest, - rooms::{RoomInfo, RoomType}, + rooms::{RoomInfo, RoomState}, MinimalRoomMemberEvent, Room, Session, SessionMeta, SessionTokens, }; pub(crate) mod ambiguity_map; mod memory_store; -pub use self::memory_store::MemoryStore; +#[cfg(any(test, feature = "testing"))] +pub use self::integration_tests::StateStoreIntegrationTests; +pub use self::{ + memory_store::MemoryStore, + traits::{ + DynStateStore, IntoStateStore, StateStore, StateStoreDataKey, StateStoreDataValue, + StateStoreExt, + }, +}; /// State store specific error type. #[derive(Debug, thiserror::Error)] @@ -135,371 +137,15 @@ impl StoreError { /// A `StateStore` specific result type. pub type Result = std::result::Result; -/// An abstract state store trait that can be used to implement different stores -/// for the SDK. -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -pub trait StateStore: AsyncTraitDeps { - /// Save the given filter id under the given name. - /// - /// # Arguments - /// - /// * `filter_name` - The name that should be used to store the filter id. - /// - /// * `filter_id` - The filter id that should be stored in the state store. - async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()>; - - /// Save the set of state changes in the store. - async fn save_changes(&self, changes: &StateChanges) -> Result<()>; - - /// Get the filter id that was stored under the given filter name. - /// - /// # Arguments - /// - /// * `filter_name` - The name that was used to store the filter id. - async fn get_filter(&self, filter_name: &str) -> Result>; - - /// Get the last stored sync token. - async fn get_sync_token(&self) -> Result>; - - /// Get the stored presence event for the given user. - /// - /// # Arguments - /// - /// * `user_id` - The id of the user for which we wish to fetch the presence - /// event for. - async fn get_presence_event(&self, user_id: &UserId) -> Result>>; - - /// Get a state event out of the state store. - /// - /// # Arguments - /// - /// * `room_id` - The id of the room the state event was received for. - /// - /// * `event_type` - The event type of the state event. - async fn get_state_event( - &self, - room_id: &RoomId, - event_type: StateEventType, - state_key: &str, - ) -> Result>>; - - /// Get a list of state events for a given room and `StateEventType`. - /// - /// # Arguments - /// - /// * `room_id` - The id of the room to find events for. - /// - /// * `event_type` - The event type. - async fn get_state_events( - &self, - room_id: &RoomId, - event_type: StateEventType, - ) -> Result>>; - - /// Get the current profile for the given user in the given room. - /// - /// # Arguments - /// - /// * `room_id` - The room id the profile is used in. - /// - /// * `user_id` - The id of the user the profile belongs to. - async fn get_profile( - &self, - room_id: &RoomId, - user_id: &UserId, - ) -> Result>; - - /// Get the `MemberEvent` for the given state key in the given room id. - /// - /// # Arguments - /// - /// * `room_id` - The room id the member event belongs to. - /// - /// * `state_key` - The user id that the member event defines the state for. - async fn get_member_event( - &self, - room_id: &RoomId, - state_key: &UserId, - ) -> Result>; - - /// Get all the user ids of members for a given room, for stripped and - /// regular rooms alike. - async fn get_user_ids(&self, room_id: &RoomId) -> Result>; - - /// Get all the user ids of members that are in the invited state for a - /// given room, for stripped and regular rooms alike. - async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result>; - - /// Get all the user ids of members that are in the joined state for a - /// given room, for stripped and regular rooms alike. - async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result>; - - /// Get all the pure `RoomInfo`s the store knows about. - async fn get_room_infos(&self) -> Result>; - - /// Get all the pure `RoomInfo`s the store knows about. - async fn get_stripped_room_infos(&self) -> Result>; - - /// Get all the users that use the given display name in the given room. - /// - /// # Arguments - /// - /// * `room_id` - The id of the room for which the display name users should - /// be fetched for. - /// - /// * `display_name` - The display name that the users use. - async fn get_users_with_display_name( - &self, - room_id: &RoomId, - display_name: &str, - ) -> Result>; - - /// Get an event out of the account data store. - /// - /// # Arguments - /// - /// * `event_type` - The event type of the account data event. - async fn get_account_data_event( - &self, - event_type: GlobalAccountDataEventType, - ) -> Result>>; - - /// Get an event out of the room account data store. - /// - /// # Arguments - /// - /// * `room_id` - The id of the room for which the room account data event - /// should - /// be fetched. - /// - /// * `event_type` - The event type of the room account data event. - async fn get_room_account_data_event( - &self, - room_id: &RoomId, - event_type: RoomAccountDataEventType, - ) -> Result>>; - - /// Get an event out of the user room receipt store. - /// - /// # Arguments - /// - /// * `room_id` - The id of the room for which the receipt should be - /// fetched. - /// - /// * `receipt_type` - The type of the receipt. - /// - /// * `user_id` - The id of the user for who the receipt should be fetched. - async fn get_user_room_receipt_event( - &self, - room_id: &RoomId, - receipt_type: ReceiptType, - user_id: &UserId, - ) -> Result>; - - /// Get events out of the event room receipt store. - /// - /// # Arguments - /// - /// * `room_id` - The id of the room for which the receipts should be - /// fetched. - /// - /// * `receipt_type` - The type of the receipts. - /// - /// * `event_id` - The id of the event for which the receipts should be - /// fetched. - async fn get_event_room_receipt_events( - &self, - room_id: &RoomId, - receipt_type: ReceiptType, - event_id: &EventId, - ) -> Result>; - - /// Get arbitrary data from the custom store - /// - /// # Arguments - /// - /// * `key` - The key to fetch data for - async fn get_custom_value(&self, key: &[u8]) -> Result>>; - - /// Put arbitrary data into the custom store - /// - /// # Arguments - /// - /// * `key` - The key to insert data into - /// - /// * `value` - The value to insert - async fn set_custom_value(&self, key: &[u8], value: Vec) -> Result>>; - - /// Add a media file's content in the media store. - /// - /// # Arguments - /// - /// * `request` - The `MediaRequest` of the file. - /// - /// * `content` - The content of the file. - async fn add_media_content(&self, request: &MediaRequest, content: Vec) -> Result<()>; - - /// Get a media file's content out of the media store. - /// - /// # Arguments - /// - /// * `request` - The `MediaRequest` of the file. - async fn get_media_content(&self, request: &MediaRequest) -> Result>>; - - /// Removes a media file's content from the media store. - /// - /// # Arguments - /// - /// * `request` - The `MediaRequest` of the file. - async fn remove_media_content(&self, request: &MediaRequest) -> Result<()>; - - /// Removes all the media files' content associated to an `MxcUri` from the - /// media store. - /// - /// # Arguments - /// - /// * `uri` - The `MxcUri` of the media files. - async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()>; - - /// Removes a room and all elements associated from the state store. - /// - /// # Arguments - /// - /// * `room_id` - The `RoomId` of the room to delete. - async fn remove_room(&self, room_id: &RoomId) -> Result<()>; -} - -/// Convenience functionality for state stores. -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -pub trait StateStoreExt: StateStore { - /// Get a specific state event of statically-known type. - /// - /// # Arguments - /// - /// * `room_id` - The id of the room the state event was received for. - async fn get_state_event_static( - &self, - room_id: &RoomId, - ) -> Result>>> - where - C: StaticEventContent + StaticStateEventContent + RedactContent, - C::Redacted: RedactedStateEventContent, - { - Ok(self.get_state_event(room_id, C::TYPE.into(), "").await?.map(Raw::cast)) - } - - /// Get a specific state event of statically-known type. - /// - /// # Arguments - /// - /// * `room_id` - The id of the room the state event was received for. - async fn get_state_event_static_for_key( - &self, - room_id: &RoomId, - state_key: &K, - ) -> Result>>> - where - C: StaticEventContent + StaticStateEventContent + RedactContent, - C::StateKey: Borrow, - C::Redacted: RedactedStateEventContent, - K: AsRef + ?Sized + Sync, - { - Ok(self.get_state_event(room_id, C::TYPE.into(), state_key.as_ref()).await?.map(Raw::cast)) - } - - /// Get a list of state events of a statically-known type for a given room. - /// - /// # Arguments - /// - /// * `room_id` - The id of the room to find events for. - async fn get_state_events_static( - &self, - room_id: &RoomId, - ) -> Result>>> - where - C: StaticEventContent + StaticStateEventContent + RedactContent, - C::Redacted: RedactedStateEventContent, - { - // FIXME: Could be more efficient, if we had streaming store accessor functions - Ok(self - .get_state_events(room_id, C::TYPE.into()) - .await? - .into_iter() - .map(Raw::cast) - .collect()) - } - - /// Get an event of a statically-known type from the account data store. - async fn get_account_data_event_static( - &self, - ) -> Result>>> - where - C: StaticEventContent + GlobalAccountDataEventContent, - { - Ok(self.get_account_data_event(C::TYPE.into()).await?.map(Raw::cast)) - } - - /// Get an event of a statically-known type from the room account data - /// store. - /// - /// # Arguments - /// - /// * `room_id` - The id of the room for which the room account data event - /// should be fetched. - async fn get_room_account_data_event_static( - &self, - room_id: &RoomId, - ) -> Result>>> - where - C: StaticEventContent + RoomAccountDataEventContent, - { - Ok(self.get_room_account_data_event(room_id, C::TYPE.into()).await?.map(Raw::cast)) - } -} - -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -impl StateStoreExt for T {} - -/// A type that can be type-erased into `Arc`. -/// -/// This trait is not meant to be implemented directly outside -/// `matrix-sdk-crypto`, but it is automatically implemented for everything that -/// implements `StateStore`. -pub trait IntoStateStore { - #[doc(hidden)] - fn into_state_store(self) -> Arc; -} - -impl IntoStateStore for T -where - T: StateStore + Sized + 'static, -{ - fn into_state_store(self) -> Arc { - Arc::new(self) - } -} - -impl IntoStateStore for Arc -where - T: StateStore + 'static, -{ - fn into_state_store(self) -> Arc { - self - } -} - /// A state store wrapper for the SDK. /// /// This adds additional higher level store functionality on top of a /// `StateStore` implementation. #[derive(Clone)] pub(crate) struct Store { - pub(super) inner: Arc, + pub(super) inner: Arc, session_meta: Arc>, - pub(super) session_tokens: Mutable>, + pub(super) session_tokens: SharedObservable>, /// The current sync token that should be used for the next sync call. pub(super) sync_token: Arc>>, rooms: Arc>, @@ -514,7 +160,7 @@ pub(crate) struct Store { impl Store { /// Create a new store, wrapping the given `StateStore` - pub fn new(inner: Arc) -> Self { + pub fn new(inner: Arc) -> Self { Self { inner, session_meta: Default::default(), @@ -548,7 +194,8 @@ impl Store { self.stripped_rooms.insert(room.room_id().to_owned(), room); } - let token = self.get_sync_token().await?; + let token = + self.get_kv_data(StateStoreDataKey::SyncToken).await?.and_then(|s| s.into_sync_token()); *self.sync_token.write().await = token; self.session_meta.set(session_meta).expect("Session Meta was already set"); @@ -561,10 +208,10 @@ impl Store { self.session_meta.get() } - /// The current [`SessionTokens`] containing our access token and optional - /// refresh token. - pub fn session_tokens(&self) -> ReadOnlyMutable> { - self.session_tokens.read_only() + /// The [`SessionTokens`] containing our access token and optional refresh + /// token. + pub fn session_tokens(&self) -> Subscriber> { + self.session_tokens.subscribe() } /// Set the current [`SessionTokens`]. @@ -576,7 +223,7 @@ impl Store { /// token and optional refresh token. pub fn session(&self) -> Option { let meta = self.session_meta.get()?; - let tokens = self.session_tokens.get_cloned()?; + let tokens = self.session_tokens().get()?; Some(Session::from_parts(meta.to_owned(), tokens)) } @@ -589,10 +236,10 @@ impl Store { pub fn get_room(&self, room_id: &RoomId) -> Option { self.rooms .get(room_id) - .and_then(|r| match r.room_type() { - RoomType::Joined => Some(r.clone()), - RoomType::Left => Some(r.clone()), - RoomType::Invited => self.get_stripped_room(room_id), + .and_then(|r| match r.state() { + RoomState::Joined => Some(r.clone()), + RoomState::Left => Some(r.clone()), + RoomState::Invited => self.get_stripped_room(room_id), }) .or_else(|| self.get_stripped_room(room_id)) } @@ -618,14 +265,14 @@ impl Store { self.stripped_rooms .entry(room_id.to_owned()) - .or_insert_with(|| Room::new(user_id, self.inner.clone(), room_id, RoomType::Invited)) + .or_insert_with(|| Room::new(user_id, self.inner.clone(), room_id, RoomState::Invited)) .clone() } /// Lookup the Room for the given RoomId, or create one, if it didn't exist /// yet in the store - pub async fn get_or_create_room(&self, room_id: &RoomId, room_type: RoomType) -> Room { - if room_type == RoomType::Invited { + pub async fn get_or_create_room(&self, room_id: &RoomId, room_type: RoomState) -> Room { + if room_type == RoomState::Invited { return self.get_or_create_stripped_room(room_id).await; } @@ -642,6 +289,7 @@ impl Store { } } +#[cfg(not(tarpaulin_include))] impl fmt::Debug for Store { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Store") @@ -655,7 +303,7 @@ impl fmt::Debug for Store { } impl Deref for Store { - type Target = dyn StateStore; + type Target = DynStateStore; fn deref(&self) -> &Self::Target { self.inner.deref() @@ -832,8 +480,8 @@ impl StateChanges { #[derive(Clone)] pub struct StoreConfig { #[cfg(feature = "e2e-encryption")] - pub(crate) crypto_store: Arc, - pub(crate) state_store: Arc, + pub(crate) crypto_store: Arc, + pub(crate) state_store: Arc, } #[cfg(not(tarpaulin_include))] @@ -849,7 +497,7 @@ impl StoreConfig { pub fn new() -> Self { Self { #[cfg(feature = "e2e-encryption")] - crypto_store: Arc::new(matrix_sdk_crypto::store::MemoryStore::new()), + crypto_store: matrix_sdk_crypto::store::MemoryStore::new().into_crypto_store(), state_store: Arc::new(MemoryStore::new()), } } diff --git a/crates/matrix-sdk-base/src/store/traits.rs b/crates/matrix-sdk-base/src/store/traits.rs new file mode 100644 index 00000000000..972dc0e2b71 --- /dev/null +++ b/crates/matrix-sdk-base/src/store/traits.rs @@ -0,0 +1,700 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{borrow::Borrow, collections::BTreeSet, fmt, sync::Arc}; + +use async_trait::async_trait; +use matrix_sdk_common::AsyncTraitDeps; +use ruma::{ + events::{ + presence::PresenceEvent, + receipt::{Receipt, ReceiptThread, ReceiptType}, + AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnySyncStateEvent, EmptyStateKey, + GlobalAccountDataEvent, GlobalAccountDataEventContent, GlobalAccountDataEventType, + RedactContent, RedactedStateEventContent, RoomAccountDataEvent, + RoomAccountDataEventContent, RoomAccountDataEventType, StateEventType, StaticEventContent, + StaticStateEventContent, SyncStateEvent, + }, + serde::Raw, + EventId, MxcUri, OwnedEventId, OwnedUserId, RoomId, UserId, +}; + +use super::{StateChanges, StoreError}; +use crate::{ + deserialized_responses::RawMemberEvent, media::MediaRequest, MinimalRoomMemberEvent, RoomInfo, +}; + +/// An abstract state store trait that can be used to implement different stores +/// for the SDK. +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +pub trait StateStore: AsyncTraitDeps { + /// The error type used by this state store. + type Error: fmt::Debug + Into + From; + + /// Get key-value data from the store. + /// + /// # Arguments + /// + /// * `key` - The key to fetch data for. + async fn get_kv_data( + &self, + key: StateStoreDataKey<'_>, + ) -> Result, Self::Error>; + + /// Put key-value data into the store. + /// + /// # Arguments + /// + /// * `key` - The key to identify the data in the store. + /// + /// * `value` - The data to insert. + /// + /// Panics if the key and value variants do not match. + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> Result<(), Self::Error>; + + /// Remove key-value data from the store. + /// + /// # Arguments + /// + /// * `key` - The key to remove the data for. + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error>; + + /// Save the set of state changes in the store. + async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error>; + + /// Get the stored presence event for the given user. + /// + /// # Arguments + /// + /// * `user_id` - The id of the user for which we wish to fetch the presence + /// event for. + async fn get_presence_event( + &self, + user_id: &UserId, + ) -> Result>, Self::Error>; + + /// Get a state event out of the state store. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room the state event was received for. + /// + /// * `event_type` - The event type of the state event. + async fn get_state_event( + &self, + room_id: &RoomId, + event_type: StateEventType, + state_key: &str, + ) -> Result>, Self::Error>; + + /// Get a list of state events for a given room and `StateEventType`. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room to find events for. + /// + /// * `event_type` - The event type. + async fn get_state_events( + &self, + room_id: &RoomId, + event_type: StateEventType, + ) -> Result>, Self::Error>; + + /// Get the current profile for the given user in the given room. + /// + /// # Arguments + /// + /// * `room_id` - The room id the profile is used in. + /// + /// * `user_id` - The id of the user the profile belongs to. + async fn get_profile( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result, Self::Error>; + + /// Get the `MemberEvent` for the given state key in the given room id. + /// + /// # Arguments + /// + /// * `room_id` - The room id the member event belongs to. + /// + /// * `state_key` - The user id that the member event defines the state for. + async fn get_member_event( + &self, + room_id: &RoomId, + state_key: &UserId, + ) -> Result, Self::Error>; + + /// Get all the user ids of members for a given room, for stripped and + /// regular rooms alike. + async fn get_user_ids(&self, room_id: &RoomId) -> Result, Self::Error>; + + /// Get all the user ids of members that are in the invited state for a + /// given room, for stripped and regular rooms alike. + async fn get_invited_user_ids(&self, room_id: &RoomId) + -> Result, Self::Error>; + + /// Get all the user ids of members that are in the joined state for a + /// given room, for stripped and regular rooms alike. + async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result, Self::Error>; + + /// Get all the pure `RoomInfo`s the store knows about. + async fn get_room_infos(&self) -> Result, Self::Error>; + + /// Get all the pure `RoomInfo`s the store knows about. + async fn get_stripped_room_infos(&self) -> Result, Self::Error>; + + /// Get all the users that use the given display name in the given room. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room for which the display name users should + /// be fetched for. + /// + /// * `display_name` - The display name that the users use. + async fn get_users_with_display_name( + &self, + room_id: &RoomId, + display_name: &str, + ) -> Result, Self::Error>; + + /// Get an event out of the account data store. + /// + /// # Arguments + /// + /// * `event_type` - The event type of the account data event. + async fn get_account_data_event( + &self, + event_type: GlobalAccountDataEventType, + ) -> Result>, Self::Error>; + + /// Get an event out of the room account data store. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room for which the room account data event + /// should + /// be fetched. + /// + /// * `event_type` - The event type of the room account data event. + async fn get_room_account_data_event( + &self, + room_id: &RoomId, + event_type: RoomAccountDataEventType, + ) -> Result>, Self::Error>; + + /// Get an event out of the user room receipt store. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room for which the receipt should be + /// fetched. + /// + /// * `receipt_type` - The type of the receipt. + /// + /// * `thread` - The thread containing this receipt. + /// + /// * `user_id` - The id of the user for who the receipt should be fetched. + async fn get_user_room_receipt_event( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + thread: ReceiptThread, + user_id: &UserId, + ) -> Result, Self::Error>; + + /// Get events out of the event room receipt store. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room for which the receipts should be + /// fetched. + /// + /// * `receipt_type` - The type of the receipts. + /// + /// * `thread` - The thread containing this receipt. + /// + /// * `event_id` - The id of the event for which the receipts should be + /// fetched. + async fn get_event_room_receipt_events( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + thread: ReceiptThread, + event_id: &EventId, + ) -> Result, Self::Error>; + + /// Get arbitrary data from the custom store + /// + /// # Arguments + /// + /// * `key` - The key to fetch data for + async fn get_custom_value(&self, key: &[u8]) -> Result>, Self::Error>; + + /// Put arbitrary data into the custom store + /// + /// # Arguments + /// + /// * `key` - The key to insert data into + /// + /// * `value` - The value to insert + async fn set_custom_value( + &self, + key: &[u8], + value: Vec, + ) -> Result>, Self::Error>; + + /// Remove arbitrary data from the custom store and return it if existed + /// + /// # Arguments + /// + /// * `key` - The key to remove data from + async fn remove_custom_value(&self, key: &[u8]) -> Result>, Self::Error>; + + /// Add a media file's content in the media store. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the file. + /// + /// * `content` - The content of the file. + async fn add_media_content( + &self, + request: &MediaRequest, + content: Vec, + ) -> Result<(), Self::Error>; + + /// Get a media file's content out of the media store. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the file. + async fn get_media_content( + &self, + request: &MediaRequest, + ) -> Result>, Self::Error>; + + /// Removes a media file's content from the media store. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the file. + async fn remove_media_content(&self, request: &MediaRequest) -> Result<(), Self::Error>; + + /// Removes all the media files' content associated to an `MxcUri` from the + /// media store. + /// + /// # Arguments + /// + /// * `uri` - The `MxcUri` of the media files. + async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<(), Self::Error>; + + /// Removes a room and all elements associated from the state store. + /// + /// # Arguments + /// + /// * `room_id` - The `RoomId` of the room to delete. + async fn remove_room(&self, room_id: &RoomId) -> Result<(), Self::Error>; +} + +#[repr(transparent)] +struct EraseStateStoreError(T); + +impl fmt::Debug for EraseStateStoreError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl StateStore for EraseStateStoreError { + type Error = StoreError; + + async fn get_kv_data( + &self, + key: StateStoreDataKey<'_>, + ) -> Result, Self::Error> { + self.0.get_kv_data(key).await.map_err(Into::into) + } + + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> Result<(), Self::Error> { + self.0.set_kv_data(key, value).await.map_err(Into::into) + } + + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error> { + self.0.remove_kv_data(key).await.map_err(Into::into) + } + + async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error> { + self.0.save_changes(changes).await.map_err(Into::into) + } + + async fn get_presence_event( + &self, + user_id: &UserId, + ) -> Result>, Self::Error> { + self.0.get_presence_event(user_id).await.map_err(Into::into) + } + + async fn get_state_event( + &self, + room_id: &RoomId, + event_type: StateEventType, + state_key: &str, + ) -> Result>, Self::Error> { + self.0.get_state_event(room_id, event_type, state_key).await.map_err(Into::into) + } + + async fn get_state_events( + &self, + room_id: &RoomId, + event_type: StateEventType, + ) -> Result>, Self::Error> { + self.0.get_state_events(room_id, event_type).await.map_err(Into::into) + } + + async fn get_profile( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result, Self::Error> { + self.0.get_profile(room_id, user_id).await.map_err(Into::into) + } + + async fn get_member_event( + &self, + room_id: &RoomId, + state_key: &UserId, + ) -> Result, Self::Error> { + self.0.get_member_event(room_id, state_key).await.map_err(Into::into) + } + + async fn get_user_ids(&self, room_id: &RoomId) -> Result, Self::Error> { + self.0.get_user_ids(room_id).await.map_err(Into::into) + } + + async fn get_invited_user_ids( + &self, + room_id: &RoomId, + ) -> Result, Self::Error> { + self.0.get_invited_user_ids(room_id).await.map_err(Into::into) + } + + async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result, Self::Error> { + self.0.get_joined_user_ids(room_id).await.map_err(Into::into) + } + + async fn get_room_infos(&self) -> Result, Self::Error> { + self.0.get_room_infos().await.map_err(Into::into) + } + + async fn get_stripped_room_infos(&self) -> Result, Self::Error> { + self.0.get_stripped_room_infos().await.map_err(Into::into) + } + + async fn get_users_with_display_name( + &self, + room_id: &RoomId, + display_name: &str, + ) -> Result, Self::Error> { + self.0.get_users_with_display_name(room_id, display_name).await.map_err(Into::into) + } + + async fn get_account_data_event( + &self, + event_type: GlobalAccountDataEventType, + ) -> Result>, Self::Error> { + self.0.get_account_data_event(event_type).await.map_err(Into::into) + } + + async fn get_room_account_data_event( + &self, + room_id: &RoomId, + event_type: RoomAccountDataEventType, + ) -> Result>, Self::Error> { + self.0.get_room_account_data_event(room_id, event_type).await.map_err(Into::into) + } + + async fn get_user_room_receipt_event( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + thread: ReceiptThread, + user_id: &UserId, + ) -> Result, Self::Error> { + self.0 + .get_user_room_receipt_event(room_id, receipt_type, thread, user_id) + .await + .map_err(Into::into) + } + + async fn get_event_room_receipt_events( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + thread: ReceiptThread, + event_id: &EventId, + ) -> Result, Self::Error> { + self.0 + .get_event_room_receipt_events(room_id, receipt_type, thread, event_id) + .await + .map_err(Into::into) + } + + async fn get_custom_value(&self, key: &[u8]) -> Result>, Self::Error> { + self.0.get_custom_value(key).await.map_err(Into::into) + } + + async fn set_custom_value( + &self, + key: &[u8], + value: Vec, + ) -> Result>, Self::Error> { + self.0.set_custom_value(key, value).await.map_err(Into::into) + } + + async fn remove_custom_value(&self, key: &[u8]) -> Result>, Self::Error> { + self.0.remove_custom_value(key).await.map_err(Into::into) + } + + async fn add_media_content( + &self, + request: &MediaRequest, + content: Vec, + ) -> Result<(), Self::Error> { + self.0.add_media_content(request, content).await.map_err(Into::into) + } + + async fn get_media_content( + &self, + request: &MediaRequest, + ) -> Result>, Self::Error> { + self.0.get_media_content(request).await.map_err(Into::into) + } + + async fn remove_media_content(&self, request: &MediaRequest) -> Result<(), Self::Error> { + self.0.remove_media_content(request).await.map_err(Into::into) + } + + async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<(), Self::Error> { + self.0.remove_media_content_for_uri(uri).await.map_err(Into::into) + } + + async fn remove_room(&self, room_id: &RoomId) -> Result<(), Self::Error> { + self.0.remove_room(room_id).await.map_err(Into::into) + } +} + +/// Convenience functionality for state stores. +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +pub trait StateStoreExt: StateStore { + /// Get a specific state event of statically-known type. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room the state event was received for. + async fn get_state_event_static( + &self, + room_id: &RoomId, + ) -> Result>>, Self::Error> + where + C: StaticEventContent + StaticStateEventContent + RedactContent, + C::Redacted: RedactedStateEventContent, + { + Ok(self.get_state_event(room_id, C::TYPE.into(), "").await?.map(Raw::cast)) + } + + /// Get a specific state event of statically-known type. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room the state event was received for. + async fn get_state_event_static_for_key( + &self, + room_id: &RoomId, + state_key: &K, + ) -> Result>>, Self::Error> + where + C: StaticEventContent + StaticStateEventContent + RedactContent, + C::StateKey: Borrow, + C::Redacted: RedactedStateEventContent, + K: AsRef + ?Sized + Sync, + { + Ok(self.get_state_event(room_id, C::TYPE.into(), state_key.as_ref()).await?.map(Raw::cast)) + } + + /// Get a list of state events of a statically-known type for a given room. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room to find events for. + async fn get_state_events_static( + &self, + room_id: &RoomId, + ) -> Result>>, Self::Error> + where + C: StaticEventContent + StaticStateEventContent + RedactContent, + C::Redacted: RedactedStateEventContent, + { + // FIXME: Could be more efficient, if we had streaming store accessor functions + Ok(self + .get_state_events(room_id, C::TYPE.into()) + .await? + .into_iter() + .map(Raw::cast) + .collect()) + } + + /// Get an event of a statically-known type from the account data store. + async fn get_account_data_event_static( + &self, + ) -> Result>>, Self::Error> + where + C: StaticEventContent + GlobalAccountDataEventContent, + { + Ok(self.get_account_data_event(C::TYPE.into()).await?.map(Raw::cast)) + } + + /// Get an event of a statically-known type from the room account data + /// store. + /// + /// # Arguments + /// + /// * `room_id` - The id of the room for which the room account data event + /// should be fetched. + async fn get_room_account_data_event_static( + &self, + room_id: &RoomId, + ) -> Result>>, Self::Error> + where + C: StaticEventContent + RoomAccountDataEventContent, + { + Ok(self.get_room_account_data_event(room_id, C::TYPE.into()).await?.map(Raw::cast)) + } +} + +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl StateStoreExt for T {} + +/// A type-erased [`StateStore`]. +pub type DynStateStore = dyn StateStore; + +/// A type that can be type-erased into `Arc`. +/// +/// This trait is not meant to be implemented directly outside +/// `matrix-sdk-crypto`, but it is automatically implemented for everything that +/// implements `StateStore`. +pub trait IntoStateStore { + #[doc(hidden)] + fn into_state_store(self) -> Arc; +} + +impl IntoStateStore for T +where + T: StateStore + Sized + 'static, +{ + fn into_state_store(self) -> Arc { + Arc::new(EraseStateStoreError(self)) + } +} + +// Turns a given `Arc` into `Arc` by attaching the +// StateStore impl vtable of `EraseStateStoreError`. +impl IntoStateStore for Arc +where + T: StateStore + 'static, +{ + fn into_state_store(self) -> Arc { + let ptr: *const T = Arc::into_raw(self); + let ptr_erased = ptr as *const EraseStateStoreError; + // SAFETY: EraseStateStoreError is repr(transparent) so T and + // EraseStateStoreError have the same layout and ABI + unsafe { Arc::from_raw(ptr_erased) } + } +} + +/// A value for key-value data that should be persisted into the store. +#[derive(Debug, Clone)] +pub enum StateStoreDataValue { + /// The sync token. + SyncToken(String), + + /// A filter with the given ID. + Filter(String), + + /// The user avatar url + UserAvatarUrl(String), +} + +impl StateStoreDataValue { + /// Get this value if it is a sync token. + pub fn into_sync_token(self) -> Option { + match self { + Self::SyncToken(token) => Some(token), + _ => None, + } + } + + /// Get this value if it is a filter. + pub fn into_filter(self) -> Option { + match self { + Self::Filter(filter) => Some(filter), + _ => None, + } + } + + /// Get this value if it is a user avatar url. + pub fn into_user_avatar_url(self) -> Option { + match self { + Self::UserAvatarUrl(user_avatar_url) => Some(user_avatar_url), + _ => None, + } + } +} + +/// A key for key-value data. +#[derive(Debug, Clone, Copy)] +pub enum StateStoreDataKey<'a> { + /// The sync token. + SyncToken, + + /// A filter with the given name. + Filter(&'a str), + + /// Avatar URL + UserAvatarUrl(&'a UserId), +} + +impl StateStoreDataKey<'_> { + /// Key to use for the [`SyncToken`][Self::SyncToken] variant. + pub const SYNC_TOKEN: &str = "sync_token"; + /// Key prefix to use for the [`Filter`][Self::Filter] variant. + pub const FILTER: &str = "filter"; + /// Key prefix to use for the [`UserAvatarUrl`][Self::UserAvatarUrl] + /// variant. + pub const USER_AVATAR_URL: &str = "user_avatar_url"; +} diff --git a/crates/matrix-sdk-common/Cargo.toml b/crates/matrix-sdk-common/Cargo.toml index d46b0506a0d..9682fe8c27d 100644 --- a/crates/matrix-sdk-common/Cargo.toml +++ b/crates/matrix-sdk-common/Cargo.toml @@ -27,12 +27,12 @@ serde_json = { workspace = true } [target.'cfg(target_arch = "wasm32")'.dependencies] async-lock = "2.5.0" -futures-util = { version = "0.3.21", default-features = false, features = ["channel"] } +futures-util = { workspace = true, features = ["channel"] } wasm-bindgen-futures = { version = "0.4.33", optional = true } -wasm-timer = "0.2.5" +gloo-timers = { version = "0.2.6", features = ["futures"] } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] -tokio = { version = "1.23.1", default-features = false, features = ["rt", "sync", "time"] } +tokio = { version = "1.24.2", default-features = false, features = ["rt", "sync", "time"] } [dev-dependencies] matrix-sdk-test = { path = "../../testing/matrix-sdk-test/", version= "0.6.0"} diff --git a/crates/matrix-sdk-common/src/deserialized_responses.rs b/crates/matrix-sdk-common/src/deserialized_responses.rs index 91bf0eaa28c..e7a91c86491 100644 --- a/crates/matrix-sdk-common/src/deserialized_responses.rs +++ b/crates/matrix-sdk-common/src/deserialized_responses.rs @@ -2,20 +2,145 @@ use std::collections::BTreeMap; use ruma::{ events::{AnySyncTimelineEvent, AnyTimelineEvent}, + push::Action, serde::Raw, DeviceKeyAlgorithm, OwnedDeviceId, OwnedEventId, OwnedUserId, }; use serde::{Deserialize, Serialize}; -/// The verification state of the device that sent an event to us. -#[derive(Clone, Debug, Deserialize, Serialize)] +const AUTHENTICITY_NOT_GUARANTEED: &str = + "The authenticity of this encrypted message can't be guaranteed on this device."; +const UNVERIFIED_IDENTITY: &str = "Encrypted by an unverified user."; +const UNSIGNED_DEVICE: &str = "Encrypted by a device not verified by its owner."; +const UNKNOWN_DEVICE: &str = "Encrypted by an unknown or deleted device."; + +/// Represents the state of verification for a decrypted message sent by a +/// device. +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] pub enum VerificationState { - /// The device is trusted. - Trusted, - /// The device is not trusted. - Untrusted, - /// The device is not known to us. - UnknownDevice, + /// This message is guaranteed to be authentic as it is coming from a device + /// belonging to a user that we have verified. + /// + /// This is the only state where authenticity can be guaranteed. + Verified, + + /// The message could not be linked to a verified device. + /// + /// For more detailed information on why the message is considered + /// unverified, refer to the VerificationLevel sub-enum. + Unverified(VerificationLevel), +} + +impl VerificationState { + /// Convert the `VerificationState` into a `ShieldState` which can be + /// directly used to decorate messages in the recommended way. + /// + /// This method decorates messages using a strict ruleset, for a more lax + /// variant of this method take a look at + /// [`VerificationState::to_shield_state_lax()`]. + pub fn to_shield_state_strict(&self) -> ShieldState { + match self { + VerificationState::Verified => ShieldState::None, + VerificationState::Unverified(level) => { + let message = match level { + VerificationLevel::UnverifiedIdentity | VerificationLevel::UnsignedDevice => { + UNVERIFIED_IDENTITY + } + VerificationLevel::None(link) => match link { + DeviceLinkProblem::MissingDevice => UNKNOWN_DEVICE, + DeviceLinkProblem::InsecureSource => AUTHENTICITY_NOT_GUARANTEED, + }, + }; + + ShieldState::Red { message } + } + } + } + + /// Convert the `VerificationState` into a `ShieldState` which can be used + /// to decorate messages in the recommended way. + /// + /// This implements a legacy, lax decoration mode. + /// + /// For a more strict variant of this method take a look at + /// [`VerificationState::to_shield_state_strict()`]. + pub fn to_shield_state_lax(&self) -> ShieldState { + match self { + VerificationState::Verified => ShieldState::None, + VerificationState::Unverified(level) => match level { + VerificationLevel::UnverifiedIdentity => { + // If you didn't show interest in verifying that user we don't + // nag you with an error message. + // TODO: We should detect identity rotation of a previously trusted identity and + // then warn see https://github.com/matrix-org/matrix-rust-sdk/issues/1129 + ShieldState::None + } + VerificationLevel::UnsignedDevice => { + // This is a high warning. The sender hasn't verified his own device. + ShieldState::Red { message: UNSIGNED_DEVICE } + } + VerificationLevel::None(link) => match link { + DeviceLinkProblem::MissingDevice => { + // Have to warn as it could have been a temporary injected device. + // Notice that the device might just not be known at this time, so callers + // should retry when there is a device change for that user. + ShieldState::Red { message: UNKNOWN_DEVICE } + } + DeviceLinkProblem::InsecureSource => { + // In legacy mode, we tone down this warning as it is quite common and + // mostly noise (due to legacy backup and lack of trusted forwards). + ShieldState::Grey { message: AUTHENTICITY_NOT_GUARANTEED } + } + }, + }, + } + } +} + +/// The sub-enum containing detailed information on why a message is considered +/// to be unverified. +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] +pub enum VerificationLevel { + /// The message was sent by a user identity we have not verified. + UnverifiedIdentity, + + /// The message was sent by a device not linked to (signed by) any user + /// identity. + UnsignedDevice, + + /// We weren't able to link the message back to any device. This might be + /// because the message claims to have been sent by a device which we have + /// not been able to obtain (for example, because the device was since + /// deleted) or because the key to decrypt the message was obtained from + /// an insecure source. + None(DeviceLinkProblem), +} + +/// The sub-enum containing detailed information on why we were not able to link +/// a message back to a device. +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] +pub enum DeviceLinkProblem { + /// The device is missing, either because it was deleted, or you haven't + /// yet downoaled it or the server is erroneously omitting it (federation + /// lag). + MissingDevice, + /// The key was obtained from an insecure source: imported from a file, + /// obtained from a legacy (asymmetric) backup, unsafe key forward, etc. + InsecureSource, +} + +/// Recommended decorations for decrypted messages, representing the message's +/// authenticity properties. +#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] +pub enum ShieldState { + /// A red shield with a tooltip containing the associated message should be + /// presented. + Red { message: &'static str }, + /// A grey shield with a tooltip containing the associated message should be + /// presented. + Grey { message: &'static str }, + /// No shield should be presented. + None, } /// The algorithm specific information of a decrypted event. @@ -37,16 +162,19 @@ pub enum AlgorithmInfo { #[derive(Clone, Debug, Deserialize, Serialize)] pub struct EncryptionInfo { /// The user ID of the event sender, note this is untrusted data unless the - /// `verification_state` is as well trusted. + /// `verification_state` is `Verified` as well. pub sender: OwnedUserId, /// The device ID of the device that sent us the event, note this is - /// untrusted data unless `verification_state` is as well trusted. + /// untrusted data unless `verification_state` is `Verified` as well. pub sender_device: Option, /// Information about the algorithm that was used to encrypt the event. pub algorithm_info: AlgorithmInfo, /// The verification state of the device that sent us the event, note this /// is the state of the device at the time of decryption. It may change in /// the future if a device gets verified or deleted. + /// + /// Callers that persist this should mark the state as dirty when a device + /// change is received down the sync. pub verification_state: VerificationState, } @@ -59,6 +187,9 @@ pub struct SyncTimelineEvent { /// The encryption info about the event. Will be `None` if the event was not /// encrypted. pub encryption_info: Option, + /// The push actions associated with this event. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub push_actions: Vec, } impl SyncTimelineEvent { @@ -71,7 +202,7 @@ impl SyncTimelineEvent { impl From> for SyncTimelineEvent { fn from(inner: Raw) -> Self { - Self { encryption_info: None, event: inner } + Self { encryption_info: None, event: inner, push_actions: Vec::default() } } } @@ -81,7 +212,11 @@ impl From for SyncTimelineEvent { // `TimelineEvent` without the `room_id`. By converting the raw value in // this way, we simply cause the `room_id` field in the json to be // ignored by a subsequent deserialization. - Self { encryption_info: o.encryption_info, event: o.event.cast() } + Self { + event: o.event.cast(), + encryption_info: o.encryption_info, + push_actions: o.push_actions, + } } } @@ -92,6 +227,18 @@ pub struct TimelineEvent { /// The encryption info about the event. Will be `None` if the event was not /// encrypted. pub encryption_info: Option, + /// The push actions associated with this event. + pub push_actions: Vec, +} + +impl TimelineEvent { + /// Create a new `TimelineEvent` from the given raw event. + /// + /// This is a convenience constructor for when you don't need to set + /// `encryption_info` or `push_action`, for example inside a test. + pub fn new(event: Raw) -> Self { + Self { event, encryption_info: None, push_actions: vec![] } + } } #[cfg(test)] @@ -106,7 +253,7 @@ mod tests { #[test] fn room_event_to_sync_room_event() { - let event = json! ({ + let event = json!({ "content": RoomMessageEventContent::text_plain("foobar"), "type": "m.room.message", "event_id": "$xxxxx:example.org", @@ -115,9 +262,7 @@ mod tests { "sender": "@carl:example.com", }); - let room_event = - TimelineEvent { event: Raw::new(&event).unwrap().cast(), encryption_info: None }; - + let room_event = TimelineEvent::new(Raw::new(&event).unwrap().cast()); let converted_room_event: SyncTimelineEvent = room_event.into(); let converted_event: AnySyncTimelineEvent = diff --git a/crates/matrix-sdk-common/src/timeout.rs b/crates/matrix-sdk-common/src/timeout.rs index 594e160acc4..2bc4b53fd4e 100644 --- a/crates/matrix-sdk-common/src/timeout.rs +++ b/crates/matrix-sdk-common/src/timeout.rs @@ -1,10 +1,12 @@ use std::{error::Error, fmt, time::Duration}; use futures_core::Future; +#[cfg(target_arch = "wasm32")] +use futures_util::future::{select, Either}; +#[cfg(target_arch = "wasm32")] +use gloo_timers::future::TimeoutFuture; #[cfg(not(target_arch = "wasm32"))] use tokio::time::timeout as tokio_timeout; -#[cfg(target_arch = "wasm32")] -use wasm_timer::ext::TryFutureExt; /// Error type notifying that a timeout has elapsed. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -25,21 +27,24 @@ impl Error for ElapsedError {} /// an error. pub async fn timeout(future: F, duration: Duration) -> Result where - F: Future, + F: Future + Unpin, { #[cfg(not(target_arch = "wasm32"))] return tokio_timeout(duration, future).await.map_err(|_| ElapsedError(())); #[cfg(target_arch = "wasm32")] { - let try_future = async { Ok::(future.await) }; - try_future.timeout(duration).await.map_err(|_| ElapsedError(())) + let timeout_future = + TimeoutFuture::new(u32::try_from(duration.as_millis()).expect("Overlong duration")); + + match select(future, timeout_future).await { + Either::Left((res, _)) => Ok(res), + Either::Right((_, _)) => Err(ElapsedError(())), + } } } -// TODO: Enable tests for wasm32 and debug why -// `with_timeout` test fails https://github.com/matrix-org/matrix-rust-sdk/issues/896 -#[cfg(all(test, not(target_arch = "wasm32")))] +#[cfg(test)] pub(crate) mod tests { use std::{future, time::Duration}; diff --git a/crates/matrix-sdk-crypto/Cargo.toml b/crates/matrix-sdk-crypto/Cargo.toml index 709b88cac2f..29d63c5dd92 100644 --- a/crates/matrix-sdk-crypto/Cargo.toml +++ b/crates/matrix-sdk-crypto/Cargo.toml @@ -15,7 +15,8 @@ version = "0.6.0" rustdoc-args = ["--cfg", "docsrs"] [features] -default = [] +default = ["automatic-room-key-forwarding"] +automatic-room-key-forwarding = [] js = ["ruma/js", "vodozemac/js"] qrcode = ["dep:matrix-sdk-qrcode"] backups_v1 = ["dep:olm-rs", "dep:bs58"] @@ -28,6 +29,7 @@ testing = ["dep:http"] aes = "0.8.1" atomic = "0.5.1" aquamarine = "0.1.12" +async-std = { version = "1.12.0", features = ["unstable"] } async-trait = { workspace = true } base64 = { workspace = true } bs58 = { version = "0.4.0", optional = true } @@ -35,11 +37,12 @@ byteorder = { workspace = true } ctr = "0.9.1" dashmap = { workspace = true } event-listener = "2.5.2" +eyeball = { workspace = true } futures-core = "0.3.24" -futures-util = { version = "0.3.21", default-features = false, features = ["alloc"] } -futures-signals = { version = "0.3.31", default-features = false } +futures-util = { workspace = true } hmac = "0.12.1" http = { workspace = true, optional = true } # feature = testing only +itertools = "0.10.5" matrix-sdk-qrcode = { version = "0.4.0", path = "../matrix-sdk-qrcode", optional = true } matrix-sdk-common = { version = "0.6.0", path = "../matrix-sdk-common" } olm-rs = { version = "2.2.0", features = ["serde"], optional = true } @@ -47,6 +50,7 @@ pbkdf2 = { version = "0.11.0", default-features = false } rand = "0.8.5" ruma = { workspace = true, features = ["rand", "canonical-json", "unstable-msc2677"] } serde = { workspace = true, features = ["derive", "rc"] } +rmp-serde = "1.1.1" serde_json = { workspace = true } sha2 = "0.10.2" thiserror = { workspace = true } @@ -56,16 +60,17 @@ zeroize = { workspace = true, features = ["zeroize_derive"] } cfg-if = "1.0" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] -tokio = { version = "1.23", default-features = false, features = ["time"] } +tokio = { version = "1.24", default-features = false, features = ["time"] } [dev-dependencies] anyhow = { workspace = true } assert_matches = "1.5.0" +ctor.workspace = true futures = { version = "0.3.21", default-features = false, features = ["executor"] } http = { workspace = true } indoc = "1.0.4" matrix-sdk-test = { version = "0.6.0", path = "../../testing/matrix-sdk-test" } proptest = { version = "1.0.0", default-features = false, features = ["std"] } # required for async_test macro -tokio = { version = "1.23.1", default-features = false, features = ["macros", "rt-multi-thread"] } -tracing-subscriber = "0.3.16" +tokio = { version = "1.24.2", default-features = false, features = ["macros", "rt-multi-thread"] } +tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } diff --git a/crates/matrix-sdk-crypto/src/file_encryption/attachments.rs b/crates/matrix-sdk-crypto/src/file_encryption/attachments.rs index 361352f9ac6..9ed02e8a5d4 100644 --- a/crates/matrix-sdk-crypto/src/file_encryption/attachments.rs +++ b/crates/matrix-sdk-crypto/src/file_encryption/attachments.rs @@ -203,7 +203,7 @@ impl<'a, R: Read + ?Sized + 'a> AttachmentEncryptor<'a, R> { /// /// # Arguments /// - /// * `reader` - The `Reader` that should be wrapped and enrypted. + /// * `reader` - The `Reader` that should be wrapped and encrypted. /// /// # Panics /// diff --git a/crates/matrix-sdk-crypto/src/gossiping/machine.rs b/crates/matrix-sdk-crypto/src/gossiping/machine.rs index dd9c7239526..ee77ba127fe 100644 --- a/crates/matrix-sdk-crypto/src/gossiping/machine.rs +++ b/crates/matrix-sdk-crypto/src/gossiping/machine.rs @@ -20,8 +20,12 @@ // If we don't trust the device store an object that remembers the request and // let the users introspect that object. -use std::{collections::BTreeMap, sync::Arc}; +use std::{ + collections::BTreeMap, + sync::{atomic::AtomicBool, Arc}, +}; +use atomic::Ordering; use dashmap::{mapref::entry::Entry, DashMap, DashSet}; use ruma::{ api::client::keys::claim_keys::v3::Request as KeysClaimRequest, @@ -34,10 +38,10 @@ use ruma::{ use tracing::{debug, info, trace, warn}; use vodozemac::{megolm::SessionOrdering, Curve25519PublicKey}; -use super::{GossipRequest, KeyForwardDecision, RequestEvent, RequestInfo, SecretInfo, WaitQueue}; +use super::{GossipRequest, RequestEvent, RequestInfo, SecretInfo, WaitQueue}; use crate::{ error::{EventError, OlmError, OlmResult}, - olm::{InboundGroupSession, Session, ShareState}, + olm::{InboundGroupSession, Session}, requests::{OutgoingRequest, ToDeviceRequest}, session_manager::GroupSessionCache, store::{Changes, CryptoStoreError, SecretImportError, Store}, @@ -45,7 +49,7 @@ use crate::{ forwarded_room_key::ForwardedRoomKeyContent, olm_v1::{DecryptedForwardedRoomKeyEvent, DecryptedSecretSendEvent}, room::encrypted::EncryptedEvent, - room_key_request::{Action, RequestedKeyInfo, RoomKeyRequestEvent}, + room_key_request::RoomKeyRequestEvent, secret_send::SecretSendContent, EventType, }, @@ -57,11 +61,13 @@ pub(crate) struct GossipMachine { user_id: Arc, device_id: Arc, store: Store, + #[cfg(feature = "automatic-room-key-forwarding")] outbound_group_sessions: GroupSessionCache, outgoing_requests: Arc>, incoming_key_requests: Arc>, wait_queue: WaitQueue, users_for_key_claim: Arc>>, + room_key_forwarding_enabled: Arc, } impl GossipMachine { @@ -69,21 +75,35 @@ impl GossipMachine { user_id: Arc, device_id: Arc, store: Store, - outbound_group_sessions: GroupSessionCache, + #[allow(unused)] outbound_group_sessions: GroupSessionCache, users_for_key_claim: Arc>>, ) -> Self { + let room_key_forwarding_enabled = + AtomicBool::new(cfg!(feature = "automatic-room-key-forwarding")).into(); + Self { user_id, device_id, store, + #[cfg(feature = "automatic-room-key-forwarding")] outbound_group_sessions, outgoing_requests: Default::default(), incoming_key_requests: Default::default(), wait_queue: WaitQueue::new(), users_for_key_claim, + room_key_forwarding_enabled, } } + #[cfg(feature = "automatic-room-key-forwarding")] + pub fn toggle_room_key_forwarding(&self, enabled: bool) { + self.room_key_forwarding_enabled.store(enabled, Ordering::SeqCst) + } + + pub fn is_room_key_forwarding_enabled(&self) -> bool { + self.room_key_forwarding_enabled.load(Ordering::SeqCst) + } + /// Load stored outgoing requests that were not yet sent out. async fn load_outgoing_requests(&self) -> Result, CryptoStoreError> { Ok(self @@ -171,8 +191,11 @@ impl GossipMachine { let event = item.value(); if let Some(s) = match event { + #[cfg(feature = "automatic-room-key-forwarding")] RequestEvent::KeyShare(e) => self.handle_key_request(e).await?, RequestEvent::Secret(e) => self.handle_secret_request(e).await?, + #[cfg(not(feature = "automatic-room-key-forwarding"))] + _ => None, } { changed_sessions.push(s); } @@ -312,6 +335,7 @@ impl GossipMachine { /// given `Device`, in that case we're going to queue up an /// `/keys/claim` request to be sent out and retry once the 1-to-1 Olm /// session has been established. + #[cfg(feature = "automatic-room-key-forwarding")] async fn try_to_forward_room_key( &self, event: &RoomKeyRequestEvent, @@ -323,7 +347,7 @@ impl GossipMachine { user_id = ?device.user_id(), device_id = ?device.device_id(), session_id = session.session_id(), - room_id = ?session.room_id, + room_id = ?session.room_id(), ?message_index, "Serving a room key request", ); @@ -358,11 +382,14 @@ impl GossipMachine { /// Answer a room key request after we found the matching /// `InboundGroupSession`. + #[cfg(feature = "automatic-room-key-forwarding")] async fn answer_room_key_request( &self, event: &RoomKeyRequestEvent, session: InboundGroupSession, ) -> OlmResult> { + use super::KeyForwardDecision; + let device = self.store.get_device(&event.sender, &event.content.requesting_device_id).await?; @@ -403,6 +430,7 @@ impl GossipMachine { } } + #[cfg(feature = "automatic-room-key-forwarding")] async fn handle_supported_key_request( &self, event: &RoomKeyRequestEvent, @@ -427,27 +455,38 @@ impl GossipMachine { } /// Handle a single incoming key request. + #[cfg(feature = "automatic-room-key-forwarding")] async fn handle_key_request(&self, event: &RoomKeyRequestEvent) -> OlmResult> { - match &event.content.action { - Action::Request(info) => match info { - RequestedKeyInfo::MegolmV1AesSha2(i) => { - self.handle_supported_key_request(event, &i.room_id, &i.session_id).await - } - #[cfg(feature = "experimental-algorithms")] - RequestedKeyInfo::MegolmV2AesSha2(i) => { - self.handle_supported_key_request(event, &i.room_id, &i.session_id).await - } - RequestedKeyInfo::Unknown(i) => { - debug!( - sender = ?event.sender, - algorithm = ?i.algorithm, - "Received a room key request for a unsupported algorithm" - ); - Ok(None) - } - }, - // We ignore cancellations here since there's nothing to serve. - Action::Cancellation => Ok(None), + use crate::types::events::room_key_request::{Action, RequestedKeyInfo}; + + if self.room_key_forwarding_enabled.load(Ordering::SeqCst) { + match &event.content.action { + Action::Request(info) => match info { + RequestedKeyInfo::MegolmV1AesSha2(i) => { + self.handle_supported_key_request(event, &i.room_id, &i.session_id).await + } + #[cfg(feature = "experimental-algorithms")] + RequestedKeyInfo::MegolmV2AesSha2(i) => { + self.handle_supported_key_request(event, &i.room_id, &i.session_id).await + } + RequestedKeyInfo::Unknown(i) => { + debug!( + sender = ?event.sender, + algorithm = ?i.algorithm, + "Received a room key request for a unsupported algorithm" + ); + Ok(None) + } + }, + // We ignore cancellations here since there's nothing to serve. + Action::Cancellation => Ok(None), + } + } else { + debug!( + sender = ?event.sender, + "Received a room key request, but room key forwarding has been turned off" + ); + Ok(None) } } @@ -476,6 +515,7 @@ impl GossipMachine { Ok(used_session) } + #[cfg(feature = "automatic-room-key-forwarding")] async fn forward_room_key( &self, session: &InboundGroupSession, @@ -533,11 +573,16 @@ impl GossipMachine { /// i. /// - `Err(x)`: Should *refuse* to share the session. `x` is the reason for /// the refusal. + + #[cfg(feature = "automatic-room-key-forwarding")] async fn should_share_key( &self, device: &Device, session: &InboundGroupSession, - ) -> Result, KeyForwardDecision> { + ) -> Result, super::KeyForwardDecision> { + use super::KeyForwardDecision; + use crate::olm::ShareState; + let outbound_session = self .outbound_group_sessions .get_with_id(session.room_id(), session.session_id()) @@ -575,20 +620,21 @@ impl GossipMachine { /// /// * `key_info` - The info of our key request containing information about /// the key we wish to request. + #[cfg(feature = "automatic-room-key-forwarding")] async fn should_request_key(&self, key_info: &SecretInfo) -> Result { - let request = self.store.get_secret_request_by_info(key_info).await?; - - // Don't send out duplicate requests, users can re-request them if they - // think a second request might succeed. - if request.is_none() { - let devices = self.store.get_user_devices(self.user_id()).await?; - - // Devices will only respond to key requests if the devices are - // verified, if the device isn't verified by us it's unlikely that - // we're verified by them either. Don't request keys if there isn't - // at least one verified device. - if devices.is_any_verified() { - Ok(true) + if self.room_key_forwarding_enabled.load(Ordering::SeqCst) { + let request = self.store.get_secret_request_by_info(key_info).await?; + + // Don't send out duplicate requests, users can re-request them if they + // think a second request might succeed. + if request.is_none() { + let devices = self.store.get_user_devices(self.user_id()).await?; + + // Devices will only respond to key requests if the devices are + // verified, if the device isn't verified by us it's unlikely that + // we're verified by them either. Don't request keys if there isn't + // at least one verified device. + Ok(devices.is_any_verified()) } else { Ok(false) } @@ -681,6 +727,7 @@ impl GossipMachine { /// * `room_id` - The id of the room where the key is used in. /// /// * `event` - The event for which we would like to request the room key. + #[cfg(feature = "automatic-room-key-forwarding")] pub async fn create_outgoing_key_request( &self, room_id: &RoomId, @@ -882,7 +929,7 @@ impl GossipMachine { info!( ?sender_key, claimed_sender_key = ?session.sender_key(), - room_id = ?session.room_id, + room_id = ?session.room_id(), session_id = session.session_id(), algorithm = ?session.algorithm(), "Received a forwarded room key but we already have a better version of it", @@ -989,6 +1036,7 @@ impl GossipMachine { mod tests { use std::sync::Arc; + #[cfg(feature = "automatic-room-key-forwarding")] use assert_matches::assert_matches; use dashmap::DashMap; use matrix_sdk_common::locks::Mutex; @@ -997,35 +1045,40 @@ mod tests { device_id, event_id, events::{ secret::request::{RequestAction, SecretName, ToDeviceSecretRequestEventContent}, - AnyToDeviceEventContent, ToDeviceEvent as RumaToDeviceEvent, + ToDeviceEvent as RumaToDeviceEvent, }, room_id, serde::Raw, user_id, DeviceId, RoomId, UserId, }; + #[cfg(feature = "automatic-room-key-forwarding")] use serde::{de::DeserializeOwned, Serialize}; use serde_json::json; - use super::{GossipMachine, KeyForwardDecision}; + use super::GossipMachine; + #[cfg(feature = "automatic-room-key-forwarding")] use crate::{ - identities::{LocalTrust, ReadOnlyDevice}, - olm::{Account, OutboundGroupSession, PrivateCrossSigningIdentity, ReadOnlyAccount}, - session_manager::GroupSessionCache, - store::{Changes, CryptoStore, MemoryStore, Store}, + gossiping::KeyForwardDecision, + olm::OutboundGroupSession, + store::Changes, types::{ events::{ - forwarded_room_key::ForwardedRoomKeyContent, - olm_v1::{AnyDecryptedOlmEvent, DecryptedOlmV1Event}, - room::encrypted::{ - EncryptedEvent, EncryptedToDeviceEvent, RoomEncryptedEventContent, - }, - EventType, ToDeviceEvent, + forwarded_room_key::ForwardedRoomKeyContent, olm_v1::AnyDecryptedOlmEvent, + olm_v1::DecryptedOlmV1Event, room::encrypted::EncryptedToDeviceEvent, EventType, + ToDeviceEvent, }, EventEncryptionAlgorithm, }, - verification::VerificationMachine, EncryptionSettings, OutgoingRequest, OutgoingRequests, }; + use crate::{ + identities::{LocalTrust, ReadOnlyDevice}, + olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, + session_manager::GroupSessionCache, + store::{IntoCryptoStore, MemoryStore, Store}, + types::events::room::encrypted::{EncryptedEvent, RoomEncryptedEventContent}, + verification::VerificationMachine, + }; fn alice_id() -> &'static UserId { user_id!("@alice:example.org") @@ -1063,12 +1116,13 @@ mod tests { ReadOnlyAccount::new(alice_id(), alice2_device_id()) } + #[cfg(feature = "automatic-room-key-forwarding")] fn test_gossip_machine(user_id: &UserId) -> GossipMachine { let user_id = Arc::from(user_id); let device_id = DeviceId::new(); let account = ReadOnlyAccount::new(&user_id, &device_id); - let store: Arc = Arc::new(MemoryStore::new()); + let store = MemoryStore::new().into_crypto_store(); let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id()))); let verification = VerificationMachine::new(account, identity.clone(), store.clone()); let store = Store::new(user_id.to_owned(), identity, store, verification); @@ -1090,7 +1144,7 @@ mod tests { let another_device = ReadOnlyDevice::from_account(&ReadOnlyAccount::new(&user_id, alice2_device_id())).await; - let store: Arc = Arc::new(MemoryStore::new()); + let store = MemoryStore::new().into_crypto_store(); let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id()))); let verification = VerificationMachine::new(account, identity.clone(), store.clone()); @@ -1107,6 +1161,7 @@ mod tests { ) } + #[cfg(feature = "automatic-room-key-forwarding")] async fn machines_for_key_share( other_machine_owner: &UserId, create_sessions: bool, @@ -1148,9 +1203,10 @@ mod tests { .await .unwrap(); + bob_machine.store.save_inbound_group_sessions(&[inbound_group_session]).await.unwrap(); + let content = group_session.encrypt(json!({}), "m.dummy").await; let event = wrap_encrypted_content(bob_machine.user_id(), content); - bob_machine.store.save_inbound_group_sessions(&[inbound_group_session]).await.unwrap(); // Alice wants to request the outbound group session from bob. assert!( @@ -1172,10 +1228,11 @@ mod tests { (alice_machine, alice_account, group_session, bob_machine) } + #[cfg(feature = "automatic-room-key-forwarding")] fn extract_content<'a>( recipient: &UserId, request: &'a OutgoingRequest, - ) -> &'a Raw { + ) -> &'a Raw { request .request() .to_device() @@ -1204,6 +1261,7 @@ mod tests { } } + #[cfg(feature = "automatic-room-key-forwarding")] fn request_to_event( recipient: &UserId, sender: &UserId, @@ -1250,6 +1308,7 @@ mod tests { } #[async_test] + #[cfg(feature = "automatic-room-key-forwarding")] async fn create_key_request() { let machine = get_machine().await; let account = account(); @@ -1281,6 +1340,7 @@ mod tests { } #[async_test] + #[cfg(feature = "automatic-room-key-forwarding")] async fn receive_forwarded_key() { let machine = get_machine().await; let account = account(); @@ -1383,6 +1443,7 @@ mod tests { } #[async_test] + #[cfg(feature = "automatic-room-key-forwarding")] async fn should_share_key_test() { let machine = get_machine().await; let account = account(); @@ -1497,6 +1558,7 @@ mod tests { assert_matches!(machine.should_share_key(&own_device, &other_inbound).await, Ok(None)); } + #[cfg(feature = "automatic-room-key-forwarding")] async fn key_share_cycle(algorithm: EventEncryptionAlgorithm) { let (alice_machine, alice_account, group_session, bob_machine) = machines_for_key_share(alice_id(), true, algorithm).await; @@ -1556,6 +1618,7 @@ mod tests { } #[async_test] + #[cfg(feature = "automatic-room-key-forwarding")] async fn reject_forward_from_another_user() { let (alice_machine, alice_account, group_session, bob_machine) = machines_for_key_share(bob_id(), true, EventEncryptionAlgorithm::MegolmV1AesSha2).await; @@ -1605,12 +1668,13 @@ mod tests { } #[async_test] + #[cfg(feature = "automatic-room-key-forwarding")] async fn key_share_cycle_megolm_v1() { key_share_cycle(EventEncryptionAlgorithm::MegolmV1AesSha2).await; } - #[cfg(feature = "experimental-algorithms")] #[async_test] + #[cfg(all(feature = "experimental-algorithms", feature = "automatic-room-key-forwarding"))] async fn key_share_cycle_megolm_v2() { key_share_cycle(EventEncryptionAlgorithm::MegolmV2AesSha2).await; } @@ -1686,6 +1750,7 @@ mod tests { } #[async_test] + #[cfg(feature = "automatic-room-key-forwarding")] async fn key_share_cycle_without_session() { let (alice_machine, alice_account, group_session, bob_machine) = machines_for_key_share(alice_id(), false, EventEncryptionAlgorithm::MegolmV1AesSha2) diff --git a/crates/matrix-sdk-crypto/src/gossiping/mod.rs b/crates/matrix-sdk-crypto/src/gossiping/mod.rs index f0487d37154..3ed63c02e1e 100644 --- a/crates/matrix-sdk-crypto/src/gossiping/mod.rs +++ b/crates/matrix-sdk-crypto/src/gossiping/mod.rs @@ -32,8 +32,6 @@ use ruma::{ DeviceId, OwnedDeviceId, OwnedTransactionId, OwnedUserId, TransactionId, UserId, }; use serde::{Deserialize, Serialize}; -use thiserror::Error; -use tracing::error; use crate::{ requests::{OutgoingRequest, ToDeviceRequest}, @@ -44,7 +42,8 @@ use crate::{ }; /// An error describing why a key share request won't be honored. -#[derive(Debug, Clone, Error, PartialEq, Eq)] +#[cfg(feature = "automatic-room-key-forwarding")] +#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)] pub enum KeyForwardDecision { /// The key request is from a device that we don't own, we're only sharing /// sessions that we know the requesting device already was supposed to get. @@ -302,7 +301,7 @@ impl WaitQueue { } } - #[cfg(test)] + #[cfg(all(test, feature = "automatic-room-key-forwarding"))] fn is_empty(&self) -> bool { self.requests_ids_waiting.is_empty() && self.requests_waiting_for_session.is_empty() } diff --git a/crates/matrix-sdk-crypto/src/identities/device.rs b/crates/matrix-sdk-crypto/src/identities/device.rs index 91d8cd20fab..7d2ea5c12a1 100644 --- a/crates/matrix-sdk-crypto/src/identities/device.rs +++ b/crates/matrix-sdk-crypto/src/identities/device.rs @@ -31,7 +31,7 @@ use ruma::{ }; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::Value; -use tracing::warn; +use tracing::{trace, warn}; use vodozemac::{olm::SessionConfig, Curve25519PublicKey, Ed25519PublicKey}; use super::{atomic_bool_deserializer, atomic_bool_serializer}; @@ -41,7 +41,7 @@ use crate::{ error::{EventError, OlmError, OlmResult, SignatureError}, identities::{ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities}, olm::{InboundGroupSession, Session, SignedJsonObject, VerifyJson}, - store::{Changes, CryptoStore, DeviceChanges, Result as StoreResult}, + store::{Changes, DeviceChanges, DynCryptoStore, Result as StoreResult}, types::{ events::{ forwarded_room_key::ForwardedRoomKeyContent, @@ -160,12 +160,84 @@ impl Device { /// can be confirmed as the creator and owner of the `m.room_key`. pub fn is_owner_of_session(&self, session: &InboundGroupSession) -> Result { if session.has_been_imported() { + // An imported room key means that we did not receive the room key as a + // `m.room_key` event when the room key was initially exchanged. + // + // This could mean a couple of things: + // 1. We received the room key as a `m.forwarded_room_key`. + // 2. We imported the room key through a file export. + // 3. We imported the room key through a backup. + // + // To be certain that a `Device` is the owner of a room key we need to have a + // proof that the `Curve25519` key of this `Device` was used to + // initially exchange the room key. This proof is provided by the Olm decryption + // step, see below for further clarification. + // + // Each of the above room key methods that receive room keys do not contain this + // proof and we received only a claim that the room key is tied to a + // `Curve25519` key. + // + // Since there's no way to verify that the claim is true, we say that we don't + // know that the room key belongs to this device. Ok(false) } else if let Some(key) = - session.signing_keys.get(&DeviceKeyAlgorithm::Ed25519).and_then(|k| k.ed25519()) + session.signing_keys().get(&DeviceKeyAlgorithm::Ed25519).and_then(|k| k.ed25519()) { + // Room keys are received as an `m.room.encrypted` event using the `m.olm` + // algorithm. Upon decryption of the `m.room.encrypted` event, the + // decrypted content will contain also a `Ed25519` public key[1]. + // + // The inclusion of this key means that the `Curve25519` key of the `Device` and + // Olm `Session`, established using the DH authentication of the + // double ratchet, binds the `Ed25519` key of the `Device` + // + // On the other hand, the `Ed25519` key is binding the `Curve25519` key + // using a signature which is uploaded to the server as + // `device_keys` and downloaded by us using a `/keys/query` request. + // + // A `Device` is considered to be the owner of a room key iff: + // 1. The `Curve25519` key that was used to establish the Olm `Session` + // that was used to decrypt the event is binding the `Ed25519`key + // of this `Device`. + // 2. The `Ed25519` key of this device has signed a `device_keys` object + // that contains the `Curve25519` key from step 1. + // + // We don't need to check the signature of the `Device` here, since we don't + // accept a `Device` unless it has a valid `Ed25519` signature. + // + // We do check that the `Curve25519` that was used to decrypt the event carrying + // the `m.room_key` and the `Ed25519` key that was part of the + // decrypted content matches the keys found in this `Device`. + // + // ```text + // β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + // β”‚ EncryptedToDeviceEventβ”‚ + // β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + // β”‚ + // β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ + // β”‚ Device β”‚ β–Ό + // β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + // β”‚ Device Keys β”‚ β”‚ Session β”‚ + // β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ + // β”‚ Ed25519 Key β”‚ Curve25519 Key │◄──────►│ Curve25519 Key β”‚ + // β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + // β–² β”‚ + // β”‚ β”‚ + // β”‚ β”‚ Decrypt + // β”‚ β”‚ + // β”‚ β–Ό + // β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + // β”‚ β”‚ DecryptedOlmV1Event β”‚ + // β”‚ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ + // β”‚ β”‚ keys β”‚ + // β”‚ β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ + // └────────────────────────────────►│ Ed25519 Key β”‚ + // β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + // ``` + // + // [1]: https://spec.matrix.org/v1.5/client-server-api/#molmv1curve25519-aes-sha2 let ed25519_comparison = self.ed25519_key().map(|k| k == key); - let curve25519_comparison = self.curve25519_key().map(|k| k == session.sender_key); + let curve25519_comparison = self.curve25519_key().map(|k| k == session.sender_key()); match (ed25519_comparison, curve25519_comparison) { // If we have any of the keys but they don't turn out to match, refuse to decrypt @@ -187,6 +259,40 @@ impl Device { } } + /// Is this device cross signed by its owner? + pub fn is_cross_signed_by_owner(&self) -> bool { + self.device_owner_identity + .as_ref() + .map(|device_identity| match device_identity { + // If it's one of our own devices, just check that + // we signed the device. + ReadOnlyUserIdentities::Own(identity) => { + identity.is_device_signed(&self.inner).is_ok() + } + // If it's a device from someone else, check + // if the other user has signed this device. + ReadOnlyUserIdentities::Other(device_identity) => { + device_identity.is_device_signed(&self.inner).is_ok() + } + }) + .unwrap_or(false) + } + + /// Is the device owner verified by us? + pub fn is_device_owner_verified(&self) -> bool { + self.device_owner_identity + .as_ref() + .map(|id| match id { + ReadOnlyUserIdentities::Own(own_identity) => own_identity.is_verified(), + ReadOnlyUserIdentities::Other(other_identity) => self + .own_identity + .as_ref() + .map(|oi| oi.is_verified() && oi.is_identity_signed(other_identity).is_ok()) + .unwrap_or(false), + }) + .unwrap_or(false) + } + /// Request an interactive verification with this `Device`. /// /// Returns a `VerificationRequest` object and a to-device request that @@ -570,7 +676,7 @@ impl ReadOnlyDevice { pub(crate) async fn encrypt( &self, - store: &dyn CryptoStore, + store: &DynCryptoStore, event_type: &str, content: Value, ) -> OlmResult<(Session, Raw)> { @@ -604,6 +710,13 @@ impl ReadOnlyDevice { let message = session.encrypt(self, event_type, content).await?; + trace!( + user_id = ?self.user_id(), + device_id = ?self.device_id(), + session_id = session.session_id(), + "Successfully encrypted a Megolm session", + ); + Ok((session, message)) } diff --git a/crates/matrix-sdk-crypto/src/identities/manager.rs b/crates/matrix-sdk-crypto/src/identities/manager.rs index bd9b7722b60..c2c842b5243 100644 --- a/crates/matrix-sdk-crypto/src/identities/manager.rs +++ b/crates/matrix-sdk-crypto/src/identities/manager.rs @@ -16,32 +16,32 @@ use std::{ collections::{BTreeMap, BTreeSet, HashSet}, ops::Deref, sync::Arc, - time::Duration, }; use futures_util::future::join_all; -use matrix_sdk_common::{ - executor::spawn, - timeout::{timeout, ElapsedError}, -}; +use itertools::Itertools; +use matrix_sdk_common::{executor::spawn, locks::Mutex}; use ruma::{ api::client::keys::get_keys::v3::Response as KeysQueryResponse, serde::Raw, DeviceId, - OwnedDeviceId, OwnedServerName, OwnedUserId, ServerName, UserId, + OwnedDeviceId, OwnedServerName, OwnedTransactionId, OwnedUserId, ServerName, TransactionId, + UserId, }; -use tracing::{debug, info, trace, warn}; +use tracing::{debug, info, instrument, trace, warn}; use crate::{ error::OlmResult, identities::{ - MasterPubkey, ReadOnlyDevice, ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities, - ReadOnlyUserIdentity, SelfSigningPubkey, UserSigningPubkey, + ReadOnlyDevice, ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities, ReadOnlyUserIdentity, }, olm::PrivateCrossSigningIdentity, requests::KeysQueryRequest, - store::{Changes, DeviceChanges, IdentityChanges, Result as StoreResult, Store}, - types::DeviceKeys, + store::{ + caches::SequenceNumber, Changes, DeviceChanges, IdentityChanges, Result as StoreResult, + Store, + }, + types::{CrossSigningKey, DeviceKeys, MasterPubkey, SelfSigningPubkey, UserSigningPubkey}, utilities::FailuresCache, - LocalTrust, + LocalTrust, SignatureError, }; enum DeviceChange { @@ -50,90 +50,47 @@ enum DeviceChange { None, } -/// A listener that can notify if a `/keys/query` response has been received. -#[derive(Clone, Debug)] -pub(crate) struct KeysQueryListener { - inner: Arc, - store: Store, -} - -/// Result type telling us if a `/keys/query` response was expected for a given -/// user. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub(crate) enum UserKeyQueryResult { - WasPending, - WasNotPending, -} - -impl KeysQueryListener { - pub(crate) fn new(store: Store) -> Self { - Self { inner: event_listener::Event::new().into(), store } - } - - /// Notify our listeners that we received a `/keys/query` response. - fn notify(&self) { - self.inner.notify(usize::MAX); - } - - /// Wait for a `/keys/query` response to be received if one is expected for - /// the given user. - /// - /// If the given timeout has elapsed the method will stop waiting and return - /// an error. - pub async fn wait_if_user_pending( - &self, - timeout: Duration, - user: &UserId, - ) -> Result { - let users_for_key_query = self.store.users_for_key_query().await.unwrap_or_default(); - - if users_for_key_query.contains(user) { - if let Err(e) = self.wait(timeout).await { - warn!( - user_id = ?user, - "The user has a pending `/key/query` request which did \ - not finish yet, some devices might be missing." - ); - - Err(e) - } else { - Ok(UserKeyQueryResult::WasPending) - } - } else { - Ok(UserKeyQueryResult::WasNotPending) - } - } - - /// Wait for a `/keys/query` response to be received. - /// - /// If the given timeout has elapsed the method will stop waiting and return - /// an error. - pub async fn wait(&self, duration: Duration) -> Result<(), ElapsedError> { - timeout(self.inner.listen(), duration).await - } +struct IdentityChange { + public: ReadOnlyUserIdentities, + private: Option, } #[derive(Debug, Clone)] pub(crate) struct IdentityManager { user_id: Arc, device_id: Arc, - keys_query_listener: KeysQueryListener, failures: FailuresCache, store: Store, + + /// Details of the current "in-flight" key query request, if any + keys_query_request_details: Arc>>, +} + +/// Details of an in-flight key query request +#[derive(Debug, Clone, Default)] +struct KeysQueryRequestDetails { + /// The sequence number, to be passed to + /// `Store.mark_tracked_users_as_up_to_date`. + sequence_number: SequenceNumber, + + /// A single batch of queries returned by the Store is broken up into one or + /// more actual KeysQueryRequests, each with their own request id. We + /// record the outstanding request ids here. + request_ids: HashSet, } impl IdentityManager { const MAX_KEY_QUERY_USERS: usize = 250; pub fn new(user_id: Arc, device_id: Arc, store: Store) -> Self { - let keys_query_listener = KeysQueryListener::new(store.clone()); + let keys_query_request_details = Mutex::new(None); IdentityManager { user_id, device_id, store, - keys_query_listener, failures: Default::default(), + keys_query_request_details: keys_query_request_details.into(), } } @@ -141,10 +98,6 @@ impl IdentityManager { &self.user_id } - pub fn listen_for_received_queries(&self) -> KeysQueryListener { - self.keys_query_listener.clone() - } - /// Receive a successful keys query response. /// /// Returns a list of devices newly discovered devices and devices that @@ -152,13 +105,16 @@ impl IdentityManager { /// /// # Arguments /// + /// * `request_id` - The request_id returned by users_for_key_query /// * `response` - The keys query response of the request that the client /// performed. pub async fn receive_keys_query_response( &self, + request_id: &TransactionId, response: &KeysQueryResponse, ) -> OlmResult<(DeviceChanges, IdentityChanges)> { debug!( + ?request_id, users = ?response.device_keys.keys().collect::>(), failures = ?response.failures, "Handling a keys query response" @@ -192,8 +148,30 @@ impl IdentityManager { }; self.store.save_changes(changes).await?; - self.mark_tracked_users_as_up_to_date(response.device_keys.keys().map(Deref::deref)) - .await?; + + // if this request is one of those we expected to be in flight, pass the + // sequence number back to the store so that it can mark devices up to + // date + let sequence_number = { + let mut request_details = self.keys_query_request_details.lock().await; + + request_details.as_mut().and_then(|details| { + if details.request_ids.remove(request_id) { + Some(details.sequence_number) + } else { + None + } + }) + }; + + if let Some(sequence_number) = sequence_number { + self.store + .mark_tracked_users_as_up_to_date( + response.device_keys.keys().map(Deref::deref), + sequence_number, + ) + .await?; + } let changed_devices = devices.changed.iter().fold(BTreeMap::new(), |mut acc, d| { acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id()); @@ -215,6 +193,7 @@ impl IdentityManager { identities.changed.iter().map(|i| i.user_id()).collect::>(); debug!( + ?request_id, ?new_devices, ?changed_devices, ?deleted_devices, @@ -223,8 +202,6 @@ impl IdentityManager { "Finished handling of the keys/query response" ); - self.keys_query_listener.notify(); - Ok((devices, identities)) } @@ -414,193 +391,302 @@ impl IdentityManager { Ok(changes) } - /// Handle the device keys part of a key query response. - /// - /// # Arguments + /// Check if the given public identity matches our private one. /// - /// * `response` - The keys query response. + /// If they don't match remove the private keys since our identity got + /// rotated. /// - /// Returns a list of identities that changed. Changed here means either - /// they are new, one of their properties has changed or they got deleted. - async fn handle_cross_singing_keys( + /// If they do match, mark the public identity as verified. + async fn check_private_identity( &self, - response: &KeysQueryResponse, - ) -> StoreResult<(IdentityChanges, Option)> { - let mut changes = IdentityChanges::default(); - let mut changed_identity = None; + identity: &ReadOnlyOwnUserIdentity, + ) -> Option { + let private_identity = self.store.private_identity(); + let private_identity = private_identity.lock().await; + let result = private_identity.clear_if_differs(identity).await; + + if result.any_cleared() { + info!(cleared = ?result, "Removed some or all of our private cross signing keys"); + Some((*private_identity).clone()) + } else { + // If the master key didn't rotate above (`clear_if_differs`), + // then this means that the public part and the private parts of + // the master key match. We previously did a signature check, so + // this means that the private part of the master key has signed + // the identity. We can safely mark the public part of the + // identity as verified. + if private_identity.has_master_key().await { + trace!("Marked our own identity as verified"); + identity.mark_as_verified() + } - // TODO: this is a bit chunky, refactor this into smaller methods. + None + } + } - for (user_id, master_key) in &response.master_keys { - match master_key.deserialize_as::() { - Ok(master_key) => { - let Some(self_signing) = response - .self_signing_keys - .get(user_id) - .and_then(|k| k.deserialize_as::().ok()) - else { + async fn handle_changed_identity( + &self, + response: &KeysQueryResponse, + master_key: MasterPubkey, + self_signing: SelfSigningPubkey, + i: ReadOnlyUserIdentities, + ) -> Result { + match i { + ReadOnlyUserIdentities::Own(mut identity) => { + if let Some(user_signing) = response + .user_signing_keys + .get(self.user_id()) + .and_then(|k| k.deserialize_as::().ok()) + { + if user_signing.user_id() != self.user_id() { warn!( - user_id = user_id.as_str(), - "A user identity didn't contain a self signing pubkey \ - or the key was invalid" + expected = ?self.user_id(), + got = ?user_signing.user_id(), + "User ID mismatch in our user-signing key", ); - continue; - }; - - let result = if let Some(mut i) = self.store.get_user_identity(user_id).await? { - match &mut i { - ReadOnlyUserIdentities::Own(identity) => { - let Some(user_signing) = response - .user_signing_keys - .get(user_id) - .and_then(|k| k.deserialize_as::().ok()) - else { - warn!( - user_id = user_id.as_str(), - "User identity for our own user didn't \ - contain a user signing pubkey", - ); - continue; - }; - - identity - .update(master_key, self_signing, user_signing) - .map(|_| (i, false)) - } - ReadOnlyUserIdentities::Other(identity) => { - identity.update(master_key, self_signing).map(|_| (i, false)) - } - } - } else if user_id == self.user_id() { - if let Some(user_signing) = response - .user_signing_keys - .get(user_id) - .and_then(|k| k.deserialize_as::().ok()) - { - if master_key.user_id() != user_id - || self_signing.user_id() != user_id - || user_signing.user_id() != user_id - { - warn!( - user_id = user_id.as_str(), - "User ID mismatch in one of the cross signing keys", - ); - continue; - } - ReadOnlyOwnUserIdentity::new(master_key, self_signing, user_signing) - .map(|i| (ReadOnlyUserIdentities::Own(i), true)) - } else { - warn!( - user_id = user_id.as_str(), - "User identity for our own user didn't contain a \ - user signing pubkey or the key isn't valid", - ); - continue; - } - } else if master_key.user_id() != user_id || self_signing.user_id() != user_id { - warn!( - user = user_id.as_str(), - "User ID mismatch in one of the cross signing keys", - ); - continue; + Err(SignatureError::UserIdMismatch) } else { - ReadOnlyUserIdentity::new(master_key, self_signing) - .map(|i| (ReadOnlyUserIdentities::Other(i), true)) - }; - - match result { - Ok((i, new)) => { - if let Some(identity) = i.own() { - let private_identity = self.store.private_identity(); - let private_identity = private_identity.lock().await; - - let result = private_identity.clear_if_differs(identity).await; - - if result.any_cleared() { - changed_identity = Some((*private_identity).clone()); - info!(cleared = ?result, "Removed some or all of our private cross signing keys"); - } else if new && private_identity.has_master_key().await { - // If the master key didn't rotate above (`clear_if_differs`), - // then this means that the public part and the private parts of - // the master key match. We previously did a signature check, so - // this means that the private part of the master key has signed - // the identity. We can safely mark the public part of the - // identity as verified. - identity.mark_as_verified(); - trace!("Received our own user identity, for which we possess the private key. Marking as verified."); - } - } + identity.update(master_key, self_signing, user_signing)?; - if new { - trace!(user_id = user_id.as_str(), identity = ?i, "Created new user identity"); - changes.new.push(i); - } else { - trace!(user_id = user_id.as_str(), identity = ?i, "Updated a user identity"); - changes.changed.push(i); - } - } - Err(e) => { - warn!( - user_id = user_id.as_str(), - error = ?e, - "Couldn't update or create new user identity" - ); - continue; - } + let private = self.check_private_identity(&identity).await; + + Ok(IdentityChange { public: identity.into(), private }) } + } else { + warn!( + "User identity for our own user didn't contain a user signing public key" + ); + + Err(SignatureError::MissingSigningKey) } - Err(e) => { + } + ReadOnlyUserIdentities::Other(mut identity) => { + identity.update(master_key, self_signing)?; + Ok(IdentityChange { public: identity.into(), private: None }) + } + } + } + + async fn handle_new_identity( + &self, + response: &KeysQueryResponse, + master_key: MasterPubkey, + self_signing: SelfSigningPubkey, + ) -> Result { + if master_key.user_id() == self.user_id() { + if let Some(user_signing) = response + .user_signing_keys + .get(self.user_id()) + .and_then(|k| k.deserialize_as::().ok()) + { + if user_signing.user_id() != self.user_id() { warn!( - user_id = user_id.as_str(), - error = ?e, - "Couldn't update or create new user identity" + expected = ?self.user_id(), + got = ?user_signing.user_id(), + "User ID mismatch in our user-signing key", ); - continue; + Err(SignatureError::UserIdMismatch) + } else { + let identity = + ReadOnlyOwnUserIdentity::new(master_key, self_signing, user_signing)?; + + let private = self.check_private_identity(&identity).await; + + Ok(IdentityChange { public: identity.into(), private }) } + } else { + warn!( + "User identity for our own user didn't contain a user signing pubkey or the key \ + isn't valid", + ); + + Err(SignatureError::MissingSigningKey) } + } else { + let identity = ReadOnlyUserIdentity::new(master_key, self_signing)?; + Ok(IdentityChange { public: identity.into(), private: None }) + } + } + + /// Try to deserialize the the master key and self-signing key of a + /// identity. + /// + /// Each user identity *must* at least contain a master and self-signing + /// key. Our own identity, in addition to those two, also contains a + /// user-signing key. + fn get_minimal_set_of_keys( + master_key: &Raw, + response: &KeysQueryResponse, + ) -> Option<(MasterPubkey, SelfSigningPubkey)> { + match master_key.deserialize_as::() { + Ok(master_key) => { + if let Some(self_signing) = response + .self_signing_keys + .get(master_key.user_id()) + .and_then(|k| k.deserialize_as::().ok()) + { + Some((master_key, self_signing)) + } else { + warn!("A user identity didn't contain a self signing pubkey or the key was invalid"); + None + } + } + Err(e) => { + warn!( + error = ?e, + "Couldn't update or create new user identity" + ); + None + } + } + } + + #[instrument(skip_all, fields(user_id))] + async fn update_or_create_identity( + &self, + response: &KeysQueryResponse, + changes: &mut IdentityChanges, + changed_identity: &mut Option, + user_id: &UserId, + master_key: MasterPubkey, + self_signing: SelfSigningPubkey, + ) -> StoreResult<()> { + if master_key.user_id() != user_id || self_signing.user_id() != user_id { + warn!(?user_id, "User ID mismatch in one of the cross signing keys",); + } else if let Some(i) = self.store.get_user_identity(user_id).await? { + match self.handle_changed_identity(response, master_key, self_signing, i).await { + Ok(c) => { + trace!(identity = ?c.public, "Updated a user identity"); + changes.changed.push(c.public); + *changed_identity = c.private; + } + Err(e) => { + warn!(error = ?e, "Couldn't update an existing user identity"); + } + } + } else { + match self.handle_new_identity(response, master_key, self_signing).await { + Ok(c) => { + trace!(identity = ?c.public, "Created new user identity"); + changes.new.push(c.public); + *changed_identity = c.private; + } + Err(e) => { + warn!(error = ?e, "Couldn't create new user identity"); + } + } + }; + + Ok(()) + } + + /// Handle the cross signing keys part of a key query response. + /// + /// # Arguments + /// + /// * `response` - The keys query response. + /// + /// Returns a list of identities that changed. Changed here means either + /// they are new or one of their properties has changed. + async fn handle_cross_singing_keys( + &self, + response: &KeysQueryResponse, + ) -> StoreResult<(IdentityChanges, Option)> { + let mut changes = IdentityChanges::default(); + let mut changed_identity = None; + + for (user_id, master_key) in &response.master_keys { + // Get the master and self-signing key for each identity, those are required for + // every user identity type, if we don't have those we skip over. + let Some((master_key, self_signing)) = Self::get_minimal_set_of_keys(master_key.cast_ref(), response) else { + continue; + }; + + self.update_or_create_identity( + response, + &mut changes, + &mut changed_identity, + user_id, + master_key, + self_signing, + ) + .await?; } Ok((changes, changed_identity)) } - /// Get a key query request if one is needed. + /// Get a list of key query requests needed. + /// + /// # Returns /// - /// Returns a key query request if the client should query E2E keys, - /// otherwise None. + /// A map of a request ID to the `/keys/query` request. /// /// The response of a successful key query requests needs to be passed to /// the [`OlmMachine`] with the [`receive_keys_query_response`]. /// /// [`OlmMachine`]: struct.OlmMachine.html /// [`receive_keys_query_response`]: #method.receive_keys_query_response - pub async fn users_for_key_query(&self) -> StoreResult> { - let users = self.store.users_for_key_query().await?; + pub async fn users_for_key_query( + &self, + ) -> StoreResult> { + // Forget about any previous key queries in flight. + *self.keys_query_request_details.lock().await = None; + + let (users, sequence_number) = self.store.users_for_key_query().await?; // We always want to track our own user, but in case we aren't in an encrypted // room yet, we won't be tracking ourselves yet. This ensures we are always // tracking ourselves. // // The check for emptiness is done first for performance. - let users = + let (users, sequence_number) = if users.is_empty() && !self.store.tracked_users().await?.contains(self.user_id()) { self.store.mark_user_as_changed(self.user_id()).await?; self.store.users_for_key_query().await? } else { - users + (users, sequence_number) }; if users.is_empty() { - Ok(Vec::new()) + Ok(BTreeMap::new()) } else { - let users: Vec = - users.into_iter().filter(|u| !self.failures.contains(u.server_name())).collect(); - - Ok(users + // Let's remove users that are part of the `FailuresCache`. The cache, which is + // a TTL cache, remembers users for which a previous `/key/query` request has + // failed. We don't retry a `/keys/query` for such users for a + // certain amount of time. + let users = users.into_iter().filter(|u| !self.failures.contains(u.server_name())); + + // We don't want to create a single `/keys/query` request with an infinite + // amount of users. Some servers will likely bail out after a + // certain amount of users and the responses will be large. In the + // case of a transmission error, we'll have to retransmit the large + // response. + // + // Convert the set of users into multiple /keys/query requests. + let requests: BTreeMap<_, _> = users .chunks(Self::MAX_KEY_QUERY_USERS) - .map(|u| u.iter().map(|u| (u.clone(), Vec::new())).collect()) - .map(KeysQueryRequest::new) - .collect()) + .into_iter() + .map(|user_chunk| { + let request_id = TransactionId::new(); + let request = KeysQueryRequest::new(user_chunk); + + debug!(?request_id, users = ?request.device_keys.keys(), "Created a /keys/query request"); + + (request_id, request) + }) + .collect(); + + // Collect the request IDs, these will be used later in the + // `receive_keys_query_response()` method to figure out if the user can be + // marked as up-to-date/non-dirty. + let request_ids = requests.keys().cloned().collect(); + let request_details = KeysQueryRequestDetails { sequence_number, request_ids }; + + *self.keys_query_request_details.lock().await = Some(request_details); + + Ok(requests) } } @@ -615,15 +701,7 @@ impl IdentityManager { &self, users: impl Iterator, ) -> StoreResult<()> { - let mut changed_user: Vec<(&UserId, bool)> = Vec::new(); - - for user_id in users { - if self.store.is_user_tracked(user_id).await? { - changed_user.push((user_id, true)) - } - } - - self.store.save_tracked_users(&changed_user).await + self.store.mark_tracked_users_as_changed(users).await } /// See the docs for [`OlmMachine::update_tracked_users()`]. @@ -631,23 +709,7 @@ impl IdentityManager { &self, users: impl IntoIterator, ) -> StoreResult<()> { - let mut tracked_users = Vec::new(); - - for user_id in users { - if !self.store.is_user_tracked(user_id).await? { - tracked_users.push((user_id, true)); - } - } - - self.store.save_tracked_users(&tracked_users).await - } - - pub async fn mark_tracked_users_as_up_to_date( - &self, - users: impl Iterator, - ) -> StoreResult<()> { - let updated_users: Vec<(&UserId, bool)> = users.map(|u| (u, false)).collect(); - self.store.save_tracked_users(&updated_users).await + self.store.update_tracked_users(users.into_iter()).await } } @@ -667,7 +729,7 @@ pub(crate) mod testing { identities::IdentityManager, machine::testing::response_from_file, olm::{PrivateCrossSigningIdentity, ReadOnlyAccount}, - store::{CryptoStore, MemoryStore, Store}, + store::{DynCryptoStore, IntoCryptoStore, MemoryStore, Store}, types::DeviceKeys, verification::VerificationMachine, UploadSigningKeysRequest, @@ -690,10 +752,14 @@ pub(crate) mod testing { let identity = Arc::new(Mutex::new(identity)); let user_id = Arc::from(user_id()); let account = ReadOnlyAccount::new(&user_id, device_id()); - let store: Arc = Arc::new(MemoryStore::new()); + let store: Arc = MemoryStore::new().into_crypto_store(); let verification = VerificationMachine::new(account, identity.clone(), store); - let store = - Store::new(user_id.clone(), identity, Arc::new(MemoryStore::new()), verification); + let store = Store::new( + user_id.clone(), + identity, + MemoryStore::new().into_crypto_store(), + verification, + ); IdentityManager::new(user_id, device_id().into(), store) } @@ -891,29 +957,17 @@ pub(crate) mod testing { #[cfg(test)] pub(crate) mod tests { - use std::{ops::Deref, time::Duration}; + use std::ops::Deref; use matrix_sdk_test::{async_test, response_from_file}; use ruma::{ api::{client::keys::get_keys::v3::Response as KeysQueryResponse, IncomingResponse}, - device_id, user_id, + device_id, user_id, TransactionId, }; use serde_json::json; use super::testing::{device_id, key_query, manager, other_key_query, other_user_id, user_id}; - fn key_query_without_failures() -> KeysQueryResponse { - let response = json!({ - "device_keys": { - "@alice:example.org": { - }, - } - }); - - let response = response_from_file(&response); - - KeysQueryResponse::try_from_http_response(response).unwrap() - } fn key_query_with_failures() -> KeysQueryResponse { let response = json!({ "device_keys": { @@ -942,11 +996,11 @@ pub(crate) mod tests { ); manager.receive_device_changes([alice].iter().map(Deref::deref)).await.unwrap(); assert!( - !manager.store.is_user_tracked(alice).await.unwrap(), + !manager.store.tracked_users().await.unwrap().contains(alice), "Receiving a device changes update for a user we don't track does nothing" ); assert!( - !manager.store.users_for_key_query().await.unwrap().contains(alice), + !manager.store.users_for_key_query().await.unwrap().0.contains(alice), "The user we don't track doesn't end up in the `/keys/query` request" ); } @@ -964,13 +1018,10 @@ pub(crate) mod tests { let devices = manager.store.get_user_devices(other_user).await.unwrap(); assert_eq!(devices.devices().count(), 0); - let listener = manager.listen_for_received_queries(); - - let task = tokio::task::spawn(async move { listener.wait(Duration::from_secs(10)).await }); - - manager.receive_keys_query_response(&other_key_query()).await.unwrap(); - - task.await.unwrap().unwrap(); + manager + .receive_keys_query_response(&TransactionId::new(), &other_key_query()) + .await + .unwrap(); let devices = manager.store.get_user_devices(other_user).await.unwrap(); assert_eq!(devices.devices().count(), 1); @@ -1001,7 +1052,10 @@ pub(crate) mod tests { let device_keys = manager.store.account().device_keys().await; manager - .receive_keys_query_response(&key_query(identity_request, device_keys)) + .receive_keys_query_response( + &TransactionId::new(), + &key_query(identity_request, device_keys), + ) .await .unwrap(); @@ -1036,6 +1090,37 @@ pub(crate) mod tests { ); } + /// If a user is invalidated while a /keys/query request is in flight, that + /// user is not removed from the list of outdated users when the + /// response is received + #[async_test] + async fn invalidation_race_handling() { + let manager = manager().await; + let alice = other_user_id(); + manager.update_tracked_users([alice]).await.unwrap(); + + // alice should be in the list of key queries + let (reqid, req) = manager.users_for_key_query().await.unwrap().pop_first().unwrap(); + assert!(req.device_keys.contains_key(alice)); + + // another invalidation turns up + manager.receive_device_changes([alice].into_iter()).await.unwrap(); + + // the response from the query arrives + manager.receive_keys_query_response(&reqid, &other_key_query()).await.unwrap(); + + // alice should *still* be in the list of key queries + let (reqid, req) = manager.users_for_key_query().await.unwrap().pop_first().unwrap(); + assert!(req.device_keys.contains_key(alice)); + + // another key query response + manager.receive_keys_query_response(&reqid, &other_key_query()).await.unwrap(); + + // finally alice should not be in the list + let queries = manager.users_for_key_query().await.unwrap(); + assert!(!queries.iter().any(|(_, r)| r.device_keys.contains_key(alice))); + } + #[async_test] async fn failure_handling() { let manager = manager().await; @@ -1051,40 +1136,27 @@ pub(crate) mod tests { manager.store.tracked_users().await.unwrap().contains(alice), "Alice is tracked after being marked as tracked" ); - assert!(manager - .users_for_key_query() - .await - .unwrap() - .iter() - .any(|r| r.device_keys.contains_key(alice))); + let (reqid, req) = manager.users_for_key_query().await.unwrap().pop_first().unwrap(); + assert!(req.device_keys.contains_key(alice)); + // a failure should stop us querying for the user's keys. let response = key_query_with_failures(); - - manager.receive_keys_query_response(&response).await.unwrap(); + manager.receive_keys_query_response(&reqid, &response).await.unwrap(); assert!(manager.failures.contains(alice.server_name())); assert!(!manager .users_for_key_query() .await .unwrap() .iter() - .any(|r| r.device_keys.contains_key(alice))); + .any(|(_, r)| r.device_keys.contains_key(alice))); - let response = key_query_without_failures(); - manager.receive_keys_query_response(&response).await.unwrap(); - assert!(!manager.failures.contains(alice.server_name())); - assert!(!manager - .users_for_key_query() - .await - .unwrap() - .iter() - .any(|r| r.device_keys.contains_key(alice))); - - manager.store.mark_user_as_changed(alice).await.unwrap(); + // clearing the failure flag should make the user reappear in the query list. + manager.failures.remove([alice.server_name().to_owned()].iter()); assert!(manager .users_for_key_query() .await .unwrap() .iter() - .any(|r| r.device_keys.contains_key(alice))); + .any(|(_, r)| r.device_keys.contains_key(alice))); } } diff --git a/crates/matrix-sdk-crypto/src/identities/mod.rs b/crates/matrix-sdk-crypto/src/identities/mod.rs index 1fad141450a..f69227f5d82 100644 --- a/crates/matrix-sdk-crypto/src/identities/mod.rs +++ b/crates/matrix-sdk-crypto/src/identities/mod.rs @@ -41,7 +41,7 @@ //! Both identity sets need to regularly fetched from the server using the //! `/keys/query` API call. pub(crate) mod device; -mod manager; +pub(crate) mod manager; pub(crate) mod user; use std::sync::{ @@ -50,11 +50,11 @@ use std::sync::{ }; pub use device::{Device, LocalTrust, ReadOnlyDevice, UserDevices}; -pub(crate) use manager::{IdentityManager, KeysQueryListener, UserKeyQueryResult}; +pub(crate) use manager::IdentityManager; use serde::{Deserialize, Deserializer, Serializer}; pub use user::{ - MasterPubkey, OwnUserIdentity, ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities, - ReadOnlyUserIdentity, SelfSigningPubkey, UserIdentities, UserIdentity, UserSigningPubkey, + OwnUserIdentity, ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities, ReadOnlyUserIdentity, + UserIdentities, UserIdentity, }; // These methods are only here because Serialize and Deserialize don't seem to diff --git a/crates/matrix-sdk-crypto/src/identities/user.rs b/crates/matrix-sdk-crypto/src/identities/user.rs index e64c219d8ee..8afc59545df 100644 --- a/crates/matrix-sdk-crypto/src/identities/user.rs +++ b/crates/matrix-sdk-crypto/src/identities/user.rs @@ -13,7 +13,6 @@ // limitations under the License. use std::{ - collections::btree_map::Iter, ops::Deref, sync::{ atomic::{AtomicBool, Ordering}, @@ -23,22 +22,19 @@ use std::{ use ruma::{ api::client::keys::upload_signatures::v3::Request as SignatureUploadRequest, - encryption::KeyUsage, events::{ key::verification::VerificationMethod, room::message::KeyVerificationRequestEventContent, }, - DeviceKeyId, EventId, OwnedDeviceId, OwnedDeviceKeyId, RoomId, UserId, + EventId, OwnedDeviceId, RoomId, UserId, }; use serde::{Deserialize, Serialize}; use tracing::error; -use vodozemac::Ed25519PublicKey; use super::{atomic_bool_deserializer, atomic_bool_serializer}; use crate::{ error::SignatureError, - olm::VerifyJson, store::{Changes, IdentityChanges}, - types::{CrossSigningKey, DeviceKeys, Signatures, SigningKey, SigningKeys}, + types::{MasterPubkey, SelfSigningPubkey, UserSigningPubkey}, verification::VerificationMachine, CryptoStoreError, OutgoingVerificationRequest, ReadOnlyDevice, VerificationRequest, }; @@ -273,344 +269,6 @@ impl UserIdentity { } } -/// Wrapper for a cross signing key marking it as the master key. -/// -/// Master keys are used to sign other cross signing keys, the self signing and -/// user signing keys of an user will be signed by their master key. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(try_from = "CrossSigningKey")] -pub struct MasterPubkey(Arc); - -macro_rules! impl_partial_eq { - ($key_type: ty) => { - impl PartialEq for $key_type { - /// The `PartialEq` implementation compares the user ID, the usage and the - /// key material, ignoring signatures. - /// - /// The usage could be safely ignored since the type guarantees it has the - /// correct usage by construction -- it is impossible to construct a - /// value of a particular key type with an incorrect usage. However, we - /// check it anyway, to codify the notion that the same key material - /// with a different usage results in a logically different key. - /// - /// The signatures are provided by other devices and don't alter the - /// identity of the key itself. - fn eq(&self, other: &Self) -> bool { - self.user_id() == other.user_id() - && self.keys() == other.keys() - && self.usage() == other.usage() - } - } - impl Eq for $key_type {} - }; -} - -impl_partial_eq!(MasterPubkey); -impl_partial_eq!(SelfSigningPubkey); -impl_partial_eq!(UserSigningPubkey); - -/// Wrapper for a cross signing key marking it as a self signing key. -/// -/// Self signing keys are used to sign the user's own devices. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(try_from = "CrossSigningKey")] -pub struct SelfSigningPubkey(Arc); - -/// Wrapper for a cross signing key marking it as a user signing key. -/// -/// User signing keys are used to sign the master keys of other users. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(try_from = "CrossSigningKey")] -pub struct UserSigningPubkey(Arc); - -impl TryFrom for MasterPubkey { - type Error = serde_json::Error; - - fn try_from(key: CrossSigningKey) -> Result { - if key.usage.contains(&KeyUsage::Master) && key.usage.len() == 1 { - Ok(Self(key.into())) - } else { - Err(serde::de::Error::custom(format!( - "Expected cross signing key usage {} was not found", - KeyUsage::Master - ))) - } - } -} - -impl TryFrom for SelfSigningPubkey { - type Error = serde_json::Error; - - fn try_from(key: CrossSigningKey) -> Result { - if key.usage.contains(&KeyUsage::SelfSigning) && key.usage.len() == 1 { - Ok(Self(key.into())) - } else { - Err(serde::de::Error::custom(format!( - "Expected cross signing key usage {} was not found", - KeyUsage::SelfSigning - ))) - } - } -} - -impl TryFrom for UserSigningPubkey { - type Error = serde_json::Error; - - fn try_from(key: CrossSigningKey) -> Result { - if key.usage.contains(&KeyUsage::UserSigning) && key.usage.len() == 1 { - Ok(Self(key.into())) - } else { - Err(serde::de::Error::custom(format!( - "Expected cross signing key usage {} was not found", - KeyUsage::UserSigning - ))) - } - } -} - -impl AsRef for MasterPubkey { - fn as_ref(&self) -> &CrossSigningKey { - &self.0 - } -} - -impl AsRef for SelfSigningPubkey { - fn as_ref(&self) -> &CrossSigningKey { - &self.0 - } -} - -impl AsRef for UserSigningPubkey { - fn as_ref(&self) -> &CrossSigningKey { - &self.0 - } -} - -impl<'a> From<&'a SelfSigningPubkey> for CrossSigningSubKeys<'a> { - fn from(key: &'a SelfSigningPubkey) -> Self { - CrossSigningSubKeys::SelfSigning(key) - } -} - -impl<'a> From<&'a UserSigningPubkey> for CrossSigningSubKeys<'a> { - fn from(key: &'a UserSigningPubkey) -> Self { - CrossSigningSubKeys::UserSigning(key) - } -} - -/// Enum over the cross signing sub-keys. -pub(crate) enum CrossSigningSubKeys<'a> { - /// The self signing subkey. - SelfSigning(&'a SelfSigningPubkey), - /// The user signing subkey. - UserSigning(&'a UserSigningPubkey), -} - -impl<'a> CrossSigningSubKeys<'a> { - /// Get the id of the user that owns this cross signing subkey. - fn user_id(&self) -> &UserId { - match self { - CrossSigningSubKeys::SelfSigning(key) => &key.0.user_id, - CrossSigningSubKeys::UserSigning(key) => &key.0.user_id, - } - } - - /// Get the `CrossSigningKey` from an sub-keys enum - pub(crate) fn cross_signing_key(&self) -> &CrossSigningKey { - match self { - CrossSigningSubKeys::SelfSigning(key) => &key.0, - CrossSigningSubKeys::UserSigning(key) => &key.0, - } - } -} - -impl MasterPubkey { - /// Get the user id of the master key's owner. - pub fn user_id(&self) -> &UserId { - &self.0.user_id - } - - /// Get the keys map of containing the master keys. - pub fn keys(&self) -> &SigningKeys { - &self.0.keys - } - - /// Get the list of `KeyUsage` that is set for this key. - pub fn usage(&self) -> &[KeyUsage] { - &self.0.usage - } - - /// Get the signatures map of this cross signing key. - pub fn signatures(&self) -> &Signatures { - &self.0.signatures - } - - /// Get the master key with the given key id. - /// - /// # Arguments - /// - /// * `key_id` - The id of the key that should be fetched. - pub fn get_key(&self, key_id: &DeviceKeyId) -> Option<&SigningKey> { - self.0.keys.get(key_id) - } - - /// Get the first available master key. - /// - /// There's usually only a single master key so this will usually fetch the - /// only key. - pub fn get_first_key(&self) -> Option { - self.0.get_first_key_and_id().map(|(_, k)| k) - } - - /// Check if the given JSON is signed by this master key. - /// - /// This method should only be used if an object's signature needs to be - /// checked multiple times, and you'd like to avoid performing the - /// canonicalization step each time. - /// - /// **Note**: Use this method with caution, the `canonical_json` needs to be - /// correctly canonicalized and make sure that the object you are checking - /// the signature for is allowed to be signed by a master key. - #[cfg(any(feature = "backups_v1", test))] - pub(crate) fn has_signed_raw( - &self, - signatures: &Signatures, - canonical_json: &str, - ) -> Result<(), SignatureError> { - if let Some((key_id, key)) = self.0.get_first_key_and_id() { - key.verify_canonicalized_json(&self.0.user_id, key_id, signatures, canonical_json) - } else { - Err(SignatureError::UnsupportedAlgorithm) - } - } - - /// Check if the given cross signing sub-key is signed by the master key. - /// - /// # Arguments - /// - /// * `subkey` - The subkey that should be checked for a valid signature. - /// - /// Returns an empty result if the signature check succeeded, otherwise a - /// SignatureError indicating why the check failed. - pub(crate) fn verify_subkey<'a>( - &self, - subkey: impl Into>, - ) -> Result<(), SignatureError> { - let subkey: CrossSigningSubKeys<'_> = subkey.into(); - - if self.0.user_id != subkey.user_id() { - return Err(SignatureError::UserIdMismatch); - } - - if let Some((key_id, key)) = self.0.get_first_key_and_id() { - key.verify_json(&self.0.user_id, key_id, subkey.cross_signing_key()) - } else { - Err(SignatureError::UnsupportedAlgorithm) - } - } -} - -impl<'a> IntoIterator for &'a MasterPubkey { - type Item = (&'a OwnedDeviceKeyId, &'a SigningKey); - type IntoIter = Iter<'a, OwnedDeviceKeyId, SigningKey>; - - fn into_iter(self) -> Self::IntoIter { - self.keys().iter() - } -} - -impl UserSigningPubkey { - /// Get the user id of the user signing key's owner. - pub fn user_id(&self) -> &UserId { - &self.0.user_id - } - - /// Get the list of `KeyUsage` that is set for this key. - pub fn usage(&self) -> &[KeyUsage] { - &self.0.usage - } - - /// Get the keys map of containing the user signing keys. - pub fn keys(&self) -> &SigningKeys { - &self.0.keys - } - - /// Check if the given master key is signed by this user signing key. - /// - /// # Arguments - /// - /// * `master_key` - The master key that should be checked for a valid - /// signature. - /// - /// Returns an empty result if the signature check succeeded, otherwise a - /// SignatureError indicating why the check failed. - pub(crate) fn verify_master_key( - &self, - master_key: &MasterPubkey, - ) -> Result<(), SignatureError> { - if let Some((key_id, key)) = self.0.get_first_key_and_id() { - key.verify_json(&self.0.user_id, key_id, master_key.0.as_ref()) - } else { - Err(SignatureError::UnsupportedAlgorithm) - } - } -} - -impl<'a> IntoIterator for &'a UserSigningPubkey { - type Item = (&'a OwnedDeviceKeyId, &'a SigningKey); - type IntoIter = Iter<'a, OwnedDeviceKeyId, SigningKey>; - - fn into_iter(self) -> Self::IntoIter { - self.keys().iter() - } -} - -impl SelfSigningPubkey { - /// Get the user id of the self signing key's owner. - pub fn user_id(&self) -> &UserId { - &self.0.user_id - } - - /// Get the keys map of containing the self signing keys. - pub fn keys(&self) -> &SigningKeys { - &self.0.keys - } - - /// Get the list of `KeyUsage` that is set for this key. - pub fn usage(&self) -> &[KeyUsage] { - &self.0.usage - } - - fn verify_device_keys(&self, device_keys: &DeviceKeys) -> Result<(), SignatureError> { - if let Some((key_id, key)) = self.0.get_first_key_and_id() { - key.verify_json(&self.0.user_id, key_id, device_keys) - } else { - Err(SignatureError::UnsupportedAlgorithm) - } - } - - /// Check if the given device is signed by this self signing key. - /// - /// # Arguments - /// - /// * `device` - The device that should be checked for a valid signature. - /// - /// Returns an empty result if the signature check succeeded, otherwise a - /// SignatureError indicating why the check failed. - pub(crate) fn verify_device(&self, device: &ReadOnlyDevice) -> Result<(), SignatureError> { - self.verify_device_keys(device.as_device_keys()) - } -} - -impl<'a> IntoIterator for &'a SelfSigningPubkey { - type Item = (&'a OwnedDeviceKeyId, &'a SigningKey); - type IntoIter = Iter<'a, OwnedDeviceKeyId, SigningKey>; - - fn into_iter(self) -> Self::IntoIter { - self.keys().iter() - } -} - /// Enum over the different user identity types we can have. #[derive(Debug, Clone, Serialize, Deserialize)] pub enum ReadOnlyUserIdentities { @@ -727,7 +385,7 @@ impl ReadOnlyUserIdentity { ) -> Result { master_key.verify_subkey(&self_signing_key)?; - Ok(Self { user_id: (*master_key.0.user_id).into(), master_key, self_signing_key }) + Ok(Self { user_id: master_key.user_id().into(), master_key, self_signing_key }) } #[cfg(test)] @@ -840,7 +498,7 @@ impl ReadOnlyOwnUserIdentity { master_key.verify_subkey(&user_signing_key)?; Ok(Self { - user_id: (*master_key.0.user_id).into(), + user_id: master_key.user_id().into(), master_key, self_signing_key, user_signing_key, @@ -1046,23 +704,18 @@ pub(crate) mod tests { use assert_matches::assert_matches; use matrix_sdk_common::locks::Mutex; use matrix_sdk_test::async_test; - use ruma::{encryption::KeyUsage, user_id, DeviceKeyId}; + use ruma::user_id; use serde_json::{json, Value}; - use vodozemac::Ed25519Signature; use super::{ testing::{device, get_other_identity, get_own_identity}, ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities, }; use crate::{ - identities::{ - manager::testing::{own_key_query, own_key_query_with_user_id}, - user::testing::get_other_own_identity, - Device, MasterPubkey, SelfSigningPubkey, UserSigningPubkey, - }, + identities::{manager::testing::own_key_query, Device}, olm::{PrivateCrossSigningIdentity, ReadOnlyAccount}, - store::MemoryStore, - types::CrossSigningKey, + store::{IntoCryptoStore, MemoryStore}, + types::{CrossSigningKey, MasterPubkey, SelfSigningPubkey, UserSigningPubkey}, verification::VerificationMachine, }; @@ -1105,7 +758,7 @@ pub(crate) mod tests { let verification_machine = VerificationMachine::new( ReadOnlyAccount::new(second.user_id(), second.device_id()), private_identity, - Arc::new(MemoryStore::new()), + MemoryStore::new().into_crypto_store(), ); let first = Device { @@ -1146,7 +799,7 @@ pub(crate) mod tests { let verification_machine = VerificationMachine::new( ReadOnlyAccount::new(device.user_id(), device.device_id()), id.clone(), - Arc::new(MemoryStore::new()), + MemoryStore::new().into_crypto_store(), ); let public_identity = identity.to_public_identity().await.unwrap(); @@ -1222,55 +875,4 @@ pub(crate) mod tests { Err(_) ); } - - #[async_test] - async fn partial_eq_cross_signing_keys() { - macro_rules! test_partial_eq { - ($key_type:ident, $key_field:ident, $field:ident, $usage:expr) => { - let user_id = user_id!("@example:localhost"); - let response = own_key_query(); - let raw = response.$field.get(user_id).unwrap(); - let key: $key_type = raw.deserialize_as().unwrap(); - - // A different key is naturally not the same as our key. - let other_identity = get_other_own_identity().await; - let other_key = other_identity.$key_field; - assert_ne!(key, other_key); - - // However, not even our own key material with another user ID is the same. - let other_user_id = user_id!("@example2:localhost"); - let other_response = own_key_query_with_user_id(&other_user_id); - let other_raw = other_response.$field.get(other_user_id).unwrap(); - let other_key: $key_type = other_raw.deserialize_as().unwrap(); - assert_ne!(key, other_key); - - // Now let's add another signature to our key. - let signature = Ed25519Signature::from_base64( - "mia28GKixFzOWKJ0h7Bdrdy2fjxiHCsst1qpe467FbW85H61UlshtKBoAXfTLlVfi0FX+/noJ8B3noQPnY+9Cg" - ).expect("The signature can always be decoded"); - let mut other_key: CrossSigningKey = raw.deserialize_as().unwrap(); - other_key.signatures.add_signature( - user_id.to_owned(), - DeviceKeyId::from_parts(ruma::DeviceKeyAlgorithm::Ed25519, "DEVICEID".into()), - signature, - ); - let other_key = other_key.try_into().unwrap(); - - // Additional signatures are fine, adding more does not change the key's identity. - assert_eq!(key, other_key); - - // However changing the usage results in a different key. - let mut other_key: CrossSigningKey = raw.deserialize_as().unwrap(); - other_key.usage.push($usage); - let other_key = $key_type { 0: other_key.into() }; - assert_ne!(key, other_key); - }; - } - - // The last argument is deliberately some usage which is *not* correct for the - // type. - test_partial_eq!(MasterPubkey, master_key, master_keys, KeyUsage::SelfSigning); - test_partial_eq!(SelfSigningPubkey, self_signing_key, self_signing_keys, KeyUsage::Master); - test_partial_eq!(UserSigningPubkey, user_signing_key, user_signing_keys, KeyUsage::Master); - } } diff --git a/crates/matrix-sdk-crypto/src/lib.rs b/crates/matrix-sdk-crypto/src/lib.rs index 05590bf5b6b..4f59ba0dc01 100644 --- a/crates/matrix-sdk-crypto/src/lib.rs +++ b/crates/matrix-sdk-crypto/src/lib.rs @@ -76,7 +76,7 @@ pub use file_encryption::{ }; pub use gossiping::GossipRequest; pub use identities::{ - Device, LocalTrust, MasterPubkey, OwnUserIdentity, ReadOnlyDevice, ReadOnlyOwnUserIdentity, + Device, LocalTrust, OwnUserIdentity, ReadOnlyDevice, ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities, ReadOnlyUserIdentity, UserDevices, UserIdentities, UserIdentity, }; pub use machine::OlmMachine; @@ -104,10 +104,24 @@ pub mod vodozemac { olm::{ DecryptionError as OlmDecryptionError, SessionCreationError as OlmSessionCreationError, }, - DecodeError, KeyError, PickleError, SignatureError, + DecodeError, KeyError, PickleError, SignatureError, VERSION, }; } +/// The version of the matrix-sdk-cypto crate being used +pub static VERSION: &str = env!("CARGO_PKG_VERSION"); + +// Enable tracing for tests in this crate +#[cfg(all(test, not(target_arch = "wasm32")))] +#[ctor::ctor] +fn init_logging() { + use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::from_default_env()) + .with(tracing_subscriber::fmt::layer().with_test_writer()) + .init(); +} + #[cfg_attr(doc, aquamarine::aquamarine)] /// A step by step guide that explains how to include [end-to-end-encryption] /// support in a [Matrix] client library. diff --git a/crates/matrix-sdk-crypto/src/machine.rs b/crates/matrix-sdk-crypto/src/machine.rs index 37a652662b8..68660fe298f 100644 --- a/crates/matrix-sdk-crypto/src/machine.rs +++ b/crates/matrix-sdk-crypto/src/machine.rs @@ -20,7 +20,10 @@ use std::{ use dashmap::DashMap; use matrix_sdk_common::{ - deserialized_responses::{AlgorithmInfo, EncryptionInfo, TimelineEvent, VerificationState}, + deserialized_responses::{ + AlgorithmInfo, DeviceLinkProblem, EncryptionInfo, TimelineEvent, VerificationLevel, + VerificationState, + }, locks::Mutex, }; use ruma::{ @@ -42,11 +45,12 @@ use ruma::{ RoomId, TransactionId, UInt, UserId, }; use serde_json::{value::to_raw_value, Value}; -use tracing::{debug, error, field::debug, info, instrument, warn}; -use vodozemac::{ - megolm::{DecryptionError, SessionOrdering}, - Curve25519PublicKey, Ed25519Signature, +use tracing::{ + debug, error, + field::{debug, display}, + info, instrument, warn, Span, }; +use vodozemac::{megolm::SessionOrdering, Curve25519PublicKey, Ed25519Signature}; #[cfg(feature = "backups_v1")] use crate::backups::BackupMachine; @@ -62,8 +66,8 @@ use crate::{ requests::{IncomingResponse, OutgoingRequest, UploadSigningKeysRequest}, session_manager::{GroupSessionManager, SessionManager}, store::{ - Changes, CryptoStore, DeviceChanges, IdentityChanges, MemoryStore, Result as StoreResult, - SecretImportError, Store, + Changes, DeviceChanges, DynCryptoStore, IdentityChanges, IntoCryptoStore, MemoryStore, + Result as StoreResult, SecretImportError, Store, }, types::{ events::{ @@ -78,7 +82,8 @@ use crate::{ Signatures, }, verification::{Verification, VerificationMachine, VerificationRequest}, - CrossSigningKeyExport, ReadOnlyDevice, RoomKeyImportResult, SignatureError, ToDeviceRequest, + CrossSigningKeyExport, CryptoStoreError, LocalTrust, ReadOnlyDevice, RoomKeyImportResult, + SignatureError, ToDeviceRequest, }; /// State machine implementation of the Olm/Megolm encryption protocol used for @@ -140,9 +145,7 @@ impl OlmMachine { /// /// * `device_id` - The unique id of the device that owns this machine. pub async fn new(user_id: &UserId, device_id: &DeviceId) -> Self { - let store: Arc = Arc::new(MemoryStore::new()); - - OlmMachine::with_store(user_id, device_id, store) + OlmMachine::with_store(user_id, device_id, MemoryStore::new()) .await .expect("Reading and writing to the memory store always succeeds") } @@ -150,7 +153,7 @@ impl OlmMachine { fn new_helper( user_id: &UserId, device_id: &DeviceId, - store: Arc, + store: Arc, account: ReadOnlyAccount, user_identity: PrivateCrossSigningIdentity, ) -> Self { @@ -178,14 +181,11 @@ impl OlmMachine { let identity_manager = IdentityManager::new(user_id.clone(), device_id.clone(), store.clone()); - let event = identity_manager.listen_for_received_queries(); - let session_manager = SessionManager::new( account.clone(), users_for_key_claim, key_request_machine.clone(), store.clone(), - event, ); #[cfg(feature = "backups_v1")] @@ -226,33 +226,51 @@ impl OlmMachine { /// the encryption keys. /// /// [`Cryptostore`]: trait.CryptoStore.html + #[instrument(skip(store), fields(ed25519_key, curve25519_key))] pub async fn with_store( user_id: &UserId, device_id: &DeviceId, - store: Arc, + store: impl IntoCryptoStore, ) -> StoreResult { + let store = store.into_crypto_store(); let account = match store.load_account().await? { - Some(a) => { - debug!( - ed25519_key = a.identity_keys().ed25519.to_base64().as_str(), - "Restored an Olm account" - ); - a + Some(account) => { + if user_id != account.user_id() || device_id != account.device_id() { + return Err(CryptoStoreError::MismatchedAccount { + expected: (account.user_id().to_owned(), account.device_id().to_owned()), + got: (user_id.to_owned(), device_id.to_owned()), + }); + } else { + Span::current() + .record("ed25519_key", display(account.identity_keys().ed25519)) + .record("curve25519_key", display(account.identity_keys().curve25519)); + debug!("Restored an Olm account"); + + account + } } None => { let account = ReadOnlyAccount::new(user_id, device_id); let device = ReadOnlyDevice::from_account(&account).await; - debug!( - ed25519_key = account.identity_keys().ed25519.to_base64().as_str(), - "Created a new Olm account" - ); + // We just created this device from our own Olm `Account`. Since we are the + // owners of the private keys of this device we can safely mark + // the device as verified. + device.set_trust_state(LocalTrust::Verified); + + Span::current() + .record("ed25519_key", display(account.identity_keys().ed25519)) + .record("curve25519_key", display(account.identity_keys().curve25519)); + + debug!("Created a new Olm account"); + let changes = Changes { account: Some(account.clone()), devices: DeviceChanges { new: vec![device], ..Default::default() }, ..Default::default() }; store.save_changes(changes).await?; + account } }; @@ -275,6 +293,11 @@ impl OlmMachine { Ok(OlmMachine::new_helper(user_id, device_id, store, account, identity)) } + /// Get the crypto store associated with this `OlmMachine` instance. + pub fn store(&self) -> &Store { + &self.store + } + /// The unique user id that owns this `OlmMachine` instance. pub fn user_id(&self) -> &UserId { &self.user_id @@ -303,6 +326,21 @@ impl OlmMachine { self.store.tracked_users().await } + /// Enable or disable room key forwarding. + /// + /// Room key forwarding allows the device to request room keys that it might + /// have missend in the original share using `m.room_key_request` + /// events. + #[cfg(feature = "automatic-room-key-forwarding")] + pub fn toggle_room_key_forwarding(&self, enable: bool) { + self.key_request_machine.toggle_room_key_forwarding(enable) + } + + /// Is room key forwarding enabled? + pub fn is_room_key_forwarding_enabled(&self) -> bool { + self.key_request_machine.is_room_key_forwarding_enabled() + } + /// Get the outgoing requests that need to be sent out. /// /// This returns a list of [`OutgoingRequest`]. Those requests need to be @@ -321,9 +359,13 @@ impl OlmMachine { requests.push(r); } - for request in self.identity_manager.users_for_key_query().await?.into_iter().map(|r| { - OutgoingRequest { request_id: TransactionId::new(), request: Arc::new(r.into()) } - }) { + for request in self + .identity_manager + .users_for_key_query() + .await? + .into_iter() + .map(|(request_id, r)| OutgoingRequest { request_id, request: Arc::new(r.into()) }) + { requests.push(request); } @@ -352,7 +394,7 @@ impl OlmMachine { self.receive_keys_upload_response(response).await?; } IncomingResponse::KeysQuery(response) => { - self.receive_keys_query_response(response).await?; + self.receive_keys_query_response(request_id, response).await?; } IncomingResponse::KeysClaim(response) => { self.receive_keys_claim_response(response).await?; @@ -506,9 +548,10 @@ impl OlmMachine { /// performed. async fn receive_keys_query_response( &self, + request_id: &TransactionId, response: &KeysQueryResponse, ) -> OlmResult<(DeviceChanges, IdentityChanges)> { - self.identity_manager.receive_keys_query_response(response).await + self.identity_manager.receive_keys_query_response(request_id, response).await } /// Get a request to upload E2EE keys to the server. @@ -555,7 +598,7 @@ impl OlmMachine { } #[instrument( - skip_all, + skip_all, // This function is only ever called by add_room_key via // handle_decrypted_to_device_event, so sender, sender_key, and algorithm are // already recorded. @@ -724,6 +767,9 @@ impl OlmMachine { /// used. /// /// `users` - The list of users that should receive the room key. + /// + /// `settings` - Encryption settings that affect when are room keys rotated + /// and who are they shared with pub async fn share_room_key( &self, room_id: &RoomId, @@ -801,6 +847,9 @@ impl OlmMachine { decrypted.result.raw_event = Raw::from_json(to_raw_value(&e)?); } } + AnyDecryptedOlmEvent::Dummy(_) => { + debug!("Received an `m.dummy` event"); + } AnyDecryptedOlmEvent::Custom(_) => { warn!("Received an unexpected encrypted to-device event"); } @@ -1067,36 +1116,63 @@ impl OlmMachine { session: &InboundGroupSession, sender: &UserId, ) -> MegolmResult<(VerificationState, Option)> { - Ok( - // First find the device corresponding to the Curve25519 identity - // key that sent us the session (recorded upon successful - // decryption of the `m.room_key` to-device message). - if let Some(device) = self - .get_user_devices(sender, None) - .await? - .devices() - .find(|d| d.curve25519_key() == Some(session.sender_key())) - { - // If the `Device` is confirmed to be the owner of the - // `InboundGroupSession` we will consider the session (i.e. - // "room key"), and by extension any events that are encrypted - // using this session, trusted if either: - // - // a) This is our own device, or - // b) The device itself is considered to be trusted. - if device.is_owner_of_session(session)? - && (device.is_our_own_device() || device.is_verified()) - { - (VerificationState::Trusted, Some(device.device_id().to_owned())) + let claimed_device = self + .get_user_devices(sender, None) + .await? + .devices() + .find(|d| d.curve25519_key() == Some(session.sender_key())); + + Ok(match claimed_device { + None => { + // We didn't find a device, no way to know if we should trust the + // `InboundGroupSession` or not. + + let link_problem = if session.has_been_imported() { + DeviceLinkProblem::InsecureSource } else { - (VerificationState::Untrusted, Some(device.device_id().to_owned())) + DeviceLinkProblem::MissingDevice + }; + + (VerificationState::Unverified(VerificationLevel::None(link_problem)), None) + } + Some(device) => { + let device_id = device.device_id().to_owned(); + + // We found a matching device, let's check if it owns the session. + if !(device.is_owner_of_session(session)?) { + // The key cannot be linked to an owning device. + ( + VerificationState::Unverified(VerificationLevel::None( + DeviceLinkProblem::InsecureSource, + )), + Some(device_id), + ) + } else { + // We only consider cross trust and not local trust. If your own device is not + // signed and send a message, it will be seen as Unverified. + if device.is_cross_signed_by_owner() { + // The device is cross signed by this owner Meaning that the user did self + // verify it properly. Let's check if we trust the identity. + if device.is_device_owner_verified() { + (VerificationState::Verified, Some(device_id)) + } else { + ( + VerificationState::Unverified( + VerificationLevel::UnverifiedIdentity, + ), + Some(device_id), + ) + } + } else { + // The device owner hasn't self-verified its device. + ( + VerificationState::Unverified(VerificationLevel::UnsignedDevice), + Some(device_id), + ) + } } - } else { - // We didn't find a device, no way to know if we should trust - // the `InboundGroupSession` or not. - (VerificationState::UnknownDevice, None) - }, - ) + } + }) } /// Get some metadata pertaining to a given group session. @@ -1149,7 +1225,11 @@ impl OlmMachine { let (decrypted_event, _) = session.decrypt(event).await?; let encryption_info = self.get_encryption_info(&session, &event.sender).await?; - Ok(TimelineEvent { encryption_info: Some(encryption_info), event: decrypted_event }) + Ok(TimelineEvent { + event: decrypted_event, + encryption_info: Some(encryption_info), + push_actions: vec![], + }) } else { Err(MegolmError::MissingRoomKey) } @@ -1189,9 +1269,12 @@ impl OlmMachine { let result = self.decrypt_megolm_events(room_id, &event, &content).await; if let Err(e) = &result { + #[cfg(feature = "automatic-room-key-forwarding")] match e { MegolmError::MissingRoomKey - | MegolmError::Decryption(DecryptionError::UnknownMessageIndex(_, _)) => { + | MegolmError::Decryption( + vodozemac::megolm::DecryptionError::UnknownMessageIndex(_, _), + ) => { self.key_request_machine.create_outgoing_key_request(room_id, &event).await?; } _ => {} @@ -1230,9 +1313,7 @@ impl OlmMachine { async fn wait_if_user_pending(&self, user_id: &UserId, timeout: Option) { if let Some(timeout) = timeout { - let listener = self.identity_manager.listen_for_received_queries(); - - let _ = listener.wait_if_user_pending(timeout, user_id).await; + self.store.wait_if_user_key_query_pending(timeout, user_id).await; } } @@ -1587,11 +1668,16 @@ pub(crate) mod tests { use std::{collections::BTreeMap, iter, sync::Arc}; use assert_matches::assert_matches; + use matrix_sdk_common::deserialized_responses::{ + DeviceLinkProblem, ShieldState, VerificationLevel, VerificationState, + }; use matrix_sdk_test::{async_test, test_json}; use ruma::{ api::{ client::{ - keys::{claim_keys, get_keys, upload_keys}, + keys::{ + claim_keys, get_keys, get_keys::v3::Response as KeyQueryResponse, upload_keys, + }, sync::sync_events::DeviceLists, to_device::send_event_to_device::v3::Response as ToDeviceResponse, }, @@ -1609,29 +1695,30 @@ pub(crate) mod tests { room_id, serde::Raw, uint, user_id, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, - OwnedDeviceKeyId, UserId, + OwnedDeviceKeyId, TransactionId, UserId, }; use serde_json::json; use vodozemac::{ megolm::{GroupSession, SessionConfig}, - Ed25519PublicKey, + Curve25519PublicKey, Ed25519PublicKey, }; use super::testing::response_from_file; use crate::{ error::EventError, machine::OlmMachine, - olm::VerifyJson, + olm::{InboundGroupSession, OutboundGroupSession, VerifyJson}, types::{ events::{ room::encrypted::{EncryptedToDeviceEvent, ToDeviceEncryptedEventContent}, ToDeviceEvent, }, - DeviceKeys, SignedKey, SigningKeys, + CrossSigningKey, DeviceKeys, EventEncryptionAlgorithm, SignedKey, SigningKeys, }, utilities::json_convert, verification::tests::{outgoing_request_to_event, request_to_event}, - EncryptionSettings, MegolmError, OlmError, ReadOnlyDevice, ToDeviceRequest, + EncryptionSettings, LocalTrust, MegolmError, OlmError, ReadOnlyDevice, ToDeviceRequest, + UserIdentities, }; /// These keys need to be periodically uploaded to the server. @@ -1645,6 +1732,10 @@ pub(crate) mod tests { device_id!("JLAFKJWSCS") } + fn bob_device_id() -> &'static DeviceId { + device_id!("NTHHPZDPRN") + } + fn user_id() -> &'static UserId { user_id!("@bob:example.com") } @@ -1679,7 +1770,7 @@ pub(crate) mod tests { } pub(crate) async fn get_prepared_machine() -> (OlmMachine, OneTimeKeys) { - let machine = OlmMachine::new(user_id(), alice_device_id()).await; + let machine = OlmMachine::new(user_id(), bob_device_id()).await; machine.account.inner.update_uploaded_key_count(0); let request = machine.keys_for_upload().await.expect("Can't prepare initial key upload"); let response = keys_upload_response(); @@ -1691,8 +1782,9 @@ pub(crate) mod tests { async fn get_machine_after_query() -> (OlmMachine, OneTimeKeys) { let (machine, otk) = get_prepared_machine().await; let response = keys_query_response(); + let req_id = TransactionId::new(); - machine.receive_keys_query_response(&response).await.unwrap(); + machine.receive_keys_query_response(&req_id, &response).await.unwrap(); (machine, otk) } @@ -1757,6 +1849,14 @@ pub(crate) mod tests { async fn create_olm_machine() { let machine = OlmMachine::new(user_id(), alice_device_id()).await; assert!(!machine.account().shared()); + + let own_device = machine + .get_device(machine.user_id(), machine.device_id(), None) + .await + .unwrap() + .expect("We should always have our own device in the store"); + + assert!(own_device.is_locally_trusted(), "Our own device should always be locally trusted"); } #[async_test] @@ -1904,7 +2004,8 @@ pub(crate) mod tests { let alice_devices = machine.store.get_user_devices(alice_id).await.unwrap(); assert!(alice_devices.devices().peekable().peek().is_none()); - machine.receive_keys_query_response(&response).await.unwrap(); + let req_id = TransactionId::new(); + machine.receive_keys_query_response(&req_id, &response).await.unwrap(); let device = machine.store.get_device(alice_id, alice_device_id).await.unwrap().unwrap(); assert_eq!(device.user_id(), alice_id); @@ -2085,6 +2186,400 @@ pub(crate) mod tests { } #[async_test] + async fn test_decryption_verification_state() { + macro_rules! assert_shield { + ($foo: ident, $strict: ident, $lax: ident) => { + let lax = $foo.verification_state.to_shield_state_lax(); + let strict = $foo.verification_state.to_shield_state_strict(); + + assert_matches!(lax, ShieldState::$lax { .. }); + assert_matches!(strict, ShieldState::$strict { .. }); + }; + } + let (alice, bob) = get_machine_pair_with_setup_sessions().await; + let room_id = room_id!("!test:example.org"); + + let to_device_requests = alice + .share_room_key(room_id, iter::once(bob.user_id()), EncryptionSettings::default()) + .await + .unwrap(); + + let event = ToDeviceEvent::new( + alice.user_id().to_owned(), + to_device_requests_to_content(to_device_requests), + ); + + let group_session = + bob.decrypt_to_device_event(&event).await.unwrap().inbound_group_session; + + let export = group_session.as_ref().unwrap().clone().export().await; + + bob.store.save_inbound_group_sessions(&[group_session.unwrap()]).await.unwrap(); + + let plaintext = "It is a secret to everybody"; + + let content = RoomMessageEventContent::text_plain(plaintext); + + let encrypted_content = alice + .encrypt_room_event(room_id, AnyMessageLikeEventContent::RoomMessage(content.clone())) + .await + .unwrap(); + + let event = json!({ + "event_id": "$xxxxx:example.org", + "origin_server_ts": MilliSecondsSinceUnixEpoch::now(), + "sender": alice.user_id(), + "type": "m.room.encrypted", + "content": encrypted_content, + }); + + let event = json_convert(&event).unwrap(); + + let encryption_info = + bob.decrypt_room_event(&event, room_id).await.unwrap().encryption_info.unwrap(); + + assert_eq!( + VerificationState::Unverified(VerificationLevel::UnsignedDevice), + encryption_info.verification_state + ); + + assert_shield!(encryption_info, Red, Red); + + // Local trust state has no effect + bob.get_device(alice.user_id(), alice_device_id(), None) + .await + .unwrap() + .unwrap() + .set_trust_state(LocalTrust::Verified); + + let encryption_info = + bob.decrypt_room_event(&event, room_id).await.unwrap().encryption_info.unwrap(); + + assert_eq!( + VerificationState::Unverified(VerificationLevel::UnsignedDevice), + encryption_info.verification_state + ); + assert_shield!(encryption_info, Red, Red); + + setup_cross_signing_for_machine(&alice, &bob).await; + let bob_id_from_alice = alice.get_identity(bob.user_id(), None).await.unwrap(); + assert_matches!(bob_id_from_alice, Some(UserIdentities::Other(_))); + let alice_id_from_bob = bob.get_identity(alice.user_id(), None).await.unwrap(); + assert_matches!(alice_id_from_bob, Some(UserIdentities::Other(_))); + + // we setup cross signing but nothing is signed yet + let encryption_info = + bob.decrypt_room_event(&event, room_id).await.unwrap().encryption_info.unwrap(); + + assert_eq!( + VerificationState::Unverified(VerificationLevel::UnsignedDevice), + encryption_info.verification_state + ); + assert_shield!(encryption_info, Red, Red); + + // Let alice sign her device + sign_alice_device_for_machine(&alice, &bob).await; + + let encryption_info = + bob.decrypt_room_event(&event, room_id).await.unwrap().encryption_info.unwrap(); + + assert_eq!( + VerificationState::Unverified(VerificationLevel::UnverifiedIdentity), + encryption_info.verification_state + ); + + assert_shield!(encryption_info, Red, None); + + mark_alice_identity_as_verified(&alice, &bob).await; + let encryption_info = + bob.decrypt_room_event(&event, room_id).await.unwrap().encryption_info.unwrap(); + assert_eq!(VerificationState::Verified, encryption_info.verification_state); + assert_shield!(encryption_info, None, None); + + // Simulate an imported session, to change verification state + let imported = InboundGroupSession::from_export(&export).unwrap(); + bob.store.save_inbound_group_sessions(&[imported]).await.unwrap(); + + let encryption_info = + bob.decrypt_room_event(&event, room_id).await.unwrap().encryption_info.unwrap(); + + // As soon as the key source is unsafe the verification state (or existence) of + // the device is meaningless + assert_eq!( + VerificationState::Unverified(VerificationLevel::None( + DeviceLinkProblem::InsecureSource + )), + encryption_info.verification_state + ); + + assert_shield!(encryption_info, Red, Grey); + } + + async fn setup_cross_signing_for_machine(alice: &OlmMachine, bob: &OlmMachine) { + let (alice_upload_signing, _) = + alice.bootstrap_cross_signing(false).await.expect("Expect Alice x-signing key request"); + + let (bob_upload_signing, _) = + bob.bootstrap_cross_signing(false).await.expect("Expect Bob x-signing key request"); + + let bob_device_keys = bob + .get_device(bob.user_id(), bob.device_id(), None) + .await + .unwrap() + .unwrap() + .as_device_keys() + .to_owned(); + + let alice_device_keys = alice + .get_device(alice.user_id(), alice.device_id(), None) + .await + .unwrap() + .unwrap() + .as_device_keys() + .to_owned(); + + // We only want to setup cross signing we don't actually sign the current + // devices. so we ignore the new device signatures + let json = json!({ + "device_keys": { + bob.user_id() : { bob.device_id() : bob_device_keys}, + alice.user_id() : { alice.device_id(): alice_device_keys } + }, + "failures": {}, + "master_keys": { + bob.user_id() : bob_upload_signing.master_key.unwrap(), + alice.user_id() : alice_upload_signing.master_key.unwrap() + }, + "user_signing_keys": { + bob.user_id() : bob_upload_signing.user_signing_key.unwrap(), + alice.user_id() : alice_upload_signing.user_signing_key.unwrap() + }, + "self_signing_keys": { + bob.user_id() : bob_upload_signing.self_signing_key.unwrap(), + alice.user_id() : alice_upload_signing.self_signing_key.unwrap() + }, + } + ); + + let kq_response = KeyQueryResponse::try_from_http_response(response_from_file(&json)) + .expect("Can't parse the keys upload response"); + + alice.receive_keys_query_response(&TransactionId::new(), &kq_response).await.unwrap(); + bob.receive_keys_query_response(&TransactionId::new(), &kq_response).await.unwrap(); + } + + async fn sign_alice_device_for_machine(alice: &OlmMachine, bob: &OlmMachine) { + let (upload_signing, upload_signature) = + alice.bootstrap_cross_signing(false).await.expect("Expect Alice x-signing key request"); + + let mut device_keys = alice + .get_device(alice.user_id(), alice.device_id(), None) + .await + .unwrap() + .unwrap() + .as_device_keys() + .to_owned(); + + let raw_extracted = upload_signature + .signed_keys + .get(alice.user_id()) + .unwrap() + .iter() + .next() + .unwrap() + .1 + .get(); + + let new_signature: DeviceKeys = serde_json::from_str(raw_extracted).unwrap(); + + let self_sign_key_id = upload_signing + .self_signing_key + .as_ref() + .unwrap() + .get_first_key_and_id() + .unwrap() + .0 + .to_owned(); + + device_keys.signatures.add_signature( + alice.user_id().to_owned(), + self_sign_key_id.to_owned(), + new_signature.signatures.get_signature(alice.user_id(), &self_sign_key_id).unwrap(), + ); + + let updated_keys_with_x_signing = json!({ device_keys.device_id.to_string(): device_keys }); + + let json = json!({ + "device_keys": { + alice.user_id() : updated_keys_with_x_signing + }, + "failures": {}, + "master_keys": { + alice.user_id() : upload_signing.master_key.unwrap(), + }, + "user_signing_keys": { + alice.user_id() : upload_signing.user_signing_key.unwrap(), + }, + "self_signing_keys": { + alice.user_id() : upload_signing.self_signing_key.unwrap(), + }, + } + ); + + let kq_response = KeyQueryResponse::try_from_http_response(response_from_file(&json)) + .expect("Can't parse the keys upload response"); + + alice.receive_keys_query_response(&TransactionId::new(), &kq_response).await.unwrap(); + bob.receive_keys_query_response(&TransactionId::new(), &kq_response).await.unwrap(); + } + + async fn mark_alice_identity_as_verified(alice: &OlmMachine, bob: &OlmMachine) { + let alice_device = + bob.get_device(alice.user_id(), alice.device_id(), None).await.unwrap().unwrap(); + + let alice_identity = + bob.get_identity(alice.user_id(), None).await.unwrap().unwrap().other().unwrap(); + let upload_request = alice_identity.verify().await.unwrap(); + + let raw_extracted = + upload_request.signed_keys.get(alice.user_id()).unwrap().iter().next().unwrap().1.get(); + + let new_signature: CrossSigningKey = serde_json::from_str(raw_extracted).unwrap(); + + let user_key_id = bob + .bootstrap_cross_signing(false) + .await + .expect("Expect Alice x-signing key request") + .0 + .user_signing_key + .unwrap() + .get_first_key_and_id() + .unwrap() + .0 + .to_owned(); + + // add the new signature to alice msk + let mut alice_updated_msk = + alice_device.device_owner_identity.as_ref().unwrap().master_key().as_ref().to_owned(); + + alice_updated_msk.signatures.add_signature( + bob.user_id().to_owned(), + user_key_id.to_owned(), + new_signature.signatures.get_signature(bob.user_id(), &user_key_id).unwrap(), + ); + + let alice_x_keys = alice + .bootstrap_cross_signing(false) + .await + .expect("Expect Alice x-signing key request") + .0; + + let json = json!({ + "device_keys": { + alice.user_id() : { alice.device_id(): alice_device.as_device_keys().to_owned() } + }, + "failures": {}, + "master_keys": { + alice.user_id() : alice_updated_msk, + }, + "user_signing_keys": { + alice.user_id() : alice_x_keys.user_signing_key.unwrap(), + }, + "self_signing_keys": { + alice.user_id() : alice_x_keys.self_signing_key.unwrap(), + }, + } + ); + + let kq_response = KeyQueryResponse::try_from_http_response(response_from_file(&json)) + .expect("Can't parse the keys upload response"); + + alice.receive_keys_query_response(&TransactionId::new(), &kq_response).await.unwrap(); + bob.receive_keys_query_response(&TransactionId::new(), &kq_response).await.unwrap(); + + // so alice identity should be now trusted + + assert!(bob + .get_identity(alice.user_id(), None) + .await + .unwrap() + .unwrap() + .other() + .unwrap() + .is_verified()); + } + + #[async_test] + async fn test_verication_states_multiple_device() { + let (bob, _) = get_prepared_machine().await; + + let other_user_id = user_id!("@web2:localhost:8482"); + + let data = response_from_file(&test_json::KEYS_QUERY_TWO_DEVICES_ONE_SIGNED); + let response = get_keys::v3::Response::try_from_http_response(data) + .expect("Can't parse the keys upload response"); + + let (device_change, identity_change) = + bob.receive_keys_query_response(&TransactionId::new(), &response).await.unwrap(); + assert_eq!(device_change.new.len(), 2); + assert_eq!(identity_change.new.len(), 1); + // + let devices = bob.store.get_user_devices(other_user_id).await.unwrap(); + assert_eq!(devices.devices().count(), 2); + + let fake_room_id = room_id!("!roomid:example.com"); + + // We just need a fake session to export it + // We will use the export to create various inbounds with other claimed + // ownership + let id_keys = bob.identity_keys(); + let fake_device_id = bob.device_id.clone(); + let olm = OutboundGroupSession::new( + fake_device_id, + Arc::new(id_keys), + fake_room_id, + EncryptionSettings::default(), + ) + .unwrap() + .session_key() + .await; + + let web_unverified_inbound_session = InboundGroupSession::new( + Curve25519PublicKey::from_base64("LTpv2DGMhggPAXO02+7f68CNEp6A40F0Yl8B094Y8gc") + .unwrap(), + Ed25519PublicKey::from_base64("loz5i40dP+azDtWvsD0L/xpnCjNkmrcvtXVXzCHX8Vw").unwrap(), + fake_room_id, + &olm, + EventEncryptionAlgorithm::MegolmV1AesSha2, + None, + ) + .unwrap(); + + let (state, _) = bob + .get_verification_state(&web_unverified_inbound_session, other_user_id) + .await + .unwrap(); + assert_eq!(VerificationState::Unverified(VerificationLevel::UnsignedDevice), state); + + let web_signed_inbound_session = InboundGroupSession::new( + Curve25519PublicKey::from_base64("XJixbpnfIk+RqcK5T6moqVY9d9Q1veR8WjjSlNiQNT0") + .unwrap(), + Ed25519PublicKey::from_base64("48f3WQAMGwYLBg5M5qUhqnEVA8yeibjZpPsShoWMFT8").unwrap(), + fake_room_id, + &olm, + EventEncryptionAlgorithm::MegolmV1AesSha2, + None, + ) + .unwrap(); + + let (state, _) = + bob.get_verification_state(&web_signed_inbound_session, other_user_id).await.unwrap(); + + assert_eq!(VerificationState::Unverified(VerificationLevel::UnverifiedIdentity), state); + } + + #[async_test] + #[cfg(feature = "automatic-room-key-forwarding")] async fn test_query_ratcheted_key() { let (alice, bob) = get_machine_pair_with_setup_sessions().await; let room_id = room_id!("!test:example.org"); @@ -2139,7 +2634,7 @@ pub(crate) mod tests { let decrypt_error = bob.decrypt_room_event(&room_event, room_id).await.unwrap_err(); - if let MegolmError::Decryption(vodo_error) = decrypt_error { + if let crate::MegolmError::Decryption(vodo_error) = decrypt_error { if let vodozemac::megolm::DecryptionError::UnknownMessageIndex(_, _) = vodo_error { // check that key has been requested let outgoing_to_devices = @@ -2471,7 +2966,7 @@ pub(crate) mod tests { .unwrap() .into(); let signing_keys = SigningKeys::from([(DeviceKeyAlgorithm::Ed25519, fake_key)]); - inbound.signing_keys = signing_keys.into(); + inbound.creator_info.signing_keys = signing_keys.into(); let content = json!({}); let content = outbound.encrypt(content, "m.dummy").await; diff --git a/crates/matrix-sdk-crypto/src/olm/account.rs b/crates/matrix-sdk-crypto/src/olm/account.rs index 56d6e3399ae..07955615fde 100644 --- a/crates/matrix-sdk-crypto/src/olm/account.rs +++ b/crates/matrix-sdk-crypto/src/olm/account.rs @@ -36,7 +36,7 @@ use ruma::{ use serde::{Deserialize, Serialize}; use serde_json::{value::RawValue as RawJsonValue, Value}; use sha2::{Digest, Sha256}; -use tracing::{debug, info, trace, warn}; +use tracing::{debug, info, instrument, trace, warn, Span}; use vodozemac::{ olm::{ Account as InnerAccount, AccountPickle, IdentityKeys, OlmMessage, PreKeyMessage, @@ -53,7 +53,7 @@ use super::{ use crate::types::events::room::encrypted::OlmV2Curve25519AesSha2Content; use crate::{ error::{EventError, OlmResult, SessionCreationError}, - identities::{MasterPubkey, ReadOnlyDevice}, + identities::ReadOnlyDevice, requests::UploadSigningKeysRequest, store::{Changes, Store}, types::{ @@ -64,20 +64,20 @@ use crate::{ ToDeviceEncryptedEventContent, }, }, - CrossSigningKey, DeviceKeys, EventEncryptionAlgorithm, OneTimeKey, SignedKey, + CrossSigningKey, DeviceKeys, EventEncryptionAlgorithm, MasterPubkey, OneTimeKey, SignedKey, }, utilities::encode, CryptoStoreError, OlmError, SignatureError, }; #[derive(Debug, Clone)] -pub struct Account { +pub(crate) struct Account { pub inner: ReadOnlyAccount, pub store: Store, } #[derive(Debug, Clone)] -pub enum SessionType { +pub(crate) enum SessionType { New(Session), Existing(Session), } @@ -286,6 +286,7 @@ impl Account { } /// Decrypt an Olm message, creating a new Olm session if possible. + #[instrument(skip(self, message), fields(session_id))] async fn decrypt_olm_message( &self, sender: &UserId, @@ -293,59 +294,72 @@ impl Account { message: &OlmMessage, ) -> OlmResult<(SessionType, DecryptionResult)> { // First try to decrypt using an existing session. - let (session, plaintext) = - if let Some(d) = self.decrypt_with_existing_sessions(sender_key, message).await? { - // Decryption succeeded, de-structure the session/plaintext out of - // the Option. - (SessionType::Existing(d.0), d.1) - } else { - // Decryption failed with every known session, let's try to create a - // new session. - match message { - // A new session can only be created using a pre-key message, - // return with an error if it isn't one. - OlmMessage::Normal(_) => { - warn!( - ?sender_key, - "Failed to decrypt a non-pre-key message with all \ - available sessions", - ); - - return Err(OlmError::SessionWedged(sender.to_owned(), sender_key)); - } + let (session, plaintext) = if let Some(d) = + self.decrypt_with_existing_sessions(sender_key, message).await? + { + // Decryption succeeded, de-structure the session/plaintext out of + // the Option. + (SessionType::Existing(d.0), d.1) + } else { + // Decryption failed with every known session, let's try to create a + // new session. + match message { + // A new session can only be created using a pre-key message, + // return with an error if it isn't one. + OlmMessage::Normal(_) => { + let session_ids = if let Some(sessions) = + self.store.get_sessions(&sender_key.to_base64()).await? + { + sessions.lock().await.iter().map(|s| s.session_id().to_owned()).collect() + } else { + vec![] + }; - OlmMessage::PreKey(m) => { - // Create the new session. - let result = match self.inner.create_inbound_session(sender_key, m).await { - Ok(r) => r, - Err(e) => { - warn!( - ?sender_key, - session_keys = ?m.session_keys(), - "Failed to create a new Olm session from a \ - pre-key message: {e:?}", - ); - return Err(OlmError::SessionWedged(sender.to_owned(), sender_key)); - } - }; + warn!( + ?session_ids, + "Failed to decrypt a non-pre-key message with all available sessions", + ); - // We need to add the new session to the session cache, otherwise - // we might try to create the same session again. - // TODO separate the session cache from the storage so we only add - // it to the cache but don't store it. - let changes = Changes { - account: Some(self.inner.clone()), - sessions: vec![result.session.clone()], - ..Default::default() - }; - self.store.save_changes(changes).await?; + return Err(OlmError::SessionWedged(sender.to_owned(), sender_key)); + } - (SessionType::New(result.session), result.plaintext) - } + OlmMessage::PreKey(m) => { + // Create the new session. + let result = match self.inner.create_inbound_session(sender_key, m).await { + Ok(r) => r, + Err(e) => { + warn!( + session_keys = ?m.session_keys(), + "Failed to create a new Olm session from a pre-key message: {e:?}", + ); + + return Err(OlmError::SessionWedged(sender.to_owned(), sender_key)); + } + }; + + // We need to add the new session to the session cache, otherwise + // we might try to create the same session again. + // TODO: separate the session cache from the storage so we only add + // it to the cache but don't store it. + let changes = Changes { + account: Some(self.inner.clone()), + sessions: vec![result.session.clone()], + ..Default::default() + }; + self.store.save_changes(changes).await?; + + (SessionType::New(result.session), result.plaintext) } - }; + } + }; - trace!(?sender_key, "Successfully decrypted an Olm message"); + let session_id = match &session { + SessionType::New(s) => s.session_id(), + SessionType::Existing(s) => s.session_id(), + }; + + Span::current().record("session_id", session_id); + trace!("Successfully decrypted an Olm message"); match self.parse_decrypted_to_device_event(sender, sender_key, plaintext).await { Ok(result) => Ok((session, result)), @@ -368,7 +382,6 @@ impl Account { } warn!( - sender_key = sender_key.to_base64(), error = ?e, "A to-device message was successfully decrypted but \ parsing and checking the event fields failed" @@ -1045,22 +1058,28 @@ impl ReadOnlyAccount { /// /// * `message` - A pre-key Olm message that was sent to us by the other /// account. + #[instrument( + skip_all, + fields( + sender_key = ?their_identity_key, + session_id = message.session_id(), + session_keys = ?message.session_keys(), + ) + )] pub async fn create_inbound_session( &self, their_identity_key: Curve25519PublicKey, message: &PreKeyMessage, ) -> Result { - debug!( - sender_key = ?their_identity_key, - session_keys = ?message.session_keys(), - "Creating a new Olm session from a pre-key message" - ); + debug!("Creating a new Olm session from a pre-key message"); let result = self.inner.lock().await.create_inbound_session(their_identity_key, message)?; let now = SecondsSinceUnixEpoch::now(); let session_id = result.session.session_id(); + trace!(?session_id, "Olm session created successfully"); + let session = Session { user_id: self.user_id.clone(), device_id: self.device_id.clone(), diff --git a/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs b/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs index 5ba13c7c7ee..3e36f855a75 100644 --- a/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs +++ b/crates/matrix-sdk-crypto/src/olm/group_sessions/inbound.rs @@ -60,25 +60,91 @@ use crate::{ // sessions that were created between some time period, this should only be set // for non-imported sessions. -/// Inbound group session. +/// Information about the creator of an inbound group session. +#[derive(Clone)] +pub(crate) struct SessionCreatorInfo { + /// The Curve25519 identity key of the session creator. + /// + /// If the session was received directly from its creator device through an + /// `m.room_key` event (and therefore, session sender == session creator), + /// this key equals the Curve25519 device identity key of that device. Since + /// this key is one of three keys used to establish the Olm session through + /// which encrypted to-device messages (including `m.room_key`) are sent, + /// this constitutes a proof that this inbound group session is owned by + /// that particular Curve25519 key. + /// + /// However, if the session was simply forwarded to us in an + /// `m.forwarded_room_key` event (in which case sender != creator), this key + /// is just a *claim* made by the session sender of what the actual creator + /// device is. + pub curve25519_key: Curve25519PublicKey, + + /// A mapping of DeviceKeyAlgorithm to the public signing keys of the + /// [`Device`] that sent us the session. + /// + /// If the session was received directly from the creator via an + /// `m.room_key` event, this map is taken from the plaintext value of + /// the decrypted Olm event, and is a copy of the + /// [`DecryptedOlmV1Event::keys`] field as defined in the [spec]. + /// + /// If the session was forwarded to us using an `m.forwarded_room_key`, this + /// map is a copy of the claimed Ed25519 key from the content of the + /// event. + /// + /// [spec]: https://spec.matrix.org/unstable/client-server-api/#molmv1curve25519-aes-sha2 + pub signing_keys: Arc>, +} + +/// A structure representing an inbound group session. +/// +/// Inbound group sessions, also known as "room keys", are used to facilitate +/// the exchange of room messages among a group of participants. The inbound +/// variant of the group session is used to decrypt the room messages. /// -/// Inbound group sessions are used to exchange room messages between a group of -/// participants. Inbound group sessions are used to decrypt the room messages. +/// This struct wraps the [vodozemac] type of the same name, and adds additional +/// Matrix-specific data to it. Additionally, the wrapper ensures thread-safe +/// access of the vodozemac type. +/// +/// [vodozemac]: https://matrix-org.github.io/vodozemac/vodozemac/index.html #[derive(Clone)] pub struct InboundGroupSession { inner: Arc>, - history_visibility: Arc>, - /// The SessionId associated to this GroupSession - pub session_id: Arc, + + /// A copy of [`InnerSession::session_id`] to avoid having to acquire a lock + /// to get to the sesison ID. + session_id: Arc, + + /// A copy of [`InnerSession::first_known_index`] to avoid having to acquire + /// a lock to get to the first known index. first_known_index: u32, - /// The sender_key associated to this GroupSession - pub sender_key: Curve25519PublicKey, - /// Map of DeviceKeyAlgorithm to the public ed25519 key of the account - pub signing_keys: Arc>, + + /// Information about the creator of the [`InboundGroupSession`] ("room + /// key"). The trustworthiness of the information in this field depends + /// on how the session was received. + pub(crate) creator_info: SessionCreatorInfo, + /// The Room this GroupSession belongs to pub room_id: Arc, + + /// A flag recording whether the `InboundGroupSession` was received directly + /// as a `m.room_key` event or indirectly via a forward or file import. + /// + /// If the session is considered to be imported, the information contained + /// in the `InboundGroupSession::creator_info` field is not proven to be + /// correct. imported: bool, + + /// The messaging algorithm of this [`InboundGroupSession`] as defined by + /// the [spec]. Will be one of the `m.megolm.*` algorithms. + /// + /// [spec]: https://spec.matrix.org/unstable/client-server-api/#messaging-algorithms algorithm: Arc, + + /// The history visibility of the room at the time when the room key was + /// created. + history_visibility: Arc>, + + /// Was this room key backed up to the server. backed_up: Arc, } @@ -89,10 +155,10 @@ impl InboundGroupSession { /// /// # Arguments /// - /// * `sender_key` - The public curve25519 key of the account that - /// sent us the session + /// * `sender_key` - The public Curve25519 key of the account that + /// sent us the session. /// - /// * `signing_key` - The public ed25519 key of the account that + /// * `signing_key` - The public Ed25519 key of the account that /// sent us the session. /// /// * `room_id` - The id of the room that the session is used in. @@ -121,8 +187,10 @@ impl InboundGroupSession { history_visibility: history_visibility.into(), session_id: session_id.into(), first_known_index, - sender_key, - signing_keys: keys.into(), + creator_info: SessionCreatorInfo { + curve25519_key: sender_key, + signing_keys: keys.into(), + }, room_id: room_id.into(), imported: false, algorithm: encryption_algorithm.into(), @@ -173,9 +241,9 @@ impl InboundGroupSession { PickledInboundGroupSession { pickle, - sender_key: self.sender_key, - signing_key: (*self.signing_keys).clone(), - room_id: (*self.room_id).to_owned(), + sender_key: self.creator_info.curve25519_key, + signing_key: (*self.creator_info.signing_keys).clone(), + room_id: self.room_id().to_owned(), imported: self.imported, backed_up: self.backed_up(), history_visibility: self.history_visibility.as_ref().clone(), @@ -193,7 +261,7 @@ impl InboundGroupSession { /// Get the sender key that this session was received from. pub fn sender_key(&self) -> Curve25519PublicKey { - self.sender_key + self.creator_info.curve25519_key } /// Has the session been backed up to the server. @@ -214,7 +282,7 @@ impl InboundGroupSession { /// Get the map of signing keys this session was received from. pub fn signing_keys(&self) -> &SigningKeys { - &self.signing_keys + &self.creator_info.signing_keys } /// Export this session at the given message index. @@ -226,11 +294,11 @@ impl InboundGroupSession { ExportedRoomKey { algorithm: self.algorithm().to_owned(), - room_id: (*self.room_id).to_owned(), - sender_key: self.sender_key, + room_id: self.room_id().to_owned(), + sender_key: self.creator_info.curve25519_key, session_id: self.session_id().to_owned(), forwarding_curve25519_key_chain: vec![], - sender_claimed_keys: (*self.signing_keys).clone(), + sender_claimed_keys: (*self.creator_info.signing_keys).clone(), session_key, } } @@ -254,10 +322,12 @@ impl InboundGroupSession { Ok(InboundGroupSession { inner: Mutex::new(session).into(), session_id: session_id.into(), - sender_key: pickle.sender_key, + creator_info: SessionCreatorInfo { + curve25519_key: pickle.sender_key, + signing_keys: pickle.signing_key.into(), + }, history_visibility: pickle.history_visibility.into(), first_known_index, - signing_keys: pickle.signing_key.into(), room_id: (*pickle.room_id).into(), backed_up: AtomicBool::from(pickle.backed_up).into(), algorithm: pickle.algorithm.into(), @@ -420,7 +490,7 @@ impl PartialEq for InboundGroupSession { pub struct PickledInboundGroupSession { /// The pickle string holding the InboundGroupSession. pub pickle: InboundGroupSessionPickle, - /// The public curve25519 key of the account that sent us the session + /// The public Curve25519 key of the account that sent us the session #[serde(deserialize_with = "deserialize_curve_key", serialize_with = "serialize_curve_key")] pub sender_key: Curve25519PublicKey, /// The public ed25519 key of the account that sent us the session. @@ -455,10 +525,12 @@ impl TryFrom<&ExportedRoomKey> for InboundGroupSession { Ok(InboundGroupSession { inner: Mutex::new(session).into(), session_id: key.session_id.to_owned().into(), - sender_key: key.sender_key, + creator_info: SessionCreatorInfo { + curve25519_key: key.sender_key, + signing_keys: key.sender_claimed_keys.to_owned().into(), + }, history_visibility: None.into(), first_known_index, - signing_keys: key.sender_claimed_keys.to_owned().into(), room_id: key.room_id.to_owned().into(), imported: true, algorithm: key.algorithm.to_owned().into(), @@ -476,14 +548,16 @@ impl From<&ForwardedMegolmV1AesSha2Content> for InboundGroupSession { InboundGroupSession { inner: Mutex::new(session).into(), session_id, - sender_key: value.claimed_sender_key, + creator_info: SessionCreatorInfo { + curve25519_key: value.claimed_sender_key, + signing_keys: SigningKeys::from([( + DeviceKeyAlgorithm::Ed25519, + value.claimed_ed25519_key.into(), + )]) + .into(), + }, history_visibility: None.into(), first_known_index, - signing_keys: SigningKeys::from([( - DeviceKeyAlgorithm::Ed25519, - value.claimed_ed25519_key.into(), - )]) - .into(), room_id: value.room_id.to_owned().into(), imported: true, algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2.into(), @@ -501,10 +575,12 @@ impl From<&ForwardedMegolmV2AesSha2Content> for InboundGroupSession { InboundGroupSession { inner: Mutex::new(session).into(), session_id, - sender_key: value.claimed_sender_key, + creator_info: SessionCreatorInfo { + curve25519_key: value.claimed_sender_key, + signing_keys: value.claimed_signing_keys.to_owned().into(), + }, history_visibility: None.into(), first_known_index, - signing_keys: value.claimed_signing_keys.to_owned().into(), room_id: value.room_id.to_owned().into(), imported: true, algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2.into(), @@ -604,7 +680,7 @@ mod test { assert_eq!(inbound.compare(&inbound).await, SessionOrdering::Equal); assert_eq!(inbound.compare(©).await, SessionOrdering::Equal); - copy.sender_key = + copy.creator_info.curve25519_key = Curve25519PublicKey::from_base64("XbmrPa1kMwmdtNYng1B2gsfoo8UtF+NklzsTZiaVKyY") .unwrap(); diff --git a/crates/matrix-sdk-crypto/src/olm/signing/mod.rs b/crates/matrix-sdk-crypto/src/olm/signing/mod.rs index 50f3c8a7cb2..300ce488f48 100644 --- a/crates/matrix-sdk-crypto/src/olm/signing/mod.rs +++ b/crates/matrix-sdk-crypto/src/olm/signing/mod.rs @@ -28,15 +28,13 @@ use ruma::{ DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedUserId, UserId, }; use serde::{Deserialize, Serialize}; -use serde_json::Error as JsonError; use vodozemac::Ed25519Signature; use crate::{ error::SignatureError, - identities::{MasterPubkey, SelfSigningPubkey, UserSigningPubkey}, requests::UploadSigningKeysRequest, store::SecretImportError, - types::DeviceKeys, + types::{DeviceKeys, MasterPubkey, SelfSigningPubkey, UserSigningPubkey}, OwnUserIdentity, ReadOnlyAccount, ReadOnlyDevice, ReadOnlyOwnUserIdentity, ReadOnlyUserIdentity, }; @@ -610,7 +608,7 @@ impl PrivateCrossSigningIdentity { /// # Panics /// /// This will panic if the provided pickle key isn't 32 bytes long. - pub async fn pickle(&self) -> Result { + pub async fn pickle(&self) -> PickledCrossSigningIdentity { let master_key = self.master_key.lock().await.as_ref().map(|m| m.pickle()); let self_signing_key = self.self_signing_key.lock().await.as_ref().map(|m| m.pickle()); @@ -619,11 +617,11 @@ impl PrivateCrossSigningIdentity { let keys = PickledSignings { master_key, user_signing_key, self_signing_key }; - Ok(PickledCrossSigningIdentity { + PickledCrossSigningIdentity { user_id: self.user_id.as_ref().to_owned(), shared: self.shared(), keys, - }) + } } /// Restore the private cross signing identity from a pickle. @@ -736,7 +734,7 @@ mod tests { async fn identity_pickling() { let identity = PrivateCrossSigningIdentity::new(user_id().to_owned()).await; - let pickled = identity.pickle().await.unwrap(); + let pickled = identity.pickle().await; let unpickled = PrivateCrossSigningIdentity::from_pickle(pickled).await.unwrap(); diff --git a/crates/matrix-sdk-crypto/src/olm/signing/pk_signing.rs b/crates/matrix-sdk-crypto/src/olm/signing/pk_signing.rs index f3de22427b6..ed9ff1b3b9b 100644 --- a/crates/matrix-sdk-crypto/src/olm/signing/pk_signing.rs +++ b/crates/matrix-sdk-crypto/src/olm/signing/pk_signing.rs @@ -20,9 +20,11 @@ use vodozemac::{Ed25519PublicKey, Ed25519SecretKey, Ed25519Signature, KeyError}; use crate::{ error::SignatureError, - identities::{MasterPubkey, SelfSigningPubkey, UserSigningPubkey}, olm::utility::SignJson, - types::{CrossSigningKey, DeviceKeys, Signatures, SigningKeys}, + types::{ + CrossSigningKey, DeviceKeys, MasterPubkey, SelfSigningPubkey, Signatures, SigningKeys, + UserSigningPubkey, + }, utilities::{encode, DecodeError}, ReadOnlyUserIdentity, }; diff --git a/crates/matrix-sdk-crypto/src/requests.rs b/crates/matrix-sdk-crypto/src/requests.rs index 5cec2b344e4..068cd58744e 100644 --- a/crates/matrix-sdk-crypto/src/requests.rs +++ b/crates/matrix-sdk-crypto/src/requests.rs @@ -192,7 +192,9 @@ pub struct KeysQueryRequest { } impl KeysQueryRequest { - pub(crate) fn new(device_keys: BTreeMap>) -> Self { + pub(crate) fn new(users: impl Iterator) -> Self { + let device_keys = users.map(|u| (u, Vec::new())).collect(); + Self { timeout: None, device_keys, token: None } } } diff --git a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs index af8783a67d4..9d177b6bda1 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/group_sessions.rs @@ -95,6 +95,7 @@ impl GroupSessionCache { /// /// This is the same as [get_or_load()](#method.get_or_load) but it will /// filter out the session if it doesn't match the given session id. + #[cfg(feature = "automatic-room-key-forwarding")] pub async fn get_with_id( &self, room_id: &RoomId, diff --git a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs index df0e89dca01..7dc021fb20a 100644 --- a/crates/matrix-sdk-crypto/src/session_manager/sessions.rs +++ b/crates/matrix-sdk-crypto/src/session_manager/sessions.rs @@ -34,10 +34,9 @@ use vodozemac::Curve25519PublicKey; use crate::{ error::OlmResult, gossiping::GossipMachine, - identities::{KeysQueryListener, UserKeyQueryResult}, olm::Account, requests::{OutgoingRequest, ToDeviceRequest}, - store::{Changes, Result as StoreResult, Store}, + store::{Changes, Result as StoreResult, Store, UserKeyQueryResult}, types::{events::EventType, EventEncryptionAlgorithm}, utilities::FailuresCache, ReadOnlyDevice, @@ -55,7 +54,6 @@ pub(crate) struct SessionManager { wedged_devices: Arc>>, key_request_machine: GossipMachine, outgoing_to_device_requests: Arc>, - keys_query_listener: KeysQueryListener, failures: FailuresCache, } @@ -69,7 +67,6 @@ impl SessionManager { users_for_key_claim: Arc>>, key_request_machine: GossipMachine, store: Store, - keys_query_listener: KeysQueryListener, ) -> Self { Self { account, @@ -78,7 +75,6 @@ impl SessionManager { users_for_key_claim, wedged_devices: Default::default(), outgoing_to_device_requests: Default::default(), - keys_query_listener, failures: Default::default(), } } @@ -176,11 +172,11 @@ impl SessionManager { let user_devices = if user_devices.is_empty() { match self - .keys_query_listener - .wait_if_user_pending(Self::KEYS_QUERY_WAIT_TIME, user_id) + .store + .wait_if_user_key_query_pending(Self::KEYS_QUERY_WAIT_TIME, user_id) .await { - Ok(WasPending) => self.store.get_readonly_devices_filtered(user_id).await?, + WasPending => self.store.get_readonly_devices_filtered(user_id).await?, _ => user_devices, } } else { @@ -398,24 +394,31 @@ impl SessionManager { #[cfg(test)] mod tests { - use std::{collections::BTreeMap, iter, sync::Arc}; + use std::{collections::BTreeMap, iter, ops::Deref, sync::Arc}; use dashmap::DashMap; use matrix_sdk_common::locks::Mutex; use matrix_sdk_test::{async_test, response_from_file}; use ruma::{ - api::{client::keys::claim_keys::v3::Response as KeyClaimResponse, IncomingResponse}, + api::{ + client::keys::{ + claim_keys::v3::Response as KeyClaimResponse, + get_keys::v3::Response as KeysQueryResponse, + }, + IncomingResponse, + }, device_id, user_id, DeviceId, UserId, }; use serde_json::json; + use tracing::info; use super::SessionManager; use crate::{ gossiping::GossipMachine, - identities::{KeysQueryListener, ReadOnlyDevice}, + identities::{IdentityManager, ReadOnlyDevice}, olm::{Account, PrivateCrossSigningIdentity, ReadOnlyAccount}, session_manager::GroupSessionCache, - store::{CryptoStore, MemoryStore, Store}, + store::{IntoCryptoStore, MemoryStore, Store}, verification::VerificationMachine, }; @@ -464,7 +467,7 @@ mod tests { let users_for_key_claim = Arc::new(DashMap::new()); let account = ReadOnlyAccount::new(user_id, device_id); - let store: Arc = Arc::new(MemoryStore::new()); + let store = MemoryStore::new().into_crypto_store(); store.save_account(account.clone()).await.unwrap(); let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id))); let verification = @@ -487,13 +490,7 @@ mod tests { users_for_key_claim.clone(), ); - SessionManager::new( - account, - users_for_key_claim, - key_request, - store.clone(), - KeysQueryListener::new(store), - ) + SessionManager::new(account, users_for_key_claim, key_request, store) } #[async_test] @@ -528,6 +525,64 @@ mod tests { assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none()); } + #[async_test] + async fn session_creation_waits_for_keys_query() { + let manager = session_manager().await; + let identity_manager = IdentityManager::new( + manager.account.user_id.clone(), + manager.account.device_id.clone(), + manager.store.clone(), + ); + + // start a keys query request. At this point, we are only interested in our own + // devices. + let (key_query_txn_id, key_query_request) = + identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap(); + info!("Initial key query: {:?}", key_query_request); + + // now bob turns up, and we start tracking his devices... + let bob = bob_account(); + let bob_device = ReadOnlyDevice::from_account(&bob).await; + manager.store.update_tracked_users(iter::once(bob.user_id())).await.unwrap(); + + // ... and start off an attempt to get the missing sessions. This should block + // for now. + let missing_sessions_task = { + let manager = manager.clone(); + let bob_user_id = bob.user_id.clone(); + + #[allow(unknown_lints, clippy::redundant_async_block)] // false positive + tokio::spawn(async move { + manager.get_missing_sessions(iter::once(bob_user_id.deref())).await + }) + }; + + // the initial keys query completes, and we start another + let response_json = json!({ "device_keys": { manager.account.user_id(): {}}}); + let response = + KeysQueryResponse::try_from_http_response(response_from_file(&response_json)).unwrap(); + identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap(); + + let (key_query_txn_id, key_query_request) = + identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap(); + info!("Second key query: {:?}", key_query_request); + + // that second request completes with info on bob's device + let response_json = json!({ "device_keys": { bob.user_id(): { + bob_device.device_id(): bob_device.as_device_keys() + }}}); + let response = + KeysQueryResponse::try_from_http_response(response_from_file(&response_json)).unwrap(); + identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap(); + + // the missing_sessions_task should now finally complete, with a claim + // including bob's device + let (_, keys_claim_request) = missing_sessions_task.await.unwrap().unwrap().unwrap(); + info!("Key claim request: {:?}", keys_claim_request.one_time_keys); + let bob_key_claims = keys_claim_request.one_time_keys.get(bob.user_id()).unwrap(); + assert!(bob_key_claims.contains_key(bob_device.device_id())); + } + // This test doesn't run on macos because we're modifying the session // creation time so we can get around the UNWEDGING_INTERVAL. #[async_test] diff --git a/crates/matrix-sdk-crypto/src/store/caches.rs b/crates/matrix-sdk-crypto/src/store/caches.rs index 1b5fe4e0987..04602135ff2 100644 --- a/crates/matrix-sdk-crypto/src/store/caches.rs +++ b/crates/matrix-sdk-crypto/src/store/caches.rs @@ -17,11 +17,17 @@ //! Note: You'll only be interested in these if you are implementing a custom //! `CryptoStore`. -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + fmt::Display, + sync::{atomic::AtomicBool, Arc, Weak}, +}; +use atomic::Ordering; use dashmap::DashMap; use matrix_sdk_common::locks::Mutex; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; +use tracing::{field::display, instrument, trace, Span}; use crate::{ identities::ReadOnlyDevice, @@ -85,7 +91,7 @@ impl GroupSessionStore { /// already in the store. pub fn add(&self, session: InboundGroupSession) -> bool { self.entries - .entry((*session.room_id).to_owned()) + .entry(session.room_id().to_owned()) .or_default() .insert(session.session_id().to_owned(), session) .is_none() @@ -163,16 +169,223 @@ impl DeviceStore { } } +/// A numeric type that can represent an infinite ordered sequence. +/// +/// It uses wrapping arithmetic to make sure we never run out of numbers. (2**64 +/// should be enough for anyone, but it's easy enough just to make it wrap.) +// +/// Internally it uses a *signed* counter so that we can compare values via a +/// subtraction. For example, suppose we've just overflowed from i64::MAX to +/// i64::MIN. (i64::MAX.wrapping_sub(i64::MIN)) is -1, which tells us that +/// i64::MAX comes before i64::MIN in the sequence. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub(crate) struct SequenceNumber(i64); + +impl Display for SequenceNumber { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl PartialOrd for SequenceNumber { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.0.wrapping_sub(other.0).cmp(&0)) + } +} + +impl Ord for SequenceNumber { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.wrapping_sub(other.0).cmp(&0) + } +} + +impl SequenceNumber { + fn increment(&mut self) { + self.0 = self.0.wrapping_add(1) + } + + fn previous(&self) -> Self { + Self(self.0.wrapping_sub(1)) + } +} + +/// Information on a task which is waiting for a `/keys/query` to complete. +#[derive(Debug)] +pub(super) struct KeysQueryWaiter { + /// The user that we are waiting for + user: OwnedUserId, + + /// The sequence number of the last invalidation of the users's device list + /// when we started waiting (ie, any `/keys/query` result with the same or + /// greater sequence number will satisfy this waiter) + sequence_number: SequenceNumber, + + /// Whether the `/keys/query` has completed. + /// + /// This is only modified whilst holding the mutex on `users_for_key_query`. + pub(super) completed: AtomicBool, +} + +/// Record of the users that are waiting for a /keys/query. +/// +/// To avoid races, we maintain a sequence number which is updated each time we +/// receive an invalidation notification. We also record the sequence number at +/// which each user was last invalidated. Then, we attach the current sequence +/// number to each `/keys/query` request, and when we get the response we can +/// tell if any users have been invalidated more recently than that request. +#[derive(Debug)] +pub(super) struct UsersForKeyQuery { + /// The sequence number we will assign to the next addition to user_map + next_sequence_number: SequenceNumber, + + /// The users pending a lookup, together with the sequence number at which + /// they were added to the list + user_map: HashMap, + + /// A list of tasks waiting for key queries to complete. + /// + /// We expect this list to remain fairly short, so don't bother partitioning + /// by user. + tasks_awaiting_key_query: Vec>, +} + +impl UsersForKeyQuery { + /// Create a new, empty, `UsersForKeyQueryCache` + pub(super) fn new() -> Self { + UsersForKeyQuery { + next_sequence_number: Default::default(), + user_map: Default::default(), + tasks_awaiting_key_query: Default::default(), + } + } + + /// Record a new user that requires a key query + pub(super) fn insert_user(&mut self, user: &UserId) { + let sequence_number = self.next_sequence_number; + + trace!(?user, %sequence_number, "Flagging user for key query"); + + self.user_map.insert(user.to_owned(), sequence_number); + self.next_sequence_number.increment(); + } + + /// Record that a user has received an update with the given sequence + /// number. + /// + /// If the sequence number is newer than the oldest invalidation for this + /// user, it is removed from the list of those needing an update. + /// + /// Returns true if the user is now up-to-date, else false + #[instrument(level = "trace", skip(self), fields(invalidation_sequence))] + pub(super) fn maybe_remove_user( + &mut self, + user: &UserId, + query_sequence: SequenceNumber, + ) -> bool { + let last_invalidation = self.user_map.get(user).copied(); + + // If there were any jobs waiting for this key query to complete, we can flag + // them as completed and remove them from our list. We also clear out any tasks + // that have been cancelled. + self.tasks_awaiting_key_query.retain(|waiter| { + let Some(waiter) = waiter.upgrade() else { + // the TaskAwaitingKeyQuery has been dropped, so it probably timed out and the + // caller went away. We can remove it from our list whether or not it's for this + // user. + trace!("removing expired waiting task"); + + return false; + }; + + if waiter.user == user && waiter.sequence_number <= query_sequence { + trace!( + ?user, + %query_sequence, + waiter_sequence = %waiter.sequence_number, + "Removing completed waiting task" + ); + + waiter.completed.store(true, Ordering::Relaxed); + + false + } else { + trace!( + ?user, + %query_sequence, + waiter_user = ?waiter.user, + waiter_sequence= %waiter.sequence_number, + "Retaining still-waiting task" + ); + + true + } + }); + + if let Some(last_invalidation) = last_invalidation { + Span::current().record("invalidation_sequence", display(last_invalidation)); + + if last_invalidation > query_sequence { + trace!("User invalidated since this query started: still not up-to-date"); + false + } else { + trace!("User now up-to-date"); + self.user_map.remove(user); + true + } + } else { + trace!("User already up-to-date, nothing to do"); + true + } + } + + /// Fetch the list of users waiting for a key query, and the current + /// sequence number + pub(super) fn users_for_key_query(&self) -> (HashSet, SequenceNumber) { + // we return the sequence number of the last invalidation + let sequence_number = self.next_sequence_number.previous(); + (self.user_map.keys().cloned().collect(), sequence_number) + } + + /// Check if a key query is pending for a user, and register for a wakeup if + /// so. + /// + /// If no key query is currently pending, returns `None`. Otherwise, returns + /// (an `Arc` to) a `KeysQueryWaiter`, whose `completed` flag will + /// be set once the lookup completes. + pub(super) fn maybe_register_waiting_task( + &mut self, + user: &UserId, + ) -> Option> { + match self.user_map.get(user) { + None => None, + Some(&sequence_number) => { + trace!(?user, %sequence_number, "Registering new waiting task"); + + let waiter = Arc::new(KeysQueryWaiter { + sequence_number, + user: user.to_owned(), + completed: AtomicBool::new(false), + }); + + self.tasks_awaiting_key_query.push(Arc::downgrade(&waiter)); + + Some(waiter) + } + } + } +} + #[cfg(test)] mod tests { use matrix_sdk_test::async_test; + use proptest::prelude::*; use ruma::room_id; use vodozemac::{Curve25519PublicKey, Ed25519PublicKey}; + use super::{DeviceStore, GroupSessionStore, SequenceNumber, SessionStore}; use crate::{ identities::device::testing::get_device, olm::{tests::get_account_and_session, InboundGroupSession}, - store::caches::{DeviceStore, GroupSessionStore, SessionStore}, }; #[async_test] @@ -263,4 +476,34 @@ mod tests { let loaded_device = store.get(device.user_id(), device.device_id()); assert!(loaded_device.is_none()); } + + #[test] + fn sequence_at_boundary() { + let first = SequenceNumber(i64::MAX); + let second = SequenceNumber(first.0.wrapping_add(1)); + let third = SequenceNumber(first.0.wrapping_sub(1)); + + assert!(second > first); + assert!(first < second); + assert!(third < first); + assert!(first > third); + assert!(second > third); + assert!(third < second); + } + + proptest! { + #[test] + fn partial_eq_sequence_number(sequence in i64::MIN..i64::MAX) { + let first = SequenceNumber(sequence); + let second = SequenceNumber(first.0.wrapping_add(1)); + let third = SequenceNumber(first.0.wrapping_sub(1)); + + assert!(second > first); + assert!(first < second); + assert!(third < first); + assert!(first > third); + assert!(second > third); + assert!(third < second); + } + } } diff --git a/crates/matrix-sdk-crypto/src/store/error.rs b/crates/matrix-sdk-crypto/src/store/error.rs new file mode 100644 index 00000000000..81f721b19e7 --- /dev/null +++ b/crates/matrix-sdk-crypto/src/store/error.rs @@ -0,0 +1,100 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{convert::Infallible, fmt::Debug, io::Error as IoError}; + +use ruma::{IdParseError, OwnedDeviceId, OwnedUserId}; +use serde_json::Error as SerdeError; +use thiserror::Error; + +use crate::olm::SessionCreationError; + +/// A `CryptoStore` specific result type. +pub type Result = std::result::Result; + +/// The crypto store's error type. +#[derive(Debug, Error)] +pub enum CryptoStoreError { + /// The account that owns the sessions, group sessions, and devices wasn't + /// found. + #[error("can't save/load sessions or group sessions in the store before an account is stored")] + AccountUnset, + + /// The store doesn't support multiple accounts and data from another device + /// was discovered. + #[error( + "the account in the store doesn't match the account in the constructor: \ + expected {}:{}, got {}:{}", .expected.0, .expected.1, .got.0, .got.1 + )] + MismatchedAccount { + /// The expected user/device id pair. + expected: (OwnedUserId, OwnedDeviceId), + /// The user/device id pair that was loaded from the store. + got: (OwnedUserId, OwnedDeviceId), + }, + + /// An IO error occurred. + #[error(transparent)] + Io(#[from] IoError), + + /// Failed to decrypt an pickled object. + #[error("An object failed to be decrypted while unpickling")] + UnpicklingError, + + /// Failed to decrypt an pickled object. + #[error(transparent)] + Pickle(#[from] vodozemac::PickleError), + + /// The received room key couldn't be converted into a valid Megolm session. + #[error(transparent)] + SessionCreation(#[from] SessionCreationError), + + /// A Matrix identifier failed to be validated. + #[error(transparent)] + IdentifierValidation(#[from] IdParseError), + + /// The store failed to (de)serialize a data type. + #[error(transparent)] + Serialization(#[from] SerdeError), + + /// The database format has changed in a backwards incompatible way. + #[error( + "The database format changed in an incompatible way, current \ + version: {0}, latest version: {1}" + )] + UnsupportedDatabaseVersion(usize, usize), + + /// A problem with the underlying database backend + #[error(transparent)] + Backend(Box), +} + +impl CryptoStoreError { + /// Create a new [`Backend`][Self::Backend] error. + /// + /// Shorthand for `StoreError::Backend(Box::new(error))`. + #[inline] + pub fn backend(error: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self::Backend(Box::new(error)) + } +} + +impl From for CryptoStoreError { + fn from(never: Infallible) -> Self { + match never {} + } +} diff --git a/crates/matrix-sdk-crypto/src/store/integration_tests.rs b/crates/matrix-sdk-crypto/src/store/integration_tests.rs index b80414ff204..bc700b08fab 100644 --- a/crates/matrix-sdk-crypto/src/store/integration_tests.rs +++ b/crates/matrix-sdk-crypto/src/store/integration_tests.rs @@ -17,10 +17,12 @@ macro_rules! cryptostore_integration_tests { }, store::{ Changes, CryptoStore, DeviceChanges, GossipRequest, IdentityChanges, - RecoveryKey, + RecoveryKey, RoomSettings, }, testing::{get_device, get_other_identity, get_own_identity}, - types::events::room_key_request::MegolmV1AesSha2Content, + types::{ + events::room_key_request::MegolmV1AesSha2Content, EventEncryptionAlgorithm, + }, ReadOnlyDevice, SecretInfo, TrackedUser, }; @@ -646,6 +648,56 @@ macro_rules! cryptostore_integration_tests { "The loaded version matches to the one we stored" ); } + + #[async_test] + async fn room_settings_saving() { + let (account, store) = get_loaded_store("room_settings_saving").await; + + let room_1 = room_id!("!test_1:localhost"); + let settings_1 = RoomSettings { + algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, + only_allow_trusted_devices: true, + }; + + let room_2 = room_id!("!test_2:localhost"); + let settings_2 = RoomSettings { + algorithm: EventEncryptionAlgorithm::OlmV1Curve25519AesSha2, + only_allow_trusted_devices: false, + }; + + let room_3 = room_id!("!test_3:localhost"); + + let changes = Changes { + room_settings: HashMap::from([ + (room_1.into(), settings_1.clone()), + (room_2.into(), settings_2.clone()), + ]), + ..Default::default() + }; + + store.save_changes(changes).await.unwrap(); + + let loaded_settings_1 = store.get_room_settings(room_1).await.unwrap(); + assert_eq!(Some(settings_1), loaded_settings_1); + + let loaded_settings_2 = store.get_room_settings(room_2).await.unwrap(); + assert_eq!(Some(settings_2), loaded_settings_2); + + let loaded_settings_3 = store.get_room_settings(room_3).await.unwrap(); + assert_eq!(None, loaded_settings_3); + } + + #[async_test] + async fn custom_value_saving() { + let (account, store) = get_loaded_store("custom_value_saving").await; + store.set_custom_value("A", "Hello".as_bytes().to_vec()).await.unwrap(); + + let loaded_1 = store.get_custom_value("A").await.unwrap(); + assert_eq!(Some("Hello".as_bytes().to_vec()), loaded_1); + + let loaded_2 = store.get_custom_value("B").await.unwrap(); + assert_eq!(None, loaded_2); + } } }; } diff --git a/crates/matrix-sdk-crypto/src/store/memorystore.rs b/crates/matrix-sdk-crypto/src/store/memorystore.rs index e7d02611545..7d31264358e 100644 --- a/crates/matrix-sdk-crypto/src/store/memorystore.rs +++ b/crates/matrix-sdk-crypto/src/store/memorystore.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, convert::Infallible, sync::Arc}; use async_trait::async_trait; use dashmap::{DashMap, DashSet}; @@ -20,11 +20,12 @@ use matrix_sdk_common::locks::Mutex; use ruma::{ DeviceId, OwnedDeviceId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId, UserId, }; +use tracing::warn; use super::{ caches::{DeviceStore, GroupSessionStore, SessionStore}, - BackupKeys, Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, Result, RoomKeyCounts, - Session, + BackupKeys, Changes, CryptoStore, InboundGroupSession, ReadOnlyAccount, RoomKeyCounts, + RoomSettings, Session, }; use crate::{ gossiping::{GossipRequest, SecretInfo}, @@ -99,9 +100,13 @@ impl MemoryStore { } } +type Result = std::result::Result; + #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] #[cfg_attr(not(target_arch = "wasm32"), async_trait)] impl CryptoStore for MemoryStore { + type Error = Infallible; + async fn load_account(&self) -> Result> { Ok(None) } @@ -267,6 +272,21 @@ impl CryptoStore for MemoryStore { async fn load_backup_keys(&self) -> Result { Ok(BackupKeys::default()) } + + async fn get_room_settings(&self, _room_id: &RoomId) -> Result> { + warn!("Method not implemented"); + Ok(None) + } + + async fn get_custom_value(&self, _key: &str) -> Result>> { + warn!("Method not implemented"); + Ok(None) + } + + async fn set_custom_value(&self, _key: &str, _value: Vec) -> Result<()> { + warn!("Method not implemented"); + Ok(()) + } } #[cfg(test)] diff --git a/crates/matrix-sdk-crypto/src/store/mod.rs b/crates/matrix-sdk-crypto/src/store/mod.rs index 7a6aae2ab13..5b60fe226ed 100644 --- a/crates/matrix-sdk-crypto/src/store/mod.rs +++ b/crates/matrix-sdk-crypto/src/store/mod.rs @@ -38,33 +38,22 @@ //! [`OlmMachine`]: /matrix_sdk_crypto/struct.OlmMachine.html //! [`CryptoStore`]: trait.Cryptostore.html -pub mod caches; -mod memorystore; - -#[cfg(any(test, feature = "testing"))] -#[macro_use] -#[allow(missing_docs)] -pub mod integration_tests; - use std::{ collections::{HashMap, HashSet}, fmt::Debug, - io::Error as IoError, ops::Deref, sync::{atomic::AtomicBool, Arc}, + time::Duration, }; -use async_trait::async_trait; +use async_std::sync::{Condvar, Mutex as AsyncStdMutex}; use atomic::Ordering; use dashmap::DashSet; -use matrix_sdk_common::{locks::Mutex, AsyncTraitDeps}; -pub use memorystore::MemoryStore; +use matrix_sdk_common::locks::Mutex; use ruma::{ - events::secret::request::SecretName, DeviceId, IdParseError, OwnedDeviceId, OwnedUserId, - RoomId, TransactionId, UserId, + events::secret::request::SecretName, DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, UserId, }; -use serde::{Deserialize, Serialize}; -use serde_json::Error as SerdeError; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use thiserror::Error; use tracing::{info, warn}; use vodozemac::{megolm::SessionOrdering, Curve25519PublicKey}; @@ -77,15 +66,29 @@ use crate::{ }, olm::{ InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity, - ReadOnlyAccount, Session, SessionCreationError, + ReadOnlyAccount, Session, }, + types::EventEncryptionAlgorithm, utilities::encode, verification::VerificationMachine, CrossSigningStatus, }; -/// A `CryptoStore` specific result type. -pub type Result = std::result::Result; +pub mod caches; +mod error; +mod memorystore; +mod traits; + +#[cfg(any(test, feature = "testing"))] +#[macro_use] +#[allow(missing_docs)] +pub mod integration_tests; + +use caches::{SequenceNumber, UsersForKeyQuery}; +pub use error::{CryptoStoreError, Result}; +use matrix_sdk_common::timeout::timeout; +pub use memorystore::MemoryStore; +pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore}; pub use crate::gossiping::{GossipRequest, SecretInfo}; @@ -99,10 +102,20 @@ pub use crate::gossiping::{GossipRequest, SecretInfo}; pub struct Store { user_id: Arc, identity: Arc>, - inner: Arc, + inner: Arc, verification_machine: VerificationMachine, tracked_users_cache: Arc>, - users_for_key_query_cache: Arc>, + + /// Record of the users that are waiting for a /keys/query. + // + // This uses an async_std::sync::Mutex rather than a + // matrix_sdk_common::locks::Mutex because it has to match the Condvar (and tokio lacks a + // working Condvar implementation) + users_for_key_query: Arc>, + + // condition variable that is notified each time an update is received for a user. + users_for_key_query_condvar: Arc, + tracked_user_loading_lock: Arc>, tracked_users_loaded: Arc, } @@ -121,6 +134,7 @@ pub struct Changes { pub key_requests: Vec, pub identities: IdentityChanges, pub devices: DeviceChanges, + pub room_settings: HashMap, } /// A user for which we are tracking the list of devices. @@ -199,6 +213,7 @@ impl RecoveryKey { } } +#[cfg(not(tarpaulin_include))] impl Debug for RecoveryKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RecoveryKey").finish() @@ -249,6 +264,7 @@ pub struct CrossSigningKeyExport { pub user_signing_key: Option, } +#[cfg(not(tarpaulin_include))] impl Debug for CrossSigningKeyExport { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("CrossSigningKeyExport") @@ -278,12 +294,42 @@ pub enum SecretImportError { Store(#[from] CryptoStoreError), } +/// Result type telling us if a `/keys/query` response was expected for a given +/// user. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum UserKeyQueryResult { + WasPending, + WasNotPending, + + /// A query was pending, but we gave up waiting + TimeoutExpired, +} + +/// Room encryption settings which are modified by state events or user options +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] +pub struct RoomSettings { + /// The encryption algorithm that should be used in the room. + pub algorithm: EventEncryptionAlgorithm, + /// Should untrusted devices receive the room key, or should they be + /// excluded from the conversation. + pub only_allow_trusted_devices: bool, +} + +impl Default for RoomSettings { + fn default() -> Self { + Self { + algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2, + only_allow_trusted_devices: false, + } + } +} + impl Store { /// Create a new Store - pub fn new( + pub(crate) fn new( user_id: Arc, identity: Arc>, - store: Arc, + store: Arc, verification_machine: VerificationMachine, ) -> Self { Self { @@ -292,40 +338,41 @@ impl Store { inner: store, verification_machine, tracked_users_cache: DashSet::new().into(), - users_for_key_query_cache: DashSet::new().into(), + users_for_key_query: AsyncStdMutex::new(UsersForKeyQuery::new()).into(), + users_for_key_query_condvar: Condvar::new().into(), tracked_users_loaded: AtomicBool::new(false).into(), tracked_user_loading_lock: Mutex::new(()).into(), } } /// UserId associated with this store - pub fn user_id(&self) -> &UserId { + pub(crate) fn user_id(&self) -> &UserId { &self.user_id } /// DeviceId associated with this store - pub fn device_id(&self) -> &DeviceId { + pub(crate) fn device_id(&self) -> &DeviceId { self.verification_machine.own_device_id() } /// The Account associated with this store - pub fn account(&self) -> &ReadOnlyAccount { + pub(crate) fn account(&self) -> &ReadOnlyAccount { &self.verification_machine.store.account } #[cfg(test)] /// test helper to reset the cross signing identity - pub async fn reset_cross_signing_identity(&self) { + pub(crate) async fn reset_cross_signing_identity(&self) { self.identity.lock().await.reset().await; } /// PrivateCrossSigningIdentity associated with this store - pub fn private_identity(&self) -> Arc> { + pub(crate) fn private_identity(&self) -> Arc> { self.identity.clone() } /// Save the given Sessions to the store - pub async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { + pub(crate) async fn save_sessions(&self, sessions: &[Session]) -> Result<()> { let changes = Changes { sessions: sessions.to_vec(), ..Default::default() }; self.save_changes(changes).await @@ -337,7 +384,7 @@ impl Store { /// This method returns `SessionOrdering::Better` if the given session is /// better than the one we already have or if we don't have such a /// session in the store. - pub async fn compare_group_session( + pub(crate) async fn compare_group_session( &self, session: &InboundGroupSession, ) -> Result { @@ -353,7 +400,7 @@ impl Store { #[cfg(test)] /// Testing helper to allow to save only a set of devices - pub async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> { + pub(crate) async fn save_devices(&self, devices: &[ReadOnlyDevice]) -> Result<()> { let changes = Changes { devices: DeviceChanges { changed: devices.to_vec(), ..Default::default() }, ..Default::default() @@ -364,7 +411,7 @@ impl Store { #[cfg(test)] /// Testing helper to allo to save only a set of InboundGroupSession - pub async fn save_inbound_group_sessions( + pub(crate) async fn save_inbound_group_sessions( &self, sessions: &[InboundGroupSession], ) -> Result<()> { @@ -374,7 +421,7 @@ impl Store { } /// Get the display name of our own device. - pub async fn device_display_name(&self) -> Result, CryptoStoreError> { + pub(crate) async fn device_display_name(&self) -> Result, CryptoStoreError> { Ok(self .inner .get_device(self.user_id(), self.device_id()) @@ -383,7 +430,7 @@ impl Store { } /// Get the read-only device associated with `device_id` for `user_id` - pub async fn get_readonly_device( + pub(crate) async fn get_readonly_device( &self, user_id: &UserId, device_id: &DeviceId, @@ -394,7 +441,7 @@ impl Store { /// Get the read-only version of all the devices that the given user has. /// /// *Note*: This doesn't return our own device. - pub async fn get_readonly_devices_filtered( + pub(crate) async fn get_readonly_devices_filtered( &self, user_id: &UserId, ) -> Result> { @@ -409,7 +456,7 @@ impl Store { /// Get the read-only version of all the devices that the given user has. /// /// *Note*: This does also return our own device. - pub async fn get_readonly_devices_unfiltered( + pub(crate) async fn get_readonly_devices_unfiltered( &self, user_id: &UserId, ) -> Result> { @@ -419,7 +466,7 @@ impl Store { /// Get a device for the given user with the given curve25519 key. /// /// *Note*: This doesn't return our own device. - pub async fn get_device_from_curve_key( + pub(crate) async fn get_device_from_curve_key( &self, user_id: &UserId, curve_key: Curve25519PublicKey, @@ -432,7 +479,7 @@ impl Store { /// Get all devices associated with the given `user_id` /// /// *Note*: This doesn't return our own device. - pub async fn get_user_devices_filtered(&self, user_id: &UserId) -> Result { + pub(crate) async fn get_user_devices_filtered(&self, user_id: &UserId) -> Result { self.get_user_devices(user_id).await.map(|mut d| { if user_id == self.user_id() { d.inner.remove(self.device_id()); @@ -444,7 +491,7 @@ impl Store { /// Get all devices associated with the given `user_id` /// /// *Note*: This does also return our own device. - pub async fn get_user_devices(&self, user_id: &UserId) -> Result { + pub(crate) async fn get_user_devices(&self, user_id: &UserId) -> Result { let devices = self.get_readonly_devices_unfiltered(user_id).await?; let own_identity = @@ -460,7 +507,7 @@ impl Store { } /// Get a Device copy associated with `device_id` for `user_id` - pub async fn get_device( + pub(crate) async fn get_device( &self, user_id: &UserId, device_id: &DeviceId, @@ -478,7 +525,7 @@ impl Store { } /// Get the Identity of `user_id` - pub async fn get_identity(&self, user_id: &UserId) -> Result> { + pub(crate) async fn get_identity(&self, user_id: &UserId) -> Result> { // let own_identity = // self.inner.get_user_identity(self.user_id()).await?.and_then(|i| i.own()); Ok(if let Some(identity) = self.inner.get_user_identity(user_id).await? { @@ -518,7 +565,7 @@ impl Store { /// # Arguments /// /// * `secret_name` - The name of the secret that should be exported. - pub async fn export_secret(&self, secret_name: &SecretName) -> Option { + pub(crate) async fn export_secret(&self, secret_name: &SecretName) -> Option { match secret_name { SecretName::CrossSigningMasterKey | SecretName::CrossSigningUserSigningKey @@ -545,7 +592,7 @@ impl Store { } /// Import the Cross Signing Keys - pub async fn import_cross_signing_keys( + pub(crate) async fn import_cross_signing_keys( &self, export: CrossSigningKeyExport, ) -> Result { @@ -575,7 +622,7 @@ impl Store { } /// Import the given `secret` named `secret_name` into the keystore. - pub async fn import_secret( + pub(crate) async fn import_secret( &self, secret_name: &SecretName, secret: &str, @@ -615,30 +662,89 @@ impl Store { Ok(()) } - /// Mark that the given user has an outdated device list. + /// Mark the given user as being tracked for device lists, and mark that it + /// has an outdated device list. /// /// This means that the user will be considered for a `/keys/query` request /// next time [`Store::users_for_key_query()`] is called. - pub async fn mark_user_as_changed(&self, user: &UserId) -> Result<()> { - self.save_tracked_users(&[(user, true)]).await + pub(crate) async fn mark_user_as_changed(&self, user: &UserId) -> Result<()> { + self.users_for_key_query.lock().await.insert_user(user); + self.tracked_users_cache.insert(user.to_owned()); + + self.inner.save_tracked_users(&[(user, true)]).await } - /// Save the list of users and their outdated/dirty flags to the store. + /// Add entries to the list of users being tracked for device changes /// - /// This method will fill up the store-internal caches, unlike the method on - /// the various [`CryptoStore`] implementations. - pub async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<()> { - for &(user, dirty) in users { - if dirty { - self.users_for_key_query_cache.insert(user.to_owned()); - } else { - self.users_for_key_query_cache.remove(user); + /// Any users not already on the list are flagged as awaiting a key query. + /// Users that were already in the list are unaffected. + pub(crate) async fn update_tracked_users( + &self, + users: impl Iterator, + ) -> Result<()> { + self.load_tracked_users().await?; + + let mut store_updates = Vec::new(); + let mut key_query_lock = self.users_for_key_query.lock().await; + + for user_id in users { + if !self.tracked_users_cache.contains(user_id) { + self.tracked_users_cache.insert(user_id.to_owned()); + key_query_lock.insert_user(user_id); + store_updates.push((user_id, true)) } + } + + self.inner.save_tracked_users(&store_updates).await + } + + /// Process notifications that users have changed devices. + /// + /// This is used to handle the list of device-list updates that is received + /// from the `/sync` response. Any users *whose device lists we are + /// tracking* are flagged as needing a key query. Users whose devices we + /// are not tracking are ignored. + pub(crate) async fn mark_tracked_users_as_changed( + &self, + users: impl Iterator, + ) -> Result<()> { + self.load_tracked_users().await?; - self.tracked_users_cache.insert(user.to_owned()); + let mut store_updates: Vec<(&UserId, bool)> = Vec::new(); + let mut key_query_lock = self.users_for_key_query.lock().await; + + for user_id in users { + if self.tracked_users_cache.contains(user_id) { + key_query_lock.insert_user(user_id); + store_updates.push((user_id, true)); + } } - self.inner.save_tracked_users(users).await?; + self.inner.save_tracked_users(&store_updates).await + } + + /// Flag that the given users devices are now up-to-date. + /// + /// This is called after processing the response to a /keys/query request. + /// Any users whose device lists we are tracking are removed from the + /// list of those pending a /keys/query. + pub(crate) async fn mark_tracked_users_as_up_to_date( + &self, + users: impl Iterator, + sequence_number: SequenceNumber, + ) -> Result<()> { + let mut store_updates: Vec<(&UserId, bool)> = Vec::new(); + let mut key_query_lock = self.users_for_key_query.lock().await; + + for user_id in users { + if self.tracked_users_cache.contains(user_id) { + let clean = key_query_lock.maybe_remove_user(user_id, sequence_number); + store_updates.push((user_id, !clean)); + } + } + self.inner.save_tracked_users(&store_updates).await?; + // wake up any tasks that may have been waiting for updates + self.users_for_key_query_condvar.notify_all(); Ok(()) } @@ -660,11 +766,12 @@ impl Store { if !self.tracked_users_loaded.load(Ordering::SeqCst) { let tracked_users = self.inner.load_tracked_users().await?; + let mut query_users_lock = self.users_for_key_query.lock().await; for user in tracked_users { self.tracked_users_cache.insert(user.user_id.to_owned()); if user.dirty { - self.users_for_key_query_cache.insert(user.user_id); + query_users_lock.insert_user(&user.user_id); } } @@ -675,277 +782,117 @@ impl Store { Ok(()) } - /// Are we tracking the list of devices this user has? - pub async fn is_user_tracked(&self, user_id: &UserId) -> Result { - self.load_tracked_users().await?; - - Ok(self.tracked_users_cache.contains(user_id)) - } - - /// Are there any users that have the outdated/dirty flag set for their list - /// of devices? - pub async fn has_users_for_key_query(&self) -> Result { - self.load_tracked_users().await?; - - Ok(!self.users_for_key_query_cache.is_empty()) - } - /// Get the set of users that has the outdate/dirty flag set for their list /// of devices. /// /// This set should be included in a `/keys/query` request which will update /// the device list. - pub async fn users_for_key_query(&self) -> Result> { - self.load_tracked_users().await?; - - Ok(self.users_for_key_query_cache.iter().map(|u| u.clone()).collect()) - } - - /// See the docs for [`crate::OlmMachine::tracked_users()`]. - pub async fn tracked_users(&self) -> Result> { - self.load_tracked_users().await?; - - Ok(self.tracked_users_cache.iter().map(|u| u.clone()).collect()) - } -} - -impl Deref for Store { - type Target = dyn CryptoStore; - - fn deref(&self) -> &Self::Target { - self.inner.deref() - } -} - -/// The crypto store's error type. -#[derive(Debug, Error)] -pub enum CryptoStoreError { - /// The account that owns the sessions, group sessions, and devices wasn't - /// found. - #[error("can't save/load sessions or group sessions in the store before an account is stored")] - AccountUnset, - - /// An IO error occurred. - #[error(transparent)] - Io(#[from] IoError), - - /// Failed to decrypt an pickled object. - #[error("An object failed to be decrypted while unpickling")] - UnpicklingError, - - /// Failed to decrypt an pickled object. - #[error(transparent)] - Pickle(#[from] vodozemac::PickleError), - - /// The received room key couldn't be converted into a valid Megolm session. - #[error(transparent)] - SessionCreation(#[from] SessionCreationError), - - /// A Matrix identifier failed to be validated. - #[error(transparent)] - IdentifierValidation(#[from] IdParseError), - - /// The store failed to (de)serialize a data type. - #[error(transparent)] - Serialization(#[from] SerdeError), - - /// The database format has changed in a backwards incompatible way. - #[error( - "The database format changed in an incompatible way, current \ - version: {0}, latest version: {1}" - )] - UnsupportedDatabaseVersion(usize, usize), - - /// A problem with the underlying database backend - #[error(transparent)] - Backend(Box), -} - -impl CryptoStoreError { - /// Create a new [`Backend`][Self::Backend] error. /// - /// Shorthand for `StoreError::Backend(Box::new(error))`. - #[inline] - pub fn backend(error: E) -> Self - where - E: std::error::Error + Send + Sync + 'static, - { - Self::Backend(Box::new(error)) - } -} - -/// Represents a store that the `OlmMachine` uses to store E2EE data (such as -/// cryptographic keys). -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] -#[cfg_attr(not(target_arch = "wasm32"), async_trait)] -pub trait CryptoStore: AsyncTraitDeps { - /// Load an account that was previously stored. - async fn load_account(&self) -> Result>; - - /// Save the given account in the store. + /// # Returns /// - /// # Arguments - /// - /// * `account` - The account that should be stored. - async fn save_account(&self, account: ReadOnlyAccount) -> Result<()>; - - /// Try to load a private cross signing identity, if one is stored. - async fn load_identity(&self) -> Result>; - - /// Save the set of changes to the store. - /// - /// # Arguments - /// - /// * `changes` - The set of changes that should be stored. - async fn save_changes(&self, changes: Changes) -> Result<()>; - - /// Get all the sessions that belong to the given sender key. - /// - /// # Arguments - /// - /// * `sender_key` - The sender key that was used to establish the sessions. - async fn get_sessions(&self, sender_key: &str) -> Result>>>>; - - /// Get the inbound group session from our store. - /// - /// # Arguments - /// * `room_id` - The room id of the room that the session belongs to. - /// - /// * `sender_key` - The sender key that sent us the session. - /// - /// * `session_id` - The unique id of the session. - async fn get_inbound_group_session( + /// A pair `(users, sequence_number)`, where `users` is the list of users to + /// be queried, and `sequence_number` is the current sequence number, + /// which should be returned in `mark_tracked_users_as_up_to_date`. + pub(crate) async fn users_for_key_query( &self, - room_id: &RoomId, - session_id: &str, - ) -> Result>; - - /// Get all the inbound group sessions we have stored. - async fn get_inbound_group_sessions(&self) -> Result>; + ) -> Result<(HashSet, SequenceNumber)> { + self.load_tracked_users().await?; - /// Get the number inbound group sessions we have and how many of them are - /// backed up. - async fn inbound_group_session_counts(&self) -> Result; + Ok(self.users_for_key_query.lock().await.users_for_key_query()) + } - /// Get all the inbound group sessions we have not backed up yet. - async fn inbound_group_sessions_for_backup( + /// Wait for a `/keys/query` response to be received if one is expected for + /// the given user. + /// + /// If the given timeout elapses, the method will stop waiting and return + /// `UserKeyQueryResult::TimeoutExpired` + pub(crate) async fn wait_if_user_key_query_pending( &self, - limit: usize, - ) -> Result>; - - /// Reset the backup state of all the stored inbound group sessions. - async fn reset_backup_state(&self) -> Result<()>; - - /// Get the backup keys we have stored. - async fn load_backup_keys(&self) -> Result; + timeout_duration: Duration, + user: &UserId, + ) -> UserKeyQueryResult { + let mut g = self.users_for_key_query.lock().await; - /// Get the outbound group session we have stored that is used for the - /// given room. - async fn get_outbound_group_session( - &self, - room_id: &RoomId, - ) -> Result>; + let Some(w) = g.maybe_register_waiting_task(user) else { + return UserKeyQueryResult::WasNotPending; + }; - /// Load the list of users whose devices we are keeping track of. - async fn load_tracked_users(&self) -> Result>; + let f1 = async { + while !w.completed.load(Ordering::Relaxed) { + g = self.users_for_key_query_condvar.wait(g).await; + } + }; - /// Save a list of users and their respective dirty/outdated flags to the - /// store. - async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<()>; + match timeout(Box::pin(f1), timeout_duration).await { + Err(_) => { + warn!( + user_id = ?user, + "The user has a pending `/key/query` request which did \ + not finish yet, some devices might be missing." + ); - /// Get the device for the given user with the given device ID. - /// - /// # Arguments - /// - /// * `user_id` - The user that the device belongs to. - /// - /// * `device_id` - The unique id of the device. - async fn get_device( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result>; - - /// Get all the devices of the given user. - /// - /// # Arguments - /// - /// * `user_id` - The user for which we should get all the devices. - async fn get_user_devices( - &self, - user_id: &UserId, - ) -> Result>; + UserKeyQueryResult::TimeoutExpired + } + _ => UserKeyQueryResult::WasPending, + } + } - /// Get the user identity that is attached to the given user id. - /// - /// # Arguments - /// - /// * `user_id` - The user for which we should get the identity. - async fn get_user_identity(&self, user_id: &UserId) -> Result>; + /// See the docs for [`crate::OlmMachine::tracked_users()`]. + pub(crate) async fn tracked_users(&self) -> Result> { + self.load_tracked_users().await?; - /// Check if a hash for an Olm message stored in the database. - async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result; + Ok(self.tracked_users_cache.iter().map(|u| u.clone()).collect()) + } - /// Get an outgoing secret request that we created that matches the given - /// request id. - /// - /// # Arguments - /// - /// * `request_id` - The unique request id that identifies this outgoing - /// secret request. - async fn get_outgoing_secret_requests( - &self, - request_id: &TransactionId, - ) -> Result>; + /// Check whether there is a global flag to only encrypt messages for + /// trusted devices or for everyone. + pub async fn get_only_allow_trusted_devices(&self) -> Result { + let value = self.get_value("only_allow_trusted_devices").await?.unwrap_or_default(); + Ok(value) + } - /// Get an outgoing key request that we created that matches the given - /// requested key info. - /// - /// # Arguments - /// - /// * `key_info` - The key info of an outgoing secret request. - async fn get_secret_request_by_info( + /// Set global flag whether to encrypt messages for untrusted devices, or + /// whether they should be excluded from the conversation. + pub async fn set_only_allow_trusted_devices( &self, - secret_info: &SecretInfo, - ) -> Result>; + block_untrusted_devices: bool, + ) -> Result<()> { + self.set_value("only_allow_trusted_devices", &block_untrusted_devices).await + } - /// Get all outgoing secret requests that we have in the store. - async fn get_unsent_secret_requests(&self) -> Result>; + /// Get custom stored value associated with a key + pub async fn get_value(&self, key: &str) -> Result> { + let Some(value) = self.get_custom_value(key).await? else { + return Ok(None); + }; + let deserialized = self.deserialize_value(&value)?; + Ok(Some(deserialized)) + } - /// Delete an outgoing key request that we created that matches the given - /// request id. - /// - /// # Arguments - /// - /// * `request_id` - The unique request id that identifies this outgoing key - /// request. - async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()>; -} + /// Store custom value associated with a key + pub async fn set_value(&self, key: &str, value: &impl Serialize) -> Result<()> { + let serialized = self.serialize_value(value)?; + self.set_custom_value(key, serialized).await?; + Ok(()) + } -/// A type that can be type-erased into `Arc`. -/// -/// This trait is not meant to be implemented directly outside -/// `matrix-sdk-crypto`, but it is automatically implemented for everything that -/// implements `CryptoStore`. -pub trait IntoCryptoStore { - #[doc(hidden)] - fn into_crypto_store(self) -> Arc; -} + fn serialize_value(&self, value: &impl Serialize) -> Result> { + let serialized = + rmp_serde::to_vec_named(value).map_err(|x| CryptoStoreError::Backend(x.into()))?; + Ok(serialized) + } -impl IntoCryptoStore for T -where - T: CryptoStore + Sized + 'static, -{ - fn into_crypto_store(self) -> Arc { - Arc::new(self) + fn deserialize_value(&self, value: &[u8]) -> Result { + let deserialized = + rmp_serde::from_slice(value).map_err(|e| CryptoStoreError::Backend(e.into()))?; + Ok(deserialized) } } -impl IntoCryptoStore for Arc -where - T: CryptoStore + 'static, -{ - fn into_crypto_store(self) -> Arc { - self +impl Deref for Store { + type Target = DynCryptoStore; + + fn deref(&self) -> &Self::Target { + self.inner.deref() } } diff --git a/crates/matrix-sdk-crypto/src/store/traits.rs b/crates/matrix-sdk-crypto/src/store/traits.rs new file mode 100644 index 00000000000..43d7021b2da --- /dev/null +++ b/crates/matrix-sdk-crypto/src/store/traits.rs @@ -0,0 +1,397 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{collections::HashMap, fmt, sync::Arc}; + +use async_trait::async_trait; +use matrix_sdk_common::{locks::Mutex, AsyncTraitDeps}; +use ruma::{DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId}; + +use super::{BackupKeys, Changes, CryptoStoreError, Result, RoomKeyCounts, RoomSettings}; +use crate::{ + olm::{ + InboundGroupSession, OlmMessageHash, OutboundGroupSession, PrivateCrossSigningIdentity, + Session, + }, + GossipRequest, ReadOnlyAccount, ReadOnlyDevice, ReadOnlyUserIdentities, SecretInfo, + TrackedUser, +}; + +/// Represents a store that the `OlmMachine` uses to store E2EE data (such as +/// cryptographic keys). +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +pub trait CryptoStore: AsyncTraitDeps { + /// The error type used by this crypto store. + type Error: fmt::Debug + Into; + + /// Load an account that was previously stored. + async fn load_account(&self) -> Result, Self::Error>; + + /// Save the given account in the store. + /// + /// # Arguments + /// + /// * `account` - The account that should be stored. + async fn save_account(&self, account: ReadOnlyAccount) -> Result<(), Self::Error>; + + /// Try to load a private cross signing identity, if one is stored. + async fn load_identity(&self) -> Result, Self::Error>; + + /// Save the set of changes to the store. + /// + /// # Arguments + /// + /// * `changes` - The set of changes that should be stored. + async fn save_changes(&self, changes: Changes) -> Result<(), Self::Error>; + + /// Get all the sessions that belong to the given sender key. + /// + /// # Arguments + /// + /// * `sender_key` - The sender key that was used to establish the sessions. + async fn get_sessions( + &self, + sender_key: &str, + ) -> Result>>>, Self::Error>; + + /// Get the inbound group session from our store. + /// + /// # Arguments + /// * `room_id` - The room id of the room that the session belongs to. + /// + /// * `sender_key` - The sender key that sent us the session. + /// + /// * `session_id` - The unique id of the session. + async fn get_inbound_group_session( + &self, + room_id: &RoomId, + session_id: &str, + ) -> Result, Self::Error>; + + /// Get all the inbound group sessions we have stored. + async fn get_inbound_group_sessions(&self) -> Result, Self::Error>; + + /// Get the number inbound group sessions we have and how many of them are + /// backed up. + async fn inbound_group_session_counts(&self) -> Result; + + /// Get all the inbound group sessions we have not backed up yet. + async fn inbound_group_sessions_for_backup( + &self, + limit: usize, + ) -> Result, Self::Error>; + + /// Reset the backup state of all the stored inbound group sessions. + async fn reset_backup_state(&self) -> Result<(), Self::Error>; + + /// Get the backup keys we have stored. + async fn load_backup_keys(&self) -> Result; + + /// Get the outbound group session we have stored that is used for the + /// given room. + async fn get_outbound_group_session( + &self, + room_id: &RoomId, + ) -> Result, Self::Error>; + + /// Load the list of users whose devices we are keeping track of. + async fn load_tracked_users(&self) -> Result, Self::Error>; + + /// Save a list of users and their respective dirty/outdated flags to the + /// store. + async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<(), Self::Error>; + + /// Get the device for the given user with the given device ID. + /// + /// # Arguments + /// + /// * `user_id` - The user that the device belongs to. + /// + /// * `device_id` - The unique id of the device. + async fn get_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result, Self::Error>; + + /// Get all the devices of the given user. + /// + /// # Arguments + /// + /// * `user_id` - The user for which we should get all the devices. + async fn get_user_devices( + &self, + user_id: &UserId, + ) -> Result, Self::Error>; + + /// Get the user identity that is attached to the given user id. + /// + /// # Arguments + /// + /// * `user_id` - The user for which we should get the identity. + async fn get_user_identity( + &self, + user_id: &UserId, + ) -> Result, Self::Error>; + + /// Check if a hash for an Olm message stored in the database. + async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result; + + /// Get an outgoing secret request that we created that matches the given + /// request id. + /// + /// # Arguments + /// + /// * `request_id` - The unique request id that identifies this outgoing + /// secret request. + async fn get_outgoing_secret_requests( + &self, + request_id: &TransactionId, + ) -> Result, Self::Error>; + + /// Get an outgoing key request that we created that matches the given + /// requested key info. + /// + /// # Arguments + /// + /// * `key_info` - The key info of an outgoing secret request. + async fn get_secret_request_by_info( + &self, + secret_info: &SecretInfo, + ) -> Result, Self::Error>; + + /// Get all outgoing secret requests that we have in the store. + async fn get_unsent_secret_requests(&self) -> Result, Self::Error>; + + /// Delete an outgoing key request that we created that matches the given + /// request id. + /// + /// # Arguments + /// + /// * `request_id` - The unique request id that identifies this outgoing key + /// request. + async fn delete_outgoing_secret_requests( + &self, + request_id: &TransactionId, + ) -> Result<(), Self::Error>; + + /// Get the room settings, such as the encryption algorithm or whether to + /// encrypt only for trusted devices. + /// + /// # Arguments + /// + /// * `room_id` - The room id of the room + async fn get_room_settings( + &self, + room_id: &RoomId, + ) -> Result, Self::Error>; + + /// Get arbitrary data from the store + /// + /// # Arguments + /// + /// * `key` - The key to fetch data for + async fn get_custom_value(&self, key: &str) -> Result>, Self::Error>; + + /// Put arbitrary data into the store + /// + /// # Arguments + /// + /// * `key` - The key to insert data into + /// + /// * `value` - The value to insert + async fn set_custom_value(&self, key: &str, value: Vec) -> Result<(), Self::Error>; +} + +#[repr(transparent)] +struct EraseCryptoStoreError(T); + +impl fmt::Debug for EraseCryptoStoreError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait)] +impl CryptoStore for EraseCryptoStoreError { + type Error = CryptoStoreError; + + async fn load_account(&self) -> Result> { + self.0.load_account().await.map_err(Into::into) + } + + async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { + self.0.save_account(account).await.map_err(Into::into) + } + + async fn load_identity(&self) -> Result> { + self.0.load_identity().await.map_err(Into::into) + } + + async fn save_changes(&self, changes: Changes) -> Result<()> { + self.0.save_changes(changes).await.map_err(Into::into) + } + + async fn get_sessions(&self, sender_key: &str) -> Result>>>> { + self.0.get_sessions(sender_key).await.map_err(Into::into) + } + + async fn get_inbound_group_session( + &self, + room_id: &RoomId, + session_id: &str, + ) -> Result> { + self.0.get_inbound_group_session(room_id, session_id).await.map_err(Into::into) + } + + async fn get_inbound_group_sessions(&self) -> Result> { + self.0.get_inbound_group_sessions().await.map_err(Into::into) + } + + async fn inbound_group_session_counts(&self) -> Result { + self.0.inbound_group_session_counts().await.map_err(Into::into) + } + + async fn inbound_group_sessions_for_backup( + &self, + limit: usize, + ) -> Result> { + self.0.inbound_group_sessions_for_backup(limit).await.map_err(Into::into) + } + + async fn reset_backup_state(&self) -> Result<()> { + self.0.reset_backup_state().await.map_err(Into::into) + } + + async fn load_backup_keys(&self) -> Result { + self.0.load_backup_keys().await.map_err(Into::into) + } + + async fn get_outbound_group_session( + &self, + room_id: &RoomId, + ) -> Result> { + self.0.get_outbound_group_session(room_id).await.map_err(Into::into) + } + + async fn load_tracked_users(&self) -> Result> { + self.0.load_tracked_users().await.map_err(Into::into) + } + + async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<()> { + self.0.save_tracked_users(users).await.map_err(Into::into) + } + + async fn get_device( + &self, + user_id: &UserId, + device_id: &DeviceId, + ) -> Result> { + self.0.get_device(user_id, device_id).await.map_err(Into::into) + } + + async fn get_user_devices( + &self, + user_id: &UserId, + ) -> Result> { + self.0.get_user_devices(user_id).await.map_err(Into::into) + } + + async fn get_user_identity(&self, user_id: &UserId) -> Result> { + self.0.get_user_identity(user_id).await.map_err(Into::into) + } + + async fn is_message_known(&self, message_hash: &OlmMessageHash) -> Result { + self.0.is_message_known(message_hash).await.map_err(Into::into) + } + + async fn get_outgoing_secret_requests( + &self, + request_id: &TransactionId, + ) -> Result> { + self.0.get_outgoing_secret_requests(request_id).await.map_err(Into::into) + } + + async fn get_secret_request_by_info( + &self, + secret_info: &SecretInfo, + ) -> Result> { + self.0.get_secret_request_by_info(secret_info).await.map_err(Into::into) + } + + async fn get_unsent_secret_requests(&self) -> Result> { + self.0.get_unsent_secret_requests().await.map_err(Into::into) + } + + async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> { + self.0.delete_outgoing_secret_requests(request_id).await.map_err(Into::into) + } + + async fn get_room_settings(&self, room_id: &RoomId) -> Result> { + self.0.get_room_settings(room_id).await.map_err(Into::into) + } + + async fn get_custom_value(&self, key: &str) -> Result>, Self::Error> { + self.0.get_custom_value(key).await.map_err(Into::into) + } + + async fn set_custom_value(&self, key: &str, value: Vec) -> Result<(), Self::Error> { + self.0.set_custom_value(key, value).await.map_err(Into::into) + } +} + +/// A type-erased [`CryptoStore`]. +pub type DynCryptoStore = dyn CryptoStore; + +/// A type that can be type-erased into `Arc`. +/// +/// This trait is not meant to be implemented directly outside +/// `matrix-sdk-crypto`, but it is automatically implemented for everything that +/// implements `CryptoStore`. +pub trait IntoCryptoStore { + #[doc(hidden)] + fn into_crypto_store(self) -> Arc; +} + +impl IntoCryptoStore for T +where + T: CryptoStore + 'static, +{ + fn into_crypto_store(self) -> Arc { + Arc::new(EraseCryptoStoreError(self)) + } +} + +// Turns a given `Arc` into `Arc` by attaching the +// CryptoStore impl vtable of `EraseCryptoStoreError`. +impl IntoCryptoStore for Arc +where + T: CryptoStore + 'static, +{ + fn into_crypto_store(self) -> Arc { + let ptr: *const T = Arc::into_raw(self); + let ptr_erased = ptr as *const EraseCryptoStoreError; + // SAFETY: EraseCryptoStoreError is repr(transparent) so T and + // EraseCryptoStoreError have the same layout and ABI + unsafe { Arc::from_raw(ptr_erased) } + } +} + +impl IntoCryptoStore for Arc { + fn into_crypto_store(self) -> Arc { + self + } +} diff --git a/crates/matrix-sdk-crypto/src/types/cross_signing_key.rs b/crates/matrix-sdk-crypto/src/types/cross_signing/common.rs similarity index 78% rename from crates/matrix-sdk-crypto/src/types/cross_signing_key.rs rename to crates/matrix-sdk-crypto/src/types/cross_signing/common.rs index e1014881a99..1036cc136ce 100644 --- a/crates/matrix-sdk-crypto/src/types/cross_signing_key.rs +++ b/crates/matrix-sdk-crypto/src/types/cross_signing/common.rs @@ -23,13 +23,14 @@ use std::collections::BTreeMap; use ruma::{ encryption::KeyUsage, serde::Raw, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceKeyId, - OwnedUserId, + OwnedUserId, UserId, }; use serde::{Deserialize, Serialize}; use serde_json::{value::to_raw_value, Value}; use vodozemac::{Ed25519PublicKey, KeyError}; -use super::{Signatures, SigningKeys}; +use super::{SelfSigningPubkey, UserSigningPubkey}; +use crate::types::{Signatures, SigningKeys}; /// A cross signing key. #[derive(Clone, Debug, Deserialize, Serialize)] @@ -132,38 +133,40 @@ impl From for SigningKey { } } -#[cfg(test)] -mod tests { - use ruma::user_id; - use serde_json::json; - - use super::CrossSigningKey; - - #[test] - fn serialization() { - let json = json!({ - "user_id": "@example:localhost", - "usage": [ - "master" - ], - "keys": { - "ed25519:rJ2TAGkEOP6dX41Ksll6cl8K3J48l8s/59zaXyvl2p0": "rJ2TAGkEOP6dX41Ksll6cl8K3J48l8s/59zaXyvl2p0" - }, - "signatures": { - "@example:localhost": { - "ed25519:WSKKLTJZCL": "ZzJp1wtmRdykXAUEItEjNiFlBrxx8L6/Vaen9am8AuGwlxxJtOkuY4m+4MPLvDPOgavKHLsrRuNLAfCeakMlCQ" - } - }, - "other_data": "other" - }); - - let key: CrossSigningKey = - serde_json::from_value(json.clone()).expect("Can't deserialize cross signing key"); - - assert_eq!(key.user_id, user_id!("@example:localhost")); - - let serialized = serde_json::to_value(key).expect("Can't reserialize cross signing key"); - - assert_eq!(json, serialized); +/// Enum over the cross signing sub-keys. +pub(crate) enum CrossSigningSubKeys<'a> { + /// The self signing subkey. + SelfSigning(&'a SelfSigningPubkey), + /// The user signing subkey. + UserSigning(&'a UserSigningPubkey), +} + +impl<'a> CrossSigningSubKeys<'a> { + /// Get the id of the user that owns this cross signing subkey. + pub fn user_id(&self) -> &UserId { + match self { + CrossSigningSubKeys::SelfSigning(key) => key.user_id(), + CrossSigningSubKeys::UserSigning(key) => key.user_id(), + } + } + + /// Get the `CrossSigningKey` from an sub-keys enum + pub fn cross_signing_key(&self) -> &CrossSigningKey { + match self { + CrossSigningSubKeys::SelfSigning(key) => key.as_ref(), + CrossSigningSubKeys::UserSigning(key) => key.as_ref(), + } + } +} + +impl<'a> From<&'a UserSigningPubkey> for CrossSigningSubKeys<'a> { + fn from(key: &'a UserSigningPubkey) -> Self { + CrossSigningSubKeys::UserSigning(key) + } +} + +impl<'a> From<&'a SelfSigningPubkey> for CrossSigningSubKeys<'a> { + fn from(key: &'a SelfSigningPubkey) -> Self { + CrossSigningSubKeys::SelfSigning(key) } } diff --git a/crates/matrix-sdk-crypto/src/types/cross_signing/master.rs b/crates/matrix-sdk-crypto/src/types/cross_signing/master.rs new file mode 100644 index 00000000000..814487e5607 --- /dev/null +++ b/crates/matrix-sdk-crypto/src/types/cross_signing/master.rs @@ -0,0 +1,150 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{collections::btree_map::Iter, sync::Arc}; + +use ruma::{encryption::KeyUsage, DeviceKeyId, OwnedDeviceKeyId, UserId}; +use serde::{Deserialize, Serialize}; +use vodozemac::Ed25519PublicKey; + +use super::{CrossSigningKey, CrossSigningSubKeys, SigningKey}; +use crate::{ + olm::VerifyJson, + types::{Signatures, SigningKeys}, + SignatureError, +}; + +/// Wrapper for a cross signing key marking it as the master key. +/// +/// Master keys are used to sign other cross signing keys, the self signing and +/// user signing keys of an user will be signed by their master key. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(try_from = "CrossSigningKey")] +pub struct MasterPubkey(pub(super) Arc); + +impl MasterPubkey { + /// Get the user id of the master key's owner. + pub fn user_id(&self) -> &UserId { + &self.0.user_id + } + + /// Get the keys map of containing the master keys. + pub fn keys(&self) -> &SigningKeys { + &self.0.keys + } + + /// Get the list of `KeyUsage` that is set for this key. + pub fn usage(&self) -> &[KeyUsage] { + &self.0.usage + } + + /// Get the signatures map of this cross signing key. + pub fn signatures(&self) -> &Signatures { + &self.0.signatures + } + + /// Get the master key with the given key id. + /// + /// # Arguments + /// + /// * `key_id` - The id of the key that should be fetched. + pub fn get_key(&self, key_id: &DeviceKeyId) -> Option<&SigningKey> { + self.0.keys.get(key_id) + } + + /// Get the first available master key. + /// + /// There's usually only a single master key so this will usually fetch the + /// only key. + pub fn get_first_key(&self) -> Option { + self.0.get_first_key_and_id().map(|(_, k)| k) + } + + /// Check if the given JSON is signed by this master key. + /// + /// This method should only be used if an object's signature needs to be + /// checked multiple times, and you'd like to avoid performing the + /// canonicalization step each time. + /// + /// **Note**: Use this method with caution, the `canonical_json` needs to be + /// correctly canonicalized and make sure that the object you are checking + /// the signature for is allowed to be signed by a master key. + #[cfg(any(feature = "backups_v1", test))] + pub(crate) fn has_signed_raw( + &self, + signatures: &Signatures, + canonical_json: &str, + ) -> Result<(), SignatureError> { + if let Some((key_id, key)) = self.0.get_first_key_and_id() { + key.verify_canonicalized_json(&self.0.user_id, key_id, signatures, canonical_json) + } else { + Err(SignatureError::UnsupportedAlgorithm) + } + } + + /// Check if the given cross signing sub-key is signed by the master key. + /// + /// # Arguments + /// + /// * `subkey` - The subkey that should be checked for a valid signature. + /// + /// Returns an empty result if the signature check succeeded, otherwise a + /// SignatureError indicating why the check failed. + pub(crate) fn verify_subkey<'a>( + &self, + subkey: impl Into>, + ) -> Result<(), SignatureError> { + let subkey: CrossSigningSubKeys<'_> = subkey.into(); + + if self.0.user_id != subkey.user_id() { + return Err(SignatureError::UserIdMismatch); + } + + if let Some((key_id, key)) = self.0.get_first_key_and_id() { + key.verify_json(&self.0.user_id, key_id, subkey.cross_signing_key()) + } else { + Err(SignatureError::UnsupportedAlgorithm) + } + } +} + +impl<'a> IntoIterator for &'a MasterPubkey { + type Item = (&'a OwnedDeviceKeyId, &'a SigningKey); + type IntoIter = Iter<'a, OwnedDeviceKeyId, SigningKey>; + + fn into_iter(self) -> Self::IntoIter { + self.keys().iter() + } +} + +impl AsRef for MasterPubkey { + fn as_ref(&self) -> &CrossSigningKey { + &self.0 + } +} + +impl TryFrom for MasterPubkey { + type Error = serde_json::Error; + + fn try_from(key: CrossSigningKey) -> Result { + if key.usage.contains(&KeyUsage::Master) && key.usage.len() == 1 { + Ok(Self(key.into())) + } else { + Err(serde::de::Error::custom(format!( + "Expected cross signing key usage {} was not found", + KeyUsage::Master + ))) + } + } +} diff --git a/crates/matrix-sdk-crypto/src/types/cross_signing/mod.rs b/crates/matrix-sdk-crypto/src/types/cross_signing/mod.rs new file mode 100644 index 00000000000..49e308fe11f --- /dev/null +++ b/crates/matrix-sdk-crypto/src/types/cross_signing/mod.rs @@ -0,0 +1,153 @@ +// Copyright 2021 Devin Ragotzy. +// Copyright 2021 Timo KΓΆsters. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +mod common; +mod master; +mod self_signing; +mod user_signing; + +pub use common::*; +pub use master::*; +pub use self_signing::*; +pub use user_signing::*; + +macro_rules! impl_partial_eq { + ($key_type: ty) => { + impl PartialEq for $key_type { + /// The `PartialEq` implementation compares the user ID, the usage and the + /// key material, ignoring signatures. + /// + /// The usage could be safely ignored since the type guarantees it has the + /// correct usage by construction -- it is impossible to construct a + /// value of a particular key type with an incorrect usage. However, we + /// check it anyway, to codify the notion that the same key material + /// with a different usage results in a logically different key. + /// + /// The signatures are provided by other devices and don't alter the + /// identity of the key itself. + fn eq(&self, other: &Self) -> bool { + self.user_id() == other.user_id() + && self.keys() == other.keys() + && self.usage() == other.usage() + } + } + impl Eq for $key_type {} + }; +} + +impl_partial_eq!(MasterPubkey); +impl_partial_eq!(SelfSigningPubkey); +impl_partial_eq!(UserSigningPubkey); + +#[cfg(test)] +mod tests { + use matrix_sdk_test::async_test; + use ruma::{encryption::KeyUsage, user_id, DeviceKeyId}; + use serde_json::json; + use vodozemac::Ed25519Signature; + + use crate::{ + identities::{ + manager::testing::{own_key_query, own_key_query_with_user_id}, + user::testing::get_other_own_identity, + }, + types::{CrossSigningKey, MasterPubkey, SelfSigningPubkey, UserSigningPubkey}, + }; + + #[test] + fn serialization() { + let json = json!({ + "user_id": "@example:localhost", + "usage": [ + "master" + ], + "keys": { + "ed25519:rJ2TAGkEOP6dX41Ksll6cl8K3J48l8s/59zaXyvl2p0": "rJ2TAGkEOP6dX41Ksll6cl8K3J48l8s/59zaXyvl2p0" + }, + "signatures": { + "@example:localhost": { + "ed25519:WSKKLTJZCL": "ZzJp1wtmRdykXAUEItEjNiFlBrxx8L6/Vaen9am8AuGwlxxJtOkuY4m+4MPLvDPOgavKHLsrRuNLAfCeakMlCQ" + } + }, + "other_data": "other" + }); + + let key: CrossSigningKey = + serde_json::from_value(json.clone()).expect("Can't deserialize cross signing key"); + + assert_eq!(key.user_id, user_id!("@example:localhost")); + + let serialized = serde_json::to_value(key).expect("Can't reserialize cross signing key"); + + assert_eq!(json, serialized); + } + + #[async_test] + async fn partial_eq_cross_signing_keys() { + macro_rules! test_partial_eq { + ($key_type:ident, $key_field:ident, $field:ident, $usage:expr) => { + let user_id = user_id!("@example:localhost"); + let response = own_key_query(); + let raw = response.$field.get(user_id).unwrap(); + let key: $key_type = raw.deserialize_as().unwrap(); + + // A different key is naturally not the same as our key. + let other_identity = get_other_own_identity().await; + let other_key = other_identity.$key_field(); + assert_ne!(&key, other_key); + + // However, not even our own key material with another user ID is the same. + let other_user_id = user_id!("@example2:localhost"); + let other_response = own_key_query_with_user_id(&other_user_id); + let other_raw = other_response.$field.get(other_user_id).unwrap(); + let other_key: $key_type = other_raw.deserialize_as().unwrap(); + assert_ne!(key, other_key); + + // Now let's add another signature to our key. + let signature = Ed25519Signature::from_base64( + "mia28GKixFzOWKJ0h7Bdrdy2fjxiHCsst1qpe467FbW85H61UlshtKBoAXfTLlVfi0FX+/noJ8B3noQPnY+9Cg" + ).expect("The signature can always be decoded"); + let mut other_key: CrossSigningKey = raw.deserialize_as().unwrap(); + other_key.signatures.add_signature( + user_id.to_owned(), + DeviceKeyId::from_parts(ruma::DeviceKeyAlgorithm::Ed25519, "DEVICEID".into()), + signature, + ); + let other_key = other_key.try_into().unwrap(); + + // Additional signatures are fine, adding more does not change the key's identity. + assert_eq!(key, other_key); + + // However changing the usage results in a different key. + let mut other_key: CrossSigningKey = raw.deserialize_as().unwrap(); + other_key.usage.push($usage); + let other_key = $key_type { 0: other_key.into() }; + assert_ne!(key, other_key); + }; + } + + // The last argument is deliberately some usage which is *not* correct for the + // type. + test_partial_eq!(MasterPubkey, master_key, master_keys, KeyUsage::SelfSigning); + test_partial_eq!(SelfSigningPubkey, self_signing_key, self_signing_keys, KeyUsage::Master); + test_partial_eq!(UserSigningPubkey, user_signing_key, user_signing_keys, KeyUsage::Master); + } +} diff --git a/crates/matrix-sdk-crypto/src/types/cross_signing/self_signing.rs b/crates/matrix-sdk-crypto/src/types/cross_signing/self_signing.rs new file mode 100644 index 00000000000..ec9cb1ff530 --- /dev/null +++ b/crates/matrix-sdk-crypto/src/types/cross_signing/self_signing.rs @@ -0,0 +1,87 @@ +use std::{collections::btree_map::Iter, sync::Arc}; + +use ruma::{encryption::KeyUsage, OwnedDeviceKeyId, UserId}; +use serde::{Deserialize, Serialize}; + +use super::{CrossSigningKey, SigningKey}; +use crate::{ + olm::VerifyJson, + types::{DeviceKeys, SigningKeys}, + ReadOnlyDevice, SignatureError, +}; + +/// Wrapper for a cross signing key marking it as a self signing key. +/// +/// Self signing keys are used to sign the user's own devices. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(try_from = "CrossSigningKey")] +pub struct SelfSigningPubkey(pub(super) Arc); + +impl SelfSigningPubkey { + /// Get the user id of the self signing key's owner. + pub fn user_id(&self) -> &UserId { + &self.0.user_id + } + + /// Get the keys map of containing the self signing keys. + pub fn keys(&self) -> &SigningKeys { + &self.0.keys + } + + /// Get the list of `KeyUsage` that is set for this key. + pub fn usage(&self) -> &[KeyUsage] { + &self.0.usage + } + + /// Verify that the [`DeviceKeys`] have a valid signature from this + /// self-signing key. + pub fn verify_device_keys(&self, device_keys: &DeviceKeys) -> Result<(), SignatureError> { + if let Some((key_id, key)) = self.0.get_first_key_and_id() { + key.verify_json(&self.0.user_id, key_id, device_keys) + } else { + Err(SignatureError::UnsupportedAlgorithm) + } + } + + /// Check if the given device is signed by this self signing key. + /// + /// # Arguments + /// + /// * `device` - The device that should be checked for a valid signature. + /// + /// Returns an empty result if the signature check succeeded, otherwise a + /// SignatureError indicating why the check failed. + pub(crate) fn verify_device(&self, device: &ReadOnlyDevice) -> Result<(), SignatureError> { + self.verify_device_keys(device.as_device_keys()) + } +} + +impl<'a> IntoIterator for &'a SelfSigningPubkey { + type Item = (&'a OwnedDeviceKeyId, &'a SigningKey); + type IntoIter = Iter<'a, OwnedDeviceKeyId, SigningKey>; + + fn into_iter(self) -> Self::IntoIter { + self.keys().iter() + } +} + +impl TryFrom for SelfSigningPubkey { + type Error = serde_json::Error; + + fn try_from(key: CrossSigningKey) -> Result { + if key.usage.contains(&KeyUsage::SelfSigning) && key.usage.len() == 1 { + Ok(Self(key.into())) + } else { + Err(serde::de::Error::custom(format!( + "Expected cross signing key usage {} was not found", + KeyUsage::SelfSigning + ))) + } + } +} + +impl AsRef for SelfSigningPubkey { + fn as_ref(&self) -> &CrossSigningKey { + &self.0 + } +} diff --git a/crates/matrix-sdk-crypto/src/types/cross_signing/user_signing.rs b/crates/matrix-sdk-crypto/src/types/cross_signing/user_signing.rs new file mode 100644 index 00000000000..a97ead111e5 --- /dev/null +++ b/crates/matrix-sdk-crypto/src/types/cross_signing/user_signing.rs @@ -0,0 +1,80 @@ +use std::{collections::btree_map::Iter, sync::Arc}; + +use ruma::{encryption::KeyUsage, OwnedDeviceKeyId, UserId}; +use serde::{Deserialize, Serialize}; + +use super::{CrossSigningKey, MasterPubkey, SigningKey}; +use crate::{olm::VerifyJson, types::SigningKeys, SignatureError}; + +/// Wrapper for a cross signing key marking it as a user signing key. +/// +/// User signing keys are used to sign the master keys of other users. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(try_from = "CrossSigningKey")] +pub struct UserSigningPubkey(pub(super) Arc); + +impl UserSigningPubkey { + /// Get the user id of the user signing key's owner. + pub fn user_id(&self) -> &UserId { + &self.0.user_id + } + + /// Get the list of `KeyUsage` that is set for this key. + pub fn usage(&self) -> &[KeyUsage] { + &self.0.usage + } + + /// Get the keys map of containing the user signing keys. + pub fn keys(&self) -> &SigningKeys { + &self.0.keys + } + + /// Check if the given master key is signed by this user signing key. + /// + /// # Arguments + /// + /// * `master_key` - The master key that should be checked for a valid + /// signature. + /// + /// Returns an empty result if the signature check succeeded, otherwise a + /// SignatureError indicating why the check failed. + pub(crate) fn verify_master_key( + &self, + master_key: &MasterPubkey, + ) -> Result<(), SignatureError> { + if let Some((key_id, key)) = self.0.get_first_key_and_id() { + key.verify_json(&self.0.user_id, key_id, master_key.as_ref()) + } else { + Err(SignatureError::UnsupportedAlgorithm) + } + } +} + +impl<'a> IntoIterator for &'a UserSigningPubkey { + type Item = (&'a OwnedDeviceKeyId, &'a SigningKey); + type IntoIter = Iter<'a, OwnedDeviceKeyId, SigningKey>; + + fn into_iter(self) -> Self::IntoIter { + self.keys().iter() + } +} + +impl TryFrom for UserSigningPubkey { + type Error = serde_json::Error; + + fn try_from(key: CrossSigningKey) -> Result { + if key.usage.contains(&KeyUsage::UserSigning) && key.usage.len() == 1 { + Ok(Self(key.into())) + } else { + Err(serde::de::Error::custom(format!( + "Expected cross signing key usage {} was not found", + KeyUsage::UserSigning + ))) + } + } +} +impl AsRef for UserSigningPubkey { + fn as_ref(&self) -> &CrossSigningKey { + &self.0 + } +} diff --git a/crates/matrix-sdk-crypto/src/types/events/dummy.rs b/crates/matrix-sdk-crypto/src/types/events/dummy.rs new file mode 100644 index 00000000000..f5fb89a8deb --- /dev/null +++ b/crates/matrix-sdk-crypto/src/types/events/dummy.rs @@ -0,0 +1,66 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Types for `m.dummy` to-device events. + +use std::collections::BTreeMap; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use super::{EventType, ToDeviceEvent}; + +/// The `m.dummy` to-device event. +pub type DummyEvent = ToDeviceEvent; + +/// The content of an `m.dummy` event. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct DummyEventContent { + /// Any other, custom and non-specced fields of the content. + #[serde(flatten)] + other: BTreeMap, +} + +impl EventType for DummyEventContent { + const EVENT_TYPE: &'static str = "m.dummy"; +} + +#[cfg(test)] +pub(super) mod test { + use serde_json::{json, Value}; + + use super::DummyEvent; + + pub fn json() -> Value { + json!({ + "sender": "@alice:example.org", + "content": { + "m.custom": "something custom", + }, + "type": "m.dummy", + "m.custom.top": "something custom in the top", + }) + } + + #[test] + fn deserialization() -> Result<(), serde_json::Error> { + let json = json(); + let event: DummyEvent = serde_json::from_value(json.clone())?; + + let serialized = serde_json::to_value(event)?; + assert_eq!(json, serialized); + + Ok(()) + } +} diff --git a/crates/matrix-sdk-crypto/src/types/events/mod.rs b/crates/matrix-sdk-crypto/src/types/events/mod.rs index 7a0464dffa0..da849998ba2 100644 --- a/crates/matrix-sdk-crypto/src/types/events/mod.rs +++ b/crates/matrix-sdk-crypto/src/types/events/mod.rs @@ -18,6 +18,7 @@ //! types. Once deserialized they aim to zeroize all the secret material once //! the type is dropped. +pub mod dummy; pub mod forwarded_room_key; pub mod olm_v1; pub mod room; diff --git a/crates/matrix-sdk-crypto/src/types/events/olm_v1.rs b/crates/matrix-sdk-crypto/src/types/events/olm_v1.rs index e1de4acb17a..a9652b58e5f 100644 --- a/crates/matrix-sdk-crypto/src/types/events/olm_v1.rs +++ b/crates/matrix-sdk-crypto/src/types/events/olm_v1.rs @@ -23,6 +23,7 @@ use serde_json::value::RawValue; use vodozemac::Ed25519PublicKey; use super::{ + dummy::DummyEventContent, forwarded_room_key::ForwardedRoomKeyContent, room_key::RoomKeyContent, room_key_request::{self, SupportedKeyInfo}, @@ -31,6 +32,10 @@ use super::{ }; use crate::types::{deserialize_ed25519_key, events::from_str, serialize_ed25519_key}; +/// An `m.dummy` event that was decrypted using the +/// `m.olm.v1.curve25519-aes-sha2` algorithm +pub type DecryptedDummyEvent = DecryptedOlmV1Event; + /// An `m.room_key` event that was decrypted using the /// `m.olm.v1.curve25519-aes-sha2` algorithm pub type DecryptedRoomKeyEvent = DecryptedOlmV1Event; @@ -82,6 +87,8 @@ pub enum AnyDecryptedOlmEvent { ForwardedRoomKey(DecryptedForwardedRoomKeyEvent), /// The `m.secret.send` decrypted to-device event. SecretSend(DecryptedSecretSendEvent), + /// The `m.dummy` decrypted to-device event. + Dummy(DecryptedDummyEvent), /// A decrypted to-device event of an unknown or custom type. Custom(Box), } @@ -94,6 +101,7 @@ impl AnyDecryptedOlmEvent { AnyDecryptedOlmEvent::ForwardedRoomKey(e) => &e.sender, AnyDecryptedOlmEvent::SecretSend(e) => &e.sender, AnyDecryptedOlmEvent::Custom(e) => &e.sender, + AnyDecryptedOlmEvent::Dummy(e) => &e.sender, } } @@ -104,6 +112,7 @@ impl AnyDecryptedOlmEvent { AnyDecryptedOlmEvent::ForwardedRoomKey(e) => &e.recipient, AnyDecryptedOlmEvent::SecretSend(e) => &e.recipient, AnyDecryptedOlmEvent::Custom(e) => &e.recipient, + AnyDecryptedOlmEvent::Dummy(e) => &e.recipient, } } @@ -114,6 +123,7 @@ impl AnyDecryptedOlmEvent { AnyDecryptedOlmEvent::ForwardedRoomKey(e) => &e.keys, AnyDecryptedOlmEvent::SecretSend(e) => &e.keys, AnyDecryptedOlmEvent::Custom(e) => &e.keys, + AnyDecryptedOlmEvent::Dummy(e) => &e.keys, } } @@ -124,6 +134,7 @@ impl AnyDecryptedOlmEvent { AnyDecryptedOlmEvent::ForwardedRoomKey(e) => &e.recipient_keys, AnyDecryptedOlmEvent::SecretSend(e) => &e.recipient_keys, AnyDecryptedOlmEvent::Custom(e) => &e.recipient_keys, + AnyDecryptedOlmEvent::Dummy(e) => &e.recipient_keys, } } @@ -134,6 +145,7 @@ impl AnyDecryptedOlmEvent { AnyDecryptedOlmEvent::RoomKey(e) => e.content.event_type(), AnyDecryptedOlmEvent::ForwardedRoomKey(e) => e.content.event_type(), AnyDecryptedOlmEvent::SecretSend(e) => e.content.event_type(), + AnyDecryptedOlmEvent::Dummy(e) => e.content.event_type(), } } } diff --git a/crates/matrix-sdk-crypto/src/types/events/to_device.rs b/crates/matrix-sdk-crypto/src/types/events/to_device.rs index 092e3f0da52..48ee6c79046 100644 --- a/crates/matrix-sdk-crypto/src/types/events/to_device.rs +++ b/crates/matrix-sdk-crypto/src/types/events/to_device.rs @@ -16,7 +16,6 @@ use std::{collections::BTreeMap, fmt::Debug}; use ruma::{ events::{ - dummy::ToDeviceDummyEvent, key::verification::{ accept::ToDeviceKeyVerificationAcceptEvent, cancel::ToDeviceKeyVerificationCancelEvent, done::ToDeviceKeyVerificationDoneEvent, key::ToDeviceKeyVerificationKeyEvent, @@ -37,6 +36,7 @@ use serde_json::{ use zeroize::Zeroize; use super::{ + dummy::DummyEvent, forwarded_room_key::{ForwardedRoomKeyContent, ForwardedRoomKeyEvent}, room::encrypted::EncryptedToDeviceEvent, room_key::RoomKeyEvent, @@ -52,7 +52,7 @@ pub enum ToDeviceEvents { /// A to-device event of an unknown or custom type. Custom(ToDeviceCustomEvent), /// The `m.dummy` to-device event. - Dummy(ToDeviceDummyEvent), + Dummy(DummyEvent), /// The `m.key.verification.accept` to-device event. KeyVerificationAccept(ToDeviceKeyVerificationAcceptEvent), @@ -115,7 +115,7 @@ impl ToDeviceEvents { pub fn event_type(&self) -> ToDeviceEventType { match self { ToDeviceEvents::Custom(e) => ToDeviceEventType::from(e.event_type.to_owned()), - ToDeviceEvents::Dummy(e) => e.content.event_type(), + ToDeviceEvents::Dummy(_) => ToDeviceEventType::Dummy, ToDeviceEvents::KeyVerificationAccept(e) => e.content.event_type(), ToDeviceEvents::KeyVerificationCancel(e) => e.content.event_type(), diff --git a/crates/matrix-sdk-crypto/src/types/mod.rs b/crates/matrix-sdk-crypto/src/types/mod.rs index 48d65c8c326..59ae3c017c8 100644 --- a/crates/matrix-sdk-crypto/src/types/mod.rs +++ b/crates/matrix-sdk-crypto/src/types/mod.rs @@ -24,7 +24,7 @@ //! data will. mod backup; -mod cross_signing_key; +mod cross_signing; mod device_keys; pub mod events; mod one_time_keys; @@ -38,7 +38,7 @@ use std::{ }; pub use backup::*; -pub use cross_signing_key::*; +pub use cross_signing::*; pub use device_keys::*; pub use one_time_keys::*; use ruma::{ diff --git a/crates/matrix-sdk-crypto/src/utilities.rs b/crates/matrix-sdk-crypto/src/utilities.rs index 36d37d84e29..09c7626dd5e 100644 --- a/crates/matrix-sdk-crypto/src/utilities.rs +++ b/crates/matrix-sdk-crypto/src/utilities.rs @@ -22,22 +22,23 @@ use std::{ pub use base64::DecodeError; use base64::{ - alphabet, decode_engine, encode_engine, - engine::fast_portable::{self, FastPortable}, + alphabet, + engine::{general_purpose, GeneralPurpose}, + Engine, }; use matrix_sdk_common::instant::Instant; -const STANDARD_NO_PAD: FastPortable = - FastPortable::from(&alphabet::STANDARD, fast_portable::NO_PAD); +const STANDARD_NO_PAD: GeneralPurpose = + GeneralPurpose::new(&alphabet::STANDARD, general_purpose::NO_PAD); /// Decode the input as base64 with no padding. pub fn decode(input: impl AsRef<[u8]>) -> Result, DecodeError> { - decode_engine(input, &STANDARD_NO_PAD) + STANDARD_NO_PAD.decode(input) } /// Encode the input as base64 with no padding. pub fn encode(input: impl AsRef<[u8]>) -> String { - encode_engine(input, &STANDARD_NO_PAD) + STANDARD_NO_PAD.encode(input) } #[cfg(test)] diff --git a/crates/matrix-sdk-crypto/src/verification/machine.rs b/crates/matrix-sdk-crypto/src/verification/machine.rs index b51cb88860a..fac58db70f3 100644 --- a/crates/matrix-sdk-crypto/src/verification/machine.rs +++ b/crates/matrix-sdk-crypto/src/verification/machine.rs @@ -40,7 +40,7 @@ use super::{ use crate::{ olm::PrivateCrossSigningIdentity, requests::OutgoingRequest, - store::{CryptoStore, CryptoStoreError}, + store::{CryptoStoreError, DynCryptoStore}, OutgoingVerificationRequest, ReadOnlyAccount, ReadOnlyDevice, ReadOnlyUserIdentity, RoomMessageRequest, ToDeviceRequest, }; @@ -56,7 +56,7 @@ impl VerificationMachine { pub(crate) fn new( account: ReadOnlyAccount, identity: Arc>, - store: Arc, + store: Arc, ) -> Self { Self { store: VerificationStore { account, private_identity: identity, inner: store }, @@ -534,7 +534,7 @@ mod tests { use super::{Sas, VerificationMachine}; use crate::{ olm::PrivateCrossSigningIdentity, - store::MemoryStore, + store::{IntoCryptoStore, MemoryStore}, verification::{ cache::VerificationCache, event_enums::{AcceptContent, KeyContent, MacContent, OutgoingContent}, @@ -579,7 +579,7 @@ mod tests { let alice = ReadOnlyAccount::new(alice_id(), alice_device_id()); let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(alice_id()))); let store = MemoryStore::new(); - let _ = VerificationMachine::new(alice, identity, Arc::new(store)); + let _ = VerificationMachine::new(alice, identity, store.into_crypto_store()); } #[async_test] diff --git a/crates/matrix-sdk-crypto/src/verification/mod.rs b/crates/matrix-sdk-crypto/src/verification/mod.rs index 6b417c95456..ed97fb85777 100644 --- a/crates/matrix-sdk-crypto/src/verification/mod.rs +++ b/crates/matrix-sdk-crypto/src/verification/mod.rs @@ -52,7 +52,7 @@ use crate::{ error::SignatureError, gossiping::{GossipMachine, GossipRequest}, olm::{PrivateCrossSigningIdentity, ReadOnlyAccount, Session}, - store::{Changes, CryptoStore}, + store::{Changes, DynCryptoStore}, types::Signatures, CryptoStoreError, LocalTrust, OutgoingVerificationRequest, ReadOnlyDevice, ReadOnlyOwnUserIdentity, ReadOnlyUserIdentities, @@ -62,7 +62,7 @@ use crate::{ pub(crate) struct VerificationStore { pub account: ReadOnlyAccount, pub private_identity: Arc>, - inner: Arc, + inner: Arc, } /// An emoji that is used for interactive verification using a short auth @@ -181,7 +181,7 @@ impl VerificationStore { .map(|d| d.signatures().to_owned())) } - pub fn inner(&self) -> &dyn CryptoStore { + pub fn inner(&self) -> &DynCryptoStore { self.inner.deref() } } @@ -812,15 +812,13 @@ pub(crate) mod tests { #[cfg(test)] mod test { - use std::sync::Arc; - use matrix_sdk_common::locks::Mutex; use ruma::{device_id, user_id, DeviceId, UserId}; use super::VerificationStore; use crate::{ olm::PrivateCrossSigningIdentity, - store::{Changes, CryptoStore, IdentityChanges, MemoryStore}, + store::{Changes, CryptoStore, IdentityChanges, IntoCryptoStore, MemoryStore}, ReadOnlyAccount, ReadOnlyDevice, ReadOnlyOwnUserIdentity, ReadOnlyUserIdentity, }; @@ -886,13 +884,13 @@ mod test { let alice_store = VerificationStore { account: alice, - inner: Arc::new(alice_store), + inner: alice_store.into_crypto_store(), private_identity: alice_private_identity.into(), }; let bob_store = VerificationStore { account: bob.clone(), - inner: Arc::new(bob_store), + inner: bob_store.into_crypto_store(), private_identity: bob_private_identity.into(), }; diff --git a/crates/matrix-sdk-crypto/src/verification/qrcode.rs b/crates/matrix-sdk-crypto/src/verification/qrcode.rs index da9637cdd1e..f9f999dd1b5 100644 --- a/crates/matrix-sdk-crypto/src/verification/qrcode.rs +++ b/crates/matrix-sdk-crypto/src/verification/qrcode.rs @@ -14,8 +14,8 @@ use std::sync::Arc; +use eyeball::shared::{Observable as SharedObservable, ObservableWriteGuard}; use futures_core::Stream; -use futures_signals::signal::{Mutable, SignalExt}; use futures_util::StreamExt; use matrix_sdk_qrcode::{ qrcode::QrCode, EncodingError, QrVerificationData, SelfVerificationData, @@ -135,7 +135,7 @@ impl From<&InnerState> for QrVerificationState { pub struct QrVerification { flow_id: FlowId, inner: Arc, - state: Arc>, + state: SharedObservable, identities: IdentitiesBeingVerified, request_handle: Option, we_started: bool, @@ -145,8 +145,8 @@ impl std::fmt::Debug for QrVerification { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("QrVerification") .field("flow_id", &self.flow_id) - .field("inner", self.inner.as_ref()) - .field("state", &self.state.lock_ref()) + .field("inner", &self.inner) + .field("state", &self.state) .finish() } } @@ -157,12 +157,12 @@ impl QrVerification { /// When the verification object is in this state it's required that the /// user confirms that the other side has scanned the QR code. pub fn has_been_scanned(&self) -> bool { - matches!(&*self.state.lock_ref(), InnerState::Scanned(_)) + matches!(*self.state.read(), InnerState::Scanned(_)) } /// Has the scanning of the QR code been confirmed by us. pub fn has_been_confirmed(&self) -> bool { - matches!(&*self.state.lock_ref(), InnerState::Confirmed(_)) + matches!(*self.state.read(), InnerState::Confirmed(_)) } /// Get our own user id. @@ -189,7 +189,7 @@ impl QrVerification { /// Get info about the cancellation if the verification flow has been /// cancelled. pub fn cancel_info(&self) -> Option { - if let InnerState::Cancelled(c) = &*self.state.lock_ref() { + if let InnerState::Cancelled(c) = &*self.state.read() { Some(c.state.clone().into()) } else { None @@ -198,12 +198,12 @@ impl QrVerification { /// Has the verification flow completed. pub fn is_done(&self) -> bool { - matches!(&*self.state.lock_ref(), InnerState::Done(_)) + matches!(*self.state.read(), InnerState::Done(_)) } /// Has the verification flow been cancelled. pub fn is_cancelled(&self) -> bool { - matches!(&*self.state.lock_ref(), InnerState::Cancelled(_)) + matches!(*self.state.read(), InnerState::Cancelled(_)) } /// Is this a verification that is veryfying one of our own devices @@ -214,7 +214,7 @@ impl QrVerification { /// Have we successfully scanned the QR code and are able to send a /// reciprocation event. pub fn reciprocated(&self) -> bool { - matches!(&*self.state.lock_ref(), InnerState::Reciprocated(_)) + matches!(*self.state.read(), InnerState::Reciprocated(_)) } /// Get the unique ID that identifies this QR code verification flow. @@ -268,7 +268,7 @@ impl QrVerification { /// /// [`cancel()`]: #method.cancel pub fn cancel_with_code(&self, code: CancelCode) -> Option { - let mut state = self.state.lock_mut(); + let mut state = self.state.write(); if let Some(request) = &self.request_handle { request.cancel_with_code(&code); @@ -283,7 +283,7 @@ impl QrVerification { | InnerState::Scanned(_) | InnerState::Reciprocated(_) | InnerState::Done(_) => { - *state = InnerState::Cancelled(new_state); + ObservableWriteGuard::set(&mut state, InnerState::Cancelled(new_state)); Some(self.content_to_request(content)) } InnerState::Cancelled(_) => None, @@ -296,7 +296,7 @@ impl QrVerification { /// This will return some `OutgoingContent` if the object is in the correct /// state to start the verification flow, otherwise `None`. pub fn reciprocate(&self) -> Option { - match &*self.state.lock_ref() { + match &*self.state.read() { InnerState::Reciprocated(s) => { Some(self.content_to_request(s.as_content(self.flow_id()))) } @@ -310,13 +310,13 @@ impl QrVerification { /// Confirm that the other side has scanned our QR code. pub fn confirm_scanning(&self) -> Option { - let mut state = self.state.lock_mut(); + let mut state = self.state.write(); match &*state { InnerState::Scanned(s) => { let new_state = s.clone().confirm_scanning(); let content = new_state.as_content(&self.flow_id); - *state = InnerState::Confirmed(new_state); + ObservableWriteGuard::set(&mut state, InnerState::Confirmed(new_state)); Some(self.content_to_request(content)) } @@ -366,9 +366,7 @@ impl QrVerification { VerificationResult::SignatureUpload(s) => (None, Some(s)), }; - let mut guard = self.state.lock_mut(); - *guard = new_state; - + self.state.set(new_state); Ok((content.map(|c| self.content_to_request(c)), request)) } @@ -379,7 +377,7 @@ impl QrVerification { (Option, Option), CryptoStoreError, > { - let state = (*self.state.lock_ref()).clone(); + let state = self.state.get(); Ok(match state { InnerState::Confirmed(s) => { @@ -432,17 +430,17 @@ impl QrVerification { &self, content: &StartContent<'_>, ) -> Option { - let mut state = self.state.lock_mut(); + let mut state = self.state.write(); match &*state { InnerState::Created(s) => match s.clone().receive_reciprocate(content) { Ok(s) => { - *state = InnerState::Scanned(s); + ObservableWriteGuard::set(&mut state, InnerState::Scanned(s)); None } Err(s) => { let content = s.as_content(self.flow_id()); - *state = InnerState::Cancelled(s); + ObservableWriteGuard::set(&mut state, InnerState::Cancelled(s)); Some(self.content_to_request(content)) } }, @@ -456,7 +454,7 @@ impl QrVerification { pub(crate) fn receive_cancel(&self, sender: &UserId, content: &CancelContent<'_>) { if sender == self.other_user_id() { - let mut state = self.state.lock_mut(); + let mut state = self.state.write(); let new_state = match &*state { InnerState::Created(s) => s.clone().into_cancelled(content), @@ -472,7 +470,7 @@ impl QrVerification { "Cancelling a QR verification, other user has cancelled" ); - *state = InnerState::Cancelled(new_state); + ObservableWriteGuard::set(&mut state, InnerState::Cancelled(new_state)); } } @@ -630,10 +628,9 @@ impl QrVerification { Ok(Self { flow_id, inner: qr_code.into(), - state: Mutable::new(InnerState::Reciprocated(QrState { + state: SharedObservable::new(InnerState::Reciprocated(QrState { state: Reciprocated { secret, own_device_id }, - })) - .into(), + })), identities, we_started, request_handle, @@ -652,7 +649,9 @@ impl QrVerification { Self { flow_id, inner: inner.into(), - state: Mutable::new(InnerState::Created(QrState { state: Created { secret } })).into(), + state: SharedObservable::new(InnerState::Created(QrState { + state: Created { secret }, + })), identities, we_started, request_handle, @@ -663,7 +662,7 @@ impl QrVerification { /// /// The changes are presented as a stream of [`QrVerificationState`] values. pub fn changes(&self) -> impl Stream { - self.state.signal_cloned().to_stream().map(|s| (&s).into()) + self.state.subscribe().map(|s| (&s).into()) } /// Get the current state the verification process is in. @@ -671,7 +670,7 @@ impl QrVerification { /// To listen to changes to the [`QrVerificationState`] use the /// [`QrVerification::changes`] method. pub fn state(&self) -> QrVerificationState { - (&*self.state.lock_ref()).into() + (&*self.state.read()).into() } } @@ -855,7 +854,7 @@ mod tests { use crate::{ olm::{PrivateCrossSigningIdentity, ReadOnlyAccount}, - store::{Changes, CryptoStore, MemoryStore}, + store::{Changes, DynCryptoStore, IntoCryptoStore, MemoryStore}, verification::{ event_enums::{DoneContent, OutgoingContent, StartContent}, FlowId, VerificationStore, @@ -867,8 +866,8 @@ mod tests { user_id!("@example:localhost") } - fn memory_store() -> Arc { - Arc::new(MemoryStore::new()) + fn memory_store() -> Arc { + MemoryStore::new().into_crypto_store() } fn device_id() -> &'static DeviceId { diff --git a/crates/matrix-sdk-crypto/src/verification/requests.rs b/crates/matrix-sdk-crypto/src/verification/requests.rs index 21ffd2e27e3..1709f5d2139 100644 --- a/crates/matrix-sdk-crypto/src/verification/requests.rs +++ b/crates/matrix-sdk-crypto/src/verification/requests.rs @@ -14,8 +14,8 @@ use std::{sync::Arc, time::Duration}; +use eyeball::shared::{Observable as SharedObservable, ObservableWriteGuard}; use futures_core::Stream; -use futures_signals::signal::{Mutable, SignalExt}; use futures_util::StreamExt; use matrix_sdk_common::instant::Instant; #[cfg(feature = "qrcode")] @@ -136,7 +136,7 @@ pub struct VerificationRequest { account: ReadOnlyAccount, flow_id: Arc, other_user_id: Arc, - inner: Arc>, + inner: SharedObservable, creation_time: Arc, we_started: bool, recipient_devices: Arc>, @@ -152,20 +152,20 @@ pub struct VerificationRequest { /// `VerificationRequest` object. #[derive(Clone, Debug)] pub(crate) struct RequestHandle { - inner: Arc>, + inner: SharedObservable, } impl RequestHandle { pub fn cancel_with_code(&self, cancel_code: &CancelCode) { - let mut guard = self.inner.lock_mut(); + let mut guard = self.inner.write(); if let Some(updated) = guard.cancel(true, cancel_code) { - *guard = updated; + ObservableWriteGuard::set(&mut guard, updated); } } } -impl From>> for RequestHandle { - fn from(inner: Arc>) -> Self { +impl From> for RequestHandle { + fn from(inner: SharedObservable) -> Self { Self { inner } } } @@ -180,14 +180,13 @@ impl VerificationRequest { methods: Option>, ) -> Self { let account = store.account.clone(); - let inner = Mutable::new(InnerRequest::Created(RequestState::new( + let inner = SharedObservable::new(InnerRequest::Created(RequestState::new( cache.clone(), store, other_user, &flow_id, methods, - ))) - .into(); + ))); Self { account, @@ -206,7 +205,7 @@ impl VerificationRequest { /// self-verifications and it should be sent to the specific device that we /// want to verify. pub(crate) fn request_to_device(&self) -> ToDeviceRequest { - let inner = self.inner.lock_ref(); + let inner = self.inner.read(); let methods = if let InnerRequest::Created(c) = &*inner { c.state.our_methods.clone() @@ -264,7 +263,7 @@ impl VerificationRequest { /// The id of the other device that is participating in this verification. pub fn other_device_id(&self) -> Option { - match &*self.inner.lock_ref() { + match &*self.inner.read() { InnerRequest::Requested(r) => Some(r.state.other_device_id.clone()), InnerRequest::Ready(r) => Some(r.state.other_device_id.clone()), InnerRequest::Created(_) @@ -285,7 +284,7 @@ impl VerificationRequest { /// Get info about the cancellation if the verification request has been /// cancelled. pub fn cancel_info(&self) -> Option { - if let InnerRequest::Cancelled(c) = &*self.inner.lock_ref() { + if let InnerRequest::Cancelled(c) = &*self.inner.read() { Some(c.state.clone().into()) } else { None @@ -294,12 +293,12 @@ impl VerificationRequest { /// Has the verification request been answered by another device. pub fn is_passive(&self) -> bool { - matches!(&*self.inner.lock_ref(), InnerRequest::Passive(_)) + matches!(*self.inner.read(), InnerRequest::Passive(_)) } /// Is the verification request ready to start a verification flow. pub fn is_ready(&self) -> bool { - matches!(&*self.inner.lock_ref(), InnerRequest::Ready(_)) + matches!(*self.inner.read(), InnerRequest::Ready(_)) } /// Has the verification flow timed out. @@ -312,7 +311,7 @@ impl VerificationRequest { /// Will be present only if the other side requested the verification or if /// we're in the ready state. pub fn their_supported_methods(&self) -> Option> { - match &*self.inner.lock_ref() { + match &*self.inner.read() { InnerRequest::Requested(r) => Some(r.state.their_methods.clone()), InnerRequest::Ready(r) => Some(r.state.their_methods.clone()), InnerRequest::Created(_) @@ -327,7 +326,7 @@ impl VerificationRequest { /// Will be present only we requested the verification or if we're in the /// ready state. pub fn our_supported_methods(&self) -> Option> { - match &*self.inner.lock_ref() { + match &*self.inner.read() { InnerRequest::Created(r) => Some(r.state.our_methods.clone()), InnerRequest::Ready(r) => Some(r.state.our_methods.clone()), InnerRequest::Requested(_) @@ -354,21 +353,20 @@ impl VerificationRequest { /// Has the verification flow that was started with this request finished. pub fn is_done(&self) -> bool { - matches!(&*self.inner.lock_ref(), InnerRequest::Done(_)) + matches!(*self.inner.read(), InnerRequest::Done(_)) } /// Has the verification flow that was started with this request been /// cancelled. pub fn is_cancelled(&self) -> bool { - matches!(&*self.inner.lock_ref(), InnerRequest::Cancelled(_)) + matches!(*self.inner.read(), InnerRequest::Cancelled(_)) } /// Generate a QR code that can be used by another client to start a QR code /// based verification. #[cfg(feature = "qrcode")] pub async fn generate_qr_code(&self) -> Result, CryptoStoreError> { - let inner = self.inner.lock_ref().clone(); - + let inner = self.inner.get(); inner.generate_qr_code(self.we_started, self.inner.clone().into()).await } @@ -384,7 +382,7 @@ impl VerificationRequest { &self, data: QrVerificationData, ) -> Result, ScanError> { - let future = if let InnerRequest::Ready(r) = &*self.inner.lock_ref() { + let future = if let InnerRequest::Ready(r) = &*self.inner.read() { QrVerification::from_scan( r.store.clone(), r.other_user_id.clone(), @@ -398,6 +396,7 @@ impl VerificationRequest { return Ok(None); }; + // await future after self.inner read guard is released let qr_verification = future.await?; // We may have previously started our own QR verification (e.g. two devices @@ -437,9 +436,9 @@ impl VerificationRequest { Self { verification_cache: cache.clone(), - inner: Arc::new(Mutable::new(InnerRequest::Requested( + inner: SharedObservable::new(InnerRequest::Requested( RequestState::from_request_event(cache, store, sender, &flow_id, content), - ))), + )), account, other_user_id: sender.into(), flow_id: flow_id.into(), @@ -459,13 +458,13 @@ impl VerificationRequest { &self, methods: Vec, ) -> Option { - let mut guard = self.inner.lock_mut(); + let mut guard = self.inner.write(); let Some((updated, content)) = guard.accept(methods) else { return None; }; - *guard = updated; + ObservableWriteGuard::set(&mut guard, updated); let request = match content { OutgoingContent::ToDevice(content) => ToDeviceRequest::with_id( @@ -505,13 +504,13 @@ impl VerificationRequest { } fn cancel_with_code(&self, cancel_code: CancelCode) -> Option { - let mut guard = self.inner.lock_mut(); + let mut guard = self.inner.write(); - let send_to_everyone = self.we_started() && matches!(&*guard, InnerRequest::Created(_)); + let send_to_everyone = self.we_started() && matches!(*guard, InnerRequest::Created(_)); let other_device = guard.other_device_id(); if let Some(updated) = guard.cancel(true, &cancel_code) { - *guard = updated; + ObservableWriteGuard::set(&mut guard, updated); } let content = if let InnerRequest::Cancelled(c) = &*guard { @@ -630,11 +629,12 @@ impl VerificationRequest { } pub(crate) fn receive_ready(&self, sender: &UserId, content: &ReadyContent<'_>) { - let mut guard = self.inner.lock_mut(); + let mut guard = self.inner.write(); match &*guard { InnerRequest::Created(s) => { - *guard = InnerRequest::Ready(s.clone().into_ready(sender, content)); + let new_value = InnerRequest::Ready(s.clone().into_ready(sender, content)); + ObservableWriteGuard::set(&mut guard, new_value); if let Some(request) = self.cancel_for_other_devices(CancelCode::Accepted, Some(content.from_device())) @@ -645,7 +645,8 @@ impl VerificationRequest { InnerRequest::Requested(s) => { if sender == self.own_user_id() && content.from_device() != self.account.device_id() { - *guard = InnerRequest::Passive(s.clone().into_passive(content)) + let new_value = InnerRequest::Passive(s.clone().into_passive(content)); + ObservableWriteGuard::set(&mut guard, new_value); } } InnerRequest::Ready(_) @@ -660,7 +661,7 @@ impl VerificationRequest { sender: &UserId, content: &StartContent<'_>, ) -> Result<(), CryptoStoreError> { - let inner = self.inner.lock_ref().clone(); + let inner = self.inner.get(); let InnerRequest::Ready(s) = inner else { warn!( @@ -682,9 +683,9 @@ impl VerificationRequest { "Marking a verification request as done" ); - let mut guard = self.inner.lock_mut(); + let mut guard = self.inner.write(); if let Some(updated) = guard.receive_done(content) { - *guard = updated; + ObservableWriteGuard::set(&mut guard, updated); } } } @@ -699,9 +700,9 @@ impl VerificationRequest { code = content.cancel_code().as_str(), "Cancelling a verification request, other user has cancelled" ); - let mut guard = self.inner.lock_mut(); + let mut guard = self.inner.write(); if let Some(updated) = guard.cancel(false, content.cancel_code()) { - *guard = updated; + ObservableWriteGuard::set(&mut guard, updated); } if self.we_started() { @@ -717,7 +718,7 @@ impl VerificationRequest { pub async fn start_sas( &self, ) -> Result, CryptoStoreError> { - let inner = self.inner.lock_ref().clone(); + let inner = self.inner.get(); Ok(match &inner { InnerRequest::Ready(s) => { @@ -771,7 +772,7 @@ impl VerificationRequest { /// The changes are presented as a stream of [`VerificationRequestState`] /// values. pub fn changes(&self) -> impl Stream { - self.inner.signal_cloned().to_stream().map(|s| (&s).into()) + self.inner.subscribe().map(|s| (&s).into()) } /// Get the current state the verification request is in. @@ -779,7 +780,7 @@ impl VerificationRequest { /// To listen to changes to the [`VerificationRequestState`] use the /// [`VerificationRequest::changes`] method. pub fn state(&self) -> VerificationRequestState { - (&*self.inner.lock_ref()).into() + (&*self.inner.read()).into() } } diff --git a/crates/matrix-sdk-crypto/src/verification/sas/mod.rs b/crates/matrix-sdk-crypto/src/verification/sas/mod.rs index e77ab4cb8c0..efdf366c2b3 100644 --- a/crates/matrix-sdk-crypto/src/verification/sas/mod.rs +++ b/crates/matrix-sdk-crypto/src/verification/sas/mod.rs @@ -18,8 +18,8 @@ mod sas_state; use std::sync::Arc; +use eyeball::shared::{Observable as SharedObservable, ObservableWriteGuard}; use futures_core::Stream; -use futures_signals::signal::{Mutable, SignalExt}; use futures_util::StreamExt; use inner_sas::InnerSas; use ruma::{ @@ -49,7 +49,7 @@ use crate::{ /// Short authentication string object. #[derive(Clone, Debug)] pub struct Sas { - inner: Arc>, + inner: SharedObservable, account: ReadOnlyAccount, identities_being_verified: IdentitiesBeingVerified, flow_id: Arc, @@ -268,12 +268,12 @@ impl Sas { /// Does this verification flow support displaying emoji for the short /// authentication string. pub fn supports_emoji(&self) -> bool { - self.inner.lock_ref().supports_emoji() + self.inner.read().supports_emoji() } /// Did this verification flow start from a verification request. pub fn started_from_request(&self) -> bool { - self.inner.lock_ref().started_from_request() + self.inner.read().started_from_request() } /// Is this a verification that is veryfying one of our own devices. @@ -283,18 +283,18 @@ impl Sas { /// Have we confirmed that the short auth string matches. pub fn have_we_confirmed(&self) -> bool { - self.inner.lock_ref().have_we_confirmed() + self.inner.read().have_we_confirmed() } /// Has the verification been accepted by both parties. pub fn has_been_accepted(&self) -> bool { - self.inner.lock_ref().has_been_accepted() + self.inner.read().has_been_accepted() } /// Get info about the cancellation if the verification flow has been /// cancelled. pub fn cancel_info(&self) -> Option { - if let InnerSas::Cancelled(c) = &*self.inner.lock_ref() { + if let InnerSas::Cancelled(c) = &*self.inner.read() { Some(c.state.as_ref().clone().into()) } else { None @@ -309,7 +309,9 @@ impl Sas { #[cfg(test)] #[allow(dead_code)] pub(crate) fn set_creation_time(&self, time: matrix_sdk_common::instant::Instant) { - self.inner.lock_mut().set_creation_time(time) + self.inner.update(|inner| { + inner.set_creation_time(time); + }); } fn start_helper( @@ -331,7 +333,7 @@ impl Sas { ( Sas { - inner: Arc::new(Mutable::new(inner)), + inner: SharedObservable::new(inner), account, identities_being_verified: identities, flow_id: flow_id.into(), @@ -415,7 +417,7 @@ impl Sas { let account = identities.store.account.clone(); Ok(Sas { - inner: Arc::new(Mutable::new(inner)), + inner: SharedObservable::new(inner), account, identities_being_verified: identities, flow_id: flow_id.into(), @@ -445,12 +447,12 @@ impl Sas { let old_state = self.state_debug(); let request = { - let mut guard = self.inner.lock_mut(); + let mut guard = self.inner.write(); let sas: InnerSas = (*guard).clone(); let methods = settings.allowed_methods; if let Some((sas, content)) = sas.accept(methods) { - *guard = sas; + ObservableWriteGuard::set(&mut guard, sas); Some(match content { OwnedAcceptContent::ToDevice(c) => { @@ -493,12 +495,12 @@ impl Sas { ) -> Result<(Vec, Option), CryptoStoreError> { let (contents, done) = { - let mut guard = self.inner.lock_mut(); + let mut guard = self.inner.write(); let sas: InnerSas = (*guard).clone(); let (sas, contents) = sas.confirm(); - *guard = sas; + ObservableWriteGuard::set(&mut guard, sas); (contents, guard.is_done()) }; @@ -565,7 +567,7 @@ impl Sas { /// [`cancel()`]: #method.cancel pub fn cancel_with_code(&self, code: CancelCode) -> Option { let content = { - let mut guard = self.inner.lock_mut(); + let mut guard = self.inner.write(); if let Some(request) = &self.request_handle { request.cancel_with_code(&code); @@ -573,7 +575,7 @@ impl Sas { let sas: InnerSas = (*guard).clone(); let (sas, content) = sas.cancel(true, code); - *guard = sas; + ObservableWriteGuard::set(&mut guard, sas); content.map(|c| match c { OutgoingContent::Room(room_id, content) => { @@ -598,22 +600,22 @@ impl Sas { /// Has the SAS verification flow timed out. pub fn timed_out(&self) -> bool { - self.inner.lock_ref().timed_out() + self.inner.read().timed_out() } /// Are we in a state where we can show the short auth string. pub fn can_be_presented(&self) -> bool { - self.inner.lock_ref().can_be_presented() + self.inner.read().can_be_presented() } /// Is the SAS flow done. pub fn is_done(&self) -> bool { - self.inner.lock_ref().is_done() + self.inner.read().is_done() } /// Is the SAS flow canceled. pub fn is_cancelled(&self) -> bool { - self.inner.lock_ref().is_cancelled() + self.inner.read().is_cancelled() } /// Get the emoji version of the short auth string. @@ -621,7 +623,7 @@ impl Sas { /// Returns None if we can't yet present the short auth string, otherwise /// seven tuples containing the emoji and description. pub fn emoji(&self) -> Option<[Emoji; 7]> { - self.inner.lock_ref().emoji() + self.inner.read().emoji() } /// Get the index of the emoji representing the short auth string @@ -631,7 +633,7 @@ impl Sas { /// converted to an emoji using the /// [relevant spec entry](https://spec.matrix.org/unstable/client-server-api/#sas-method-emoji). pub fn emoji_index(&self) -> Option<[u8; 7]> { - self.inner.lock_ref().emoji_index() + self.inner.read().emoji_index() } /// Get the decimal version of the short auth string. @@ -640,7 +642,7 @@ impl Sas { /// tuple containing three 4-digit integers that represent the short auth /// string. pub fn decimals(&self) -> Option<(u16, u16, u16)> { - self.inner.lock_ref().decimals() + self.inner.read().decimals() } /// Listen for changes in the SAS verification process. @@ -732,16 +734,16 @@ impl Sas { /// # anyhow::Ok(()) }); /// ``` pub fn changes(&self) -> impl Stream { - self.inner.signal_cloned().to_stream().map(|s| (&s).into()) + self.inner.subscribe().map(|s| (&s).into()) } /// Get the current state of the verification process. pub fn state(&self) -> SasState { - (&*self.inner.lock_ref()).into() + (&*self.inner.read()).into() } fn state_debug(&self) -> State { - (&*self.inner.lock_ref()).into() + (&*self.inner.read()).into() } pub(crate) fn receive_any_event( @@ -752,11 +754,11 @@ impl Sas { let old_state = self.state_debug(); let content = { - let mut guard = self.inner.lock_mut(); + let mut guard = self.inner.write(); let sas: InnerSas = (*guard).clone(); let (sas, content) = sas.receive_any_event(sender, content); - *guard = sas; + ObservableWriteGuard::set(&mut guard, sas); content }; @@ -776,12 +778,12 @@ impl Sas { let old_state = self.state_debug(); { - let mut guard = self.inner.lock_mut(); + let mut guard = self.inner.write(); let sas: InnerSas = (*guard).clone(); if let Some(sas) = sas.mark_request_as_sent(request_id) { - *guard = sas; + ObservableWriteGuard::set(&mut guard, sas); } else { error!( flow_id = self.flow_id().as_str(), @@ -803,11 +805,11 @@ impl Sas { } pub(crate) fn verified_devices(&self) -> Option> { - self.inner.lock_ref().verified_devices() + self.inner.read().verified_devices() } pub(crate) fn verified_identities(&self) -> Option> { - self.inner.lock_ref().verified_identities() + self.inner.read().verified_identities() } pub(crate) fn content_to_request(&self, content: AnyToDeviceEventContent) -> ToDeviceRequest { @@ -851,8 +853,6 @@ impl AcceptSettings { #[cfg(test)] mod tests { - use std::sync::Arc; - use assert_matches::assert_matches; use matrix_sdk_common::locks::Mutex; use matrix_sdk_test::async_test; @@ -861,7 +861,7 @@ mod tests { use super::Sas; use crate::{ olm::PrivateCrossSigningIdentity, - store::MemoryStore, + store::{IntoCryptoStore, MemoryStore}, verification::{ event_enums::{AcceptContent, KeyContent, MacContent, OutgoingContent, StartContent}, VerificationStore, @@ -895,7 +895,7 @@ mod tests { let alice_store = VerificationStore { account: alice.clone(), - inner: Arc::new(MemoryStore::new()), + inner: MemoryStore::new().into_crypto_store(), private_identity: Mutex::new(PrivateCrossSigningIdentity::empty(alice_id())).into(), }; @@ -904,7 +904,7 @@ mod tests { let bob_store = VerificationStore { account: bob.clone(), - inner: Arc::new(bob_store), + inner: bob_store.into_crypto_store(), private_identity: Mutex::new(PrivateCrossSigningIdentity::empty(bob_id())).into(), }; diff --git a/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs b/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs index 3ab0921ae69..7fc0a3acee9 100644 --- a/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs +++ b/crates/matrix-sdk-crypto/src/verification/sas/sas_state.rs @@ -80,6 +80,7 @@ fn the_protocol_definitions() -> SasV1Content { short_authentication_string: STRINGS.to_vec(), key_agreement_protocols: KEY_AGREEMENT_PROTOCOLS.to_vec(), message_authentication_codes: vec![ + #[allow(deprecated)] MessageAuthenticationCode::HkdfHmacSha256, MessageAuthenticationCode::from("org.matrix.msc3783.hkdf-hmac-sha256"), ], @@ -1581,6 +1582,7 @@ mod tests { ], hashes: vec![HashAlgorithm::Sha256], message_authentication_codes: vec![ + #[allow(deprecated)] MessageAuthenticationCode::HkdfHmacSha256, MessageAuthenticationCode::from("org.matrix.msc3783.hkdf-hmac-sha256"), ], diff --git a/crates/matrix-sdk-indexeddb/Cargo.toml b/crates/matrix-sdk-indexeddb/Cargo.toml index b205178b16e..84247d624e6 100644 --- a/crates/matrix-sdk-indexeddb/Cargo.toml +++ b/crates/matrix-sdk-indexeddb/Cargo.toml @@ -22,7 +22,6 @@ anyhow = { workspace = true } async-trait = { workspace = true } base64 = { workspace = true } dashmap = { workspace = true, optional = true } -derive_builder = "0.11.2" gloo-utils = { version = "0.1", features = ["serde"] } indexed_db_futures = "0.3.0" js-sys = { version = "0.3.58" } @@ -42,9 +41,10 @@ web-sys = { version = "0.3.57", features = ["IdbKeyRange"] } getrandom = { version = "0.2.6", features = ["js"] } [dev-dependencies] +assert_matches = "1.5.0" matrix-sdk-base = { path = "../matrix-sdk-base", features = ["testing"] } matrix-sdk-common = { path = "../matrix-sdk-common", features = ["js"] } matrix-sdk-crypto = { path = "../matrix-sdk-crypto", features = ["js", "testing"] } matrix-sdk-test = { path = "../../testing/matrix-sdk-test" } -uuid = "1.0.0" +uuid = "1.3.0" wasm-bindgen-test = "0.3.33" diff --git a/crates/matrix-sdk-indexeddb/src/crypto_store.rs b/crates/matrix-sdk-indexeddb/src/crypto_store.rs index 5671804144a..17a2b60149d 100644 --- a/crates/matrix-sdk-indexeddb/src/crypto_store.rs +++ b/crates/matrix-sdk-indexeddb/src/crypto_store.rs @@ -28,6 +28,7 @@ use matrix_sdk_crypto::{ }, store::{ caches::SessionStore, BackupKeys, Changes, CryptoStore, CryptoStoreError, RoomKeyCounts, + RoomSettings, }, GossipRequest, ReadOnlyAccount, ReadOnlyDevice, ReadOnlyUserIdentities, SecretInfo, TrackedUser, @@ -40,9 +41,8 @@ use web_sys::IdbKeyRange; use crate::safe_encode::SafeEncode; -#[allow(non_snake_case)] -mod KEYS { - // STORES +mod keys { + // stores pub const CORE: &str = "core"; pub const SESSION: &str = "session"; @@ -60,13 +60,14 @@ mod KEYS { pub const UNSENT_SECRET_REQUESTS: &str = "unsent_secret_requests"; pub const SECRET_REQUESTS_BY_INFO: &str = "secret_requests_by_info"; pub const KEY_REQUEST: &str = "key_request"; + pub const ROOM_SETTINGS: &str = "room_settings"; - // KEYS + // keys pub const STORE_CIPHER: &str = "store_cipher"; pub const ACCOUNT: &str = "account"; pub const PRIVATE_IDENTITY: &str = "private_identity"; - // BACKUP v1 + // backup v1 pub const BACKUP_KEYS: &str = "backup_keys"; pub const BACKUP_KEY_V1: &str = "backup_key_v1"; pub const RECOVERY_KEY_V1: &str = "recovery_key_v1"; @@ -143,7 +144,7 @@ impl IndexeddbCryptoStore { let name = format!("{prefix:0}::matrix-sdk-crypto"); // Open my_db v1 - let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&name, 1.1)?; + let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&name, 2.0)?; db_req.set_on_upgrade_needed(Some(|evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { let old_version = evt.old_version(); @@ -151,21 +152,21 @@ impl IndexeddbCryptoStore { // migrating to version 1 let db = evt.db(); - db.create_object_store(KEYS::CORE)?; - db.create_object_store(KEYS::SESSION)?; + db.create_object_store(keys::CORE)?; + db.create_object_store(keys::SESSION)?; - db.create_object_store(KEYS::INBOUND_GROUP_SESSIONS)?; - db.create_object_store(KEYS::OUTBOUND_GROUP_SESSIONS)?; - db.create_object_store(KEYS::TRACKED_USERS)?; - db.create_object_store(KEYS::OLM_HASHES)?; - db.create_object_store(KEYS::DEVICES)?; + db.create_object_store(keys::INBOUND_GROUP_SESSIONS)?; + db.create_object_store(keys::OUTBOUND_GROUP_SESSIONS)?; + db.create_object_store(keys::TRACKED_USERS)?; + db.create_object_store(keys::OLM_HASHES)?; + db.create_object_store(keys::DEVICES)?; - db.create_object_store(KEYS::IDENTITIES)?; - db.create_object_store(KEYS::OUTGOING_SECRET_REQUESTS)?; - db.create_object_store(KEYS::UNSENT_SECRET_REQUESTS)?; - db.create_object_store(KEYS::SECRET_REQUESTS_BY_INFO)?; + db.create_object_store(keys::IDENTITIES)?; + db.create_object_store(keys::OUTGOING_SECRET_REQUESTS)?; + db.create_object_store(keys::UNSENT_SECRET_REQUESTS)?; + db.create_object_store(keys::SECRET_REQUESTS_BY_INFO)?; - db.create_object_store(KEYS::BACKUP_KEYS)?; + db.create_object_store(keys::BACKUP_KEYS)?; } else if old_version < 1.1 { // We changed how we store inbound group sessions, the key used to // be a trippled of `(room_id, sender_key, session_id)` now it's a @@ -175,8 +176,13 @@ impl IndexeddbCryptoStore { let db = evt.db(); - db.delete_object_store(KEYS::INBOUND_GROUP_SESSIONS)?; - db.create_object_store(KEYS::INBOUND_GROUP_SESSIONS)?; + db.delete_object_store(keys::INBOUND_GROUP_SESSIONS)?; + db.create_object_store(keys::INBOUND_GROUP_SESSIONS)?; + } + + if old_version < 2.0 { + let db = evt.db(); + db.create_object_store(keys::ROOM_SETTINGS)?; } Ok(()) @@ -250,7 +256,7 @@ impl IndexeddbCryptoStore { let ob = tx.object_store("matrix-sdk-crypto")?; let store_cipher: Option> = ob - .get(&JsValue::from_str(KEYS::STORE_CIPHER))? + .get(&JsValue::from_str(keys::STORE_CIPHER))? .await? .map(|k| k.into_serde()) .transpose()?; @@ -272,7 +278,7 @@ impl IndexeddbCryptoStore { let ob = tx.object_store("matrix-sdk-crypto")?; ob.put_key_val( - &JsValue::from_str(KEYS::STORE_CIPHER), + &JsValue::from_str(keys::STORE_CIPHER), &JsValue::from_serde(&export.map_err(CryptoStoreError::backend)?)?, )?; tx.await.into_result()?; @@ -317,25 +323,58 @@ impl IndexeddbCryptoStore { fn get_account_info(&self) -> Option { self.account_info.read().unwrap().clone() } +} + +// Small hack to have the following macro invocation act as the appropriate +// trait impl block on wasm, but still be compiled on non-wasm as a regular +// impl block otherwise. +// +// The trait impl doesn't compile on non-wasm due to unfulfilled trait bounds, +// this hack allows us to still have most of rust-analyzer's IDE functionality +// within the impl block without having to set it up to check things against +// the wasm target (which would disable many other parts of the codebase). +#[cfg(target_arch = "wasm32")] +macro_rules! impl_crypto_store { + ( $($body:tt)* ) => { + #[async_trait(?Send)] + impl CryptoStore for IndexeddbCryptoStore { + type Error = IndexeddbCryptoStoreError; + + $($body)* + } + }; +} + +#[cfg(not(target_arch = "wasm32"))] +macro_rules! impl_crypto_store { + ( $($body:tt)* ) => { + impl IndexeddbCryptoStore { + $($body)* + } + }; +} +impl_crypto_store! { async fn save_changes(&self, changes: Changes) -> Result<()> { let mut stores: Vec<&str> = [ - (changes.account.is_some() || changes.private_identity.is_some(), KEYS::CORE), - (changes.recovery_key.is_some() || changes.backup_version.is_some(), KEYS::BACKUP_KEYS), - (!changes.sessions.is_empty(), KEYS::SESSION), + (changes.account.is_some() || changes.private_identity.is_some(), keys::CORE), + (changes.recovery_key.is_some() || changes.backup_version.is_some(), keys::BACKUP_KEYS), + (!changes.sessions.is_empty(), keys::SESSION), ( !changes.devices.new.is_empty() || !changes.devices.changed.is_empty() || !changes.devices.deleted.is_empty(), - KEYS::DEVICES, + keys::DEVICES, ), ( !changes.identities.new.is_empty() || !changes.identities.changed.is_empty(), - KEYS::IDENTITIES, + keys::IDENTITIES, ), - (!changes.inbound_group_sessions.is_empty(), KEYS::INBOUND_GROUP_SESSIONS), - (!changes.outbound_group_sessions.is_empty(), KEYS::OUTBOUND_GROUP_SESSIONS), - (!changes.message_hashes.is_empty(), KEYS::OLM_HASHES), + + (!changes.inbound_group_sessions.is_empty(), keys::INBOUND_GROUP_SESSIONS), + (!changes.outbound_group_sessions.is_empty(), keys::OUTBOUND_GROUP_SESSIONS), + (!changes.message_hashes.is_empty(), keys::OLM_HASHES), + (!changes.room_settings.is_empty(), keys::ROOM_SETTINGS), ] .iter() .filter_map(|(id, key)| if *id { Some(*key) } else { None }) @@ -343,9 +382,9 @@ impl IndexeddbCryptoStore { if !changes.key_requests.is_empty() { stores.extend([ - KEYS::SECRET_REQUESTS_BY_INFO, - KEYS::UNSENT_SECRET_REQUESTS, - KEYS::OUTGOING_SECRET_REQUESTS, + keys::SECRET_REQUESTS_BY_INFO, + keys::UNSENT_SECRET_REQUESTS, + keys::OUTGOING_SECRET_REQUESTS, ]) } @@ -371,56 +410,56 @@ impl IndexeddbCryptoStore { }; let private_identity_pickle = - if let Some(i) = changes.private_identity { Some(i.pickle().await?) } else { None }; + if let Some(i) = changes.private_identity { Some(i.pickle().await) } else { None }; let recovery_key_pickle = changes.recovery_key; let backup_version = changes.backup_version; if let Some(a) = &account_pickle { - tx.object_store(KEYS::CORE)? - .put_key_val(&JsValue::from_str(KEYS::ACCOUNT), &self.serialize_value(&a)?)?; + tx.object_store(keys::CORE)? + .put_key_val(&JsValue::from_str(keys::ACCOUNT), &self.serialize_value(&a)?)?; } if let Some(i) = &private_identity_pickle { - tx.object_store(KEYS::CORE)?.put_key_val( - &JsValue::from_str(KEYS::PRIVATE_IDENTITY), + tx.object_store(keys::CORE)?.put_key_val( + &JsValue::from_str(keys::PRIVATE_IDENTITY), &self.serialize_value(i)?, )?; } if let Some(a) = &recovery_key_pickle { - tx.object_store(KEYS::BACKUP_KEYS)?.put_key_val( - &JsValue::from_str(KEYS::RECOVERY_KEY_V1), + tx.object_store(keys::BACKUP_KEYS)?.put_key_val( + &JsValue::from_str(keys::RECOVERY_KEY_V1), &self.serialize_value(&a)?, )?; } if let Some(a) = &backup_version { - tx.object_store(KEYS::BACKUP_KEYS)? - .put_key_val(&JsValue::from_str(KEYS::BACKUP_KEY_V1), &self.serialize_value(&a)?)?; + tx.object_store(keys::BACKUP_KEYS)? + .put_key_val(&JsValue::from_str(keys::BACKUP_KEY_V1), &self.serialize_value(&a)?)?; } if !changes.sessions.is_empty() { - let sessions = tx.object_store(KEYS::SESSION)?; + let sessions = tx.object_store(keys::SESSION)?; for session in &changes.sessions { let sender_key = session.sender_key().to_base64(); let session_id = session.session_id(); let pickle = session.pickle().await; - let key = self.encode_key(KEYS::SESSION, (&sender_key, session_id)); + let key = self.encode_key(keys::SESSION, (&sender_key, session_id)); sessions.put_key_val(&key, &self.serialize_value(&pickle)?)?; } } if !changes.inbound_group_sessions.is_empty() { - let sessions = tx.object_store(KEYS::INBOUND_GROUP_SESSIONS)?; + let sessions = tx.object_store(keys::INBOUND_GROUP_SESSIONS)?; for session in changes.inbound_group_sessions { let room_id = session.room_id(); let session_id = session.session_id(); - let key = self.encode_key(KEYS::INBOUND_GROUP_SESSIONS, (room_id, session_id)); + let key = self.encode_key(keys::INBOUND_GROUP_SESSIONS, (room_id, session_id)); let pickle = session.pickle().await; sessions.put_key_val(&key, &self.serialize_value(&pickle)?)?; @@ -428,13 +467,13 @@ impl IndexeddbCryptoStore { } if !changes.outbound_group_sessions.is_empty() { - let sessions = tx.object_store(KEYS::OUTBOUND_GROUP_SESSIONS)?; + let sessions = tx.object_store(keys::OUTBOUND_GROUP_SESSIONS)?; for session in changes.outbound_group_sessions { let room_id = session.room_id(); let pickle = session.pickle().await; sessions.put_key_val( - &self.encode_key(KEYS::OUTBOUND_GROUP_SESSIONS, room_id), + &self.encode_key(keys::OUTBOUND_GROUP_SESSIONS, room_id), &self.serialize_value(&pickle)?, )?; } @@ -444,11 +483,12 @@ impl IndexeddbCryptoStore { let identity_changes = changes.identities; let olm_hashes = changes.message_hashes; let key_requests = changes.key_requests; + let room_settings_changes = changes.room_settings; if !device_changes.new.is_empty() || !device_changes.changed.is_empty() { - let device_store = tx.object_store(KEYS::DEVICES)?; + let device_store = tx.object_store(keys::DEVICES)?; for device in device_changes.new.iter().chain(&device_changes.changed) { - let key = self.encode_key(KEYS::DEVICES, (device.user_id(), device.device_id())); + let key = self.encode_key(keys::DEVICES, (device.user_id(), device.device_id())); let device = self.serialize_value(&device)?; device_store.put_key_val(&key, &device)?; @@ -456,43 +496,43 @@ impl IndexeddbCryptoStore { } if !device_changes.deleted.is_empty() { - let device_store = tx.object_store(KEYS::DEVICES)?; + let device_store = tx.object_store(keys::DEVICES)?; for device in &device_changes.deleted { - let key = self.encode_key(KEYS::DEVICES, (device.user_id(), device.device_id())); + let key = self.encode_key(keys::DEVICES, (device.user_id(), device.device_id())); device_store.delete(&key)?; } } if !identity_changes.changed.is_empty() || !identity_changes.new.is_empty() { - let identities = tx.object_store(KEYS::IDENTITIES)?; + let identities = tx.object_store(keys::IDENTITIES)?; for identity in identity_changes.changed.iter().chain(&identity_changes.new) { identities.put_key_val( - &self.encode_key(KEYS::IDENTITIES, identity.user_id()), + &self.encode_key(keys::IDENTITIES, identity.user_id()), &self.serialize_value(&identity)?, )?; } } if !olm_hashes.is_empty() { - let hashes = tx.object_store(KEYS::OLM_HASHES)?; + let hashes = tx.object_store(keys::OLM_HASHES)?; for hash in &olm_hashes { hashes.put_key_val( - &self.encode_key(KEYS::OLM_HASHES, (&hash.sender_key, &hash.hash)), + &self.encode_key(keys::OLM_HASHES, (&hash.sender_key, &hash.hash)), &JsValue::TRUE, )?; } } if !key_requests.is_empty() { - let secret_requests_by_info = tx.object_store(KEYS::SECRET_REQUESTS_BY_INFO)?; - let unsent_secret_requests = tx.object_store(KEYS::UNSENT_SECRET_REQUESTS)?; - let outgoing_secret_requests = tx.object_store(KEYS::OUTGOING_SECRET_REQUESTS)?; + let secret_requests_by_info = tx.object_store(keys::SECRET_REQUESTS_BY_INFO)?; + let unsent_secret_requests = tx.object_store(keys::UNSENT_SECRET_REQUESTS)?; + let outgoing_secret_requests = tx.object_store(keys::OUTGOING_SECRET_REQUESTS)?; for key_request in &key_requests { let key_request_id = - self.encode_key(KEYS::KEY_REQUEST, key_request.request_id.as_str()); + self.encode_key(keys::KEY_REQUEST, key_request.request_id.as_str()); secret_requests_by_info.put_key_val( - &self.encode_key(KEYS::KEY_REQUEST, key_request.info.as_key()), + &self.encode_key(keys::KEY_REQUEST, key_request.info.as_key()), &key_request_id, )?; @@ -508,6 +548,16 @@ impl IndexeddbCryptoStore { } } + if !room_settings_changes.is_empty() { + let settings_store = tx.object_store(keys::ROOM_SETTINGS)?; + + for (room_id, settings) in &room_settings_changes { + let key = self.encode_key(keys::ROOM_SETTINGS, room_id); + let value = self.serialize_value(&settings)?; + settings_store.put_key_val(&key, &value)?; + } + } + tx.await.into_result()?; // all good, let's update our caches:indexeddb @@ -521,8 +571,8 @@ impl IndexeddbCryptoStore { async fn load_tracked_users(&self) -> Result> { let tx = self .inner - .transaction_on_one_with_mode(KEYS::TRACKED_USERS, IdbTransactionMode::Readonly)?; - let os = tx.object_store(KEYS::TRACKED_USERS)?; + .transaction_on_one_with_mode(keys::TRACKED_USERS, IdbTransactionMode::Readonly)?; + let os = tx.object_store(keys::TRACKED_USERS)?; let user_ids = os.get_all_keys()?.await?; let mut users = Vec::new(); @@ -538,7 +588,7 @@ impl IndexeddbCryptoStore { Ok(users) } - async fn load_outbound_group_session( + async fn get_outbound_group_session( &self, room_id: &RoomId, ) -> Result> { @@ -546,11 +596,11 @@ impl IndexeddbCryptoStore { if let Some(value) = self .inner .transaction_on_one_with_mode( - KEYS::OUTBOUND_GROUP_SESSIONS, + keys::OUTBOUND_GROUP_SESSIONS, IdbTransactionMode::Readonly, )? - .object_store(KEYS::OUTBOUND_GROUP_SESSIONS)? - .get(&self.encode_key(KEYS::OUTBOUND_GROUP_SESSIONS, room_id))? + .object_store(keys::OUTBOUND_GROUP_SESSIONS)? + .get(&self.encode_key(keys::OUTBOUND_GROUP_SESSIONS, room_id))? .await? { Ok(Some( @@ -565,14 +615,18 @@ impl IndexeddbCryptoStore { Ok(None) } } - async fn get_outgoing_key_request_helper(&self, key: &str) -> Result> { + + async fn get_outgoing_secret_requests( + &self, + request_id: &TransactionId, + ) -> Result> { // in this internal we expect key to already be escaped or encrypted - let jskey = JsValue::from_str(key); - let dbs = [KEYS::OUTGOING_SECRET_REQUESTS, KEYS::UNSENT_SECRET_REQUESTS]; + let jskey = JsValue::from_str(request_id.as_str()); + let dbs = [keys::OUTGOING_SECRET_REQUESTS, keys::UNSENT_SECRET_REQUESTS]; let tx = self.inner.transaction_on_multi_with_mode(&dbs, IdbTransactionMode::Readonly)?; let request = tx - .object_store(KEYS::OUTGOING_SECRET_REQUESTS)? + .object_store(keys::OUTGOING_SECRET_REQUESTS)? .get(&jskey)? .await? .map(|i| self.deserialize_value(i)) @@ -580,7 +634,7 @@ impl IndexeddbCryptoStore { Ok(match request { None => tx - .object_store(KEYS::UNSENT_SECRET_REQUESTS)? + .object_store(keys::UNSENT_SECRET_REQUESTS)? .get(&jskey)? .await? .map(|i| self.deserialize_value(i)) @@ -592,9 +646,9 @@ impl IndexeddbCryptoStore { async fn load_account(&self) -> Result> { if let Some(pickle) = self .inner - .transaction_on_one_with_mode(KEYS::CORE, IdbTransactionMode::Readonly)? - .object_store(KEYS::CORE)? - .get(&JsValue::from_str(KEYS::ACCOUNT))? + .transaction_on_one_with_mode(keys::CORE, IdbTransactionMode::Readonly)? + .object_store(keys::CORE)? + .get(&JsValue::from_str(keys::ACCOUNT))? .await? { let pickle = self.deserialize_value(pickle)?; @@ -615,12 +669,17 @@ impl IndexeddbCryptoStore { } } + async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { + self.save_changes(Changes { account: Some(account), ..Default::default() }) + .await + } + async fn load_identity(&self) -> Result> { if let Some(pickle) = self .inner - .transaction_on_one_with_mode(KEYS::CORE, IdbTransactionMode::Readonly)? - .object_store(KEYS::CORE)? - .get(&JsValue::from_str(KEYS::PRIVATE_IDENTITY))? + .transaction_on_one_with_mode(keys::CORE, IdbTransactionMode::Readonly)? + .object_store(keys::CORE)? + .get(&JsValue::from_str(keys::PRIVATE_IDENTITY))? .await? { let pickle = self.deserialize_value(pickle)?; @@ -639,11 +698,11 @@ impl IndexeddbCryptoStore { let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?; if self.session_cache.get(sender_key).is_none() { - let range = self.encode_to_range(KEYS::SESSION, sender_key)?; + let range = self.encode_to_range(keys::SESSION, sender_key)?; let sessions: Vec = self .inner - .transaction_on_one_with_mode(KEYS::SESSION, IdbTransactionMode::Readonly)? - .object_store(KEYS::SESSION)? + .transaction_on_one_with_mode(keys::SESSION, IdbTransactionMode::Readonly)? + .object_store(keys::SESSION)? .get_all_with_key(&range)? .await? .iter() @@ -669,14 +728,14 @@ impl IndexeddbCryptoStore { room_id: &RoomId, session_id: &str, ) -> Result> { - let key = self.encode_key(KEYS::INBOUND_GROUP_SESSIONS, (room_id, session_id)); + let key = self.encode_key(keys::INBOUND_GROUP_SESSIONS, (room_id, session_id)); if let Some(pickle) = self .inner .transaction_on_one_with_mode( - KEYS::INBOUND_GROUP_SESSIONS, + keys::INBOUND_GROUP_SESSIONS, IdbTransactionMode::Readonly, )? - .object_store(KEYS::INBOUND_GROUP_SESSIONS)? + .object_store(keys::INBOUND_GROUP_SESSIONS)? .get(&key)? .await? { @@ -691,10 +750,10 @@ impl IndexeddbCryptoStore { Ok(self .inner .transaction_on_one_with_mode( - KEYS::INBOUND_GROUP_SESSIONS, + keys::INBOUND_GROUP_SESSIONS, IdbTransactionMode::Readonly, )? - .object_store(KEYS::INBOUND_GROUP_SESSIONS)? + .object_store(keys::INBOUND_GROUP_SESSIONS)? .get_all()? .await? .iter() @@ -747,8 +806,8 @@ impl IndexeddbCryptoStore { async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<()> { let tx = self .inner - .transaction_on_one_with_mode(KEYS::TRACKED_USERS, IdbTransactionMode::Readwrite)?; - let os = tx.object_store(KEYS::TRACKED_USERS)?; + .transaction_on_one_with_mode(keys::TRACKED_USERS, IdbTransactionMode::Readwrite)?; + let os = tx.object_store(keys::TRACKED_USERS)?; for (user, dirty) in users { os.put_key_val(&JsValue::from_str(user.as_str()), &JsValue::from(*dirty))?; @@ -763,11 +822,11 @@ impl IndexeddbCryptoStore { user_id: &UserId, device_id: &DeviceId, ) -> Result> { - let key = self.encode_key(KEYS::DEVICES, (user_id, device_id)); + let key = self.encode_key(keys::DEVICES, (user_id, device_id)); Ok(self .inner - .transaction_on_one_with_mode(KEYS::DEVICES, IdbTransactionMode::Readonly)? - .object_store(KEYS::DEVICES)? + .transaction_on_one_with_mode(keys::DEVICES, IdbTransactionMode::Readonly)? + .object_store(keys::DEVICES)? .get(&key)? .await? .map(|i| self.deserialize_value(i)) @@ -778,11 +837,11 @@ impl IndexeddbCryptoStore { &self, user_id: &UserId, ) -> Result> { - let range = self.encode_to_range(KEYS::DEVICES, user_id)?; + let range = self.encode_to_range(keys::DEVICES, user_id)?; Ok(self .inner - .transaction_on_one_with_mode(KEYS::DEVICES, IdbTransactionMode::Readonly)? - .object_store(KEYS::DEVICES)? + .transaction_on_one_with_mode(keys::DEVICES, IdbTransactionMode::Readonly)? + .object_store(keys::DEVICES)? .get_all_with_key(&range)? .await? .iter() @@ -796,9 +855,9 @@ impl IndexeddbCryptoStore { async fn get_user_identity(&self, user_id: &UserId) -> Result> { Ok(self .inner - .transaction_on_one_with_mode(KEYS::IDENTITIES, IdbTransactionMode::Readonly)? - .object_store(KEYS::IDENTITIES)? - .get(&self.encode_key(KEYS::IDENTITIES, user_id))? + .transaction_on_one_with_mode(keys::IDENTITIES, IdbTransactionMode::Readonly)? + .object_store(keys::IDENTITIES)? + .get(&self.encode_key(keys::IDENTITIES, user_id))? .await? .map(|i| self.deserialize_value(i)) .transpose()?) @@ -807,9 +866,9 @@ impl IndexeddbCryptoStore { async fn is_message_known(&self, hash: &OlmMessageHash) -> Result { Ok(self .inner - .transaction_on_one_with_mode(KEYS::OLM_HASHES, IdbTransactionMode::Readonly)? - .object_store(KEYS::OLM_HASHES)? - .get(&self.encode_key(KEYS::OLM_HASHES, (&hash.sender_key, &hash.hash)))? + .transaction_on_one_with_mode(keys::OLM_HASHES, IdbTransactionMode::Readonly)? + .object_store(keys::OLM_HASHES)? + .get(&self.encode_key(keys::OLM_HASHES, (&hash.sender_key, &hash.hash)))? .await? .is_some()) } @@ -821,15 +880,15 @@ impl IndexeddbCryptoStore { let id = self .inner .transaction_on_one_with_mode( - KEYS::SECRET_REQUESTS_BY_INFO, + keys::SECRET_REQUESTS_BY_INFO, IdbTransactionMode::Readonly, )? - .object_store(KEYS::SECRET_REQUESTS_BY_INFO)? - .get(&self.encode_key(KEYS::KEY_REQUEST, key_info.as_key()))? + .object_store(keys::SECRET_REQUESTS_BY_INFO)? + .get(&self.encode_key(keys::KEY_REQUEST, key_info.as_key()))? .await? .and_then(|i| i.as_string()); if let Some(id) = id { - self.get_outgoing_key_request_helper(&id).await + self.get_outgoing_secret_requests(id.as_str().into()).await } else { Ok(None) } @@ -839,10 +898,10 @@ impl IndexeddbCryptoStore { Ok(self .inner .transaction_on_one_with_mode( - KEYS::UNSENT_SECRET_REQUESTS, + keys::UNSENT_SECRET_REQUESTS, IdbTransactionMode::Readonly, )? - .object_store(KEYS::UNSENT_SECRET_REQUESTS)? + .object_store(keys::UNSENT_SECRET_REQUESTS)? .get_all()? .await? .iter() @@ -851,16 +910,16 @@ impl IndexeddbCryptoStore { } async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> { - let jskey = self.encode_key(KEYS::KEY_REQUEST, request_id); //.as_str()); + let jskey = self.encode_key(keys::KEY_REQUEST, request_id); //.as_str()); let dbs = [ - KEYS::OUTGOING_SECRET_REQUESTS, - KEYS::UNSENT_SECRET_REQUESTS, - KEYS::SECRET_REQUESTS_BY_INFO, + keys::OUTGOING_SECRET_REQUESTS, + keys::UNSENT_SECRET_REQUESTS, + keys::SECRET_REQUESTS_BY_INFO, ]; let tx = self.inner.transaction_on_multi_with_mode(&dbs, IdbTransactionMode::Readwrite)?; let request: Option = tx - .object_store(KEYS::OUTGOING_SECRET_REQUESTS)? + .object_store(keys::OUTGOING_SECRET_REQUESTS)? .get(&jskey)? .await? .map(|i| self.deserialize_value(i)) @@ -868,7 +927,7 @@ impl IndexeddbCryptoStore { let request = match request { None => tx - .object_store(KEYS::UNSENT_SECRET_REQUESTS)? + .object_store(keys::UNSENT_SECRET_REQUESTS)? .get(&jskey)? .await? .map(|i| self.deserialize_value(i)) @@ -877,12 +936,12 @@ impl IndexeddbCryptoStore { }; if let Some(inner) = request { - tx.object_store(KEYS::SECRET_REQUESTS_BY_INFO)? - .delete(&self.encode_key(KEYS::KEY_REQUEST, inner.info.as_key()))?; + tx.object_store(keys::SECRET_REQUESTS_BY_INFO)? + .delete(&self.encode_key(keys::KEY_REQUEST, inner.info.as_key()))?; } - tx.object_store(KEYS::UNSENT_SECRET_REQUESTS)?.delete(&jskey)?; - tx.object_store(KEYS::OUTGOING_SECRET_REQUESTS)?.delete(&jskey)?; + tx.object_store(keys::UNSENT_SECRET_REQUESTS)?.delete(&jskey)?; + tx.object_store(keys::OUTGOING_SECRET_REQUESTS)?.delete(&jskey)?; tx.await.into_result().map_err(|e| e.into()) } @@ -891,17 +950,17 @@ impl IndexeddbCryptoStore { let key = { let tx = self .inner - .transaction_on_one_with_mode(KEYS::BACKUP_KEYS, IdbTransactionMode::Readonly)?; - let store = tx.object_store(KEYS::BACKUP_KEYS)?; + .transaction_on_one_with_mode(keys::BACKUP_KEYS, IdbTransactionMode::Readonly)?; + let store = tx.object_store(keys::BACKUP_KEYS)?; let backup_version = store - .get(&JsValue::from_str(KEYS::BACKUP_KEY_V1))? + .get(&JsValue::from_str(keys::BACKUP_KEY_V1))? .await? .map(|i| self.deserialize_value(i)) .transpose()?; let recovery_key = store - .get(&JsValue::from_str(KEYS::RECOVERY_KEY_V1))? + .get(&JsValue::from_str(keys::RECOVERY_KEY_V1))? .await? .map(|i| self.deserialize_value(i)) .transpose()?; @@ -911,141 +970,45 @@ impl IndexeddbCryptoStore { Ok(key) } -} - -impl Drop for IndexeddbCryptoStore { - fn drop(&mut self) { - // Must release the database access manually as it's not done when - // dropping it. - self.inner.close(); - } -} - -#[cfg(target_arch = "wasm32")] -#[async_trait(?Send)] -impl CryptoStore for IndexeddbCryptoStore { - async fn load_account(&self) -> Result, CryptoStoreError> { - self.load_account().await.map_err(|e| e.into()) - } - - async fn save_account(&self, account: ReadOnlyAccount) -> Result<(), CryptoStoreError> { - self.save_changes(Changes { account: Some(account), ..Default::default() }) - .await - .map_err(|e| e.into()) - } - - async fn load_identity(&self) -> Result, CryptoStoreError> { - self.load_identity().await.map_err(|e| e.into()) - } - - async fn save_changes(&self, changes: Changes) -> Result<(), CryptoStoreError> { - self.save_changes(changes).await.map_err(|e| e.into()) - } - - async fn get_sessions( - &self, - sender_key: &str, - ) -> Result>>>, CryptoStoreError> { - self.get_sessions(sender_key).await.map_err(|e| e.into()) - } - - async fn get_inbound_group_session( - &self, - room_id: &RoomId, - session_id: &str, - ) -> Result, CryptoStoreError> { - self.get_inbound_group_session(room_id, session_id).await.map_err(|e| e.into()) - } - - async fn get_inbound_group_sessions( - &self, - ) -> Result, CryptoStoreError> { - self.get_inbound_group_sessions().await.map_err(|e| e.into()) - } - - async fn get_outbound_group_session( - &self, - room_id: &RoomId, - ) -> Result, CryptoStoreError> { - self.load_outbound_group_session(room_id).await.map_err(|e| e.into()) - } - - async fn inbound_group_session_counts(&self) -> Result { - self.inbound_group_session_counts().await.map_err(|e| e.into()) - } - - async fn inbound_group_sessions_for_backup( - &self, - limit: usize, - ) -> Result, CryptoStoreError> { - self.inbound_group_sessions_for_backup(limit).await.map_err(|e| e.into()) - } - - async fn reset_backup_state(&self) -> Result<(), CryptoStoreError> { - self.reset_backup_state().await.map_err(|e| e.into()) - } - - async fn load_backup_keys(&self) -> Result { - self.load_backup_keys().await.map_err(|e| e.into()) - } - async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> Result<(), CryptoStoreError> { - self.save_tracked_users(users).await.map_err(Into::into) - } - - async fn load_tracked_users(&self) -> Result, CryptoStoreError> { - self.load_tracked_users().await.map_err(Into::into) - } - - async fn get_device( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result, CryptoStoreError> { - self.get_device(user_id, device_id).await.map_err(|e| e.into()) - } - - async fn get_user_devices( - &self, - user_id: &UserId, - ) -> Result, CryptoStoreError> { - self.get_user_devices(user_id).await.map_err(|e| e.into()) - } - - async fn get_user_identity( - &self, - user_id: &UserId, - ) -> Result, CryptoStoreError> { - self.get_user_identity(user_id).await.map_err(|e| e.into()) - } - - async fn is_message_known(&self, hash: &OlmMessageHash) -> Result { - self.is_message_known(hash).await.map_err(|e| e.into()) - } - - async fn get_outgoing_secret_requests( - &self, - request_id: &TransactionId, - ) -> Result, CryptoStoreError> { - self.get_outgoing_key_request_helper(request_id.as_str()).await.map_err(|e| e.into()) + async fn get_room_settings(&self, room_id: &RoomId) -> Result> { + let key = self.encode_key(keys::ROOM_SETTINGS, room_id); + Ok(self + .inner + .transaction_on_one_with_mode(keys::ROOM_SETTINGS, IdbTransactionMode::Readonly)? + .object_store(keys::ROOM_SETTINGS)? + .get(&key)? + .await? + .map(|v| self.deserialize_value(v)) + .transpose()?) } - async fn get_secret_request_by_info( - &self, - key_info: &SecretInfo, - ) -> Result, CryptoStoreError> { - self.get_secret_request_by_info(key_info).await.map_err(|e| e.into()) + async fn get_custom_value(&self, key: &str) -> Result>> { + Ok(self + .inner + .transaction_on_one_with_mode(keys::CORE, IdbTransactionMode::Readonly)? + .object_store(keys::CORE)? + .get(&JsValue::from_str(key))? + .await? + .map(|v| self.deserialize_value(v)) + .transpose()?) } - async fn get_unsent_secret_requests(&self) -> Result, CryptoStoreError> { - self.get_unsent_secret_requests().await.map_err(|e| e.into()) + async fn set_custom_value(&self, key: &str, value: Vec) -> Result<()> { + self + .inner + .transaction_on_one_with_mode(keys::CORE, IdbTransactionMode::Readwrite)? + .object_store(keys::CORE)? + .put_key_val(&JsValue::from_str(key), &self.serialize_value(&value)?)?; + Ok(()) } +} - async fn delete_outgoing_secret_requests( - &self, - request_id: &TransactionId, - ) -> Result<(), CryptoStoreError> { - self.delete_outgoing_secret_requests(request_id).await.map_err(|e| e.into()) +impl Drop for IndexeddbCryptoStore { + fn drop(&mut self) { + // Must release the database access manually as it's not done when + // dropping it. + self.inner.close(); } } diff --git a/crates/matrix-sdk-indexeddb/src/lib.rs b/crates/matrix-sdk-indexeddb/src/lib.rs index 2232f803705..e9a2245761f 100644 --- a/crates/matrix-sdk-indexeddb/src/lib.rs +++ b/crates/matrix-sdk-indexeddb/src/lib.rs @@ -22,11 +22,9 @@ async fn open_stores_with_name( name: &str, passphrase: Option<&str>, ) -> Result<(IndexeddbStateStore, IndexeddbCryptoStore), OpenStoreError> { - let mut builder = IndexeddbStateStore::builder(); - builder.name(name.to_owned()); - + let mut builder = IndexeddbStateStore::builder().name(name.to_owned()); if let Some(passphrase) = passphrase { - builder.passphrase(passphrase.to_owned()); + builder = builder.passphrase(passphrase.to_owned()); } let state_store = builder.build().await.map_err(StoreError::from)?; @@ -54,11 +52,10 @@ pub async fn make_store_config( #[cfg(not(feature = "e2e-encryption"))] { - let mut builder = IndexeddbStateStore::builder(); - builder.name(name.to_owned()); + let mut builder = IndexeddbStateStore::builder().name(name.to_owned()); if let Some(passphrase) = passphrase { - builder.passphrase(passphrase.to_owned()); + builder = builder.passphrase(passphrase.to_owned()); } let state_store = builder.build().await.map_err(StoreError::from)?; diff --git a/crates/matrix-sdk-indexeddb/src/safe_encode.rs b/crates/matrix-sdk-indexeddb/src/safe_encode.rs index 4838bda405a..d3dbbbc4cb6 100644 --- a/crates/matrix-sdk-indexeddb/src/safe_encode.rs +++ b/crates/matrix-sdk-indexeddb/src/safe_encode.rs @@ -1,7 +1,8 @@ #![allow(dead_code)] use base64::{ - alphabet, encode_engine as base64_encode, - engine::fast_portable::{self, FastPortable}, + alphabet, + engine::{general_purpose, GeneralPurpose}, + Engine, }; use matrix_sdk_store_encryption::StoreCipher; use ruma::{ @@ -25,8 +26,8 @@ pub const RANGE_END: &str = "\u{001E}"; /// (though super unlikely) pub const ESCAPED: &str = "\u{001E}\u{001D}"; -const STANDARD_NO_PAD: FastPortable = - FastPortable::from(&alphabet::STANDARD, fast_portable::NO_PAD); +const STANDARD_NO_PAD: GeneralPurpose = + GeneralPurpose::new(&alphabet::STANDARD, general_purpose::NO_PAD); /// Encode value as String/JsValue/IdbKeyRange for the JS APIs in a /// safe, escaped manner. @@ -58,10 +59,8 @@ pub trait SafeEncode { /// `store_cipher` hash_key, returns the value as a base64 encoded /// string without any padding. fn as_secure_string(&self, table_name: &str, store_cipher: &StoreCipher) -> String { - base64_encode( - store_cipher.hash_key(table_name, self.as_encoded_string().as_bytes()), - &STANDARD_NO_PAD, - ) + STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.as_encoded_string().as_bytes())) } /// encode self into a JsValue, internally using `as_encoded_string` @@ -120,15 +119,11 @@ where fn as_secure_string(&self, table_name: &str, store_cipher: &StoreCipher) -> String { [ - &base64_encode( - store_cipher.hash_key(table_name, self.0.as_encoded_string().as_bytes()), - &STANDARD_NO_PAD, - ), + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.0.as_encoded_string().as_bytes())), KEY_SEPARATOR, - &base64_encode( - store_cipher.hash_key(table_name, self.1.as_encoded_string().as_bytes()), - &STANDARD_NO_PAD, - ), + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.1.as_encoded_string().as_bytes())), ] .concat() } @@ -155,20 +150,14 @@ where fn as_secure_string(&self, table_name: &str, store_cipher: &StoreCipher) -> String { [ - &base64_encode( - store_cipher.hash_key(table_name, self.0.as_encoded_string().as_bytes()), - &STANDARD_NO_PAD, - ), + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.0.as_encoded_string().as_bytes())), KEY_SEPARATOR, - &base64_encode( - store_cipher.hash_key(table_name, self.1.as_encoded_string().as_bytes()), - &STANDARD_NO_PAD, - ), + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.1.as_encoded_string().as_bytes())), KEY_SEPARATOR, - &base64_encode( - store_cipher.hash_key(table_name, self.2.as_encoded_string().as_bytes()), - &STANDARD_NO_PAD, - ), + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.2.as_encoded_string().as_bytes())), ] .concat() } @@ -198,25 +187,63 @@ where fn as_secure_string(&self, table_name: &str, store_cipher: &StoreCipher) -> String { [ - &base64_encode( - store_cipher.hash_key(table_name, self.0.as_encoded_string().as_bytes()), - &STANDARD_NO_PAD, - ), + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.0.as_encoded_string().as_bytes())), + KEY_SEPARATOR, + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.1.as_encoded_string().as_bytes())), + KEY_SEPARATOR, + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.2.as_encoded_string().as_bytes())), + KEY_SEPARATOR, + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.3.as_encoded_string().as_bytes())), + ] + .concat() + } +} + +/// Implement SafeEncode for tuple of five elements, separating the escaped +/// values with with `KEY_SEPARATOR`. +impl SafeEncode for (A, B, C, D, E) +where + A: SafeEncode, + B: SafeEncode, + C: SafeEncode, + D: SafeEncode, + E: SafeEncode, +{ + fn as_encoded_string(&self) -> String { + [ + &self.0.as_encoded_string(), + KEY_SEPARATOR, + &self.1.as_encoded_string(), + KEY_SEPARATOR, + &self.2.as_encoded_string(), + KEY_SEPARATOR, + &self.3.as_encoded_string(), + KEY_SEPARATOR, + &self.4.as_encoded_string(), + ] + .concat() + } + + fn as_secure_string(&self, table_name: &str, store_cipher: &StoreCipher) -> String { + [ + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.0.as_encoded_string().as_bytes())), + KEY_SEPARATOR, + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.1.as_encoded_string().as_bytes())), KEY_SEPARATOR, - &base64_encode( - store_cipher.hash_key(table_name, self.1.as_encoded_string().as_bytes()), - &STANDARD_NO_PAD, - ), + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.2.as_encoded_string().as_bytes())), KEY_SEPARATOR, - &base64_encode( - store_cipher.hash_key(table_name, self.2.as_encoded_string().as_bytes()), - &STANDARD_NO_PAD, - ), + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.3.as_encoded_string().as_bytes())), KEY_SEPARATOR, - &base64_encode( - store_cipher.hash_key(table_name, self.3.as_encoded_string().as_bytes()), - &STANDARD_NO_PAD, - ), + &STANDARD_NO_PAD + .encode(store_cipher.hash_key(table_name, self.4.as_encoded_string().as_bytes())), ] .concat() } diff --git a/crates/matrix-sdk-indexeddb/src/state_store.rs b/crates/matrix-sdk-indexeddb/src/state_store.rs deleted file mode 100644 index 27cb501f4ec..00000000000 --- a/crates/matrix-sdk-indexeddb/src/state_store.rs +++ /dev/null @@ -1,1568 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{ - collections::{BTreeSet, HashSet}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, -}; - -use anyhow::anyhow; -use async_trait::async_trait; -use derive_builder::Builder; -use gloo_utils::format::JsValueSerdeExt; -use indexed_db_futures::prelude::*; -use js_sys::Date as JsDate; -use matrix_sdk_base::{ - deserialized_responses::RawMemberEvent, - media::{MediaRequest, UniqueKey}, - store::{Result as StoreResult, StateChanges, StateStore, StoreError}, - MinimalStateEvent, RoomInfo, -}; -use matrix_sdk_store_encryption::{Error as EncryptionError, StoreCipher}; -use ruma::{ - canonical_json::redact, - events::{ - presence::PresenceEvent, - receipt::{Receipt, ReceiptType}, - room::member::{MembershipState, RoomMemberEventContent}, - AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnySyncStateEvent, - GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType, - }, - serde::Raw, - CanonicalJsonObject, EventId, MxcUri, OwnedEventId, OwnedUserId, RoomId, RoomVersionId, UserId, -}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use tracing::{debug, warn}; -use wasm_bindgen::JsValue; -use web_sys::IdbKeyRange; - -use crate::safe_encode::SafeEncode; - -#[derive(Clone, Serialize, Deserialize)] -struct StoreKeyWrapper(Vec); - -#[derive(Debug, thiserror::Error)] -pub enum IndexeddbStateStoreError { - #[error(transparent)] - Json(#[from] serde_json::Error), - #[error(transparent)] - Encryption(#[from] EncryptionError), - #[error("DomException {name} ({code}): {message}")] - DomException { name: String, message: String, code: u16 }, - #[error(transparent)] - StoreError(#[from] StoreError), - #[error("Can't migrate {name} from {old_version} to {new_version} without deleting data. See MigrationConflictStrategy for ways to configure.")] - MigrationConflict { name: String, old_version: f64, new_version: f64 }, -} - -/// Sometimes Migrations can't proceed without having to drop existing -/// data. This allows you to configure, how these cases should be handled. -#[allow(dead_code)] -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum MigrationConflictStrategy { - /// Just drop the data, we don't care that we have to sync again - Drop, - /// Raise a [`IndexeddbStateStoreError::MigrationConflict`] error with the - /// path to the DB in question. The caller then has to take care about - /// what they want to do and try again after. - Raise, - /// Default. - BackupAndDrop, -} - -impl From for IndexeddbStateStoreError { - fn from(frm: indexed_db_futures::web_sys::DomException) -> IndexeddbStateStoreError { - IndexeddbStateStoreError::DomException { - name: frm.name(), - message: frm.message(), - code: frm.code(), - } - } -} - -impl From for StoreError { - fn from(e: IndexeddbStateStoreError) -> Self { - match e { - IndexeddbStateStoreError::Json(e) => StoreError::Json(e), - IndexeddbStateStoreError::StoreError(e) => e, - IndexeddbStateStoreError::Encryption(e) => StoreError::Encryption(e), - _ => StoreError::backend(e), - } - } -} - -#[allow(non_snake_case)] -mod KEYS { - // STORES - - pub const CURRENT_DB_VERSION: f64 = 1.1; - pub const CURRENT_META_DB_VERSION: f64 = 2.0; - - pub const INTERNAL_STATE: &str = "matrix-sdk-state"; - pub const BACKUPS_META: &str = "backups"; - - pub const SESSION: &str = "session"; - pub const ACCOUNT_DATA: &str = "account_data"; - - pub const MEMBERS: &str = "members"; - pub const PROFILES: &str = "profiles"; - pub const DISPLAY_NAMES: &str = "display_names"; - pub const JOINED_USER_IDS: &str = "joined_user_ids"; - pub const INVITED_USER_IDS: &str = "invited_user_ids"; - - pub const ROOM_STATE: &str = "room_state"; - pub const ROOM_INFOS: &str = "room_infos"; - pub const PRESENCE: &str = "presence"; - pub const ROOM_ACCOUNT_DATA: &str = "room_account_data"; - - pub const STRIPPED_ROOM_INFOS: &str = "stripped_room_infos"; - pub const STRIPPED_MEMBERS: &str = "stripped_members"; - pub const STRIPPED_ROOM_STATE: &str = "stripped_room_state"; - pub const STRIPPED_JOINED_USER_IDS: &str = "stripped_joined_user_ids"; - pub const STRIPPED_INVITED_USER_IDS: &str = "stripped_invited_user_ids"; - - pub const ROOM_USER_RECEIPTS: &str = "room_user_receipts"; - pub const ROOM_EVENT_RECEIPTS: &str = "room_event_receipts"; - - pub const MEDIA: &str = "media"; - - pub const CUSTOM: &str = "custom"; - - pub const SYNC_TOKEN: &str = "sync_token"; - - /// All names of the state stores for convenience. - pub const ALL_STORES: &[&str] = &[ - SESSION, - ACCOUNT_DATA, - MEMBERS, - PROFILES, - DISPLAY_NAMES, - JOINED_USER_IDS, - INVITED_USER_IDS, - ROOM_STATE, - ROOM_INFOS, - PRESENCE, - ROOM_ACCOUNT_DATA, - STRIPPED_ROOM_INFOS, - STRIPPED_MEMBERS, - STRIPPED_ROOM_STATE, - STRIPPED_JOINED_USER_IDS, - STRIPPED_INVITED_USER_IDS, - ROOM_USER_RECEIPTS, - ROOM_EVENT_RECEIPTS, - MEDIA, - CUSTOM, - SYNC_TOKEN, - ]; - - // static keys - - pub const STORE_KEY: &str = "store_key"; - pub const FILTER: &str = "filter"; -} - -pub use KEYS::ALL_STORES; - -fn drop_stores(db: &IdbDatabase) -> Result<(), JsValue> { - for name in ALL_STORES { - db.delete_object_store(name)?; - } - Ok(()) -} - -fn create_stores(db: &IdbDatabase) -> Result<(), JsValue> { - for name in ALL_STORES { - db.create_object_store(name)?; - } - Ok(()) -} - -async fn backup(source: &IdbDatabase, meta: &IdbDatabase) -> Result<()> { - let now = JsDate::now(); - let backup_name = format!("backup-{}-{now}", source.name()); - - let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&backup_name, source.version())?; - db_req.set_on_upgrade_needed(Some(move |evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { - // migrating to version 1 - let db = evt.db(); - for name in ALL_STORES { - db.create_object_store(name)?; - } - Ok(()) - })); - let target = db_req.into_future().await?; - - for name in ALL_STORES { - let tx = target.transaction_on_one_with_mode(name, IdbTransactionMode::Readwrite)?; - - let obj = tx.object_store(name)?; - - if let Some(curs) = source - .transaction_on_one_with_mode(name, IdbTransactionMode::Readonly)? - .object_store(name)? - .open_cursor()? - .await? - { - while let Some(key) = curs.key() { - obj.put_key_val(&key, &curs.value())?; - - curs.continue_cursor()?.await?; - } - } - - tx.await.into_result()?; - } - - let tx = - meta.transaction_on_one_with_mode(KEYS::BACKUPS_META, IdbTransactionMode::Readwrite)?; - let backup_store = tx.object_store(KEYS::BACKUPS_META)?; - backup_store.put_key_val(&JsValue::from_f64(now), &JsValue::from_str(&backup_name))?; - - tx.await; - - Ok(()) -} - -#[derive(Builder, Debug, PartialEq, Eq)] -#[builder(name = "IndexeddbStateStoreBuilder", build_fn(skip))] -pub struct IndexeddbStateStoreBuilderConfig { - /// The name for the indexeddb store to use, `state` is none given - name: String, - /// The password the indexeddb should be encrypted with. If not given, the - /// DB is not encrypted - passphrase: String, - /// The strategy to use when a merge conflict is found, see - /// [`MigrationConflictStrategy`] for details - #[builder(default = "MigrationConflictStrategy::BackupAndDrop")] - migration_conflict_strategy: MigrationConflictStrategy, -} - -impl IndexeddbStateStoreBuilder { - pub async fn build(&mut self) -> Result { - let migration_strategy = self - .migration_conflict_strategy - .clone() - .unwrap_or(MigrationConflictStrategy::BackupAndDrop); - let name = self.name.clone().unwrap_or_else(|| "state".to_owned()); - - let meta_name = format!("{name}::{}", KEYS::INTERNAL_STATE); - - let mut db_req: OpenDbRequest = - IdbDatabase::open_f64(&meta_name, KEYS::CURRENT_META_DB_VERSION)?; - db_req.set_on_upgrade_needed(Some(|evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { - let db = evt.db(); - if evt.old_version() < 1.0 { - // migrating to version 1 - - db.create_object_store(KEYS::INTERNAL_STATE)?; - db.create_object_store(KEYS::BACKUPS_META)?; - } else if evt.old_version() < 2.0 { - db.create_object_store(KEYS::BACKUPS_META)?; - } - Ok(()) - })); - - let meta_db: IdbDatabase = db_req.into_future().await?; - - let store_cipher = if let Some(passphrase) = &self.passphrase { - let tx: IdbTransaction<'_> = meta_db.transaction_on_one_with_mode( - KEYS::INTERNAL_STATE, - IdbTransactionMode::Readwrite, - )?; - let ob = tx.object_store(KEYS::INTERNAL_STATE)?; - - let cipher = if let Some(StoreKeyWrapper(inner)) = ob - .get(&JsValue::from_str(KEYS::STORE_KEY))? - .await? - .map(|v| v.into_serde()) - .transpose()? - { - StoreCipher::import(passphrase, &inner)? - } else { - let cipher = StoreCipher::new()?; - #[cfg(not(test))] - let export = cipher.export(passphrase)?; - #[cfg(test)] - let export = cipher._insecure_export_fast_for_testing(passphrase)?; - ob.put_key_val( - &JsValue::from_str(KEYS::STORE_KEY), - &JsValue::from_serde(&StoreKeyWrapper(export))?, - )?; - cipher - }; - - tx.await.into_result()?; - Some(Arc::new(cipher)) - } else { - None - }; - - let recreate_stores = { - // checkup up in a separate call, whether we have to backup or do anything else - // to the db. Unfortunately the set_on_upgrade_needed doesn't allow async fn - // which we need to execute the backup. - let has_store_cipher = store_cipher.is_some(); - let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&name, 1.0)?; - let created = Arc::new(AtomicBool::new(false)); - let created_inner = created.clone(); - - db_req.set_on_upgrade_needed(Some( - move |evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { - // in case this is a fresh db, we dont't want to trigger - // further migrations other than just creating the full - // schema. - if evt.old_version() < 1.0 { - create_stores(evt.db())?; - created_inner.store(true, Ordering::Relaxed); - } - Ok(()) - }, - )); - - let pre_db = db_req.into_future().await?; - let old_version = pre_db.version(); - - if created.load(Ordering::Relaxed) { - // this is a fresh DB, return - false - } else if old_version == 1.0 && has_store_cipher { - match migration_strategy { - MigrationConflictStrategy::BackupAndDrop => { - backup(&pre_db, &meta_db).await?; - true - } - MigrationConflictStrategy::Drop => true, - MigrationConflictStrategy::Raise => { - return Err(IndexeddbStateStoreError::MigrationConflict { - name, - old_version, - new_version: KEYS::CURRENT_DB_VERSION, - }) - } - } - } else { - // Nothing to be done - false - } - }; - - let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&name, KEYS::CURRENT_DB_VERSION)?; - db_req.set_on_upgrade_needed(Some( - move |evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { - // changing the format can only happen in the upgrade procedure - if recreate_stores { - drop_stores(evt.db())?; - create_stores(evt.db())?; - } - Ok(()) - }, - )); - - let db = db_req.into_future().await?; - Ok(IndexeddbStateStore { name, inner: db, meta: meta_db, store_cipher }) - } -} - -pub struct IndexeddbStateStore { - name: String, - pub(crate) inner: IdbDatabase, - pub(crate) meta: IdbDatabase, - pub(crate) store_cipher: Option>, -} - -impl std::fmt::Debug for IndexeddbStateStore { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("IndexeddbStateStore").field("name", &self.name).finish() - } -} - -type Result = std::result::Result; - -impl IndexeddbStateStore { - /// Generate a IndexeddbStateStoreBuilder with default parameters - pub fn builder() -> IndexeddbStateStoreBuilder { - IndexeddbStateStoreBuilder::default() - } - - /// Whether this database has any migration backups - pub async fn has_backups(&self) -> Result { - Ok(self - .meta - .transaction_on_one_with_mode(KEYS::BACKUPS_META, IdbTransactionMode::Readonly)? - .object_store(KEYS::BACKUPS_META)? - .count()? - .await? - > 0) - } - - /// What's the database name of the latest backup< - pub async fn latest_backup(&self) -> Result> { - Ok(self - .meta - .transaction_on_one_with_mode(KEYS::BACKUPS_META, IdbTransactionMode::Readonly)? - .object_store(KEYS::BACKUPS_META)? - .open_cursor_with_direction(indexed_db_futures::prelude::IdbCursorDirection::Prev)? - .await? - .and_then(|c| c.value().as_string())) - } - - fn serialize_event(&self, event: &impl Serialize) -> Result { - Ok(match &self.store_cipher { - Some(cipher) => JsValue::from_serde(&cipher.encrypt_value_typed(event)?)?, - None => JsValue::from_serde(event)?, - }) - } - - fn deserialize_event(&self, event: JsValue) -> Result { - match &self.store_cipher { - Some(cipher) => Ok(cipher.decrypt_value_typed(event.into_serde()?)?), - None => Ok(event.into_serde()?), - } - } - - fn encode_key(&self, table_name: &str, key: T) -> JsValue - where - T: SafeEncode, - { - match &self.store_cipher { - Some(cipher) => key.encode_secure(table_name, cipher), - None => key.encode(), - } - } - - fn encode_to_range(&self, table_name: &str, key: T) -> Result - where - T: SafeEncode, - { - match &self.store_cipher { - Some(cipher) => key.encode_to_range_secure(table_name, cipher), - None => key.encode_to_range(), - } - .map_err(|e| IndexeddbStateStoreError::StoreError(StoreError::Backend(anyhow!(e).into()))) - } - - pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { - let tx = self - .inner - .transaction_on_one_with_mode(KEYS::SESSION, IdbTransactionMode::Readwrite)?; - - let obj = tx.object_store(KEYS::SESSION)?; - - obj.put_key_val( - &self.encode_key(KEYS::FILTER, (KEYS::FILTER, filter_name)), - &self.serialize_event(&filter_id)?, - )?; - - tx.await.into_result()?; - - Ok(()) - } - - pub async fn get_filter(&self, filter_name: &str) -> Result> { - self.inner - .transaction_on_one_with_mode(KEYS::SESSION, IdbTransactionMode::Readonly)? - .object_store(KEYS::SESSION)? - .get(&self.encode_key(KEYS::FILTER, (KEYS::FILTER, filter_name)))? - .await? - .map(|f| self.deserialize_event(f)) - .transpose() - } - - pub async fn get_sync_token(&self) -> Result> { - self.inner - .transaction_on_one_with_mode(KEYS::SYNC_TOKEN, IdbTransactionMode::Readonly)? - .object_store(KEYS::SYNC_TOKEN)? - .get(&JsValue::from_str(KEYS::SYNC_TOKEN))? - .await? - .map(|f| self.deserialize_event(f)) - .transpose() - } - - pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> { - let mut stores: HashSet<&'static str> = [ - (changes.sync_token.is_some(), KEYS::SYNC_TOKEN), - (changes.session.is_some(), KEYS::SESSION), - (!changes.ambiguity_maps.is_empty(), KEYS::DISPLAY_NAMES), - (!changes.account_data.is_empty(), KEYS::ACCOUNT_DATA), - (!changes.presence.is_empty(), KEYS::PRESENCE), - (!changes.profiles.is_empty(), KEYS::PROFILES), - (!changes.room_account_data.is_empty(), KEYS::ROOM_ACCOUNT_DATA), - (!changes.receipts.is_empty(), KEYS::ROOM_EVENT_RECEIPTS), - (!changes.stripped_state.is_empty(), KEYS::STRIPPED_ROOM_STATE), - ] - .iter() - .filter_map(|(id, key)| if *id { Some(*key) } else { None }) - .collect(); - - if !changes.state.is_empty() { - stores.extend([KEYS::ROOM_STATE, KEYS::STRIPPED_ROOM_STATE]); - } - - if !changes.redactions.is_empty() { - stores.extend([KEYS::ROOM_STATE, KEYS::ROOM_INFOS]); - } - - if !changes.room_infos.is_empty() || !changes.stripped_room_infos.is_empty() { - stores.extend([KEYS::ROOM_INFOS, KEYS::STRIPPED_ROOM_INFOS]); - } - - if !changes.members.is_empty() { - stores.extend([ - KEYS::PROFILES, - KEYS::MEMBERS, - KEYS::INVITED_USER_IDS, - KEYS::JOINED_USER_IDS, - KEYS::STRIPPED_MEMBERS, - KEYS::STRIPPED_INVITED_USER_IDS, - KEYS::STRIPPED_JOINED_USER_IDS, - ]) - } - - if !changes.stripped_members.is_empty() { - stores.extend([ - KEYS::STRIPPED_MEMBERS, - KEYS::STRIPPED_INVITED_USER_IDS, - KEYS::STRIPPED_JOINED_USER_IDS, - ]) - } - - if !changes.receipts.is_empty() { - stores.extend([KEYS::ROOM_EVENT_RECEIPTS, KEYS::ROOM_USER_RECEIPTS]) - } - - if stores.is_empty() { - // nothing to do, quit early - return Ok(()); - } - - let stores: Vec<&'static str> = stores.into_iter().collect(); - let tx = - self.inner.transaction_on_multi_with_mode(&stores, IdbTransactionMode::Readwrite)?; - - if let Some(s) = &changes.sync_token { - tx.object_store(KEYS::SYNC_TOKEN)? - .put_key_val(&JsValue::from_str(KEYS::SYNC_TOKEN), &self.serialize_event(s)?)?; - } - - if !changes.ambiguity_maps.is_empty() { - let store = tx.object_store(KEYS::DISPLAY_NAMES)?; - for (room_id, ambiguity_maps) in &changes.ambiguity_maps { - for (display_name, map) in ambiguity_maps { - let key = self.encode_key(KEYS::DISPLAY_NAMES, (room_id, display_name)); - - store.put_key_val(&key, &self.serialize_event(&map)?)?; - } - } - } - - if !changes.account_data.is_empty() { - let store = tx.object_store(KEYS::ACCOUNT_DATA)?; - for (event_type, event) in &changes.account_data { - store.put_key_val( - &self.encode_key(KEYS::ACCOUNT_DATA, event_type), - &self.serialize_event(&event)?, - )?; - } - } - - if !changes.room_account_data.is_empty() { - let store = tx.object_store(KEYS::ROOM_ACCOUNT_DATA)?; - for (room, events) in &changes.room_account_data { - for (event_type, event) in events { - let key = self.encode_key(KEYS::ROOM_ACCOUNT_DATA, (room, event_type)); - store.put_key_val(&key, &self.serialize_event(&event)?)?; - } - } - } - - if !changes.state.is_empty() { - let state = tx.object_store(KEYS::ROOM_STATE)?; - let stripped_state = tx.object_store(KEYS::STRIPPED_ROOM_STATE)?; - for (room, event_types) in &changes.state { - for (event_type, events) in event_types { - for (state_key, event) in events { - let key = self.encode_key(KEYS::ROOM_STATE, (room, event_type, state_key)); - state.put_key_val(&key, &self.serialize_event(&event)?)?; - stripped_state.delete(&key)?; - } - } - } - } - - if !changes.room_infos.is_empty() { - let room_infos = tx.object_store(KEYS::ROOM_INFOS)?; - let stripped_room_infos = tx.object_store(KEYS::STRIPPED_ROOM_INFOS)?; - for (room_id, room_info) in &changes.room_infos { - room_infos.put_key_val( - &self.encode_key(KEYS::ROOM_INFOS, room_id), - &self.serialize_event(&room_info)?, - )?; - stripped_room_infos.delete(&self.encode_key(KEYS::STRIPPED_ROOM_INFOS, room_id))?; - } - } - - if !changes.presence.is_empty() { - let store = tx.object_store(KEYS::PRESENCE)?; - for (sender, event) in &changes.presence { - store.put_key_val( - &self.encode_key(KEYS::PRESENCE, sender), - &self.serialize_event(&event)?, - )?; - } - } - - if !changes.stripped_room_infos.is_empty() { - let stripped_room_infos = tx.object_store(KEYS::STRIPPED_ROOM_INFOS)?; - let room_infos = tx.object_store(KEYS::ROOM_INFOS)?; - for (room_id, info) in &changes.stripped_room_infos { - stripped_room_infos.put_key_val( - &self.encode_key(KEYS::STRIPPED_ROOM_INFOS, room_id), - &self.serialize_event(&info)?, - )?; - room_infos.delete(&self.encode_key(KEYS::ROOM_INFOS, room_id))?; - } - } - - if !changes.stripped_members.is_empty() { - let store = tx.object_store(KEYS::STRIPPED_MEMBERS)?; - let joined = tx.object_store(KEYS::STRIPPED_JOINED_USER_IDS)?; - let invited = tx.object_store(KEYS::STRIPPED_INVITED_USER_IDS)?; - for (room, raw_events) in &changes.stripped_members { - for raw_event in raw_events.values() { - let event = match raw_event.deserialize() { - Ok(ev) => ev, - Err(e) => { - let event_id: Option = - raw_event.get_field("event_id").ok().flatten(); - debug!(event_id, "Failed to deserialize stripped member event: {e}"); - continue; - } - }; - - let key = (room, &event.state_key); - - match event.content.membership { - MembershipState::Join => { - joined.put_key_val_owned( - &self.encode_key(KEYS::STRIPPED_JOINED_USER_IDS, key), - &self.serialize_event(&event.state_key)?, - )?; - invited - .delete(&self.encode_key(KEYS::STRIPPED_INVITED_USER_IDS, key))?; - } - MembershipState::Invite => { - invited.put_key_val_owned( - &self.encode_key(KEYS::STRIPPED_INVITED_USER_IDS, key), - &self.serialize_event(&event.state_key)?, - )?; - joined.delete(&self.encode_key(KEYS::STRIPPED_JOINED_USER_IDS, key))?; - } - _ => { - joined.delete(&self.encode_key(KEYS::STRIPPED_JOINED_USER_IDS, key))?; - invited - .delete(&self.encode_key(KEYS::STRIPPED_INVITED_USER_IDS, key))?; - } - } - store.put_key_val( - &self.encode_key(KEYS::STRIPPED_MEMBERS, key), - &self.serialize_event(&raw_event)?, - )?; - } - } - } - - if !changes.stripped_state.is_empty() { - let store = tx.object_store(KEYS::STRIPPED_ROOM_STATE)?; - for (room, event_types) in &changes.stripped_state { - for (event_type, events) in event_types { - for (state_key, event) in events { - let key = self - .encode_key(KEYS::STRIPPED_ROOM_STATE, (room, event_type, state_key)); - store.put_key_val(&key, &self.serialize_event(&event)?)?; - } - } - } - } - - if !changes.members.is_empty() { - let profiles = tx.object_store(KEYS::PROFILES)?; - let joined = tx.object_store(KEYS::JOINED_USER_IDS)?; - let invited = tx.object_store(KEYS::INVITED_USER_IDS)?; - let members = tx.object_store(KEYS::MEMBERS)?; - let stripped_members = tx.object_store(KEYS::STRIPPED_MEMBERS)?; - let stripped_joined = tx.object_store(KEYS::STRIPPED_JOINED_USER_IDS)?; - let stripped_invited = tx.object_store(KEYS::STRIPPED_INVITED_USER_IDS)?; - - for (room, raw_events) in &changes.members { - let profile_changes = changes.profiles.get(room); - - for raw_event in raw_events.values() { - let event = match raw_event.deserialize() { - Ok(ev) => ev, - Err(e) => { - let event_id: Option = - raw_event.get_field("event_id").ok().flatten(); - debug!(event_id, "Failed to deserialize member event: {e}"); - continue; - } - }; - - let key = (room, event.state_key()); - - stripped_joined - .delete(&self.encode_key(KEYS::STRIPPED_JOINED_USER_IDS, key))?; - stripped_invited - .delete(&self.encode_key(KEYS::STRIPPED_INVITED_USER_IDS, key))?; - - match event.membership() { - MembershipState::Join => { - joined.put_key_val_owned( - &self.encode_key(KEYS::JOINED_USER_IDS, key), - &self.serialize_event(event.state_key())?, - )?; - invited.delete(&self.encode_key(KEYS::INVITED_USER_IDS, key))?; - } - MembershipState::Invite => { - invited.put_key_val_owned( - &self.encode_key(KEYS::INVITED_USER_IDS, key), - &self.serialize_event(event.state_key())?, - )?; - joined.delete(&self.encode_key(KEYS::JOINED_USER_IDS, key))?; - } - _ => { - joined.delete(&self.encode_key(KEYS::JOINED_USER_IDS, key))?; - invited.delete(&self.encode_key(KEYS::INVITED_USER_IDS, key))?; - } - } - - members.put_key_val_owned( - &self.encode_key(KEYS::MEMBERS, key), - &self.serialize_event(&raw_event)?, - )?; - stripped_members.delete(&self.encode_key(KEYS::STRIPPED_MEMBERS, key))?; - - if let Some(profile) = profile_changes.and_then(|p| p.get(event.state_key())) { - profiles.put_key_val_owned( - &self.encode_key(KEYS::PROFILES, key), - &self.serialize_event(&profile)?, - )?; - } - } - } - } - - if !changes.receipts.is_empty() { - let room_user_receipts = tx.object_store(KEYS::ROOM_USER_RECEIPTS)?; - let room_event_receipts = tx.object_store(KEYS::ROOM_EVENT_RECEIPTS)?; - - for (room, content) in &changes.receipts { - for (event_id, receipts) in &content.0 { - for (receipt_type, receipts) in receipts { - for (user_id, receipt) in receipts { - let key = self.encode_key( - KEYS::ROOM_USER_RECEIPTS, - (room, receipt_type, user_id), - ); - - if let Some((old_event, _)) = - room_user_receipts.get(&key)?.await?.and_then(|f| { - self.deserialize_event::<(OwnedEventId, Receipt)>(f).ok() - }) - { - room_event_receipts.delete(&self.encode_key( - KEYS::ROOM_EVENT_RECEIPTS, - (room, receipt_type, &old_event, user_id), - ))?; - } - - room_user_receipts - .put_key_val(&key, &self.serialize_event(&(event_id, receipt))?)?; - - // Add the receipt to the room event receipts - room_event_receipts.put_key_val( - &self.encode_key( - KEYS::ROOM_EVENT_RECEIPTS, - (room, receipt_type, event_id, user_id), - ), - &self.serialize_event(&(user_id, receipt))?, - )?; - } - } - } - } - } - - if !changes.redactions.is_empty() { - let state = tx.object_store(KEYS::ROOM_STATE)?; - let room_info = tx.object_store(KEYS::ROOM_INFOS)?; - - for (room_id, redactions) in &changes.redactions { - let range = self.encode_to_range(KEYS::ROOM_STATE, room_id)?; - let Some(cursor) = state.open_cursor_with_range(&range)?.await? else { continue }; - - let mut room_version = None; - - while let Some(key) = cursor.key() { - let raw_evt = - self.deserialize_event::>(cursor.value())?; - if let Ok(Some(event_id)) = raw_evt.get_field::("event_id") { - if let Some(redaction) = redactions.get(&event_id) { - let version = { - if room_version.is_none() { - room_version.replace(room_info - .get(&self.encode_key(KEYS::ROOM_INFOS, room_id))? - .await? - .and_then(|f| self.deserialize_event::(f).ok()) - .and_then(|info| info.room_version().cloned()) - .unwrap_or_else(|| { - warn!(?room_id, "Unable to find the room version, assume version 9"); - RoomVersionId::V9 - }) - ); - } - room_version.as_ref().unwrap() - }; - - let redacted = redact( - raw_evt.deserialize_as::()?, - version, - Some(redaction.try_into()?), - ) - .map_err(StoreError::Redaction)?; - state.put_key_val(&key, &self.serialize_event(&redacted)?)?; - } - } - - // move forward. - cursor.advance(1)?.await?; - } - } - } - - tx.await.into_result().map_err(|e| e.into()) - } - - pub async fn get_presence_event(&self, user_id: &UserId) -> Result>> { - self.inner - .transaction_on_one_with_mode(KEYS::PRESENCE, IdbTransactionMode::Readonly)? - .object_store(KEYS::PRESENCE)? - .get(&self.encode_key(KEYS::PRESENCE, user_id))? - .await? - .map(|f| self.deserialize_event(f)) - .transpose() - } - - pub async fn get_state_event( - &self, - room_id: &RoomId, - event_type: StateEventType, - state_key: &str, - ) -> Result>> { - self.inner - .transaction_on_one_with_mode(KEYS::ROOM_STATE, IdbTransactionMode::Readonly)? - .object_store(KEYS::ROOM_STATE)? - .get(&self.encode_key(KEYS::ROOM_STATE, (room_id, event_type, state_key)))? - .await? - .map(|f| self.deserialize_event(f)) - .transpose() - } - - pub async fn get_state_events( - &self, - room_id: &RoomId, - event_type: StateEventType, - ) -> Result>> { - let range = self.encode_to_range(KEYS::ROOM_STATE, (room_id, event_type))?; - Ok(self - .inner - .transaction_on_one_with_mode(KEYS::ROOM_STATE, IdbTransactionMode::Readonly)? - .object_store(KEYS::ROOM_STATE)? - .get_all_with_key(&range)? - .await? - .iter() - .filter_map(|f| self.deserialize_event(f).ok()) - .collect::>()) - } - - pub async fn get_profile( - &self, - room_id: &RoomId, - user_id: &UserId, - ) -> Result>> { - self.inner - .transaction_on_one_with_mode(KEYS::PROFILES, IdbTransactionMode::Readonly)? - .object_store(KEYS::PROFILES)? - .get(&self.encode_key(KEYS::PROFILES, (room_id, user_id)))? - .await? - .map(|f| self.deserialize_event(f)) - .transpose() - } - - pub async fn get_member_event( - &self, - room_id: &RoomId, - state_key: &UserId, - ) -> Result> { - if let Some(e) = self - .inner - .transaction_on_one_with_mode(KEYS::STRIPPED_MEMBERS, IdbTransactionMode::Readonly)? - .object_store(KEYS::STRIPPED_MEMBERS)? - .get(&self.encode_key(KEYS::STRIPPED_MEMBERS, (room_id, state_key)))? - .await? - .map(|f| self.deserialize_event(f)) - .transpose()? - { - Ok(Some(RawMemberEvent::Stripped(e))) - } else if let Some(e) = self - .inner - .transaction_on_one_with_mode(KEYS::MEMBERS, IdbTransactionMode::Readonly)? - .object_store(KEYS::MEMBERS)? - .get(&self.encode_key(KEYS::MEMBERS, (room_id, state_key)))? - .await? - .map(|f| self.deserialize_event(f)) - .transpose()? - { - Ok(Some(RawMemberEvent::Sync(e))) - } else { - Ok(None) - } - } - - pub async fn get_user_ids_stream(&self, room_id: &RoomId) -> Result> { - Ok([self.get_invited_user_ids(room_id).await?, self.get_joined_user_ids(room_id).await?] - .concat()) - } - - pub async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result> { - let range = self.encode_to_range(KEYS::INVITED_USER_IDS, room_id)?; - let entries = self - .inner - .transaction_on_one_with_mode(KEYS::INVITED_USER_IDS, IdbTransactionMode::Readonly)? - .object_store(KEYS::INVITED_USER_IDS)? - .get_all_with_key(&range)? - .await? - .iter() - .filter_map(|f| self.deserialize_event::(f).ok()) - .collect::>(); - - Ok(entries) - } - - pub async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result> { - let range = self.encode_to_range(KEYS::JOINED_USER_IDS, room_id)?; - Ok(self - .inner - .transaction_on_one_with_mode(KEYS::JOINED_USER_IDS, IdbTransactionMode::Readonly)? - .object_store(KEYS::JOINED_USER_IDS)? - .get_all_with_key(&range)? - .await? - .iter() - .filter_map(|f| self.deserialize_event::(f).ok()) - .collect::>()) - } - - pub async fn get_stripped_user_ids_stream(&self, room_id: &RoomId) -> Result> { - Ok([ - self.get_stripped_invited_user_ids(room_id).await?, - self.get_stripped_joined_user_ids(room_id).await?, - ] - .concat()) - } - - pub async fn get_stripped_invited_user_ids( - &self, - room_id: &RoomId, - ) -> Result> { - let range = self.encode_to_range(KEYS::STRIPPED_INVITED_USER_IDS, room_id)?; - let entries = self - .inner - .transaction_on_one_with_mode( - KEYS::STRIPPED_INVITED_USER_IDS, - IdbTransactionMode::Readonly, - )? - .object_store(KEYS::STRIPPED_INVITED_USER_IDS)? - .get_all_with_key(&range)? - .await? - .iter() - .filter_map(|f| self.deserialize_event::(f).ok()) - .collect::>(); - - Ok(entries) - } - - pub async fn get_stripped_joined_user_ids(&self, room_id: &RoomId) -> Result> { - let range = self.encode_to_range(KEYS::STRIPPED_JOINED_USER_IDS, room_id)?; - Ok(self - .inner - .transaction_on_one_with_mode( - KEYS::STRIPPED_JOINED_USER_IDS, - IdbTransactionMode::Readonly, - )? - .object_store(KEYS::STRIPPED_JOINED_USER_IDS)? - .get_all_with_key(&range)? - .await? - .iter() - .filter_map(|f| self.deserialize_event::(f).ok()) - .collect::>()) - } - - pub async fn get_room_infos(&self) -> Result> { - let entries: Vec<_> = self - .inner - .transaction_on_one_with_mode(KEYS::ROOM_INFOS, IdbTransactionMode::Readonly)? - .object_store(KEYS::ROOM_INFOS)? - .get_all()? - .await? - .iter() - .filter_map(|f| self.deserialize_event::(f).ok()) - .collect(); - - Ok(entries) - } - - pub async fn get_stripped_room_infos(&self) -> Result> { - let entries = self - .inner - .transaction_on_one_with_mode(KEYS::STRIPPED_ROOM_INFOS, IdbTransactionMode::Readonly)? - .object_store(KEYS::STRIPPED_ROOM_INFOS)? - .get_all()? - .await? - .iter() - .filter_map(|f| self.deserialize_event(f).ok()) - .collect::>(); - - Ok(entries) - } - - pub async fn get_users_with_display_name( - &self, - room_id: &RoomId, - display_name: &str, - ) -> Result> { - self.inner - .transaction_on_one_with_mode(KEYS::DISPLAY_NAMES, IdbTransactionMode::Readonly)? - .object_store(KEYS::DISPLAY_NAMES)? - .get(&self.encode_key(KEYS::DISPLAY_NAMES, (room_id, display_name)))? - .await? - .map(|f| self.deserialize_event::>(f)) - .unwrap_or_else(|| Ok(Default::default())) - } - - pub async fn get_account_data_event( - &self, - event_type: GlobalAccountDataEventType, - ) -> Result>> { - self.inner - .transaction_on_one_with_mode(KEYS::ACCOUNT_DATA, IdbTransactionMode::Readonly)? - .object_store(KEYS::ACCOUNT_DATA)? - .get(&self.encode_key(KEYS::ACCOUNT_DATA, event_type))? - .await? - .map(|f| self.deserialize_event(f)) - .transpose() - } - - pub async fn get_room_account_data_event( - &self, - room_id: &RoomId, - event_type: RoomAccountDataEventType, - ) -> Result>> { - self.inner - .transaction_on_one_with_mode(KEYS::ROOM_ACCOUNT_DATA, IdbTransactionMode::Readonly)? - .object_store(KEYS::ROOM_ACCOUNT_DATA)? - .get(&self.encode_key(KEYS::ROOM_ACCOUNT_DATA, (room_id, event_type)))? - .await? - .map(|f| self.deserialize_event(f)) - .transpose() - } - - async fn get_user_room_receipt_event( - &self, - room_id: &RoomId, - receipt_type: ReceiptType, - user_id: &UserId, - ) -> Result> { - self.inner - .transaction_on_one_with_mode(KEYS::ROOM_USER_RECEIPTS, IdbTransactionMode::Readonly)? - .object_store(KEYS::ROOM_USER_RECEIPTS)? - .get(&self.encode_key(KEYS::ROOM_USER_RECEIPTS, (room_id, receipt_type, user_id)))? - .await? - .map(|f| self.deserialize_event(f)) - .transpose() - } - - async fn get_event_room_receipt_events( - &self, - room_id: &RoomId, - receipt_type: ReceiptType, - event_id: &EventId, - ) -> Result> { - let range = - self.encode_to_range(KEYS::ROOM_EVENT_RECEIPTS, (room_id, &receipt_type, event_id))?; - let tx = self.inner.transaction_on_one_with_mode( - KEYS::ROOM_EVENT_RECEIPTS, - IdbTransactionMode::Readonly, - )?; - let store = tx.object_store(KEYS::ROOM_EVENT_RECEIPTS)?; - - Ok(store - .get_all_with_key(&range)? - .await? - .iter() - .filter_map(|f| self.deserialize_event(f).ok()) - .collect::>()) - } - - async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { - let key = self - .encode_key(KEYS::MEDIA, (request.source.unique_key(), request.format.unique_key())); - let tx = - self.inner.transaction_on_one_with_mode(KEYS::MEDIA, IdbTransactionMode::Readwrite)?; - - tx.object_store(KEYS::MEDIA)?.put_key_val(&key, &self.serialize_event(&data)?)?; - - tx.await.into_result().map_err(|e| e.into()) - } - - async fn get_media_content(&self, request: &MediaRequest) -> Result>> { - let key = self - .encode_key(KEYS::MEDIA, (request.source.unique_key(), request.format.unique_key())); - self.inner - .transaction_on_one_with_mode(KEYS::MEDIA, IdbTransactionMode::Readonly)? - .object_store(KEYS::MEDIA)? - .get(&key)? - .await? - .map(|f| self.deserialize_event(f)) - .transpose() - } - - async fn get_custom_value(&self, key: &[u8]) -> Result>> { - let jskey = &JsValue::from_str(core::str::from_utf8(key).map_err(StoreError::Codec)?); - self.get_custom_value_for_js(jskey).await - } - - async fn get_custom_value_for_js(&self, jskey: &JsValue) -> Result>> { - self.inner - .transaction_on_one_with_mode(KEYS::CUSTOM, IdbTransactionMode::Readonly)? - .object_store(KEYS::CUSTOM)? - .get(jskey)? - .await? - .map(|f| self.deserialize_event(f)) - .transpose() - } - - async fn set_custom_value(&self, key: &[u8], value: Vec) -> Result>> { - let jskey = JsValue::from_str(core::str::from_utf8(key).map_err(StoreError::Codec)?); - - let prev = self.get_custom_value_for_js(&jskey).await?; - - let tx = - self.inner.transaction_on_one_with_mode(KEYS::CUSTOM, IdbTransactionMode::Readwrite)?; - - tx.object_store(KEYS::CUSTOM)?.put_key_val(&jskey, &self.serialize_event(&value)?)?; - - tx.await.into_result().map_err(IndexeddbStateStoreError::from)?; - Ok(prev) - } - - async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { - let key = self - .encode_key(KEYS::MEDIA, (request.source.unique_key(), request.format.unique_key())); - let tx = - self.inner.transaction_on_one_with_mode(KEYS::MEDIA, IdbTransactionMode::Readwrite)?; - - tx.object_store(KEYS::MEDIA)?.delete(&key)?; - - tx.await.into_result().map_err(|e| e.into()) - } - - async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { - let range = self.encode_to_range(KEYS::MEDIA, uri)?; - let tx = - self.inner.transaction_on_one_with_mode(KEYS::MEDIA, IdbTransactionMode::Readwrite)?; - let store = tx.object_store(KEYS::MEDIA)?; - - for k in store.get_all_keys_with_key(&range)?.await?.iter() { - store.delete(&k)?; - } - - tx.await.into_result().map_err(|e| e.into()) - } - - async fn remove_room(&self, room_id: &RoomId) -> Result<()> { - let direct_stores = [KEYS::ROOM_INFOS, KEYS::STRIPPED_ROOM_INFOS]; - - let prefixed_stores = [ - KEYS::MEMBERS, - KEYS::PROFILES, - KEYS::DISPLAY_NAMES, - KEYS::INVITED_USER_IDS, - KEYS::JOINED_USER_IDS, - KEYS::ROOM_STATE, - KEYS::ROOM_ACCOUNT_DATA, - KEYS::ROOM_EVENT_RECEIPTS, - KEYS::ROOM_USER_RECEIPTS, - KEYS::STRIPPED_ROOM_STATE, - KEYS::STRIPPED_MEMBERS, - ]; - - let all_stores = { - let mut v = Vec::new(); - v.extend(prefixed_stores); - v.extend(direct_stores); - v - }; - - let tx = self - .inner - .transaction_on_multi_with_mode(&all_stores, IdbTransactionMode::Readwrite)?; - - for store_name in direct_stores { - tx.object_store(store_name)?.delete(&self.encode_key(store_name, room_id))?; - } - - for store_name in prefixed_stores { - let store = tx.object_store(store_name)?; - let range = self.encode_to_range(store_name, room_id)?; - for key in store.get_all_keys_with_key(&range)?.await?.iter() { - store.delete(&key)?; - } - } - tx.await.into_result().map_err(|e| e.into()) - } -} - -#[cfg(target_arch = "wasm32")] -#[async_trait(?Send)] -impl StateStore for IndexeddbStateStore { - async fn save_filter(&self, filter_name: &str, filter_id: &str) -> StoreResult<()> { - self.save_filter(filter_name, filter_id).await.map_err(|e| e.into()) - } - - async fn save_changes(&self, changes: &StateChanges) -> StoreResult<()> { - self.save_changes(changes).await.map_err(|e| e.into()) - } - - async fn get_filter(&self, filter_id: &str) -> StoreResult> { - self.get_filter(filter_id).await.map_err(|e| e.into()) - } - - async fn get_sync_token(&self) -> StoreResult> { - self.get_sync_token().await.map_err(|e| e.into()) - } - - async fn get_presence_event( - &self, - user_id: &UserId, - ) -> StoreResult>> { - self.get_presence_event(user_id).await.map_err(|e| e.into()) - } - - async fn get_state_event( - &self, - room_id: &RoomId, - event_type: StateEventType, - state_key: &str, - ) -> StoreResult>> { - self.get_state_event(room_id, event_type, state_key).await.map_err(|e| e.into()) - } - - async fn get_state_events( - &self, - room_id: &RoomId, - event_type: StateEventType, - ) -> StoreResult>> { - self.get_state_events(room_id, event_type).await.map_err(|e| e.into()) - } - - async fn get_profile( - &self, - room_id: &RoomId, - user_id: &UserId, - ) -> StoreResult>> { - self.get_profile(room_id, user_id).await.map_err(|e| e.into()) - } - - async fn get_member_event( - &self, - room_id: &RoomId, - state_key: &UserId, - ) -> StoreResult> { - self.get_member_event(room_id, state_key).await.map_err(|e| e.into()) - } - - async fn get_user_ids(&self, room_id: &RoomId) -> StoreResult> { - let ids: Vec = self.get_stripped_user_ids_stream(room_id).await?; - if !ids.is_empty() { - return Ok(ids); - } - self.get_user_ids_stream(room_id).await.map_err(|e| e.into()) - } - - async fn get_invited_user_ids(&self, room_id: &RoomId) -> StoreResult> { - let ids: Vec = self.get_stripped_invited_user_ids(room_id).await?; - if !ids.is_empty() { - return Ok(ids); - } - self.get_invited_user_ids(room_id).await.map_err(|e| e.into()) - } - - async fn get_joined_user_ids(&self, room_id: &RoomId) -> StoreResult> { - let ids: Vec = self.get_stripped_joined_user_ids(room_id).await?; - if !ids.is_empty() { - return Ok(ids); - } - self.get_joined_user_ids(room_id).await.map_err(|e| e.into()) - } - - async fn get_room_infos(&self) -> StoreResult> { - self.get_room_infos().await.map_err(|e| e.into()) - } - - async fn get_stripped_room_infos(&self) -> StoreResult> { - self.get_stripped_room_infos().await.map_err(|e| e.into()) - } - - async fn get_users_with_display_name( - &self, - room_id: &RoomId, - display_name: &str, - ) -> StoreResult> { - self.get_users_with_display_name(room_id, display_name).await.map_err(|e| e.into()) - } - - async fn get_account_data_event( - &self, - event_type: GlobalAccountDataEventType, - ) -> StoreResult>> { - self.get_account_data_event(event_type).await.map_err(|e| e.into()) - } - - async fn get_room_account_data_event( - &self, - room_id: &RoomId, - event_type: RoomAccountDataEventType, - ) -> StoreResult>> { - self.get_room_account_data_event(room_id, event_type).await.map_err(|e| e.into()) - } - - async fn get_user_room_receipt_event( - &self, - room_id: &RoomId, - receipt_type: ReceiptType, - user_id: &UserId, - ) -> StoreResult> { - self.get_user_room_receipt_event(room_id, receipt_type, user_id).await.map_err(|e| e.into()) - } - - async fn get_event_room_receipt_events( - &self, - room_id: &RoomId, - receipt_type: ReceiptType, - event_id: &EventId, - ) -> StoreResult> { - self.get_event_room_receipt_events(room_id, receipt_type, event_id) - .await - .map_err(|e| e.into()) - } - - async fn get_custom_value(&self, key: &[u8]) -> StoreResult>> { - self.get_custom_value(key).await.map_err(|e| e.into()) - } - - async fn set_custom_value(&self, key: &[u8], value: Vec) -> StoreResult>> { - self.set_custom_value(key, value).await.map_err(|e| e.into()) - } - - async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> StoreResult<()> { - self.add_media_content(request, data).await.map_err(|e| e.into()) - } - - async fn get_media_content(&self, request: &MediaRequest) -> StoreResult>> { - self.get_media_content(request).await.map_err(|e| e.into()) - } - - async fn remove_media_content(&self, request: &MediaRequest) -> StoreResult<()> { - self.remove_media_content(request).await.map_err(|e| e.into()) - } - - async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> StoreResult<()> { - self.remove_media_content_for_uri(uri).await.map_err(|e| e.into()) - } - - async fn remove_room(&self, room_id: &RoomId) -> StoreResult<()> { - self.remove_room(room_id).await.map_err(|e| e.into()) - } -} - -#[cfg(all(test, target_arch = "wasm32"))] -mod tests { - #[cfg(target_arch = "wasm32")] - wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); - - use matrix_sdk_base::statestore_integration_tests; - use uuid::Uuid; - - use super::{IndexeddbStateStore, Result}; - - async fn get_store() -> Result { - let db_name = format!("test-state-plain-{}", Uuid::new_v4().as_hyphenated()); - Ok(IndexeddbStateStore::builder().name(db_name).build().await?) - } - - statestore_integration_tests!(with_media_tests); -} - -#[cfg(all(test, target_arch = "wasm32"))] -mod encrypted_tests { - #[cfg(target_arch = "wasm32")] - wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); - - use matrix_sdk_base::statestore_integration_tests; - use uuid::Uuid; - - use super::{IndexeddbStateStore, Result}; - - async fn get_store() -> Result { - let db_name = format!("test-state-encrypted-{}", Uuid::new_v4().as_hyphenated()); - let passphrase = format!("some_passphrase-{}", Uuid::new_v4().as_hyphenated()); - Ok(IndexeddbStateStore::builder().name(db_name).passphrase(passphrase).build().await?) - } - - statestore_integration_tests!(with_media_tests); -} - -#[cfg(all(test, target_arch = "wasm32"))] -mod migration_tests { - wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); - - use indexed_db_futures::prelude::*; - use matrix_sdk_test::async_test; - use uuid::Uuid; - use wasm_bindgen::JsValue; - - use super::{ - IndexeddbStateStore, IndexeddbStateStoreError, MigrationConflictStrategy, Result, - ALL_STORES, - }; - - pub async fn create_fake_db(name: &str, version: f64) -> Result<()> { - let mut db_req: OpenDbRequest = IdbDatabase::open_f64(name, version)?; - db_req.set_on_upgrade_needed(Some( - move |evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { - // migrating to version 1 - let db = evt.db(); - for name in ALL_STORES { - db.create_object_store(name)?; - } - Ok(()) - }, - )); - db_req.into_future().await?; - Ok(()) - } - - #[async_test] - pub async fn test_no_upgrade() -> Result<()> { - let name = format!("simple-1.1-no-cipher-{}", Uuid::new_v4().as_hyphenated().to_string()); - - // this transparently migrates to the latest version - let store = IndexeddbStateStore::builder().name(name).build().await?; - // this didn't create any backup - assert_eq!(store.has_backups().await?, false); - // simple check that the layout exists. - assert_eq!(store.get_sync_token().await?, None); - Ok(()) - } - - #[async_test] - pub async fn test_migrating_v1_to_1_1_plain() -> Result<()> { - let name = - format!("migrating-1.1-no-cipher-{}", Uuid::new_v4().as_hyphenated().to_string()); - create_fake_db(&name, 1.0).await?; - - // this transparently migrates to the latest version - let store = IndexeddbStateStore::builder().name(name).build().await?; - // this didn't create any backup - assert_eq!(store.has_backups().await?, false); - assert_eq!(store.get_sync_token().await?, None); - Ok(()) - } - - #[async_test] - pub async fn test_migrating_v1_to_1_1_with_pw() -> Result<()> { - let name = - format!("migrating-1.1-with-cipher-{}", Uuid::new_v4().as_hyphenated().to_string()); - let passphrase = "somepassphrase".to_owned(); - create_fake_db(&name, 1.0).await?; - - // this transparently migrates to the latest version - let store = - IndexeddbStateStore::builder().name(name).passphrase(passphrase).build().await?; - // this creates a backup by default - assert_eq!(store.has_backups().await?, true); - assert!(store.latest_backup().await?.is_some(), "No backup_found"); - assert_eq!(store.get_sync_token().await?, None); - Ok(()) - } - - #[async_test] - pub async fn test_migrating_v1_to_1_1_with_pw_drops() -> Result<()> { - let name = format!( - "migrating-1.1-with-cipher-drops-{}", - Uuid::new_v4().as_hyphenated().to_string() - ); - let passphrase = "some-other-passphrase".to_owned(); - create_fake_db(&name, 1.0).await?; - - // this transparently migrates to the latest version - let store = IndexeddbStateStore::builder() - .name(name) - .passphrase(passphrase) - .migration_conflict_strategy(MigrationConflictStrategy::Drop) - .build() - .await?; - // this creates a backup by default - assert_eq!(store.has_backups().await?, false); - assert_eq!(store.get_sync_token().await?, None); - Ok(()) - } - - #[async_test] - pub async fn test_migrating_v1_to_1_1_with_pw_raise() -> Result<()> { - let name = format!( - "migrating-1.1-with-cipher-raises-{}", - Uuid::new_v4().as_hyphenated().to_string() - ); - let passphrase = "some-other-passphrase".to_owned(); - create_fake_db(&name, 1.0).await?; - - // this transparently migrates to the latest version - let store_res = IndexeddbStateStore::builder() - .name(name) - .passphrase(passphrase) - .migration_conflict_strategy(MigrationConflictStrategy::Raise) - .build() - .await; - - if let Err(IndexeddbStateStoreError::MigrationConflict { .. }) = store_res { - // all fine! - } else { - assert!(false, "Conflict didn't raise: {:?}", store_res) - } - Ok(()) - } -} diff --git a/crates/matrix-sdk-indexeddb/src/state_store/migrations.rs b/crates/matrix-sdk-indexeddb/src/state_store/migrations.rs new file mode 100644 index 00000000000..f9db21cbb46 --- /dev/null +++ b/crates/matrix-sdk-indexeddb/src/state_store/migrations.rs @@ -0,0 +1,748 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use gloo_utils::format::JsValueSerdeExt; +use indexed_db_futures::{prelude::*, request::OpenDbRequest, IdbDatabase, IdbVersionChangeEvent}; +use js_sys::Date as JsDate; +use matrix_sdk_base::StateStoreDataKey; +use matrix_sdk_store_encryption::StoreCipher; +use serde::{Deserialize, Serialize}; +use serde_json::value::{RawValue as RawJsonValue, Value as JsonValue}; +use wasm_bindgen::JsValue; +use web_sys::IdbTransactionMode; + +use super::{ + deserialize_event, encode_key, encode_to_range, keys, serialize_event, Result, ALL_STORES, +}; +use crate::IndexeddbStateStoreError; + +const CURRENT_DB_VERSION: u32 = 4; +const CURRENT_META_DB_VERSION: u32 = 2; + +/// Sometimes Migrations can't proceed without having to drop existing +/// data. This allows you to configure, how these cases should be handled. +#[allow(dead_code)] +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum MigrationConflictStrategy { + /// Just drop the data, we don't care that we have to sync again + Drop, + /// Raise a [`IndexeddbStateStoreError::MigrationConflict`] error with the + /// path to the DB in question. The caller then has to take care about + /// what they want to do and try again after. + Raise, + /// Default. + BackupAndDrop, +} + +#[derive(Clone, Serialize, Deserialize)] +struct StoreKeyWrapper(Vec); + +mod old_keys { + pub const SESSION: &str = "session"; + pub const SYNC_TOKEN: &str = "sync_token"; +} + +pub async fn upgrade_meta_db( + meta_name: &str, + passphrase: Option<&str>, +) -> Result<(IdbDatabase, Option>)> { + // Meta database. + let mut db_req: OpenDbRequest = IdbDatabase::open_u32(meta_name, CURRENT_META_DB_VERSION)?; + db_req.set_on_upgrade_needed(Some(|evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { + let db = evt.db(); + let old_version = evt.old_version() as u32; + + if old_version < 1 { + db.create_object_store(keys::INTERNAL_STATE)?; + } + + if old_version < 2 { + db.create_object_store(keys::BACKUPS_META)?; + } + + Ok(()) + })); + + let meta_db: IdbDatabase = db_req.into_future().await?; + + let store_cipher = if let Some(passphrase) = passphrase { + let tx: IdbTransaction<'_> = meta_db + .transaction_on_one_with_mode(keys::INTERNAL_STATE, IdbTransactionMode::Readwrite)?; + let ob = tx.object_store(keys::INTERNAL_STATE)?; + + let cipher = if let Some(StoreKeyWrapper(inner)) = ob + .get(&JsValue::from_str(keys::STORE_KEY))? + .await? + .map(|v| v.into_serde()) + .transpose()? + { + StoreCipher::import(passphrase, &inner)? + } else { + let cipher = StoreCipher::new()?; + #[cfg(not(test))] + let export = cipher.export(passphrase)?; + #[cfg(test)] + let export = cipher._insecure_export_fast_for_testing(passphrase)?; + ob.put_key_val( + &JsValue::from_str(keys::STORE_KEY), + &JsValue::from_serde(&StoreKeyWrapper(export))?, + )?; + cipher + }; + + tx.await.into_result()?; + Some(Arc::new(cipher)) + } else { + None + }; + + Ok((meta_db, store_cipher)) +} + +// Helper struct for upgrading the inner DB. +#[derive(Debug, Clone, Default)] +pub struct OngoingMigration { + // Names of stores to drop. + drop_stores: HashSet<&'static str>, + // Names of stores to create. + create_stores: HashSet<&'static str>, + // Store name => key-value data to add. + data: HashMap<&'static str, Vec<(JsValue, JsValue)>>, +} + +impl OngoingMigration { + /// Merge this migration with the given one. + fn merge(&mut self, other: OngoingMigration) { + self.drop_stores.extend(other.drop_stores); + self.create_stores.extend(other.create_stores); + + for (store, data) in other.data { + let entry = self.data.entry(store).or_default(); + entry.extend(data); + } + } +} + +pub async fn upgrade_inner_db( + name: &str, + store_cipher: Option<&StoreCipher>, + migration_strategy: MigrationConflictStrategy, + meta_db: &IdbDatabase, +) -> Result { + let mut migration = OngoingMigration::default(); + { + // This is a hack, we need to open the database a first time to get the current + // version. + // The indexed_db_futures crate doesn't let us access the transaction so we + // can't migrate data inside the `onupgradeneeded` callback. Instead we see if + // we need to migrate some data before the upgrade, then let the store process + // the upgrade. + // See + let has_store_cipher = store_cipher.is_some(); + let pre_db = IdbDatabase::open(name)?.into_future().await?; + + // Even if the web-sys bindings expose the version as a f64, the IndexedDB API + // works with an unsigned integer. + // See + let mut old_version = pre_db.version() as u32; + + // Inside the `onupgradeneeded` callback we would know whether it's a new DB + // because the old version would be set to 0, here it is already set to 1 so we + // check if the stores exist. + if old_version == 1 && pre_db.object_store_names().next().is_none() { + old_version = 0; + } + + // Upgrades to v1 and v2 (re)create empty stores, while the other upgrades + // change data that is already in the stores, so we use exclusive branches here. + if old_version == 0 { + migration.create_stores.extend(ALL_STORES); + } else if old_version < 2 && has_store_cipher { + match migration_strategy { + MigrationConflictStrategy::BackupAndDrop => { + backup_v1(&pre_db, meta_db).await?; + migration.drop_stores.extend(V1_STORES); + migration.create_stores.extend(ALL_STORES); + } + MigrationConflictStrategy::Drop => { + migration.drop_stores.extend(V1_STORES); + migration.create_stores.extend(ALL_STORES); + } + MigrationConflictStrategy::Raise => { + return Err(IndexeddbStateStoreError::MigrationConflict { + name: name.to_owned(), + old_version, + new_version: CURRENT_DB_VERSION, + }); + } + } + } else { + if old_version < 3 { + migrate_to_v3(&pre_db, store_cipher).await?; + } + if old_version < 4 { + migration.merge(migrate_to_v4(&pre_db, store_cipher).await?); + } + } + + pre_db.close(); + } + + let mut db_req: OpenDbRequest = IdbDatabase::open_u32(name, CURRENT_DB_VERSION)?; + db_req.set_on_upgrade_needed(Some(move |evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { + // Changing the format can only happen in the upgrade procedure + for store in &migration.drop_stores { + evt.db().delete_object_store(store)?; + } + for store in &migration.create_stores { + evt.db().create_object_store(store)?; + } + + Ok(()) + })); + + let db = db_req.into_future().await?; + + // Finally, we can add data to the newly created tables if needed. + if !migration.data.is_empty() { + let stores: Vec<_> = migration.data.keys().copied().collect(); + let tx = db.transaction_on_multi_with_mode(&stores, IdbTransactionMode::Readwrite)?; + + for (name, data) in migration.data { + let store = tx.object_store(name)?; + for (key, value) in data { + store.put_key_val(&key, &value)?; + } + } + + tx.await.into_result()?; + } + + Ok(db) +} + +pub const V1_STORES: &[&str] = &[ + old_keys::SESSION, + keys::ACCOUNT_DATA, + keys::MEMBERS, + keys::PROFILES, + keys::DISPLAY_NAMES, + keys::JOINED_USER_IDS, + keys::INVITED_USER_IDS, + keys::ROOM_STATE, + keys::ROOM_INFOS, + keys::PRESENCE, + keys::ROOM_ACCOUNT_DATA, + keys::STRIPPED_ROOM_INFOS, + keys::STRIPPED_MEMBERS, + keys::STRIPPED_ROOM_STATE, + keys::STRIPPED_JOINED_USER_IDS, + keys::STRIPPED_INVITED_USER_IDS, + keys::ROOM_USER_RECEIPTS, + keys::ROOM_EVENT_RECEIPTS, + keys::MEDIA, + keys::CUSTOM, + old_keys::SYNC_TOKEN, +]; + +async fn backup_v1(source: &IdbDatabase, meta: &IdbDatabase) -> Result<()> { + let now = JsDate::now(); + let backup_name = format!("backup-{}-{now}", source.name()); + + let mut db_req: OpenDbRequest = IdbDatabase::open_f64(&backup_name, source.version())?; + db_req.set_on_upgrade_needed(Some(move |evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { + // migrating to version 1 + let db = evt.db(); + for name in V1_STORES { + db.create_object_store(name)?; + } + Ok(()) + })); + let target = db_req.into_future().await?; + + for name in V1_STORES { + let source_tx = source.transaction_on_one_with_mode(name, IdbTransactionMode::Readonly)?; + let source_obj = source_tx.object_store(name)?; + let Some(curs) = source_obj + .open_cursor()? + .await? else { + continue; + }; + + let data = curs.into_vec(0).await?; + + let target_tx = target.transaction_on_one_with_mode(name, IdbTransactionMode::Readwrite)?; + let target_obj = target_tx.object_store(name)?; + + for kv in data { + target_obj.put_key_val(kv.key(), kv.value())?; + } + + target_tx.await.into_result()?; + } + + let tx = + meta.transaction_on_one_with_mode(keys::BACKUPS_META, IdbTransactionMode::Readwrite)?; + let backup_store = tx.object_store(keys::BACKUPS_META)?; + backup_store.put_key_val(&JsValue::from_f64(now), &JsValue::from_str(&backup_name))?; + + tx.await; + + Ok(()) +} + +async fn v3_fix_store( + store: &IdbObjectStore<'_>, + store_cipher: Option<&StoreCipher>, +) -> Result<()> { + fn maybe_fix_json(raw_json: &RawJsonValue) -> Result> { + let json = raw_json.get(); + + if json.contains(r#""content":null"#) { + let mut value: JsonValue = serde_json::from_str(json)?; + if let Some(content) = value.get_mut("content") { + if matches!(content, JsonValue::Null) { + *content = JsonValue::Object(Default::default()); + return Ok(Some(value)); + } + } + } + + Ok(None) + } + + let cursor = store.open_cursor()?.await?; + + if let Some(cursor) = cursor { + loop { + let raw_json: Box = deserialize_event(store_cipher, cursor.value())?; + + if let Some(fixed_json) = maybe_fix_json(&raw_json)? { + cursor.update(&serialize_event(store_cipher, &fixed_json)?)?.await?; + } + + if !cursor.continue_cursor()?.await? { + break; + } + } + } + + Ok(()) +} + +/// Fix serialized redacted state events. +async fn migrate_to_v3(db: &IdbDatabase, store_cipher: Option<&StoreCipher>) -> Result<()> { + let tx = db.transaction_on_multi_with_mode( + &[keys::ROOM_STATE, keys::ROOM_INFOS], + IdbTransactionMode::Readwrite, + )?; + + v3_fix_store(&tx.object_store(keys::ROOM_STATE)?, store_cipher).await?; + v3_fix_store(&tx.object_store(keys::ROOM_INFOS)?, store_cipher).await?; + + tx.await.into_result().map_err(|e| e.into()) +} + +/// Move the content of the SYNC_TOKEN and SESSION stores to the new KV store. +async fn migrate_to_v4( + db: &IdbDatabase, + store_cipher: Option<&StoreCipher>, +) -> Result { + let tx = db.transaction_on_multi_with_mode( + &[old_keys::SYNC_TOKEN, old_keys::SESSION], + IdbTransactionMode::Readonly, + )?; + let mut values = Vec::new(); + + // Sync token + let sync_token_store = tx.object_store(old_keys::SYNC_TOKEN)?; + let sync_token = sync_token_store.get(&JsValue::from_str(old_keys::SYNC_TOKEN))?.await?; + + if let Some(sync_token) = sync_token { + values.push(( + encode_key(store_cipher, StateStoreDataKey::SYNC_TOKEN, StateStoreDataKey::SYNC_TOKEN), + sync_token, + )); + } + + // Filters + let session_store = tx.object_store(old_keys::SESSION)?; + let range = + encode_to_range(store_cipher, StateStoreDataKey::FILTER, StateStoreDataKey::FILTER)?; + if let Some(cursor) = session_store.open_cursor_with_range(&range)?.await? { + while let Some(key) = cursor.key() { + let value = cursor.value(); + values.push((key, value)); + cursor.continue_cursor()?.await?; + } + } + + tx.await.into_result()?; + + let mut data = HashMap::new(); + if !values.is_empty() { + data.insert(keys::KV, values); + } + + Ok(OngoingMigration { + drop_stores: [old_keys::SYNC_TOKEN, old_keys::SESSION].into_iter().collect(), + create_stores: [keys::KV].into_iter().collect(), + data, + }) +} + +#[cfg(all(test, target_arch = "wasm32"))] +mod tests { + wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + + use assert_matches::assert_matches; + use indexed_db_futures::prelude::*; + use matrix_sdk_base::{StateStore, StateStoreDataKey, StoreError}; + use matrix_sdk_test::async_test; + use ruma::{ + events::{AnySyncStateEvent, StateEventType}, + room_id, + }; + use serde_json::json; + use uuid::Uuid; + use wasm_bindgen::JsValue; + + use super::{ + old_keys, MigrationConflictStrategy, CURRENT_DB_VERSION, CURRENT_META_DB_VERSION, V1_STORES, + }; + use crate::{ + safe_encode::SafeEncode, + state_store::{encode_key, keys, serialize_event, Result, ALL_STORES}, + IndexeddbStateStore, IndexeddbStateStoreError, + }; + + const CUSTOM_DATA_KEY: &[u8] = b"custom_data_key"; + const CUSTOM_DATA: &[u8] = b"some_custom_data"; + + pub async fn create_fake_db(name: &str, version: u32) -> Result { + let mut db_req: OpenDbRequest = IdbDatabase::open_u32(name, version)?; + db_req.set_on_upgrade_needed(Some( + move |evt: &IdbVersionChangeEvent| -> Result<(), JsValue> { + let db = evt.db(); + + // Initialize stores. + if version < 4 { + for name in V1_STORES { + db.create_object_store(name)?; + } + } else { + for name in ALL_STORES { + db.create_object_store(name)?; + } + } + + Ok(()) + }, + )); + db_req.into_future().await.map_err(Into::into) + } + + #[async_test] + pub async fn test_new_store() -> Result<()> { + let name = format!("new-store-no-cipher-{}", Uuid::new_v4().as_hyphenated().to_string()); + + // this transparently migrates to the latest version + let store = IndexeddbStateStore::builder().name(name).build().await?; + // this didn't create any backup + assert_eq!(store.has_backups().await?, false); + // simple check that the layout exists. + assert_eq!(store.get_custom_value(CUSTOM_DATA_KEY).await?, None); + + // Check versions. + assert_eq!(store.version(), CURRENT_DB_VERSION); + assert_eq!(store.meta_version(), CURRENT_META_DB_VERSION); + + Ok(()) + } + + #[async_test] + pub async fn test_migrating_v1_to_v2_plain() -> Result<()> { + let name = format!("migrating-v2-no-cipher-{}", Uuid::new_v4().as_hyphenated().to_string()); + + // Create and populate db. + { + let db = create_fake_db(&name, 1).await?; + let tx = + db.transaction_on_one_with_mode(keys::CUSTOM, IdbTransactionMode::Readwrite)?; + let custom = tx.object_store(keys::CUSTOM)?; + let jskey = JsValue::from_str( + core::str::from_utf8(CUSTOM_DATA_KEY).map_err(StoreError::Codec)?, + ); + custom.put_key_val(&jskey, &serialize_event(None, &CUSTOM_DATA)?)?; + tx.await.into_result()?; + db.close(); + } + + // this transparently migrates to the latest version + let store = IndexeddbStateStore::builder().name(name).build().await?; + // this didn't create any backup + assert_eq!(store.has_backups().await?, false); + // Custom data is still there. + let stored_data = assert_matches!( + store.get_custom_value(CUSTOM_DATA_KEY).await?, + Some(d) => d + ); + assert_eq!(stored_data, CUSTOM_DATA); + + // Check versions. + assert_eq!(store.version(), CURRENT_DB_VERSION); + assert_eq!(store.meta_version(), CURRENT_META_DB_VERSION); + + Ok(()) + } + + #[async_test] + pub async fn test_migrating_v1_to_v2_with_pw() -> Result<()> { + let name = + format!("migrating-v2-with-cipher-{}", Uuid::new_v4().as_hyphenated().to_string()); + let passphrase = "somepassphrase".to_owned(); + + // Create and populate db. + { + let db = create_fake_db(&name, 1).await?; + let tx = + db.transaction_on_one_with_mode(keys::CUSTOM, IdbTransactionMode::Readwrite)?; + let custom = tx.object_store(keys::CUSTOM)?; + let jskey = JsValue::from_str( + core::str::from_utf8(CUSTOM_DATA_KEY).map_err(StoreError::Codec)?, + ); + custom.put_key_val(&jskey, &serialize_event(None, &CUSTOM_DATA)?)?; + tx.await.into_result()?; + db.close(); + } + + // this transparently migrates to the latest version + let store = + IndexeddbStateStore::builder().name(name).passphrase(passphrase).build().await?; + // this creates a backup by default + assert_eq!(store.has_backups().await?, true); + assert!(store.latest_backup().await?.is_some(), "No backup_found"); + // the data is gone + assert_eq!(store.get_custom_value(CUSTOM_DATA_KEY).await?, None); + + // Check versions. + assert_eq!(store.version(), CURRENT_DB_VERSION); + assert_eq!(store.meta_version(), CURRENT_META_DB_VERSION); + + Ok(()) + } + + #[async_test] + pub async fn test_migrating_v1_to_v2_with_pw_drops() -> Result<()> { + let name = format!( + "migrating-v2-with-cipher-drops-{}", + Uuid::new_v4().as_hyphenated().to_string() + ); + let passphrase = "some-other-passphrase".to_owned(); + + // Create and populate db. + { + let db = create_fake_db(&name, 1).await?; + let tx = + db.transaction_on_one_with_mode(keys::CUSTOM, IdbTransactionMode::Readwrite)?; + let custom = tx.object_store(keys::CUSTOM)?; + let jskey = JsValue::from_str( + core::str::from_utf8(CUSTOM_DATA_KEY).map_err(StoreError::Codec)?, + ); + custom.put_key_val(&jskey, &serialize_event(None, &CUSTOM_DATA)?)?; + tx.await.into_result()?; + db.close(); + } + + // this transparently migrates to the latest version + let store = IndexeddbStateStore::builder() + .name(name) + .passphrase(passphrase) + .migration_conflict_strategy(MigrationConflictStrategy::Drop) + .build() + .await?; + // this doesn't create a backup + assert_eq!(store.has_backups().await?, false); + // the data is gone + assert_eq!(store.get_custom_value(CUSTOM_DATA_KEY).await?, None); + + // Check versions. + assert_eq!(store.version(), CURRENT_DB_VERSION); + assert_eq!(store.meta_version(), CURRENT_META_DB_VERSION); + + Ok(()) + } + + #[async_test] + pub async fn test_migrating_v1_to_v2_with_pw_raise() -> Result<()> { + let name = format!( + "migrating-v2-with-cipher-raises-{}", + Uuid::new_v4().as_hyphenated().to_string() + ); + let passphrase = "some-other-passphrase".to_owned(); + + // Create and populate db. + { + let db = create_fake_db(&name, 1).await?; + let tx = + db.transaction_on_one_with_mode(keys::CUSTOM, IdbTransactionMode::Readwrite)?; + let custom = tx.object_store(keys::CUSTOM)?; + let jskey = JsValue::from_str( + core::str::from_utf8(CUSTOM_DATA_KEY).map_err(StoreError::Codec)?, + ); + custom.put_key_val(&jskey, &serialize_event(None, &CUSTOM_DATA)?)?; + tx.await.into_result()?; + db.close(); + } + + // this transparently migrates to the latest version + let store_res = IndexeddbStateStore::builder() + .name(name) + .passphrase(passphrase) + .migration_conflict_strategy(MigrationConflictStrategy::Raise) + .build() + .await; + + assert_matches!(store_res, Err(IndexeddbStateStoreError::MigrationConflict { .. })); + + Ok(()) + } + + #[async_test] + pub async fn test_migrating_to_v3() -> Result<()> { + let name = format!("migrating-v3-{}", Uuid::new_v4().as_hyphenated().to_string()); + + // An event that fails to deserialize. + let wrong_redacted_state_event = json!({ + "content": null, + "event_id": "$wrongevent", + "origin_server_ts": 1673887516047_u64, + "sender": "@example:localhost", + "state_key": "", + "type": "m.room.topic", + "unsigned": { + "redacted_because": { + "type": "m.room.redaction", + "sender": "@example:localhost", + "content": {}, + "redacts": "$wrongevent", + "origin_server_ts": 1673893816047_u64, + "unsigned": {}, + "event_id": "$redactionevent", + }, + }, + }); + serde_json::from_value::(wrong_redacted_state_event.clone()) + .unwrap_err(); + + let room_id = room_id!("!some_room:localhost"); + + // Populate DB with wrong event. + { + let db = create_fake_db(&name, 2).await?; + let tx = + db.transaction_on_one_with_mode(keys::ROOM_STATE, IdbTransactionMode::Readwrite)?; + let state = tx.object_store(keys::ROOM_STATE)?; + let key = (room_id, StateEventType::RoomTopic, "").encode(); + state.put_key_val(&key, &serialize_event(None, &wrong_redacted_state_event)?)?; + tx.await.into_result()?; + db.close(); + } + + // this transparently migrates to the latest version + let store = IndexeddbStateStore::builder().name(name).build().await?; + let event = + store.get_state_event(room_id, StateEventType::RoomTopic, "").await.unwrap().unwrap(); + event.deserialize().unwrap(); + + // Check versions. + assert_eq!(store.version(), CURRENT_DB_VERSION); + assert_eq!(store.meta_version(), CURRENT_META_DB_VERSION); + + Ok(()) + } + + #[async_test] + pub async fn test_migrating_to_v4() -> Result<()> { + let name = format!("migrating-v4-{}", Uuid::new_v4().as_hyphenated().to_string()); + + let sync_token = "a_very_unique_string"; + let filter_1 = "filter_1"; + let filter_1_id = "filter_1_id"; + let filter_2 = "filter_2"; + let filter_2_id = "filter_2_id"; + + // Populate DB with old table. + { + let db = create_fake_db(&name, 3).await?; + let tx = db.transaction_on_multi_with_mode( + &[old_keys::SYNC_TOKEN, old_keys::SESSION], + IdbTransactionMode::Readwrite, + )?; + + let sync_token_store = tx.object_store(old_keys::SYNC_TOKEN)?; + sync_token_store.put_key_val( + &JsValue::from_str(old_keys::SYNC_TOKEN), + &serialize_event(None, &sync_token)?, + )?; + + let session_store = tx.object_store(old_keys::SESSION)?; + session_store.put_key_val( + &encode_key(None, StateStoreDataKey::FILTER, (StateStoreDataKey::FILTER, filter_1)), + &serialize_event(None, &filter_1_id)?, + )?; + session_store.put_key_val( + &encode_key(None, StateStoreDataKey::FILTER, (StateStoreDataKey::FILTER, filter_2)), + &serialize_event(None, &filter_2_id)?, + )?; + + tx.await.into_result()?; + db.close(); + } + + // this transparently migrates to the latest version + let store = IndexeddbStateStore::builder().name(name).build().await?; + + let stored_sync_token = store + .get_kv_data(StateStoreDataKey::SyncToken) + .await? + .unwrap() + .into_sync_token() + .unwrap(); + assert_eq!(stored_sync_token, sync_token); + + let stored_filter_1_id = store + .get_kv_data(StateStoreDataKey::Filter(filter_1)) + .await? + .unwrap() + .into_filter() + .unwrap(); + assert_eq!(stored_filter_1_id, filter_1_id); + + let stored_filter_2_id = store + .get_kv_data(StateStoreDataKey::Filter(filter_2)) + .await? + .unwrap() + .into_filter() + .unwrap(); + assert_eq!(stored_filter_2_id, filter_2_id); + + Ok(()) + } +} diff --git a/crates/matrix-sdk-indexeddb/src/state_store/mod.rs b/crates/matrix-sdk-indexeddb/src/state_store/mod.rs new file mode 100644 index 00000000000..fd8fbe01d91 --- /dev/null +++ b/crates/matrix-sdk-indexeddb/src/state_store/mod.rs @@ -0,0 +1,1302 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + collections::{BTreeSet, HashSet}, + sync::Arc, +}; + +use anyhow::anyhow; +use async_trait::async_trait; +use gloo_utils::format::JsValueSerdeExt; +use indexed_db_futures::prelude::*; +use matrix_sdk_base::{ + deserialized_responses::RawMemberEvent, + media::{MediaRequest, UniqueKey}, + store::{StateChanges, StateStore, StoreError}, + MinimalStateEvent, RoomInfo, StateStoreDataKey, StateStoreDataValue, +}; +use matrix_sdk_store_encryption::{Error as EncryptionError, StoreCipher}; +use ruma::{ + canonical_json::redact, + events::{ + presence::PresenceEvent, + receipt::{Receipt, ReceiptThread, ReceiptType}, + room::member::{MembershipState, RoomMemberEventContent}, + AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnySyncStateEvent, + GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType, + }, + serde::Raw, + CanonicalJsonObject, EventId, MxcUri, OwnedEventId, OwnedUserId, RoomId, RoomVersionId, UserId, +}; +use serde::{de::DeserializeOwned, Serialize}; +use tracing::{debug, warn}; +use wasm_bindgen::JsValue; +use web_sys::IdbKeyRange; + +mod migrations; + +pub use self::migrations::MigrationConflictStrategy; +use self::migrations::{upgrade_inner_db, upgrade_meta_db}; +use crate::safe_encode::SafeEncode; + +#[derive(Debug, thiserror::Error)] +pub enum IndexeddbStateStoreError { + #[error(transparent)] + Json(#[from] serde_json::Error), + #[error(transparent)] + Encryption(#[from] EncryptionError), + #[error("DomException {name} ({code}): {message}")] + DomException { name: String, message: String, code: u16 }, + #[error(transparent)] + StoreError(#[from] StoreError), + #[error("Can't migrate {name} from {old_version} to {new_version} without deleting data. See MigrationConflictStrategy for ways to configure.")] + MigrationConflict { name: String, old_version: u32, new_version: u32 }, +} + +impl From for IndexeddbStateStoreError { + fn from(frm: indexed_db_futures::web_sys::DomException) -> IndexeddbStateStoreError { + IndexeddbStateStoreError::DomException { + name: frm.name(), + message: frm.message(), + code: frm.code(), + } + } +} + +impl From for StoreError { + fn from(e: IndexeddbStateStoreError) -> Self { + match e { + IndexeddbStateStoreError::Json(e) => StoreError::Json(e), + IndexeddbStateStoreError::StoreError(e) => e, + IndexeddbStateStoreError::Encryption(e) => StoreError::Encryption(e), + _ => StoreError::backend(e), + } + } +} + +mod keys { + pub const INTERNAL_STATE: &str = "matrix-sdk-state"; + pub const BACKUPS_META: &str = "backups"; + + pub const ACCOUNT_DATA: &str = "account_data"; + + pub const MEMBERS: &str = "members"; + pub const PROFILES: &str = "profiles"; + pub const DISPLAY_NAMES: &str = "display_names"; + pub const JOINED_USER_IDS: &str = "joined_user_ids"; + pub const INVITED_USER_IDS: &str = "invited_user_ids"; + + pub const ROOM_STATE: &str = "room_state"; + pub const ROOM_INFOS: &str = "room_infos"; + pub const PRESENCE: &str = "presence"; + pub const ROOM_ACCOUNT_DATA: &str = "room_account_data"; + + pub const STRIPPED_ROOM_INFOS: &str = "stripped_room_infos"; + pub const STRIPPED_MEMBERS: &str = "stripped_members"; + pub const STRIPPED_ROOM_STATE: &str = "stripped_room_state"; + pub const STRIPPED_JOINED_USER_IDS: &str = "stripped_joined_user_ids"; + pub const STRIPPED_INVITED_USER_IDS: &str = "stripped_invited_user_ids"; + + pub const ROOM_USER_RECEIPTS: &str = "room_user_receipts"; + pub const ROOM_EVENT_RECEIPTS: &str = "room_event_receipts"; + + pub const MEDIA: &str = "media"; + + pub const CUSTOM: &str = "custom"; + pub const KV: &str = "kv"; + + /// All names of the current state stores for convenience. + pub const ALL_STORES: &[&str] = &[ + ACCOUNT_DATA, + MEMBERS, + PROFILES, + DISPLAY_NAMES, + JOINED_USER_IDS, + INVITED_USER_IDS, + ROOM_STATE, + ROOM_INFOS, + PRESENCE, + ROOM_ACCOUNT_DATA, + STRIPPED_ROOM_INFOS, + STRIPPED_MEMBERS, + STRIPPED_ROOM_STATE, + STRIPPED_JOINED_USER_IDS, + STRIPPED_INVITED_USER_IDS, + ROOM_USER_RECEIPTS, + ROOM_EVENT_RECEIPTS, + MEDIA, + CUSTOM, + KV, + ]; + + // static keys + + pub const STORE_KEY: &str = "store_key"; +} + +pub use keys::ALL_STORES; + +fn serialize_event(store_cipher: Option<&StoreCipher>, event: &impl Serialize) -> Result { + Ok(match store_cipher { + Some(cipher) => JsValue::from_serde(&cipher.encrypt_value_typed(event)?)?, + None => JsValue::from_serde(event)?, + }) +} + +fn deserialize_event( + store_cipher: Option<&StoreCipher>, + event: JsValue, +) -> Result { + match store_cipher { + Some(cipher) => Ok(cipher.decrypt_value_typed(event.into_serde()?)?), + None => Ok(event.into_serde()?), + } +} + +fn encode_key(store_cipher: Option<&StoreCipher>, table_name: &str, key: T) -> JsValue +where + T: SafeEncode, +{ + match store_cipher { + Some(cipher) => key.encode_secure(table_name, cipher), + None => key.encode(), + } +} + +fn encode_to_range( + store_cipher: Option<&StoreCipher>, + table_name: &str, + key: T, +) -> Result +where + T: SafeEncode, +{ + match store_cipher { + Some(cipher) => key.encode_to_range_secure(table_name, cipher), + None => key.encode_to_range(), + } + .map_err(|e| IndexeddbStateStoreError::StoreError(StoreError::Backend(anyhow!(e).into()))) +} + +/// Builder for [`IndexeddbStateStore`]. +#[derive(Debug)] +pub struct IndexeddbStateStoreBuilder { + name: Option, + passphrase: Option, + migration_conflict_strategy: MigrationConflictStrategy, +} + +impl IndexeddbStateStoreBuilder { + fn new() -> Self { + Self { + name: None, + passphrase: None, + migration_conflict_strategy: MigrationConflictStrategy::BackupAndDrop, + } + } + + /// Set the name for the indexeddb store to use, `state` is none given. + pub fn name(mut self, value: String) -> Self { + self.name = Some(value); + self + } + + /// Set the password the indexeddb should be encrypted with. + /// + /// If not given, the DB is not encrypted. + pub fn passphrase(mut self, value: String) -> Self { + self.passphrase = Some(value); + self + } + + /// The strategy to use when a merge conflict is found. + /// + /// See [`MigrationConflictStrategy`] for details. + pub fn migration_conflict_strategy(mut self, value: MigrationConflictStrategy) -> Self { + self.migration_conflict_strategy = value; + self + } + + pub async fn build(self) -> Result { + let migration_strategy = self.migration_conflict_strategy.clone(); + let name = self.name.unwrap_or_else(|| "state".to_owned()); + + let meta_name = format!("{name}::{}", keys::INTERNAL_STATE); + + let (meta, store_cipher) = upgrade_meta_db(&meta_name, self.passphrase.as_deref()).await?; + let inner = + upgrade_inner_db(&name, store_cipher.as_deref(), migration_strategy, &meta).await?; + + Ok(IndexeddbStateStore { name, inner, meta, store_cipher }) + } +} + +pub struct IndexeddbStateStore { + name: String, + pub(crate) inner: IdbDatabase, + pub(crate) meta: IdbDatabase, + pub(crate) store_cipher: Option>, +} + +impl std::fmt::Debug for IndexeddbStateStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("IndexeddbStateStore").field("name", &self.name).finish() + } +} + +type Result = std::result::Result; + +impl IndexeddbStateStore { + /// Generate a IndexeddbStateStoreBuilder with default parameters + pub fn builder() -> IndexeddbStateStoreBuilder { + IndexeddbStateStoreBuilder::new() + } + + /// The version of the database containing the data. + pub fn version(&self) -> u32 { + self.inner.version() as u32 + } + + /// The version of the database containing the metadata. + pub fn meta_version(&self) -> u32 { + self.meta.version() as u32 + } + + /// Whether this database has any migration backups + pub async fn has_backups(&self) -> Result { + Ok(self + .meta + .transaction_on_one_with_mode(keys::BACKUPS_META, IdbTransactionMode::Readonly)? + .object_store(keys::BACKUPS_META)? + .count()? + .await? + > 0) + } + + /// What's the database name of the latest backup< + pub async fn latest_backup(&self) -> Result> { + Ok(self + .meta + .transaction_on_one_with_mode(keys::BACKUPS_META, IdbTransactionMode::Readonly)? + .object_store(keys::BACKUPS_META)? + .open_cursor_with_direction(indexed_db_futures::prelude::IdbCursorDirection::Prev)? + .await? + .and_then(|c| c.value().as_string())) + } + + fn serialize_event(&self, event: &impl Serialize) -> Result { + serialize_event(self.store_cipher.as_deref(), event) + } + + fn deserialize_event(&self, event: JsValue) -> Result { + deserialize_event(self.store_cipher.as_deref(), event) + } + + fn encode_key(&self, table_name: &str, key: T) -> JsValue + where + T: SafeEncode, + { + encode_key(self.store_cipher.as_deref(), table_name, key) + } + + fn encode_to_range(&self, table_name: &str, key: T) -> Result + where + T: SafeEncode, + { + encode_to_range(self.store_cipher.as_deref(), table_name, key) + } + + pub async fn get_user_ids_stream(&self, room_id: &RoomId) -> Result> { + Ok([ + self.get_invited_user_ids_inner(room_id).await?, + self.get_joined_user_ids_inner(room_id).await?, + ] + .concat()) + } + + pub async fn get_invited_user_ids_inner(&self, room_id: &RoomId) -> Result> { + let range = self.encode_to_range(keys::INVITED_USER_IDS, room_id)?; + let entries = self + .inner + .transaction_on_one_with_mode(keys::INVITED_USER_IDS, IdbTransactionMode::Readonly)? + .object_store(keys::INVITED_USER_IDS)? + .get_all_with_key(&range)? + .await? + .iter() + .filter_map(|f| self.deserialize_event::(f).ok()) + .collect::>(); + + Ok(entries) + } + + pub async fn get_joined_user_ids_inner(&self, room_id: &RoomId) -> Result> { + let range = self.encode_to_range(keys::JOINED_USER_IDS, room_id)?; + Ok(self + .inner + .transaction_on_one_with_mode(keys::JOINED_USER_IDS, IdbTransactionMode::Readonly)? + .object_store(keys::JOINED_USER_IDS)? + .get_all_with_key(&range)? + .await? + .iter() + .filter_map(|f| self.deserialize_event::(f).ok()) + .collect::>()) + } + + pub async fn get_stripped_user_ids_stream(&self, room_id: &RoomId) -> Result> { + Ok([ + self.get_stripped_invited_user_ids(room_id).await?, + self.get_stripped_joined_user_ids(room_id).await?, + ] + .concat()) + } + + pub async fn get_stripped_invited_user_ids( + &self, + room_id: &RoomId, + ) -> Result> { + let range = self.encode_to_range(keys::STRIPPED_INVITED_USER_IDS, room_id)?; + let entries = self + .inner + .transaction_on_one_with_mode( + keys::STRIPPED_INVITED_USER_IDS, + IdbTransactionMode::Readonly, + )? + .object_store(keys::STRIPPED_INVITED_USER_IDS)? + .get_all_with_key(&range)? + .await? + .iter() + .filter_map(|f| self.deserialize_event::(f).ok()) + .collect::>(); + + Ok(entries) + } + + pub async fn get_stripped_joined_user_ids(&self, room_id: &RoomId) -> Result> { + let range = self.encode_to_range(keys::STRIPPED_JOINED_USER_IDS, room_id)?; + Ok(self + .inner + .transaction_on_one_with_mode( + keys::STRIPPED_JOINED_USER_IDS, + IdbTransactionMode::Readonly, + )? + .object_store(keys::STRIPPED_JOINED_USER_IDS)? + .get_all_with_key(&range)? + .await? + .iter() + .filter_map(|f| self.deserialize_event::(f).ok()) + .collect::>()) + } + + async fn get_custom_value_for_js(&self, jskey: &JsValue) -> Result>> { + self.inner + .transaction_on_one_with_mode(keys::CUSTOM, IdbTransactionMode::Readonly)? + .object_store(keys::CUSTOM)? + .get(jskey)? + .await? + .map(|f| self.deserialize_event(f)) + .transpose() + } + + fn encode_kv_data_key(&self, key: StateStoreDataKey<'_>) -> JsValue { + // Use the key (prefix) for the table name as well, to keep encoded + // keys compatible for the sync token and filters, which were in + // separate tables initially. + match key { + StateStoreDataKey::SyncToken => { + self.encode_key(StateStoreDataKey::SYNC_TOKEN, StateStoreDataKey::SYNC_TOKEN) + } + StateStoreDataKey::Filter(filter_name) => { + self.encode_key(StateStoreDataKey::FILTER, (StateStoreDataKey::FILTER, filter_name)) + } + StateStoreDataKey::UserAvatarUrl(user_id) => { + self.encode_key(keys::KV, (StateStoreDataKey::USER_AVATAR_URL, user_id)) + } + } + } +} + +// Small hack to have the following macro invocation act as the appropriate +// trait impl block on wasm, but still be compiled on non-wasm as a regular +// impl block otherwise. +// +// The trait impl doesn't compile on non-wasm due to unfulfilled trait bounds, +// this hack allows us to still have most of rust-analyzer's IDE functionality +// within the impl block without having to set it up to check things against +// the wasm target (which would disable many other parts of the codebase). +#[cfg(target_arch = "wasm32")] +macro_rules! impl_state_store { + ( $($body:tt)* ) => { + #[async_trait(?Send)] + impl StateStore for IndexeddbStateStore { + type Error = IndexeddbStateStoreError; + + $($body)* + } + }; +} + +#[cfg(not(target_arch = "wasm32"))] +macro_rules! impl_state_store { + ( $($body:tt)* ) => { + impl IndexeddbStateStore { + $($body)* + } + }; +} + +impl_state_store! { + async fn get_kv_data( + &self, + key: StateStoreDataKey<'_>, + ) -> Result> { + let encoded_key = self.encode_kv_data_key(key); + + let value = self + .inner + .transaction_on_one_with_mode(keys::KV, IdbTransactionMode::Readonly)? + .object_store(keys::KV)? + .get(&encoded_key)? + .await? + .map(|f| self.deserialize_event::(f)) + .transpose()?; + + let value = match key { + StateStoreDataKey::SyncToken => value.map(StateStoreDataValue::SyncToken), + StateStoreDataKey::Filter(_) => value.map(StateStoreDataValue::Filter), + StateStoreDataKey::UserAvatarUrl(_) => value.map(StateStoreDataValue::UserAvatarUrl), + }; + + Ok(value) + } + + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> Result<()> { + let encoded_key = self.encode_kv_data_key(key); + + let value = match key { + StateStoreDataKey::SyncToken => { + value.into_sync_token().expect("Session data not a sync token") + } + StateStoreDataKey::Filter(_) => { + value.into_filter().expect("Session data not a filter") + } + StateStoreDataKey::UserAvatarUrl(_) => { + value.into_user_avatar_url().expect("Session data not an user avatar url") + } + }; + + let tx = self + .inner + .transaction_on_one_with_mode(keys::KV, IdbTransactionMode::Readwrite)?; + + let obj = tx.object_store(keys::KV)?; + + obj.put_key_val(&encoded_key, &self.serialize_event(&value)?)?; + + tx.await.into_result()?; + + Ok(()) + } + + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> { + let encoded_key = self.encode_kv_data_key(key); + + let tx = self + .inner + .transaction_on_one_with_mode(keys::KV, IdbTransactionMode::Readwrite)?; + let obj = tx.object_store(keys::KV)?; + + obj.delete(&encoded_key)?; + + tx.await.into_result()?; + + Ok(()) + } + + async fn save_changes(&self, changes: &StateChanges) -> Result<()> { + let mut stores: HashSet<&'static str> = [ + (changes.sync_token.is_some(), keys::KV), + (!changes.ambiguity_maps.is_empty(), keys::DISPLAY_NAMES), + (!changes.account_data.is_empty(), keys::ACCOUNT_DATA), + (!changes.presence.is_empty(), keys::PRESENCE), + (!changes.profiles.is_empty(), keys::PROFILES), + (!changes.room_account_data.is_empty(), keys::ROOM_ACCOUNT_DATA), + (!changes.receipts.is_empty(), keys::ROOM_EVENT_RECEIPTS), + (!changes.stripped_state.is_empty(), keys::STRIPPED_ROOM_STATE), + ] + .iter() + .filter_map(|(id, key)| if *id { Some(*key) } else { None }) + .collect(); + + if !changes.state.is_empty() { + stores.extend([keys::ROOM_STATE, keys::STRIPPED_ROOM_STATE]); + } + + if !changes.redactions.is_empty() { + stores.extend([keys::ROOM_STATE, keys::ROOM_INFOS]); + } + + if !changes.room_infos.is_empty() || !changes.stripped_room_infos.is_empty() { + stores.extend([keys::ROOM_INFOS, keys::STRIPPED_ROOM_INFOS]); + } + + if !changes.members.is_empty() { + stores.extend([ + keys::PROFILES, + keys::MEMBERS, + keys::INVITED_USER_IDS, + keys::JOINED_USER_IDS, + keys::STRIPPED_MEMBERS, + keys::STRIPPED_INVITED_USER_IDS, + keys::STRIPPED_JOINED_USER_IDS, + ]) + } + + if !changes.stripped_members.is_empty() { + stores.extend([ + keys::STRIPPED_MEMBERS, + keys::STRIPPED_INVITED_USER_IDS, + keys::STRIPPED_JOINED_USER_IDS, + ]) + } + + if !changes.receipts.is_empty() { + stores.extend([keys::ROOM_EVENT_RECEIPTS, keys::ROOM_USER_RECEIPTS]) + } + + if stores.is_empty() { + // nothing to do, quit early + return Ok(()); + } + + let stores: Vec<&'static str> = stores.into_iter().collect(); + let tx = + self.inner.transaction_on_multi_with_mode(&stores, IdbTransactionMode::Readwrite)?; + + if let Some(s) = &changes.sync_token { + tx.object_store(keys::KV)?.put_key_val( + &self.encode_kv_data_key(StateStoreDataKey::SyncToken), + &self.serialize_event(s)?, + )?; + } + + if !changes.ambiguity_maps.is_empty() { + let store = tx.object_store(keys::DISPLAY_NAMES)?; + for (room_id, ambiguity_maps) in &changes.ambiguity_maps { + for (display_name, map) in ambiguity_maps { + let key = self.encode_key(keys::DISPLAY_NAMES, (room_id, display_name)); + + store.put_key_val(&key, &self.serialize_event(&map)?)?; + } + } + } + + if !changes.account_data.is_empty() { + let store = tx.object_store(keys::ACCOUNT_DATA)?; + for (event_type, event) in &changes.account_data { + store.put_key_val( + &self.encode_key(keys::ACCOUNT_DATA, event_type), + &self.serialize_event(&event)?, + )?; + } + } + + if !changes.room_account_data.is_empty() { + let store = tx.object_store(keys::ROOM_ACCOUNT_DATA)?; + for (room, events) in &changes.room_account_data { + for (event_type, event) in events { + let key = self.encode_key(keys::ROOM_ACCOUNT_DATA, (room, event_type)); + store.put_key_val(&key, &self.serialize_event(&event)?)?; + } + } + } + + if !changes.state.is_empty() { + let state = tx.object_store(keys::ROOM_STATE)?; + let stripped_state = tx.object_store(keys::STRIPPED_ROOM_STATE)?; + for (room, event_types) in &changes.state { + for (event_type, events) in event_types { + for (state_key, event) in events { + let key = self.encode_key(keys::ROOM_STATE, (room, event_type, state_key)); + state.put_key_val(&key, &self.serialize_event(&event)?)?; + stripped_state.delete(&key)?; + } + } + } + } + + if !changes.room_infos.is_empty() { + let room_infos = tx.object_store(keys::ROOM_INFOS)?; + let stripped_room_infos = tx.object_store(keys::STRIPPED_ROOM_INFOS)?; + for (room_id, room_info) in &changes.room_infos { + room_infos.put_key_val( + &self.encode_key(keys::ROOM_INFOS, room_id), + &self.serialize_event(&room_info)?, + )?; + stripped_room_infos.delete(&self.encode_key(keys::STRIPPED_ROOM_INFOS, room_id))?; + } + } + + if !changes.presence.is_empty() { + let store = tx.object_store(keys::PRESENCE)?; + for (sender, event) in &changes.presence { + store.put_key_val( + &self.encode_key(keys::PRESENCE, sender), + &self.serialize_event(&event)?, + )?; + } + } + + if !changes.stripped_room_infos.is_empty() { + let stripped_room_infos = tx.object_store(keys::STRIPPED_ROOM_INFOS)?; + let room_infos = tx.object_store(keys::ROOM_INFOS)?; + for (room_id, info) in &changes.stripped_room_infos { + stripped_room_infos.put_key_val( + &self.encode_key(keys::STRIPPED_ROOM_INFOS, room_id), + &self.serialize_event(&info)?, + )?; + room_infos.delete(&self.encode_key(keys::ROOM_INFOS, room_id))?; + } + } + + if !changes.stripped_members.is_empty() { + let store = tx.object_store(keys::STRIPPED_MEMBERS)?; + let joined = tx.object_store(keys::STRIPPED_JOINED_USER_IDS)?; + let invited = tx.object_store(keys::STRIPPED_INVITED_USER_IDS)?; + for (room, raw_events) in &changes.stripped_members { + for raw_event in raw_events.values() { + let event = match raw_event.deserialize() { + Ok(ev) => ev, + Err(e) => { + let event_id: Option = + raw_event.get_field("event_id").ok().flatten(); + debug!(event_id, "Failed to deserialize stripped member event: {e}"); + continue; + } + }; + + let key = (room, &event.state_key); + + match event.content.membership { + MembershipState::Join => { + joined.put_key_val_owned( + &self.encode_key(keys::STRIPPED_JOINED_USER_IDS, key), + &self.serialize_event(&event.state_key)?, + )?; + invited + .delete(&self.encode_key(keys::STRIPPED_INVITED_USER_IDS, key))?; + } + MembershipState::Invite => { + invited.put_key_val_owned( + &self.encode_key(keys::STRIPPED_INVITED_USER_IDS, key), + &self.serialize_event(&event.state_key)?, + )?; + joined.delete(&self.encode_key(keys::STRIPPED_JOINED_USER_IDS, key))?; + } + _ => { + joined.delete(&self.encode_key(keys::STRIPPED_JOINED_USER_IDS, key))?; + invited + .delete(&self.encode_key(keys::STRIPPED_INVITED_USER_IDS, key))?; + } + } + store.put_key_val( + &self.encode_key(keys::STRIPPED_MEMBERS, key), + &self.serialize_event(&raw_event)?, + )?; + } + } + } + + if !changes.stripped_state.is_empty() { + let store = tx.object_store(keys::STRIPPED_ROOM_STATE)?; + for (room, event_types) in &changes.stripped_state { + for (event_type, events) in event_types { + for (state_key, event) in events { + let key = self + .encode_key(keys::STRIPPED_ROOM_STATE, (room, event_type, state_key)); + store.put_key_val(&key, &self.serialize_event(&event)?)?; + } + } + } + } + + if !changes.members.is_empty() { + let profiles = tx.object_store(keys::PROFILES)?; + let joined = tx.object_store(keys::JOINED_USER_IDS)?; + let invited = tx.object_store(keys::INVITED_USER_IDS)?; + let members = tx.object_store(keys::MEMBERS)?; + let stripped_members = tx.object_store(keys::STRIPPED_MEMBERS)?; + let stripped_joined = tx.object_store(keys::STRIPPED_JOINED_USER_IDS)?; + let stripped_invited = tx.object_store(keys::STRIPPED_INVITED_USER_IDS)?; + + for (room, raw_events) in &changes.members { + let profile_changes = changes.profiles.get(room); + + for raw_event in raw_events.values() { + let event = match raw_event.deserialize() { + Ok(ev) => ev, + Err(e) => { + let event_id: Option = + raw_event.get_field("event_id").ok().flatten(); + debug!(event_id, "Failed to deserialize member event: {e}"); + continue; + } + }; + + let key = (room, event.state_key()); + + stripped_joined + .delete(&self.encode_key(keys::STRIPPED_JOINED_USER_IDS, key))?; + stripped_invited + .delete(&self.encode_key(keys::STRIPPED_INVITED_USER_IDS, key))?; + + match event.membership() { + MembershipState::Join => { + joined.put_key_val_owned( + &self.encode_key(keys::JOINED_USER_IDS, key), + &self.serialize_event(event.state_key())?, + )?; + invited.delete(&self.encode_key(keys::INVITED_USER_IDS, key))?; + } + MembershipState::Invite => { + invited.put_key_val_owned( + &self.encode_key(keys::INVITED_USER_IDS, key), + &self.serialize_event(event.state_key())?, + )?; + joined.delete(&self.encode_key(keys::JOINED_USER_IDS, key))?; + } + _ => { + joined.delete(&self.encode_key(keys::JOINED_USER_IDS, key))?; + invited.delete(&self.encode_key(keys::INVITED_USER_IDS, key))?; + } + } + + members.put_key_val_owned( + &self.encode_key(keys::MEMBERS, key), + &self.serialize_event(&raw_event)?, + )?; + stripped_members.delete(&self.encode_key(keys::STRIPPED_MEMBERS, key))?; + + if let Some(profile) = profile_changes.and_then(|p| p.get(event.state_key())) { + profiles.put_key_val_owned( + &self.encode_key(keys::PROFILES, key), + &self.serialize_event(&profile)?, + )?; + } + } + } + } + + if !changes.receipts.is_empty() { + let room_user_receipts = tx.object_store(keys::ROOM_USER_RECEIPTS)?; + let room_event_receipts = tx.object_store(keys::ROOM_EVENT_RECEIPTS)?; + + for (room, content) in &changes.receipts { + for (event_id, receipts) in &content.0 { + for (receipt_type, receipts) in receipts { + for (user_id, receipt) in receipts { + let key = match receipt.thread.as_str() { + Some(thread_id) => self.encode_key( + keys::ROOM_USER_RECEIPTS, + (room, receipt_type, thread_id, user_id), + ), + None => self.encode_key( + keys::ROOM_USER_RECEIPTS, + (room, receipt_type, user_id), + ), + }; + + if let Some((old_event, _)) = + room_user_receipts.get(&key)?.await?.and_then(|f| { + self.deserialize_event::<(OwnedEventId, Receipt)>(f).ok() + }) + { + let key = match receipt.thread.as_str() { + Some(thread_id) => self.encode_key( + keys::ROOM_EVENT_RECEIPTS, + (room, receipt_type, thread_id, old_event, user_id), + ), + None => self.encode_key( + keys::ROOM_EVENT_RECEIPTS, + (room, receipt_type, old_event, user_id), + ), + }; + room_event_receipts.delete(&key)?; + } + + room_user_receipts + .put_key_val(&key, &self.serialize_event(&(event_id, receipt))?)?; + + // Add the receipt to the room event receipts + let key = match receipt.thread.as_str() { + Some(thread_id) => self.encode_key( + keys::ROOM_EVENT_RECEIPTS, + (room, receipt_type, thread_id, event_id, user_id), + ), + None => self.encode_key( + keys::ROOM_EVENT_RECEIPTS, + (room, receipt_type, event_id, user_id), + ), + }; + room_event_receipts + .put_key_val(&key, &self.serialize_event(&(user_id, receipt))?)?; + } + } + } + } + } + + if !changes.redactions.is_empty() { + let state = tx.object_store(keys::ROOM_STATE)?; + let room_info = tx.object_store(keys::ROOM_INFOS)?; + + for (room_id, redactions) in &changes.redactions { + let range = self.encode_to_range(keys::ROOM_STATE, room_id)?; + let Some(cursor) = state.open_cursor_with_range(&range)?.await? else { continue }; + + let mut room_version = None; + + while let Some(key) = cursor.key() { + let raw_evt = + self.deserialize_event::>(cursor.value())?; + if let Ok(Some(event_id)) = raw_evt.get_field::("event_id") { + if let Some(redaction) = redactions.get(&event_id) { + let version = { + if room_version.is_none() { + room_version.replace(room_info + .get(&self.encode_key(keys::ROOM_INFOS, room_id))? + .await? + .and_then(|f| self.deserialize_event::(f).ok()) + .and_then(|info| info.room_version().cloned()) + .unwrap_or_else(|| { + warn!(?room_id, "Unable to find the room version, assume version 9"); + RoomVersionId::V9 + }) + ); + } + room_version.as_ref().unwrap() + }; + + let redacted = redact( + raw_evt.deserialize_as::()?, + version, + Some(redaction.try_into()?), + ) + .map_err(StoreError::Redaction)?; + state.put_key_val(&key, &self.serialize_event(&redacted)?)?; + } + } + + // move forward. + cursor.advance(1)?.await?; + } + } + } + + tx.await.into_result().map_err(|e| e.into()) + } + + async fn get_presence_event(&self, user_id: &UserId) -> Result>> { + self.inner + .transaction_on_one_with_mode(keys::PRESENCE, IdbTransactionMode::Readonly)? + .object_store(keys::PRESENCE)? + .get(&self.encode_key(keys::PRESENCE, user_id))? + .await? + .map(|f| self.deserialize_event(f)) + .transpose() + } + + async fn get_state_event( + &self, + room_id: &RoomId, + event_type: StateEventType, + state_key: &str, + ) -> Result>> { + self.inner + .transaction_on_one_with_mode(keys::ROOM_STATE, IdbTransactionMode::Readonly)? + .object_store(keys::ROOM_STATE)? + .get(&self.encode_key(keys::ROOM_STATE, (room_id, event_type, state_key)))? + .await? + .map(|f| self.deserialize_event(f)) + .transpose() + } + + async fn get_state_events( + &self, + room_id: &RoomId, + event_type: StateEventType, + ) -> Result>> { + let range = self.encode_to_range(keys::ROOM_STATE, (room_id, event_type))?; + Ok(self + .inner + .transaction_on_one_with_mode(keys::ROOM_STATE, IdbTransactionMode::Readonly)? + .object_store(keys::ROOM_STATE)? + .get_all_with_key(&range)? + .await? + .iter() + .filter_map(|f| self.deserialize_event(f).ok()) + .collect::>()) + } + + async fn get_profile( + &self, + room_id: &RoomId, + user_id: &UserId, + ) -> Result>> { + self.inner + .transaction_on_one_with_mode(keys::PROFILES, IdbTransactionMode::Readonly)? + .object_store(keys::PROFILES)? + .get(&self.encode_key(keys::PROFILES, (room_id, user_id)))? + .await? + .map(|f| self.deserialize_event(f)) + .transpose() + } + + async fn get_member_event( + &self, + room_id: &RoomId, + state_key: &UserId, + ) -> Result> { + if let Some(e) = self + .inner + .transaction_on_one_with_mode(keys::STRIPPED_MEMBERS, IdbTransactionMode::Readonly)? + .object_store(keys::STRIPPED_MEMBERS)? + .get(&self.encode_key(keys::STRIPPED_MEMBERS, (room_id, state_key)))? + .await? + .map(|f| self.deserialize_event(f)) + .transpose()? + { + Ok(Some(RawMemberEvent::Stripped(e))) + } else if let Some(e) = self + .inner + .transaction_on_one_with_mode(keys::MEMBERS, IdbTransactionMode::Readonly)? + .object_store(keys::MEMBERS)? + .get(&self.encode_key(keys::MEMBERS, (room_id, state_key)))? + .await? + .map(|f| self.deserialize_event(f)) + .transpose()? + { + Ok(Some(RawMemberEvent::Sync(e))) + } else { + Ok(None) + } + } + + async fn get_room_infos(&self) -> Result> { + let entries: Vec<_> = self + .inner + .transaction_on_one_with_mode(keys::ROOM_INFOS, IdbTransactionMode::Readonly)? + .object_store(keys::ROOM_INFOS)? + .get_all()? + .await? + .iter() + .filter_map(|f| self.deserialize_event::(f).ok()) + .collect(); + + Ok(entries) + } + + async fn get_stripped_room_infos(&self) -> Result> { + let entries = self + .inner + .transaction_on_one_with_mode(keys::STRIPPED_ROOM_INFOS, IdbTransactionMode::Readonly)? + .object_store(keys::STRIPPED_ROOM_INFOS)? + .get_all()? + .await? + .iter() + .filter_map(|f| self.deserialize_event(f).ok()) + .collect::>(); + + Ok(entries) + } + + async fn get_users_with_display_name( + &self, + room_id: &RoomId, + display_name: &str, + ) -> Result> { + self.inner + .transaction_on_one_with_mode(keys::DISPLAY_NAMES, IdbTransactionMode::Readonly)? + .object_store(keys::DISPLAY_NAMES)? + .get(&self.encode_key(keys::DISPLAY_NAMES, (room_id, display_name)))? + .await? + .map(|f| self.deserialize_event::>(f)) + .unwrap_or_else(|| Ok(Default::default())) + } + + async fn get_account_data_event( + &self, + event_type: GlobalAccountDataEventType, + ) -> Result>> { + self.inner + .transaction_on_one_with_mode(keys::ACCOUNT_DATA, IdbTransactionMode::Readonly)? + .object_store(keys::ACCOUNT_DATA)? + .get(&self.encode_key(keys::ACCOUNT_DATA, event_type))? + .await? + .map(|f| self.deserialize_event(f)) + .transpose() + } + + async fn get_room_account_data_event( + &self, + room_id: &RoomId, + event_type: RoomAccountDataEventType, + ) -> Result>> { + self.inner + .transaction_on_one_with_mode(keys::ROOM_ACCOUNT_DATA, IdbTransactionMode::Readonly)? + .object_store(keys::ROOM_ACCOUNT_DATA)? + .get(&self.encode_key(keys::ROOM_ACCOUNT_DATA, (room_id, event_type)))? + .await? + .map(|f| self.deserialize_event(f)) + .transpose() + } + + async fn get_user_room_receipt_event( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + thread: ReceiptThread, + user_id: &UserId, + ) -> Result> { + let key = match thread.as_str() { + Some(thread_id) => self + .encode_key(keys::ROOM_USER_RECEIPTS, (room_id, receipt_type, thread_id, user_id)), + None => self.encode_key(keys::ROOM_USER_RECEIPTS, (room_id, receipt_type, user_id)), + }; + self.inner + .transaction_on_one_with_mode(keys::ROOM_USER_RECEIPTS, IdbTransactionMode::Readonly)? + .object_store(keys::ROOM_USER_RECEIPTS)? + .get(&key)? + .await? + .map(|f| self.deserialize_event(f)) + .transpose() + } + + async fn get_event_room_receipt_events( + &self, + room_id: &RoomId, + receipt_type: ReceiptType, + thread: ReceiptThread, + event_id: &EventId, + ) -> Result> { + let range = match thread.as_str() { + Some(thread_id) => self.encode_to_range( + keys::ROOM_EVENT_RECEIPTS, + (room_id, receipt_type, thread_id, event_id), + ), + None => { + self.encode_to_range(keys::ROOM_EVENT_RECEIPTS, (room_id, receipt_type, event_id)) + } + }?; + let tx = self.inner.transaction_on_one_with_mode( + keys::ROOM_EVENT_RECEIPTS, + IdbTransactionMode::Readonly, + )?; + let store = tx.object_store(keys::ROOM_EVENT_RECEIPTS)?; + + Ok(store + .get_all_with_key(&range)? + .await? + .iter() + .filter_map(|f| self.deserialize_event(f).ok()) + .collect::>()) + } + + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { + let key = self + .encode_key(keys::MEDIA, (request.source.unique_key(), request.format.unique_key())); + let tx = + self.inner.transaction_on_one_with_mode(keys::MEDIA, IdbTransactionMode::Readwrite)?; + + tx.object_store(keys::MEDIA)?.put_key_val(&key, &self.serialize_event(&data)?)?; + + tx.await.into_result().map_err(|e| e.into()) + } + + async fn get_media_content(&self, request: &MediaRequest) -> Result>> { + let key = self + .encode_key(keys::MEDIA, (request.source.unique_key(), request.format.unique_key())); + self.inner + .transaction_on_one_with_mode(keys::MEDIA, IdbTransactionMode::Readonly)? + .object_store(keys::MEDIA)? + .get(&key)? + .await? + .map(|f| self.deserialize_event(f)) + .transpose() + } + + async fn get_custom_value(&self, key: &[u8]) -> Result>> { + let jskey = &JsValue::from_str(core::str::from_utf8(key).map_err(StoreError::Codec)?); + self.get_custom_value_for_js(jskey).await + } + + async fn set_custom_value(&self, key: &[u8], value: Vec) -> Result>> { + let jskey = JsValue::from_str(core::str::from_utf8(key).map_err(StoreError::Codec)?); + + let prev = self.get_custom_value_for_js(&jskey).await?; + + let tx = + self.inner.transaction_on_one_with_mode(keys::CUSTOM, IdbTransactionMode::Readwrite)?; + + tx.object_store(keys::CUSTOM)?.put_key_val(&jskey, &self.serialize_event(&value)?)?; + + tx.await.into_result().map_err(IndexeddbStateStoreError::from)?; + Ok(prev) + } + + async fn remove_custom_value(&self, key: &[u8]) -> Result>> { + let jskey = JsValue::from_str(core::str::from_utf8(key).map_err(StoreError::Codec)?); + + let prev = self.get_custom_value_for_js(&jskey).await?; + + let tx = + self.inner.transaction_on_one_with_mode(keys::CUSTOM, IdbTransactionMode::Readwrite)?; + + tx.object_store(keys::CUSTOM)?.delete(&jskey)?; + + tx.await.into_result().map_err(IndexeddbStateStoreError::from)?; + Ok(prev) + } + + async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { + let key = self + .encode_key(keys::MEDIA, (request.source.unique_key(), request.format.unique_key())); + let tx = + self.inner.transaction_on_one_with_mode(keys::MEDIA, IdbTransactionMode::Readwrite)?; + + tx.object_store(keys::MEDIA)?.delete(&key)?; + + tx.await.into_result().map_err(|e| e.into()) + } + + async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { + let range = self.encode_to_range(keys::MEDIA, uri)?; + let tx = + self.inner.transaction_on_one_with_mode(keys::MEDIA, IdbTransactionMode::Readwrite)?; + let store = tx.object_store(keys::MEDIA)?; + + for k in store.get_all_keys_with_key(&range)?.await?.iter() { + store.delete(&k)?; + } + + tx.await.into_result().map_err(|e| e.into()) + } + + async fn remove_room(&self, room_id: &RoomId) -> Result<()> { + let direct_stores = [keys::ROOM_INFOS, keys::STRIPPED_ROOM_INFOS]; + + let prefixed_stores = [ + keys::MEMBERS, + keys::PROFILES, + keys::DISPLAY_NAMES, + keys::INVITED_USER_IDS, + keys::JOINED_USER_IDS, + keys::ROOM_STATE, + keys::ROOM_ACCOUNT_DATA, + keys::ROOM_EVENT_RECEIPTS, + keys::ROOM_USER_RECEIPTS, + keys::STRIPPED_ROOM_STATE, + keys::STRIPPED_MEMBERS, + ]; + + let all_stores = { + let mut v = Vec::new(); + v.extend(prefixed_stores); + v.extend(direct_stores); + v + }; + + let tx = self + .inner + .transaction_on_multi_with_mode(&all_stores, IdbTransactionMode::Readwrite)?; + + for store_name in direct_stores { + tx.object_store(store_name)?.delete(&self.encode_key(store_name, room_id))?; + } + + for store_name in prefixed_stores { + let store = tx.object_store(store_name)?; + let range = self.encode_to_range(store_name, room_id)?; + for key in store.get_all_keys_with_key(&range)?.await?.iter() { + store.delete(&key)?; + } + } + tx.await.into_result().map_err(|e| e.into()) + } + + async fn get_user_ids(&self, room_id: &RoomId) -> Result> { + let ids: Vec = self.get_stripped_user_ids_stream(room_id).await?; + if !ids.is_empty() { + return Ok(ids); + } + self.get_user_ids_stream(room_id).await + } + + async fn get_invited_user_ids(&self, room_id: &RoomId) -> Result> { + let ids: Vec = self.get_stripped_invited_user_ids(room_id).await?; + if !ids.is_empty() { + return Ok(ids); + } + self.get_invited_user_ids_inner(room_id).await + } + + async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result> { + let ids: Vec = self.get_stripped_joined_user_ids(room_id).await?; + if !ids.is_empty() { + return Ok(ids); + } + self.get_joined_user_ids_inner(room_id).await + } +} + +#[cfg(all(test, target_arch = "wasm32"))] +mod tests { + #[cfg(target_arch = "wasm32")] + wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + + use matrix_sdk_base::statestore_integration_tests; + use uuid::Uuid; + + use super::{IndexeddbStateStore, Result}; + + async fn get_store() -> Result { + let db_name = format!("test-state-plain-{}", Uuid::new_v4().as_hyphenated()); + Ok(IndexeddbStateStore::builder().name(db_name).build().await?) + } + + statestore_integration_tests!(with_media_tests); +} + +#[cfg(all(test, target_arch = "wasm32"))] +mod encrypted_tests { + #[cfg(target_arch = "wasm32")] + wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + + use matrix_sdk_base::statestore_integration_tests; + use uuid::Uuid; + + use super::{IndexeddbStateStore, Result}; + + async fn get_store() -> Result { + let db_name = format!("test-state-encrypted-{}", Uuid::new_v4().as_hyphenated()); + let passphrase = format!("some_passphrase-{}", Uuid::new_v4().as_hyphenated()); + Ok(IndexeddbStateStore::builder().name(db_name).passphrase(passphrase).build().await?) + } + + statestore_integration_tests!(with_media_tests); +} diff --git a/crates/matrix-sdk-sled/Cargo.toml b/crates/matrix-sdk-sled/Cargo.toml index f2f1f420190..67f0d2b40e1 100644 --- a/crates/matrix-sdk-sled/Cargo.toml +++ b/crates/matrix-sdk-sled/Cargo.toml @@ -27,10 +27,9 @@ crypto-store = [ async-stream = { workspace = true } async-trait = { workspace = true } dashmap = { workspace = true } -derive_builder = "0.11.2" fs_extra = "1.2.0" futures-core = "0.3.21" -futures-util = { version = "0.3.21", default-features = false } +futures-util = { workspace = true } matrix-sdk-base = { version = "0.6.0", path = "../matrix-sdk-base", optional = true } matrix-sdk-common = { version = "0.6.0", path = "../matrix-sdk-common" } matrix-sdk-crypto = { version = "0.6.0", path = "../matrix-sdk-crypto", optional = true } @@ -40,7 +39,7 @@ serde = { workspace = true } serde_json = { workspace = true } sled = "0.34.7" thiserror = { workspace = true } -tokio = { version = "1.23.1", default-features = false, features = ["sync", "fs"] } +tokio = { version = "1.24.2", default-features = false, features = ["sync", "fs"] } tracing = { workspace = true } [dev-dependencies] @@ -50,4 +49,4 @@ matrix-sdk-crypto = { path = "../matrix-sdk-crypto", features = ["testing"] } matrix-sdk-test = { path = "../../testing/matrix-sdk-test" } once_cell = { workspace = true } tempfile = "3.3.0" -tokio = { version = "1.23.1", default-features = false, features = ["rt-multi-thread", "macros"] } +tokio = { version = "1.24.2", default-features = false, features = ["rt-multi-thread", "macros"] } diff --git a/crates/matrix-sdk-sled/src/crypto_store.rs b/crates/matrix-sdk-sled/src/crypto_store.rs index b386db6386e..01652fb597b 100644 --- a/crates/matrix-sdk-sled/src/crypto_store.rs +++ b/crates/matrix-sdk-sled/src/crypto_store.rs @@ -28,7 +28,7 @@ use matrix_sdk_crypto::{ }, store::{ caches::SessionStore, BackupKeys, Changes, CryptoStore, CryptoStoreError, Result, - RoomKeyCounts, + RoomKeyCounts, RoomSettings, }, types::{events::room_key_request::SupportedKeyInfo, EventEncryptionAlgorithm}, GossipRequest, ReadOnlyAccount, ReadOnlyDevice, ReadOnlyUserIdentities, SecretInfo, @@ -37,7 +37,6 @@ use matrix_sdk_crypto::{ use matrix_sdk_store_encryption::StoreCipher; use ruma::{DeviceId, OwnedDeviceId, RoomId, TransactionId, UserId}; use serde::{de::DeserializeOwned, Serialize}; -pub use sled::Error; use sled::{ transaction::{ConflictableTransactionError, TransactionError}, Batch, Config, Db, IVec, Transactional, Tree, @@ -47,7 +46,7 @@ use tracing::debug; use super::OpenStoreError; use crate::encode_key::{EncodeKey, ENCODE_SEPARATOR}; -const DATABASE_VERSION: u8 = 6; +const DATABASE_VERSION: u8 = 7; // Table names that are used to derive a separate key for each tree. This ensure // that user ids encoded for different trees won't end up as the same byte @@ -59,6 +58,7 @@ const INBOUND_GROUP_TABLE_NAME: &str = "crypto-store-inbound-group-sessions"; const OUTBOUND_GROUP_TABLE_NAME: &str = "crypto-store-outbound-group-sessions"; const SECRET_REQUEST_BY_INFO_TABLE: &str = "crypto-store-secret-request-by-info"; const TRACKED_USERS_TABLE: &str = "crypto-store-secret-tracked-users"; +const ROOM_SETTINGS_TABLE: &str = "crypto-store-secret-room-settings"; impl EncodeKey for InboundGroupSession { fn encode(&self) -> Vec { @@ -186,6 +186,8 @@ pub struct SledCryptoStore { identities: Tree, tracked_users: Tree, + + room_settings: Tree, } impl std::fmt::Debug for SledCryptoStore { @@ -387,6 +389,8 @@ impl SledCryptoStore { let unsent_secret_requests = db.open_tree("unsent_secret_requests")?; let secret_requests_by_info = db.open_tree("secret_requests_by_info")?; + let room_settings = db.open_tree("room_settings")?; + let session_cache = SessionStore::new(); let database = Self { @@ -407,6 +411,7 @@ impl SledCryptoStore { tracked_users, olm_hashes, identities, + room_settings, }; database.upgrade().await?; @@ -463,7 +468,7 @@ impl SledCryptoStore { }; let private_identity_pickle = - if let Some(i) = changes.private_identity { Some(i.pickle().await?) } else { None }; + if let Some(i) = changes.private_identity { Some(i.pickle().await) } else { None }; let recovery_key_pickle = changes.recovery_key; @@ -500,6 +505,7 @@ impl SledCryptoStore { let olm_hashes = changes.message_hashes; let key_requests = changes.key_requests; let backup_version = changes.backup_version; + let room_settings_changes = changes.room_settings; let ret: Result<(), TransactionError> = ( &self.account, @@ -513,6 +519,7 @@ impl SledCryptoStore { &self.outgoing_secret_requests, &self.unsent_secret_requests, &self.secret_requests_by_info, + &self.room_settings, ) .transaction( |( @@ -527,6 +534,7 @@ impl SledCryptoStore { outgoing_secret_requests, unsent_secret_requests, secret_requests_by_info, + room_settings, )| { if let Some(a) = &account_pickle { account.insert( @@ -636,6 +644,15 @@ impl SledCryptoStore { } } + for (room_id, settings) in &room_settings_changes { + let key = self.encode_key(ROOM_SETTINGS_TABLE, room_id); + room_settings.insert( + key.as_slice(), + self.serialize_value(&settings) + .map_err(ConflictableTransactionError::Abort)?, + )?; + } + Ok(()) }, ); @@ -698,6 +715,8 @@ impl SledCryptoStore { #[async_trait] impl CryptoStore for SledCryptoStore { + type Error = CryptoStoreError; + async fn load_account(&self) -> Result> { if let Some(pickle) = self.account.get("account".encode()).map_err(CryptoStoreError::backend)? @@ -1009,6 +1028,25 @@ impl CryptoStore for SledCryptoStore { Ok(key) } + + async fn get_room_settings(&self, room_id: &RoomId) -> Result> { + let key = self.encode_key(ROOM_SETTINGS_TABLE, room_id); + self.room_settings + .get(key) + .map_err(CryptoStoreError::backend)? + .map(|p| self.deserialize_value(&p)) + .transpose() + } + + async fn get_custom_value(&self, key: &str) -> Result>> { + let value = self.inner.get(key).map_err(CryptoStoreError::backend)?.map(|v| v.to_vec()); + Ok(value) + } + + async fn set_custom_value(&self, key: &str, value: Vec) -> Result<()> { + self.inner.insert(key, value).map_err(CryptoStoreError::backend)?; + Ok(()) + } } #[cfg(test)] diff --git a/crates/matrix-sdk-sled/src/encode_key.rs b/crates/matrix-sdk-sled/src/encode_key.rs index 505222aea3c..7b4a20c7816 100644 --- a/crates/matrix-sdk-sled/src/encode_key.rs +++ b/crates/matrix-sdk-sled/src/encode_key.rs @@ -247,3 +247,44 @@ where .concat() } } + +impl EncodeKey for (A, B, C, D, E) +where + A: EncodeKey, + B: EncodeKey, + C: EncodeKey, + D: EncodeKey, + E: EncodeKey, +{ + fn encode(&self) -> Vec { + [ + self.0.encode_as_bytes().deref(), + &[ENCODE_SEPARATOR], + self.1.encode_as_bytes().deref(), + &[ENCODE_SEPARATOR], + self.2.encode_as_bytes().deref(), + &[ENCODE_SEPARATOR], + self.3.encode_as_bytes().deref(), + &[ENCODE_SEPARATOR], + self.4.encode_as_bytes().deref(), + &[ENCODE_SEPARATOR], + ] + .concat() + } + + fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec { + [ + store_cipher.hash_key(table_name, &self.0.encode_as_bytes()).as_slice(), + &[ENCODE_SEPARATOR], + store_cipher.hash_key(table_name, &self.1.encode_as_bytes()).as_slice(), + &[ENCODE_SEPARATOR], + store_cipher.hash_key(table_name, &self.2.encode_as_bytes()).as_slice(), + &[ENCODE_SEPARATOR], + store_cipher.hash_key(table_name, &self.3.encode_as_bytes()).as_slice(), + &[ENCODE_SEPARATOR], + store_cipher.hash_key(table_name, &self.4.encode_as_bytes()).as_slice(), + &[ENCODE_SEPARATOR], + ] + .concat() + } +} diff --git a/crates/matrix-sdk-sled/src/lib.rs b/crates/matrix-sdk-sled/src/lib.rs index 9eb38fbe8fc..8dc8e293dab 100644 --- a/crates/matrix-sdk-sled/src/lib.rs +++ b/crates/matrix-sdk-sled/src/lib.rs @@ -63,12 +63,11 @@ pub async fn make_store_config( #[cfg(not(feature = "crypto-store"))] { - let mut store_builder = SledStateStore::builder(); - store_builder.path(path.as_ref().to_path_buf()); + let mut store_builder = SledStateStore::builder().path(path.as_ref().to_path_buf()); if let Some(passphrase) = passphrase { - store_builder.passphrase(passphrase.to_owned()); - }; + store_builder = store_builder.passphrase(passphrase.to_owned()); + } let state_store = store_builder.build().map_err(StoreError::backend)?; Ok(StoreConfig::new().state_store(state_store)) @@ -82,11 +81,9 @@ async fn open_stores_with_path( path: impl AsRef, passphrase: Option<&str>, ) -> Result<(SledStateStore, SledCryptoStore), OpenStoreError> { - let mut store_builder = SledStateStore::builder(); - store_builder.path(path.as_ref().to_path_buf()); - + let mut store_builder = SledStateStore::builder().path(path.as_ref().to_path_buf()); if let Some(passphrase) = passphrase { - store_builder.passphrase(passphrase.to_owned()); + store_builder = store_builder.passphrase(passphrase.to_owned()); } let state_store = store_builder.build().map_err(StoreError::backend)?; diff --git a/crates/matrix-sdk-sled/src/state_store/migrations.rs b/crates/matrix-sdk-sled/src/state_store/migrations.rs new file mode 100644 index 00000000000..e399e5922e4 --- /dev/null +++ b/crates/matrix-sdk-sled/src/state_store/migrations.rs @@ -0,0 +1,449 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use matrix_sdk_base::{ + store::{Result as StoreResult, StoreError}, + StateStoreDataKey, +}; +use serde_json::value::{RawValue as RawJsonValue, Value as JsonValue}; +use sled::{transaction::TransactionError, Batch, Transactional, Tree}; +use tracing::debug; + +use super::{keys, Result, SledStateStore, SledStoreError}; +use crate::encode_key::EncodeKey; + +const DATABASE_VERSION: u8 = 4; + +const VERSION_KEY: &str = "state-store-version"; + +/// Sometimes Migrations can't proceed without having to drop existing +/// data. This allows you to configure, how these cases should be handled. +#[derive(PartialEq, Eq, Clone, Debug)] +pub enum MigrationConflictStrategy { + /// Just drop the data, we don't care that we have to sync again + Drop, + /// Raise a `SledStoreError::MigrationConflict` error with the path to the + /// DB in question. The caller then has to take care about what they want + /// to do and try again after. + Raise, + /// _Default_: The _entire_ database is backed up under + /// `$path.$timestamp.backup` (this includes the crypto store if they + /// are linked), before the state tables are dropped. + BackupAndDrop, +} + +impl SledStateStore { + pub(super) fn upgrade(&mut self) -> Result<()> { + let old_version = self.db_version()?; + + if old_version == 0 { + // we are fresh, let's write the current version + return self.set_db_version(DATABASE_VERSION); + } + if old_version == DATABASE_VERSION { + // current, we don't have to do anything + return Ok(()); + }; + + debug!(old_version, new_version = DATABASE_VERSION, "Upgrading the Sled state store"); + + if old_version == 1 && self.store_cipher.is_some() { + // we stored some fields un-encrypted. Drop them to force re-creation + return Err(SledStoreError::MigrationConflict { + path: self.path.take().expect("Path must exist for a migration to fail"), + old_version: old_version.into(), + new_version: DATABASE_VERSION.into(), + }); + } + + if old_version < 3 { + self.migrate_to_v3()?; + } + + if old_version < 4 { + self.migrate_to_v4()?; + return Ok(()); + } + + // FUTURE UPGRADE CODE GOES HERE + + // can't upgrade from that version to the new one + Err(SledStoreError::MigrationConflict { + path: self.path.take().expect("Path must exist for a migration to fail"), + old_version: old_version.into(), + new_version: DATABASE_VERSION.into(), + }) + } + + /// Get the version of the database. + /// + /// Returns `0` for a new database. + fn db_version(&self) -> Result { + Ok(self + .inner + .get(VERSION_KEY)? + .map(|v| { + let (version_bytes, _) = v.split_at(std::mem::size_of::()); + u8::from_be_bytes(version_bytes.try_into().unwrap_or_default()) + }) + .unwrap_or_default()) + } + + fn set_db_version(&self, version: u8) -> Result<()> { + self.inner.insert(VERSION_KEY, version.to_be_bytes().as_ref())?; + self.inner.flush()?; + Ok(()) + } + + pub fn drop_v1_tables(self) -> StoreResult<()> { + for name in V1_DB_STORES { + self.inner.drop_tree(name).map_err(StoreError::backend)?; + } + self.inner.remove(VERSION_KEY).map_err(StoreError::backend)?; + + Ok(()) + } + + fn v3_fix_tree(&self, tree: &Tree, batch: &mut Batch) -> Result<()> { + fn maybe_fix_json(raw_json: &RawJsonValue) -> Result> { + let json = raw_json.get(); + + if json.contains(r#""content":null"#) { + let mut value: JsonValue = serde_json::from_str(json)?; + if let Some(content) = value.get_mut("content") { + if matches!(content, JsonValue::Null) { + *content = JsonValue::Object(Default::default()); + return Ok(Some(value)); + } + } + } + + Ok(None) + } + + for entry in tree.iter() { + let (key, value) = entry?; + let raw_json: Box = self.deserialize_value(&value)?; + + if let Some(fixed_json) = maybe_fix_json(&raw_json)? { + batch.insert(key, self.serialize_value(&fixed_json)?); + } + } + + Ok(()) + } + + fn migrate_to_v3(&self) -> Result<()> { + let mut room_info_batch = sled::Batch::default(); + self.v3_fix_tree(&self.room_info, &mut room_info_batch)?; + + let mut room_state_batch = sled::Batch::default(); + self.v3_fix_tree(&self.room_state, &mut room_state_batch)?; + + let ret: Result<(), TransactionError> = (&self.room_info, &self.room_state) + .transaction(|(room_info, room_state)| { + room_info.apply_batch(&room_info_batch)?; + room_state.apply_batch(&room_state_batch)?; + + Ok(()) + }); + ret?; + + self.set_db_version(3u8) + } + + /// Replace the SYNC_TOKEN and SESSION trees by KV. + fn migrate_to_v4(&self) -> Result<()> { + { + let session = &self.inner.open_tree(old_keys::SESSION)?; + let mut batch = sled::Batch::default(); + + // Sync token + let sync_token = session.get(StateStoreDataKey::SYNC_TOKEN.encode())?; + if let Some(sync_token) = sync_token { + batch.insert(StateStoreDataKey::SYNC_TOKEN.encode(), sync_token); + } + + // Filters + let key = self.encode_key(keys::SESSION, StateStoreDataKey::FILTER); + for res in session.scan_prefix(key) { + let (key, value) = res?; + batch.insert(key, value); + } + self.kv.apply_batch(batch)?; + } + + // This was unused so we can just drop it. + self.inner.drop_tree(old_keys::SYNC_TOKEN)?; + self.inner.drop_tree(old_keys::SESSION)?; + + self.set_db_version(4) + } +} + +mod old_keys { + /// Old stores. + pub const SYNC_TOKEN: &str = "sync_token"; + pub const SESSION: &str = "session"; +} + +pub const V1_DB_STORES: &[&str] = &[ + keys::ACCOUNT_DATA, + old_keys::SYNC_TOKEN, + keys::DISPLAY_NAME, + keys::INVITED_USER_ID, + keys::JOINED_USER_ID, + keys::MEDIA, + keys::MEMBER, + keys::PRESENCE, + keys::PROFILE, + keys::ROOM_ACCOUNT_DATA, + keys::ROOM_EVENT_RECEIPT, + keys::ROOM_INFO, + keys::ROOM_STATE, + keys::ROOM_USER_RECEIPT, + keys::ROOM, + old_keys::SESSION, + keys::STRIPPED_INVITED_USER_ID, + keys::STRIPPED_JOINED_USER_ID, + keys::STRIPPED_ROOM_INFO, + keys::STRIPPED_ROOM_MEMBER, + keys::STRIPPED_ROOM_STATE, + keys::CUSTOM, +]; + +#[cfg(test)] +mod test { + use matrix_sdk_base::StateStoreDataKey; + use matrix_sdk_test::async_test; + use ruma::{ + events::{AnySyncStateEvent, StateEventType}, + room_id, + }; + use serde_json::json; + use tempfile::TempDir; + + use super::{old_keys, MigrationConflictStrategy}; + use crate::{ + encode_key::EncodeKey, + state_store::{keys, Result, SledStateStore, SledStoreError}, + }; + + #[async_test] + pub async fn migrating_v1_to_2_plain() -> Result<()> { + let folder = TempDir::new()?; + + let store = SledStateStore::builder().path(folder.path().to_path_buf()).build()?; + + store.set_db_version(1u8)?; + drop(store); + + // this transparently migrates to the latest version + let _store = SledStateStore::builder().path(folder.path().to_path_buf()).build()?; + Ok(()) + } + + #[async_test] + pub async fn migrating_v1_to_2_with_pw_backed_up() -> Result<()> { + let folder = TempDir::new()?; + + let store = SledStateStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("something".to_owned()) + .build()?; + + store.set_db_version(1u8)?; + drop(store); + + // this transparently creates a backup and a fresh db + let _store = SledStateStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("something".to_owned()) + .build()?; + assert_eq!(std::fs::read_dir(folder.path())?.count(), 2); + Ok(()) + } + + #[async_test] + pub async fn migrating_v1_to_2_with_pw_drop() -> Result<()> { + let folder = TempDir::new()?; + + let store = SledStateStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("other thing".to_owned()) + .build()?; + + store.set_db_version(1u8)?; + drop(store); + + // this transparently creates a backup and a fresh db + let _store = SledStateStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("other thing".to_owned()) + .migration_conflict_strategy(MigrationConflictStrategy::Drop) + .build()?; + assert_eq!(std::fs::read_dir(folder.path())?.count(), 1); + Ok(()) + } + + #[async_test] + pub async fn migrating_v1_to_2_with_pw_raises() -> Result<()> { + let folder = TempDir::new()?; + + let store = SledStateStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("secret".to_owned()) + .build()?; + + store.set_db_version(1u8)?; + drop(store); + + // this transparently creates a backup and a fresh db + let res = SledStateStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("secret".to_owned()) + .migration_conflict_strategy(MigrationConflictStrategy::Raise) + .build(); + if let Err(SledStoreError::MigrationConflict { .. }) = res { + // all good + } else { + panic!("Didn't raise the expected error: {res:?}"); + } + assert_eq!(std::fs::read_dir(folder.path())?.count(), 1); + Ok(()) + } + + #[async_test] + pub async fn migrating_v2_to_v3() { + // An event that fails to deserialize. + let wrong_redacted_state_event = json!({ + "content": null, + "event_id": "$wrongevent", + "origin_server_ts": 1673887516047_u64, + "sender": "@example:localhost", + "state_key": "", + "type": "m.room.topic", + "unsigned": { + "redacted_because": { + "type": "m.room.redaction", + "sender": "@example:localhost", + "content": {}, + "redacts": "$wrongevent", + "origin_server_ts": 1673893816047_u64, + "unsigned": {}, + "event_id": "$redactionevent", + }, + }, + }); + serde_json::from_value::(wrong_redacted_state_event.clone()) + .unwrap_err(); + + let room_id = room_id!("!some_room:localhost"); + let folder = TempDir::new().unwrap(); + + let store = SledStateStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("secret".to_owned()) + .build() + .unwrap(); + + store + .room_state + .insert( + store.encode_key(keys::ROOM_STATE, (room_id, StateEventType::RoomTopic, "")), + store.serialize_value(&wrong_redacted_state_event).unwrap(), + ) + .unwrap(); + store.set_db_version(2u8).unwrap(); + drop(store); + + let store = SledStateStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("secret".to_owned()) + .build() + .unwrap(); + let event = + store.get_state_event(room_id, StateEventType::RoomTopic, "").await.unwrap().unwrap(); + event.deserialize().unwrap(); + } + + #[async_test] + pub async fn migrating_v3_to_v4() { + let sync_token = "a_very_unique_string"; + let filter_1 = "filter_1"; + let filter_1_id = "filter_1_id"; + let filter_2 = "filter_2"; + let filter_2_id = "filter_2_id"; + + let folder = TempDir::new().unwrap(); + let store = SledStateStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("secret".to_owned()) + .build() + .unwrap(); + + let session = store.inner.open_tree(old_keys::SESSION).unwrap(); + let mut batch = sled::Batch::default(); + batch.insert( + StateStoreDataKey::SYNC_TOKEN.encode(), + store.serialize_value(&sync_token).unwrap(), + ); + batch.insert( + store.encode_key(keys::SESSION, (StateStoreDataKey::FILTER, filter_1)), + store.serialize_value(&filter_1_id).unwrap(), + ); + batch.insert( + store.encode_key(keys::SESSION, (StateStoreDataKey::FILTER, filter_2)), + store.serialize_value(&filter_2_id).unwrap(), + ); + session.apply_batch(batch).unwrap(); + + store.set_db_version(3).unwrap(); + drop(session); + drop(store); + + let store = SledStateStore::builder() + .path(folder.path().to_path_buf()) + .passphrase("secret".to_owned()) + .build() + .unwrap(); + + let stored_sync_token = store + .get_kv_data(StateStoreDataKey::SyncToken) + .await + .unwrap() + .unwrap() + .into_sync_token() + .unwrap(); + assert_eq!(stored_sync_token, sync_token); + + let stored_filter_1_id = store + .get_kv_data(StateStoreDataKey::Filter(filter_1)) + .await + .unwrap() + .unwrap() + .into_filter() + .unwrap(); + assert_eq!(stored_filter_1_id, filter_1_id); + + let stored_filter_2_id = store + .get_kv_data(StateStoreDataKey::Filter(filter_2)) + .await + .unwrap() + .unwrap() + .into_filter() + .unwrap(); + assert_eq!(stored_filter_2_id, filter_2_id); + } +} diff --git a/crates/matrix-sdk-sled/src/state_store.rs b/crates/matrix-sdk-sled/src/state_store/mod.rs similarity index 72% rename from crates/matrix-sdk-sled/src/state_store.rs rename to crates/matrix-sdk-sled/src/state_store/mod.rs index 2c3c122accd..e5d8f68cc38 100644 --- a/crates/matrix-sdk-sled/src/state_store.rs +++ b/crates/matrix-sdk-sled/src/state_store/mod.rs @@ -20,21 +20,20 @@ use std::{ }; use async_trait::async_trait; -use derive_builder::Builder; use futures_core::stream::Stream; use futures_util::stream::{self, StreamExt, TryStreamExt}; use matrix_sdk_base::{ deserialized_responses::RawMemberEvent, media::{MediaRequest, UniqueKey}, store::{Result as StoreResult, StateChanges, StateStore, StoreError}, - MinimalStateEvent, RoomInfo, + MinimalStateEvent, RoomInfo, StateStoreDataKey, StateStoreDataValue, }; use matrix_sdk_store_encryption::{Error as KeyEncryptionError, StoreCipher}; use ruma::{ canonical_json::redact, events::{ presence::PresenceEvent, - receipt::{Receipt, ReceiptType}, + receipt::{Receipt, ReceiptThread, ReceiptType}, room::member::{MembershipState, RoomMemberEventContent}, AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType, @@ -51,6 +50,9 @@ use sled::{ use tokio::task::spawn_blocking; use tracing::{debug, info, warn}; +mod migrations; + +pub use self::migrations::MigrationConflictStrategy; #[cfg(feature = "crypto-store")] use super::OpenStoreError; use crate::encode_key::{EncodeKey, EncodeUnchecked}; @@ -79,22 +81,6 @@ pub enum SledStoreError { MigrationConflict { path: PathBuf, old_version: usize, new_version: usize }, } -/// Sometimes Migrations can't proceed without having to drop existing -/// data. This allows you to configure, how these cases should be handled. -#[derive(PartialEq, Eq, Clone, Debug)] -pub enum MigrationConflictStrategy { - /// Just drop the data, we don't care that we have to sync again - Drop, - /// Raise a `SledStoreError::MigrationConflict` error with the path to the - /// DB in question. The caller then has to take care about what they want - /// to do and try again after. - Raise, - /// _Default_: The _entire_ database is backed up under - /// `$path.$timestamp.backup` (this includes the crypto store if they - /// are linked), before the state tables are dropped. - BackupAndDrop, -} - impl From> for SledStoreError { fn from(e: TransactionError) -> Self { match e { @@ -115,58 +101,34 @@ impl From for StoreError { } } } -const DATABASE_VERSION: u8 = 2; - -const VERSION_KEY: &str = "state-store-version"; - -const ACCOUNT_DATA: &str = "account-data"; -const CUSTOM: &str = "custom"; -const SYNC_TOKEN: &str = "sync_token"; -const DISPLAY_NAME: &str = "display-name"; -const INVITED_USER_ID: &str = "invited-user-id"; -const JOINED_USER_ID: &str = "joined-user-id"; -const MEDIA: &str = "media"; -const MEMBER: &str = "member"; -const PRESENCE: &str = "presence"; -const PROFILE: &str = "profile"; -const ROOM_ACCOUNT_DATA: &str = "room-account-data"; -const ROOM_EVENT_RECEIPT: &str = "room-event-receipt"; -const ROOM_INFO: &str = "room-info"; -const ROOM_STATE: &str = "room-state"; -const ROOM_USER_RECEIPT: &str = "room-user-receipt"; -const ROOM: &str = "room"; -const SESSION: &str = "session"; -const STRIPPED_INVITED_USER_ID: &str = "stripped-invited-user-id"; -const STRIPPED_JOINED_USER_ID: &str = "stripped-joined-user-id"; -const STRIPPED_ROOM_INFO: &str = "stripped-room-info"; -const STRIPPED_ROOM_MEMBER: &str = "stripped-room-member"; -const STRIPPED_ROOM_STATE: &str = "stripped-room-state"; - -const ALL_DB_STORES: &[&str] = &[ - ACCOUNT_DATA, - SYNC_TOKEN, - DISPLAY_NAME, - INVITED_USER_ID, - JOINED_USER_ID, - MEDIA, - MEMBER, - PRESENCE, - PROFILE, - ROOM_ACCOUNT_DATA, - ROOM_EVENT_RECEIPT, - ROOM_INFO, - ROOM_STATE, - ROOM_USER_RECEIPT, - ROOM, - SESSION, - STRIPPED_INVITED_USER_ID, - STRIPPED_JOINED_USER_ID, - STRIPPED_ROOM_INFO, - STRIPPED_ROOM_MEMBER, - STRIPPED_ROOM_STATE, - CUSTOM, -]; -const ALL_GLOBAL_KEYS: &[&str] = &[VERSION_KEY]; + +mod keys { + // Static keys + pub const SESSION: &str = "session"; + + // Stores + pub const ACCOUNT_DATA: &str = "account-data"; + pub const CUSTOM: &str = "custom"; + pub const DISPLAY_NAME: &str = "display-name"; + pub const INVITED_USER_ID: &str = "invited-user-id"; + pub const JOINED_USER_ID: &str = "joined-user-id"; + pub const MEDIA: &str = "media"; + pub const MEMBER: &str = "member"; + pub const PRESENCE: &str = "presence"; + pub const PROFILE: &str = "profile"; + pub const ROOM_ACCOUNT_DATA: &str = "room-account-data"; + pub const ROOM_EVENT_RECEIPT: &str = "room-event-receipt"; + pub const ROOM_INFO: &str = "room-info"; + pub const ROOM_STATE: &str = "room-state"; + pub const ROOM_USER_RECEIPT: &str = "room-user-receipt"; + pub const ROOM: &str = "room"; + pub const STRIPPED_INVITED_USER_ID: &str = "stripped-invited-user-id"; + pub const STRIPPED_JOINED_USER_ID: &str = "stripped-joined-user-id"; + pub const STRIPPED_ROOM_INFO: &str = "stripped-room-info"; + pub const STRIPPED_ROOM_MEMBER: &str = "stripped-room-member"; + pub const STRIPPED_ROOM_STATE: &str = "stripped-room-state"; + pub const KV: &str = "kv"; +} type Result = std::result::Result; @@ -176,25 +138,27 @@ enum DbOrPath { Path(PathBuf), } -#[derive(Builder, Debug)] -#[builder(name = "SledStateStoreBuilder", build_fn(skip))] -#[allow(dead_code)] -pub struct SledStateStoreBuilderConfig { - #[builder(setter(custom))] - db_or_path: DbOrPath, - /// Set the password the sled store is encrypted with (if any) - passphrase: String, - /// The strategy to use when a merge conflict is found, see - /// [`MigrationConflictStrategy`] for details - #[builder(default = "MigrationConflictStrategy::BackupAndDrop")] +/// Builder for [`SledStateStore`]. +#[derive(Debug)] +pub struct SledStateStoreBuilder { + db_or_path: Option, + passphrase: Option, migration_conflict_strategy: MigrationConflictStrategy, } impl SledStateStoreBuilder { + fn new() -> Self { + Self { + db_or_path: None, + passphrase: None, + migration_conflict_strategy: MigrationConflictStrategy::BackupAndDrop, + } + } + /// Path to the sled store files, created if not it doesn't exist yet. /// /// Mutually exclusive with [`db`][Self::db], whichever is called last wins. - pub fn path(&mut self, path: PathBuf) -> &mut SledStateStoreBuilder { + pub fn path(mut self, path: PathBuf) -> Self { self.db_or_path = Some(DbOrPath::Path(path)); self } @@ -203,11 +167,25 @@ impl SledStateStoreBuilder { /// /// Mutually exclusive with [`path`][Self::path], whichever is called last /// wins. - pub fn db(&mut self, db: Db) -> &mut SledStateStoreBuilder { + pub fn db(mut self, db: Db) -> Self { self.db_or_path = Some(DbOrPath::Db(db)); self } + /// Set the password the sled store is encrypted with (if any). + pub fn passphrase(mut self, value: String) -> Self { + self.passphrase = Some(value); + self + } + + /// Set the strategy to use when a merge conflict is found. + /// + /// See [`MigrationConflictStrategy`] for details. + pub fn migration_conflict_strategy(mut self, value: MigrationConflictStrategy) -> Self { + self.migration_conflict_strategy = value; + self + } + /// Create a [`SledStateStore`] with the options set on this builder. /// /// # Errors @@ -218,7 +196,7 @@ impl SledStateStoreBuilder { /// path. /// * Migration error: The migration to a newer version of the schema /// failed, see `SledStoreError::MigrationConflict`. - pub fn build(&mut self) -> Result { + pub fn build(self) -> Result { let (db, path) = match &self.db_or_path { None => { let db = Config::new().temporary(true).open().map_err(StoreError::backend)?; @@ -253,11 +231,7 @@ impl SledStateStoreBuilder { let migration_res = store.upgrade(); if let Err(SledStoreError::MigrationConflict { path, .. }) = &migration_res { // how are supposed to react about this? - match self - .migration_conflict_strategy - .as_ref() - .unwrap_or(&MigrationConflictStrategy::BackupAndDrop) - { + match self.migration_conflict_strategy { MigrationConflictStrategy::BackupAndDrop => { let mut new_path = path.clone(); new_path.set_extension(format!( @@ -269,11 +243,11 @@ impl SledStateStoreBuilder { )); fs_extra::dir::create_all(&new_path, false)?; fs_extra::dir::copy(path, new_path, &fs_extra::dir::CopyOptions::new())?; - store.drop_tables()?; + store.drop_v1_tables()?; return self.build(); } MigrationConflictStrategy::Drop => { - store.drop_tables()?; + store.drop_v1_tables()?; return self.build(); } MigrationConflictStrategy::Raise => migration_res?, @@ -304,7 +278,7 @@ pub struct SledStateStore { path: Option, pub(crate) inner: Db, store_cipher: Option>, - session: Tree, + kv: Tree, account_data: Tree, members: Tree, profiles: Tree, @@ -342,38 +316,38 @@ impl SledStateStore { path: Option, store_cipher: Option>, ) -> Result { - let session = db.open_tree(SESSION)?; - let account_data = db.open_tree(ACCOUNT_DATA)?; + let kv = db.open_tree(keys::KV)?; + let account_data = db.open_tree(keys::ACCOUNT_DATA)?; - let members = db.open_tree(MEMBER)?; - let profiles = db.open_tree(PROFILE)?; - let display_names = db.open_tree(DISPLAY_NAME)?; - let joined_user_ids = db.open_tree(JOINED_USER_ID)?; - let invited_user_ids = db.open_tree(INVITED_USER_ID)?; + let members = db.open_tree(keys::MEMBER)?; + let profiles = db.open_tree(keys::PROFILE)?; + let display_names = db.open_tree(keys::DISPLAY_NAME)?; + let joined_user_ids = db.open_tree(keys::JOINED_USER_ID)?; + let invited_user_ids = db.open_tree(keys::INVITED_USER_ID)?; - let room_state = db.open_tree(ROOM_STATE)?; - let room_info = db.open_tree(ROOM_INFO)?; - let presence = db.open_tree(PRESENCE)?; - let room_account_data = db.open_tree(ROOM_ACCOUNT_DATA)?; + let room_state = db.open_tree(keys::ROOM_STATE)?; + let room_info = db.open_tree(keys::ROOM_INFO)?; + let presence = db.open_tree(keys::PRESENCE)?; + let room_account_data = db.open_tree(keys::ROOM_ACCOUNT_DATA)?; - let stripped_joined_user_ids = db.open_tree(STRIPPED_JOINED_USER_ID)?; - let stripped_invited_user_ids = db.open_tree(STRIPPED_INVITED_USER_ID)?; - let stripped_room_infos = db.open_tree(STRIPPED_ROOM_INFO)?; - let stripped_members = db.open_tree(STRIPPED_ROOM_MEMBER)?; - let stripped_room_state = db.open_tree(STRIPPED_ROOM_STATE)?; + let stripped_joined_user_ids = db.open_tree(keys::STRIPPED_JOINED_USER_ID)?; + let stripped_invited_user_ids = db.open_tree(keys::STRIPPED_INVITED_USER_ID)?; + let stripped_room_infos = db.open_tree(keys::STRIPPED_ROOM_INFO)?; + let stripped_members = db.open_tree(keys::STRIPPED_ROOM_MEMBER)?; + let stripped_room_state = db.open_tree(keys::STRIPPED_ROOM_STATE)?; - let room_user_receipts = db.open_tree(ROOM_USER_RECEIPT)?; - let room_event_receipts = db.open_tree(ROOM_EVENT_RECEIPT)?; + let room_user_receipts = db.open_tree(keys::ROOM_USER_RECEIPT)?; + let room_event_receipts = db.open_tree(keys::ROOM_EVENT_RECEIPT)?; - let media = db.open_tree(MEDIA)?; + let media = db.open_tree(keys::MEDIA)?; - let custom = db.open_tree(CUSTOM)?; + let custom = db.open_tree(keys::CUSTOM)?; Ok(Self { path, inner: db, store_cipher, - session, + kv, account_data, members, profiles, @@ -396,70 +370,9 @@ impl SledStateStore { }) } - /// Generate a SledStateStoreBuilder with default parameters + /// Create a [`SledStateStoreBuilder`] with default parameters. pub fn builder() -> SledStateStoreBuilder { - SledStateStoreBuilder::default() - } - - fn drop_tables(self) -> StoreResult<()> { - for name in ALL_DB_STORES { - self.inner.drop_tree(name).map_err(StoreError::backend)?; - } - for name in ALL_GLOBAL_KEYS { - self.inner.remove(name).map_err(StoreError::backend)?; - } - - Ok(()) - } - - fn set_db_version(&self, version: u8) -> Result<()> { - self.inner.insert(VERSION_KEY, version.to_be_bytes().as_ref())?; - self.inner.flush()?; - Ok(()) - } - - fn upgrade(&mut self) -> Result<()> { - let db_version = self.inner.get(VERSION_KEY)?.map(|v| { - let (version_bytes, _) = v.split_at(std::mem::size_of::()); - u8::from_be_bytes(version_bytes.try_into().unwrap_or_default()) - }); - - let old_version = match db_version { - None => { - // we are fresh, let's write the current version - return self.set_db_version(DATABASE_VERSION); - } - Some(version) if version == DATABASE_VERSION => { - // current, we don't have to do anything - return Ok(()); - } - Some(version) => version, - }; - - debug!(old_version, new_version = DATABASE_VERSION, "Upgrading the Sled state store"); - - if old_version == 1 { - if self.store_cipher.is_some() { - // we stored some fields un-encrypted. Drop them to force re-creation - return Err(SledStoreError::MigrationConflict { - path: self.path.take().expect("Path must exist for a migration to fail"), - old_version: old_version.into(), - new_version: DATABASE_VERSION.into(), - }); - } - // no migration to handle - self.set_db_version(2u8)?; - return Ok(()); - } - - // FUTURE UPGRADE CODE GOES HERE - - // can't upgrade from that version to the new one - Err(SledStoreError::MigrationConflict { - path: self.path.take().expect("Path must exist for a migration to fail"), - old_version: old_version.into(), - new_version: DATABASE_VERSION.into(), - }) + SledStateStoreBuilder::new() } /// Open a `SledCryptoStore` that uses the same database as this store. @@ -496,23 +409,61 @@ impl SledStateStore { } } - pub async fn save_filter(&self, filter_name: &str, filter_id: &str) -> Result<()> { - self.session.insert( - self.encode_key(SESSION, ("filter", filter_name)), - self.serialize_value(&filter_id)?, - )?; - Ok(()) + fn encode_kv_data_key(&self, key: StateStoreDataKey<'_>) -> Vec { + match key { + StateStoreDataKey::SyncToken => StateStoreDataKey::SYNC_TOKEN.encode(), + StateStoreDataKey::Filter(filter_name) => { + self.encode_key(keys::SESSION, (StateStoreDataKey::FILTER, filter_name)) + } + StateStoreDataKey::UserAvatarUrl(user_id) => { + self.encode_key(keys::SESSION, (StateStoreDataKey::USER_AVATAR_URL, user_id)) + } + } + } + + async fn get_kv_data(&self, key: StateStoreDataKey<'_>) -> Result> { + let encoded_key = self.encode_kv_data_key(key); + + let value = + self.kv.get(encoded_key)?.map(|e| self.deserialize_value::(&e)).transpose()?; + + let value = match key { + StateStoreDataKey::SyncToken => value.map(StateStoreDataValue::SyncToken), + StateStoreDataKey::Filter(_) => value.map(StateStoreDataValue::Filter), + StateStoreDataKey::UserAvatarUrl(_) => value.map(StateStoreDataValue::UserAvatarUrl), + }; + + Ok(value) } - pub async fn get_filter(&self, filter_name: &str) -> Result> { - self.session - .get(self.encode_key(SESSION, ("filter", filter_name)))? - .map(|f| self.deserialize_value(&f)) - .transpose() + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> Result<()> { + let encoded_key = self.encode_kv_data_key(key); + + let value = match key { + StateStoreDataKey::SyncToken => { + value.into_sync_token().expect("Session data not a sync token") + } + StateStoreDataKey::Filter(_) => value.into_filter().expect("Session data not a filter"), + StateStoreDataKey::UserAvatarUrl(_) => { + value.into_user_avatar_url().expect("Session data not an user avatar url") + } + }; + + self.kv.insert(encoded_key, self.serialize_value(&value)?)?; + + Ok(()) } - pub async fn get_sync_token(&self) -> Result> { - self.session.get(SYNC_TOKEN.encode())?.map(|t| self.deserialize_value(&t)).transpose() + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<()> { + let encoded_key = self.encode_kv_data_key(key); + + self.kv.remove(encoded_key)?; + + Ok(()) } pub async fn save_changes(&self, changes: &StateChanges) -> Result<()> { @@ -567,45 +518,46 @@ impl SledStateStore { let key = (room, event.state_key()); stripped_joined - .remove(self.encode_key(STRIPPED_JOINED_USER_ID, key))?; + .remove(self.encode_key(keys::STRIPPED_JOINED_USER_ID, key))?; stripped_invited - .remove(self.encode_key(STRIPPED_INVITED_USER_ID, key))?; + .remove(self.encode_key(keys::STRIPPED_INVITED_USER_ID, key))?; match event.membership() { MembershipState::Join => { joined.insert( - self.encode_key(JOINED_USER_ID, key), + self.encode_key(keys::JOINED_USER_ID, key), self.serialize_value(event.state_key()) .map_err(ConflictableTransactionError::Abort)?, )?; - invited.remove(self.encode_key(INVITED_USER_ID, key))?; + invited.remove(self.encode_key(keys::INVITED_USER_ID, key))?; } MembershipState::Invite => { invited.insert( - self.encode_key(INVITED_USER_ID, key), + self.encode_key(keys::INVITED_USER_ID, key), self.serialize_value(event.state_key()) .map_err(ConflictableTransactionError::Abort)?, )?; - joined.remove(self.encode_key(JOINED_USER_ID, key))?; + joined.remove(self.encode_key(keys::JOINED_USER_ID, key))?; } _ => { - joined.remove(self.encode_key(JOINED_USER_ID, key))?; - invited.remove(self.encode_key(INVITED_USER_ID, key))?; + joined.remove(self.encode_key(keys::JOINED_USER_ID, key))?; + invited.remove(self.encode_key(keys::INVITED_USER_ID, key))?; } } members.insert( - self.encode_key(MEMBER, key), + self.encode_key(keys::MEMBER, key), self.serialize_value(&raw_event) .map_err(ConflictableTransactionError::Abort)?, )?; - stripped_members.remove(self.encode_key(STRIPPED_ROOM_MEMBER, key))?; + stripped_members + .remove(self.encode_key(keys::STRIPPED_ROOM_MEMBER, key))?; if let Some(profile) = profile_changes.and_then(|p| p.get(event.state_key())) { profiles.insert( - self.encode_key(PROFILE, key), + self.encode_key(keys::PROFILE, key), self.serialize_value(&profile) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -616,7 +568,7 @@ impl SledStateStore { for (room_id, ambiguity_maps) in &changes.ambiguity_maps { for (display_name, map) in ambiguity_maps { display_names.insert( - self.encode_key(DISPLAY_NAME, (room_id, display_name)), + self.encode_key(keys::DISPLAY_NAME, (room_id, display_name)), self.serialize_value(&map) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -626,7 +578,7 @@ impl SledStateStore { for (room, events) in &changes.room_account_data { for (event_type, event) in events { room_account_data.insert( - self.encode_key(ROOM_ACCOUNT_DATA, (room, event_type)), + self.encode_key(keys::ROOM_ACCOUNT_DATA, (room, event_type)), self.serialize_value(&event) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -637,12 +589,15 @@ impl SledStateStore { for (event_type, events) in event_types { for (state_key, event) in events { state.insert( - self.encode_key(ROOM_STATE, (room, event_type, state_key)), + self.encode_key( + keys::ROOM_STATE, + (room, event_type, state_key), + ), self.serialize_value(&event) .map_err(ConflictableTransactionError::Abort)?, )?; stripped_state.remove(self.encode_key( - STRIPPED_ROOM_STATE, + keys::STRIPPED_ROOM_STATE, (room, event_type, state_key), ))?; } @@ -651,20 +606,21 @@ impl SledStateStore { for (room_id, room_info) in &changes.room_infos { rooms.insert( - self.encode_key(ROOM, room_id), + self.encode_key(keys::ROOM, room_id), self.serialize_value(room_info) .map_err(ConflictableTransactionError::Abort)?, )?; - stripped_rooms.remove(self.encode_key(STRIPPED_ROOM_INFO, room_id))?; + stripped_rooms + .remove(self.encode_key(keys::STRIPPED_ROOM_INFO, room_id))?; } for (room_id, info) in &changes.stripped_room_infos { stripped_rooms.insert( - self.encode_key(STRIPPED_ROOM_INFO, room_id), + self.encode_key(keys::STRIPPED_ROOM_INFO, room_id), self.serialize_value(&info) .map_err(ConflictableTransactionError::Abort)?, )?; - rooms.remove(self.encode_key(ROOM, room_id))?; + rooms.remove(self.encode_key(keys::ROOM, room_id))?; } for (room, raw_events) in &changes.stripped_members { @@ -687,31 +643,35 @@ impl SledStateStore { match event.content.membership { MembershipState::Join => { stripped_joined.insert( - self.encode_key(STRIPPED_JOINED_USER_ID, key), + self.encode_key(keys::STRIPPED_JOINED_USER_ID, key), self.serialize_value(&event.state_key) .map_err(ConflictableTransactionError::Abort)?, )?; - stripped_invited - .remove(self.encode_key(STRIPPED_INVITED_USER_ID, key))?; + stripped_invited.remove( + self.encode_key(keys::STRIPPED_INVITED_USER_ID, key), + )?; } MembershipState::Invite => { stripped_invited.insert( - self.encode_key(STRIPPED_INVITED_USER_ID, key), + self.encode_key(keys::STRIPPED_INVITED_USER_ID, key), self.serialize_value(&event.state_key) .map_err(ConflictableTransactionError::Abort)?, )?; - stripped_joined - .remove(self.encode_key(STRIPPED_JOINED_USER_ID, key))?; + stripped_joined.remove( + self.encode_key(keys::STRIPPED_JOINED_USER_ID, key), + )?; } _ => { - stripped_joined - .remove(self.encode_key(STRIPPED_JOINED_USER_ID, key))?; - stripped_invited - .remove(self.encode_key(STRIPPED_INVITED_USER_ID, key))?; + stripped_joined.remove( + self.encode_key(keys::STRIPPED_JOINED_USER_ID, key), + )?; + stripped_invited.remove( + self.encode_key(keys::STRIPPED_INVITED_USER_ID, key), + )?; } } stripped_members.insert( - self.encode_key(STRIPPED_ROOM_MEMBER, key), + self.encode_key(keys::STRIPPED_ROOM_MEMBER, key), self.serialize_value(&raw_event) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -723,7 +683,7 @@ impl SledStateStore { for (state_key, event) in events { stripped_state.insert( self.encode_key( - STRIPPED_ROOM_STATE, + keys::STRIPPED_ROOM_STATE, (room, event_type.to_string(), state_key), ), self.serialize_value(&event) @@ -745,7 +705,7 @@ impl SledStateStore { let make_room_version = |room_id| { self.room_info - .get(self.encode_key(ROOM_INFO, room_id)) + .get(self.encode_key(keys::ROOM, room_id)) .ok() .flatten() .map(|r| self.deserialize_value::(&r)) @@ -760,7 +720,7 @@ impl SledStateStore { }; for (room_id, redactions) in &changes.redactions { - let key_prefix = self.encode_key(ROOM_STATE, room_id); + let key_prefix = self.encode_key(keys::ROOM_STATE, room_id); let mut room_version = None; // iterate through all saved state events and check whether they are among the @@ -796,11 +756,18 @@ impl SledStateStore { for (receipt_type, receipts) in receipts { for (user_id, receipt) in receipts { // Add the receipt to the room user receipts - if let Some(old) = room_user_receipts.insert( - self.encode_key( - ROOM_USER_RECEIPT, + let key = match receipt.thread.as_str() { + Some(thread_id) => self.encode_key( + keys::ROOM_USER_RECEIPT, + (room, receipt_type, thread_id, user_id), + ), + None => self.encode_key( + keys::ROOM_USER_RECEIPT, (room, receipt_type, user_id), ), + }; + if let Some(old) = room_user_receipts.insert( + key, self.serialize_value(&(event_id, receipt)) .map_err(ConflictableTransactionError::Abort)?, )? { @@ -808,18 +775,32 @@ impl SledStateStore { let (old_event, _): (OwnedEventId, Receipt) = self .deserialize_value(&old) .map_err(ConflictableTransactionError::Abort)?; - room_event_receipts.remove(self.encode_key( - ROOM_EVENT_RECEIPT, - (room, receipt_type, old_event, user_id), - ))?; + let key = match receipt.thread.as_str() { + Some(thread_id) => self.encode_key( + keys::ROOM_EVENT_RECEIPT, + (room, receipt_type, thread_id, old_event, user_id), + ), + None => self.encode_key( + keys::ROOM_EVENT_RECEIPT, + (room, receipt_type, old_event, user_id), + ), + }; + room_event_receipts.remove(key)?; } // Add the receipt to the room event receipts - room_event_receipts.insert( - self.encode_key( - ROOM_EVENT_RECEIPT, + let key = match receipt.thread.as_str() { + Some(thread_id) => self.encode_key( + keys::ROOM_EVENT_RECEIPT, + (room, receipt_type, thread_id, event_id, user_id), + ), + None => self.encode_key( + keys::ROOM_EVENT_RECEIPT, (room, receipt_type, event_id, user_id), ), + }; + room_event_receipts.insert( + key, self.serialize_value(&(user_id, receipt)) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -830,7 +811,7 @@ impl SledStateStore { for (sender, event) in &changes.presence { presence.insert( - self.encode_key(PRESENCE, sender), + self.encode_key(keys::PRESENCE, sender), self.serialize_value(&event) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -843,18 +824,18 @@ impl SledStateStore { ret?; // user state - let ret: Result<(), TransactionError> = (&self.session, &self.account_data) - .transaction(|(session, account_data)| { + let ret: Result<(), TransactionError> = (&self.kv, &self.account_data) + .transaction(|(kv, account_data)| { if let Some(s) = &changes.sync_token { - session.insert( - SYNC_TOKEN.encode(), + kv.insert( + self.encode_kv_data_key(StateStoreDataKey::SyncToken), self.serialize_value(s).map_err(ConflictableTransactionError::Abort)?, )?; } for (event_type, event) in &changes.account_data { account_data.insert( - self.encode_key(ACCOUNT_DATA, event_type), + self.encode_key(keys::ACCOUNT_DATA, event_type), self.serialize_value(&event) .map_err(ConflictableTransactionError::Abort)?, )?; @@ -874,7 +855,7 @@ impl SledStateStore { pub async fn get_presence_event(&self, user_id: &UserId) -> Result>> { let db = self.clone(); - let key = self.encode_key(PRESENCE, user_id); + let key = self.encode_key(keys::PRESENCE, user_id); spawn_blocking(move || db.presence.get(key)?.map(|e| db.deserialize_value(&e)).transpose()) .await? } @@ -886,7 +867,7 @@ impl SledStateStore { state_key: &str, ) -> Result>> { let db = self.clone(); - let key = self.encode_key(ROOM_STATE, (room_id, event_type.to_string(), state_key)); + let key = self.encode_key(keys::ROOM_STATE, (room_id, event_type.to_string(), state_key)); spawn_blocking(move || { db.room_state.get(key)?.map(|e| db.deserialize_value(&e)).transpose() }) @@ -899,7 +880,7 @@ impl SledStateStore { event_type: StateEventType, ) -> Result>> { let db = self.clone(); - let key = self.encode_key(ROOM_STATE, (room_id, event_type.to_string())); + let key = self.encode_key(keys::ROOM_STATE, (room_id, event_type.to_string())); spawn_blocking(move || { db.room_state .scan_prefix(key) @@ -915,7 +896,7 @@ impl SledStateStore { user_id: &UserId, ) -> Result>> { let db = self.clone(); - let key = self.encode_key(PROFILE, (room_id, user_id)); + let key = self.encode_key(keys::PROFILE, (room_id, user_id)); spawn_blocking(move || db.profiles.get(key)?.map(|p| db.deserialize_value(&p)).transpose()) .await? } @@ -926,8 +907,8 @@ impl SledStateStore { state_key: &UserId, ) -> Result> { let db = self.clone(); - let key = self.encode_key(MEMBER, (room_id, state_key)); - let stripped_key = self.encode_key(STRIPPED_ROOM_MEMBER, (room_id, state_key)); + let key = self.encode_key(keys::MEMBER, (room_id, state_key)); + let stripped_key = self.encode_key(keys::STRIPPED_ROOM_MEMBER, (room_id, state_key)); spawn_blocking(move || { if let Some(e) = db .stripped_members @@ -971,7 +952,7 @@ impl SledStateStore { room_id: &RoomId, ) -> StoreResult>> { let db = self.clone(); - let key = self.encode_key(INVITED_USER_ID, room_id); + let key = self.encode_key(keys::INVITED_USER_ID, room_id); spawn_blocking(move || { stream::iter(db.invited_user_ids.scan_prefix(key).map(move |u| { db.deserialize_value(&u.map_err(StoreError::backend)?.1) @@ -987,7 +968,7 @@ impl SledStateStore { room_id: &RoomId, ) -> StoreResult>> { let db = self.clone(); - let key = self.encode_key(JOINED_USER_ID, room_id); + let key = self.encode_key(keys::JOINED_USER_ID, room_id); spawn_blocking(move || { stream::iter(db.joined_user_ids.scan_prefix(key).map(move |u| { db.deserialize_value(&u.map_err(StoreError::backend)?.1) @@ -1003,7 +984,7 @@ impl SledStateStore { room_id: &RoomId, ) -> StoreResult>> { let db = self.clone(); - let key = self.encode_key(STRIPPED_INVITED_USER_ID, room_id); + let key = self.encode_key(keys::STRIPPED_INVITED_USER_ID, room_id); spawn_blocking(move || { stream::iter(db.stripped_invited_user_ids.scan_prefix(key).map(move |u| { db.deserialize_value(&u.map_err(StoreError::backend)?.1) @@ -1019,7 +1000,7 @@ impl SledStateStore { room_id: &RoomId, ) -> StoreResult>> { let db = self.clone(); - let key = self.encode_key(STRIPPED_JOINED_USER_ID, room_id); + let key = self.encode_key(keys::STRIPPED_JOINED_USER_ID, room_id); spawn_blocking(move || { stream::iter(db.stripped_joined_user_ids.scan_prefix(key).map(move |u| { db.deserialize_value(&u.map_err(StoreError::backend)?.1) @@ -1054,7 +1035,7 @@ impl SledStateStore { display_name: &str, ) -> Result> { let db = self.clone(); - let key = self.encode_key(DISPLAY_NAME, (room_id, display_name)); + let key = self.encode_key(keys::DISPLAY_NAME, (room_id, display_name)); spawn_blocking(move || { Ok(db .display_names @@ -1071,7 +1052,7 @@ impl SledStateStore { event_type: GlobalAccountDataEventType, ) -> Result>> { let db = self.clone(); - let key = self.encode_key(ACCOUNT_DATA, event_type); + let key = self.encode_key(keys::ACCOUNT_DATA, event_type); spawn_blocking(move || { db.account_data.get(key)?.map(|m| db.deserialize_value(&m)).transpose() }) @@ -1084,7 +1065,7 @@ impl SledStateStore { event_type: RoomAccountDataEventType, ) -> Result>> { let db = self.clone(); - let key = self.encode_key(ROOM_ACCOUNT_DATA, (room_id, event_type)); + let key = self.encode_key(keys::ROOM_ACCOUNT_DATA, (room_id, event_type)); spawn_blocking(move || { db.room_account_data.get(key)?.map(|m| db.deserialize_value(&m)).transpose() }) @@ -1095,10 +1076,15 @@ impl SledStateStore { &self, room_id: &RoomId, receipt_type: ReceiptType, + thread: ReceiptThread, user_id: &UserId, ) -> Result> { let db = self.clone(); - let key = self.encode_key(ROOM_USER_RECEIPT, (room_id, receipt_type, user_id)); + let key = match thread.as_str() { + Some(thread_id) => self + .encode_key(keys::ROOM_USER_RECEIPT, (room_id, receipt_type, thread_id, user_id)), + None => self.encode_key(keys::ROOM_USER_RECEIPT, (room_id, receipt_type, user_id)), + }; spawn_blocking(move || { db.room_user_receipts.get(key)?.map(|m| db.deserialize_value(&m)).transpose() }) @@ -1109,10 +1095,15 @@ impl SledStateStore { &self, room_id: &RoomId, receipt_type: ReceiptType, + thread: ReceiptThread, event_id: &EventId, ) -> StoreResult> { let db = self.clone(); - let key = self.encode_key(ROOM_EVENT_RECEIPT, (room_id, receipt_type, event_id)); + let key = match thread.as_str() { + Some(thread_id) => self + .encode_key(keys::ROOM_EVENT_RECEIPT, (room_id, receipt_type, thread_id, event_id)), + None => self.encode_key(keys::ROOM_EVENT_RECEIPT, (room_id, receipt_type, event_id)), + }; spawn_blocking(move || { db.room_event_receipts .scan_prefix(key) @@ -1129,7 +1120,10 @@ impl SledStateStore { async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> Result<()> { self.media.insert( - self.encode_key(MEDIA, (request.source.unique_key(), request.format.unique_key())), + self.encode_key( + keys::MEDIA, + (request.source.unique_key(), request.format.unique_key()), + ), self.serialize_value(&data)?, )?; @@ -1140,8 +1134,8 @@ impl SledStateStore { async fn get_media_content(&self, request: &MediaRequest) -> Result>> { let db = self.clone(); - let key = - self.encode_key(MEDIA, (request.source.unique_key(), request.format.unique_key())); + let key = self + .encode_key(keys::MEDIA, (request.source.unique_key(), request.format.unique_key())); spawn_blocking(move || { db.media.get(key)?.map(move |m| db.deserialize_value(&m)).transpose() @@ -1152,13 +1146,13 @@ impl SledStateStore { async fn get_custom_value(&self, key: &[u8]) -> Result>> { let custom = self.custom.clone(); let me = self.clone(); - let key = self.encode_key(CUSTOM, EncodeUnchecked::from(key)); + let key = self.encode_key(keys::CUSTOM, EncodeUnchecked::from(key)); spawn_blocking(move || custom.get(key)?.map(move |v| me.deserialize_value(&v)).transpose()) .await? } async fn set_custom_value(&self, key: &[u8], value: Vec) -> Result>> { - let key = self.encode_key(CUSTOM, EncodeUnchecked::from(key)); + let key = self.encode_key(keys::CUSTOM, EncodeUnchecked::from(key)); let me = self.clone(); let ret = self .custom @@ -1170,16 +1164,28 @@ impl SledStateStore { ret } + async fn remove_custom_value(&self, key: &[u8]) -> Result>> { + let key = self.encode_key(keys::CUSTOM, EncodeUnchecked::from(key)); + let me = self.clone(); + let ret = self.custom.remove(key)?.map(|v| me.deserialize_value(&v)).transpose(); + self.inner.flush_async().await?; + + ret + } + async fn remove_media_content(&self, request: &MediaRequest) -> Result<()> { self.media.remove( - self.encode_key(MEDIA, (request.source.unique_key(), request.format.unique_key())), + self.encode_key( + keys::MEDIA, + (request.source.unique_key(), request.format.unique_key()), + ), )?; Ok(()) } async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> { - let keys = self.media.scan_prefix(self.encode_key(MEDIA, uri)).keys(); + let keys = self.media.scan_prefix(self.encode_key(keys::MEDIA, uri)).keys(); let mut batch = sled::Batch::default(); for key in keys { @@ -1191,29 +1197,34 @@ impl SledStateStore { async fn remove_room(&self, room_id: &RoomId) -> Result<()> { let mut members_batch = sled::Batch::default(); - for key in self.members.scan_prefix(self.encode_key(MEMBER, room_id)).keys() { + for key in self.members.scan_prefix(self.encode_key(keys::MEMBER, room_id)).keys() { members_batch.remove(key?); } let mut stripped_members_batch = sled::Batch::default(); - for key in - self.stripped_members.scan_prefix(self.encode_key(STRIPPED_ROOM_MEMBER, room_id)).keys() + for key in self + .stripped_members + .scan_prefix(self.encode_key(keys::STRIPPED_ROOM_MEMBER, room_id)) + .keys() { stripped_members_batch.remove(key?); } let mut profiles_batch = sled::Batch::default(); - for key in self.profiles.scan_prefix(self.encode_key(PROFILE, room_id)).keys() { + for key in self.profiles.scan_prefix(self.encode_key(keys::PROFILE, room_id)).keys() { profiles_batch.remove(key?); } let mut display_names_batch = sled::Batch::default(); - for key in self.display_names.scan_prefix(self.encode_key(DISPLAY_NAME, room_id)).keys() { + for key in + self.display_names.scan_prefix(self.encode_key(keys::DISPLAY_NAME, room_id)).keys() + { display_names_batch.remove(key?); } let mut joined_user_ids_batch = sled::Batch::default(); - for key in self.joined_user_ids.scan_prefix(self.encode_key(JOINED_USER_ID, room_id)).keys() + for key in + self.joined_user_ids.scan_prefix(self.encode_key(keys::JOINED_USER_ID, room_id)).keys() { joined_user_ids_batch.remove(key?); } @@ -1221,15 +1232,17 @@ impl SledStateStore { let mut stripped_joined_user_ids_batch = sled::Batch::default(); for key in self .stripped_joined_user_ids - .scan_prefix(self.encode_key(STRIPPED_JOINED_USER_ID, room_id)) + .scan_prefix(self.encode_key(keys::STRIPPED_JOINED_USER_ID, room_id)) .keys() { stripped_joined_user_ids_batch.remove(key?); } let mut invited_user_ids_batch = sled::Batch::default(); - for key in - self.invited_user_ids.scan_prefix(self.encode_key(INVITED_USER_ID, room_id)).keys() + for key in self + .invited_user_ids + .scan_prefix(self.encode_key(keys::INVITED_USER_ID, room_id)) + .keys() { invited_user_ids_batch.remove(key?); } @@ -1237,29 +1250,31 @@ impl SledStateStore { let mut stripped_invited_user_ids_batch = sled::Batch::default(); for key in self .stripped_invited_user_ids - .scan_prefix(self.encode_key(STRIPPED_INVITED_USER_ID, room_id)) + .scan_prefix(self.encode_key(keys::STRIPPED_INVITED_USER_ID, room_id)) .keys() { stripped_invited_user_ids_batch.remove(key?); } let mut room_state_batch = sled::Batch::default(); - for key in self.room_state.scan_prefix(self.encode_key(ROOM_STATE, room_id)).keys() { + for key in self.room_state.scan_prefix(self.encode_key(keys::ROOM_STATE, room_id)).keys() { room_state_batch.remove(key?); } let mut stripped_room_state_batch = sled::Batch::default(); for key in self .stripped_room_state - .scan_prefix(self.encode_key(STRIPPED_ROOM_STATE, room_id)) + .scan_prefix(self.encode_key(keys::STRIPPED_ROOM_STATE, room_id)) .keys() { stripped_room_state_batch.remove(key?); } let mut room_account_data_batch = sled::Batch::default(); - for key in - self.room_account_data.scan_prefix(self.encode_key(ROOM_ACCOUNT_DATA, room_id)).keys() + for key in self + .room_account_data + .scan_prefix(self.encode_key(keys::ROOM_ACCOUNT_DATA, room_id)) + .keys() { room_account_data_batch.remove(key?); } @@ -1295,8 +1310,8 @@ impl SledStateStore { stripped_state, room_account_data, )| { - rooms.remove(self.encode_key(ROOM, room_id))?; - stripped_rooms.remove(self.encode_key(STRIPPED_ROOM_INFO, room_id))?; + rooms.remove(self.encode_key(keys::ROOM, room_id))?; + stripped_rooms.remove(self.encode_key(keys::STRIPPED_ROOM_INFO, room_id))?; members.apply_batch(&members_batch)?; stripped_members.apply_batch(&stripped_members_batch)?; @@ -1316,8 +1331,10 @@ impl SledStateStore { ret?; let mut room_user_receipts_batch = sled::Batch::default(); - for key in - self.room_user_receipts.scan_prefix(self.encode_key(ROOM_USER_RECEIPT, room_id)).keys() + for key in self + .room_user_receipts + .scan_prefix(self.encode_key(keys::ROOM_USER_RECEIPT, room_id)) + .keys() { room_user_receipts_batch.remove(key?); } @@ -1325,7 +1342,7 @@ impl SledStateStore { let mut room_event_receipts_batch = sled::Batch::default(); for key in self .room_event_receipts - .scan_prefix(self.encode_key(ROOM_EVENT_RECEIPT, room_id)) + .scan_prefix(self.encode_key(keys::ROOM_EVENT_RECEIPT, room_id)) .keys() { room_event_receipts_batch.remove(key?); @@ -1349,20 +1366,29 @@ impl SledStateStore { #[async_trait] impl StateStore for SledStateStore { - async fn save_filter(&self, filter_name: &str, filter_id: &str) -> StoreResult<()> { - self.save_filter(filter_name, filter_id).await.map_err(Into::into) + type Error = StoreError; + + async fn get_kv_data( + &self, + key: StateStoreDataKey<'_>, + ) -> StoreResult> { + self.get_kv_data(key).await.map_err(Into::into) } - async fn save_changes(&self, changes: &StateChanges) -> StoreResult<()> { - self.save_changes(changes).await.map_err(Into::into) + async fn set_kv_data( + &self, + key: StateStoreDataKey<'_>, + value: StateStoreDataValue, + ) -> StoreResult<()> { + self.set_kv_data(key, value).await.map_err(Into::into) } - async fn get_filter(&self, filter_id: &str) -> StoreResult> { - self.get_filter(filter_id).await.map_err(Into::into) + async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> StoreResult<()> { + self.remove_kv_data(key).await.map_err(Into::into) } - async fn get_sync_token(&self) -> StoreResult> { - self.get_sync_token().await.map_err(Into::into) + async fn save_changes(&self, changes: &StateChanges) -> StoreResult<()> { + self.save_changes(changes).await.map_err(Into::into) } async fn get_presence_event( @@ -1477,18 +1503,22 @@ impl StateStore for SledStateStore { &self, room_id: &RoomId, receipt_type: ReceiptType, + thread: ReceiptThread, user_id: &UserId, ) -> StoreResult> { - self.get_user_room_receipt_event(room_id, receipt_type, user_id).await.map_err(Into::into) + self.get_user_room_receipt_event(room_id, receipt_type, thread, user_id) + .await + .map_err(Into::into) } async fn get_event_room_receipt_events( &self, room_id: &RoomId, receipt_type: ReceiptType, + thread: ReceiptThread, event_id: &EventId, ) -> StoreResult> { - self.get_event_room_receipt_events(room_id, receipt_type, event_id) + self.get_event_room_receipt_events(room_id, receipt_type, thread, event_id) .await .map_err(Into::into) } @@ -1501,6 +1531,10 @@ impl StateStore for SledStateStore { self.set_custom_value(key, value).await.map_err(Into::into) } + async fn remove_custom_value(&self, key: &[u8]) -> StoreResult>> { + self.remove_custom_value(key).await.map_err(Into::into) + } + async fn add_media_content(&self, request: &MediaRequest, data: Vec) -> StoreResult<()> { self.add_media_content(request, data).await.map_err(Into::into) } @@ -1547,95 +1581,3 @@ mod encrypted_tests { statestore_integration_tests!(with_media_tests); } - -#[cfg(test)] -mod migration { - use matrix_sdk_test::async_test; - use tempfile::TempDir; - - use super::{MigrationConflictStrategy, Result, SledStateStore, SledStoreError}; - - #[async_test] - pub async fn migrating_v1_to_2_plain() -> Result<()> { - let folder = TempDir::new()?; - - let store = SledStateStore::builder().path(folder.path().to_path_buf()).build()?; - - store.set_db_version(1u8)?; - drop(store); - - // this transparently migrates to the latest version - let _store = SledStateStore::builder().path(folder.path().to_path_buf()).build()?; - Ok(()) - } - - #[async_test] - pub async fn migrating_v1_to_2_with_pw_backed_up() -> Result<()> { - let folder = TempDir::new()?; - - let store = SledStateStore::builder() - .path(folder.path().to_path_buf()) - .passphrase("something".to_owned()) - .build()?; - - store.set_db_version(1u8)?; - drop(store); - - // this transparently creates a backup and a fresh db - let _store = SledStateStore::builder() - .path(folder.path().to_path_buf()) - .passphrase("something".to_owned()) - .build()?; - assert_eq!(std::fs::read_dir(folder.path())?.count(), 2); - Ok(()) - } - - #[async_test] - pub async fn migrating_v1_to_2_with_pw_drop() -> Result<()> { - let folder = TempDir::new()?; - - let store = SledStateStore::builder() - .path(folder.path().to_path_buf()) - .passphrase("other thing".to_owned()) - .build()?; - - store.set_db_version(1u8)?; - drop(store); - - // this transparently creates a backup and a fresh db - let _store = SledStateStore::builder() - .path(folder.path().to_path_buf()) - .passphrase("other thing".to_owned()) - .migration_conflict_strategy(MigrationConflictStrategy::Drop) - .build()?; - assert_eq!(std::fs::read_dir(folder.path())?.count(), 1); - Ok(()) - } - - #[async_test] - pub async fn migrating_v1_to_2_with_pw_raises() -> Result<()> { - let folder = TempDir::new()?; - - let store = SledStateStore::builder() - .path(folder.path().to_path_buf()) - .passphrase("secret".to_owned()) - .build()?; - - store.set_db_version(1u8)?; - drop(store); - - // this transparently creates a backup and a fresh db - let res = SledStateStore::builder() - .path(folder.path().to_path_buf()) - .passphrase("secret".to_owned()) - .migration_conflict_strategy(MigrationConflictStrategy::Raise) - .build(); - if let Err(SledStoreError::MigrationConflict { .. }) = res { - // all good - } else { - panic!("Didn't raise the expected error: {res:?}"); - } - assert_eq!(std::fs::read_dir(folder.path())?.count(), 1); - Ok(()) - } -} diff --git a/crates/matrix-sdk-sqlite/Cargo.toml b/crates/matrix-sdk-sqlite/Cargo.toml index 15a7c80b90c..576f5a9e338 100644 --- a/crates/matrix-sdk-sqlite/Cargo.toml +++ b/crates/matrix-sdk-sqlite/Cargo.toml @@ -23,7 +23,7 @@ dashmap = { workspace = true } deadpool-sqlite = "0.5.0" fs_extra = "1.2.0" futures-core = "0.3.21" -futures-util = { version = "0.3.21", default-features = false } +futures-util = { workspace = true } matrix-sdk-base = { version = "0.6.0", path = "../matrix-sdk-base", optional = true } matrix-sdk-common = { version = "0.6.0", path = "../matrix-sdk-common" } matrix-sdk-crypto = { version = "0.6.0", path = "../matrix-sdk-crypto", optional = true } @@ -32,13 +32,10 @@ rmp-serde = "1.1.1" ruma = { workspace = true } rusqlite = { version = "0.28.0", features = ["bundled"] } serde = { workspace = true } -serde_json = { workspace = true } thiserror = { workspace = true } -tokio = { version = "1.23.1", default-features = false, features = [ - "sync", - "fs", -] } +tokio = { version = "1.24.2", default-features = false, features = ["sync", "fs"] } tracing = { workspace = true } +vodozemac = { workspace = true } [dev-dependencies] ctor = { workspace = true } @@ -48,8 +45,5 @@ matrix-sdk-crypto = { path = "../matrix-sdk-crypto", features = ["testing"] } matrix-sdk-test = { path = "../../testing/matrix-sdk-test" } once_cell = { workspace = true } tempfile = "3.3.0" -tokio = { version = "1.23.1", default-features = false, features = [ - "rt-multi-thread", - "macros", -] } +tokio = { version = "1.24.2", default-features = false, features = ["rt-multi-thread", "macros"] } tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } diff --git a/crates/matrix-sdk-sqlite/migrations/002_reset_olm_hash.sql b/crates/matrix-sdk-sqlite/migrations/002_reset_olm_hash.sql new file mode 100644 index 00000000000..8df4c6f5fc5 --- /dev/null +++ b/crates/matrix-sdk-sqlite/migrations/002_reset_olm_hash.sql @@ -0,0 +1,4 @@ +-- Hashes in the olm_hash table were initially stored as JSON, even though +-- everything else is MessagePack. Alongside this migration, the encoding is +-- updated. +DELETE FROM "olm_hash"; diff --git a/crates/matrix-sdk-sqlite/migrations/003_room_settings.sql b/crates/matrix-sdk-sqlite/migrations/003_room_settings.sql new file mode 100644 index 00000000000..91fcbae1fd4 --- /dev/null +++ b/crates/matrix-sdk-sqlite/migrations/003_room_settings.sql @@ -0,0 +1,4 @@ +CREATE TABLE room_settings( + "room_id" BLOB PRIMARY KEY NOT NULL, + "data" BLOB NOT NULL +); diff --git a/crates/matrix-sdk-sqlite/src/crypto_store.rs b/crates/matrix-sdk-sqlite/src/crypto_store.rs index 897286ec83d..eedf4f0132b 100644 --- a/crates/matrix-sdk-sqlite/src/crypto_store.rs +++ b/crates/matrix-sdk-sqlite/src/crypto_store.rs @@ -14,6 +14,7 @@ use std::{ collections::HashMap, + fmt, path::{Path, PathBuf}, sync::{Arc, RwLock}, }; @@ -26,10 +27,7 @@ use matrix_sdk_crypto::{ IdentityKeys, InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession, PrivateCrossSigningIdentity, Session, }, - store::{ - caches::SessionStore, BackupKeys, Changes, CryptoStore, CryptoStoreError, - Result as StoreResult, RoomKeyCounts, - }, + store::{caches::SessionStore, BackupKeys, Changes, CryptoStore, RoomKeyCounts, RoomSettings}, GossipRequest, ReadOnlyAccount, ReadOnlyDevice, ReadOnlyUserIdentities, SecretInfo, TrackedUser, }; @@ -41,9 +39,10 @@ use tokio::fs; use tracing::{debug, error, instrument, warn}; use crate::{ + error::{Error, Result}, get_or_create_store_cipher, - utils::{Key, SqliteObjectExt}, - OpenStoreError, SqliteConnectionExt as _, SqliteObjectStoreExt, + utils::{Key, SqliteConnectionExt as _, SqliteObjectExt, SqliteObjectStoreExt as _}, + OpenStoreError, }; #[derive(Clone, Debug)] @@ -53,43 +52,6 @@ pub struct AccountInfo { identity_keys: Arc, } -#[derive(Debug)] -enum Error { - Crypto(CryptoStoreError), - Sqlite(rusqlite::Error), - Pool(deadpool_sqlite::PoolError), -} - -impl From for Error { - fn from(value: CryptoStoreError) -> Self { - Self::Crypto(value) - } -} - -impl From for Error { - fn from(value: rusqlite::Error) -> Self { - Self::Sqlite(value) - } -} - -impl From for Error { - fn from(value: deadpool_sqlite::PoolError) -> Self { - Self::Pool(value) - } -} - -impl From for CryptoStoreError { - fn from(value: Error) -> Self { - match value { - Error::Crypto(c) => c, - Error::Sqlite(b) => CryptoStoreError::backend(b), - Error::Pool(b) => CryptoStoreError::backend(b), - } - } -} - -type Result = std::result::Result; - /// A sqlite based cryptostore. #[derive(Clone)] pub struct SqliteCryptoStore { @@ -102,12 +64,13 @@ pub struct SqliteCryptoStore { session_cache: SessionStore, } -impl std::fmt::Debug for SqliteCryptoStore { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +#[cfg(not(tarpaulin_include))] +impl fmt::Debug for SqliteCryptoStore { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Some(path) = &self.path { - f.debug_struct("SledCryptoStore").field("path", &path).finish() + f.debug_struct("SqliteCryptoStore").field("path", &path).finish() } else { - f.debug_struct("SledCryptoStore").field("path", &"memory store").finish() + f.debug_struct("SqliteCryptoStore").field("path", &"memory store").finish() } } } @@ -120,7 +83,7 @@ impl SqliteCryptoStore { passphrase: Option<&str>, ) -> Result { let path = path.as_ref(); - fs::create_dir_all(path).await.map_err(CryptoStoreError::from)?; + fs::create_dir_all(path).await.map_err(OpenStoreError::CreateDir)?; let cfg = deadpool_sqlite::Config::new(path.join("matrix-sdk-crypto.sqlite3")); let pool = cfg.create_pool(Runtime::Tokio1)?; @@ -133,8 +96,8 @@ impl SqliteCryptoStore { pool: SqlitePool, passphrase: Option<&str>, ) -> Result { - let conn = pool.get().await.map_err(CryptoStoreError::backend)?; - run_migrations(&conn).await?; + let conn = pool.get().await?; + run_migrations(&conn).await.map_err(OpenStoreError::Migration)?; let store_cipher = match passphrase { Some(p) => Some(Arc::new(get_or_create_store_cipher(p, &conn).await?)), None => None, @@ -149,26 +112,25 @@ impl SqliteCryptoStore { }) } - fn serialize_value(&self, value: &impl Serialize) -> Result, CryptoStoreError> { - let serialized = rmp_serde::to_vec_named(value).map_err(CryptoStoreError::backend)?; + fn serialize_value(&self, value: &impl Serialize) -> Result> { + let serialized = rmp_serde::to_vec_named(value)?; if let Some(key) = &self.store_cipher { - let encrypted = - key.encrypt_value_data(serialized).map_err(CryptoStoreError::backend)?; - rmp_serde::to_vec_named(&encrypted).map_err(CryptoStoreError::backend) + let encrypted = key.encrypt_value_data(serialized)?; + Ok(rmp_serde::to_vec_named(&encrypted)?) } else { Ok(serialized) } } - fn deserialize_value(&self, value: &[u8]) -> Result { + fn deserialize_value(&self, value: &[u8]) -> Result { if let Some(key) = &self.store_cipher { - let encrypted = rmp_serde::from_slice(value).map_err(CryptoStoreError::backend)?; - let decrypted = key.decrypt_value_data(encrypted).map_err(CryptoStoreError::backend)?; + let encrypted = rmp_serde::from_slice(value)?; + let decrypted = key.decrypt_value_data(encrypted)?; - rmp_serde::from_slice(&decrypted).map_err(CryptoStoreError::backend) + Ok(rmp_serde::from_slice(&decrypted)?) } else { - rmp_serde::from_slice(value).map_err(CryptoStoreError::backend) + Ok(rmp_serde::from_slice(value)?) } } @@ -176,7 +138,7 @@ impl SqliteCryptoStore { &self, value: &[u8], backed_up: bool, - ) -> Result { + ) -> Result { let mut pickle: PickledInboundGroupSession = self.deserialize_value(value)?; // backed_up SQL column is source of truth, backed_up field in pickle // needed for other stores though @@ -184,11 +146,7 @@ impl SqliteCryptoStore { Ok(pickle) } - fn deserialize_key_request( - &self, - value: &[u8], - sent_out: bool, - ) -> Result { + fn deserialize_key_request(&self, value: &[u8], sent_out: bool) -> Result { let mut request: GossipRequest = self.deserialize_value(value)?; // sent_out SQL column is source of truth, sent_out field in serialized value // needed for other stores though @@ -212,46 +170,18 @@ impl SqliteCryptoStore { async fn acquire(&self) -> Result { Ok(self.pool.get().await?) } - - async fn load_tracked_users(&self) -> Result> { - self.acquire() - .await? - .get_tracked_users() - .await? - .iter() - .map(|value| Ok(self.deserialize_value(value)?)) - .collect() - } - - async fn save_tracked_users( - &self, - tracked_users: &[(&UserId, bool)], - ) -> Result<(), CryptoStoreError> { - let users: Vec<(Key, Vec)> = tracked_users - .iter() - .map(|(u, d)| { - let user_id = self.encode_key("tracked_users", u.as_bytes()); - let data = - self.serialize_value(&TrackedUser { user_id: (*u).into(), dirty: *d })?; - Ok((user_id, data)) - }) - .collect::>()?; - - Ok(self.acquire().await?.add_tracked_users(users).await?) - } } -const DATABASE_VERSION: u8 = 1; +const DATABASE_VERSION: u8 = 3; -async fn run_migrations(conn: &SqliteConn) -> Result<(), CryptoStoreError> { +async fn run_migrations(conn: &SqliteConn) -> rusqlite::Result<()> { let kv_exists = conn .query_row( "SELECT count(*) FROM sqlite_master WHERE type = 'table' AND name = 'kv'", (), |row| row.get::<_, u32>(0), ) - .await - .map_err(CryptoStoreError::backend)? + .await? > 0; let version = if kv_exists { @@ -279,15 +209,26 @@ async fn run_migrations(conn: &SqliteConn) -> Result<(), CryptoStoreError> { if version < 1 { // First turn on WAL mode, this can't be done in the transaction, it fails with // the error message: "cannot change into wal mode from within a transaction". - conn.execute_batch("PRAGMA journal_mode = wal;") - .await - .map_err(CryptoStoreError::backend)?; + conn.execute_batch("PRAGMA journal_mode = wal;").await?; conn.with_transaction(|txn| txn.execute_batch(include_str!("../migrations/001_init.sql"))) - .await - .map_err(CryptoStoreError::backend)?; + .await?; } - conn.set_kv("version", vec![DATABASE_VERSION]).await.map_err(CryptoStoreError::backend)?; + if version < 2 { + conn.with_transaction(|txn| { + txn.execute_batch(include_str!("../migrations/002_reset_olm_hash.sql")) + }) + .await?; + } + + if version < 3 { + conn.with_transaction(|txn| { + txn.execute_batch(include_str!("../migrations/003_room_settings.sql")) + }) + .await?; + } + + conn.set_kv("version", vec![DATABASE_VERSION]).await?; Ok(()) } @@ -323,6 +264,8 @@ trait SqliteConnectionExt { sent_out: bool, data: &[u8], ) -> rusqlite::Result<()>; + + fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()>; } impl SqliteConnectionExt for rusqlite::Connection { @@ -414,6 +357,16 @@ impl SqliteConnectionExt for rusqlite::Connection { )?; Ok(()) } + + fn set_room_settings(&self, room_id: &[u8], data: &[u8]) -> rusqlite::Result<()> { + self.execute( + "INSERT INTO room_settings (room_id, data) + VALUES (?1, ?2) + ON CONFLICT (room_id) DO UPDATE SET data = ?2", + (room_id, data), + )?; + Ok(()) + } } #[async_trait] @@ -581,6 +534,15 @@ trait SqliteObjectCryptoStoreExt: SqliteObjectExt { self.execute("DELETE FROM key_requests WHERE request_id = ?", (request_id,)).await?; Ok(()) } + + async fn get_room_settings(&self, room_id: Key) -> Result>> { + Ok(self + .query_row("SELECT data FROM room_settings WHERE room_id = ?", (room_id,), |row| { + row.get(0) + }) + .await + .optional()?) + } } #[async_trait] @@ -588,12 +550,14 @@ impl SqliteObjectCryptoStoreExt for deadpool_sqlite::Object {} #[async_trait] impl CryptoStore for SqliteCryptoStore { - async fn load_account(&self) -> StoreResult> { + type Error = Error; + + async fn load_account(&self) -> Result> { let conn = self.acquire().await?; if let Some(pickle) = conn.get_kv("account").await? { let pickle = self.deserialize_value(&pickle)?; - let account = ReadOnlyAccount::from_pickle(pickle)?; + let account = ReadOnlyAccount::from_pickle(pickle).map_err(|_| Error::Unpickle)?; let account_info = AccountInfo { user_id: account.user_id.clone(), @@ -609,7 +573,7 @@ impl CryptoStore for SqliteCryptoStore { } } - async fn save_account(&self, account: ReadOnlyAccount) -> StoreResult<()> { + async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> { let account_info = AccountInfo { user_id: account.user_id.clone(), device_id: account.device_id.clone(), @@ -623,21 +587,21 @@ impl CryptoStore for SqliteCryptoStore { Ok(()) } - async fn load_identity(&self) -> StoreResult> { + async fn load_identity(&self) -> Result> { let conn = self.acquire().await?; if let Some(i) = conn.get_kv("identity").await? { let pickle = self.deserialize_value(&i)?; Ok(Some( PrivateCrossSigningIdentity::from_pickle(pickle) .await - .map_err(|_| CryptoStoreError::UnpicklingError)?, + .map_err(|_| Error::Unpickle)?, )) } else { Ok(None) } } - async fn save_changes(&self, changes: Changes) -> StoreResult<()> { + async fn save_changes(&self, changes: Changes) -> Result<()> { let pickled_account = if let Some(account) = changes.account { let account_info = AccountInfo { user_id: account.user_id.clone(), @@ -652,7 +616,7 @@ impl CryptoStore for SqliteCryptoStore { }; let pickled_private_identity = - if let Some(i) = changes.private_identity { Some(i.pickle().await?) } else { None }; + if let Some(i) = changes.private_identity { Some(i.pickle().await) } else { None }; let mut session_changes = Vec::new(); for session in changes.sessions { @@ -744,7 +708,7 @@ impl CryptoStore for SqliteCryptoStore { } for hash in &changes.message_hashes { - let hash = serde_json::to_vec(hash).map_err(CryptoStoreError::from)?; + let hash = rmp_serde::to_vec(hash)?; txn.add_olm_hash(&hash)?; } @@ -754,6 +718,12 @@ impl CryptoStore for SqliteCryptoStore { txn.set_key_request(&request_id, request.sent_out, &serialized_request)?; } + for (room_id, settings) in changes.room_settings { + let room_id = this.encode_key("room_settings", room_id.as_bytes()); + let value = this.serialize_value(&settings)?; + txn.set_room_settings(&room_id, &value)?; + } + Ok::<_, Error>(()) }) .await?; @@ -761,11 +731,8 @@ impl CryptoStore for SqliteCryptoStore { Ok(()) } - async fn get_sessions( - &self, - sender_key: &str, - ) -> StoreResult>>>> { - let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?; + async fn get_sessions(&self, sender_key: &str) -> Result>>>> { + let account_info = self.get_account_info().ok_or(Error::AccountUnset)?; if self.session_cache.get(sender_key).is_none() { let sessions = self @@ -796,7 +763,7 @@ impl CryptoStore for SqliteCryptoStore { &self, room_id: &RoomId, session_id: &str, - ) -> StoreResult> { + ) -> Result> { let session_id = self.encode_key("inbound_group_session", session_id); let Some((room_id_from_db, value)) = self.acquire().await?.get_inbound_group_session(session_id).await? @@ -815,7 +782,7 @@ impl CryptoStore for SqliteCryptoStore { Ok(Some(InboundGroupSession::from_pickle(pickle)?)) } - async fn get_inbound_group_sessions(&self) -> StoreResult> { + async fn get_inbound_group_sessions(&self) -> Result> { self.acquire() .await? .get_inbound_group_sessions() @@ -828,14 +795,14 @@ impl CryptoStore for SqliteCryptoStore { .collect() } - async fn inbound_group_session_counts(&self) -> StoreResult { + async fn inbound_group_session_counts(&self) -> Result { Ok(self.acquire().await?.get_inbound_group_session_counts().await?) } async fn inbound_group_sessions_for_backup( &self, limit: usize, - ) -> StoreResult> { + ) -> Result> { self.acquire() .await? .get_inbound_group_sessions_for_backup(limit) @@ -848,11 +815,11 @@ impl CryptoStore for SqliteCryptoStore { .collect() } - async fn reset_backup_state(&self) -> StoreResult<()> { + async fn reset_backup_state(&self) -> Result<()> { Ok(self.acquire().await?.reset_inbound_group_session_backup_state().await?) } - async fn load_backup_keys(&self) -> StoreResult { + async fn load_backup_keys(&self) -> Result { let conn = self.acquire().await?; let backup_version = conn @@ -873,37 +840,54 @@ impl CryptoStore for SqliteCryptoStore { async fn get_outbound_group_session( &self, room_id: &RoomId, - ) -> StoreResult> { + ) -> Result> { let room_id = self.encode_key("outbound_group_session", room_id.as_bytes()); let Some(value) = self.acquire().await?.get_outbound_group_session(room_id).await? else { return Ok(None); }; - let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?; + let account_info = self.get_account_info().ok_or(Error::AccountUnset)?; let pickle = self.deserialize_value(&value)?; let session = OutboundGroupSession::from_pickle( account_info.device_id, account_info.identity_keys, pickle, - )?; + ) + .map_err(|_| Error::Unpickle)?; return Ok(Some(session)); } - async fn load_tracked_users(&self) -> StoreResult> { - Ok(self.load_tracked_users().await?) + async fn load_tracked_users(&self) -> Result> { + self.acquire() + .await? + .get_tracked_users() + .await? + .iter() + .map(|value| self.deserialize_value(value)) + .collect() } - async fn save_tracked_users(&self, users: &[(&UserId, bool)]) -> StoreResult<()> { - self.save_tracked_users(users).await + async fn save_tracked_users(&self, tracked_users: &[(&UserId, bool)]) -> Result<()> { + let users: Vec<(Key, Vec)> = tracked_users + .iter() + .map(|(u, d)| { + let user_id = self.encode_key("tracked_users", u.as_bytes()); + let data = + self.serialize_value(&TrackedUser { user_id: (*u).into(), dirty: *d })?; + Ok((user_id, data)) + }) + .collect::>()?; + + Ok(self.acquire().await?.add_tracked_users(users).await?) } async fn get_device( &self, user_id: &UserId, device_id: &DeviceId, - ) -> StoreResult> { + ) -> Result> { let user_id = self.encode_key("device", user_id.as_bytes()); let device_id = self.encode_key("device", device_id.as_bytes()); Ok(self @@ -918,7 +902,7 @@ impl CryptoStore for SqliteCryptoStore { async fn get_user_devices( &self, user_id: &UserId, - ) -> StoreResult> { + ) -> Result> { let user_id = self.encode_key("device", user_id.as_bytes()); self.acquire() .await? @@ -932,10 +916,7 @@ impl CryptoStore for SqliteCryptoStore { .collect() } - async fn get_user_identity( - &self, - user_id: &UserId, - ) -> StoreResult> { + async fn get_user_identity(&self, user_id: &UserId) -> Result> { let user_id = self.encode_key("identity", user_id.as_bytes()); Ok(self .acquire() @@ -949,15 +930,15 @@ impl CryptoStore for SqliteCryptoStore { async fn is_message_known( &self, message_hash: &matrix_sdk_crypto::olm::OlmMessageHash, - ) -> StoreResult { - let value = serde_json::to_vec(message_hash)?; + ) -> Result { + let value = rmp_serde::to_vec(message_hash)?; Ok(self.acquire().await?.has_olm_hash(value).await?) } async fn get_outgoing_secret_requests( &self, request_id: &TransactionId, - ) -> StoreResult> { + ) -> Result> { let request_id = self.encode_key("key_requests", request_id.as_bytes()); Ok(self .acquire() @@ -971,7 +952,7 @@ impl CryptoStore for SqliteCryptoStore { async fn get_secret_request_by_info( &self, key_info: &SecretInfo, - ) -> StoreResult> { + ) -> Result> { let requests = self.acquire().await?.get_outgoing_secret_requests().await?; for (request, sent_out) in requests { let request = self.deserialize_key_request(&request, sent_out)?; @@ -982,7 +963,7 @@ impl CryptoStore for SqliteCryptoStore { Ok(None) } - async fn get_unsent_secret_requests(&self) -> StoreResult> { + async fn get_unsent_secret_requests(&self) -> Result> { self.acquire() .await? .get_unsent_secret_requests() @@ -995,10 +976,47 @@ impl CryptoStore for SqliteCryptoStore { .collect() } - async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> StoreResult<()> { + async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> { let request_id = self.encode_key("key_requests", request_id.as_bytes()); Ok(self.acquire().await?.delete_key_request(request_id).await?) } + + async fn get_room_settings(&self, room_id: &RoomId) -> Result> { + let room_id = self.encode_key("room_settings", room_id.as_bytes()); + let Some(value) = self.acquire().await?.get_room_settings(room_id).await? else { + return Ok(None); + }; + + let settings = self.deserialize_value(&value)?; + + return Ok(Some(settings)); + } + + async fn get_custom_value(&self, key: &str) -> Result>> { + let Some(serialized) = self.acquire().await?.get_kv(key).await? else { + return Ok(None); + }; + let value = if let Some(cipher) = &self.store_cipher { + let encrypted = rmp_serde::from_slice(&serialized)?; + cipher.decrypt_value_data(encrypted)? + } else { + serialized + }; + + Ok(Some(value)) + } + + async fn set_custom_value(&self, key: &str, value: Vec) -> Result<()> { + let serialized = if let Some(cipher) = &self.store_cipher { + let encrypted = cipher.encrypt_value_data(value)?; + rmp_serde::to_vec_named(&encrypted)? + } else { + value + }; + + self.acquire().await?.set_kv(key, serialized).await?; + Ok(()) + } } #[cfg(test)] diff --git a/crates/matrix-sdk-sqlite/src/error.rs b/crates/matrix-sdk-sqlite/src/error.rs new file mode 100644 index 00000000000..dbcb435a5e1 --- /dev/null +++ b/crates/matrix-sdk-sqlite/src/error.rs @@ -0,0 +1,97 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use deadpool_sqlite::{CreatePoolError, PoolError}; +#[cfg(feature = "crypto-store")] +use matrix_sdk_crypto::CryptoStoreError; +use thiserror::Error; +use tokio::io; + +/// All the errors that can occur when opening a sled store. +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum OpenStoreError { + /// Failed to create the DB's parent directory. + #[error("Failed to create the database's parent directory")] + CreateDir(#[source] io::Error), + + /// Failed to create the DB pool. + #[error(transparent)] + CreatePool(#[from] CreatePoolError), + + /// Failed to apply migrations. + #[error("Failed to run migrations")] + Migration(#[source] rusqlite::Error), + + /// Failed to get a DB connection from the pool. + #[error(transparent)] + Pool(#[from] PoolError), + + /// Failed to initialize the store cipher. + #[error("Failed to initialize the store cipher")] + InitCipher(#[from] matrix_sdk_store_encryption::Error), + + /// Failed to load the store cipher from the DB. + #[error("Failed to load the store cipher from the DB")] + LoadCipher(#[source] rusqlite::Error), + + /// Failed to save the store cipher to the DB. + #[error("Failed to save the store cipher to the DB")] + SaveCipher(#[source] rusqlite::Error), +} + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + Sqlite(rusqlite::Error), + #[error(transparent)] + Pool(PoolError), + #[error(transparent)] + Encode(rmp_serde::encode::Error), + #[error(transparent)] + Decode(rmp_serde::decode::Error), + #[error(transparent)] + Encryption(matrix_sdk_store_encryption::Error), + #[error("can't save/load sessions or group sessions in the store before an account is stored")] + AccountUnset, + #[error(transparent)] + Pickle(#[from] vodozemac::PickleError), + #[error("An object failed to be decrypted while unpickling")] + Unpickle, +} + +macro_rules! impl_from { + ( $ty:ty => $enum:ident::$variant:ident ) => { + impl From<$ty> for $enum { + fn from(value: $ty) -> Self { + Self::$variant(value) + } + } + }; +} + +impl_from!(rusqlite::Error => Error::Sqlite); +impl_from!(PoolError => Error::Pool); +impl_from!(rmp_serde::encode::Error => Error::Encode); +impl_from!(rmp_serde::decode::Error => Error::Decode); +impl_from!(matrix_sdk_store_encryption::Error => Error::Encryption); + +#[cfg(feature = "crypto-store")] +impl From for CryptoStoreError { + fn from(e: Error) -> Self { + CryptoStoreError::backend(e) + } +} + +pub(crate) type Result = std::result::Result; diff --git a/crates/matrix-sdk-sqlite/src/lib.rs b/crates/matrix-sdk-sqlite/src/lib.rs index 97858e2321a..a707b8244f4 100644 --- a/crates/matrix-sdk-sqlite/src/lib.rs +++ b/crates/matrix-sdk-sqlite/src/lib.rs @@ -11,109 +11,42 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#![cfg_attr(not(feature = "crypto-store"), allow(dead_code, unused_imports))] -#[cfg(feature = "crypto-store")] -use async_trait::async_trait; -use deadpool_sqlite::CreatePoolError; -#[cfg(feature = "crypto-store")] use deadpool_sqlite::Object as SqliteConn; -#[cfg(feature = "crypto-store")] -use matrix_sdk_crypto::{store::Result, CryptoStoreError}; -#[cfg(feature = "crypto-store")] use matrix_sdk_store_encryption::StoreCipher; -#[cfg(feature = "crypto-store")] -use rusqlite::OptionalExtension; -use thiserror::Error; -use tracing::error; #[cfg(feature = "crypto-store")] mod crypto_store; -#[cfg(feature = "crypto-store")] +mod error; mod utils; #[cfg(feature = "crypto-store")] pub use self::crypto_store::SqliteCryptoStore; -#[cfg(feature = "crypto-store")] -use self::utils::SqliteObjectExt; +pub use self::error::OpenStoreError; +use self::utils::SqliteObjectStoreExt; -/// All the errors that can occur when opening a sled store. -#[derive(Error, Debug)] -#[non_exhaustive] -pub enum OpenStoreError { - /// An error occurred with the crypto store implementation. - #[cfg(feature = "crypto-store")] - #[error(transparent)] - Crypto(#[from] CryptoStoreError), - - /// An error occurred with sqlite. - #[error(transparent)] - Sqlite(#[from] CreatePoolError), -} - -#[cfg(feature = "crypto-store")] -async fn get_or_create_store_cipher(passphrase: &str, conn: &SqliteConn) -> Result { - let encrypted_cipher = conn.get_kv("cipher").await?; +async fn get_or_create_store_cipher( + passphrase: &str, + conn: &SqliteConn, +) -> Result { + let encrypted_cipher = conn.get_kv("cipher").await.map_err(OpenStoreError::LoadCipher)?; let cipher = if let Some(encrypted) = encrypted_cipher { - StoreCipher::import(passphrase, &encrypted) - .map_err(|_| CryptoStoreError::UnpicklingError)? + StoreCipher::import(passphrase, &encrypted)? } else { - let cipher = StoreCipher::new().map_err(CryptoStoreError::backend)?; + let cipher = StoreCipher::new()?; #[cfg(not(test))] let export = cipher.export(passphrase); #[cfg(test)] let export = cipher._insecure_export_fast_for_testing(passphrase); - conn.set_kv("cipher", export.map_err(CryptoStoreError::backend)?).await?; + conn.set_kv("cipher", export?).await.map_err(OpenStoreError::SaveCipher)?; cipher }; Ok(cipher) } -#[cfg(feature = "crypto-store")] -trait SqliteConnectionExt { - fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()>; -} - -#[cfg(feature = "crypto-store")] -impl SqliteConnectionExt for rusqlite::Connection { - fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()> { - self.execute( - "INSERT INTO kv VALUES (?1, ?2) ON CONFLICT (key) DO UPDATE SET value = ?2", - (key, value), - )?; - Ok(()) - } -} - -#[cfg(feature = "crypto-store")] -#[async_trait] -trait SqliteObjectStoreExt: SqliteObjectExt { - async fn get_kv(&self, key: &str) -> Result>> { - let key = key.to_owned(); - self.query_row("SELECT value FROM kv WHERE key = ?", (key,), |row| row.get(0)) - .await - .optional() - .map_err(CryptoStoreError::backend) - } - - async fn set_kv(&self, key: &str, value: Vec) -> Result<()>; -} - -#[cfg(feature = "crypto-store")] -#[async_trait] -impl SqliteObjectStoreExt for deadpool_sqlite::Object { - async fn set_kv(&self, key: &str, value: Vec) -> Result<()> { - let key = key.to_owned(); - self.interact(move |conn| conn.set_kv(&key, &value)) - .await - .unwrap() - .map_err(CryptoStoreError::backend)?; - - Ok(()) - } -} - #[cfg(test)] #[ctor::ctor] fn init_logging() { diff --git a/crates/matrix-sdk-sqlite/src/utils.rs b/crates/matrix-sdk-sqlite/src/utils.rs index 79ea904e36e..9decfec447b 100644 --- a/crates/matrix-sdk-sqlite/src/utils.rs +++ b/crates/matrix-sdk-sqlite/src/utils.rs @@ -15,7 +15,7 @@ use std::ops::Deref; use async_trait::async_trait; -use rusqlite::{Params, Row, Statement, Transaction}; +use rusqlite::{OptionalExtension, Params, Row, Statement, Transaction}; #[derive(Debug)] pub(crate) enum Key { @@ -112,3 +112,39 @@ impl SqliteObjectExt for deadpool_sqlite::Object { .unwrap() } } + +pub(crate) trait SqliteConnectionExt { + fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()>; +} + +impl SqliteConnectionExt for rusqlite::Connection { + fn set_kv(&self, key: &str, value: &[u8]) -> rusqlite::Result<()> { + self.execute( + "INSERT INTO kv VALUES (?1, ?2) ON CONFLICT (key) DO UPDATE SET value = ?2", + (key, value), + )?; + Ok(()) + } +} + +#[async_trait] +pub(crate) trait SqliteObjectStoreExt: SqliteObjectExt { + async fn get_kv(&self, key: &str) -> rusqlite::Result>> { + let key = key.to_owned(); + self.query_row("SELECT value FROM kv WHERE key = ?", (key,), |row| row.get(0)) + .await + .optional() + } + + async fn set_kv(&self, key: &str, value: Vec) -> rusqlite::Result<()>; +} + +#[async_trait] +impl SqliteObjectStoreExt for deadpool_sqlite::Object { + async fn set_kv(&self, key: &str, value: Vec) -> rusqlite::Result<()> { + let key = key.to_owned(); + self.interact(move |conn| conn.set_kv(&key, &value)).await.unwrap()?; + + Ok(()) + } +} diff --git a/crates/matrix-sdk-store-encryption/Cargo.toml b/crates/matrix-sdk-store-encryption/Cargo.toml index 31c797289ae..7286e3074cb 100644 --- a/crates/matrix-sdk-store-encryption/Cargo.toml +++ b/crates/matrix-sdk-store-encryption/Cargo.toml @@ -21,6 +21,7 @@ getrandom = { version = "0.2.6", optional = true } hmac = "0.12.1" pbkdf2 = "0.11.0" rand = "0.8.5" +rmp-serde = "1.1.1" serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } sha2 = "0.10.2" diff --git a/crates/matrix-sdk-store-encryption/src/lib.rs b/crates/matrix-sdk-store-encryption/src/lib.rs index 1eebd65e6b8..da17d2fcb14 100644 --- a/crates/matrix-sdk-store-encryption/src/lib.rs +++ b/crates/matrix-sdk-store-encryption/src/lib.rs @@ -41,8 +41,12 @@ type MacKeySeed = [u8; 32]; /// Error type for the `StoreCipher` operations. #[derive(Debug, Display, thiserror::Error)] pub enum Error { - /// Failed to serialize or deserialize a value {0} - Serialization(#[from] serde_json::Error), + /// Failed to serialize a value {0} + Serialization(#[from] rmp_serde::encode::Error), + /// Failed to deserialize a value {0} + Deserialization(#[from] rmp_serde::decode::Error), + /// Failed to deserialize or serialize a JSON value {0} + Json(#[from] serde_json::Error), /// Error encrypting or decrypting a value {0} Encryption(#[from] EncryptionError), /// Coulnd't generate enough randomness for a cryptographic operation: {0} @@ -51,6 +55,11 @@ pub enum Error { Version(u8, u8), /// The ciphertext had an invalid length, expected {0}, got {1} Length(usize, usize), + /** + * Failed to import a store cipher, the export used a passphrase while + * we're trying to import it using a key or vice-versa. + */ + KdfMismatch, } /// An encryption key that can be used to encrypt data for key/value stores. @@ -90,8 +99,8 @@ impl StoreCipher { /// Encrypt the store cipher using the given passphrase and export it. /// - /// This method can be used to persist the `StoreCipher` in the key/value - /// store in a safe manner. + /// This method can be used to persist the `StoreCipher` in an unencrypted + /// key/value store in a safe manner. /// /// The `StoreCipher` can later on be restored using /// [`StoreCipher::import`]. @@ -117,21 +126,47 @@ impl StoreCipher { /// # anyhow::Ok(()) }; /// ``` pub fn export(&self, passphrase: &str) -> Result, Error> { - self.export_impl(passphrase, KDF_ROUNDS) + self.export_kdf(passphrase, KDF_ROUNDS) } - #[doc(hidden)] - pub fn _insecure_export_fast_for_testing(&self, passphrase: &str) -> Result, Error> { - self.export_impl(passphrase, 1000) + /// Encrypt the store cipher using the given key and export it. + /// + /// This method can be used to persist the `StoreCipher` in an unencrypted + /// key/value store in a safe manner. + /// + /// The `StoreCipher` can later on be restored using + /// [`StoreCipher::import_with_key`]. + /// + /// # Arguments + /// + /// * `key` - The 32-byte key to be used to encrypt the store cipher. It's + /// recommended to use a freshly and securely generated random key. + /// + /// # Examples + /// + /// ``` + /// # let example = || { + /// use matrix_sdk_store_encryption::StoreCipher; + /// use serde_json::json; + /// + /// let store_cipher = StoreCipher::new()?; + /// + /// // Export the store cipher and persist it in your key/value store + /// let export = store_cipher.export_with_key(&[0u8; 32]); + /// + /// // Save the export in your key/value store. + /// # anyhow::Ok(()) }; + /// ``` + pub fn export_with_key(&self, key: &[u8; 32]) -> Result, Error> { + let store_cipher = self.export_helper(key, KdfInfo::None)?; + Ok(rmp_serde::to_vec_named(&store_cipher).expect("Can't serialize the store cipher")) } - fn export_impl(&self, passphrase: &str, kdf_rounds: u32) -> Result, Error> { - let mut rng = thread_rng(); - - let mut salt = [0u8; KDF_SALT_SIZE]; - salt.try_fill(&mut rng)?; - - let key = StoreCipher::expand_key(passphrase, &salt, kdf_rounds); + fn export_helper( + &self, + key: &[u8; 32], + kdf_info: KdfInfo, + ) -> Result { let key = ChachaKey::from_slice(key.as_ref()); let cipher = XChaCha20Poly1305::new(key); @@ -146,15 +181,62 @@ impl StoreCipher { keys.zeroize(); - let store_cipher = EncryptedStoreCipher { - kdf_info: KdfInfo::Pbkdf2ToChaCha20Poly1305 { rounds: kdf_rounds, kdf_salt: salt }, + Ok(EncryptedStoreCipher { + kdf_info, ciphertext_info: CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext }, + }) + } + + #[doc(hidden)] + pub fn _insecure_export_fast_for_testing(&self, passphrase: &str) -> Result, Error> { + self.export_kdf(passphrase, 1000) + } + + fn export_kdf(&self, passphrase: &str, kdf_rounds: u32) -> Result, Error> { + let mut rng = thread_rng(); + + let mut salt = [0u8; KDF_SALT_SIZE]; + salt.try_fill(&mut rng)?; + + let key = StoreCipher::expand_key(passphrase, &salt, kdf_rounds); + + let store_cipher = self.export_helper( + &key, + KdfInfo::Pbkdf2ToChaCha20Poly1305 { rounds: kdf_rounds, kdf_salt: salt }, + )?; + + Ok(rmp_serde::to_vec_named(&store_cipher).expect("Can't serialize the store cipher")) + } + + fn import_helper(key: &ChachaKey, encrypted: EncryptedStoreCipher) -> Result { + let mut decrypted = match encrypted.ciphertext_info { + CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext } => { + let cipher = XChaCha20Poly1305::new(key); + let nonce = XNonce::from_slice(&nonce); + cipher.decrypt(nonce, ciphertext.as_ref())? + } }; - Ok(serde_json::to_vec(&store_cipher).expect("Can't serialize the store cipher")) + if decrypted.len() != 64 { + decrypted.zeroize(); + + Err(Error::Length(64, decrypted.len())) + } else { + let mut encryption_key = Box::new([0u8; 32]); + let mut mac_key_seed = Box::new([0u8; 32]); + + encryption_key.copy_from_slice(&decrypted[0..32]); + mac_key_seed.copy_from_slice(&decrypted[32..64]); + + let keys = Keys { encryption_key, mac_key_seed }; + + decrypted.zeroize(); + + Ok(Self { inner: keys }) + } } - /// Restore a store cipher from an encrypted export. + /// Restore a store cipher from an export encrypted with a passphrase. /// /// # Arguments /// @@ -182,41 +264,66 @@ impl StoreCipher { /// # anyhow::Ok(()) }; /// ``` pub fn import(passphrase: &str, encrypted: &[u8]) -> Result { - let encrypted: EncryptedStoreCipher = serde_json::from_slice(encrypted)?; + // Our old export format used serde_json for the serialization format. Let's + // first try the new format and if that fails, try the old one. + let encrypted: EncryptedStoreCipher = + if let Ok(deserialized) = rmp_serde::from_slice(encrypted) { + deserialized + } else { + serde_json::from_slice(encrypted)? + }; let key = match encrypted.kdf_info { KdfInfo::Pbkdf2ToChaCha20Poly1305 { rounds, kdf_salt } => { Self::expand_key(passphrase, &kdf_salt, rounds) } - }; - - let key = ChachaKey::from_slice(key.as_ref()); - - let mut decrypted = match encrypted.ciphertext_info { - CipherTextInfo::ChaCha20Poly1305 { nonce, ciphertext } => { - let cipher = XChaCha20Poly1305::new(key); - let nonce = XNonce::from_slice(&nonce); - cipher.decrypt(nonce, ciphertext.as_ref())? + KdfInfo::None => { + return Err(Error::KdfMismatch); } }; - if decrypted.len() != 64 { - decrypted.zeroize(); + let key = ChachaKey::from_slice(key.as_ref()); - Err(Error::Length(64, decrypted.len())) - } else { - let mut encryption_key = Box::new([0u8; 32]); - let mut mac_key_seed = Box::new([0u8; 32]); + Self::import_helper(key, encrypted) + } - encryption_key.copy_from_slice(&decrypted[0..32]); - mac_key_seed.copy_from_slice(&decrypted[32..64]); + /// Restore a store cipher from an export encrypted with a random key. + /// + /// # Arguments + /// + /// * `key` - The 32-byte decryption key that was previously used to + /// encrypt the store cipher. + /// + /// * `encrypted` - The exported and encrypted version of the store cipher. + /// + /// # Examples + /// + /// ``` + /// # let example = || { + /// use matrix_sdk_store_encryption::StoreCipher; + /// use serde_json::json; + /// + /// let store_cipher = StoreCipher::new()?; + /// + /// // Export the store cipher and persist it in your key/value store + /// let export = store_cipher.export_with_key(&[0u8; 32])?; + /// + /// // This is now the same as `store_cipher`. + /// let imported = StoreCipher::import_with_key(&[0u8; 32], &export)?; + /// + /// // Save the export in your key/value store. + /// # anyhow::Ok(()) }; + /// ``` + pub fn import_with_key(key: &[u8; 32], encrypted: &[u8]) -> Result { + let encrypted: EncryptedStoreCipher = rmp_serde::from_slice(encrypted).unwrap(); - let keys = Keys { encryption_key, mac_key_seed }; + if let KdfInfo::Pbkdf2ToChaCha20Poly1305 { .. } = encrypted.kdf_info { + return Err(Error::KdfMismatch); + }; - decrypted.zeroize(); + let key = ChachaKey::from_slice(key.as_ref()); - Ok(Self { inner: keys }) - } + Self::import_helper(key, encrypted) } /// Hash a key before it is inserted into the key/value store. @@ -567,6 +674,7 @@ impl Keys { /// Version specific info for the key derivation method that is used. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] enum KdfInfo { + None, /// The PBKDF2 to Chacha key derivation variant. Pbkdf2ToChaCha20Poly1305 { /// The number of PBKDF rounds that were used when deriving the store @@ -635,6 +743,63 @@ mod tests { assert_eq!(value, decrypted_value); + // Can't use assert matches here since we don't have a Debug implementation for + // StoreCipher. + match StoreCipher::import_with_key(&[0u8; 32], &encrypted) { + Err(Error::KdfMismatch) => {} + _ => panic!( + "Invalid error when importing a passphrase-encrypted store cipher with a key" + ), + } + + let store_cipher = StoreCipher::new()?; + let encrypted_value = store_cipher.encrypt_value(&value)?; + + let export = store_cipher.export_with_key(&[0u8; 32])?; + let decrypted = StoreCipher::import_with_key(&[0u8; 32], &export)?; + + let decrypted_value: Value = decrypted.decrypt_value(&encrypted_value)?; + assert_eq!(value, decrypted_value); + + // Same as above, can't use assert_matches. + match StoreCipher::import_with_key(&[0u8; 32], &encrypted) { + Err(Error::KdfMismatch) => {} + _ => panic!( + "Invalid error when importing a key-encrypted store cipher with a passphrase" + ), + } + + let old_export = json!({ + "ciphertext_info": { + "ChaCha20Poly1305":{ + "ciphertext":[ + 136,202,212,194,9,223,171,109,152,84,140,183,14,55,198,22,150,130,80,135, + 161,202,79,205,151,202,120,91,108,154,252,94,56,178,108,216,186,179,167,128, + 154,107,243,195,14,138,86,78,140,159,245,170,204,227,27,84,255,161,196,69, + 60,150,69,123,67,134,28,50,10,179,250,141,221,19,202,132,28,122,92,116 + ], + "nonce":[ + 108,3,115,54,65,135,250,188,212,204,93,223,78,11,52,46, + 124,140,218,73,88,167,50,230 + ] + } + }, + "kdf_info":{ + "Pbkdf2ToChaCha20Poly1305":{ + "kdf_salt":[ + 221,133,149,116,199,122,172,189,236,42,26,204,53,164,245,158,137,113, + 31,220,239,66,64,51,242,164,185,166,176,218,209,245 + ], + "rounds":1000 + } + } + }); + + let old_export = serde_json::to_vec(&old_export)?; + + StoreCipher::import(passphrase, &old_export) + .expect("We can import the old store-cipher export"); + Ok(()) } diff --git a/crates/matrix-sdk/CHANGELOG.md b/crates/matrix-sdk/CHANGELOG.md new file mode 100644 index 00000000000..35c8a151f3f --- /dev/null +++ b/crates/matrix-sdk/CHANGELOG.md @@ -0,0 +1,7 @@ +# 0.6.2 + +- Fix the access token being printed in tracing span fields. + +# 0.6.1 + +- Fixes a bug where the access token used for Matrix requests was added as a field to a tracing span. diff --git a/crates/matrix-sdk/Cargo.toml b/crates/matrix-sdk/Cargo.toml index 0f485696826..25190954651 100644 --- a/crates/matrix-sdk/Cargo.toml +++ b/crates/matrix-sdk/Cargo.toml @@ -18,6 +18,7 @@ rustdoc-args = ["--cfg", "docsrs"] [features] default = [ "e2e-encryption", + "automatic-room-key-forwarding", "sled", "native-tls", ] @@ -25,6 +26,7 @@ testing = [] e2e-encryption = [ "matrix-sdk-base/e2e-encryption", + "matrix-sdk-base/automatic-room-key-forwarding", "matrix-sdk-sled?/crypto-store", # activate crypto-store on sled if given "matrix-sdk-indexeddb?/e2e-encryption", # activate on indexeddb if given ] @@ -34,11 +36,12 @@ sled = ["dep:matrix-sdk-sled", "matrix-sdk-sled?/state-store"] indexeddb = ["dep:matrix-sdk-indexeddb"] qrcode = ["e2e-encryption", "matrix-sdk-base/qrcode"] +automatic-room-key-forwarding = ["e2e-encryption", "matrix-sdk-base/automatic-room-key-forwarding"] markdown = ["ruma/markdown"] native-tls = ["reqwest/native-tls"] rustls-tls = ["reqwest/rustls-tls"] socks = ["reqwest/socks"] -sso-login = ["dep:hyper", "dep:rand", "dep:tokio-stream", "dep:tower"] +sso-login = ["dep:hyper", "dep:rand", "dep:tower"] appservice = ["ruma/appservice-api-s"] image-proc = ["dep:image"] image-rayon = ["image-proc", "image?/jpeg_rayon"] @@ -49,7 +52,7 @@ experimental-sliding-sync = [ "matrix-sdk-base/experimental-sliding-sync", "experimental-timeline", "reqwest/gzip", - "dep:derive_builder", + "dep:uuid", ] docsrs = [ @@ -69,13 +72,14 @@ bytes = "1.1.0" bytesize = "1.1" chrono = { version = "0.4.23", optional = true } dashmap = { workspace = true } -derive_builder = { version = "0.11.2", optional = true } event-listener = "2.5.2" +eyeball = { workspace = true } +eyeball-im = { workspace = true } eyre = { version = "0.6.8", optional = true } futures-core = "0.3.21" -futures-signals = { version = "0.3.30", default-features = false } -futures-util = { version = "0.3.21", default-features = false } +futures-util = { workspace = true } http = { workspace = true } +im = { version = "15.1.0", features = ["serde"] } indexmap = "1.9.1" hyper = { version = "0.14.20", features = ["http1", "http2", "server"], optional = true } matrix-sdk-base = { version = "0.6.0", path = "../matrix-sdk-base", default_features = false } @@ -83,17 +87,20 @@ matrix-sdk-common = { version = "0.6.0", path = "../matrix-sdk-common" } matrix-sdk-indexeddb = { version = "0.2.0", path = "../matrix-sdk-indexeddb", default-features = false, optional = true } matrix-sdk-sled = { version = "0.2.0", path = "../matrix-sdk-sled", default-features = false, optional = true } mime = "0.3.16" +mime_guess = "2.0.4" +pin-project-lite = "0.2.9" rand = { version = "0.8.5", optional = true } reqwest = { version = "0.11.10", default_features = false } -ruma = { workspace = true, features = ["compat", "rand", "unstable-msc2448", "unstable-msc2965"] } +ruma = { workspace = true, features = ["rand", "unstable-msc2448", "unstable-msc2965"] } serde = { workspace = true } serde_html_form = { workspace = true } serde_json = { workspace = true } +tempfile = "3.3.0" thiserror = { workspace = true } -tokio-stream = { version = "0.1.8", features = ["net"], optional = true } tower = { version = "0.4.13", features = ["make"], optional = true } tracing = { workspace = true, features = ["attributes"] } url = "2.2.2" +uuid = { version = "1.3.0", optional = true } zeroize = { workspace = true } [dependencies.image] @@ -118,11 +125,12 @@ optional = true [target.'cfg(target_arch = "wasm32")'.dependencies] async-once-cell = "0.4.2" -wasm-timer = "0.2.5" +gloo-timers = { version = "0.2.6", features = ["futures"] } +tokio = { version = "1.24.2", default-features = false, features = ["sync"] } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] backoff = { version = "0.4.0", features = ["tokio"] } -tokio = { version = "1.23.1", default-features = false, features = ["fs", "rt"] } +tokio = { version = "1.24.2", default-features = false, features = ["fs", "rt"] } [dev-dependencies] anyhow = { workspace = true } @@ -131,7 +139,6 @@ dirs = "4.0.0" futures = { version = "0.3.21", default-features = false, features = ["executor"] } matrix-sdk-test = { version = "0.6.0", path = "../../testing/matrix-sdk-test" } once_cell = { workspace = true } -tempfile = "3.3.0" tracing-subscriber = { version = "0.3.11", features = ["env-filter"] } [target.'cfg(target_arch = "wasm32")'.dev-dependencies] @@ -140,5 +147,5 @@ wasm-bindgen-test = "0.3.33" [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] ctor = { workspace = true } -tokio = { version = "1.23.1", default-features = false, features = ["rt-multi-thread", "macros"] } +tokio = { version = "1.24.2", default-features = false, features = ["rt-multi-thread", "macros"] } wiremock = "0.5.13" diff --git a/crates/matrix-sdk/README.md b/crates/matrix-sdk/README.md index af0d40f4241..450d789fabf 100644 --- a/crates/matrix-sdk/README.md +++ b/crates/matrix-sdk/README.md @@ -12,8 +12,7 @@ other lower-level crates. If you're attempting something more custom, you might - [`matrix_sdk_base`]: A no-network-IO client state machine which can be used to embed a Matrix client into an existing network stack or to build a new Matrix client library on top. -- [`matrix_sdk_crypto`](https://docs.rs/matrix-sdk-crypto/*/matrix_sdk_crypto/): - A no-network-IO encryption state machine which can be used to add Matrix E2EE +- [`matrix_sdk_crypto`]: A no-network-IO encryption state machine which can be used to add Matrix E2EE support into an existing client or library. # Getting started @@ -91,5 +90,7 @@ The `RUST_LOG` variable also supports a more advanced syntax for filtering log output more precisely, for instance with crate-level granularity. For more information on this, check out the [tracing_subscriber documentation]. -[examples]: https://github.com/matrix-org/matrix-rust-sdk/tree/main/crates/matrix-sdk/examples +[examples]: https://github.com/matrix-org/matrix-rust-sdk/tree/main/examples/ [tracing_subscriber documentation]: https://tracing.rs/tracing_subscriber/filter/struct.envfilter +[`matrix_sdk_crypto`]: https://docs.rs/matrix-sdk-crypto/ +[`matrix_sdk_base`]: https://docs.rs/matrix-sdk-base/ \ No newline at end of file diff --git a/crates/matrix-sdk/src/account.rs b/crates/matrix-sdk/src/account.rs index 656292a44d4..9579106f16c 100644 --- a/crates/matrix-sdk/src/account.rs +++ b/crates/matrix-sdk/src/account.rs @@ -17,6 +17,7 @@ use matrix_sdk_base::{ media::{MediaFormat, MediaRequest}, store::StateStoreExt, + StateStoreDataKey, StateStoreDataValue, }; use mime::Mime; use ruma::{ @@ -133,9 +134,32 @@ impl Account { let config = Some(RequestConfig::new().force_auth()); let response = self.client.send(request, config).await?; + if let Some(url) = response.avatar_url.clone() { + // If an avatar is found cache it. + let _ = self + .client + .store() + .set_kv_data( + StateStoreDataKey::UserAvatarUrl(user_id), + StateStoreDataValue::UserAvatarUrl(url.to_string()), + ) + .await; + } else { + // If there is no avatar the user has removed it and we uncache it. + let _ = + self.client.store().remove_kv_data(StateStoreDataKey::UserAvatarUrl(user_id)).await; + } Ok(response.avatar_url) } + /// Get the URL of the account's avatar, if is stored in cache. + pub async fn get_cached_avatar_url(&self) -> Result> { + let user_id = self.client.user_id().ok_or(Error::AuthenticationRequired)?; + let data = + self.client.store().get_kv_data(StateStoreDataKey::UserAvatarUrl(user_id)).await?; + Ok(data.map(|v| v.into_user_avatar_url().expect("Session data is not a user avatar url"))) + } + /// Set the MXC URI of the account's avatar. /// /// The avatar is unset if `url` is `None`. diff --git a/crates/matrix-sdk/src/attachment.rs b/crates/matrix-sdk/src/attachment.rs index ad5e039b510..3aa0eda9ede 100644 --- a/crates/matrix-sdk/src/attachment.rs +++ b/crates/matrix-sdk/src/attachment.rs @@ -293,43 +293,39 @@ impl Default for AttachmentConfig { /// # Examples /// /// ```no_run -/// # use std::{path::PathBuf, fs::File, io::{BufReader, Cursor, Read, Seek}}; -/// # use matrix_sdk::{ -/// # Client, -/// # attachment::{AttachmentConfig, Thumbnail, generate_image_thumbnail}, -/// # ruma::room_id -/// # }; +/// use std::{io::Cursor, path::PathBuf}; +/// +/// use matrix_sdk::attachment::{ +/// generate_image_thumbnail, AttachmentConfig, Thumbnail, +/// }; +/// use mime; +/// # use matrix_sdk::{Client, ruma::room_id }; /// # use url::Url; -/// # use mime; /// # use futures::executor::block_on; /// # block_on(async { /// # let homeserver = Url::parse("http://localhost:8080")?; /// # let mut client = Client::new(homeserver).await?; /// # let room_id = room_id!("!test:localhost"); /// let path = PathBuf::from("/home/example/my-cat.jpg"); -/// let mut image = BufReader::new(File::open(path)?); +/// let image = tokio::fs::read(path).await?; /// -/// let (thumbnail_data, thumbnail_info) = generate_image_thumbnail( -/// &mime::IMAGE_JPEG, -/// &mut image, -/// None -/// )?; -/// let mut cursor = Cursor::new(thumbnail_data); +/// let cursor = Cursor::new(&image); +/// let (thumbnail_data, thumbnail_info) = +/// generate_image_thumbnail(&mime::IMAGE_JPEG, cursor, None)?; /// let config = AttachmentConfig::with_thumbnail(Thumbnail { -/// reader: &mut cursor, -/// content_type: &mime::IMAGE_JPEG, +/// data: thumbnail_data, +/// content_type: mime::IMAGE_JPEG, /// info: Some(thumbnail_info), /// }); /// -/// image.rewind()?; -/// /// if let Some(room) = client.get_joined_room(&room_id) { /// room.send_attachment( /// "My favorite cat", /// &mime::IMAGE_JPEG, -/// &mut image, +/// image, /// config, -/// ).await?; +/// ) +/// .await?; /// } /// # anyhow::Ok(()) }); /// ``` diff --git a/crates/matrix-sdk/src/client/builder.rs b/crates/matrix-sdk/src/client/builder.rs index cc3fe02e2a9..a4ec7370a68 100644 --- a/crates/matrix-sdk/src/client/builder.rs +++ b/crates/matrix-sdk/src/client/builder.rs @@ -27,8 +27,14 @@ use ruma::{ OwnedServerName, ServerName, }; use thiserror::Error; +use tokio::sync::broadcast; #[cfg(not(target_arch = "wasm32"))] use tokio::sync::OnceCell; +use tracing::{ + debug, + field::{self, debug}, + instrument, span, Level, Span, +}; use url::Url; use super::{Client, ClientInner}; @@ -88,10 +94,19 @@ pub struct ClientBuilder { appservice_mode: bool, server_versions: Option>, handle_refresh_tokens: bool, + root_span: Span, } impl ClientBuilder { pub(crate) fn new() -> Self { + let root_span = span!( + Level::INFO, + "matrix-sdk", + user_id = field::Empty, + device_id = field::Empty, + ed25519_key = field::Empty + ); + Self { homeserver_cfg: None, http_cfg: None, @@ -101,6 +116,7 @@ impl ClientBuilder { appservice_mode: false, server_versions: None, handle_refresh_tokens: false, + root_span, } } @@ -303,7 +319,7 @@ impl ClientBuilder { /// is encountered, it means that the user needs to be logged in again. /// /// * The access token and refresh token need to be watched for changes, - /// using [`Client::session_tokens_signal()`] for example, to be able to + /// using [`Client::session_tokens_stream()`] for example, to be able to /// [restore the session] later. /// /// [refreshing access tokens]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens @@ -326,8 +342,12 @@ impl ClientBuilder { /// server discovery request is made which can fail; if you didn't set /// [`server_versions(false)`][Self::server_versions], that amounts to /// another request that can fail + #[instrument(skip_all, parent = &self.root_span, target = "matrix_sdk::client", fields(homeserver))] pub async fn build(self) -> Result { + debug!("Starting to build the Client"); + let homeserver_cfg = self.homeserver_cfg.ok_or(ClientBuildError::MissingHomeserver)?; + Span::current().record("homeserver", debug(&homeserver_cfg)); let inner_http_client = match self.http_cfg.unwrap_or_default() { #[allow(unused_mut)] @@ -359,9 +379,13 @@ impl ClientBuilder { let http_client = HttpClient::new(inner_http_client.clone(), self.request_config); let mut authentication_issuer: Option = None; + #[cfg(feature = "experimental-sliding-sync")] + let mut sliding_sync_proxy: Option = None; let homeserver = match homeserver_cfg { HomeserverConfig::Url(url) => url, HomeserverConfig::ServerName(server_name) => { + debug!("Trying to discover the homeserver"); + let homeserver = homeserver_from_name(&server_name); let well_known = http_client .send( @@ -381,6 +405,11 @@ impl ClientBuilder { if let Some(issuer) = well_known.authentication.map(|auth| auth.issuer) { authentication_issuer = Url::parse(&issuer).ok(); } + #[cfg(feature = "experimental-sliding-sync")] + if let Some(proxy) = well_known.sliding_sync_proxy.map(|p| p.url) { + sliding_sync_proxy = Url::parse(&proxy).ok(); + } + debug!(homserver_url = well_known.homeserver.base_url, "Discovered the homeserver"); well_known.homeserver.base_url } @@ -388,10 +417,16 @@ impl ClientBuilder { let homeserver = RwLock::new(Url::parse(&homeserver)?); let authentication_issuer = authentication_issuer.map(RwLock::new); + #[cfg(feature = "experimental-sliding-sync")] + let sliding_sync_proxy = sliding_sync_proxy.map(RwLock::new); + + let (unknown_token_error_sender, _) = broadcast::channel(1); let inner = Arc::new(ClientInner { homeserver, authentication_issuer, + #[cfg(feature = "experimental-sliding-sync")] + sliding_sync_proxy, http_client, base_client, server_versions: OnceCell::new_with(self.server_versions), @@ -409,8 +444,12 @@ impl ClientBuilder { sync_beat: event_listener::Event::new(), handle_refresh_tokens: self.handle_refresh_tokens, refresh_token_lock: Mutex::new(Ok(())), + unknown_token_error_sender, + root_span: self.root_span, }); + debug!("Done building the Client"); + Ok(Client { inner }) } } @@ -474,6 +513,7 @@ enum BuilderStoreConfig { Custom(StoreConfig), } +#[cfg(not(tarpaulin_include))] impl fmt::Debug for BuilderStoreConfig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { #[allow(clippy::infallible_destructuring_match)] diff --git a/crates/matrix-sdk/src/client/login_builder.rs b/crates/matrix-sdk/src/client/login_builder.rs index 8bfc6130c91..ee83b5ed44e 100644 --- a/crates/matrix-sdk/src/client/login_builder.rs +++ b/crates/matrix-sdk/src/client/login_builder.rs @@ -22,6 +22,7 @@ use std::{ use ruma::{ api::client::{session::login, uiaa::UserIdentifier}, assign, + serde::JsonObject, }; use tracing::{info, instrument}; @@ -35,16 +36,20 @@ use crate::{config::RequestConfig, Result}; /// [the spec]: https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3login enum LoginMethod { /// Login type `m.login.password` - UserPassword { id: UserIdentifier, password: String }, + UserPassword { + id: UserIdentifier, + password: String, + }, /// Login type `m.token` Token(String), + Custom(login::v3::LoginInfo), } impl LoginMethod { fn id(&self) -> Option<&UserIdentifier> { match self { LoginMethod::UserPassword { id, .. } => Some(id), - LoginMethod::Token(_) => None, + LoginMethod::Token(_) | LoginMethod::Custom(_) => None, } } @@ -52,6 +57,7 @@ impl LoginMethod { match self { LoginMethod::UserPassword { .. } => "identifier and password", LoginMethod::Token(_) => "token", + LoginMethod::Custom(_) => "custom", } } @@ -61,6 +67,7 @@ impl LoginMethod { login::v3::LoginInfo::Password(login::v3::Password::new(id, password)) } LoginMethod::Token(token) => login::v3::LoginInfo::Token(login::v3::Token::new(token)), + LoginMethod::Custom(login_info) => login_info, } } } @@ -98,6 +105,15 @@ impl LoginBuilder { Self::new(client, LoginMethod::Token(token)) } + pub(super) fn new_custom( + client: Client, + login_type: &str, + data: JsonObject, + ) -> serde_json::Result { + let login_info = login::v3::LoginInfo::new(login_type, data)?; + Ok(Self::new(client, LoginMethod::Custom(login_info))) + } + /// Set the device ID. /// /// The device ID is a unique ID that will be associated with this session. @@ -142,6 +158,7 @@ impl LoginBuilder { /// Instead of calling this function and `.await`ing its return value, you /// can also `.await` the `LoginBuilder` directly. #[instrument( + parent = &self.client.inner.root_span, target = "matrix_sdk::client", name = "login", skip_all, @@ -278,7 +295,13 @@ where /// /// Instead of calling this function and `.await`ing its return value, you /// can also `.await` the `SsoLoginBuilder` directly. - #[instrument(target = "matrix_sdk::client", name = "login", skip_all, fields(method = "sso"))] + #[instrument( + parent = &self.client.inner.root_span, + target = "matrix_sdk::client", + name = "login", + skip_all, + fields(method = "sso"), + )] pub async fn send(self) -> Result { use std::{ convert::Infallible, diff --git a/crates/matrix-sdk/src/client/mod.rs b/crates/matrix-sdk/src/client/mod.rs index 26cd2f2e81b..5d44f88a2da 100644 --- a/crates/matrix-sdk/src/client/mod.rs +++ b/crates/matrix-sdk/src/client/mod.rs @@ -15,6 +15,7 @@ // limitations under the License. use std::{ + collections::BTreeMap, fmt::{self, Debug}, future::Future, pin::Pin, @@ -24,11 +25,11 @@ use std::{ #[cfg(target_arch = "wasm32")] use async_once_cell::OnceCell; use dashmap::DashMap; -use futures_core::stream::Stream; -use futures_signals::signal::Signal; +use futures_core::Stream; +use futures_util::StreamExt; use matrix_sdk_base::{ - BaseClient, RoomType, SendOutsideWasm, Session, SessionMeta, SessionTokens, StateStore, - SyncOutsideWasm, + store::DynStateStore, BaseClient, RoomState, SendOutsideWasm, Session, SessionMeta, + SessionTokens, SyncOutsideWasm, }; use matrix_sdk_common::{ instant::Instant, @@ -41,7 +42,7 @@ use ruma::{ client::{ account::{register, whoami}, alias::get_alias, - device::{delete_devices, get_devices}, + device::{delete_devices, get_devices, update_device}, directory::{get_public_rooms, get_public_rooms_filtered}, discovery::{ get_capabilities::{self, Capabilities}, @@ -50,7 +51,7 @@ use ruma::{ error::ErrorKind, filter::{create_filter::v3::Request as FilterUploadRequest, FilterDefinition}, membership::{join_room_by_id, join_room_by_id_or_alias}, - push::get_notifications::v3::Notification, + push::{get_notifications::v3::Notification, set_pusher, Pusher}, room::create_room, session::{ get_login_types, login, logout, refresh_token, sso_login, sso_login_with_provider, @@ -61,15 +62,18 @@ use ruma::{ error::FromHttpResponseError, MatrixVersion, OutgoingRequest, SendAccessToken, }, - assign, DeviceId, OwnedDeviceId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, - RoomOrAliasId, ServerName, UInt, UserId, + assign, + serde::JsonObject, + DeviceId, OwnedDeviceId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, RoomOrAliasId, + ServerName, UInt, UserId, }; use serde::de::DeserializeOwned; +use tokio::sync::broadcast; #[cfg(not(target_arch = "wasm32"))] use tokio::sync::OnceCell; #[cfg(feature = "e2e-encryption")] use tracing::error; -use tracing::{debug, info, instrument}; +use tracing::{debug, field::display, info, instrument, trace, Instrument, Span}; use url::Url; #[cfg(feature = "e2e-encryption")] @@ -122,6 +126,13 @@ pub enum LoopCtrl { Break, } +/// Wrapper struct for ErrorKind::UnknownToken +#[derive(Debug, Clone)] +pub struct UnknownToken { + /// Whether or not the session was soft logged out + pub soft_logout: bool, +} + /// An async/await enabled Matrix client. /// /// All of the state is held in an `Arc` so the `Client` can be cloned freely. @@ -135,6 +146,9 @@ pub(crate) struct ClientInner { homeserver: RwLock, /// The OIDC Provider that is trusted by the homeserver. authentication_issuer: Option>, + /// The sliding sync proxy that is trusted by the homeserver. + #[cfg(feature = "experimental-sliding-sync")] + sliding_sync_proxy: Option>, /// The underlying HTTP client. http_client: HttpClient, /// User session data. @@ -144,11 +158,11 @@ pub(crate) struct ClientInner { /// Locks making sure we only have one group session sharing request in /// flight per room. #[cfg(feature = "e2e-encryption")] - pub(crate) group_session_locks: DashMap>>, + pub(crate) group_session_locks: Mutex>>>, /// Lock making sure we're only doing one key claim request at a time. #[cfg(feature = "e2e-encryption")] pub(crate) key_claim_lock: Mutex<()>, - pub(crate) members_request_locks: DashMap>>, + pub(crate) members_request_locks: Mutex>>>, /// Locks for requests on the encryption state of rooms. pub(crate) encryption_state_request_locks: DashMap>>, pub(crate) typing_notice_times: DashMap, @@ -174,6 +188,11 @@ pub(crate) struct ClientInner { /// wait for the sync to get the data to fetch a room object from the state /// store. pub(crate) sync_beat: event_listener::Event, + /// Client API UnknownToken error publisher. Allows the subscriber logout + /// the user when any request fails because of an invalid access token + pub(crate) unknown_token_error_sender: broadcast::Sender, + /// Root span for `tracing`. + pub(crate) root_span: Span, } #[cfg(not(tarpaulin_include))] @@ -316,6 +335,13 @@ impl Client { Some(server.read().await.clone()) } + /// The sliding sync proxy that is trusted by the homeserver. + #[cfg(feature = "experimental-sliding-sync")] + pub async fn sliding_sync_proxy(&self) -> Option { + let server = self.inner.sliding_sync_proxy.as_ref()?; + Some(server.read().await.clone()) + } + fn session_meta(&self) -> Option<&SessionMeta> { self.base_client().session_meta() } @@ -340,7 +366,7 @@ impl Client { /// /// [refreshing access tokens]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens pub fn session_tokens(&self) -> Option { - self.base_client().session_tokens().get_cloned() + self.base_client().session_tokens().get() } /// Get the current access token for this session. @@ -368,7 +394,7 @@ impl Client { self.session_tokens().and_then(|tokens| tokens.refresh_token) } - /// [`Signal`] to get notified when the current access token and optional + /// [`Stream`] to get notified when the current access token and optional /// refresh token for this session change. /// /// This can be used with [`Client::session()`] to persist the [`Session`] @@ -380,7 +406,7 @@ impl Client { /// # Example /// /// ```no_run - /// use futures_signals::signal::SignalExt; + /// use futures_util::StreamExt; /// use matrix_sdk::Client; /// # use matrix_sdk::Session; /// # use futures::executor::block_on; @@ -404,7 +430,7 @@ impl Client { /// persist_session(client.session()); /// /// // Handle when at least one of the tokens changed. - /// let future = client.session_tokens_changed_signal().for_each(move |_| { + /// let future = client.session_tokens_changed_stream().for_each(move |_| { /// let client = client.clone(); /// async move { /// persist_session(client.session()); @@ -417,15 +443,12 @@ impl Client { /// ``` /// /// [refreshing access tokens]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens - pub fn session_tokens_changed_signal(&self) -> impl Signal { - self.base_client().session_tokens().signal_ref(|_| ()) + pub fn session_tokens_changed_stream(&self) -> impl Stream { + self.session_tokens_stream().map(|_| ()) } - /// Get the current access token and optional refresh token for this - /// session as a [`Signal`]. - /// - /// This can be used to watch changes of the tokens by calling methods like - /// `for_each()` or `to_stream()`. + /// Get changes to the access token and optional refresh token for this + /// session as a [`Stream`]. /// /// The value will be `None` if the client has not been logged in. /// @@ -436,7 +459,6 @@ impl Client { /// /// ```no_run /// use futures::StreamExt; - /// use futures_signals::signal::SignalExt; /// use matrix_sdk::Client; /// # use matrix_sdk::Session; /// # use futures::executor::block_on; @@ -461,7 +483,7 @@ impl Client { /// persist_session(&session); /// /// // Handle when at least one of the tokens changed. - /// let mut tokens_stream = client.session_tokens_signal().to_stream(); + /// let mut tokens_stream = client.session_tokens_stream(); /// loop { /// if let Some(tokens) = tokens_stream.next().await.flatten() { /// session.access_token = tokens.access_token; @@ -478,8 +500,8 @@ impl Client { /// ``` /// /// [refreshing access tokens]: https://spec.matrix.org/v1.3/client-server-api/#refreshing-access-tokens - pub fn session_tokens_signal(&self) -> impl Signal> { - self.base_client().session_tokens().signal_cloned() + pub fn session_tokens_stream(&self) -> impl Stream> { + self.base_client().session_tokens() } /// Get the whole session info of this client. @@ -493,7 +515,7 @@ impl Client { } /// Get a reference to the state store. - pub fn store(&self) -> &dyn StateStore { + pub fn store(&self) -> &DynStateStore { self.base_client().store() } @@ -557,6 +579,7 @@ impl Client { /// push_rules::PushRulesEvent, /// room::{message::SyncRoomMessageEvent, topic::SyncRoomTopicEvent}, /// }, + /// push::Action, /// Int, MilliSecondsSinceUnixEpoch, /// }, /// Client, @@ -584,6 +607,16 @@ impl Client { /// } /// }, /// ); + /// client.add_event_handler( + /// |ev: SyncRoomMessageEvent, room: Room, push_actions: Vec| { + /// async move { + /// // A `Vec` parameter allows you to know which push actions + /// // are applicable for an event. For example, an event with + /// // `Action::SetTweak(Tweak::Highlight(true))` should be highlighted + /// // in the timeline. + /// } + /// }, + /// ); /// client.add_event_handler(|ev: SyncRoomTopicEvent| async move { /// // You can omit any or all arguments after the first. /// }); @@ -1006,6 +1039,50 @@ impl Client { LoginBuilder::new_password(self.clone(), id, password.to_owned()) } + /// Login to the server with a custom login type + /// + /// # Arguments + /// + /// * `login_type` - Identifier of the custom login type, e.g. + /// `org.matrix.login.jwt` + /// + /// * `data` - The additional data which should be attached to the login + /// request. + /// + /// ```no_run + /// # use futures::executor::block_on; + /// # use url::Url; + /// # let homeserver = Url::parse("http://example.com").unwrap(); + /// # block_on(async { + /// use matrix_sdk::Client; + /// + /// let client = Client::new(homeserver).await?; + /// let user = "example"; + /// + /// let response = client + /// .login_custom( + /// "org.matrix.login.jwt", + /// [("token".to_owned(), "jwt_token_content".into())] + /// .into_iter() + /// .collect(), + /// )? + /// .initial_device_display_name("My bot") + /// .await?; + /// + /// println!( + /// "Logged in as {user}, got device_id {} and access_token {}", + /// response.device_id, response.access_token, + /// ); + /// # anyhow::Ok(()) }); + /// ``` + pub fn login_custom( + &self, + login_type: &str, + data: JsonObject, + ) -> serde_json::Result { + LoginBuilder::new_custom(self.clone(), login_type, data) + } + /// Login to the server with a token. /// /// This token is usually received in the SSO flow after following the URL @@ -1142,6 +1219,16 @@ impl Client { } } + self.inner + .root_span + .record("user_id", display(&response.user_id)) + .record("device_id", display(&response.device_id)); + + #[cfg(feature = "e2e-encryption")] + if let Some(key) = self.encryption().ed25519_key().await { + self.inner.root_span.record("ed25519_key", key); + } + self.inner.base_client.receive_login_response(response).await?; Ok(()) @@ -1206,10 +1293,28 @@ impl Client { /// ``` /// /// [`login`]: #method.login + #[instrument(skip_all, parent = &self.inner.root_span)] pub async fn restore_session(&self, session: Session) -> Result<()> { + debug!("Restoring session"); + let (meta, tokens) = session.into_parts(); + + self.inner + .root_span + .record("user_id", display(&meta.user_id)) + .record("device_id", display(&meta.device_id)); + self.base_client().set_session_tokens(tokens); - Ok(self.base_client().set_session_meta(meta).await?) + self.base_client().set_session_meta(meta).await?; + + #[cfg(feature = "e2e-encryption")] + if let Some(key) = self.encryption().ed25519_key().await { + self.inner.root_span.record("ed25519_key", key); + } + + debug!("Done restoring session"); + + Ok(()) } /// Refresh the access token. @@ -1391,7 +1496,7 @@ impl Client { /// client.register(request).await; /// # }) /// ``` - #[instrument(skip_all)] + #[instrument(skip_all, parent = &self.inner.root_span)] pub async fn register( &self, request: register::v3::Request, @@ -1454,7 +1559,7 @@ impl Client { /// /// let response = client.sync_once(sync_settings).await.unwrap(); /// # }); - #[instrument(skip(self, definition))] + #[instrument(skip(self, definition), parent = &self.inner.root_span)] pub async fn get_or_upload_filter( &self, filter_name: &str, @@ -1590,7 +1695,7 @@ impl Client { let response = self.send(request, None).await?; let base_room = - self.base_client().get_or_create_room(&response.room_id, RoomType::Joined).await; + self.base_client().get_or_create_room(&response.room_id, RoomState::Joined).await; Ok(room::Joined::new(self, base_room).unwrap()) } @@ -1763,7 +1868,8 @@ impl Client { None => self.homeserver().await.to_string(), }; - self.inner + let response = self + .inner .http_client .send( request, @@ -1773,7 +1879,20 @@ impl Client { self.user_id(), self.server_versions().await?, ) - .await + .await; + + if let Err(http_error) = &response { + if let Some(ErrorKind::UnknownToken { soft_logout }) = + http_error.client_api_error_kind() + { + _ = self + .inner + .unknown_token_error_sender + .send(UnknownToken { soft_logout: *soft_logout }); + } + } + + response } async fn request_server_versions(&self) -> HttpResult> { @@ -1891,6 +2010,26 @@ impl Client { self.send(request, None).await } + /// Change the display name of a device owned by the current user. + /// + /// Returns a `update_device::Response` which specifies the result + /// of the operation. + /// + /// # Arguments + /// + /// * `device_id` - The ID of the device to change the display name of. + /// * `display_name` - The new display name to set. + pub async fn rename_device( + &self, + device_id: &DeviceId, + display_name: &str, + ) -> HttpResult { + let mut request = update_device::v3::Request::new(device_id.to_owned()); + request.display_name = Some(display_name.to_owned()); + + self.send(request, None).await + } + /// Synchronize the client's state with the latest state on the server. /// /// ## Syncing Events @@ -2147,7 +2286,7 @@ impl Client { /// .await; /// }) /// ``` - #[instrument(skip(self, callback))] + #[instrument(skip_all, parent = &self.inner.root_span)] pub async fn sync_with_callback( &self, sync_settings: crate::config::SyncSettings, @@ -2244,11 +2383,15 @@ impl Client { } loop { + trace!("Syncing"); let result = self.sync_loop_helper(&mut sync_settings).await; + trace!("Running callback"); if callback(result).await? == LoopCtrl::Break { + trace!("Callback told us to stop"); break; } + trace!("Done running callback"); Client::delay_sync(&mut last_sync_time).await } @@ -2298,7 +2441,8 @@ impl Client { /// /// # anyhow::Ok(()) }); /// ``` - #[instrument(skip(self))] + #[allow(unknown_lints, clippy::let_with_type_underscore)] // triggered by instrument macro + #[instrument(skip(self), parent = &self.inner.root_span)] pub async fn sync_stream( &self, mut sync_settings: crate::config::SyncSettings, @@ -2309,9 +2453,11 @@ impl Client { sync_settings.token = self.sync_token().await; } + let parent_span = Span::current(); + async_stream::stream! { loop { - yield self.sync_loop_helper(&mut sync_settings).await; + yield self.sync_loop_helper(&mut sync_settings).instrument(parent_span.clone()).await; Client::delay_sync(&mut last_sync_time).await } @@ -2335,6 +2481,18 @@ impl Client { let request = logout::v3::Request::new(); self.send(request, None).await } + + /// Subscribes a new receiver to client UnknownToken errors + pub fn subscribe_to_unknown_token_errors(&self) -> broadcast::Receiver { + let broadcast = &self.inner.unknown_token_error_sender; + broadcast.subscribe() + } + + /// Sets a given pusher + pub async fn set_pusher(&self, pusher: Pusher) -> HttpResult { + let request = set_pusher::v3::Request::post(pusher); + self.send(request, None).await + } } // The http mocking library is not supported for wasm32 diff --git a/crates/matrix-sdk/src/config/sync.rs b/crates/matrix-sdk/src/config/sync.rs index b44225d5f80..ef5cb8b9940 100644 --- a/crates/matrix-sdk/src/config/sync.rs +++ b/crates/matrix-sdk/src/config/sync.rs @@ -34,6 +34,7 @@ impl Default for SyncSettings { } } +#[cfg(not(tarpaulin_include))] impl fmt::Debug for SyncSettings { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut s = f.debug_struct("SyncSettings"); diff --git a/crates/matrix-sdk/src/encryption/identities/mod.rs b/crates/matrix-sdk/src/encryption/identities/mod.rs index bb154f883e3..966acd19b29 100644 --- a/crates/matrix-sdk/src/encryption/identities/mod.rs +++ b/crates/matrix-sdk/src/encryption/identities/mod.rs @@ -90,7 +90,7 @@ mod devices; mod users; pub use devices::{Device, UserDevices}; -pub use matrix_sdk_base::crypto::MasterPubkey; +pub use matrix_sdk_base::crypto::types::MasterPubkey; pub use users::UserIdentity; /// Error for the manual verification step, when we manually sign users or diff --git a/crates/matrix-sdk/src/encryption/identities/users.rs b/crates/matrix-sdk/src/encryption/identities/users.rs index 711a6416e98..d4f02c67dcc 100644 --- a/crates/matrix-sdk/src/encryption/identities/users.rs +++ b/crates/matrix-sdk/src/encryption/identities/users.rs @@ -16,7 +16,8 @@ use std::sync::Arc; use matrix_sdk_base::{ crypto::{ - MasterPubkey, OwnUserIdentity as InnerOwnUserIdentity, UserIdentity as InnerUserIdentity, + types::MasterPubkey, OwnUserIdentity as InnerOwnUserIdentity, + UserIdentity as InnerUserIdentity, }, locks::RwLock, }; diff --git a/crates/matrix-sdk/src/encryption/mod.rs b/crates/matrix-sdk/src/encryption/mod.rs index 0f0755379d6..dea2e59aa17 100644 --- a/crates/matrix-sdk/src/encryption/mod.rs +++ b/crates/matrix-sdk/src/encryption/mod.rs @@ -33,7 +33,7 @@ pub use matrix_sdk_base::crypto::{ }, vodozemac, CryptoStoreError, DecryptorError, EventError, KeyExportError, LocalTrust, MediaEncryptionInfo, MegolmError, OlmError, RoomKeyImportResult, SecretImportError, - SessionCreationError, SignatureError, + SessionCreationError, SignatureError, VERSION, }; use matrix_sdk_base::crypto::{ CrossSigningStatus, OutgoingRequest, RoomMessageRequest, ToDeviceRequest, @@ -108,6 +108,66 @@ impl Client { Ok(response) } + /// Construct a [`EncryptedFile`][ruma::events::room::EncryptedFile] by + /// encrypting and uploading a provided reader. + /// + /// # Arguments + /// * `content_type` - The content type of the file. + /// * `reader` - The reader that should be encrypted and uploaded. + /// + /// # Example + /// ```no_run + /// # use futures::executor::block_on; + /// # use matrix_sdk::Client; + /// # use url::Url; + /// # use matrix_sdk::ruma::{room_id, OwnedRoomId}; + /// use serde::{Deserialize, Serialize}; + /// use matrix_sdk::ruma::events::macros::EventContent; + /// + /// #[derive(Clone, Debug, Deserialize, Serialize, EventContent)] + /// #[ruma_event(type = "com.example.custom", kind = MessageLike)] + /// struct CustomEventContent { + /// encrypted_file: matrix_sdk::ruma::events::room::EncryptedFile, + /// } + /// # block_on(async { + /// # let homeserver = Url::parse("http://example.com")?; + /// # let client = Client::new(homeserver).await?; + /// # let room = client.get_joined_room(&room_id!("!test:example.com")).unwrap(); + /// + /// let mut reader = std::io::Cursor::new(b"Hello, world!"); + /// let encrypted_file = client.prepare_encrypted_file(&mime::TEXT_PLAIN, &mut reader).await?; + /// + /// room.send(CustomEventContent { encrypted_file }, None).await?; + /// # anyhow::Ok(()) }); + /// ``` + #[cfg(feature = "e2e-encryption")] + pub async fn prepare_encrypted_file<'a, R: Read + ?Sized + 'a>( + &self, + content_type: &mime::Mime, + reader: &'a mut R, + ) -> Result { + let mut encryptor = matrix_sdk_base::crypto::AttachmentEncryptor::new(reader); + + let mut buf = Vec::new(); + encryptor.read_to_end(&mut buf)?; + + let response = self.media().upload(content_type, buf).await?; + + let file: ruma::events::room::EncryptedFile = { + let keys = encryptor.finish(); + ruma::events::room::EncryptedFileInit { + url: response.content_uri, + key: keys.key, + iv: keys.iv, + hashes: keys.hashes, + v: keys.version, + } + .into() + }; + + Ok(file) + } + /// Encrypt and upload the file to be read from `reader` and construct an /// attachment message with `body`, `content_type`, `info` and `thumbnail`. #[cfg(feature = "e2e-encryption")] @@ -121,25 +181,8 @@ impl Client { ) -> Result { let (thumbnail_source, thumbnail_info) = if let Some(thumbnail) = thumbnail { let mut cursor = Cursor::new(thumbnail.data); - let mut encryptor = matrix_sdk_base::crypto::AttachmentEncryptor::new(&mut cursor); - - let mut buf = Vec::new(); - encryptor.read_to_end(&mut buf)?; - - let response = self.media().upload(&thumbnail.content_type, buf).await?; - - let file: ruma::events::room::EncryptedFile = { - let keys = encryptor.finish(); - ruma::events::room::EncryptedFileInit { - url: response.content_uri, - key: keys.key, - iv: keys.iv, - hashes: keys.hashes, - v: keys.version, - } - .into() - }; + let file = self.prepare_encrypted_file(content_type, &mut cursor).await?; use ruma::events::room::ThumbnailInfo; #[rustfmt::skip] @@ -154,23 +197,7 @@ impl Client { }; let mut cursor = Cursor::new(data); - let mut encryptor = matrix_sdk_base::crypto::AttachmentEncryptor::new(&mut cursor); - let mut buf = Vec::new(); - encryptor.read_to_end(&mut buf)?; - - let response = self.media().upload(content_type, buf).await?; - - let file: ruma::events::room::EncryptedFile = { - let keys = encryptor.finish(); - ruma::events::room::EncryptedFileInit { - url: response.content_uri, - key: keys.key, - iv: keys.iv, - hashes: keys.hashes, - v: keys.version, - } - .into() - }; + let file = self.prepare_encrypted_file(content_type, &mut cursor).await?; use std::io::Cursor; diff --git a/crates/matrix-sdk/src/error.rs b/crates/matrix-sdk/src/error.rs index 687c14867c6..b73275885b0 100644 --- a/crates/matrix-sdk/src/error.rs +++ b/crates/matrix-sdk/src/error.rs @@ -171,6 +171,10 @@ pub enum Error { #[error("the queried endpoint requires authentication but was called before logging in")] AuthenticationRequired, + /// This request failed because the local data wasn't sufficient. + #[error("Local cache doesn't contain all necessary data to perform the action.")] + InsufficientData, + /// Attempting to restore a session after the olm-machine has already been /// set up fails #[cfg(feature = "e2e-encryption")] diff --git a/crates/matrix-sdk/src/event_handler/context.rs b/crates/matrix-sdk/src/event_handler/context.rs index 0b4dd5fe093..7e02d54179d 100644 --- a/crates/matrix-sdk/src/event_handler/context.rs +++ b/crates/matrix-sdk/src/event_handler/context.rs @@ -16,6 +16,7 @@ use std::ops::Deref; use matrix_sdk_base::deserialized_responses::EncryptionInfo; +use ruma::push::Action; use serde_json::value::RawValue as RawJsonValue; use super::{EventHandlerData, EventHandlerHandle}; @@ -81,6 +82,12 @@ impl EventHandlerContext for Option { } } +impl EventHandlerContext for Vec { + fn from_data(data: &EventHandlerData<'_>) -> Option { + Some(data.push_actions.to_owned()) + } +} + /// A custom value registered with /// [`.add_event_handler_context`][Client::add_event_handler_context]. #[derive(Debug)] diff --git a/crates/matrix-sdk/src/event_handler/mod.rs b/crates/matrix-sdk/src/event_handler/mod.rs index b9bdfec214c..01742de947a 100644 --- a/crates/matrix-sdk/src/event_handler/mod.rs +++ b/crates/matrix-sdk/src/event_handler/mod.rs @@ -50,7 +50,7 @@ use matrix_sdk_base::{ deserialized_responses::{EncryptionInfo, SyncTimelineEvent}, SendOutsideWasm, SyncOutsideWasm, }; -use ruma::{events::AnySyncStateEvent, serde::Raw, OwnedRoomId}; +use ruma::{events::AnySyncStateEvent, push::Action, serde::Raw, OwnedRoomId}; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::value::RawValue as RawJsonValue; use tracing::{debug, error, field::debug, instrument, warn}; @@ -234,6 +234,7 @@ pub struct EventHandlerData<'a> { room: Option, raw: &'a RawJsonValue, encryption_info: Option<&'a EncryptionInfo>, + push_actions: &'a [Action], handle: EventHandlerHandle, } @@ -338,7 +339,7 @@ impl Client { for raw_event in events { let event_type = raw_event.deserialize_as::>()?.event_type; - self.call_event_handlers(room, raw_event.json(), kind, &event_type, None).await; + self.call_event_handlers(room, raw_event.json(), kind, &event_type, None, &[]).await; } Ok(()) @@ -365,7 +366,8 @@ impl Client { let redacted = unsigned.and_then(|u| u.redacted_because).is_some(); let handler_kind = HandlerKind::state_redacted(redacted); - self.call_event_handlers(room, raw_event.json(), handler_kind, &event_type, None).await; + self.call_event_handlers(room, raw_event.json(), handler_kind, &event_type, None, &[]) + .await; } Ok(()) @@ -396,18 +398,41 @@ impl Client { let raw_event = item.event.json(); let encryption_info = item.encryption_info.as_ref(); + let push_actions = &item.push_actions; // Event handlers for possibly-redacted timeline events - self.call_event_handlers(room, raw_event, handler_kind_g, &event_type, encryption_info) - .await; + self.call_event_handlers( + room, + raw_event, + handler_kind_g, + &event_type, + encryption_info, + push_actions, + ) + .await; // Event handlers specifically for redacted OR unredacted timeline events - self.call_event_handlers(room, raw_event, handler_kind_r, &event_type, encryption_info) - .await; + self.call_event_handlers( + room, + raw_event, + handler_kind_r, + &event_type, + encryption_info, + push_actions, + ) + .await; // Event handlers for `AnySyncTimelineEvent` let kind = HandlerKind::Timeline; - self.call_event_handlers(room, raw_event, kind, &event_type, encryption_info).await; + self.call_event_handlers( + room, + raw_event, + kind, + &event_type, + encryption_info, + push_actions, + ) + .await; } Ok(()) @@ -421,6 +446,7 @@ impl Client { event_kind: HandlerKind, event_type: &str, encryption_info: Option<&EncryptionInfo>, + push_actions: &[Action], ) { let room_id = room.as_ref().map(|r| r.room_id()); if let Some(room_id) = room_id { @@ -441,6 +467,7 @@ impl Client { room: room.clone(), raw, encryption_info, + push_actions, handle, }; diff --git a/crates/matrix-sdk/src/http_client.rs b/crates/matrix-sdk/src/http_client.rs index 1cc19f986e7..1656cb839c8 100644 --- a/crates/matrix-sdk/src/http_client.rs +++ b/crates/matrix-sdk/src/http_client.rs @@ -319,8 +319,7 @@ impl HttpClient { )?; let request_size = ByteSize(request.body().len().try_into().unwrap_or(u64::MAX)); - span.record("path", request.uri().path()) - .record("request_size", request_size.to_string_as(true)); + span.record("request_size", request_size.to_string_as(true)); // Since sliding sync is experimental, and the proxy might not do what we expect // it to do given a specific request body, it's useful to log the @@ -329,8 +328,14 @@ impl HttpClient { #[cfg(feature = "experimental-sliding-sync")] if type_name::() == "ruma_client_api::sync::sync_events::v4::Request" { span.record("request_body", debug(request.body())); + span.record("path", request.uri().path_and_query().map(|p| p.as_str())); + } else { + span.record("path", request.uri().path()); } + #[cfg(not(feature = "experimental-sliding-sync"))] + span.record("path", request.uri().path()); + debug!("Sending request"); match self.send_request::(request, config).await { Ok((status_code, response_size, response)) => { diff --git a/crates/matrix-sdk/src/lib.rs b/crates/matrix-sdk/src/lib.rs index c03466e1a12..d5bcafee35d 100644 --- a/crates/matrix-sdk/src/lib.rs +++ b/crates/matrix-sdk/src/lib.rs @@ -20,7 +20,7 @@ pub use async_trait::async_trait; pub use bytes; pub use matrix_sdk_base::{ deserialized_responses, DisplayName, Room as BaseRoom, RoomInfo, RoomMember as BaseRoomMember, - RoomType, Session, StateChanges, StoreError, + RoomState, Session, StateChanges, StoreError, }; pub use matrix_sdk_common::*; pub use reqwest; @@ -39,7 +39,7 @@ pub mod room; pub mod sync; #[cfg(feature = "experimental-sliding-sync")] -mod sliding_sync; +pub mod sliding_sync; #[cfg(feature = "e2e-encryption")] pub mod encryption; @@ -49,16 +49,17 @@ mod events; pub use account::Account; #[cfg(feature = "sso-login")] pub use client::SsoLoginBuilder; -pub use client::{Client, ClientBuildError, ClientBuilder, LoginBuilder, LoopCtrl}; +pub use client::{Client, ClientBuildError, ClientBuilder, LoginBuilder, LoopCtrl, UnknownToken}; #[cfg(feature = "image-proc")] pub use error::ImageError; pub use error::{Error, HttpError, HttpResult, RefreshTokenError, Result, RumaApiError}; pub use http_client::HttpSend; pub use media::Media; +pub use ruma::{IdParseError, OwnedServerName, ServerName}; #[cfg(feature = "experimental-sliding-sync")] pub use sliding_sync::{ - RoomListEntry, SlidingSync, SlidingSyncBuilder, SlidingSyncMode, SlidingSyncRoom, - SlidingSyncState, SlidingSyncView, SlidingSyncViewBuilder, UpdateSummary, + RoomListEntry, SlidingSync, SlidingSyncBuilder, SlidingSyncList, SlidingSyncListBuilder, + SlidingSyncMode, SlidingSyncRoom, SlidingSyncState, UpdateSummary, }; #[cfg(any(test, feature = "testing"))] @@ -73,3 +74,39 @@ fn init_logging() { .with(tracing_subscriber::fmt::layer().with_test_writer()) .init(); } + +/// Creates a server name from a user supplied string. The string is first +/// sanitized by removing whitespace, the http(s) scheme and any trailing +/// slashes before being parsed. +pub fn sanitize_server_name(s: &str) -> Result { + ServerName::parse( + s.trim().trim_start_matches("http://").trim_start_matches("https://").trim_end_matches('/'), + ) +} + +#[cfg(test)] +mod tests { + use assert_matches::assert_matches; + + use crate::sanitize_server_name; + + #[test] + fn test_sanitize_server_name() { + assert_eq!(sanitize_server_name("matrix.org").unwrap().as_str(), "matrix.org"); + assert_eq!(sanitize_server_name("https://matrix.org").unwrap().as_str(), "matrix.org"); + assert_eq!(sanitize_server_name("http://matrix.org").unwrap().as_str(), "matrix.org"); + assert_eq!( + sanitize_server_name("https://matrix.server.org").unwrap().as_str(), + "matrix.server.org" + ); + assert_eq!( + sanitize_server_name("https://matrix.server.org/").unwrap().as_str(), + "matrix.server.org" + ); + assert_eq!( + sanitize_server_name(" https://matrix.server.org// ").unwrap().as_str(), + "matrix.server.org" + ); + assert_matches!(sanitize_server_name("https://matrix.server.org/something"), Err(_)) + } +} diff --git a/crates/matrix-sdk/src/media.rs b/crates/matrix-sdk/src/media.rs index cd7e5a4e12c..972f4b27342 100644 --- a/crates/matrix-sdk/src/media.rs +++ b/crates/matrix-sdk/src/media.rs @@ -17,16 +17,24 @@ #[cfg(feature = "e2e-encryption")] use std::io::Read; +#[cfg(not(target_arch = "wasm32"))] +use std::path::Path; use std::time::Duration; pub use matrix_sdk_base::media::*; use mime::Mime; +#[cfg(not(target_arch = "wasm32"))] +use mime_guess; use ruma::{ api::client::media::{create_content, get_content, get_content_thumbnail}, assign, events::room::MediaSource, MxcUri, }; +#[cfg(not(target_arch = "wasm32"))] +use tempfile::{Builder as TempFileBuilder, NamedTempFile}; +#[cfg(not(target_arch = "wasm32"))] +use tokio::{fs::File as TokioFile, io::AsyncWriteExt}; use crate::{ attachment::{AttachmentInfo, Thumbnail}, @@ -45,6 +53,23 @@ pub struct Media { client: Client, } +/// A file handle that takes ownership of a media file on disk. When the handle +/// is dropped, the file will be removed from the disk. +#[derive(Debug)] +#[cfg(not(target_arch = "wasm32"))] +pub struct MediaFileHandle { + /// The temporary file that contains the media. + file: NamedTempFile, +} + +#[cfg(not(target_arch = "wasm32"))] +impl MediaFileHandle { + /// Get the media file's path. + pub fn path(&self) -> &Path { + self.file.path() + } +} + impl Media { pub(crate) fn new(client: Client) -> Self { Self { client } @@ -96,6 +121,43 @@ impl Media { Ok(self.client.send(request, Some(request_config)).await?) } + /// Gets a media file by copying it to a temporary location on disk. + /// + /// The file won't be encrypted even if it is encrypted on the server. + /// + /// Returns a `MediaFileHandle` which takes ownership of the file. When the + /// handle is dropped, the file will be deleted from the temporary location. + /// + /// # Arguments + /// + /// * `request` - The `MediaRequest` of the content. + /// + /// * `content_type` - The type of the media, this will be used to set the + /// temporary file's extension. + /// + /// * `use_cache` - If we should use the media cache for this request. + #[cfg(not(target_arch = "wasm32"))] + pub async fn get_media_file( + &self, + request: &MediaRequest, + content_type: &Mime, + use_cache: bool, + ) -> Result { + let data = self.get_media_content(request, use_cache).await?; + + let mut suffix = String::from(""); + if let Some(extension) = + mime_guess::get_mime_extensions(content_type).and_then(|a| a.first()) + { + suffix = String::from(".") + extension; + } + + let file = TempFileBuilder::new().suffix(&suffix).tempfile()?; + TokioFile::from_std(file.reopen()?).write_all(&data).await?; + + Ok(MediaFileHandle { file }) + } + /// Get a media file's content. /// /// If the content is encrypted and encryption is enabled, the content will diff --git a/crates/matrix-sdk/src/room/common.rs b/crates/matrix-sdk/src/room/common.rs index e307c4b8677..aab0694b3c1 100644 --- a/crates/matrix-sdk/src/room/common.rs +++ b/crates/matrix-sdk/src/room/common.rs @@ -28,9 +28,12 @@ use ruma::{ assign, events::{ direct::DirectEventContent, + push_rules::PushRulesEventContent, + receipt::{Receipt, ReceiptThread, ReceiptType}, room::{ encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility, - server_acl::RoomServerAclEventContent, MediaSource, + power_levels::RoomPowerLevelsEventContent, server_acl::RoomServerAclEventContent, + MediaSource, }, tag::{TagInfo, TagName}, AnyRoomAccountDataEvent, AnyStateEvent, AnySyncStateEvent, EmptyStateKey, RedactContent, @@ -38,8 +41,10 @@ use ruma::{ RoomAccountDataEventType, StateEventType, StaticEventContent, StaticStateEventContent, SyncStateEvent, }, + push::{PushConditionRoomCtx, Ruleset}, serde::Raw, - uint, EventId, MatrixToUri, MatrixUri, OwnedEventId, OwnedServerName, RoomId, UInt, UserId, + uint, EventId, MatrixToUri, MatrixUri, OwnedEventId, OwnedServerName, OwnedUserId, RoomId, + UInt, UserId, }; use serde::de::DeserializeOwned; @@ -49,7 +54,7 @@ use super::Joined; use crate::{ event_handler::{EventHandler, EventHandlerHandle, SyncEvent}, media::{MediaFormat, MediaRequest}, - room::{Left, RoomMember, RoomType}, + room::{Left, RoomMember, RoomState}, BaseRoom, Client, Error, HttpError, HttpResult, Result, }; @@ -206,11 +211,7 @@ impl Common { start: http_response.start, end: http_response.end, #[cfg(not(feature = "e2e-encryption"))] - chunk: http_response - .chunk - .into_iter() - .map(|event| TimelineEvent { event, encryption_info: None }) - .collect(), + chunk: http_response.chunk.into_iter().map(TimelineEvent::new).collect(), #[cfg(feature = "e2e-encryption")] chunk: Vec::with_capacity(http_response.chunk.len()), state: http_response.state, @@ -226,21 +227,30 @@ impl Common { if let Ok(event) = machine.decrypt_room_event(event.cast_ref(), room_id).await { event } else { - TimelineEvent { event, encryption_info: None } + TimelineEvent::new(event) } } else { - TimelineEvent { event, encryption_info: None } + TimelineEvent::new(event) }; response.chunk.push(decrypted_event); } } else { - response.chunk.extend( - http_response - .chunk - .into_iter() - .map(|event| TimelineEvent { event, encryption_info: None }), - ); + response.chunk.extend(http_response.chunk.into_iter().map(TimelineEvent::new)); + } + + if let Some(push_context) = self.push_context().await? { + let push_rules = self + .client() + .account() + .account_data::() + .await? + .and_then(|r| r.deserialize().ok().map(|r| r.global)) + .unwrap_or_else(|| Ruleset::server_default(self.own_user_id())); + + for event in &mut response.chunk { + event.push_actions = push_rules.get_actions(&event.event, &push_context).to_owned(); + } } Ok(response) @@ -270,7 +280,7 @@ impl Common { /// independent events. #[cfg(feature = "experimental-timeline")] pub async fn timeline(&self) -> Timeline { - Timeline::new(self).with_fully_read_tracking().await + Timeline::builder(self).track_read_marker_and_receipts().build().await } /// Fetch the event with the given `EventId` in this room. @@ -289,30 +299,29 @@ impl Common { return Ok(event); } } - Ok(TimelineEvent { event, encryption_info: None }) + Ok(TimelineEvent::new(event)) } #[cfg(not(feature = "e2e-encryption"))] - Ok(TimelineEvent { event, encryption_info: None }) + Ok(TimelineEvent::new(event)) } pub(crate) async fn request_members(&self) -> Result> { - if let Some(mutex) = - self.client.inner.members_request_locks.get(self.inner.room_id()).map(|m| m.clone()) - { + let mut map = self.client.inner.members_request_locks.lock().await; + + if let Some(mutex) = map.get(self.inner.room_id()).cloned() { // If a member request is already going on, await the release of // the lock. + drop(map); _ = mutex.lock().await; Ok(None) } else { let mutex = Arc::new(Mutex::new(())); - self.client - .inner - .members_request_locks - .insert(self.inner.room_id().to_owned(), mutex.clone()); + map.insert(self.inner.room_id().to_owned(), mutex.clone()); let _guard = mutex.lock().await; + drop(map); let request = get_member_events::v3::Request::new(self.inner.room_id().to_owned()); let response = self.client.send(request, None).await?; @@ -320,7 +329,7 @@ impl Common { let response = self.client.base_client().receive_members(self.inner.room_id(), &response).await?; - self.client.inner.members_request_locks.remove(self.inner.room_id()); + self.client.inner.members_request_locks.lock().await.remove(self.inner.room_id()); Ok(Some(response)) } @@ -390,20 +399,20 @@ impl Common { } } - pub(crate) async fn ensure_members(&self) -> Result<()> { + pub(crate) async fn ensure_members(&self) -> Result> { if !self.are_events_visible() { - return Ok(()); + return Ok(None); } if !self.are_members_synced() { - self.request_members().await?; + self.request_members().await + } else { + Ok(None) } - - Ok(()) } fn are_events_visible(&self) -> bool { - if let RoomType::Invited = self.inner.room_type() { + if let RoomState::Invited = self.inner.state() { return matches!( self.inner.history_visibility(), HistoryVisibility::WorldReadable | HistoryVisibility::Invited @@ -416,9 +425,10 @@ impl Common { /// Sync the member list with the server. /// /// This method will de-duplicate requests if it is called multiple times in - /// quick succession, in that case the return value will be `None`. + /// quick succession, in that case the return value will be `None`. This + /// method does nothing if the members are already synced. pub async fn sync_members(&self) -> Result> { - self.request_members().await + self.ensure_members().await } /// Get active members for this room, includes invited, joined members. @@ -973,6 +983,85 @@ impl Common { let via = self.route().await?; Ok(self.room_id().matrix_event_uri_via(event_id, via)) } + + /// Get the latest receipt of a user in this room. + /// + /// # Arguments + /// + /// * `receipt_type` - The type of receipt to get. + /// + /// * `thread` - The thread containing the event of the receipt, if any. + /// + /// * `user_id` - The ID of the user. + /// + /// Returns the ID of the event on which the receipt applies and the + /// receipt. + pub async fn user_receipt( + &self, + receipt_type: ReceiptType, + thread: ReceiptThread, + user_id: &UserId, + ) -> Result> { + self.inner.user_receipt(receipt_type, thread, user_id).await.map_err(Into::into) + } + + /// Get the receipts for an event in this room. + /// + /// # Arguments + /// + /// * `receipt_type` - The type of receipt to get. + /// + /// * `thread` - The thread containing the event of the receipt, if any. + /// + /// * `event_id` - The ID of the event. + /// + /// Returns a list of IDs of users who have sent a receipt for the event and + /// the corresponding receipts. + pub async fn event_receipts( + &self, + receipt_type: ReceiptType, + thread: ReceiptThread, + event_id: &EventId, + ) -> Result> { + self.inner.event_receipts(receipt_type, thread, event_id).await.map_err(Into::into) + } + + /// Get the push context for this room. + /// + /// Returns `None` if some data couldn't be found. This should only happen + /// in brand new rooms, while we process its state. + async fn push_context(&self) -> Result> { + let room_id = self.room_id(); + let user_id = self.own_user_id(); + let room_info = self.clone_info(); + let member_count = room_info.active_members_count(); + + let user_display_name = if let Some(member) = self.get_member_no_sync(user_id).await? { + member.name().to_owned() + } else { + return Ok(None); + }; + + let room_power_levels = if let Some(event) = self + .get_state_event_static::() + .await? + .and_then(|e| e.deserialize().ok()) + { + event.power_levels() + } else { + return Ok(None); + }; + + Ok(Some(PushConditionRoomCtx { + user_id: user_id.to_owned(), + room_id: room_id.to_owned(), + member_count: UInt::new(member_count).unwrap_or(UInt::MAX), + user_display_name, + users_power_levels: room_power_levels.users, + default_power_level: room_power_levels.users_default, + notification_power_levels: room_power_levels.notifications, + })) + } } /// Options for [`messages`][Common::messages]. diff --git a/crates/matrix-sdk/src/room/invited.rs b/crates/matrix-sdk/src/room/invited.rs index a7ea8c5c0d7..a99e210c69a 100644 --- a/crates/matrix-sdk/src/room/invited.rs +++ b/crates/matrix-sdk/src/room/invited.rs @@ -5,14 +5,14 @@ use thiserror::Error; use super::{Joined, Left}; use crate::{ room::{Common, RoomMember}, - BaseRoom, Client, Error, Result, RoomType, + BaseRoom, Client, Error, Result, RoomState, }; /// A room in the invited state. /// -/// This struct contains all methods specific to a `Room` with type -/// `RoomType::Invited`. Operations may fail once the underlying `Room` changes -/// `RoomType`. +/// This struct contains all methods specific to a `Room` with +/// `RoomState::Invited`. Operations may fail once the underlying `Room` changes +/// `RoomState`. #[derive(Debug, Clone)] pub struct Invited { pub(crate) inner: Common, @@ -37,15 +37,15 @@ pub enum InvitationError { } impl Invited { - /// Create a new `room::Invited` if the underlying `Room` has type - /// `RoomType::Invited`. + /// Create a new `room::Invited` if the underlying `Room` has + /// `RoomState::Invited`. /// /// # Arguments /// * `client` - The client used to make requests. /// /// * `room` - The underlying room. pub(crate) fn new(client: &Client, room: BaseRoom) -> Option { - if room.room_type() == RoomType::Invited { + if room.state() == RoomState::Invited { Some(Self { inner: Common::new(client.clone(), room) }) } else { None diff --git a/crates/matrix-sdk/src/room/joined.rs b/crates/matrix-sdk/src/room/joined.rs index 827c3d84159..98d4bb368d8 100644 --- a/crates/matrix-sdk/src/room/joined.rs +++ b/crates/matrix-sdk/src/room/joined.rs @@ -24,21 +24,28 @@ use ruma::{ }, assign, events::{ - room::message::RoomMessageEventContent, EmptyStateKey, MessageLikeEventContent, - StateEventContent, + receipt::ReceiptThread, + room::{ + avatar::{ImageInfo, RoomAvatarEventContent}, + message::RoomMessageEventContent, + name::RoomNameEventContent, + power_levels::RoomPowerLevelsEventContent, + topic::RoomTopicEventContent, + }, + EmptyStateKey, MessageLikeEventContent, StateEventContent, }, serde::Raw, - EventId, OwnedTransactionId, TransactionId, UserId, + EventId, Int, MxcUri, OwnedEventId, OwnedTransactionId, TransactionId, UserId, }; use serde_json::Value; -use tracing::debug; -#[cfg(feature = "e2e-encryption")] -use tracing::instrument; +use tracing::{debug, instrument}; use super::Left; use crate::{ - attachment::AttachmentConfig, error::HttpResult, room::Common, BaseRoom, Client, Result, - RoomType, + attachment::AttachmentConfig, + error::{Error, HttpResult}, + room::Common, + BaseRoom, Client, Result, RoomState, }; #[cfg(feature = "image-proc")] use crate::{ @@ -51,9 +58,9 @@ const TYPING_NOTICE_RESEND_TIMEOUT: Duration = Duration::from_secs(3); /// A room in the joined state. /// -/// The `JoinedRoom` contains all methods specific to a `Room` with type -/// `RoomType::Joined`. Operations may fail once the underlying `Room` changes -/// `RoomType`. +/// The `JoinedRoom` contains all methods specific to a `Room` with +/// `RoomState::Joined`. Operations may fail once the underlying `Room` changes +/// `RoomState`. #[derive(Debug, Clone)] pub struct Joined { pub(crate) inner: Common, @@ -68,15 +75,15 @@ impl Deref for Joined { } impl Joined { - /// Create a new `room::Joined` if the underlying `BaseRoom` has type - /// `RoomType::Joined`. + /// Create a new `room::Joined` if the underlying `BaseRoom` has + /// `RoomState::Joined`. /// /// # Arguments /// * `client` - The client used to make requests. /// /// * `room` - The underlying room. pub(crate) fn new(client: &Client, room: BaseRoom) -> Option { - if room.room_type() == RoomType::Joined { + if room.state() == RoomState::Joined { Some(Self { inner: Common::new(client.clone(), room) }) } else { None @@ -84,6 +91,7 @@ impl Joined { } /// Leave this room. + #[instrument(skip_all, parent = &self.client.inner.root_span)] pub async fn leave(&self) -> Result { self.inner.leave().await } @@ -95,6 +103,7 @@ impl Joined { /// * `user_id` - The user to ban with `UserId`. /// /// * `reason` - The reason for banning this user. + #[instrument(skip_all, parent = &self.client.inner.root_span)] pub async fn ban_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> { let request = assign!( ban_user::v3::Request::new(self.inner.room_id().to_owned(), user_id.to_owned()), @@ -112,6 +121,7 @@ impl Joined { /// room. /// /// * `reason` - Optional reason why the room member is being kicked out. + #[instrument(skip_all, parent = &self.client.inner.root_span)] pub async fn kick_user(&self, user_id: &UserId, reason: Option<&str>) -> Result<()> { let request = assign!( kick_user::v3::Request::new(self.inner.room_id().to_owned(), user_id.to_owned()), @@ -126,6 +136,7 @@ impl Joined { /// # Arguments /// /// * `user_id` - The `UserId` of the user to invite to the room. + #[instrument(skip_all, parent = &self.client.inner.root_span)] pub async fn invite_user_by_id(&self, user_id: &UserId) -> Result<()> { let recipient = InvitationRecipient::UserId { user_id: user_id.to_owned() }; @@ -140,6 +151,7 @@ impl Joined { /// # Arguments /// /// * `invite_id` - A third party id of a user to invite to the room. + #[instrument(skip_all, parent = &self.client.inner.root_span)] pub async fn invite_user_by_3pid(&self, invite_id: Invite3pid) -> Result<()> { let recipient = InvitationRecipient::ThirdPartyId(invite_id); let request = invite_user::v3::Request::new(self.inner.room_id().to_owned(), recipient); @@ -207,63 +219,85 @@ impl Joined { }; if send { - let typing = if typing { - self.client - .inner - .typing_notice_times - .insert(self.inner.room_id().to_owned(), Instant::now()); - Typing::Yes(TYPING_NOTICE_TIMEOUT) - } else { - self.client.inner.typing_notice_times.remove(self.inner.room_id()); - Typing::No - }; - - let request = TypingRequest::new( - self.inner.own_user_id().to_owned(), - self.inner.room_id().to_owned(), - typing, - ); - self.client.send(request, None).await?; + self.send_typing_notice(typing).await?; } Ok(()) } - /// Send a request to notify this room that the user has read specific - /// event. + #[instrument(name = "typing_notice", skip(self), parent = &self.client.inner.root_span)] + async fn send_typing_notice(&self, typing: bool) -> Result<()> { + let typing = if typing { + self.client + .inner + .typing_notice_times + .insert(self.inner.room_id().to_owned(), Instant::now()); + Typing::Yes(TYPING_NOTICE_TIMEOUT) + } else { + self.client.inner.typing_notice_times.remove(self.inner.room_id()); + Typing::No + }; + + let request = TypingRequest::new( + self.inner.own_user_id().to_owned(), + self.inner.room_id().to_owned(), + typing, + ); + + self.client.send(request, None).await?; + + Ok(()) + } + + /// Send a request to set a single receipt. /// /// # Arguments /// - /// * `event_id` - The `EventId` specifies the event to set the read receipt - /// on. - pub async fn read_receipt(&self, event_id: &EventId) -> Result<()> { - let request = create_receipt::v3::Request::new( + /// * `receipt_type` - The type of the receipt to set. Note that it is + /// possible to set the fully-read marker although it is technically not a + /// receipt. + /// + /// * `thread` - The thread where this receipt should apply, if any. Note + /// that this must be [`ReceiptThread::Unthreaded`] when sending a + /// [`ReceiptType::FullyRead`]. + /// + /// * `event_id` - The `EventId` of the event to set the receipt on. + #[instrument(skip_all, parent = &self.client.inner.root_span)] + pub async fn send_single_receipt( + &self, + receipt_type: ReceiptType, + thread: ReceiptThread, + event_id: OwnedEventId, + ) -> Result<()> { + let mut request = create_receipt::v3::Request::new( self.inner.room_id().to_owned(), - ReceiptType::Read, - event_id.to_owned(), + receipt_type, + event_id, ); + request.thread = thread; self.client.send(request, None).await?; Ok(()) } - /// Send a request to notify this room that the user has read up to specific - /// event. + /// Send a request to set multiple receipts at once. /// /// # Arguments /// - /// * fully_read - The `EventId` of the event the user has read to. + /// * `receipts` - The `Receipts` to send. /// - /// * read_receipt - An `EventId` to specify the event to set the read - /// receipt on. - pub async fn read_marker( - &self, - fully_read: &EventId, - read_receipt: Option<&EventId>, - ) -> Result<()> { + /// If `receipts` is empty, this is a no-op. + #[instrument(skip_all, parent = &self.client.inner.root_span)] + pub async fn send_multiple_receipts(&self, receipts: Receipts) -> Result<()> { + if receipts.is_empty() { + return Ok(()); + } + + let Receipts { fully_read, read_receipt, private_read_receipt } = receipts; let request = assign!(set_read_marker::v3::Request::new(self.inner.room_id().to_owned()), { - fully_read: Some(fully_read.to_owned()), - read_receipt: read_receipt.map(ToOwned::to_owned), + fully_read, + read_receipt, + private_read_receipt, }); self.client.send(request, None).await?; @@ -301,6 +335,7 @@ impl Joined { /// } /// # anyhow::Ok(()) }); /// ``` + #[instrument(skip_all, parent = &self.client.inner.root_span)] pub async fn enable_encryption(&self) -> Result<()> { use ruma::{ events::room::encryption::RoomEncryptionEventContent, EventEncryptionAlgorithm, @@ -327,25 +362,25 @@ impl Joined { /// room if necessary and share a room key that can be shared with them. /// /// Does nothing if no room key needs to be shared. + // TODO: expose this publicly so people can pre-share a group session if + // e.g. a user starts to type a message for a room. #[cfg(feature = "e2e-encryption")] #[instrument(skip_all, fields(room_id = ?self.room_id()))] async fn preshare_room_key(&self) -> Result<()> { - // TODO: expose this publicly so people can pre-share a group session if - // e.g. a user starts to type a message for a room. - if let Some(mutex) = - self.client.inner.group_session_locks.get(self.inner.room_id()).map(|m| m.clone()) - { + let mut map = self.client.inner.group_session_locks.lock().await; + + if let Some(mutex) = map.get(self.inner.room_id()).cloned() { // If a group session share request is already going on, await the // release of the lock. + drop(map); _ = mutex.lock().await; } else { // Otherwise create a new lock and share the group // session. let mutex = Arc::new(Mutex::new(())); - self.client - .inner - .group_session_locks - .insert(self.inner.room_id().to_owned(), mutex.clone()); + map.insert(self.inner.room_id().to_owned(), mutex.clone()); + + drop(map); let _guard = mutex.lock().await; @@ -359,7 +394,7 @@ impl Joined { let response = self.share_room_key().await; - self.client.inner.group_session_locks.remove(self.inner.room_id()); + self.client.inner.group_session_locks.lock().await.remove(self.inner.room_id()); // If one of the responses failed invalidate the group // session as using it would end up in undecryptable @@ -401,8 +436,9 @@ impl Joined { /// Warning: This waits until a sync happens and does not return if no sync /// is happening! It can also return early when the room is not a joined /// room anymore! + #[instrument(skip_all, parent = &self.client.inner.root_span)] pub async fn sync_up(&self) { - while !self.is_synced() && self.room_type() == RoomType::Joined { + while !self.is_synced() && self.state() == RoomState::Joined { self.client.inner.sync_beat.listen().wait_timeout(Duration::from_secs(1)); } } @@ -591,7 +627,7 @@ impl Joined { ); if !self.are_members_synced() { - self.request_members().await?; + self.ensure_members().await?; // TODO query keys here? } @@ -671,6 +707,7 @@ impl Joined { /// } /// # anyhow::Ok(()) }); /// ``` + #[instrument(skip_all, parent = &self.client.inner.root_span)] pub async fn send_attachment( &self, body: &str, @@ -687,12 +724,26 @@ impl Joined { #[cfg(feature = "image-proc")] let data_slot; #[cfg(feature = "image-proc")] - let thumbnail = if config.generate_thumbnail { - match generate_image_thumbnail( - content_type, - Cursor::new(&data), - config.thumbnail_size, - ) { + let (data, thumbnail) = if config.generate_thumbnail { + let content_type = content_type.clone(); + let make_thumbnail = move |data| { + let res = generate_image_thumbnail( + &content_type, + Cursor::new(&data), + config.thumbnail_size, + ); + (data, res) + }; + + #[cfg(not(target_arch = "wasm32"))] + let (data, res) = tokio::task::spawn_blocking(move || make_thumbnail(data)) + .await + .expect("Task join error"); + + #[cfg(target_arch = "wasm32")] + let (data, res) = make_thumbnail(data); + + let thumbnail = match res { Ok((thumbnail_data, thumbnail_info)) => { data_slot = thumbnail_data; Some(Thumbnail { @@ -705,9 +756,11 @@ impl Joined { ImageError::ThumbnailBiggerThanOriginal | ImageError::FormatNotSupported, ) => None, Err(error) => return Err(error.into()), - } + }; + + (data, thumbnail) } else { - None + (data, None) }; let config = AttachmentConfig { @@ -781,6 +834,93 @@ impl Joined { self.send(RoomMessageEventContent::new(content), config.txn_id.as_deref()).await } + /// Update the power levels of a select set of users of this room. + /// + /// Issue a `power_levels` state event request to the server, changing the + /// given UserId -> Int levels. May fail if the `power_levels` aren't + /// locally known yet or the server rejects the state event update, e.g. + /// because of insufficient permissions. Neither permissions to update + /// nor whether the data might be stale is checked prior to issuing the + /// request. + pub async fn update_power_levels( + &self, + updates: Vec<(&UserId, Int)>, + ) -> Result { + let raw_pl_event = self + .get_state_event_static::() + .await? + .ok_or(Error::InsufficientData)?; + + let mut power_levels = raw_pl_event.deserialize()?.power_levels(); + + for (user_id, new_level) in updates { + if new_level == power_levels.users_default { + power_levels.users.remove(user_id); + } else { + power_levels.users.insert(user_id.to_owned(), new_level); + } + } + + self.send_state_event(RoomPowerLevelsEventContent::from(power_levels)).await + } + + /// Sets the name of this room. + pub async fn set_name(&self, name: Option) -> Result { + self.send_state_event(RoomNameEventContent::new(name)).await + } + + /// Sets a new topic for this room. + pub async fn set_room_topic(&self, topic: &str) -> Result { + let topic_event = RoomTopicEventContent::new(topic.into()); + + self.send_state_event(topic_event).await + } + + /// Sets the new avatar url for this room. + /// + /// # Arguments + /// * `avatar_url` - The owned matrix uri that represents the avatar + /// * `info` - The optional image info that can be provided for the avatar + pub async fn set_avatar_url( + &self, + url: &MxcUri, + info: Option, + ) -> Result { + let mut room_avatar_event = RoomAvatarEventContent::new(); + room_avatar_event.url = Some(url.to_owned()); + room_avatar_event.info = info.map(Box::new); + + self.send_state_event(room_avatar_event).await + } + + /// Removes the avatar from the room + pub async fn remove_avatar(&self) -> Result { + let room_avatar_event = RoomAvatarEventContent::new(); + + self.send_state_event(room_avatar_event).await + } + + /// Uploads a new avatar for this room. + /// + /// # Arguments + /// * `mime` - The mime type describing the data + /// * `data` - The data representation of the avatar + /// * `info` - The optional image info provided for the avatar, + /// the blurhash and the mimetype will always be updated + pub async fn upload_avatar( + &self, + mime: &Mime, + data: Vec, + info: Option, + ) -> Result { + let upload_response = self.client.media().upload(mime, data).await?; + let mut info = info.unwrap_or_else(ImageInfo::new); + info.blurhash = upload_response.blurhash; + info.mimetype = Some(mime.to_string()); + + self.set_avatar_url(&upload_response.content_uri, Some(info)).await + } + /// Send a state event with an empty state key to the homeserver. /// /// For state events with a non-empty state key, see @@ -824,6 +964,7 @@ impl Joined { /// joined_room.send_state_event(content).await?; /// # anyhow::Ok(()) }; /// ``` + #[instrument(skip_all, parent = &self.client.inner.root_span)] pub async fn send_state_event( &self, content: impl StateEventContent, @@ -923,6 +1064,7 @@ impl Joined { /// } /// # anyhow::Ok(()) }); /// ``` + #[instrument(skip_all, parent = &self.client.inner.root_span)] pub async fn send_state_event_raw( &self, content: Value, @@ -973,6 +1115,7 @@ impl Joined { /// } /// # anyhow::Ok(()) }); /// ``` + #[instrument(skip_all, parent = &self.client.inner.root_span)] pub async fn redact( &self, event_id: &EventId, @@ -988,3 +1131,56 @@ impl Joined { self.client.send(request, None).await } } + +/// Receipts to send all at once. +#[derive(Debug, Clone, Default)] +pub struct Receipts { + pub(super) fully_read: Option, + pub(super) read_receipt: Option, + pub(super) private_read_receipt: Option, +} + +impl Receipts { + /// Create an empty `Receipts`. + pub fn new() -> Self { + Self::default() + } + + /// Set the last event the user has read. + /// + /// It means that the user has read all the events before this event. + /// + /// This is a private marker only visible by the user. + /// + /// Note that this is technically not a receipt as it is persisted in the + /// room account data. + pub fn fully_read_marker(mut self, event_id: impl Into>) -> Self { + self.fully_read = event_id.into(); + self + } + + /// Set the last event presented to the user and forward it to the other + /// users in the room. + /// + /// This is used to reset the unread messages/notification count and + /// advertise to other users the last event that the user has likely seen. + pub fn public_read_receipt(mut self, event_id: impl Into>) -> Self { + self.read_receipt = event_id.into(); + self + } + + /// Set the last event presented to the user and don't forward it. + /// + /// This is used to reset the unread messages/notification count. + pub fn private_read_receipt(mut self, event_id: impl Into>) -> Self { + self.private_read_receipt = event_id.into(); + self + } + + /// Whether this `Receipts` is empty. + pub fn is_empty(&self) -> bool { + self.fully_read.is_none() + && self.read_receipt.is_none() + && self.private_read_receipt.is_none() + } +} diff --git a/crates/matrix-sdk/src/room/left.rs b/crates/matrix-sdk/src/room/left.rs index 97875e37644..9690262e87d 100644 --- a/crates/matrix-sdk/src/room/left.rs +++ b/crates/matrix-sdk/src/room/left.rs @@ -3,28 +3,28 @@ use std::ops::Deref; use ruma::api::client::membership::forget_room; use super::Joined; -use crate::{room::Common, BaseRoom, Client, Result, RoomType}; +use crate::{room::Common, BaseRoom, Client, Result, RoomState}; /// A room in the left state. /// -/// This struct contains all methods specific to a `Room` with type -/// `RoomType::Left`. Operations may fail once the underlying `Room` changes -/// `RoomType`. +/// This struct contains all methods specific to a `Room` with +/// `RoomState::Left`. Operations may fail once the underlying `Room` changes +/// `RoomState`. #[derive(Debug, Clone)] pub struct Left { pub(crate) inner: Common, } impl Left { - /// Create a new `room::Left` if the underlying `Room` has type - /// `RoomType::Left`. + /// Create a new `room::Left` if the underlying `Room` has + /// `RoomState::Left`. /// /// # Arguments /// * `client` - The client used to make requests. /// /// * `room` - The underlying room. pub(crate) fn new(client: &Client, room: BaseRoom) -> Option { - if room.room_type() == RoomType::Left { + if room.state() == RoomState::Left { Some(Self { inner: Common::new(client.clone(), room) }) } else { None diff --git a/crates/matrix-sdk/src/room/mod.rs b/crates/matrix-sdk/src/room/mod.rs index 633213bec6d..333b7f737c8 100644 --- a/crates/matrix-sdk/src/room/mod.rs +++ b/crates/matrix-sdk/src/room/mod.rs @@ -2,7 +2,7 @@ use std::ops::Deref; -use crate::RoomType; +use crate::RoomState; mod common; mod invited; @@ -15,7 +15,7 @@ pub mod timeline; pub use self::{ common::{Common, Messages, MessagesOptions}, invited::Invited, - joined::Joined, + joined::{Joined, Receipts}, left::Left, member::RoomMember, }; @@ -45,10 +45,10 @@ impl Deref for Room { impl From for Room { fn from(room: Common) -> Self { - match room.room_type() { - RoomType::Joined => Self::Joined(Joined { inner: room }), - RoomType::Left => Self::Left(Left { inner: room }), - RoomType::Invited => Self::Invited(Invited { inner: room }), + match room.state() { + RoomState::Joined => Self::Joined(Joined { inner: room }), + RoomState::Left => Self::Left(Left { inner: room }), + RoomState::Invited => Self::Invited(Invited { inner: room }), } } } @@ -56,10 +56,10 @@ impl From for Room { impl From for Room { fn from(room: Joined) -> Self { let room = (*room).clone(); - match room.room_type() { - RoomType::Joined => Self::Joined(Joined { inner: room }), - RoomType::Left => Self::Left(Left { inner: room }), - RoomType::Invited => Self::Invited(Invited { inner: room }), + match room.state() { + RoomState::Joined => Self::Joined(Joined { inner: room }), + RoomState::Left => Self::Left(Left { inner: room }), + RoomState::Invited => Self::Invited(Invited { inner: room }), } } } @@ -67,10 +67,10 @@ impl From for Room { impl From for Room { fn from(room: Left) -> Self { let room = (*room).clone(); - match room.room_type() { - RoomType::Joined => Self::Joined(Joined { inner: room }), - RoomType::Left => Self::Left(Left { inner: room }), - RoomType::Invited => Self::Invited(Invited { inner: room }), + match room.state() { + RoomState::Joined => Self::Joined(Joined { inner: room }), + RoomState::Left => Self::Left(Left { inner: room }), + RoomState::Invited => Self::Invited(Invited { inner: room }), } } } @@ -78,10 +78,10 @@ impl From for Room { impl From for Room { fn from(room: Invited) -> Self { let room = (*room).clone(); - match room.room_type() { - RoomType::Joined => Self::Joined(Joined { inner: room }), - RoomType::Left => Self::Left(Left { inner: room }), - RoomType::Invited => Self::Invited(Invited { inner: room }), + match room.state() { + RoomState::Joined => Self::Joined(Joined { inner: room }), + RoomState::Left => Self::Left(Left { inner: room }), + RoomState::Invited => Self::Invited(Invited { inner: room }), } } } diff --git a/crates/matrix-sdk/src/room/timeline/builder.rs b/crates/matrix-sdk/src/room/timeline/builder.rs new file mode 100644 index 00000000000..6e57190ee37 --- /dev/null +++ b/crates/matrix-sdk/src/room/timeline/builder.rs @@ -0,0 +1,202 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use im::Vector; +use matrix_sdk_base::{ + deserialized_responses::{EncryptionInfo, SyncTimelineEvent}, + locks::Mutex, +}; +use ruma::{ + events::receipt::{ReceiptThread, ReceiptType, SyncReceiptEvent}, + push::Action, +}; +use tracing::error; + +#[cfg(feature = "e2e-encryption")] +use super::to_device::{handle_forwarded_room_key_event, handle_room_key_event}; +use super::{inner::TimelineInner, Timeline, TimelineEventHandlerHandles}; +use crate::room; + +/// Builder that allows creating and configuring various parts of a +/// [`Timeline`]. +#[must_use] +#[derive(Debug)] +pub(crate) struct TimelineBuilder { + room: room::Common, + prev_token: Option, + events: Vector, + track_read_marker_and_receipts: bool, +} + +impl TimelineBuilder { + pub(super) fn new(room: &room::Common) -> Self { + Self { + room: room.clone(), + prev_token: None, + events: Vector::new(), + track_read_marker_and_receipts: false, + } + } + + /// Add initial events to the timeline. + #[cfg(feature = "experimental-sliding-sync")] + pub(crate) fn events( + mut self, + prev_token: Option, + events: Vector, + ) -> Self { + self.prev_token = prev_token; + self.events = events; + self + } + + /// Enable tracking of the fully-read marker and the read receipts on the + /// timeline. + pub(crate) fn track_read_marker_and_receipts(mut self) -> Self { + self.track_read_marker_and_receipts = true; + self + } + + /// Create a [`Timeline`] with the options set on this builder. + pub(crate) async fn build(self) -> Timeline { + let Self { room, prev_token, events, track_read_marker_and_receipts } = self; + let has_events = !events.is_empty(); + + let mut inner = + TimelineInner::new(room).with_read_receipt_tracking(track_read_marker_and_receipts); + + if track_read_marker_and_receipts { + match inner + .room() + .user_receipt( + ReceiptType::Read, + ReceiptThread::Unthreaded, + inner.room().own_user_id(), + ) + .await + { + Ok(Some(read_receipt)) => { + inner.set_initial_user_receipt(ReceiptType::Read, read_receipt); + } + Err(e) => { + error!("Failed to get public read receipt of own user from the store: {e}"); + } + _ => {} + } + match inner + .room() + .user_receipt( + ReceiptType::ReadPrivate, + ReceiptThread::Unthreaded, + inner.room().own_user_id(), + ) + .await + { + Ok(Some(private_read_receipt)) => { + inner.set_initial_user_receipt(ReceiptType::ReadPrivate, private_read_receipt); + } + Err(e) => { + error!("Failed to get private read receipt of own user from the store: {e}"); + } + _ => {} + } + } + + if has_events { + inner.add_initial_events(events).await; + } + + let inner = Arc::new(inner); + let room = inner.room(); + + let timeline_event_handle = room.add_event_handler({ + let inner = inner.clone(); + move |event, encryption_info: Option, push_actions: Vec| { + let inner = inner.clone(); + async move { + inner.handle_live_event(event, encryption_info, push_actions).await; + } + } + }); + + // Not using room.add_event_handler here because RoomKey events are + // to-device events that are not received in the context of a room. + #[cfg(feature = "e2e-encryption")] + let room_key_handle = room + .client + .add_event_handler(handle_room_key_event(inner.clone(), room.room_id().to_owned())); + #[cfg(feature = "e2e-encryption")] + let forwarded_room_key_handle = room.client.add_event_handler( + handle_forwarded_room_key_event(inner.clone(), room.room_id().to_owned()), + ); + + let mut handles = vec![ + timeline_event_handle, + #[cfg(feature = "e2e-encryption")] + room_key_handle, + #[cfg(feature = "e2e-encryption")] + forwarded_room_key_handle, + ]; + + if track_read_marker_and_receipts { + inner.load_fully_read_event().await; + + let fully_read_handle = room.add_event_handler({ + let inner = inner.clone(); + move |event| { + let inner = inner.clone(); + async move { + inner.handle_fully_read(event).await; + } + } + }); + handles.push(fully_read_handle); + + let read_receipts_handle = room.add_event_handler({ + let inner = inner.clone(); + move |read_receipts: SyncReceiptEvent| { + let inner = inner.clone(); + async move { + inner.handle_read_receipts(read_receipts.content).await; + } + } + }); + handles.push(read_receipts_handle); + } + + let client = room.client.clone(); + let timeline = Timeline { + inner, + start_token: Mutex::new(prev_token), + _end_token: Mutex::new(None), + event_handler_handles: Arc::new(TimelineEventHandlerHandles { client, handles }), + }; + + #[cfg(feature = "e2e-encryption")] + if has_events { + // The events we're injecting might be encrypted events, but we might + // have received the room key to decrypt them while nobody was listening to the + // `m.room_key` event, let's retry now. + // + // TODO: We could spawn a task here and put this into the background, though it + // might not be worth it depending on the number of events we injected. + // Some measuring needs to be done. + timeline.retry_decryption_for_all_events().await; + } + + timeline + } +} diff --git a/crates/matrix-sdk/src/room/timeline/event_handler.rs b/crates/matrix-sdk/src/room/timeline/event_handler.rs index 85d229ca23b..1487bc94b97 100644 --- a/crates/matrix-sdk/src/room/timeline/event_handler.rs +++ b/crates/matrix-sdk/src/room/timeline/event_handler.rs @@ -14,16 +14,17 @@ use std::{collections::HashMap, sync::Arc}; -use chrono::{DateTime, Datelike, Local, TimeZone}; -use futures_signals::signal_vec::MutableVecLockMut; +use chrono::{Datelike, Local, TimeZone}; +use eyeball_im::ObservableVector; use indexmap::{map::Entry, IndexMap, IndexSet}; use matrix_sdk_base::deserialized_responses::EncryptionInfo; use ruma::{ events::{ reaction::ReactionEventContent, + receipt::{Receipt, ReceiptType}, relation::{Annotation, Replacement}, room::{ - encrypted::{self, RoomEncryptedEventContent}, + encrypted::RoomEncryptedEventContent, member::{Change, RoomMemberEventContent}, message::{self, MessageType, RoomMessageEventContent}, redaction::{ @@ -46,9 +47,11 @@ use super::{ MemberProfileChange, OtherState, Profile, RemoteEventTimelineItem, RoomMembershipChange, Sticker, }, - find_read_marker, rfind_event_by_id, rfind_event_item, EventTimelineItem, InReplyToDetails, - Message, ReactionGroup, TimelineDetails, TimelineInnerMetadata, TimelineItem, - TimelineItemContent, VirtualTimelineItem, + find_read_marker, + read_receipts::maybe_add_implicit_read_receipt, + rfind_event_by_id, rfind_event_item, EventTimelineItem, InReplyToDetails, Message, + ReactionGroup, TimelineDetails, TimelineInnerState, TimelineItem, TimelineItemContent, + VirtualTimelineItem, }; use crate::{events::SyncTimelineEventWithoutContent, room::timeline::MembershipChange}; @@ -72,6 +75,8 @@ pub(super) struct TimelineEventMetadata { pub(super) is_own_event: bool, pub(super) relations: BundledRelations, pub(super) encryption_info: Option, + pub(super) read_receipts: IndexMap, + pub(super) is_highlighted: bool, } #[derive(Clone)] @@ -181,6 +186,8 @@ pub(super) enum TimelineItemPosition { #[derive(Default)] pub(super) struct HandleEventResult { pub(super) item_added: bool, + #[cfg(feature = "e2e-encryption")] + pub(super) item_removed: bool, pub(super) items_updated: u16, } @@ -188,10 +195,10 @@ pub(super) struct HandleEventResult { // of handling an event (figuring out whether it should update an existing // timeline item, transforming that item or creating a new one, updating the // reactive Vec). -pub(super) struct TimelineEventHandler<'a, 'i> { +pub(super) struct TimelineEventHandler<'a> { meta: TimelineEventMetadata, flow: Flow, - timeline_items: &'a mut MutableVecLockMut<'i, Arc>, + items: &'a mut ObservableVector>, #[allow(clippy::type_complexity)] reaction_map: &'a mut HashMap< (Option, Option), @@ -200,6 +207,9 @@ pub(super) struct TimelineEventHandler<'a, 'i> { pending_reactions: &'a mut HashMap>, fully_read_event: &'a mut Option, fully_read_event_in_timeline: &'a mut bool, + track_read_receipts: bool, + users_read_receipts: + &'a mut HashMap>, result: HandleEventResult, } @@ -209,7 +219,7 @@ pub(super) struct TimelineEventHandler<'a, 'i> { macro_rules! update_timeline_item { ($this:ident, $event_id:expr, $action:expr, $update:expr) => { _update_timeline_item( - &mut *$this.timeline_items, + &mut *$this.items, &mut $this.result.items_updated, $event_id, $action, @@ -218,21 +228,23 @@ macro_rules! update_timeline_item { }; } -impl<'a, 'i> TimelineEventHandler<'a, 'i> { +impl<'a> TimelineEventHandler<'a> { pub(super) fn new( event_meta: TimelineEventMetadata, flow: Flow, - timeline_items: &'a mut MutableVecLockMut<'i, Arc>, - timeline_meta: &'a mut TimelineInnerMetadata, + state: &'a mut TimelineInnerState, + track_read_receipts: bool, ) -> Self { Self { meta: event_meta, flow, - timeline_items, - reaction_map: &mut timeline_meta.reaction_map, - pending_reactions: &mut timeline_meta.pending_reactions, - fully_read_event: &mut timeline_meta.fully_read_event, - fully_read_event_in_timeline: &mut timeline_meta.fully_read_event_in_timeline, + items: &mut state.items, + reaction_map: &mut state.reaction_map, + pending_reactions: &mut state.pending_reactions, + fully_read_event: &mut state.fully_read_event, + fully_read_event_in_timeline: &mut state.fully_read_event_in_timeline, + track_read_receipts, + users_read_receipts: &mut state.users_read_receipts, result: HandleEventResult::default(), } } @@ -259,6 +271,8 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { } } + trace!("Handling event"); + match event_kind { TimelineEventKind::Message { content } => match content { AnyMessageLikeEventContent::Reaction(c) => { @@ -312,6 +326,17 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { } if !self.result.item_added { + trace!("No new item added"); + + #[cfg(feature = "e2e-encryption")] + if let Flow::Remote { position: TimelineItemPosition::Update(idx), .. } = self.flow { + // If add was not called, that means the UTD event is one that + // wouldn't normally be visible. Remove it. + trace!("Removing UTD that was successfully retried"); + self.items.remove(idx); + self.result.item_removed = true; + } + // TODO: Add event as raw } @@ -378,7 +403,7 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { } }; - if let Some((idx, event_item)) = rfind_event_by_id(self.timeline_items, event_id) { + if let Some((idx, event_item)) = rfind_event_by_id(self.items, event_id) { let EventTimelineItem::Remote(remote_event_item) = event_item else { error!("inconsistent state: reaction received on a non-remote event item"); return; @@ -386,11 +411,11 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { // Handling of reactions on redacted events is an open question. // For now, ignore reactions on redacted events like Element does. - if let TimelineItemContent::RedactedMessage = remote_event_item.content { + if let TimelineItemContent::RedactedMessage = remote_event_item.content() { debug!("Ignoring reaction on redacted event"); return; } else { - let mut reactions = remote_event_item.reactions.clone(); + let mut reactions = remote_event_item.reactions().clone(); let reaction_group = reactions.entry(c.relates_to.key.clone()).or_default(); if let Some(txn_id) = old_txn_id { @@ -405,7 +430,7 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { reaction_group.0.insert(reaction_id.clone(), self.meta.sender.clone()); trace!("Adding reaction"); - self.timeline_items.set_cloned( + self.items.set( idx, Arc::new(TimelineItem::Event( remote_event_item.with_reactions(reactions).into(), @@ -439,16 +464,8 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { #[instrument(skip_all)] fn handle_room_encrypted(&mut self, c: RoomEncryptedEventContent) { - match c.relates_to { - Some(encrypted::Relation::Replacement(_) | encrypted::Relation::Annotation(_)) => { - // Do nothing for these, as they would not produce a new - // timeline item when decrypted either - debug!("Ignoring aggregating event that failed to decrypt"); - } - _ => { - self.add(NewEventTimelineItem::unable_to_decrypt(c)); - } - } + // TODO: Handle replacements if the replaced event is also UTD + self.add(NewEventTimelineItem::unable_to_decrypt(c)); } // Redacted redactions are no-ops (unfortunately) @@ -461,7 +478,7 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { return None; }; - let mut reactions = remote_event_item.reactions.clone(); + let mut reactions = remote_event_item.reactions().clone(); let count = { let Entry::Occupied(mut group_entry) = reactions.entry(rel.key.clone()) else { @@ -529,15 +546,17 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { let sender_profile = TimelineDetails::from_initial_value(self.meta.sender_profile.clone()); let mut reactions = self.pending_reactions().unwrap_or_default(); - let item = match &self.flow { - Flow::Local { txn_id, timestamp } => EventTimelineItem::Local(LocalEventTimelineItem { - send_state: EventSendState::NotSentYet, - transaction_id: txn_id.to_owned(), - sender, - sender_profile, - timestamp: *timestamp, - content, - }), + let mut item = match &self.flow { + Flow::Local { txn_id, timestamp } => { + EventTimelineItem::Local(LocalEventTimelineItem::new( + EventSendState::NotSentYet, + txn_id.to_owned(), + sender, + sender_profile, + *timestamp, + content, + )) + } Flow::Remote { event_id, origin_server_ts, raw_event, .. } => { // Drop pending reactions if the message is redacted. if let TimelineItemContent::RedactedMessage = content { @@ -546,51 +565,67 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { } } - EventTimelineItem::Remote(RemoteEventTimelineItem { - event_id: event_id.clone(), + EventTimelineItem::Remote(RemoteEventTimelineItem::new( + event_id.clone(), sender, sender_profile, - timestamp: *origin_server_ts, + *origin_server_ts, content, reactions, - is_own: self.meta.is_own_event, - encryption_info: self.meta.encryption_info.clone(), - raw: raw_event.clone(), - }) + self.meta.read_receipts.clone(), + self.meta.is_own_event, + self.meta.encryption_info.clone(), + raw_event.clone(), + self.meta.is_highlighted, + )) } }; - let item = Arc::new(TimelineItem::Event(item)); - match &self.flow { Flow::Local { timestamp, .. } => { + trace!("Adding new local timeline item"); + // Check if the latest event has the same date as this event. - if let Some(latest_event) = self - .timeline_items - .iter() - .rfind(|item| item.as_event().is_some()) - .and_then(|item| item.as_event()) + if let Some(latest_event) = self.items.iter().rev().find_map(|item| item.as_event()) { let old_ts = latest_event.timestamp(); if let Some(day_divider_item) = maybe_create_day_divider_from_timestamps(old_ts, *timestamp) { - self.timeline_items.push_cloned(Arc::new(day_divider_item)); + trace!("Adding day divider"); + self.items.push_back(Arc::new(day_divider_item)); } } else { // If there is no event item, there is no day divider yet. - self.timeline_items - .push_cloned(Arc::new(TimelineItem::day_divider(*timestamp))); + trace!("Adding first day divider"); + self.items.push_back(Arc::new(TimelineItem::day_divider(*timestamp))); } - self.timeline_items.push_cloned(item); + self.items.push_back(Arc::new(item.into())); } - Flow::Remote { position: TimelineItemPosition::Start, origin_server_ts, .. } => { + Flow::Remote { + position: TimelineItemPosition::Start, + event_id, + origin_server_ts, + .. + } => { + if self + .items + .iter() + .filter_map(|ev| ev.as_event()?.event_id()) + .any(|id| id == event_id) + { + trace!("Skipping back-paginated event that has already been seen"); + return; + } + + trace!("Adding new remote timeline item at the start"); + // If there is a loading indicator at the top, check for / insert the day // divider at position 1 and the new event at 2 rather than 0 and 1. - let offset = match self.timeline_items.first().and_then(|item| item.as_virtual()) { + let offset = match self.items.front().and_then(|item| item.as_virtual()) { Some( VirtualTimelineItem::LoadingIndicator | VirtualTimelineItem::TimelineStart, ) => 1, @@ -599,22 +634,30 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { // Check if the earliest day divider has the same date as this event. if let Some(VirtualTimelineItem::DayDivider(divider_ts)) = - self.timeline_items.get(offset).and_then(|item| item.as_virtual()) + self.items.get(offset).and_then(|item| item.as_virtual()) { if let Some(day_divider_item) = maybe_create_day_divider_from_timestamps(*divider_ts, *origin_server_ts) { - self.timeline_items.insert_cloned(offset, Arc::new(day_divider_item)); + self.items.insert(offset, Arc::new(day_divider_item)); } } else { // The list must always start with a day divider. - self.timeline_items.insert_cloned( + self.items + .insert(offset, Arc::new(TimelineItem::day_divider(*origin_server_ts))); + } + + if self.track_read_receipts { + maybe_add_implicit_read_receipt( offset, - Arc::new(TimelineItem::day_divider(*origin_server_ts)), + &mut item, + self.meta.is_own_event, + self.items, + self.users_read_receipts, ); } - self.timeline_items.insert_cloned(offset + 1, item); + self.items.insert(offset + 1, Arc::new(item.into())); } Flow::Remote { @@ -624,24 +667,17 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { origin_server_ts, .. } => { - let result = rfind_event_item(self.timeline_items, |it| { + let result = rfind_event_item(self.items, |it| { txn_id.is_some() && it.transaction_id() == txn_id.as_deref() || it.event_id() == Some(event_id) }); if let Some((idx, old_item)) = result { if let EventTimelineItem::Remote(old_item) = old_item { - // Item was previously received by the server. Until we - // implement forwards pagination, this indicates a bug - // somewhere. - warn!(?item, ?old_item, "Received duplicate event"); - - // With /messages and /sync sometimes disagreeing on - // order of messages, we might want to change the - // position in some circumstances, but for now this - // should be good enough. - self.timeline_items.set_cloned(idx, item); - return; + // Item was previously received from the server. This + // should be very rare normally, but with the sliding- + // sync proxy, it is actually very common. + trace!(?item, ?old_item, "Received duplicate event"); }; if txn_id.is_none() { @@ -653,10 +689,59 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { trace!("Received remote echo without transaction ID"); } - // Remove local echo, remote echo will be added below // TODO: Check whether anything is different about the // old and new item? - self.timeline_items.remove(idx); + + if idx == self.items.len() - 1 + && timestamp_to_date(old_item.timestamp()) + == timestamp_to_date(*origin_server_ts) + { + // If the old item is the last one and no day divider + // changes need to happen, replace and return early. + + if self.track_read_receipts { + maybe_add_implicit_read_receipt( + idx, + &mut item, + self.meta.is_own_event, + self.items, + self.users_read_receipts, + ); + } + + trace!(idx, "Replacing existing event"); + self.items.set(idx, Arc::new(item.into())); + return; + } else { + // In more complex cases, remove the item and day + // divider (if necessary) before re-adding the item. + trace!("Removing local echo or duplicate timeline item"); + self.items.remove(idx); + + assert_ne!( + idx, 0, + "there is never an event item at index 0 because \ + the first event item is preceded by a day divider" + ); + + // Pre-requisites for removing the day divider: + // 1. there is one preceding the old item at all + if self.items[idx - 1].is_day_divider() + // 2. the item after the old one that was removed + // is virtual (it should be impossible for this + // to be a read marker) + && self + .items + .get(idx) + .map_or(true, |item| item.is_virtual()) + { + trace!("Removing day divider"); + self.items.remove(idx - 1); + } + + // no return here, below code for adding a new event + // will run to re-add the removed item + } } else if txn_id.is_some() { warn!( "Received event with transaction ID, but didn't \ @@ -665,35 +750,47 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { } // Check if the latest event has the same date as this event. - if let Some(latest_event) = - self.timeline_items.iter().rev().find_map(|item| item.as_event()) + if let Some(latest_event) = self.items.iter().rev().find_map(|item| item.as_event()) { let old_ts = latest_event.timestamp(); if let Some(day_divider_item) = maybe_create_day_divider_from_timestamps(old_ts, *origin_server_ts) { - self.timeline_items.push_cloned(Arc::new(day_divider_item)); + trace!("Adding day divider"); + self.items.push_back(Arc::new(day_divider_item)); } } else { - // If there is not event item, there is no day divider yet. - self.timeline_items - .push_cloned(Arc::new(TimelineItem::day_divider(*origin_server_ts))); + // If there is no event item, there is no day divider yet. + trace!("Adding first day divider"); + self.items.push_back(Arc::new(TimelineItem::day_divider(*origin_server_ts))); + } + + if self.track_read_receipts { + maybe_add_implicit_read_receipt( + self.items.len(), + &mut item, + self.meta.is_own_event, + self.items, + self.users_read_receipts, + ); } - self.timeline_items.push_cloned(item); + trace!("Adding new remote timeline item at the end"); + self.items.push_back(Arc::new(item.into())); } #[cfg(feature = "e2e-encryption")] Flow::Remote { position: TimelineItemPosition::Update(idx), .. } => { - self.timeline_items.set_cloned(*idx, item); + trace!("Updating timeline item at position {idx}"); + self.items.set(*idx, Arc::new(item.into())); } } // See if we got the event corresponding to the read marker now. if !*self.fully_read_event_in_timeline { update_read_marker( - self.timeline_items, + self.items, self.fully_read_event.as_deref(), self.fully_read_event_in_timeline, ); @@ -728,19 +825,21 @@ impl<'a, 'i> TimelineEventHandler<'a, 'i> { } pub(crate) fn update_read_marker( - items_lock: &mut MutableVecLockMut<'_, Arc>, + items: &mut ObservableVector>, fully_read_event: Option<&EventId>, fully_read_event_in_timeline: &mut bool, ) { let Some(fully_read_event) = fully_read_event else { return }; - let read_marker_idx = find_read_marker(items_lock); - let fully_read_event_idx = rfind_event_by_id(items_lock, fully_read_event).map(|(idx, _)| idx); + trace!(?fully_read_event, "Updating read marker"); + + let read_marker_idx = find_read_marker(items); + let fully_read_event_idx = rfind_event_by_id(items, fully_read_event).map(|(idx, _)| idx); match (read_marker_idx, fully_read_event_idx) { (None, None) => {} (None, Some(idx)) => { *fully_read_event_in_timeline = true; - items_lock.insert_cloned(idx + 1, Arc::new(TimelineItem::read_marker())); + items.insert(idx + 1, Arc::new(TimelineItem::read_marker())); } (Some(_), None) => { // Keep the current position of the read marker, hopefully we @@ -752,22 +851,28 @@ pub(crate) fn update_read_marker( // The read marker can't move backwards. if from < to { - items_lock.move_from_to(from, to); + let item = items.remove(from); + // Since the fully-read event's index was shifted to the left + // by one position by the remove call above, insert the fully- + // read marker at its previous position, rather than that + 1 + items.insert(to, item); } } } } fn _update_timeline_item( - timeline_items: &mut MutableVecLockMut<'_, Arc>, + items: &mut ObservableVector>, items_updated: &mut u16, event_id: &EventId, action: &str, update: impl FnOnce(&EventTimelineItem) -> Option, ) { - if let Some((idx, item)) = rfind_event_by_id(timeline_items, event_id) { + if let Some((idx, item)) = rfind_event_by_id(items, event_id) { + trace!("Found timeline item to update"); if let Some(new_item) = update(item) { - timeline_items.set_cloned(idx, Arc::new(TimelineItem::Event(new_item))); + trace!("Updating item"); + items.set(idx, Arc::new(TimelineItem::Event(new_item))); *items_updated += 1; } } else { @@ -775,19 +880,24 @@ fn _update_timeline_item( } } -/// Converts a timestamp since Unix Epoch to a local date and time. -fn timestamp_to_local_datetime(ts: MilliSecondsSinceUnixEpoch) -> DateTime { - Local +#[derive(PartialEq)] +struct Date { + year: i32, + month: u32, + day: u32, +} + +/// Converts a timestamp since Unix Epoch to a year, month and day. +fn timestamp_to_date(ts: MilliSecondsSinceUnixEpoch) -> Date { + let datetime = Local .timestamp_millis_opt(ts.0.into()) // Only returns `None` if date is after Dec 31, 262143 BCE. .single() // Fallback to the current date to avoid issues with malicious // homeservers. - .unwrap_or_else(Local::now) -} + .unwrap_or_else(Local::now); -fn datetime_to_ymd(datetime: DateTime) -> (i32, u32, u32) { - (datetime.year(), datetime.month(), datetime.day()) + Date { year: datetime.year(), month: datetime.month(), day: datetime.day() } } /// Returns a new day divider item for the new timestamp if it is on a different @@ -796,14 +906,8 @@ fn maybe_create_day_divider_from_timestamps( old_ts: MilliSecondsSinceUnixEpoch, new_ts: MilliSecondsSinceUnixEpoch, ) -> Option { - let old_date = timestamp_to_local_datetime(old_ts); - let new_date = timestamp_to_local_datetime(new_ts); - - if datetime_to_ymd(old_date) != datetime_to_ymd(new_date) { - Some(TimelineItem::day_divider(new_ts)) - } else { - None - } + (timestamp_to_date(old_ts) != timestamp_to_date(new_ts)) + .then(|| TimelineItem::day_divider(new_ts)) } struct NewEventTimelineItem { diff --git a/crates/matrix-sdk/src/room/timeline/event_item.rs b/crates/matrix-sdk/src/room/timeline/event_item/content.rs similarity index 54% rename from crates/matrix-sdk/src/room/timeline/event_item.rs rename to crates/matrix-sdk/src/room/timeline/event_item/content.rs index 54159be7693..6d239b07a07 100644 --- a/crates/matrix-sdk/src/room/timeline/event_item.rs +++ b/crates/matrix-sdk/src/room/timeline/event_item/content.rs @@ -1,21 +1,7 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - use std::{fmt, ops::Deref, sync::Arc}; use indexmap::IndexMap; -use matrix_sdk_base::deserialized_responses::{EncryptionInfo, TimelineEvent}; +use matrix_sdk_base::deserialized_responses::TimelineEvent; use ruma::{ events::{ policy::rule::{ @@ -44,368 +30,17 @@ use ruma::{ }, space::{child::SpaceChildEventContent, parent::SpaceParentEventContent}, sticker::StickerEventContent, - AnyFullStateEventContent, AnyMessageLikeEventContent, AnySyncTimelineEvent, - AnyTimelineEvent, FullStateEventContent, MessageLikeEventType, StateEventType, + AnyFullStateEventContent, AnyMessageLikeEventContent, AnyTimelineEvent, + FullStateEventContent, MessageLikeEventType, StateEventType, }, - serde::Raw, - EventId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedEventId, OwnedMxcUri, - OwnedTransactionId, OwnedUserId, TransactionId, UserId, + OwnedDeviceId, OwnedEventId, OwnedMxcUri, OwnedTransactionId, OwnedUserId, UserId, }; -use super::inner::ProfileProvider; -use crate::{Error, Result}; - -/// An item in the timeline that represents at least one event. -/// -/// There is always one main event that gives the `EventTimelineItem` its -/// identity but in many cases, additional events like reactions and edits are -/// also part of the item. -#[derive(Debug, Clone)] -pub enum EventTimelineItem { - /// An event item that has been sent, but not yet acknowledged by the - /// server. - Local(LocalEventTimelineItem), - /// An event item that has eben sent _and_ acknowledged by the server. - Remote(RemoteEventTimelineItem), -} - -impl EventTimelineItem { - /// Get the `LocalEventTimelineItem` if `self` is `Local`. - pub fn as_local(&self) -> Option<&LocalEventTimelineItem> { - match self { - Self::Local(local_event_item) => Some(local_event_item), - Self::Remote(_) => None, - } - } - - /// Get the `RemoteEventTimelineItem` if `self` is `Remote`. - pub fn as_remote(&self) -> Option<&RemoteEventTimelineItem> { - match self { - Self::Local(_) => None, - Self::Remote(remote_event_item) => Some(remote_event_item), - } - } - - /// Get a unique identifier to identify the event item, either by using - /// transaction ID or event ID in case of a local event, or by event ID in - /// case of a remote event. - pub fn unique_identifier(&self) -> String { - match self { - Self::Local(LocalEventTimelineItem { transaction_id, send_state, .. }) => { - match send_state { - EventSendState::Sent { event_id } => event_id.to_string(), - _ => transaction_id.to_string(), - } - } - - Self::Remote(RemoteEventTimelineItem { event_id, .. }) => event_id.to_string(), - } - } - - /// Get the transaction ID of this item. - /// - /// The transaction ID is only kept until the remote echo for a local event - /// is received, at which point the `EventTimelineItem::Local` is - /// transformed to `EventTimelineItem::Remote` and the transaction ID - /// discarded. - pub fn transaction_id(&self) -> Option<&TransactionId> { - match self { - Self::Local(local) => Some(&local.transaction_id), - Self::Remote(_) => None, - } - } - - /// Get the event ID of this item. - /// - /// If this returns `Some(_)`, the event was successfully created by the - /// server. - /// - /// Even if this is a [`Local`](Self::Local) event,, this can be `Some(_)` - /// as the event ID can be known not just from the remote echo via - /// `sync_events`, but also from the response of the send request that - /// created the event. - pub fn event_id(&self) -> Option<&EventId> { - match self { - Self::Local(local_event) => local_event.event_id(), - Self::Remote(remote_event) => Some(&remote_event.event_id), - } - } - - /// Get the sender of this item. - pub fn sender(&self) -> &UserId { - match self { - Self::Local(local_event) => &local_event.sender, - Self::Remote(remote_event) => &remote_event.sender, - } - } - - /// Get the profile of the sender. - pub fn sender_profile(&self) -> &TimelineDetails { - match self { - Self::Local(local_event) => &local_event.sender_profile, - Self::Remote(remote_event) => &remote_event.sender_profile, - } - } - - /// Get the content of this item. - pub fn content(&self) -> &TimelineItemContent { - match self { - Self::Local(local_event) => &local_event.content, - Self::Remote(remote_event) => &remote_event.content, - } - } - - /// Get the timestamp of this item. - /// - /// If this event hasn't been echoed back by the server yet, returns the - /// time the local event was created. Otherwise, returns the origin - /// server timestamp. - pub fn timestamp(&self) -> MilliSecondsSinceUnixEpoch { - match self { - Self::Local(local_event) => local_event.timestamp, - Self::Remote(remote_event) => remote_event.timestamp, - } - } - - /// Whether this timeline item was sent by the logged-in user themselves. - pub fn is_own(&self) -> bool { - match self { - Self::Local(_) => true, - Self::Remote(remote_event) => remote_event.is_own, - } - } - - /// Flag indicating this timeline item can be edited by current user. - pub fn is_editable(&self) -> bool { - match self.content() { - TimelineItemContent::Message(message) => { - self.is_own() - && matches!(message.msgtype(), MessageType::Text(_) | MessageType::Emote(_)) - } - _ => false, - } - } - - /// Get the raw JSON representation of the initial event (the one that - /// caused this timeline item to be created). - /// - /// Returns `None` if this event hasn't been echoed back by the server - /// yet. - pub fn raw(&self) -> Option<&Raw> { - match self { - Self::Local(_local_event) => None, - Self::Remote(remote_event) => Some(&remote_event.raw), - } - } - - /// Clone the current event item, and update its `content`. - pub(super) fn with_content(&self, content: TimelineItemContent) -> Self { - match self { - Self::Local(local_event_item) => { - Self::Local(LocalEventTimelineItem { content, ..local_event_item.clone() }) - } - Self::Remote(remote_event_item) => { - Self::Remote(RemoteEventTimelineItem { content, ..remote_event_item.clone() }) - } - } - } - - /// Clone the current event item, and update its `sender_profile`. - pub(super) fn with_sender_profile(&self, sender_profile: TimelineDetails) -> Self { - match self { - EventTimelineItem::Local(item) => { - Self::Local(LocalEventTimelineItem { sender_profile, ..item.clone() }) - } - EventTimelineItem::Remote(item) => { - Self::Remote(RemoteEventTimelineItem { sender_profile, ..item.clone() }) - } - } - } -} - -/// This type represents the "send state" of a local event timeline item. -#[derive(Clone, Debug)] -pub enum EventSendState { - /// The local event has not been sent yet. - NotSentYet, - /// The local event has been sent to the server, but unsuccessfully: The - /// sending has failed. - SendingFailed { - /// Details about how sending the event failed. - error: Arc, - }, - /// The local event has been sent successfully to the server. - Sent { - /// The event ID assigned by the server. - event_id: OwnedEventId, - }, -} - -#[derive(Debug, Clone)] -pub struct LocalEventTimelineItem { - /// The send state of this local event. - pub send_state: EventSendState, - /// The transaction ID. - pub transaction_id: OwnedTransactionId, - /// The sender of the event. - pub sender: OwnedUserId, - /// The sender's profile of the event. - pub sender_profile: TimelineDetails, - /// The timestamp of the event. - pub timestamp: MilliSecondsSinceUnixEpoch, - /// The content of the event. - pub content: TimelineItemContent, -} - -impl LocalEventTimelineItem { - /// Get the event ID of this item. - /// - /// Will be `Some` if and only if `send_state` is `EventSendState::Sent`. - pub fn event_id(&self) -> Option<&EventId> { - match &self.send_state { - EventSendState::Sent { event_id } => Some(event_id), - _ => None, - } - } - - /// Clone the current event item, and update its `send_state`. - pub(super) fn with_send_state(&self, send_state: EventSendState) -> Self { - Self { send_state, ..self.clone() } - } -} - -impl From for EventTimelineItem { - fn from(value: LocalEventTimelineItem) -> Self { - Self::Local(value) - } -} - -#[derive(Clone)] -pub struct RemoteEventTimelineItem { - /// The event ID. - pub event_id: OwnedEventId, - /// The sender of the event. - pub sender: OwnedUserId, - /// The sender's profile of the event. - pub sender_profile: TimelineDetails, - /// The timestamp of the event. - pub timestamp: MilliSecondsSinceUnixEpoch, - /// The content of the event. - pub content: TimelineItemContent, - /// All bundled reactions about the event. - pub reactions: BundledReactions, - /// Whether the event has been sent by the the logged-in user themselves. - pub is_own: bool, - /// Encryption information. - pub encryption_info: Option, - // FIXME: Expose the raw JSON of aggregated events somehow - pub raw: Raw, -} - -impl RemoteEventTimelineItem { - /// Clone the current event item, and update its `reactions`. - pub(super) fn with_reactions(&self, reactions: BundledReactions) -> Self { - Self { reactions, ..self.clone() } - } - - /// Clone the current event item, and update its `content`. - pub(super) fn with_content(&self, content: TimelineItemContent) -> Self { - Self { content, ..self.clone() } - } - - /// Clone the current event item, change its `content` to - /// [`TimelineItemContent::RedactedMessage`], and reset its `reactions`. - pub(super) fn to_redacted(&self) -> Self { - Self { - // FIXME: Change when we support state events - content: TimelineItemContent::RedactedMessage, - reactions: BundledReactions::default(), - ..self.clone() - } - } - - /// Get the reactions of this item. - pub fn reactions(&self) -> &BundledReactions { - // FIXME: Find out the state of incomplete bundled reactions, adjust - // Ruma if necessary, return the whole BundledReactions field - &self.reactions - } -} - -impl From for EventTimelineItem { - fn from(value: RemoteEventTimelineItem) -> Self { - Self::Remote(value) - } -} - -impl fmt::Debug for RemoteEventTimelineItem { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RemoteEventTimelineItem") - .field("event_id", &self.event_id) - .field("sender", &self.sender) - .field("timestamp", &self.timestamp) - .field("content", &self.content) - .field("reactions", &self.reactions) - .field("is_own", &self.is_own) - .field("encryption_info", &self.encryption_info) - // skip raw, too noisy - .finish_non_exhaustive() - } -} - -/// The display name and avatar URL of a room member. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Profile { - /// The display name, if set. - pub display_name: Option, - /// Whether the display name is ambiguous. - /// - /// Note that in rooms with lazy-loading enabled, this could be `false` even - /// though the display name is actually ambiguous if not all member events - /// have been seen yet. - pub display_name_ambiguous: bool, - /// The avatar URL, if set. - pub avatar_url: Option, -} - -/// Some details of an [`EventTimelineItem`] that may require server requests -/// other than just the regular -/// [`sync_events`][ruma::api::client::sync::sync_events]. -#[derive(Clone, Debug)] -pub enum TimelineDetails { - /// The details are not available yet, and have not been request from the - /// server. - Unavailable, - - /// The details are not available yet, but have been requested. - Pending, - - /// The details are available. - Ready(T), - - /// An error occurred when fetching the details. - Error(Arc), -} - -impl TimelineDetails { - pub(crate) fn from_initial_value(value: Option) -> Self { - match value { - Some(v) => Self::Ready(v), - None => Self::Unavailable, - } - } - - pub(crate) fn is_unavailable(&self) -> bool { - matches!(self, Self::Unavailable) - } - - pub(crate) fn contains(&self, value: &U) -> bool - where - T: PartialEq, - { - matches!(self, Self::Ready(v) if v == value) - } -} +use super::{Profile, TimelineDetails}; +use crate::{ + room::timeline::{inner::RoomDataProvider, Error as TimelineError}, + Result, +}; /// The content of an [`EventTimelineItem`]. #[derive(Clone, Debug)] @@ -476,9 +111,9 @@ impl TimelineItemContent { /// An `m.room.message` event or extensible event, including edits. #[derive(Clone)] pub struct Message { - pub(super) msgtype: MessageType, - pub(super) in_reply_to: Option, - pub(super) edited: bool, + pub(in crate::room::timeline) msgtype: MessageType, + pub(in crate::room::timeline) in_reply_to: Option, + pub(in crate::room::timeline) edited: bool, } impl Message { @@ -504,11 +139,15 @@ impl Message { self.edited } - pub(super) fn with_in_reply_to(&self, in_reply_to: InReplyToDetails) -> Self { + pub(in crate::room::timeline) fn with_in_reply_to( + &self, + in_reply_to: InReplyToDetails, + ) -> Self { Self { in_reply_to: Some(in_reply_to), ..self.clone() } } } +#[cfg(not(tarpaulin_include))] impl fmt::Debug for Message { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // since timeline items are logged, don't include all fields here so @@ -535,7 +174,7 @@ pub struct InReplyToDetails { } impl InReplyToDetails { - pub(super) fn from_relation(relation: Relation) -> Option { + pub(in crate::room::timeline) fn from_relation(relation: Relation) -> Option { match relation { message::Relation::Reply { in_reply_to } => { Some(Self { event_id: in_reply_to.event_id, details: TimelineDetails::Unavailable }) @@ -548,9 +187,9 @@ impl InReplyToDetails { /// An event that is replied to. #[derive(Clone, Debug)] pub struct RepliedToEvent { - pub(super) message: Message, - pub(super) sender: OwnedUserId, - pub(super) sender_profile: TimelineDetails, + pub(in crate::room::timeline) message: Message, + pub(in crate::room::timeline) sender: OwnedUserId, + pub(in crate::room::timeline) sender_profile: TimelineDetails, } impl RepliedToEvent { @@ -569,19 +208,19 @@ impl RepliedToEvent { &self.sender_profile } - pub(super) async fn try_from_timeline_event( + pub(in crate::room::timeline) async fn try_from_timeline_event( timeline_event: TimelineEvent, - profile_provider: &P, + room_data_provider: &P, ) -> Result { let event = match timeline_event.event.deserialize() { Ok(AnyTimelineEvent::MessageLike(event)) => event, _ => { - return Err(super::Error::UnsupportedEvent.into()); + return Err(TimelineError::UnsupportedEvent.into()); } }; let Some(AnyMessageLikeEventContent::RoomMessage(c)) = event.original_content() else { - return Err(super::Error::UnsupportedEvent.into()); + return Err(TimelineError::UnsupportedEvent.into()); }; let message = Message { @@ -591,7 +230,7 @@ impl RepliedToEvent { }; let sender = event.sender().to_owned(); let sender_profile = - TimelineDetails::from_initial_value(profile_provider.profile(&sender).await); + TimelineDetails::from_initial_value(room_data_provider.profile(&sender).await); Ok(Self { message, sender, sender_profile }) } @@ -647,14 +286,16 @@ impl From for EncryptedMessage { /// Value: The group of reactions. pub type BundledReactions = IndexMap; +// The long type after a long visibility specified trips up rustfmt currently. +// This works around. Report: https://github.com/rust-lang/rustfmt/issues/5703 +type ReactionGroupInner = IndexMap<(Option, Option), OwnedUserId>; + /// A group of reaction events on the same event with the same key. /// /// This is a map of the event ID or transaction ID of the reactions to the ID /// of the sender of the reaction. #[derive(Clone, Debug, Default)] -pub struct ReactionGroup( - pub(super) IndexMap<(Option, Option), OwnedUserId>, -); +pub struct ReactionGroup(pub(in crate::room::timeline) ReactionGroupInner); impl ReactionGroup { /// The senders of the reactions in this group. @@ -674,7 +315,7 @@ impl Deref for ReactionGroup { /// An `m.sticker` event. #[derive(Clone, Debug)] pub struct Sticker { - pub(super) content: StickerEventContent, + pub(in crate::room::timeline) content: StickerEventContent, } impl Sticker { @@ -687,9 +328,9 @@ impl Sticker { /// An event changing a room membership. #[derive(Clone, Debug)] pub struct RoomMembershipChange { - pub(super) user_id: OwnedUserId, - pub(super) content: FullStateEventContent, - pub(super) change: Option, + pub(in crate::room::timeline) user_id: OwnedUserId, + pub(in crate::room::timeline) content: FullStateEventContent, + pub(in crate::room::timeline) change: Option, } impl RoomMembershipChange { @@ -776,9 +417,9 @@ pub enum MembershipChange { /// membership is already `join`. #[derive(Clone, Debug)] pub struct MemberProfileChange { - pub(super) user_id: OwnedUserId, - pub(super) displayname_change: Option>>, - pub(super) avatar_url_change: Option>>, + pub(in crate::room::timeline) user_id: OwnedUserId, + pub(in crate::room::timeline) displayname_change: Option>>, + pub(in crate::room::timeline) avatar_url_change: Option>>, } impl MemberProfileChange { @@ -932,8 +573,8 @@ impl AnyOtherFullStateEventContent { /// A state event that doesn't have its own variant. #[derive(Clone, Debug)] pub struct OtherState { - pub(super) state_key: String, - pub(super) content: AnyOtherFullStateEventContent, + pub(in crate::room::timeline) state_key: String, + pub(in crate::room::timeline) content: AnyOtherFullStateEventContent, } impl OtherState { diff --git a/crates/matrix-sdk/src/room/timeline/event_item/local.rs b/crates/matrix-sdk/src/room/timeline/event_item/local.rs new file mode 100644 index 00000000000..b1a58ae7968 --- /dev/null +++ b/crates/matrix-sdk/src/room/timeline/event_item/local.rs @@ -0,0 +1,101 @@ +use ruma::{ + EventId, MilliSecondsSinceUnixEpoch, OwnedTransactionId, OwnedUserId, TransactionId, UserId, +}; + +use super::{EventSendState, Profile, TimelineDetails, TimelineItemContent}; + +/// An item for an event that was created locally and not yet echoed back by +/// the homeserver. +#[derive(Debug, Clone)] +pub struct LocalEventTimelineItem { + /// The send state of this local event. + send_state: EventSendState, + /// The transaction ID. + transaction_id: OwnedTransactionId, + /// The sender of the event. + sender: OwnedUserId, + /// The sender's profile of the event. + sender_profile: TimelineDetails, + /// The timestamp of the event. + timestamp: MilliSecondsSinceUnixEpoch, + /// The content of the event. + content: TimelineItemContent, +} + +impl LocalEventTimelineItem { + pub(in crate::room::timeline) fn new( + send_state: EventSendState, + transaction_id: OwnedTransactionId, + sender: OwnedUserId, + sender_profile: TimelineDetails, + timestamp: MilliSecondsSinceUnixEpoch, + content: TimelineItemContent, + ) -> Self { + Self { send_state, transaction_id, sender, sender_profile, timestamp, content } + } + + /// Get the event's send state. + pub fn send_state(&self) -> &EventSendState { + &self.send_state + } + + /// Get the event ID of this item. + /// + /// Will be `Some` if and only if `send_state` is + /// `EventSendState::Sent`. + pub fn event_id(&self) -> Option<&EventId> { + match &self.send_state { + EventSendState::Sent { event_id } => Some(event_id), + _ => None, + } + } + + /// Get the transaction ID of the event. + pub fn transaction_id(&self) -> &TransactionId { + &self.transaction_id + } + + /// Get the sender of the event. + /// + /// This is always the user's own user ID. + pub(crate) fn sender(&self) -> &UserId { + &self.sender + } + + /// Get the profile of the event's sender. + /// + /// Since `LocalEventTimelineItem`s are always sent by the user that is + /// logged in with the client that created the timeline, this effectively + /// gives the sender's own (possibly room-specific) profile. + pub fn sender_profile(&self) -> &TimelineDetails { + &self.sender_profile + } + + /// Get the timestamp when the event was created locally. + pub fn timestamp(&self) -> MilliSecondsSinceUnixEpoch { + self.timestamp + } + + /// Get the content of the event. + pub fn content(&self) -> &TimelineItemContent { + &self.content + } + + /// Clone the current event item, and update its `send_state`. + pub(in crate::room::timeline) fn with_send_state(&self, send_state: EventSendState) -> Self { + Self { send_state, ..self.clone() } + } + + /// Clone the current event item, and update its `sender_profile`. + pub(in crate::room::timeline) fn with_sender_profile( + &self, + sender_profile: TimelineDetails, + ) -> Self { + Self { sender_profile, ..self.clone() } + } + + /// Clone the current event item, and update its `content`. + pub(in crate::room::timeline) fn with_content(&self, content: TimelineItemContent) -> Self { + Self { content, ..self.clone() } + } +} diff --git a/crates/matrix-sdk/src/room/timeline/event_item/mod.rs b/crates/matrix-sdk/src/room/timeline/event_item/mod.rs new file mode 100644 index 00000000000..662a71180bf --- /dev/null +++ b/crates/matrix-sdk/src/room/timeline/event_item/mod.rs @@ -0,0 +1,282 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use ruma::{ + events::{room::message::MessageType, AnySyncTimelineEvent}, + serde::Raw, + EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedMxcUri, TransactionId, UserId, +}; + +use crate::Error; + +mod content; +mod local; +mod remote; + +pub use self::{ + content::{ + AnyOtherFullStateEventContent, BundledReactions, EncryptedMessage, InReplyToDetails, + MemberProfileChange, MembershipChange, Message, OtherState, ReactionGroup, RepliedToEvent, + RoomMembershipChange, Sticker, TimelineItemContent, + }, + local::LocalEventTimelineItem, + remote::RemoteEventTimelineItem, +}; + +/// An item in the timeline that represents at least one event. +/// +/// There is always one main event that gives the `EventTimelineItem` its +/// identity but in many cases, additional events like reactions and edits are +/// also part of the item. +#[derive(Debug, Clone)] +pub enum EventTimelineItem { + /// An event item that has been sent, but not yet acknowledged by the + /// server. + Local(LocalEventTimelineItem), + /// An event item that has eben sent _and_ acknowledged by the server. + Remote(RemoteEventTimelineItem), +} + +impl EventTimelineItem { + /// Get the `LocalEventTimelineItem` if `self` is `Local`. + pub fn as_local(&self) -> Option<&LocalEventTimelineItem> { + match self { + Self::Local(local_event_item) => Some(local_event_item), + Self::Remote(_) => None, + } + } + + /// Get the `RemoteEventTimelineItem` if `self` is `Remote`. + pub fn as_remote(&self) -> Option<&RemoteEventTimelineItem> { + match self { + Self::Local(_) => None, + Self::Remote(remote_event_item) => Some(remote_event_item), + } + } + + /// Get a unique identifier to identify the event item, either by using + /// transaction ID or event ID in case of a local event, or by event ID in + /// case of a remote event. + pub fn unique_identifier(&self) -> String { + match self { + Self::Local(item) => match item.send_state() { + EventSendState::Sent { event_id } => event_id.to_string(), + _ => item.transaction_id().to_string(), + }, + Self::Remote(item) => item.event_id().to_string(), + } + } + + /// Get the transaction ID of this item. + /// + /// The transaction ID is only kept until the remote echo for a local event + /// is received, at which point the `EventTimelineItem::Local` is + /// transformed to `EventTimelineItem::Remote` and the transaction ID + /// discarded. + pub fn transaction_id(&self) -> Option<&TransactionId> { + match self { + Self::Local(local) => Some(local.transaction_id()), + Self::Remote(_) => None, + } + } + + /// Get the event ID of this item. + /// + /// If this returns `Some(_)`, the event was successfully created by the + /// server. + /// + /// Even if this is a [`Local`](Self::Local) event,, this can be `Some(_)` + /// as the event ID can be known not just from the remote echo via + /// `sync_events`, but also from the response of the send request that + /// created the event. + pub fn event_id(&self) -> Option<&EventId> { + match self { + Self::Local(local_event) => local_event.event_id(), + Self::Remote(remote_event) => Some(remote_event.event_id()), + } + } + + /// Get the sender of this item. + pub fn sender(&self) -> &UserId { + match self { + Self::Local(local_event) => local_event.sender(), + Self::Remote(remote_event) => remote_event.sender(), + } + } + + /// Get the profile of the sender. + pub fn sender_profile(&self) -> &TimelineDetails { + match self { + Self::Local(local_event) => local_event.sender_profile(), + Self::Remote(remote_event) => remote_event.sender_profile(), + } + } + + /// Get the content of this item. + pub fn content(&self) -> &TimelineItemContent { + match self { + Self::Local(local_event) => local_event.content(), + Self::Remote(remote_event) => remote_event.content(), + } + } + + /// Get the timestamp of this item. + /// + /// If this event hasn't been echoed back by the server yet, returns the + /// time the local event was created. Otherwise, returns the origin + /// server timestamp. + pub fn timestamp(&self) -> MilliSecondsSinceUnixEpoch { + match self { + Self::Local(local_event) => local_event.timestamp(), + Self::Remote(remote_event) => remote_event.timestamp(), + } + } + + /// Whether this timeline item was sent by the logged-in user themselves. + pub fn is_own(&self) -> bool { + match self { + Self::Local(_) => true, + Self::Remote(remote_event) => remote_event.is_own(), + } + } + + /// Flag indicating this timeline item can be edited by current user. + pub fn is_editable(&self) -> bool { + match self.content() { + TimelineItemContent::Message(message) => { + self.is_own() + && matches!(message.msgtype(), MessageType::Text(_) | MessageType::Emote(_)) + } + _ => false, + } + } + + /// Get the raw JSON representation of the initial event (the one that + /// caused this timeline item to be created). + /// + /// Returns `None` if this event hasn't been echoed back by the server + /// yet. + pub fn raw(&self) -> Option<&Raw> { + match self { + Self::Local(_local_event) => None, + Self::Remote(remote_event) => Some(remote_event.raw()), + } + } + + /// Clone the current event item, and update its `content`. + pub(super) fn with_content(&self, content: TimelineItemContent) -> Self { + match self { + Self::Local(local_event) => Self::Local(local_event.with_content(content)), + Self::Remote(remote_event) => Self::Remote(remote_event.with_content(content)), + } + } + + /// Clone the current event item, and update its `sender_profile`. + pub(super) fn with_sender_profile(&self, sender_profile: TimelineDetails) -> Self { + match self { + EventTimelineItem::Local(local_event) => { + Self::Local(local_event.with_sender_profile(sender_profile)) + } + EventTimelineItem::Remote(remote_event) => { + Self::Remote(remote_event.with_sender_profile(sender_profile)) + } + } + } +} + +/// This type represents the "send state" of a local event timeline item. +#[derive(Clone, Debug)] +pub enum EventSendState { + /// The local event has not been sent yet. + NotSentYet, + /// The local event has been sent to the server, but unsuccessfully: The + /// sending has failed. + SendingFailed { + /// Details about how sending the event failed. + error: Arc, + }, + /// The local event has been sent successfully to the server. + Sent { + /// The event ID assigned by the server. + event_id: OwnedEventId, + }, +} + +impl From for EventTimelineItem { + fn from(value: LocalEventTimelineItem) -> Self { + Self::Local(value) + } +} + +impl From for EventTimelineItem { + fn from(value: RemoteEventTimelineItem) -> Self { + Self::Remote(value) + } +} + +/// The display name and avatar URL of a room member. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Profile { + /// The display name, if set. + pub display_name: Option, + /// Whether the display name is ambiguous. + /// + /// Note that in rooms with lazy-loading enabled, this could be `false` even + /// though the display name is actually ambiguous if not all member events + /// have been seen yet. + pub display_name_ambiguous: bool, + /// The avatar URL, if set. + pub avatar_url: Option, +} + +/// Some details of an [`EventTimelineItem`] that may require server requests +/// other than just the regular +/// [`sync_events`][ruma::api::client::sync::sync_events]. +#[derive(Clone, Debug)] +pub enum TimelineDetails { + /// The details are not available yet, and have not been request from the + /// server. + Unavailable, + + /// The details are not available yet, but have been requested. + Pending, + + /// The details are available. + Ready(T), + + /// An error occurred when fetching the details. + Error(Arc), +} + +impl TimelineDetails { + pub(crate) fn from_initial_value(value: Option) -> Self { + match value { + Some(v) => Self::Ready(v), + None => Self::Unavailable, + } + } + + pub(crate) fn is_unavailable(&self) -> bool { + matches!(self, Self::Unavailable) + } + + pub(crate) fn contains(&self, value: &U) -> bool + where + T: PartialEq, + { + matches!(self, Self::Ready(v) if v == value) + } +} diff --git a/crates/matrix-sdk/src/room/timeline/event_item/remote.rs b/crates/matrix-sdk/src/room/timeline/event_item/remote.rs new file mode 100644 index 00000000000..e824a44ff83 --- /dev/null +++ b/crates/matrix-sdk/src/room/timeline/event_item/remote.rs @@ -0,0 +1,200 @@ +use std::fmt; + +use indexmap::IndexMap; +use matrix_sdk_base::deserialized_responses::EncryptionInfo; +use ruma::{ + events::{receipt::Receipt, AnySyncTimelineEvent}, + serde::Raw, + EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedUserId, UserId, +}; + +use super::{BundledReactions, Profile, TimelineDetails, TimelineItemContent}; + +/// An item for an event that was received from the homeserver. +#[derive(Clone)] +pub struct RemoteEventTimelineItem { + /// The event ID. + event_id: OwnedEventId, + /// The sender of the event. + sender: OwnedUserId, + /// The sender's profile of the event. + sender_profile: TimelineDetails, + /// The timestamp of the event. + timestamp: MilliSecondsSinceUnixEpoch, + /// The content of the event. + content: TimelineItemContent, + /// All bundled reactions about the event. + reactions: BundledReactions, + /// All read receipts for the event. + /// + /// The key is the ID of a room member and the value are details about the + /// read receipt. + /// + /// Note that currently this ignores threads. + read_receipts: IndexMap, + /// Whether the event has been sent by the the logged-in user themselves. + is_own: bool, + /// Encryption information. + encryption_info: Option, + // FIXME: Expose the raw JSON of aggregated events somehow + raw: Raw, + /// Whether the item should be highlighted in the timeline. + is_highlighted: bool, +} + +impl RemoteEventTimelineItem { + #[allow(clippy::too_many_arguments)] // Would be nice to fix, but unclear how + pub(in crate::room::timeline) fn new( + event_id: OwnedEventId, + sender: OwnedUserId, + sender_profile: TimelineDetails, + timestamp: MilliSecondsSinceUnixEpoch, + content: TimelineItemContent, + reactions: BundledReactions, + read_receipts: IndexMap, + is_own: bool, + encryption_info: Option, + raw: Raw, + is_highlighted: bool, + ) -> Self { + Self { + event_id, + sender, + sender_profile, + timestamp, + content, + reactions, + read_receipts, + is_own, + encryption_info, + raw, + is_highlighted, + } + } + + /// Get the ID of the event. + pub fn event_id(&self) -> &EventId { + &self.event_id + } + + /// Get the sender of the event. + pub fn sender(&self) -> &UserId { + &self.sender + } + + /// Get the profile of the event's sender. + pub fn sender_profile(&self) -> &TimelineDetails { + &self.sender_profile + } + + /// Get the event timestamp as set by the homeserver that created the event. + pub fn timestamp(&self) -> MilliSecondsSinceUnixEpoch { + self.timestamp + } + + /// Get the content of the event. + pub fn content(&self) -> &TimelineItemContent { + &self.content + } + + /// Get the reactions of this item. + pub fn reactions(&self) -> &BundledReactions { + // FIXME: Find out the state of incomplete bundled reactions, adjust + // Ruma if necessary, return the whole BundledReactions field + &self.reactions + } + + /// Get the read receipts of this item. + /// + /// The key is the ID of a room member and the value are details about the + /// read receipt. + /// + /// Note that currently this ignores threads. + pub fn read_receipts(&self) -> &IndexMap { + &self.read_receipts + } + + /// Whether the event has been sent by the the logged-in user themselves. + pub fn is_own(&self) -> bool { + self.is_own + } + + /// Get the encryption information for the event. + pub fn encryption_info(&self) -> Option<&EncryptionInfo> { + self.encryption_info.as_ref() + } + + /// Get the raw JSON representation of the primary event. + pub fn raw(&self) -> &Raw { + &self.raw + } + + /// Whether the event should be highlighted in the timeline. + pub fn is_highlighted(&self) -> bool { + self.is_highlighted + } + + pub(in crate::room::timeline) fn set_content(&mut self, content: TimelineItemContent) { + self.content = content; + } + + pub(in crate::room::timeline) fn add_read_receipt( + &mut self, + user_id: OwnedUserId, + receipt: Receipt, + ) { + self.read_receipts.insert(user_id, receipt); + } + + /// Remove the read receipt for the given user. + /// + /// Returns `true` if there was one, `false` if not. + pub(in crate::room::timeline) fn remove_read_receipt(&mut self, user_id: &UserId) -> bool { + self.read_receipts.remove(user_id).is_some() + } + + /// Clone the current event item, and update its `reactions`. + pub(in crate::room::timeline) fn with_reactions(&self, reactions: BundledReactions) -> Self { + Self { reactions, ..self.clone() } + } + + /// Clone the current event item, and update its `content`. + pub(in crate::room::timeline) fn with_content(&self, content: TimelineItemContent) -> Self { + Self { content, ..self.clone() } + } + + /// Clone the current event item, and update its `sender_profile`. + pub(in crate::room::timeline) fn with_sender_profile( + &self, + sender_profile: TimelineDetails, + ) -> Self { + Self { sender_profile, ..self.clone() } + } + + /// Clone the current event item, change its `content` to + /// [`TimelineItemContent::RedactedMessage`], and reset its `reactions`. + pub(in crate::room::timeline) fn to_redacted(&self) -> Self { + Self { + // FIXME: Change when we support state events + content: TimelineItemContent::RedactedMessage, + reactions: BundledReactions::default(), + ..self.clone() + } + } +} + +#[cfg(not(tarpaulin_include))] +impl fmt::Debug for RemoteEventTimelineItem { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RemoteEventTimelineItem") + .field("event_id", &self.event_id) + .field("sender", &self.sender) + .field("timestamp", &self.timestamp) + .field("content", &self.content) + .field("reactions", &self.reactions) + .field("is_own", &self.is_own) + .field("encryption_info", &self.encryption_info) + // skip raw, too noisy + .finish_non_exhaustive() + } +} diff --git a/crates/matrix-sdk/src/room/timeline/inner.rs b/crates/matrix-sdk/src/room/timeline/inner.rs index cb5618a57f9..efb5e4fd808 100644 --- a/crates/matrix-sdk/src/room/timeline/inner.rs +++ b/crates/matrix-sdk/src/room/timeline/inner.rs @@ -1,38 +1,63 @@ -use std::{ - collections::{BTreeSet, HashMap}, - sync::Arc, -}; +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(feature = "e2e-encryption")] +use std::collections::BTreeSet; +use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; -use futures_signals::signal_vec::{MutableVec, MutableVecLockRef, SignalVec}; -use indexmap::IndexSet; -#[cfg(any(test, feature = "experimental-sliding-sync"))] -use matrix_sdk_base::deserialized_responses::SyncTimelineEvent; +use eyeball_im::{ObservableVector, VectorSubscriber}; +use im::Vector; +use indexmap::{IndexMap, IndexSet}; +#[cfg(feature = "e2e-encryption")] +use matrix_sdk_base::crypto::OlmMachine; use matrix_sdk_base::{ - crypto::OlmMachine, - deserialized_responses::{EncryptionInfo, TimelineEvent}, - locks::Mutex, + deserialized_responses::{EncryptionInfo, SyncTimelineEvent, TimelineEvent}, + locks::{Mutex, MutexGuard}, }; +#[cfg(feature = "e2e-encryption")] +use ruma::RoomId; use ruma::{ + api::client::receipt::create_receipt::v3::ReceiptType as SendReceiptType, events::{ - fully_read::FullyReadEvent, relation::Annotation, AnyMessageLikeEventContent, - AnySyncTimelineEvent, + fully_read::FullyReadEvent, + receipt::{Receipt, ReceiptEventContent, ReceiptThread, ReceiptType}, + relation::Annotation, + AnyMessageLikeEventContent, AnySyncTimelineEvent, }, + push::Action, serde::Raw, - EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedTransactionId, OwnedUserId, RoomId, + EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedTransactionId, OwnedUserId, TransactionId, UserId, }; -use tracing::{debug, error, field::debug, info, warn}; +use tracing::{debug, error, field::debug, instrument, trace, warn}; #[cfg(feature = "e2e-encryption")] -use tracing::{instrument, trace}; +use tracing::{field, info, info_span, Instrument as _}; use super::{ + compare_events_positions, event_handler::{ update_read_marker, Flow, HandleEventResult, TimelineEventHandler, TimelineEventKind, TimelineEventMetadata, TimelineItemPosition, }, + read_receipts::{ + handle_explicit_read_receipts, latest_user_read_receipt, load_read_receipts_for_event, + user_receipt, + }, rfind_event_by_id, rfind_event_item, EventSendState, EventTimelineItem, InReplyToDetails, - Message, Profile, RepliedToEvent, TimelineDetails, TimelineItem, TimelineItemContent, + Message, Profile, RelativePosition, RepliedToEvent, TimelineDetails, TimelineItem, + TimelineItemContent, }; use crate::{ events::SyncTimelineEventWithoutContent, @@ -41,15 +66,15 @@ use crate::{ }; #[derive(Debug)] -pub(super) struct TimelineInner { - items: MutableVec>, - metadata: Mutex, - profile_provider: P, +pub(super) struct TimelineInner { + state: Mutex, + room_data_provider: P, + track_read_receipts: bool, } -/// Non-signalling parts of `TimelineInner`. #[derive(Debug, Default)] -pub(super) struct TimelineInnerMetadata { +pub(super) struct TimelineInnerState { + pub(super) items: ObservableVector>, /// Reaction event / txn ID => sender and reaction data. pub(super) reaction_map: HashMap<(Option, Option), (OwnedUserId, Annotation)>, @@ -60,39 +85,78 @@ pub(super) struct TimelineInnerMetadata { /// Whether the event that the fully-ready event _refers to_ is part of the /// timeline. pub(super) fully_read_event_in_timeline: bool, + /// User ID => Receipt type => Read receipt of the user of the given type. + pub(super) users_read_receipts: + HashMap>, } -impl TimelineInner

{ - pub(super) fn new(profile_provider: P) -> Self { - Self { items: Default::default(), metadata: Default::default(), profile_provider } +impl TimelineInner

{ + pub(super) fn new(room_data_provider: P) -> Self { + let state = TimelineInnerState { + // Upstream default capacity is currently 16, which is making + // sliding-sync tests with 20 events lag. This should still be + // small enough. + items: ObservableVector::with_capacity(32), + ..Default::default() + }; + Self { state: Mutex::new(state), room_data_provider, track_read_receipts: false } } - pub(super) fn items(&self) -> MutableVecLockRef<'_, Arc> { - self.items.lock_ref() + pub(super) fn with_read_receipt_tracking(mut self, track_read_receipts: bool) -> Self { + self.track_read_receipts = track_read_receipts; + self } - pub(super) fn items_signal(&self) -> impl SignalVec> { - self.items.signal_vec_cloned() + /// Get a copy of the current items in the list. + /// + /// Cheap because `im::Vector` is cheap to clone. + pub(super) async fn items(&self) -> Vector> { + self.state.lock().await.items.clone() } - #[cfg(any(test, feature = "experimental-sliding-sync"))] - pub(super) async fn add_initial_events(&mut self, events: Vec) { + pub(super) async fn subscribe( + &self, + ) -> (Vector>, VectorSubscriber>) { + trace!("Creating timeline items signal"); + let state = self.state.lock().await; + // auto-deref to the inner vector's clone method + let items = state.items.clone(); + let stream = state.items.subscribe(); + (items, stream) + } + + pub(super) fn set_initial_user_receipt( + &mut self, + receipt_type: ReceiptType, + receipt: (OwnedEventId, Receipt), + ) { + let own_user_id = self.room_data_provider.own_user_id().to_owned(); + self.state + .get_mut() + .users_read_receipts + .entry(own_user_id) + .or_default() + .insert(receipt_type, receipt); + } + + pub(super) async fn add_initial_events(&mut self, events: Vector) { if events.is_empty() { return; } debug!("Adding {} initial events", events.len()); - let timeline_meta = self.metadata.get_mut(); + let state = self.state.get_mut(); for event in events { handle_remote_event( event.event, event.encryption_info, + event.push_actions, TimelineItemPosition::End, - &self.items, - timeline_meta, - &self.profile_provider, + state, + &self.room_data_provider, + self.track_read_receipts, ) .await; } @@ -100,41 +164,44 @@ impl TimelineInner

{ #[cfg(feature = "experimental-sliding-sync")] pub(super) async fn clear(&self) { - let mut timeline_meta = self.metadata.lock().await; - let mut timeline_items = self.items.lock_mut(); - - timeline_meta.reaction_map.clear(); - timeline_meta.fully_read_event = None; - timeline_meta.fully_read_event_in_timeline = false; + trace!("Clearing timeline"); - timeline_items.clear(); + let mut state = self.state.lock().await; + state.items.clear(); + state.reaction_map.clear(); + state.fully_read_event = None; + state.fully_read_event_in_timeline = false; } + #[instrument(skip_all)] pub(super) async fn handle_live_event( &self, raw: Raw, encryption_info: Option, + push_actions: Vec, ) { - let mut timeline_meta = self.metadata.lock().await; + let mut state = self.state.lock().await; handle_remote_event( raw, encryption_info, + push_actions, TimelineItemPosition::End, - &self.items, - &mut timeline_meta, - &self.profile_provider, + &mut state, + &self.room_data_provider, + self.track_read_receipts, ) .await; } /// Handle the creation of a new local event. + #[instrument(skip_all)] pub(super) async fn handle_local_event( &self, txn_id: OwnedTransactionId, content: AnyMessageLikeEventContent, ) { - let sender = self.profile_provider.own_user_id().to_owned(); - let sender_profile = self.profile_provider.profile(&sender).await; + let sender = self.room_data_provider.own_user_id().to_owned(); + let sender_profile = self.room_data_provider.profile(&sender).await; let event_meta = TimelineEventMetadata { sender, sender_profile, @@ -142,26 +209,29 @@ impl TimelineInner

{ relations: Default::default(), // FIXME: Should we supply something here for encrypted rooms? encryption_info: None, + read_receipts: Default::default(), + // An event sent by ourself is never matched against push rules. + is_highlighted: false, }; let flow = Flow::Local { txn_id, timestamp: MilliSecondsSinceUnixEpoch::now() }; let kind = TimelineEventKind::Message { content }; - let mut timeline_meta = self.metadata.lock().await; - let mut timeline_items = self.items.lock_mut(); - TimelineEventHandler::new(event_meta, flow, &mut timeline_items, &mut timeline_meta) + let mut state = self.state.lock().await; + TimelineEventHandler::new(event_meta, flow, &mut state, self.track_read_receipts) .handle_event(kind); } /// Update the send state of a local event represented by a transaction ID. /// /// If no local event is found, a warning is raised. - pub(super) fn update_event_send_state( + #[instrument(skip_all, fields(txn_id))] + pub(super) async fn update_event_send_state( &self, txn_id: &TransactionId, send_state: EventSendState, ) { - let mut lock = self.items.lock_mut(); + let mut state = self.state.lock().await; let new_event_id: Option<&EventId> = match &send_state { EventSendState::Sent { event_id } => Some(event_id), @@ -169,79 +239,84 @@ impl TimelineInner

{ }; // Look for the local event by the transaction ID or event ID. - let result = rfind_event_item(&lock, |it| { + let result = rfind_event_item(&state.items, |it| { it.transaction_id() == Some(txn_id) || new_event_id.is_some() && it.event_id() == new_event_id }); let Some((idx, item)) = result else { // Event isn't found at all. - warn!(?txn_id, "Timeline item not found, can't add event ID"); + warn!("Timeline item not found, can't add event ID"); return; }; let EventTimelineItem::Local(item) = item else { // Remote echo already received. This is very unlikely. - trace!(?txn_id, "Remote echo received before send-event response"); + trace!("Remote echo received before send-event response"); return; }; // The event was already marked as sent, that's a broken state, let's // emit an error but also override to the given sent state. - if let EventSendState::Sent { event_id: existing_event_id } = &item.send_state { + if let EventSendState::Sent { event_id: existing_event_id } = item.send_state() { let new_event_id = new_event_id.map(debug); - error!(?existing_event_id, ?new_event_id, ?txn_id, "Local echo already marked as sent"); + error!(?existing_event_id, ?new_event_id, "Local echo already marked as sent"); } let new_item = TimelineItem::Event(item.with_send_state(send_state).into()); - lock.set_cloned(idx, Arc::new(new_item)); + state.items.set(idx, Arc::new(new_item)); } /// Handle a back-paginated event. /// /// Returns the number of timeline updates that were made. + #[instrument(skip_all)] pub(super) async fn handle_back_paginated_event( &self, event: TimelineEvent, ) -> HandleEventResult { - let mut metadata_lock = self.metadata.lock().await; + let mut state = self.state.lock().await; handle_remote_event( event.event.cast(), event.encryption_info, + event.push_actions, TimelineItemPosition::Start, - &self.items, - &mut metadata_lock, - &self.profile_provider, + &mut state, + &self.room_data_provider, + self.track_read_receipts, ) .await } #[instrument(skip_all)] - pub(super) fn add_loading_indicator(&self) { - let mut lock = self.items.lock_mut(); - if lock.first().map_or(false, |item| item.is_loading_indicator()) { + pub(super) async fn add_loading_indicator(&self) { + let mut state = self.state.lock().await; + + if state.items.front().map_or(false, |item| item.is_loading_indicator()) { warn!("There is already a loading indicator"); return; } - lock.insert_cloned(0, Arc::new(TimelineItem::loading_indicator())); + state.items.push_front(Arc::new(TimelineItem::loading_indicator())); } #[instrument(skip(self))] - pub(super) fn remove_loading_indicator(&self, more_messages: bool) { - let mut lock = self.items.lock_mut(); - if !lock.first().map_or(false, |item| item.is_loading_indicator()) { + pub(super) async fn remove_loading_indicator(&self, more_messages: bool) { + let mut state = self.state.lock().await; + + if !state.items.front().map_or(false, |item| item.is_loading_indicator()) { warn!("There is no loading indicator"); return; } if more_messages { - lock.remove(0); + state.items.pop_front(); } else { - lock.set_cloned(0, Arc::new(TimelineItem::timeline_start())) + state.items.set(0, Arc::new(TimelineItem::timeline_start())); } } + #[instrument(skip_all)] pub(super) async fn handle_fully_read(&self, raw: Raw) { let fully_read_event_id = match raw.deserialize() { Ok(ev) => ev.content.event_id, @@ -254,280 +329,391 @@ impl TimelineInner

{ self.set_fully_read_event(fully_read_event_id).await; } + #[instrument(skip_all)] pub(super) async fn set_fully_read_event(&self, fully_read_event_id: OwnedEventId) { - let mut metadata_lock = self.metadata.lock().await; + let mut state = self.state.lock().await; // A similar event has been handled already. We can ignore it. - if metadata_lock.fully_read_event.as_ref().map_or(false, |id| *id == fully_read_event_id) { + if state.fully_read_event.as_ref().map_or(false, |id| *id == fully_read_event_id) { return; } - metadata_lock.fully_read_event = Some(fully_read_event_id); + state.fully_read_event = Some(fully_read_event_id); - let mut items_lock = self.items.lock_mut(); - let metadata = &mut *metadata_lock; + let state = &mut *state; update_read_marker( - &mut items_lock, - metadata.fully_read_event.as_deref(), - &mut metadata.fully_read_event_in_timeline, + &mut state.items, + state.fully_read_event.as_deref(), + &mut state.fully_read_event_in_timeline, ); } - /// Collect events and their metadata that are unable-to-decrypt (UTD) - /// events in the timeline. - fn collect_utds( + #[cfg(feature = "e2e-encryption")] + #[instrument(skip(self, olm_machine))] + pub(super) async fn retry_event_decryption( &self, + room_id: &RoomId, + olm_machine: &OlmMachine, session_ids: Option>, - ) -> Vec<(usize, OwnedEventId, String, Raw)> { + ) { use super::EncryptedMessage; + trace!("Retrying decryption"); let should_retry = |session_id: &str| { - let session_ids = &session_ids; - - if let Some(session_ids) = session_ids { + if let Some(session_ids) = &session_ids { session_ids.contains(session_id) } else { true } }; - self.items - .lock_ref() - .iter() - .enumerate() - .filter_map(|(idx, item)| { - let event_item = &item.as_event()?; - let utd = event_item.content().as_unable_to_decrypt()?; + let retry_one = |item: Arc| { + async move { + let event_item = item.as_event()?; - match utd { + let session_id = match event_item.content().as_unable_to_decrypt()? { EncryptedMessage::MegolmV1AesSha2 { session_id, .. } if should_retry(session_id) => { - let EventTimelineItem::Remote(RemoteEventTimelineItem { event_id, raw, .. }) = event_item else { - error!("Key for unable-to-decrypt timeline item is not an event ID"); - return None; - }; - - Some(( - idx, - event_id.to_owned(), - session_id.to_owned(), - raw.clone(), - )) + session_id } EncryptedMessage::MegolmV1AesSha2 { .. } | EncryptedMessage::OlmV1Curve25519AesSha2 { .. } - | EncryptedMessage::Unknown => None, - } - }) - .collect() - } + | EncryptedMessage::Unknown => return None, + }; - #[cfg(feature = "e2e-encryption")] - #[instrument(skip(self, olm_machine))] - pub(super) async fn retry_event_decryption( - &self, - room_id: &RoomId, - olm_machine: &OlmMachine, - session_ids: Option>, - ) { - debug!("Retrying decryption"); + tracing::Span::current().record("session_id", session_id); - let utds_for_session = self.collect_utds(session_ids); + let EventTimelineItem::Remote(remote_event) = event_item else { + error!("Key for unable-to-decrypt timeline item is not an event ID"); + return None; + }; - if utds_for_session.is_empty() { - trace!("Found no events to retry decryption for"); - return; - } + tracing::Span::current().record("event_id", debug(remote_event.event_id())); - let mut metadata_lock = self.metadata.lock().await; - for (idx, event_id, session_id, utd) in utds_for_session.iter().rev() { - let event = match olm_machine.decrypt_room_event(utd.cast_ref(), room_id).await { - Ok(ev) => ev, - Err(e) => { - info!( - ?event_id, - ?session_id, - "Failed to decrypt event after receiving room key: {e}" - ); - continue; + let raw = remote_event.raw().cast_ref(); + match olm_machine.decrypt_room_event(raw, room_id).await { + Ok(event) => { + trace!("Successfully decrypted event that previously failed to decrypt"); + Some(event) + } + Err(e) => { + info!("Failed to decrypt event after receiving room key: {e}"); + None + } } - }; + } + .instrument(info_span!( + "retry_one", + session_id = field::Empty, + event_id = field::Empty + )) + }; - trace!( - ?event_id, - ?session_id, - "Successfully decrypted event that previously failed to decrypt" - ); + let mut state = self.state.lock().await; - handle_remote_event( + // We loop through all the items in the timeline, if we successfully + // decrypt a UTD item we either replace it or remove it and update + // another one. + let mut idx = 0; + while let Some(item) = state.items.get(idx) { + let Some(event) = retry_one(item.clone()).await else { + idx += 1; + continue; + }; + + let result = handle_remote_event( event.event.cast(), event.encryption_info, - TimelineItemPosition::Update(*idx), - &self.items, - &mut metadata_lock, - &self.profile_provider, + event.push_actions, + TimelineItemPosition::Update(idx), + &mut state, + &self.room_data_provider, + self.track_read_receipts, ) .await; + + // If the UTD was removed rather than updated, run the loop again + // with the same index. + if !result.item_removed { + idx += 1; + } } } - pub(super) fn set_sender_profiles_pending(&self) { - self.set_non_ready_sender_profiles(TimelineDetails::Pending); + pub(super) async fn set_sender_profiles_pending(&self) { + self.set_non_ready_sender_profiles(TimelineDetails::Pending).await; } - pub(super) fn set_sender_profiles_error(&self, error: Arc) { - self.set_non_ready_sender_profiles(TimelineDetails::Error(error)); + pub(super) async fn set_sender_profiles_error(&self, error: Arc) { + self.set_non_ready_sender_profiles(TimelineDetails::Error(error)).await; } - fn set_non_ready_sender_profiles(&self, state: TimelineDetails) { - let mut timeline_items = self.items.lock_mut(); - for idx in 0..timeline_items.len() { - let Some(event_item) = timeline_items[idx].as_event() else { continue }; + async fn set_non_ready_sender_profiles(&self, profile_state: TimelineDetails) { + let mut state = self.state.lock().await; + for idx in 0..state.items.len() { + let Some(event_item) = state.items[idx].as_event() else { continue }; if !matches!(event_item.sender_profile(), TimelineDetails::Ready(_)) { - timeline_items.set_cloned( - idx, - Arc::new(TimelineItem::Event(event_item.with_sender_profile(state.clone()))), - ); + let item = Arc::new(TimelineItem::Event( + event_item.with_sender_profile(profile_state.clone()), + )); + state.items.set(idx, item); } } } pub(super) async fn update_sender_profiles(&self) { - // Can't lock the timeline items across .await points without making the - // resulting future `!Send`. As a (brittle) hack around that, lock the - // timeline items in each loop iteration but keep a lock of the metadata - // so no event handler runs in parallel and assert the number of items - // doesn't change between iterations. - let _guard = self.metadata.lock().await; - let num_items = self.items().len(); + trace!("Updating sender profiles"); + + let mut state = self.state.lock().await; + let num_items = state.items.len(); for idx in 0..num_items { - let sender = match self.items()[idx].as_event() { + let sender = match state.items[idx].as_event() { Some(event_item) => event_item.sender().to_owned(), None => continue, }; - let maybe_profile = self.profile_provider.profile(&sender).await; + let maybe_profile = self.room_data_provider.profile(&sender).await; - let mut timeline_items = self.items.lock_mut(); - assert_eq!(timeline_items.len(), num_items); + assert_eq!(state.items.len(), num_items); - let event_item = timeline_items[idx].as_event().unwrap(); + let event_item = state.items[idx].as_event().unwrap(); match maybe_profile { Some(profile) => { if !event_item.sender_profile().contains(&profile) { let updated_item = event_item.with_sender_profile(TimelineDetails::Ready(profile)); - timeline_items.set_cloned(idx, Arc::new(TimelineItem::Event(updated_item))); + state.items.set(idx, Arc::new(TimelineItem::Event(updated_item))); } } None => { if !event_item.sender_profile().is_unavailable() { let updated_item = event_item.with_sender_profile(TimelineDetails::Unavailable); - timeline_items.set_cloned(idx, Arc::new(TimelineItem::Event(updated_item))); + state.items.set(idx, Arc::new(TimelineItem::Event(updated_item))); } } } } } - fn update_event_item(&self, index: usize, event_item: EventTimelineItem) { - self.items.lock_mut().set_cloned(index, Arc::new(TimelineItem::Event(event_item))) + pub(super) async fn handle_read_receipts(&self, receipt_event_content: ReceiptEventContent) { + let mut state = self.state.lock().await; + let own_user_id = self.room_data_provider.own_user_id(); + + handle_explicit_read_receipts(receipt_event_content, own_user_id, &mut state) } } impl TimelineInner { pub(super) fn room(&self) -> &room::Common { - &self.profile_provider + &self.room_data_provider + } + + /// Get the current fully-read event. + pub(super) async fn fully_read_event(&self) -> Option { + match self.room().account_data_static().await { + Ok(Some(fully_read)) => match fully_read.deserialize() { + Ok(fully_read) => Some(fully_read), + Err(e) => { + error!("Failed to deserialize fully-read account data: {e}"); + None + } + }, + Err(e) => { + error!("Failed to get fully-read account data from the store: {e}"); + None + } + _ => None, + } + } + + /// Load the current fully-read event in this inner timeline. + pub(super) async fn load_fully_read_event(&self) { + if let Some(fully_read) = self.fully_read_event().await { + self.set_fully_read_event(fully_read.content.event_id).await; + } } pub(super) async fn fetch_in_reply_to_details( &self, - index: usize, - mut item: RemoteEventTimelineItem, + event_id: &EventId, ) -> Result { - let TimelineItemContent::Message(message) = item.content.clone() else { + let state = self.state.lock().await; + let (index, item) = rfind_event_by_id(&state.items, event_id) + .and_then(|(pos, item)| item.as_remote().map(|item| (pos, item.clone()))) + .ok_or(super::Error::RemoteEventNotInTimeline)?; + + let TimelineItemContent::Message(message) = item.content().clone() else { return Ok(item); }; let Some(in_reply_to) = message.in_reply_to() else { return Ok(item); }; - let details = - self.fetch_replied_to_event(index, &item, &message, &in_reply_to.event_id).await; + let details = fetch_replied_to_event( + state, + index, + &item, + &message, + &in_reply_to.event_id, + self.room(), + ) + .await; // We need to be sure to have the latest position of the event as it might have // changed while waiting for the request. - let (index, _) = rfind_event_by_id(&self.items(), &item.event_id) + let mut state = self.state.lock().await; + let (index, mut item) = rfind_event_by_id(&state.items, item.event_id()) + .and_then(|(pos, item)| item.as_remote().map(|item| (pos, item.clone()))) .ok_or(super::Error::RemoteEventNotInTimeline)?; - item = item.with_content(TimelineItemContent::Message(message.with_in_reply_to( + // Check the state of the event again, it might have been redacted while + // the request was in-flight. + let TimelineItemContent::Message(message) = item.content().clone() else { + return Ok(item); + }; + let Some(in_reply_to) = message.in_reply_to() else { + return Ok(item); + }; + + item.set_content(TimelineItemContent::Message(message.with_in_reply_to( InReplyToDetails { event_id: in_reply_to.event_id.clone(), details }, ))); - self.update_event_item(index, item.clone().into()); + state.items.set(index, Arc::new(TimelineItem::Event(item.clone().into()))); Ok(item) } - async fn fetch_replied_to_event( + /// Get the latest read receipt for the given user. + /// + /// Useful to get the latest read receipt, whether it's private or public. + pub(super) async fn latest_user_read_receipt( &self, - index: usize, - item: &RemoteEventTimelineItem, - message: &Message, - in_reply_to: &EventId, - ) -> TimelineDetails> { - if let Some((_, item)) = rfind_event_by_id(&self.items(), in_reply_to) { - let details = match item.content() { - TimelineItemContent::Message(message) => { - TimelineDetails::Ready(Box::new(RepliedToEvent { - message: message.clone(), - sender: item.sender().to_owned(), - sender_profile: item.sender_profile().clone(), - })) - } - _ => TimelineDetails::Error(Arc::new(super::Error::UnsupportedEvent.into())), - }; + user_id: &UserId, + ) -> Option<(OwnedEventId, Receipt)> { + let state = self.state.lock().await; + let room = self.room(); - return details; - }; + latest_user_read_receipt(user_id, &state, room).await + } - self.update_event_item( - index, - item.with_content(TimelineItemContent::Message(message.with_in_reply_to( - InReplyToDetails { - event_id: in_reply_to.to_owned(), - details: TimelineDetails::Pending, - }, - ))) - .into(), - ); + /// Check whether the given receipt should be sent. + /// + /// Returns `false` if the given receipt is older than the current one. + pub(super) async fn should_send_receipt( + &self, + receipt_type: &SendReceiptType, + thread: &ReceiptThread, + event_id: &EventId, + ) -> bool { + // We don't support threaded receipts yet. + if *thread != ReceiptThread::Unthreaded { + return true; + } + + let own_user_id = self.room().own_user_id(); + let state = self.state.lock().await; + let room = self.room(); - match self.room().event(in_reply_to).await { - Ok(timeline_event) => { - match RepliedToEvent::try_from_timeline_event( - timeline_event, - &self.profile_provider, - ) - .await + match receipt_type { + SendReceiptType::Read => { + if let Some((old_pub_read, _)) = + user_receipt(own_user_id, ReceiptType::Read, &state, room).await + { + if let Some(relative_pos) = + compare_events_positions(&old_pub_read, event_id, &state.items) + { + return relative_pos == RelativePosition::After; + } + } + } + // Implicit read receipts are saved as public read receipts, so get the latest. It also + // doesn't make sense to have a private read receipt behind a public one. + SendReceiptType::ReadPrivate => { + if let Some((old_priv_read, _)) = + latest_user_read_receipt(own_user_id, &state, room).await { - Ok(event) => TimelineDetails::Ready(Box::new(event)), - Err(e) => TimelineDetails::Error(Arc::new(e)), + if let Some(relative_pos) = + compare_events_positions(&old_priv_read, event_id, &state.items) + { + return relative_pos == RelativePosition::After; + } + } + } + SendReceiptType::FullyRead => { + if let Some(old_fully_read) = self.fully_read_event().await { + if let Some(relative_pos) = compare_events_positions( + &old_fully_read.content.event_id, + event_id, + &state.items, + ) { + return relative_pos == RelativePosition::After; + } } } - Err(e) => TimelineDetails::Error(Arc::new(e)), + _ => {} } + + // Let the server handle unknown receipts. + true + } +} + +async fn fetch_replied_to_event( + mut state: MutexGuard<'_, TimelineInnerState>, + index: usize, + item: &RemoteEventTimelineItem, + message: &Message, + in_reply_to: &EventId, + room: &room::Common, +) -> TimelineDetails> { + if let Some((_, item)) = rfind_event_by_id(&state.items, in_reply_to) { + let details = match item.content() { + TimelineItemContent::Message(message) => { + TimelineDetails::Ready(Box::new(RepliedToEvent { + message: message.clone(), + sender: item.sender().to_owned(), + sender_profile: item.sender_profile().clone(), + })) + } + _ => TimelineDetails::Error(Arc::new(super::Error::UnsupportedEvent.into())), + }; + + return details; + }; + + let event_item = item + .with_content(TimelineItemContent::Message(message.with_in_reply_to(InReplyToDetails { + event_id: in_reply_to.to_owned(), + details: TimelineDetails::Pending, + }))) + .into(); + state.items.set(index, Arc::new(TimelineItem::Event(event_item))); + + // Don't hold the state lock while the network request is made + drop(state); + + match room.event(in_reply_to).await { + Ok(timeline_event) => { + match RepliedToEvent::try_from_timeline_event(timeline_event, room).await { + Ok(event) => TimelineDetails::Ready(Box::new(event)), + Err(e) => TimelineDetails::Error(Arc::new(e)), + } + } + Err(e) => TimelineDetails::Error(Arc::new(e)), } } #[async_trait] -pub(super) trait ProfileProvider { +pub(super) trait RoomDataProvider { fn own_user_id(&self) -> &UserId; async fn profile(&self, user_id: &UserId) -> Option; + async fn read_receipts_for_event(&self, event_id: &EventId) -> IndexMap; } #[async_trait] -impl ProfileProvider for room::Common { +impl RoomDataProvider for room::Common { fn own_user_id(&self) -> &UserId { (**self).own_user_id() } @@ -551,20 +737,29 @@ impl ProfileProvider for room::Common { } } } + + async fn read_receipts_for_event(&self, event_id: &EventId) -> IndexMap { + match self.event_receipts(ReceiptType::Read, ReceiptThread::Unthreaded, event_id).await { + Ok(receipts) => receipts.into_iter().collect(), + Err(e) => { + error!(?event_id, "Failed to get read receipts for event: {e}"); + IndexMap::new() + } + } + } } /// Handle a remote event. /// /// Returns the number of timeline updates that were made. -async fn handle_remote_event( +async fn handle_remote_event( raw: Raw, encryption_info: Option, + push_actions: Vec, position: TimelineItemPosition, - // MutableVecLock can't be held across `.await`s in `Send` futures, so we - // can't lock it ahead of time like `timeline_meta`. - timeline_items: &MutableVec>, - timeline_meta: &mut TimelineInnerMetadata, - profile_provider: &P, + timeline_state: &mut TimelineInnerState, + room_data_provider: &P, + track_read_receipts: bool, ) -> HandleEventResult { let (event_id, sender, origin_server_ts, txn_id, relations, event_kind) = match raw.deserialize() { @@ -594,13 +789,25 @@ async fn handle_remote_event( }, }; - let is_own_event = sender == profile_provider.own_user_id(); - let sender_profile = profile_provider.profile(&sender).await; - let event_meta = - TimelineEventMetadata { sender, sender_profile, is_own_event, relations, encryption_info }; + let is_own_event = sender == room_data_provider.own_user_id(); + let sender_profile = room_data_provider.profile(&sender).await; + let read_receipts = if track_read_receipts { + load_read_receipts_for_event(&event_id, timeline_state, room_data_provider).await + } else { + Default::default() + }; + let is_highlighted = push_actions.iter().any(Action::is_highlight); + let event_meta = TimelineEventMetadata { + sender, + sender_profile, + is_own_event, + relations, + encryption_info, + read_receipts, + is_highlighted, + }; let flow = Flow::Remote { event_id, origin_server_ts, raw_event: raw, txn_id, position }; - let mut timeline_items = timeline_items.lock_mut(); - TimelineEventHandler::new(event_meta, flow, &mut timeline_items, timeline_meta) + TimelineEventHandler::new(event_meta, flow, timeline_state, track_read_receipts) .handle_event(event_kind) } diff --git a/crates/matrix-sdk/src/room/timeline/mod.rs b/crates/matrix-sdk/src/room/timeline/mod.rs index ae9d589b6a2..a0cdb981731 100644 --- a/crates/matrix-sdk/src/room/timeline/mod.rs +++ b/crates/matrix-sdk/src/room/timeline/mod.rs @@ -16,52 +16,56 @@ //! //! See [`Timeline`] for details. -use std::sync::Arc; +use std::{pin::Pin, sync::Arc, task::Poll}; +use eyeball_im::{VectorDiff, VectorSubscriber}; use futures_core::Stream; -use futures_signals::signal_vec::{SignalVec, SignalVecExt, VecDiff}; -#[cfg(feature = "experimental-sliding-sync")] -use matrix_sdk_base::deserialized_responses::SyncTimelineEvent; -use matrix_sdk_base::{deserialized_responses::EncryptionInfo, locks::Mutex}; +use im::Vector; +use matrix_sdk_base::locks::Mutex; +use pin_project_lite::pin_project; use ruma::{ + api::client::receipt::create_receipt::v3::ReceiptType, assign, - events::{fully_read::FullyReadEventContent, AnyMessageLikeEventContent}, - EventId, MilliSecondsSinceUnixEpoch, TransactionId, + events::{ + receipt::{Receipt, ReceiptThread}, + AnyMessageLikeEventContent, + }, + EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, TransactionId, UserId, }; use thiserror::Error; use tracing::{error, instrument, warn}; -use super::Joined; +use super::{Joined, Receipts}; use crate::{ event_handler::EventHandlerHandle, room::{self, MessagesOptions}, - Result, + Client, Result, }; +mod builder; mod event_handler; mod event_item; mod inner; mod pagination; +mod read_receipts; #[cfg(test)] mod tests; #[cfg(feature = "e2e-encryption")] mod to_device; mod virtual_item; +pub(crate) use self::builder::TimelineBuilder; +use self::inner::{TimelineInner, TimelineInnerState}; pub use self::{ event_item::{ AnyOtherFullStateEventContent, BundledReactions, EncryptedMessage, EventSendState, - EventTimelineItem, InReplyToDetails, MemberProfileChange, MembershipChange, Message, - OtherState, Profile, ReactionGroup, RepliedToEvent, RoomMembershipChange, Sticker, - TimelineDetails, TimelineItemContent, + EventTimelineItem, InReplyToDetails, LocalEventTimelineItem, MemberProfileChange, + MembershipChange, Message, OtherState, Profile, ReactionGroup, RemoteEventTimelineItem, + RepliedToEvent, RoomMembershipChange, Sticker, TimelineDetails, TimelineItemContent, }, pagination::{PaginationOptions, PaginationOutcome}, virtual_item::VirtualTimelineItem, }; -use self::{ - inner::{TimelineInner, TimelineInnerMetadata}, - to_device::{handle_forwarded_room_key_event, handle_room_key_event}, -}; /// A high-level view into a regularΒΉ room's contents. /// @@ -73,119 +77,18 @@ pub struct Timeline { inner: Arc>, start_token: Mutex>, _end_token: Mutex>, - event_handler_handles: Vec, -} - -impl Drop for Timeline { - fn drop(&mut self) { - for handle in self.event_handler_handles.drain(..) { - self.inner.room().client().remove_event_handler(handle); - } - } + event_handler_handles: Arc, } impl Timeline { - pub(super) fn new(room: &room::Common) -> Self { - Self::from_inner(Arc::new(TimelineInner::new(room.to_owned())), None) - } - - #[cfg(feature = "experimental-sliding-sync")] - pub(crate) async fn with_events( - room: &room::Common, - prev_token: Option, - events: Vec, - ) -> Self { - let mut inner = TimelineInner::new(room.to_owned()); - inner.add_initial_events(events).await; - - let timeline = Self::from_inner(Arc::new(inner), prev_token); - - // The events we're injecting might be encrypted events, but we might - // have received the room key to decrypt them while nobody was listening to the - // `m.room_key` event, let's retry now. - // - // TODO: We could spawn a task here and put this into the background, though it - // might not be worth it depending on the number of events we injected. - // Some measuring needs to be done. - #[cfg(feature = "e2e-encryption")] - timeline.retry_decryption_for_all_events().await; - - timeline - } - - fn from_inner(inner: Arc, prev_token: Option) -> Timeline { - let room = inner.room(); - - let timeline_event_handle = room.add_event_handler({ - let inner = inner.clone(); - move |event, encryption_info: Option| { - let inner = inner.clone(); - async move { - inner.handle_live_event(event, encryption_info).await; - } - } - }); - - // Not using room.add_event_handler here because RoomKey events are - // to-device events that are not received in the context of a room. - #[cfg(feature = "e2e-encryption")] - let room_key_handle = room - .client - .add_event_handler(handle_room_key_event(inner.clone(), room.room_id().to_owned())); - #[cfg(feature = "e2e-encryption")] - let forwarded_room_key_handle = room.client.add_event_handler( - handle_forwarded_room_key_event(inner.clone(), room.room_id().to_owned()), - ); - - let event_handler_handles = vec![ - timeline_event_handle, - #[cfg(feature = "e2e-encryption")] - room_key_handle, - #[cfg(feature = "e2e-encryption")] - forwarded_room_key_handle, - ]; - - Timeline { - inner, - start_token: Mutex::new(prev_token), - _end_token: Mutex::new(None), - event_handler_handles, - } + pub(crate) fn builder(room: &room::Common) -> TimelineBuilder { + TimelineBuilder::new(room) } fn room(&self) -> &room::Common { self.inner.room() } - /// Enable tracking of the fully-read marker on this `Timeline`. - pub async fn with_fully_read_tracking(mut self) -> Self { - match self.room().account_data_static::().await { - Ok(Some(fully_read)) => match fully_read.deserialize() { - Ok(fully_read) => { - self.inner.set_fully_read_event(fully_read.content.event_id).await; - } - Err(e) => { - error!("Failed to deserialize fully-read account data: {e}"); - } - }, - Err(e) => { - error!("Failed to get fully-read account data from the store: {e}"); - } - _ => {} - } - - let inner = self.inner.clone(); - let fully_read_handle = self.room().add_event_handler(move |event| { - let inner = inner.clone(); - async move { - inner.handle_fully_read(event).await; - } - }); - self.event_handler_handles.push(fully_read_handle); - - self - } - /// Clear all timeline items, and reset pagination parameters. #[cfg(feature = "experimental-sliding-sync")] pub async fn clear(&self) { @@ -199,25 +102,17 @@ impl Timeline { } /// Add more events to the start of the timeline. - /// - /// # Arguments - /// - /// * `initial_pagination_size`: The number of events to fetch from the - /// server in the first pagination request. The server may choose return - /// fewer events, for example because the supplied number is too big or - /// the beginning of the visible timeline was reached. - /// * ` #[instrument(skip_all, fields(initial_pagination_size, room_id = ?self.room().room_id()))] pub async fn paginate_backwards(&self, mut opts: PaginationOptions<'_>) -> Result<()> { let mut start_lock = self.start_token.lock().await; if start_lock.is_none() - && self.inner.items().first().map_or(false, |item| item.is_timeline_start()) + && self.inner.items().await.front().map_or(false, |item| item.is_timeline_start()) { warn!("Start of timeline reached, ignoring backwards-pagination request"); return Ok(()); } - self.inner.add_loading_indicator(); + self.inner.add_loading_indicator().await; let mut from = start_lock.clone(); let mut outcome = PaginationOutcome::new(); @@ -265,7 +160,7 @@ impl Timeline { } } - self.inner.remove_loading_indicator(from.is_some()); + self.inner.remove_loading_indicator(from.is_some()).await; *start_lock = from; Ok(()) @@ -311,7 +206,7 @@ impl Timeline { .await; } - #[cfg(all(feature = "experimental-sliding-sync", feature = "e2e-encryption"))] + #[cfg(feature = "e2e-encryption")] async fn retry_decryption_for_all_events(&self) { self.inner .retry_event_decryption( @@ -322,27 +217,27 @@ impl Timeline { .await; } - /// Get the latest of the timeline's event items. - pub fn latest_event(&self) -> Option { - self.inner.items().last()?.as_event().cloned() + /// Get the current list of timeline items. Do not use this in production! + #[cfg(feature = "testing")] + pub async fn items(&self) -> Vector> { + self.inner.items().await } - /// Get a signal of the timeline's items. - /// - /// You can poll this signal to receive updates, the first of which will - /// be the full list of items currently available. - /// - /// See [`SignalVecExt`](futures_signals::signal_vec::SignalVecExt) for a - /// high-level API on top of [`SignalVec`]. - pub fn signal(&self) -> impl SignalVec> { - self.inner.items_signal() + /// Get the latest of the timeline's event items. + pub async fn latest_event(&self) -> Option { + self.inner.items().await.last()?.as_event().cloned() } - /// Get a stream of timeline changes. + /// Get the current timeline items, and a stream of changes. /// - /// This is a convenience shorthand for `timeline.signal().to_stream()`. - pub fn stream(&self) -> impl Stream>> { - self.signal().to_stream() + /// You can poll this stream to receive updates. See + /// [`futures_util::StreamExt`] for a high-level API on top of [`Stream`]. + pub async fn subscribe( + &self, + ) -> (Vector>, impl Stream>>) { + let (items, stream) = self.inner.subscribe().await; + let stream = TimelineStream::new(stream, self.event_handler_handles.clone()); + (items, stream) } /// Send a message to the room, and add it to the timeline as a local echo. @@ -374,7 +269,7 @@ impl Timeline { /// /// [`MessageLikeUnsigned`]: ruma::events::MessageLikeUnsigned /// [`SyncMessageLikeEvent`]: ruma::events::SyncMessageLikeEvent - #[instrument(skip(self, content), fields(room_id = ?self.room().room_id()))] + #[instrument(skip(self, content), parent = &self.inner.room().client.inner.root_span, fields(room_id = ?self.room().room_id()))] pub async fn send(&self, content: AnyMessageLikeEventContent, txn_id: Option<&TransactionId>) { let txn_id = txn_id.map_or_else(TransactionId::new, ToOwned::to_owned); self.inner.handle_local_event(txn_id.clone(), content.clone()).await; @@ -389,7 +284,7 @@ impl Timeline { Ok(response) => EventSendState::Sent { event_id: response.event_id }, Err(error) => EventSendState::SendingFailed { error: Arc::new(error) }, }; - self.inner.update_event_send_state(&txn_id, send_state); + self.inner.update_event_send_state(&txn_id, send_state).await; } /// Fetch unavailable details about the event with the given ID. @@ -413,12 +308,7 @@ impl Timeline { /// before all requests are handled. #[instrument(skip(self), fields(room_id = ?self.room().room_id()))] pub async fn fetch_event_details(&self, event_id: &EventId) -> Result<()> { - let (index, item) = rfind_event_by_id(&self.inner.items(), event_id) - .and_then(|(pos, item)| item.as_remote().map(|item| (pos, item.clone()))) - .ok_or(Error::RemoteEventNotInTimeline)?; - - self.inner.fetch_in_reply_to_details(index, item).await?; - + self.inner.fetch_in_reply_to_details(event_id).await?; Ok(()) } @@ -431,15 +321,146 @@ impl Timeline { /// the `sender_profile` set to [`TimelineDetails::Error`]. #[instrument(skip_all)] pub async fn fetch_members(&self) { - self.inner.set_sender_profiles_pending(); + self.inner.set_sender_profiles_pending().await; match self.room().ensure_members().await { Ok(_) => { self.inner.update_sender_profiles().await; } Err(e) => { - self.inner.set_sender_profiles_error(Arc::new(e)); + self.inner.set_sender_profiles_error(Arc::new(e)).await; + } + } + } + + /// Get the latest read receipt for the given user. + /// + /// Contrary to [`Common::user_receipt()`](super::Common::user_receipt) that + /// only keeps track of read receipts received from the homeserver, this + /// keeps also track of implicit read receipts in this timeline, i.e. + /// when a room member sends an event. + #[instrument(skip(self), parent = &self.room().client.inner.root_span)] + pub async fn latest_user_read_receipt( + &self, + user_id: &UserId, + ) -> Option<(OwnedEventId, Receipt)> { + self.inner.latest_user_read_receipt(user_id).await + } + + /// Send the given receipt. + /// + /// This uses [`Joined::send_single_receipt`] internally, but checks + /// first if the receipt points to an event in this timeline that is more + /// recent than the current ones, to avoid unnecessary requests. + #[instrument(skip(self), parent = &self.room().client.inner.root_span)] + pub async fn send_single_receipt( + &self, + receipt_type: ReceiptType, + thread: ReceiptThread, + event_id: OwnedEventId, + ) -> Result<()> { + if !self.inner.should_send_receipt(&receipt_type, &thread, &event_id).await { + return Ok(()); + } + + // If this room isn't actually in joined state, we'll get a server error. + // Not ideal, but works for now. + let room = Joined { inner: self.room().clone() }; + + room.send_single_receipt(receipt_type, thread, event_id).await + } + + /// Send the given receipts. + /// + /// This uses [`Joined::send_multiple_receipts`] internally, but checks + /// first if the receipts point to events in this timeline that are more + /// recent than the current ones, to avoid unnecessary requests. + #[instrument(skip(self), parent = &self.room().client.inner.root_span)] + pub async fn send_multiple_receipts(&self, mut receipts: Receipts) -> Result<()> { + if let Some(fully_read) = &receipts.fully_read { + if !self + .inner + .should_send_receipt( + &ReceiptType::FullyRead, + &ReceiptThread::Unthreaded, + fully_read, + ) + .await + { + receipts.fully_read = None; + } + } + + if let Some(read_receipt) = &receipts.read_receipt { + if !self + .inner + .should_send_receipt(&ReceiptType::Read, &ReceiptThread::Unthreaded, read_receipt) + .await + { + receipts.read_receipt = None; + } + } + + if let Some(private_read_receipt) = &receipts.private_read_receipt { + if !self + .inner + .should_send_receipt( + &ReceiptType::ReadPrivate, + &ReceiptThread::Unthreaded, + private_read_receipt, + ) + .await + { + receipts.private_read_receipt = None; } } + + // If this room isn't actually in joined state, we'll get a server error. + // Not ideal, but works for now. + let room = Joined { inner: self.room().clone() }; + + room.send_multiple_receipts(receipts).await + } +} + +#[derive(Debug)] +struct TimelineEventHandlerHandles { + client: Client, + handles: Vec, +} + +impl Drop for TimelineEventHandlerHandles { + fn drop(&mut self) { + for handle in self.handles.drain(..) { + self.client.remove_event_handler(handle); + } + } +} + +pin_project! { + struct TimelineStream { + #[pin] + inner: VectorSubscriber>, + event_handler_handles: Arc, + } +} + +impl TimelineStream { + fn new( + inner: VectorSubscriber>, + event_handler_handles: Arc, + ) -> Self { + Self { inner, event_handler_handles } + } +} + +impl Stream for TimelineStream { + type Item = VectorDiff>; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().inner.poll_next(cx) } } @@ -489,6 +510,14 @@ impl TimelineItem { Self::Virtual(VirtualTimelineItem::TimelineStart) } + fn is_virtual(&self) -> bool { + matches!(self, Self::Virtual(_)) + } + + fn is_day_divider(&self) -> bool { + matches!(self, Self::Virtual(VirtualTimelineItem::DayDivider(_))) + } + fn is_read_marker(&self) -> bool { matches!(self, Self::Virtual(VirtualTimelineItem::ReadMarker)) } @@ -502,10 +531,22 @@ impl TimelineItem { } } +impl From for TimelineItem { + fn from(item: EventTimelineItem) -> Self { + Self::Event(item) + } +} + +impl From for TimelineItem { + fn from(item: VirtualTimelineItem) -> Self { + Self::Virtual(item) + } +} + // FIXME: Put an upper bound on timeline size or add a separate map to look up // the index of a timeline item by its key, to avoid large linear scans. fn rfind_event_item( - items: &[Arc], + items: &Vector>, mut f: impl FnMut(&EventTimelineItem) -> bool, ) -> Option<(usize, &EventTimelineItem)> { items @@ -516,13 +557,13 @@ fn rfind_event_item( } fn rfind_event_by_id<'a>( - items: &'a [Arc], + items: &'a Vector>, event_id: &EventId, ) -> Option<(usize, &'a EventTimelineItem)> { rfind_event_item(items, |it| it.event_id() == Some(event_id)) } -fn find_read_marker(items: &[Arc]) -> Option { +fn find_read_marker(items: &Vector>) -> Option { items.iter().rposition(|item| item.is_read_marker()) } @@ -538,3 +579,33 @@ pub enum Error { #[error("Unsupported event")] UnsupportedEvent, } + +/// Result of comparing events position in the timeline. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum RelativePosition { + /// Event B is after (more recent than) event A. + After, + /// They are the same event. + Same, + /// Event B is before (older than) event A. + Before, +} + +fn compare_events_positions( + event_a: &EventId, + event_b: &EventId, + timeline_items: &Vector>, +) -> Option { + if event_a == event_b { + return Some(RelativePosition::Same); + } + + let (pos_event_a, _) = rfind_event_by_id(timeline_items, event_a)?; + let (pos_event_b, _) = rfind_event_by_id(timeline_items, event_b)?; + + if pos_event_a > pos_event_b { + Some(RelativePosition::Before) + } else { + Some(RelativePosition::After) + } +} diff --git a/crates/matrix-sdk/src/room/timeline/pagination.rs b/crates/matrix-sdk/src/room/timeline/pagination.rs index 285d8941b56..98a0188f5cd 100644 --- a/crates/matrix-sdk/src/room/timeline/pagination.rs +++ b/crates/matrix-sdk/src/room/timeline/pagination.rs @@ -1,3 +1,17 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use std::fmt; /// Options for pagination. diff --git a/crates/matrix-sdk/src/room/timeline/read_receipts.rs b/crates/matrix-sdk/src/room/timeline/read_receipts.rs new file mode 100644 index 00000000000..72ef2c0c833 --- /dev/null +++ b/crates/matrix-sdk/src/room/timeline/read_receipts.rs @@ -0,0 +1,304 @@ +// Copyright 2023 KΓ©vin Commaille +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{collections::HashMap, sync::Arc}; + +use eyeball_im::ObservableVector; +use indexmap::IndexMap; +use ruma::{ + events::receipt::{Receipt, ReceiptEventContent, ReceiptThread, ReceiptType}, + EventId, OwnedEventId, OwnedUserId, UserId, +}; +use tracing::error; + +use super::{ + compare_events_positions, + inner::{RoomDataProvider, TimelineInnerState}, + rfind_event_by_id, EventTimelineItem, RelativePosition, TimelineItem, +}; +use crate::room; + +struct FullReceipt<'a> { + event_id: &'a EventId, + user_id: &'a UserId, + receipt_type: ReceiptType, + receipt: &'a Receipt, +} + +pub(super) fn handle_explicit_read_receipts( + receipt_event_content: ReceiptEventContent, + own_user_id: &UserId, + timeline_state: &mut TimelineInnerState, +) { + for (event_id, receipt_types) in receipt_event_content.0 { + for (receipt_type, receipts) in receipt_types { + // We only care about read receipts here. + if !matches!(receipt_type, ReceiptType::Read | ReceiptType::ReadPrivate) { + continue; + } + + for (user_id, receipt) in receipts { + if receipt.thread != ReceiptThread::Unthreaded { + continue; + } + + let receipt_item_pos = + rfind_event_by_id(&timeline_state.items, &event_id).map(|(pos, _)| pos); + let is_own_user_id = user_id == own_user_id; + let full_receipt = FullReceipt { + event_id: &event_id, + user_id: &user_id, + receipt_type: receipt_type.clone(), + receipt: &receipt, + }; + + let read_receipt_updated = maybe_update_read_receipt( + full_receipt, + receipt_item_pos, + is_own_user_id, + &mut timeline_state.items, + &mut timeline_state.users_read_receipts, + ); + + if read_receipt_updated && !is_own_user_id { + // Update the new item pointed to by the user's read receipt. + let new_receipt_event_item = receipt_item_pos.and_then(|pos| { + let e = timeline_state.items[pos].as_event()?.as_remote()?; + Some((pos, e.clone())) + }); + + if let Some((pos, mut remote_event_item)) = new_receipt_event_item { + remote_event_item.add_read_receipt(user_id, receipt); + timeline_state + .items + .set(pos, Arc::new(TimelineItem::Event(remote_event_item.into()))); + } + } + } + } + } +} + +/// Add an implicit read receipt to the given event item, if it is more recent +/// than the current read receipt for the sender of the event. +/// +/// According to the spec, read receipts should not point to events sent by our +/// own user, but these events are used to reset the notification count, so we +/// need to handle them locally too. For that we create an "implicit" read +/// receipt, compared to the "explicit" ones sent by the client. +pub(super) fn maybe_add_implicit_read_receipt( + item_pos: usize, + event_item: &mut EventTimelineItem, + is_own_event: bool, + timeline_items: &mut ObservableVector>, + users_read_receipts: &mut HashMap>, +) { + let EventTimelineItem::Remote(remote_event_item) = event_item else { + return; + }; + + let receipt = Receipt::new(remote_event_item.timestamp()); + let new_receipt = FullReceipt { + event_id: remote_event_item.event_id(), + user_id: remote_event_item.sender(), + receipt_type: ReceiptType::Read, + receipt: &receipt, + }; + + let read_receipt_updated = maybe_update_read_receipt( + new_receipt, + Some(item_pos), + is_own_event, + timeline_items, + users_read_receipts, + ); + if read_receipt_updated && !is_own_event { + remote_event_item.add_read_receipt(remote_event_item.sender().to_owned(), receipt); + } +} + +/// Update the timeline items with the given read receipt if it is more recent +/// than the current one. +/// +/// In the process, this method removes the corresponding receipt from its old +/// item, if applicable, and updates the `users_read_receipts` map to use the +/// new receipt. +/// +/// Returns true if the read receipt was saved. +/// +/// Currently this method only works reliably if the timeline was started from +/// the end of the timeline. +fn maybe_update_read_receipt( + receipt: FullReceipt<'_>, + new_item_pos: Option, + is_own_user_id: bool, + timeline_items: &mut ObservableVector>, + users_read_receipts: &mut HashMap>, +) -> bool { + let old_event_id = users_read_receipts + .get(receipt.user_id) + .and_then(|receipts| receipts.get(&receipt.receipt_type)) + .map(|(event_id, _)| event_id); + if old_event_id.map_or(false, |id| id == receipt.event_id) { + // Nothing to do. + return false; + } + + let old_item = old_event_id.and_then(|e| { + let (pos, item) = rfind_event_by_id(timeline_items, e)?; + Some((pos, item.as_remote()?)) + }); + + if let Some((old_receipt_pos, old_event_item)) = old_item { + let Some(new_receipt_pos) = new_item_pos else { + // The old receipt is likely more recent since we can't find the event of the + // new receipt in the timeline. Even if it isn't, we wouldn't know where to put + // it. + return false; + }; + + if old_receipt_pos > new_receipt_pos { + // The old receipt is more recent than the new one. + return false; + } + + if !is_own_user_id { + // Remove the read receipt for this user from the old event. + let mut old_event_item = old_event_item.clone(); + if !old_event_item.remove_read_receipt(receipt.user_id) { + error!( + "inconsistent state: old event item for user's read \ + receipt doesn't have a receipt for the user" + ); + } + timeline_items + .set(old_receipt_pos, Arc::new(TimelineItem::Event(old_event_item.into()))); + } + } + + // The new receipt is deemed more recent from now on because: + // - If old_receipt_item is Some, we already checked all the cases where it + // wouldn't be more recent. + // - If both old_receipt_item and new_receipt_item are None, they are both + // explicit read receipts so the server should only send us a more recent + // receipt. + // - If old_receipt_item is None and new_receipt_item is Some, the new receipt + // is likely more recent because it has a place in the timeline. + users_read_receipts + .entry(receipt.user_id.to_owned()) + .or_default() + .insert(receipt.receipt_type, (receipt.event_id.to_owned(), receipt.receipt.clone())); + + true +} + +/// Load the read receipts from the store for the given event ID. +pub(super) async fn load_read_receipts_for_event( + event_id: &EventId, + timeline_state: &mut TimelineInnerState, + room_data_provider: &P, +) -> IndexMap { + let read_receipts = room_data_provider.read_receipts_for_event(event_id).await; + + // Filter out receipts for our own user. + let own_user_id = room_data_provider.own_user_id(); + let read_receipts: IndexMap = + read_receipts.into_iter().filter(|(user_id, _)| user_id != own_user_id).collect(); + + // Keep track of the user's read receipt. + for (user_id, receipt) in read_receipts.clone() { + // Only insert the read receipt if the user is not known to avoid conflicts with + // `TimelineInner::handle_read_receipts`. + if !timeline_state.users_read_receipts.contains_key(&user_id) { + timeline_state + .users_read_receipts + .entry(user_id) + .or_default() + .insert(ReceiptType::Read, (event_id.to_owned(), receipt)); + } + } + + read_receipts +} + +/// Get the unthreaded receipt of the given type for the given user in the +/// timeline. +pub(super) async fn user_receipt( + user_id: &UserId, + receipt_type: ReceiptType, + timeline_state: &TimelineInnerState, + room: &room::Common, +) -> Option<(OwnedEventId, Receipt)> { + if let Some(receipt) = timeline_state + .users_read_receipts + .get(user_id) + .and_then(|user_map| user_map.get(&receipt_type)) + .cloned() + { + return Some(receipt); + } + + room.user_receipt(receipt_type.clone(), ReceiptThread::Unthreaded, user_id) + .await + .unwrap_or_else(|e| { + error!("Could not get user read receipt of type {receipt_type:?}: {e}"); + None + }) +} + +/// Get the latest read receipt for the given user. +/// +/// Useful to get the latest read receipt, whether it's private or public. +pub(super) async fn latest_user_read_receipt( + user_id: &UserId, + timeline_state: &TimelineInnerState, + room: &room::Common, +) -> Option<(OwnedEventId, Receipt)> { + let public_read_receipt = user_receipt(user_id, ReceiptType::Read, timeline_state, room).await; + let private_read_receipt = + user_receipt(user_id, ReceiptType::ReadPrivate, timeline_state, room).await; + + // If we only have one, return it. + let Some((pub_event_id, pub_receipt)) = &public_read_receipt else { + return private_read_receipt; + }; + let Some((priv_event_id, priv_receipt)) = &private_read_receipt else { + return public_read_receipt; + }; + + // Compare by position in the timeline. + if let Some(relative_pos) = + compare_events_positions(pub_event_id, priv_event_id, &timeline_state.items) + { + if relative_pos == RelativePosition::After { + return private_read_receipt; + } + + return public_read_receipt; + } + + // Compare by timestamp. + if let Some((pub_ts, priv_ts)) = pub_receipt.ts.zip(priv_receipt.ts) { + if priv_ts > pub_ts { + return private_read_receipt; + } + + return public_read_receipt; + } + + // As a fallback, let's assume that a private read receipt should be more recent + // than a public read receipt, otherwise there's no point in the private read + // receipt. + private_read_receipt +} diff --git a/crates/matrix-sdk/src/room/timeline/tests.rs b/crates/matrix-sdk/src/room/timeline/tests.rs deleted file mode 100644 index 430b5fc7d26..00000000000 --- a/crates/matrix-sdk/src/room/timeline/tests.rs +++ /dev/null @@ -1,939 +0,0 @@ -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Unit tests (based on private methods) for the timeline API. - -use std::{ - io, - sync::{ - atomic::{AtomicU32, Ordering::SeqCst}, - Arc, - }, -}; - -use assert_matches::assert_matches; -use async_trait::async_trait; -use chrono::{Datelike, Local, TimeZone}; -use futures_core::Stream; -use futures_signals::signal_vec::{SignalVecExt, VecDiff}; -use futures_util::StreamExt; -use matrix_sdk_base::{crypto::OlmMachine, deserialized_responses::SyncTimelineEvent}; -use matrix_sdk_test::async_test; -use once_cell::sync::Lazy; -use ruma::{ - assign, event_id, - events::{ - reaction::ReactionEventContent, - relation::{Annotation, Replacement}, - room::{ - encrypted::{ - EncryptedEventScheme, MegolmV1AesSha2ContentInit, RoomEncryptedEventContent, - }, - member::{MembershipState, RedactedRoomMemberEventContent, RoomMemberEventContent}, - message::{self, MessageType, RoomMessageEventContent}, - name::RoomNameEventContent, - topic::RedactedRoomTopicEventContent, - }, - AnyMessageLikeEventContent, EmptyStateKey, FullStateEventContent, MessageLikeEventContent, - MessageLikeEventType, RedactedStateEventContent, StateEventContent, StateEventType, - StaticStateEventContent, - }, - room_id, - serde::Raw, - server_name, uint, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedTransactionId, - TransactionId, UserId, -}; -use serde_json::{json, Value as JsonValue}; - -use super::{ - event_item::AnyOtherFullStateEventContent, inner::ProfileProvider, EncryptedMessage, - EventTimelineItem, MembershipChange, Profile, TimelineInner, TimelineItem, TimelineItemContent, - VirtualTimelineItem, -}; -use crate::{room::timeline::event_item::EventSendState, Error}; - -static ALICE: Lazy<&UserId> = Lazy::new(|| user_id!("@alice:server.name")); -static BOB: Lazy<&UserId> = Lazy::new(|| user_id!("@bob:other.server")); - -#[async_test] -async fn reaction_redaction() { - let timeline = TestTimeline::new(); - let mut stream = timeline.stream(); - - timeline.handle_live_message_event(&ALICE, RoomMessageEventContent::text_plain("hi!")).await; - let _day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let event = item.as_event().unwrap().as_remote().unwrap(); - assert_eq!(event.reactions().len(), 0); - - let msg_event_id = &event.event_id; - - let rel = Annotation::new(msg_event_id.to_owned(), "+1".to_owned()); - timeline.handle_live_message_event(&BOB, ReactionEventContent::new(rel)).await; - let item = - assert_matches!(stream.next().await, Some(VecDiff::UpdateAt { index: 1, value }) => value); - let event = item.as_event().unwrap().as_remote().unwrap(); - assert_eq!(event.reactions().len(), 1); - - // TODO: After adding raw timeline items, check for one here - - let reaction_event_id = event.event_id.as_ref(); - - timeline.handle_live_redaction(&BOB, reaction_event_id).await; - let item = - assert_matches!(stream.next().await, Some(VecDiff::UpdateAt { index: 1, value }) => value); - let event = item.as_event().unwrap().as_remote().unwrap(); - assert_eq!(event.reactions().len(), 0); -} - -#[async_test] -async fn invalid_edit() { - let timeline = TestTimeline::new(); - let mut stream = timeline.stream(); - - timeline.handle_live_message_event(&ALICE, RoomMessageEventContent::text_plain("test")).await; - let _day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let event = item.as_event().unwrap().as_remote().unwrap(); - let msg = event.content.as_message().unwrap(); - assert_eq!(msg.body(), "test"); - - let msg_event_id = &event.event_id; - - let edit = assign!(RoomMessageEventContent::text_plain(" * fake"), { - relates_to: Some(message::Relation::Replacement(Replacement::new( - msg_event_id.to_owned(), - MessageType::text_plain("fake"), - ))), - }); - // Edit is from a different user than the previous event - timeline.handle_live_message_event(&BOB, edit).await; - - // Can't easily test the non-arrival of an item using the stream. Instead - // just assert that there is still just a couple items in the timeline. - assert_eq!(timeline.inner.items().len(), 2); -} - -#[async_test] -async fn edit_redacted() { - let timeline = TestTimeline::new(); - let mut stream = timeline.stream(); - - // Ruma currently fails to serialize most redacted events correctly - timeline - .handle_live_custom_event(json!({ - "content": {}, - "event_id": "$eeG0HA0FAZ37wP8kXlNkxx3I", - "origin_server_ts": 10, - "sender": "@alice:example.org", - "type": "m.room.message", - "unsigned": { - "redacted_because": { - "content": {}, - "redacts": "$eeG0HA0FAZ37wP8kXlNkxx3K", - "event_id": "$N6eUCBc3vu58PL8TobGaVQzM", - "sender": "@alice:example.org", - "origin_server_ts": 5, - "type": "m.room.redaction", - }, - }, - })) - .await; - let _day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - - let redacted_event_id = item.as_event().unwrap().event_id().unwrap(); - - let edit = assign!(RoomMessageEventContent::text_plain(" * test"), { - relates_to: Some(message::Relation::Replacement(Replacement::new( - redacted_event_id.to_owned(), - MessageType::text_plain("test"), - ))), - }); - timeline.handle_live_message_event(&ALICE, edit).await; - - assert_eq!(timeline.inner.items().len(), 2); -} - -#[cfg(not(target_arch = "wasm32"))] -#[async_test] -async fn unable_to_decrypt() { - use std::{io::Cursor, iter}; - - use matrix_sdk_base::crypto::decrypt_room_key_export; - - const SESSION_ID: &str = "gM8i47Xhu0q52xLfgUXzanCMpLinoyVyH7R58cBuVBU"; - const SESSION_KEY: &[u8] = b"\ - -----BEGIN MEGOLM SESSION DATA-----\n\ - ASKcWoiAVUM97482UAi83Avce62hSLce7i5JhsqoF6xeAAAACqt2Cg3nyJPRWTTMXxXH7TXnkfdlmBXbQtq5\ - bpHo3LRijcq2Gc6TXilESCmJN14pIsfKRJrWjZ0squ/XsoTFytuVLWwkNaW3QF6obeg2IoVtJXLMPdw3b2vO\ - vgwGY3OMP0XafH13j1vcb6YLzvgLkZQLnYvd47hv3yK/9GmKS9tokuaQ7dCVYckYcIOS09EDTs70YdxUd5WG\ - rQynATCLFP1p/NAGv70r9MK7Cy/mNpjD0r4qC7UEDIoi1kOWzHgnLo19wtvwsb8Fg8ATxcs3Wmtj8hIUYpDx\ - ia4sM10zbytUuaPUAfCDf42IyxdmOnGe1CueXhgI71y+RW0s0argNqUt7jB70JT0o9CyX6UBGRaqLk2MPY9T\ - hUu5J8X3UgIa6rcbWigzohzWm9rdbEHFrSWqjpfQYMaAKQQgETrjSy4XTrp2RhC2oNqG/hylI4ab+F4X6fpH\ - DYP1NqNMP5g36xNu7LhDnrUB5qsPjYOmWORxGLfudpF3oLYCSlr3DgHqEIB6HjQblLZ3KQuPBse3zxyROTnS\ - AhdPH4a/z1wioFtKNVph3hecsiKEdqnz4Y2coSIdhz58mJ9JWNQoFAENE5CSsoEZAGvafYZVpW4C75YY2zq1\ - wIeiFi1dT43/jLAUGkslsi1VvnyfUu8qO404RxYO3XHoGLMFoFLOO+lZ+VGci2Vz10AhxJhEBHxRKxw4k2uB\ - HztoSJUr/2Y\n\ - -----END MEGOLM SESSION DATA-----"; - - let timeline = TestTimeline::new(); - let mut stream = timeline.stream(); - - timeline - .handle_live_message_event( - &BOB, - RoomEncryptedEventContent::new( - EncryptedEventScheme::MegolmV1AesSha2( - MegolmV1AesSha2ContentInit { - ciphertext: "\ - AwgAEtABPRMavuZMDJrPo6pGQP4qVmpcuapuXtzKXJyi3YpEsjSWdzuRKIgJzD4P\ - cSqJM1A8kzxecTQNJsC5q22+KSFEPxPnI4ltpm7GFowSoPSW9+bFdnlfUzEP1jPq\ - YevHAsMJp2fRKkzQQbPordrUk1gNqEpGl4BYFeRqKl9GPdKFwy45huvQCLNNueql\ - CFZVoYMuhxrfyMiJJAVNTofkr2um2mKjDTlajHtr39pTG8k0eOjSXkLOSdZvNOMz\ - hGhSaFNeERSA2G2YbeknOvU7MvjiO0AKuxaAe1CaVhAI14FCgzrJ8g0y5nly+n7x\ - QzL2G2Dn8EoXM5Iqj8W99iokQoVsSrUEnaQ1WnSIfewvDDt4LCaD/w7PGETMCQ" - .to_owned(), - sender_key: "DeHIg4gwhClxzFYcmNntPNF9YtsdZbmMy8+3kzCMXHA".to_owned(), - device_id: "NLAZCWIOCO".into(), - session_id: SESSION_ID.into(), - } - .into(), - ), - None, - ), - ) - .await; - - assert_eq!(timeline.inner.items().len(), 2); - - let _day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let event = item.as_event().unwrap(); - let session_id = assert_matches!( - event.content(), - TimelineItemContent::UnableToDecrypt( - EncryptedMessage::MegolmV1AesSha2 { session_id, .. }, - ) => session_id - ); - assert_eq!(session_id, SESSION_ID); - - let own_user_id = user_id!("@example:morheus.localhost"); - let exported_keys = decrypt_room_key_export(Cursor::new(SESSION_KEY), "1234").unwrap(); - - let olm_machine = OlmMachine::new(own_user_id, "SomeDeviceId".into()).await; - olm_machine.import_room_keys(exported_keys, false, |_, _| {}).await.unwrap(); - - timeline - .inner - .retry_event_decryption( - room_id!("!DovneieKSTkdHKpIXy:morpheus.localhost"), - &olm_machine, - Some(iter::once(SESSION_ID).collect()), - ) - .await; - - assert_eq!(timeline.inner.items().len(), 2); - - let item = - assert_matches!(stream.next().await, Some(VecDiff::UpdateAt { index: 1, value }) => value); - let event = item.as_event().unwrap().as_remote().unwrap(); - assert_matches!(&event.encryption_info, Some(_)); - let text = assert_matches!(&event.content, TimelineItemContent::Message(msg) => msg.body()); - assert_eq!(text, "It's a secret to everybody"); -} - -#[async_test] -async fn update_read_marker() { - let timeline = TestTimeline::new(); - let mut stream = timeline.stream(); - - timeline.handle_live_message_event(&ALICE, RoomMessageEventContent::text_plain("A")).await; - let _day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let event_id = item.as_event().unwrap().event_id().unwrap().to_owned(); - - timeline.inner.set_fully_read_event(event_id).await; - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - assert_matches!(item.as_virtual(), Some(VirtualTimelineItem::ReadMarker)); - - timeline.handle_live_message_event(&BOB, RoomMessageEventContent::text_plain("B")).await; - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let event_id = item.as_event().unwrap().event_id().unwrap().to_owned(); - - timeline.inner.set_fully_read_event(event_id.clone()).await; - assert_matches!(stream.next().await, Some(VecDiff::Move { old_index: 2, new_index: 3 })); - - // Nothing should happen if the fully read event is set back to the same event - // as before. - timeline.inner.set_fully_read_event(event_id.clone()).await; - - // Nothing should happen if the fully read event isn't found. - timeline.inner.set_fully_read_event(event_id!("$fake_event_id").to_owned()).await; - - // Nothing should happen if the fully read event is referring to an old event - // that has already been marked as fully read. - timeline.inner.set_fully_read_event(event_id).await; - - timeline.handle_live_message_event(&ALICE, RoomMessageEventContent::text_plain("C")).await; - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let event_id = item.as_event().unwrap().event_id().unwrap().to_owned(); - - timeline.inner.set_fully_read_event(event_id).await; - assert_matches!(stream.next().await, Some(VecDiff::Move { old_index: 3, new_index: 4 })); -} - -#[async_test] -async fn invalid_event_content() { - let timeline = TestTimeline::new(); - let mut stream = timeline.stream(); - - // m.room.message events must have a msgtype and body in content, so this - // event with an empty content object should fail to deserialize. - timeline - .handle_live_custom_event(json!({ - "content": {}, - "event_id": "$eeG0HA0FAZ37wP8kXlNkxx3I", - "origin_server_ts": 10, - "sender": "@alice:example.org", - "type": "m.room.message", - })) - .await; - - let _day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let event_item = item.as_event().unwrap().as_remote().unwrap(); - assert_eq!(event_item.sender, "@alice:example.org"); - assert_eq!(event_item.event_id, event_id!("$eeG0HA0FAZ37wP8kXlNkxx3I").to_owned()); - assert_eq!(event_item.timestamp, MilliSecondsSinceUnixEpoch(uint!(10))); - let event_type = assert_matches!( - &event_item.content, - TimelineItemContent::FailedToParseMessageLike { event_type, .. } => event_type - ); - assert_eq!(*event_type, MessageLikeEventType::RoomMessage); - - // Similar to above, the m.room.member state event must also not have an - // empty content object. - timeline - .handle_live_custom_event(json!({ - "content": {}, - "event_id": "$d5G0HA0FAZ37wP8kXlNkxx3I", - "origin_server_ts": 2179, - "sender": "@alice:example.org", - "type": "m.room.member", - "state_key": "@alice:example.org", - })) - .await; - - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let event_item = item.as_event().unwrap().as_remote().unwrap(); - assert_eq!(event_item.sender, "@alice:example.org"); - assert_eq!(event_item.event_id, event_id!("$d5G0HA0FAZ37wP8kXlNkxx3I").to_owned()); - assert_eq!(event_item.timestamp, MilliSecondsSinceUnixEpoch(uint!(2179))); - let (event_type, state_key) = assert_matches!( - &event_item.content, - TimelineItemContent::FailedToParseState { - event_type, - state_key, - .. - } => (event_type, state_key) - ); - assert_eq!(*event_type, StateEventType::RoomMember); - assert_eq!(state_key, "@alice:example.org"); -} - -#[async_test] -async fn invalid_event() { - let timeline = TestTimeline::new(); - - // This event is missing the sender field which the homeserver must add to - // all timeline events. Because the event is malformed, it will be ignored. - timeline - .handle_live_custom_event(json!({ - "content": { - "body": "hello world", - "msgtype": "m.text" - }, - "event_id": "$eeG0HA0FAZ37wP8kXlNkxx3I", - "origin_server_ts": 10, - "type": "m.room.message", - })) - .await; - assert_eq!(timeline.inner.items().len(), 0); -} - -#[async_test] -async fn remote_echo_full_trip() { - let timeline = TestTimeline::new(); - let mut stream = timeline.stream(); - - // Given a local event… - let txn_id = timeline - .handle_local_event(AnyMessageLikeEventContent::RoomMessage( - RoomMessageEventContent::text_plain("echo"), - )) - .await; - - let _day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - - // Scenario 1: The local event has not been sent yet to the server. - { - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let event = item.as_event().unwrap().as_local().unwrap(); - assert_matches!(event.send_state, EventSendState::NotSentYet); - } - - // Scenario 2: The local event has not been sent to the server successfully, it - // has failed. In this case, there is no event ID. - { - let some_io_error = Error::Io(io::Error::new(io::ErrorKind::Other, "this is a test")); - timeline.inner.update_event_send_state( - &txn_id, - EventSendState::SendingFailed { error: Arc::new(some_io_error) }, - ); - - let item = assert_matches!( - stream.next().await, - Some(VecDiff::UpdateAt { value, index: 1 }) => value - ); - let event = item.as_event().unwrap().as_local().unwrap(); - assert_matches!(event.send_state, EventSendState::SendingFailed { .. }); - } - - // Scenario 3: The local event has been sent successfully to the server and an - // event ID has been received as part of the server's response. - let event_id = { - let event_id = event_id!("$W6mZSLWMmfuQQ9jhZWeTxFIM"); - - timeline.inner.update_event_send_state( - &txn_id, - EventSendState::Sent { event_id: event_id.to_owned() }, - ); - - let item = assert_matches!( - stream.next().await, - Some(VecDiff::UpdateAt { value, index: 1 }) => value - ); - let event = item.as_event().unwrap().as_local().unwrap(); - assert_matches!(event.send_state, EventSendState::Sent { .. }); - - event_id - }; - - // Now, a sync has been run against the server, and an event with the same ID - // comes in. - timeline - .handle_live_custom_event(json!({ - "content": { - "body": "echo", - "msgtype": "m.text", - }, - "sender": &*ALICE, - "event_id": event_id, - "origin_server_ts": 5, - "type": "m.room.message", - })) - .await; - - // The local echo is removed - assert_matches!(stream.next().await, Some(VecDiff::Pop { .. })); - - // This day divider shouldn't be present, or the previous one should be - // removed. There being a two day dividers in a row is a bug, but it's - // non-trivial to fix and rare enough that we can fix it later (only happens - // when the first message on a given day is a local echo). - let _day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - - // … and the remote echo is added. - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - assert_matches!(item.as_event().unwrap(), EventTimelineItem::Remote(_)); -} - -#[async_test] -async fn remote_echo_new_position() { - let timeline = TestTimeline::new(); - let mut stream = timeline.stream(); - - // Given a local event… - let txn_id = timeline - .handle_local_event(AnyMessageLikeEventContent::RoomMessage( - RoomMessageEventContent::text_plain("echo"), - )) - .await; - - let _day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let txn_id_from_event = item.as_event().unwrap().as_local().unwrap(); - assert_eq!(txn_id, *txn_id_from_event.transaction_id); - - // … and another event that comes back before the remote echo - timeline.handle_live_message_event(&BOB, RoomMessageEventContent::text_plain("test")).await; - let _day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let _bob_message = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - - // When the remote echo comes in… - timeline - .handle_live_custom_event(json!({ - "content": { - "body": "echo", - "msgtype": "m.text", - }, - "sender": &*ALICE, - "event_id": "$eeG0HA0FAZ37wP8kXlNkxx3I", - "origin_server_ts": 6, - "type": "m.room.message", - "unsigned": { - "transaction_id": txn_id, - }, - })) - .await; - - // … the local echo should be removed - assert_matches!(stream.next().await, Some(VecDiff::RemoveAt { index: 1 })); - - // … and the remote echo added - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - assert_matches!(item.as_event().unwrap(), EventTimelineItem::Remote(_)); -} - -#[async_test] -async fn day_divider() { - let timeline = TestTimeline::new(); - let mut stream = timeline.stream(); - - timeline - .handle_live_custom_event(json!({ - "content": { - "msgtype": "m.text", - "body": "This is a first message on the first day" - }, - "event_id": "$eeG0HA0FAZ37wP8kXlNkxx3I", - "origin_server_ts": 1669897395000u64, - "sender": "@alice:example.org", - "type": "m.room.message", - })) - .await; - - let day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let ts = assert_matches!( - day_divider.as_virtual().unwrap(), - VirtualTimelineItem::DayDivider(ts) => *ts - ); - let date = Local.timestamp_millis_opt(ts.0.into()).single().unwrap(); - assert_eq!(date.year(), 2022); - assert_eq!(date.month(), 12); - assert_eq!(date.day(), 1); - - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - item.as_event().unwrap(); - - timeline - .handle_live_custom_event(json!({ - "content": { - "msgtype": "m.text", - "body": "This is a second message on the first day" - }, - "event_id": "$feG0HA0FAZ37wP8kXlNkxx3I", - "origin_server_ts": 1669906604000u64, - "sender": "@alice:example.org", - "type": "m.room.message", - })) - .await; - - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - item.as_event().unwrap(); - - timeline - .handle_live_custom_event(json!({ - "content": { - "msgtype": "m.text", - "body": "This is a first message on the next day" - }, - "event_id": "$geG0HA0FAZ37wP8kXlNkxx3I", - "origin_server_ts": 1669992963000u64, - "sender": "@alice:example.org", - "type": "m.room.message", - })) - .await; - - let day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let ts = assert_matches!( - day_divider.as_virtual().unwrap(), - VirtualTimelineItem::DayDivider(ts) => *ts - ); - let date = Local.timestamp_millis_opt(ts.0.into()).single().unwrap(); - assert_eq!(date.year(), 2022); - assert_eq!(date.month(), 12); - assert_eq!(date.day(), 2); - - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - item.as_event().unwrap(); - - let _ = timeline - .handle_local_event(AnyMessageLikeEventContent::RoomMessage( - RoomMessageEventContent::text_plain("A message I'm sending just now"), - )) - .await; - - // The other events are in the past so a local event always creates a new day - // divider. - let day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - assert_matches!(day_divider.as_virtual().unwrap(), VirtualTimelineItem::DayDivider { .. }); - - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - item.as_event().unwrap(); -} - -#[async_test] -async fn sticker() { - let timeline = TestTimeline::new(); - let mut stream = timeline.stream(); - - timeline - .handle_live_custom_event(json!({ - "content": { - "body": "Happy sticker", - "info": { - "h": 398, - "mimetype": "image/jpeg", - "size": 31037, - "w": 394 - }, - "url": "mxc://server.name/JWEIFJgwEIhweiWJE", - }, - "event_id": "$143273582443PhrSn", - "origin_server_ts": 143273582, - "sender": "@alice:server.name", - "type": "m.sticker", - })) - .await; - - let _day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::Sticker(_)); -} - -#[async_test] -async fn initial_events() { - let timeline = TestTimeline::with_initial_events([ - (*ALICE, RoomMessageEventContent::text_plain("A").into()), - (*BOB, RoomMessageEventContent::text_plain("B").into()), - ]) - .await; - let mut stream = timeline.stream(); - - let items = assert_matches!(stream.next().await, Some(VecDiff::Replace { values }) => values); - assert_eq!(items.len(), 3); - assert_matches!(items[0].as_virtual().unwrap(), VirtualTimelineItem::DayDivider { .. }); - assert_eq!(items[1].as_event().unwrap().sender(), *ALICE); - assert_eq!(items[2].as_event().unwrap().sender(), *BOB); -} - -#[async_test] -async fn other_state() { - let timeline = TestTimeline::new(); - let mut stream = timeline.stream(); - - timeline - .handle_live_original_state_event( - &ALICE, - RoomNameEventContent::new(Some("Alice's room".to_owned())), - None, - ) - .await; - - let _day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let ev = assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::OtherState(ev) => ev); - let full_content = - assert_matches!(ev.content(), AnyOtherFullStateEventContent::RoomName(c) => c); - let (content, prev_content) = assert_matches!(full_content, FullStateEventContent::Original { content, prev_content } => (content, prev_content)); - assert_eq!(content.name.as_ref().unwrap(), "Alice's room"); - assert_matches!(prev_content, None); - - timeline.handle_live_redacted_state_event(&ALICE, RedactedRoomTopicEventContent::new()).await; - - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let ev = assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::OtherState(ev) => ev); - let full_content = - assert_matches!(ev.content(), AnyOtherFullStateEventContent::RoomTopic(c) => c); - assert_matches!(full_content, FullStateEventContent::Redacted(_)); -} - -#[async_test] -async fn room_member() { - let timeline = TestTimeline::new(); - let mut stream = timeline.stream(); - - let mut first_room_member_content = RoomMemberEventContent::new(MembershipState::Invite); - first_room_member_content.displayname = Some("Alice".to_owned()); - timeline - .handle_live_original_state_event_with_state_key( - &BOB, - ALICE.to_owned(), - first_room_member_content.clone(), - None, - ) - .await; - - let _day_divider = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let membership = assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::MembershipChange(ev) => ev); - assert_matches!(membership.content(), FullStateEventContent::Original { .. }); - assert_matches!(membership.change(), Some(MembershipChange::Invited)); - - let mut second_room_member_content = RoomMemberEventContent::new(MembershipState::Join); - second_room_member_content.displayname = Some("Alice".to_owned()); - timeline - .handle_live_original_state_event_with_state_key( - &ALICE, - ALICE.to_owned(), - second_room_member_content.clone(), - Some(first_room_member_content), - ) - .await; - - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let membership = assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::MembershipChange(ev) => ev); - assert_matches!(membership.content(), FullStateEventContent::Original { .. }); - assert_matches!(membership.change(), Some(MembershipChange::InvitationAccepted)); - - let mut third_room_member_content = RoomMemberEventContent::new(MembershipState::Join); - third_room_member_content.displayname = Some("Alice In Wonderland".to_owned()); - timeline - .handle_live_original_state_event_with_state_key( - &ALICE, - ALICE.to_owned(), - third_room_member_content, - Some(second_room_member_content), - ) - .await; - - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let profile = assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::ProfileChange(ev) => ev); - assert_matches!(profile.displayname_change(), Some(_)); - assert_matches!(profile.avatar_url_change(), None); - - timeline - .handle_live_redacted_state_event_with_state_key( - &ALICE, - ALICE.to_owned(), - RedactedRoomMemberEventContent::new(MembershipState::Join), - ) - .await; - - let item = assert_matches!(stream.next().await, Some(VecDiff::Push { value }) => value); - let membership = assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::MembershipChange(ev) => ev); - assert_matches!(membership.content(), FullStateEventContent::Redacted(_)); - assert_matches!(membership.change(), None); -} - -struct TestTimeline { - inner: TimelineInner, -} - -impl TestTimeline { - fn new() -> Self { - Self { inner: TimelineInner::new(TestProfileProvider) } - } - - async fn with_initial_events<'a>( - events: impl IntoIterator, - ) -> Self { - let mut inner = TimelineInner::new(TestProfileProvider); - inner - .add_initial_events( - events - .into_iter() - .map(|(sender, content)| { - let event = - serde_json::from_value(make_message_event(sender, content)).unwrap(); - SyncTimelineEvent { event, encryption_info: None } - }) - .collect(), - ) - .await; - - Self { inner } - } - - fn stream(&self) -> impl Stream>> { - self.inner.items_signal().to_stream() - } - - async fn handle_live_message_event(&self, sender: &UserId, content: C) - where - C: MessageLikeEventContent, - { - let ev = make_message_event(sender, content); - let raw = Raw::new(&ev).unwrap().cast(); - self.inner.handle_live_event(raw, None).await; - } - - async fn handle_live_original_state_event( - &self, - sender: &UserId, - content: C, - prev_content: Option, - ) where - C: StaticStateEventContent, - { - let ev = make_state_event(sender, "", content, prev_content); - let raw = Raw::new(&ev).unwrap().cast(); - self.inner.handle_live_event(raw, None).await; - } - - async fn handle_live_original_state_event_with_state_key( - &self, - sender: &UserId, - state_key: C::StateKey, - content: C, - prev_content: Option, - ) where - C: StaticStateEventContent, - { - let ev = make_state_event(sender, state_key.as_ref(), content, prev_content); - let raw = Raw::new(&ev).unwrap().cast(); - self.inner.handle_live_event(raw, None).await; - } - - async fn handle_live_redacted_state_event(&self, sender: &UserId, content: C) - where - C: RedactedStateEventContent, - { - let ev = make_redacted_state_event(sender, "", content); - let raw = Raw::new(&ev).unwrap().cast(); - self.inner.handle_live_event(raw, None).await; - } - - async fn handle_live_redacted_state_event_with_state_key( - &self, - sender: &UserId, - state_key: C::StateKey, - content: C, - ) where - C: RedactedStateEventContent, - { - let ev = make_redacted_state_event(sender, state_key.as_ref(), content); - let raw = Raw::new(&ev).unwrap().cast(); - self.inner.handle_live_event(raw, None).await; - } - - async fn handle_live_custom_event(&self, event: JsonValue) { - let raw = Raw::new(&event).unwrap().cast(); - self.inner.handle_live_event(raw, None).await; - } - - async fn handle_live_redaction(&self, sender: &UserId, redacts: &EventId) { - let ev = json!({ - "type": "m.room.redaction", - "content": {}, - "redacts": redacts, - "event_id": EventId::new(server_name!("dummy.server")), - "sender": sender, - "origin_server_ts": next_server_ts(), - }); - let raw = Raw::new(&ev).unwrap().cast(); - self.inner.handle_live_event(raw, None).await; - } - - async fn handle_local_event(&self, content: AnyMessageLikeEventContent) -> OwnedTransactionId { - let txn_id = TransactionId::new(); - self.inner.handle_local_event(txn_id.clone(), content).await; - txn_id - } -} - -struct TestProfileProvider; - -#[async_trait] -impl ProfileProvider for TestProfileProvider { - fn own_user_id(&self) -> &UserId { - &ALICE - } - - async fn profile(&self, _user_id: &UserId) -> Option { - None - } -} - -fn make_message_event(sender: &UserId, content: C) -> JsonValue { - json!({ - "type": content.event_type(), - "content": content, - "event_id": EventId::new(server_name!("dummy.server")), - "sender": sender, - "origin_server_ts": next_server_ts(), - }) -} - -fn make_state_event( - sender: &UserId, - state_key: &str, - content: C, - prev_content: Option, -) -> JsonValue { - let unsigned = if let Some(prev_content) = prev_content { - json!({ "prev_content": prev_content }) - } else { - json!({}) - }; - - json!({ - "type": content.event_type(), - "state_key": state_key, - "content": content, - "event_id": EventId::new(server_name!("dummy.server")), - "sender": sender, - "origin_server_ts": next_server_ts(), - "unsigned": unsigned, - }) -} - -fn make_redacted_state_event( - sender: &UserId, - state_key: &str, - content: C, -) -> JsonValue { - json!({ - "type": content.event_type(), - "state_key": state_key, - "content": content, - "event_id": EventId::new(server_name!("dummy.server")), - "sender": sender, - "origin_server_ts": next_server_ts(), - "unsigned": make_redacted_unsigned(sender), - }) -} - -fn make_redacted_unsigned(sender: &UserId) -> JsonValue { - json!({ - "redacted_because": { - "content": {}, - "event_id": EventId::new(server_name!("dummy.server")), - "sender": sender, - "origin_server_ts": next_server_ts(), - "type": "m.room.redaction", - }, - }) -} - -fn next_server_ts() -> MilliSecondsSinceUnixEpoch { - static NEXT_TS: AtomicU32 = AtomicU32::new(0); - MilliSecondsSinceUnixEpoch(NEXT_TS.fetch_add(1, SeqCst).into()) -} diff --git a/crates/matrix-sdk/src/room/timeline/tests/basic.rs b/crates/matrix-sdk/src/room/timeline/tests/basic.rs new file mode 100644 index 00000000000..e8ead03fec0 --- /dev/null +++ b/crates/matrix-sdk/src/room/timeline/tests/basic.rs @@ -0,0 +1,283 @@ +use assert_matches::assert_matches; +use eyeball_im::VectorDiff; +use futures_util::StreamExt; +use im::vector; +use matrix_sdk_base::deserialized_responses::SyncTimelineEvent; +use matrix_sdk_test::async_test; +use ruma::{ + assign, + events::{ + reaction::ReactionEventContent, + relation::{Annotation, Replacement}, + room::{ + member::{MembershipState, RedactedRoomMemberEventContent, RoomMemberEventContent}, + message::{ + self, MessageType, RedactedRoomMessageEventContent, RoomMessageEventContent, + }, + name::RoomNameEventContent, + topic::RedactedRoomTopicEventContent, + }, + FullStateEventContent, + }, +}; +use serde_json::{json, Value as JsonValue}; + +use super::{TestTimeline, ALICE, BOB}; +use crate::room::timeline::{ + event_item::AnyOtherFullStateEventContent, MembershipChange, TimelineItem, TimelineItemContent, + VirtualTimelineItem, +}; + +fn sync_timeline_event(event: JsonValue) -> SyncTimelineEvent { + let event = serde_json::from_value(event).unwrap(); + SyncTimelineEvent { event, encryption_info: None, push_actions: Vec::default() } +} + +#[async_test] +async fn initial_events() { + let mut timeline = TestTimeline::new(); + let mut stream = timeline.subscribe().await; + + timeline + .inner + .add_initial_events(vector![ + sync_timeline_event( + timeline.make_message_event(*ALICE, RoomMessageEventContent::text_plain("A")), + ), + sync_timeline_event( + timeline.make_message_event(*BOB, RoomMessageEventContent::text_plain("B")), + ), + ]) + .await; + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + assert_matches!(&*item, TimelineItem::Virtual(VirtualTimelineItem::DayDivider(_))); + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + assert_eq!(item.as_event().unwrap().sender(), *ALICE); + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + assert_eq!(item.as_event().unwrap().sender(), *BOB); +} + +#[async_test] +async fn reaction_redaction() { + let timeline = TestTimeline::new(); + let mut stream = timeline.subscribe().await; + + timeline.handle_live_message_event(&ALICE, RoomMessageEventContent::text_plain("hi!")).await; + let _day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let event = item.as_event().unwrap().as_remote().unwrap(); + assert_eq!(event.reactions().len(), 0); + + let msg_event_id = event.event_id(); + + let rel = Annotation::new(msg_event_id.to_owned(), "+1".to_owned()); + timeline.handle_live_message_event(&BOB, ReactionEventContent::new(rel)).await; + let item = + assert_matches!(stream.next().await, Some(VectorDiff::Set { index: 1, value }) => value); + let event = item.as_event().unwrap().as_remote().unwrap(); + assert_eq!(event.reactions().len(), 1); + + // TODO: After adding raw timeline items, check for one here + + let reaction_event_id = event.event_id(); + + timeline.handle_live_redaction(&BOB, reaction_event_id).await; + let item = + assert_matches!(stream.next().await, Some(VectorDiff::Set { index: 1, value }) => value); + let event = item.as_event().unwrap().as_remote().unwrap(); + assert_eq!(event.reactions().len(), 0); +} + +#[async_test] +async fn edit_redacted() { + let timeline = TestTimeline::new(); + let mut stream = timeline.subscribe().await; + + timeline + .handle_live_redacted_message_event(*ALICE, RedactedRoomMessageEventContent::new()) + .await; + let _day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + + let redacted_event_id = item.as_event().unwrap().event_id().unwrap(); + + let edit = assign!(RoomMessageEventContent::text_plain(" * test"), { + relates_to: Some(message::Relation::Replacement(Replacement::new( + redacted_event_id.to_owned(), + MessageType::text_plain("test"), + ))), + }); + timeline.handle_live_message_event(&ALICE, edit).await; + + assert_eq!(timeline.inner.items().await.len(), 2); +} + +#[async_test] +async fn sticker() { + let timeline = TestTimeline::new(); + let mut stream = timeline.subscribe().await; + + timeline + .handle_live_custom_event(json!({ + "content": { + "body": "Happy sticker", + "info": { + "h": 398, + "mimetype": "image/jpeg", + "size": 31037, + "w": 394 + }, + "url": "mxc://server.name/JWEIFJgwEIhweiWJE", + }, + "event_id": "$143273582443PhrSn", + "origin_server_ts": 143273582, + "sender": "@alice:server.name", + "type": "m.sticker", + })) + .await; + + let _day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::Sticker(_)); +} + +#[async_test] +async fn room_member() { + let timeline = TestTimeline::new(); + let mut stream = timeline.subscribe().await; + + let mut first_room_member_content = RoomMemberEventContent::new(MembershipState::Invite); + first_room_member_content.displayname = Some("Alice".to_owned()); + timeline + .handle_live_state_event_with_state_key( + &BOB, + ALICE.to_owned(), + first_room_member_content.clone(), + None, + ) + .await; + + let _day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let membership = assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::MembershipChange(ev) => ev); + assert_matches!(membership.content(), FullStateEventContent::Original { .. }); + assert_matches!(membership.change(), Some(MembershipChange::Invited)); + + let mut second_room_member_content = RoomMemberEventContent::new(MembershipState::Join); + second_room_member_content.displayname = Some("Alice".to_owned()); + timeline + .handle_live_state_event_with_state_key( + &ALICE, + ALICE.to_owned(), + second_room_member_content.clone(), + Some(first_room_member_content), + ) + .await; + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let membership = assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::MembershipChange(ev) => ev); + assert_matches!(membership.content(), FullStateEventContent::Original { .. }); + assert_matches!(membership.change(), Some(MembershipChange::InvitationAccepted)); + + let mut third_room_member_content = RoomMemberEventContent::new(MembershipState::Join); + third_room_member_content.displayname = Some("Alice In Wonderland".to_owned()); + timeline + .handle_live_state_event_with_state_key( + &ALICE, + ALICE.to_owned(), + third_room_member_content, + Some(second_room_member_content), + ) + .await; + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let profile = assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::ProfileChange(ev) => ev); + assert_matches!(profile.displayname_change(), Some(_)); + assert_matches!(profile.avatar_url_change(), None); + + timeline + .handle_live_redacted_state_event_with_state_key( + &ALICE, + ALICE.to_owned(), + RedactedRoomMemberEventContent::new(MembershipState::Join), + ) + .await; + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let membership = assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::MembershipChange(ev) => ev); + assert_matches!(membership.content(), FullStateEventContent::Redacted(_)); + assert_matches!(membership.change(), None); +} + +#[async_test] +async fn other_state() { + let timeline = TestTimeline::new(); + let mut stream = timeline.subscribe().await; + + timeline + .handle_live_state_event( + &ALICE, + RoomNameEventContent::new(Some("Alice's room".to_owned())), + None, + ) + .await; + + let _day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let ev = assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::OtherState(ev) => ev); + let full_content = + assert_matches!(ev.content(), AnyOtherFullStateEventContent::RoomName(c) => c); + let (content, prev_content) = assert_matches!(full_content, FullStateEventContent::Original { content, prev_content } => (content, prev_content)); + assert_eq!(content.name.as_ref().unwrap(), "Alice's room"); + assert_matches!(prev_content, None); + + timeline.handle_live_redacted_state_event(&ALICE, RedactedRoomTopicEventContent::new()).await; + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let ev = assert_matches!(item.as_event().unwrap().content(), TimelineItemContent::OtherState(ev) => ev); + let full_content = + assert_matches!(ev.content(), AnyOtherFullStateEventContent::RoomTopic(c) => c); + assert_matches!(full_content, FullStateEventContent::Redacted(_)); +} + +#[async_test] +async fn dedup_pagination() { + let timeline = TestTimeline::new(); + + let event = timeline.make_message_event(*ALICE, RoomMessageEventContent::text_plain("o/")); + timeline.handle_live_custom_event(event.clone()).await; + timeline.handle_back_paginated_custom_event(event).await; + + let timeline_items = timeline.inner.items().await; + assert_eq!(timeline_items.len(), 2); + assert_matches!(*timeline_items[0], TimelineItem::Virtual(VirtualTimelineItem::DayDivider(_))); + assert_matches!(*timeline_items[1], TimelineItem::Event(_)); +} + +#[async_test] +async fn dedup_initial() { + let mut timeline = TestTimeline::new(); + + let event_a = sync_timeline_event( + timeline.make_message_event(*ALICE, RoomMessageEventContent::text_plain("A")), + ); + let event_b = sync_timeline_event( + timeline.make_message_event(*BOB, RoomMessageEventContent::text_plain("B")), + ); + + timeline.inner.add_initial_events(vector![event_a.clone(), event_b, event_a]).await; + + let timeline_items = timeline.inner.items().await; + assert_eq!(timeline_items.len(), 3); + assert_eq!(timeline_items[1].as_event().unwrap().sender(), *BOB); + assert_eq!(timeline_items[2].as_event().unwrap().sender(), *ALICE); +} diff --git a/crates/matrix-sdk/src/room/timeline/tests/echo.rs b/crates/matrix-sdk/src/room/timeline/tests/echo.rs new file mode 100644 index 00000000000..04afb758c75 --- /dev/null +++ b/crates/matrix-sdk/src/room/timeline/tests/echo.rs @@ -0,0 +1,157 @@ +use std::{io, sync::Arc}; + +use assert_matches::assert_matches; +use eyeball_im::VectorDiff; +use futures_util::StreamExt; +use matrix_sdk_test::async_test; +use ruma::{ + event_id, + events::{room::message::RoomMessageEventContent, AnyMessageLikeEventContent}, +}; +use serde_json::json; + +use super::{TestTimeline, ALICE, BOB}; +use crate::{ + room::timeline::{event_item::EventSendState, EventTimelineItem}, + Error, +}; + +#[async_test] +async fn remote_echo_full_trip() { + let timeline = TestTimeline::new(); + let mut stream = timeline.subscribe().await; + + // Given a local event… + let txn_id = timeline + .handle_local_event(AnyMessageLikeEventContent::RoomMessage( + RoomMessageEventContent::text_plain("echo"), + )) + .await; + + let _day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + + // Scenario 1: The local event has not been sent yet to the server. + { + let item = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let event = item.as_event().unwrap().as_local().unwrap(); + assert_matches!(event.send_state(), EventSendState::NotSentYet); + } + + // Scenario 2: The local event has not been sent to the server successfully, it + // has failed. In this case, there is no event ID. + { + let some_io_error = Error::Io(io::Error::new(io::ErrorKind::Other, "this is a test")); + timeline + .inner + .update_event_send_state( + &txn_id, + EventSendState::SendingFailed { error: Arc::new(some_io_error) }, + ) + .await; + + let item = assert_matches!( + stream.next().await, + Some(VectorDiff::Set { value, index: 1 }) => value + ); + let event = item.as_event().unwrap().as_local().unwrap(); + assert_matches!(event.send_state(), EventSendState::SendingFailed { .. }); + } + + // Scenario 3: The local event has been sent successfully to the server and an + // event ID has been received as part of the server's response. + let event_id = event_id!("$W6mZSLWMmfuQQ9jhZWeTxFIM"); + let timestamp = { + timeline + .inner + .update_event_send_state( + &txn_id, + EventSendState::Sent { event_id: event_id.to_owned() }, + ) + .await; + + let item = assert_matches!( + stream.next().await, + Some(VectorDiff::Set { value, index: 1 }) => value + ); + let event_item = item.as_event().unwrap().as_local().unwrap(); + assert_matches!(event_item.send_state(), EventSendState::Sent { .. }); + + event_item.timestamp() + }; + + // Now, a sync has been run against the server, and an event with the same ID + // comes in. + timeline + .handle_live_custom_event(json!({ + "content": { + "body": "echo", + "msgtype": "m.text", + }, + "sender": &*ALICE, + "event_id": event_id, + "origin_server_ts": timestamp, + "type": "m.room.message", + })) + .await; + + // The local echo is replaced with the remote echo + let item = + assert_matches!(stream.next().await, Some(VectorDiff::Set { index: 1, value }) => value); + assert_matches!(item.as_event().unwrap(), EventTimelineItem::Remote(_)); +} + +#[async_test] +async fn remote_echo_new_position() { + let timeline = TestTimeline::new(); + let mut stream = timeline.subscribe().await; + + // Given a local event… + let txn_id = timeline + .handle_local_event(AnyMessageLikeEventContent::RoomMessage( + RoomMessageEventContent::text_plain("echo"), + )) + .await; + + let _day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let txn_id_from_event = item.as_event().unwrap().as_local().unwrap(); + assert_eq!(txn_id, *txn_id_from_event.transaction_id()); + + // … and another event that comes back before the remote echo + timeline.handle_live_message_event(&BOB, RoomMessageEventContent::text_plain("test")).await; + let _day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let _bob_message = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + + // When the remote echo comes in… + timeline + .handle_live_custom_event(json!({ + "content": { + "body": "echo", + "msgtype": "m.text", + }, + "sender": &*ALICE, + "event_id": "$eeG0HA0FAZ37wP8kXlNkxx3I", + "origin_server_ts": 6, + "type": "m.room.message", + "unsigned": { + "transaction_id": txn_id, + }, + })) + .await; + + // … the local echo should be removed + assert_matches!(stream.next().await, Some(VectorDiff::Remove { index: 1 })); + // … along with its day divider + assert_matches!(stream.next().await, Some(VectorDiff::Remove { index: 0 })); + + // … and the remote echo added (no new day divider because both bob's and + // alice's message are from the same day according to server timestamps) + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + assert_matches!(item.as_event().unwrap(), EventTimelineItem::Remote(_)); +} diff --git a/crates/matrix-sdk/src/room/timeline/tests/encryption.rs b/crates/matrix-sdk/src/room/timeline/tests/encryption.rs new file mode 100644 index 00000000000..3deba7770b1 --- /dev/null +++ b/crates/matrix-sdk/src/room/timeline/tests/encryption.rs @@ -0,0 +1,312 @@ +#![cfg(not(target_arch = "wasm32"))] + +use std::{collections::BTreeSet, io::Cursor, iter}; + +use assert_matches::assert_matches; +use eyeball_im::VectorDiff; +use futures_util::StreamExt; +use matrix_sdk_base::crypto::{decrypt_room_key_export, OlmMachine}; +use matrix_sdk_test::async_test; +use ruma::{ + assign, + events::room::encrypted::{ + EncryptedEventScheme, MegolmV1AesSha2ContentInit, Relation, Replacement, + RoomEncryptedEventContent, + }, + room_id, user_id, +}; + +use super::{TestTimeline, BOB}; +use crate::room::timeline::{EncryptedMessage, TimelineItemContent}; + +#[async_test] +async fn retry_message_decryption() { + const SESSION_ID: &str = "gM8i47Xhu0q52xLfgUXzanCMpLinoyVyH7R58cBuVBU"; + const SESSION_KEY: &[u8] = b"\ + -----BEGIN MEGOLM SESSION DATA-----\n\ + ASKcWoiAVUM97482UAi83Avce62hSLce7i5JhsqoF6xeAAAACqt2Cg3nyJPRWTTMXxXH7TXnkfdlmBXbQtq5\ + bpHo3LRijcq2Gc6TXilESCmJN14pIsfKRJrWjZ0squ/XsoTFytuVLWwkNaW3QF6obeg2IoVtJXLMPdw3b2vO\ + vgwGY3OMP0XafH13j1vcb6YLzvgLkZQLnYvd47hv3yK/9GmKS9tokuaQ7dCVYckYcIOS09EDTs70YdxUd5WG\ + rQynATCLFP1p/NAGv70r9MK7Cy/mNpjD0r4qC7UEDIoi1kOWzHgnLo19wtvwsb8Fg8ATxcs3Wmtj8hIUYpDx\ + ia4sM10zbytUuaPUAfCDf42IyxdmOnGe1CueXhgI71y+RW0s0argNqUt7jB70JT0o9CyX6UBGRaqLk2MPY9T\ + hUu5J8X3UgIa6rcbWigzohzWm9rdbEHFrSWqjpfQYMaAKQQgETrjSy4XTrp2RhC2oNqG/hylI4ab+F4X6fpH\ + DYP1NqNMP5g36xNu7LhDnrUB5qsPjYOmWORxGLfudpF3oLYCSlr3DgHqEIB6HjQblLZ3KQuPBse3zxyROTnS\ + AhdPH4a/z1wioFtKNVph3hecsiKEdqnz4Y2coSIdhz58mJ9JWNQoFAENE5CSsoEZAGvafYZVpW4C75YY2zq1\ + wIeiFi1dT43/jLAUGkslsi1VvnyfUu8qO404RxYO3XHoGLMFoFLOO+lZ+VGci2Vz10AhxJhEBHxRKxw4k2uB\ + HztoSJUr/2Y\n\ + -----END MEGOLM SESSION DATA-----"; + + let timeline = TestTimeline::new(); + let mut stream = timeline.subscribe().await; + + timeline + .handle_live_message_event( + &BOB, + RoomEncryptedEventContent::new( + EncryptedEventScheme::MegolmV1AesSha2( + MegolmV1AesSha2ContentInit { + ciphertext: "\ + AwgAEtABPRMavuZMDJrPo6pGQP4qVmpcuapuXtzKXJyi3YpEsjSWdzuRKIgJzD4P\ + cSqJM1A8kzxecTQNJsC5q22+KSFEPxPnI4ltpm7GFowSoPSW9+bFdnlfUzEP1jPq\ + YevHAsMJp2fRKkzQQbPordrUk1gNqEpGl4BYFeRqKl9GPdKFwy45huvQCLNNueql\ + CFZVoYMuhxrfyMiJJAVNTofkr2um2mKjDTlajHtr39pTG8k0eOjSXkLOSdZvNOMz\ + hGhSaFNeERSA2G2YbeknOvU7MvjiO0AKuxaAe1CaVhAI14FCgzrJ8g0y5nly+n7x\ + QzL2G2Dn8EoXM5Iqj8W99iokQoVsSrUEnaQ1WnSIfewvDDt4LCaD/w7PGETMCQ" + .to_owned(), + sender_key: "DeHIg4gwhClxzFYcmNntPNF9YtsdZbmMy8+3kzCMXHA".to_owned(), + device_id: "NLAZCWIOCO".into(), + session_id: SESSION_ID.into(), + } + .into(), + ), + None, + ), + ) + .await; + + assert_eq!(timeline.inner.items().await.len(), 2); + + let _day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let event = item.as_event().unwrap(); + let session_id = assert_matches!( + event.content(), + TimelineItemContent::UnableToDecrypt( + EncryptedMessage::MegolmV1AesSha2 { session_id, .. }, + ) => session_id + ); + assert_eq!(session_id, SESSION_ID); + + let own_user_id = user_id!("@example:morheus.localhost"); + let exported_keys = decrypt_room_key_export(Cursor::new(SESSION_KEY), "1234").unwrap(); + + let olm_machine = OlmMachine::new(own_user_id, "SomeDeviceId".into()).await; + olm_machine.import_room_keys(exported_keys, false, |_, _| {}).await.unwrap(); + + timeline + .inner + .retry_event_decryption( + room_id!("!DovneieKSTkdHKpIXy:morpheus.localhost"), + &olm_machine, + Some(iter::once(SESSION_ID).collect()), + ) + .await; + + assert_eq!(timeline.inner.items().await.len(), 2); + + let item = + assert_matches!(stream.next().await, Some(VectorDiff::Set { index: 1, value }) => value); + let event = item.as_event().unwrap().as_remote().unwrap(); + assert_matches!(event.encryption_info(), Some(_)); + let text = assert_matches!(event.content(), TimelineItemContent::Message(msg) => msg.body()); + assert_eq!(text, "It's a secret to everybody"); +} + +#[async_test] +async fn retry_edit_decryption() { + const SESSION1_KEY: &[u8] = b"\ + -----BEGIN MEGOLM SESSION DATA-----\n\ + AXou7bY+PWm0GrxTioyoKTkxAgfrQ5lGIla62WoBMrqWAAAACgXidLIt0gaK5NT3mGigzFAPjh/M0ibXjSvo\ + P9haNoJN2839XPCqHpErqje9x25Vy830vQXu9OpwT/QNgVXoffK6rXvIMvom6V2ElopBSVVHqgJdfqRrlGKH\ + okfW6AE+ApVPk31BclxuUuxCy+Ph9sWBTW3MA64YGog5Ddp2PAz2Vk/iZ9Dcmtf5CDLbhIRsWiLuSEvO56ok\ + 8/ZxCsiuI4SXx+hikBs+krMTIHn74NL5ffpIlnPSOVtbiY49wE1SRwVgdeJUO9qjHpQX3fZKldBBC01l0BuB\ + WK+W/f/LlOPgLr9Eac/u66fCK6Y81ziJOyn3l1wQuu3MQvuuJfwOqcljl47/yg6SaoTYhZ3ytHXkkBtYx0E6\ + h+J3dgXvW6r0prqci/0gljDQR7KtWEUhXb0BwPK7ojRZWBIzF9T/5uKOio/hBZJ7MQHXt8S2HGOB+gKuzrG8\ + azLt5EB48zgeciNlvQ5zh+AltVEErbyENhCAOxEMoO2sTjK1WZ58ZZmti8uaEZ2mJOCciAp6QiFFDnx2FiPv\ + 5DN4g22qr4A2Z4rFZNgum4VosoDA8hBvqr+G9TN5ZxVyi4IPOlqv7ycf6WGOLB6022HmZMX74KHlimDtiYlv\ + G6q7EyfpmeT5rKs51f83rQNkRzcNXKlK83YwIBxCdv9EQXZ4WATAvRqeVF8/m2qpv58zIHjLmq7irckNDmPF\ + W8aUFGxYXuU\n\ + -----END MEGOLM SESSION DATA-----"; + + const SESSION2_KEY: &[u8] = b"\ + -----BEGIN MEGOLM SESSION DATA-----\n\ + AbMgil4w2zS9PcZ25f+vdcBdv0/YVaOg52K49DwCmMUkAAAAChEzP9tvnK3jd0NA+BjFfm0zzHYOiu5EyRK/\ + F+2mFmC5vYzSiT6Zcx3dn23cU+BpmkCH/HxFli1TMZ29jLZt/ri6FgwRZtkNqmcRDnPi18xnY1GTDFYtdZEZ\ + 8Fv4L29JVOWLgEIGRdH1ct8HAqxxgSCAEcuVY7ns8xjGWKrX6gs2yanF9vUbdMyRHzBqgytzwnXl+sg5TvQS\ + a5Hh8D0eGewv0gWzUVh4PIhpwTxbEJ97k6Dklq2UneJiBo4kmna4uCRz3khq69k0kajIEiqT6eZtwIz0lDDT\ + V+MQz7YUKkFI6Th88VL9/eehcnuYQgefEEbHeb3zvoA6LSJGpvJEPcHaVNpFgnxNlQaDowtb5XMGZfI/YU4O\ + exTiEdtbYSjGnwDEuVUXtFfHCElvrBhvO3MAiXrk1QbZRNzyNUvU+1+ZmPc0IBsDHJiCN/15MKuEWF9kKqt+\ + 9FsFoRnKbXwUfDk9azdOtzymiel6xiD7kr5RTEmyxBIbTQukqZSSyTzKcTxiWQyK7HL0vxztf7Vdy7o1qtKo\ + 9Q48eyIc4fc3HwcSLz6CqRlJENsuhqdPcovE4TeIrv72/WBFLot+gGFltrhdXeaNdzLo+xTSdIjXRpnPtNob\ + dld8OyD3F7GpNdtMXoNhpQNfeOWca0eKUkL/gJw5T7kNkTwso2t1gfcIezEge1UpigAQxUgVDRLTdZZ+C1mM\ + rHCyB4ElRjU\n\ + -----END MEGOLM SESSION DATA-----"; + + let timeline = TestTimeline::new(); + + let encrypted = EncryptedEventScheme::MegolmV1AesSha2( + MegolmV1AesSha2ContentInit { + ciphertext: "\ + AwgAEpABqOCAaP6NqXquQcEsrGCVInjRTLHmVH8exqYO0b5Aulhgzqrt6oWVUZCpSRBCnlmvnc96\ + n/wpjlALt6vYUcNr2lMkXpuKuYaQhHx5c4in2OJCkPzGmbpXRRdw6WC25uzzKr5Vi5Fa8B5o1C5E\ + DGgNsJg8jC+cVZbcbVCFisQcLATG8UBDuZUGn3WtVFzw0aHzgxGEc+t4C8J9aWwqwokaEF7fRjTK\ + ma5GZJZKR9KfdmeHR2TsnlnLPiPh5F12hqwd5XaOMQemS2j4pENfxpBlYIy5Wk3FQN0G" + .to_owned(), + sender_key: "sKSGv2uD9zUncgL6GiLedvuky3fjVcEz9qVKZkpzN14".to_owned(), + device_id: "PNQBRWYIJL".into(), + session_id: "gI3QWFyqg55EDS8d0omSJwDw8ZWBNEGUw8JxoZlzJgU".into(), + } + .into(), + ); + timeline.handle_live_message_event(&BOB, RoomEncryptedEventContent::new(encrypted, None)).await; + + let event_id = + timeline.inner.items().await[1].as_event().unwrap().event_id().unwrap().to_owned(); + + let encrypted = EncryptedEventScheme::MegolmV1AesSha2( + MegolmV1AesSha2ContentInit { + ciphertext: "\ + AwgAEtABWuWeRLintqVP5ez5kki8sDsX7zSq++9AJo9lELGTDjNKzbF8sowUgg0DaGoP\ + dgWyBmuUxT2bMggwM0fAevtu4XcFtWUx1c/sj1vhekrng9snmXpz4a30N8jhQ7N4WoIg\ + /G5wsPKtOITjUHeon7EKjTPFU7xoYXmxbjDL/9R4hGQdRqogs1hj0ZnWRxNCvr3ahq24\ + E0j8WyBrQXOb2PIHVNfV/9eW8AB744UQXn8FJpmQO8c0Us3YorXtIFrwAtvI3FknD7Lj\ + eeYFpR9oeyZKuzo2Wzp7eiEZt0Lm+xb7Lfp9yY52RhAO7JLlCM4oPff2yXHpUmcjdGsi\ + 9Zc9Z92hiILkZoKOSGccYQoLjYlfL8rVsIVvl4tDDQ" + .to_owned(), + sender_key: "sKSGv2uD9zUncgL6GiLedvuky3fjVcEz9qVKZkpzN14".to_owned(), + device_id: "PNQBRWYIJL".into(), + session_id: "HSRlM67FgLYl0J0l1luflfGwpnFcLKHnNoRqUuIhQ5Q".into(), + } + .into(), + ); + timeline + .handle_live_message_event( + &BOB, + assign!(RoomEncryptedEventContent::new(encrypted, None), { + relates_to: Some(Relation::Replacement(Replacement::new(event_id))), + }), + ) + .await; + + let mut keys = decrypt_room_key_export(Cursor::new(SESSION1_KEY), "1234").unwrap(); + keys.extend(decrypt_room_key_export(Cursor::new(SESSION2_KEY), "1234").unwrap()); + + let own_user_id = user_id!("@example:morheus.localhost"); + let olm_machine = OlmMachine::new(own_user_id, "SomeDeviceId".into()).await; + olm_machine.import_room_keys(keys, false, |_, _| {}).await.unwrap(); + + timeline + .inner + .retry_event_decryption( + room_id!("!bdsREiCPHyZAPkpXer:morpheus.localhost"), + &olm_machine, + None, + ) + .await; + + let items = timeline.inner.items().await; + assert_eq!(items.len(), 2); + + let item = items[1].as_event().unwrap().as_remote().unwrap(); + + assert_matches!(item.encryption_info(), Some(_)); + let msg = assert_matches!(item.content(), TimelineItemContent::Message(msg) => msg); + assert!(msg.is_edited()); + assert_eq!(msg.body(), "This is Error"); +} + +#[async_test] +async fn retry_edit_and_more() { + const DEVICE_ID: &str = "MTEGRRVPEN"; + const SENDER_KEY: &str = "NFPM2+ucU3n3sEdbDdwwv48Bsj4AiQ185lGuRFjy+gs"; + const SESSION_ID: &str = "SMNh04luorH5E8J3b4XYuOBFp8dldO5njacq0OFO70o"; + const SESSION_KEY: &[u8] = b"\ + -----BEGIN MEGOLM SESSION DATA-----\n\ + AXT1CtOfPgmZRXEk4st3ZwIGShWtZ6iDW0+fwku7AIonAAAACr31UJxAbryf6bH3eF5y+WrOipWmZ6G/59A3\ + kuCwntIOrdIC5ShTRWo0qmcWHav2TaFBCx7kWFUs1ryFZjzksCB7sRnVhfXsDUgGGKgj0MOESlPH9Px+IOcV\ + B6Dr9rjj2STtapCknlit9FMrOcfQhsV5q+ymZwm1C32Zc3UTEtyxfpXiIVyru4Xsrzti61fDIiWFj7Mie4Wn\ + 7YQ8SQ1Q9CZUnOCzflP4Yw+5cXHwMRDcz7/kIPzczCYILLp89G//Uh8QN25tN+oCPhBmTxMxoHhabEwkZ/rK\ + D1T+jXDK/dClfXqDXxjjAhQpcUI0soWeAGEq8nMEE5J2D/42AOpKVYqfq2GPiGoPQk3suy4GtDJQlXZaFuz/\ + l4fmHwB1CJCxMUlgpRJ4PhRHAfJn9zfiskM19/dj/G9foGt8KQBRnnbxDVM4eYuoMJZn7SaQfXFmybBTY+Z/\ + bYGg9FUKn/LyjYc8jqbyXCnddzCHB+YENwEOP3WQQrZccyvjuTv5oB/TqK4yS90phIvkLlqEyJXKxxPnzAvV\ + CArjU7naYXMeVieMqcntbeaXutLftLUIF7KUUCPu357sTKjaAp8z98YfPZBctrHRrx7Oo2t6Wtph0A5N/NwA\ + dSN2ceRzRzkoupc4FCxvH6o6PmmtD9DfxtZsk+HA+8NQhgFpvm/VYalikckW+wGFxB4nn1nVViS4GN5n8fc/\ + Ug\n\ + -----END MEGOLM SESSION DATA-----"; + + fn encrypted_message(ciphertext: &str) -> RoomEncryptedEventContent { + RoomEncryptedEventContent::new( + EncryptedEventScheme::MegolmV1AesSha2( + MegolmV1AesSha2ContentInit { + ciphertext: ciphertext.into(), + sender_key: SENDER_KEY.into(), + device_id: DEVICE_ID.into(), + session_id: SESSION_ID.into(), + } + .into(), + ), + None, + ) + } + + let timeline = TestTimeline::new(); + + timeline + .handle_live_message_event( + &BOB, + encrypted_message( + "AwgDEoABQsTrPTYDh22PTmfODR9EucX3qLl3buDcahHPjKJA8QIM+wW0s+e08Zi7/JbLdnZL1VL\ + jO47HcRhxDTyHZPXPg8wd1l0Qb3irjnCnS7LFAc98+ko18CFJUGNeRZZwzGiorKK5VLMv0WQZI8\ + mBZdKIaqDTUBFvcvbn2gQaWtUipQdJQRKyv2h0AWveVkv75lp5hRb7jolCi08oMX8cM+V3Zzyi7\ + mlPAzZjDz0PaRbQwfbMTTHkcL7TZybBi4vLX4f5ZR2Iiysc7gw", + ), + ) + .await; + + let event_id = + timeline.inner.items().await[1].as_event().unwrap().event_id().unwrap().to_owned(); + + let msg2 = encrypted_message( + "AwgEErABt7svMEHDYJTjCQEHypR21l34f9IZLNyFaAbI+EiCIN7C8X5iKmkzuYSmGUodyGKbFRYrW9l5dLj\ + 35xIRli3SZ6duZpmBI7D4pBGPj2T2Jkc/I9kd/I4EhpvV2emDTioB7jwUfFoATfdA0z/6ciTmU73PStKHZM\ + +WYNxCWZERsCQBtiINzC80FymwLjh4nBhnyW0nlMihGGasakn+3wKQUY0HkVoFM8TXQlCXl1RM2oxL9nn0C\ + dRu2LPArXc5K/1GBSyfluSrdQuA9DciLwVHJB9NwvbZ/7flIkaOC7ahahmk2ws+QeSz8MmHt+9QityK3ZUB\ + 4uEzsQ0", + ); + timeline + .handle_live_message_event( + &BOB, + assign!(msg2, { relates_to: Some(Relation::Replacement(Replacement::new(event_id))) }), + ) + .await; + + timeline + .handle_live_message_event( + &BOB, + encrypted_message( + "AwgFEoABUAwzBLYStHEa1RaZtojePQ6sue9terXNMFufeLKci/UcpOpZC9o3lDxp9rxlNjk4Ii+\ + fkOeSClib/qxt+wLszeQZVa04bRr6byK1dOhlptvAPjUCcEsaHyMMR1AnjT2vmFlJRGviwN6cvQ\ + 2r/fEvAW/9QB+N6fX4g9729bt5ftXRqa5QI7NA351RNUveRHxVvx+2x0WJArQjYGRk7tMS2rUto\ + IYt2ZY17nE1UJjN7M87STnCF9c9qy4aGNqIpeVIht6XbtgD7gQ", + ), + ) + .await; + + assert_eq!(timeline.inner.items().await.len(), 4); + + let olm_machine = OlmMachine::new(user_id!("@jptest:matrix.org"), DEVICE_ID.into()).await; + let keys = decrypt_room_key_export(Cursor::new(SESSION_KEY), "testing").unwrap(); + olm_machine.import_room_keys(keys, false, |_, _| {}).await.unwrap(); + + timeline + .inner + .retry_event_decryption( + room_id!("!wFnAUSQbxMcfIMgvNX:flipdot.org"), + &olm_machine, + Some(BTreeSet::from_iter([SESSION_ID])), + ) + .await; + + let timeline_items = timeline.inner.items().await; + assert_eq!(timeline_items.len(), 3); + assert!(timeline_items[0].is_day_divider()); + assert_eq!( + timeline_items[1].as_event().unwrap().content().as_message().unwrap().body(), + "edited" + ); + assert_eq!( + timeline_items[2].as_event().unwrap().content().as_message().unwrap().body(), + "Another message" + ); +} diff --git a/crates/matrix-sdk/src/room/timeline/tests/invalid.rs b/crates/matrix-sdk/src/room/timeline/tests/invalid.rs new file mode 100644 index 00000000000..c3cd412a387 --- /dev/null +++ b/crates/matrix-sdk/src/room/timeline/tests/invalid.rs @@ -0,0 +1,126 @@ +use assert_matches::assert_matches; +use eyeball_im::VectorDiff; +use futures_util::StreamExt; +use matrix_sdk_test::async_test; +use ruma::{ + assign, event_id, + events::{ + relation::Replacement, + room::message::{self, MessageType, RoomMessageEventContent}, + MessageLikeEventType, StateEventType, + }, + uint, MilliSecondsSinceUnixEpoch, +}; +use serde_json::json; + +use super::{TestTimeline, ALICE, BOB}; +use crate::room::timeline::TimelineItemContent; + +#[async_test] +async fn invalid_edit() { + let timeline = TestTimeline::new(); + let mut stream = timeline.subscribe().await; + + timeline.handle_live_message_event(&ALICE, RoomMessageEventContent::text_plain("test")).await; + let _day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let event = item.as_event().unwrap().as_remote().unwrap(); + let msg = event.content().as_message().unwrap(); + assert_eq!(msg.body(), "test"); + + let msg_event_id = event.event_id(); + + let edit = assign!(RoomMessageEventContent::text_plain(" * fake"), { + relates_to: Some(message::Relation::Replacement(Replacement::new( + msg_event_id.to_owned(), + MessageType::text_plain("fake"), + ))), + }); + // Edit is from a different user than the previous event + timeline.handle_live_message_event(&BOB, edit).await; + + // Can't easily test the non-arrival of an item using the stream. Instead + // just assert that there is still just a couple items in the timeline. + assert_eq!(timeline.inner.items().await.len(), 2); +} + +#[async_test] +async fn invalid_event_content() { + let timeline = TestTimeline::new(); + let mut stream = timeline.subscribe().await; + + // m.room.message events must have a msgtype and body in content, so this + // event with an empty content object should fail to deserialize. + timeline + .handle_live_custom_event(json!({ + "content": {}, + "event_id": "$eeG0HA0FAZ37wP8kXlNkxx3I", + "origin_server_ts": 10, + "sender": "@alice:example.org", + "type": "m.room.message", + })) + .await; + + let _day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let event_item = item.as_event().unwrap().as_remote().unwrap(); + assert_eq!(event_item.sender(), "@alice:example.org"); + assert_eq!(event_item.event_id(), event_id!("$eeG0HA0FAZ37wP8kXlNkxx3I").to_owned()); + assert_eq!(event_item.timestamp(), MilliSecondsSinceUnixEpoch(uint!(10))); + let event_type = assert_matches!( + event_item.content(), + TimelineItemContent::FailedToParseMessageLike { event_type, .. } => event_type + ); + assert_eq!(*event_type, MessageLikeEventType::RoomMessage); + + // Similar to above, the m.room.member state event must also not have an + // empty content object. + timeline + .handle_live_custom_event(json!({ + "content": {}, + "event_id": "$d5G0HA0FAZ37wP8kXlNkxx3I", + "origin_server_ts": 2179, + "sender": "@alice:example.org", + "type": "m.room.member", + "state_key": "@alice:example.org", + })) + .await; + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let event_item = item.as_event().unwrap().as_remote().unwrap(); + assert_eq!(event_item.sender(), "@alice:example.org"); + assert_eq!(event_item.event_id(), event_id!("$d5G0HA0FAZ37wP8kXlNkxx3I").to_owned()); + assert_eq!(event_item.timestamp(), MilliSecondsSinceUnixEpoch(uint!(2179))); + let (event_type, state_key) = assert_matches!( + event_item.content(), + TimelineItemContent::FailedToParseState { + event_type, + state_key, + .. + } => (event_type, state_key) + ); + assert_eq!(*event_type, StateEventType::RoomMember); + assert_eq!(state_key, "@alice:example.org"); +} + +#[async_test] +async fn invalid_event() { + let timeline = TestTimeline::new(); + + // This event is missing the sender field which the homeserver must add to + // all timeline events. Because the event is malformed, it will be ignored. + timeline + .handle_live_custom_event(json!({ + "content": { + "body": "hello world", + "msgtype": "m.text" + }, + "event_id": "$eeG0HA0FAZ37wP8kXlNkxx3I", + "origin_server_ts": 10, + "type": "m.room.message", + })) + .await; + assert_eq!(timeline.inner.items().await.len(), 0); +} diff --git a/crates/matrix-sdk/src/room/timeline/tests/mod.rs b/crates/matrix-sdk/src/room/timeline/tests/mod.rs new file mode 100644 index 00000000000..9105846afae --- /dev/null +++ b/crates/matrix-sdk/src/room/timeline/tests/mod.rs @@ -0,0 +1,308 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Unit tests (based on private methods) for the timeline API. + +use std::{ + collections::BTreeMap, + sync::{ + atomic::{AtomicU64, Ordering::SeqCst}, + Arc, + }, +}; + +use async_trait::async_trait; +use eyeball_im::VectorDiff; +use futures_core::Stream; +use indexmap::IndexMap; +use matrix_sdk_base::deserialized_responses::TimelineEvent; +use once_cell::sync::Lazy; +use ruma::{ + events::{ + receipt::{Receipt, ReceiptEventContent, ReceiptThread, ReceiptType}, + AnyMessageLikeEventContent, AnySyncTimelineEvent, EmptyStateKey, MessageLikeEventContent, + RedactedMessageLikeEventContent, RedactedStateEventContent, StateEventContent, + StaticStateEventContent, + }, + serde::Raw, + server_name, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedTransactionId, + OwnedUserId, TransactionId, UserId, +}; +use serde_json::{json, Value as JsonValue}; + +use super::{inner::RoomDataProvider, Profile, TimelineInner, TimelineItem}; + +mod basic; +mod echo; +#[cfg(feature = "e2e-encryption")] +mod encryption; +mod invalid; +mod read_receipts; +mod virt; + +static ALICE: Lazy<&UserId> = Lazy::new(|| user_id!("@alice:server.name")); +static BOB: Lazy<&UserId> = Lazy::new(|| user_id!("@bob:other.server")); + +struct TestTimeline { + inner: TimelineInner, + next_ts: AtomicU64, +} + +impl TestTimeline { + fn new() -> Self { + Self { inner: TimelineInner::new(TestRoomDataProvider), next_ts: AtomicU64::new(0) } + } + + fn with_read_receipt_tracking(mut self) -> Self { + self.inner = self.inner.with_read_receipt_tracking(true); + self + } + + async fn subscribe(&self) -> impl Stream>> { + let (items, stream) = self.inner.subscribe().await; + assert_eq!(items.len(), 0, "Please subscribe to TestTimeline before adding items to it"); + stream + } + + async fn handle_live_message_event(&self, sender: &UserId, content: C) + where + C: MessageLikeEventContent, + { + let ev = self.make_message_event(sender, content); + self.handle_live_event(Raw::new(&ev).unwrap().cast()).await; + } + + async fn handle_live_redacted_message_event(&self, sender: &UserId, content: C) + where + C: RedactedMessageLikeEventContent, + { + let ev = self.make_redacted_message_event(sender, content); + self.handle_live_event(Raw::new(&ev).unwrap().cast()).await; + } + + async fn handle_live_state_event(&self, sender: &UserId, content: C, prev_content: Option) + where + C: StaticStateEventContent, + { + let ev = self.make_state_event(sender, "", content, prev_content); + self.handle_live_event(Raw::new(&ev).unwrap().cast()).await; + } + + async fn handle_live_state_event_with_state_key( + &self, + sender: &UserId, + state_key: C::StateKey, + content: C, + prev_content: Option, + ) where + C: StaticStateEventContent, + { + let ev = self.make_state_event(sender, state_key.as_ref(), content, prev_content); + self.handle_live_event(Raw::new(&ev).unwrap().cast()).await; + } + + async fn handle_live_redacted_state_event(&self, sender: &UserId, content: C) + where + C: RedactedStateEventContent, + { + let ev = self.make_redacted_state_event(sender, "", content); + self.handle_live_event(Raw::new(&ev).unwrap().cast()).await; + } + + async fn handle_live_redacted_state_event_with_state_key( + &self, + sender: &UserId, + state_key: C::StateKey, + content: C, + ) where + C: RedactedStateEventContent, + { + let ev = self.make_redacted_state_event(sender, state_key.as_ref(), content); + self.handle_live_event(Raw::new(&ev).unwrap().cast()).await; + } + + async fn handle_live_custom_event(&self, event: JsonValue) { + let raw = Raw::new(&event).unwrap().cast(); + self.handle_live_event(raw).await; + } + + async fn handle_live_redaction(&self, sender: &UserId, redacts: &EventId) { + let ev = json!({ + "type": "m.room.redaction", + "content": {}, + "redacts": redacts, + "event_id": EventId::new(server_name!("dummy.server")), + "sender": sender, + "origin_server_ts": self.next_server_ts(), + }); + let raw = Raw::new(&ev).unwrap().cast(); + self.handle_live_event(raw).await; + } + + async fn handle_live_event(&self, raw: Raw) { + self.inner.handle_live_event(raw, None, vec![]).await + } + + async fn handle_local_event(&self, content: AnyMessageLikeEventContent) -> OwnedTransactionId { + let txn_id = TransactionId::new(); + self.inner.handle_local_event(txn_id.clone(), content).await; + txn_id + } + + async fn handle_back_paginated_custom_event(&self, event: JsonValue) { + let timeline_event = TimelineEvent::new(Raw::new(&event).unwrap().cast()); + self.inner.handle_back_paginated_event(timeline_event).await; + } + + async fn handle_read_receipts( + &self, + receipts: impl IntoIterator, + ) { + let ev_content = self.make_receipt_event_content(receipts); + self.inner.handle_read_receipts(ev_content).await; + } + + /// Set the next server timestamp. + /// + /// Timestamps will continue to increase by 1 (millisecond) from that value. + fn set_next_ts(&self, value: u64) { + self.next_ts.store(value, SeqCst); + } + + fn make_message_event( + &self, + sender: &UserId, + content: C, + ) -> JsonValue { + json!({ + "type": content.event_type(), + "content": content, + "event_id": EventId::new(server_name!("dummy.server")), + "sender": sender, + "origin_server_ts": self.next_server_ts(), + }) + } + + fn make_redacted_message_event( + &self, + sender: &UserId, + content: C, + ) -> JsonValue { + json!({ + "type": content.event_type(), + "content": content, + "event_id": EventId::new(server_name!("dummy.server")), + "sender": sender, + "origin_server_ts": self.next_server_ts(), + "unsigned": self.make_redacted_unsigned(sender), + }) + } + + fn make_state_event( + &self, + sender: &UserId, + state_key: &str, + content: C, + prev_content: Option, + ) -> JsonValue { + let unsigned = if let Some(prev_content) = prev_content { + json!({ "prev_content": prev_content }) + } else { + json!({}) + }; + + json!({ + "type": content.event_type(), + "state_key": state_key, + "content": content, + "event_id": EventId::new(server_name!("dummy.server")), + "sender": sender, + "origin_server_ts": self.next_server_ts(), + "unsigned": unsigned, + }) + } + + fn make_redacted_state_event( + &self, + sender: &UserId, + state_key: &str, + content: C, + ) -> JsonValue { + json!({ + "type": content.event_type(), + "state_key": state_key, + "content": content, + "event_id": EventId::new(server_name!("dummy.server")), + "sender": sender, + "origin_server_ts": self.next_server_ts(), + "unsigned": self.make_redacted_unsigned(sender), + }) + } + + fn make_redacted_unsigned(&self, sender: &UserId) -> JsonValue { + json!({ + "redacted_because": { + "content": {}, + "event_id": EventId::new(server_name!("dummy.server")), + "sender": sender, + "origin_server_ts": self.next_server_ts(), + "type": "m.room.redaction", + }, + }) + } + + fn make_receipt_event_content( + &self, + receipts: impl IntoIterator, + ) -> ReceiptEventContent { + let mut ev_content = ReceiptEventContent(BTreeMap::new()); + for (event_id, receipt_type, user_id, thread) in receipts { + let event_map = ev_content.entry(event_id).or_default(); + let receipt_map = event_map.entry(receipt_type).or_default(); + + let mut receipt = Receipt::new(self.next_server_ts()); + receipt.thread = thread; + + receipt_map.insert(user_id, receipt); + } + + ev_content + } + + fn next_server_ts(&self) -> MilliSecondsSinceUnixEpoch { + MilliSecondsSinceUnixEpoch( + self.next_ts + .fetch_add(1, SeqCst) + .try_into() + .expect("server timestamp should fit in js_int::UInt"), + ) + } +} + +struct TestRoomDataProvider; + +#[async_trait] +impl RoomDataProvider for TestRoomDataProvider { + fn own_user_id(&self) -> &UserId { + &ALICE + } + + async fn profile(&self, _user_id: &UserId) -> Option { + None + } + + async fn read_receipts_for_event(&self, _event_id: &EventId) -> IndexMap { + IndexMap::new() + } +} diff --git a/crates/matrix-sdk/src/room/timeline/tests/read_receipts.rs b/crates/matrix-sdk/src/room/timeline/tests/read_receipts.rs new file mode 100644 index 00000000000..a8fb13a876f --- /dev/null +++ b/crates/matrix-sdk/src/room/timeline/tests/read_receipts.rs @@ -0,0 +1,91 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use assert_matches::assert_matches; +use eyeball_im::VectorDiff; +use futures_util::StreamExt; +use matrix_sdk_test::async_test; +use ruma::events::{ + receipt::{ReceiptThread, ReceiptType}, + room::message::RoomMessageEventContent, +}; + +use super::{TestTimeline, ALICE, BOB}; + +#[async_test] +async fn read_receipts_updates() { + let timeline = TestTimeline::new().with_read_receipt_tracking(); + let mut stream = timeline.subscribe().await; + + timeline.handle_live_message_event(*ALICE, RoomMessageEventContent::text_plain("A")).await; + timeline.handle_live_message_event(*BOB, RoomMessageEventContent::text_plain("B")).await; + + let _day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + + // No read receipt for our own user. + let item_a = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let event_a = item_a.as_event().unwrap().as_remote().unwrap(); + assert!(event_a.read_receipts().is_empty()); + + // Implicit read receipt of Bob. + let item_b = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let event_b = item_b.as_event().unwrap().as_remote().unwrap(); + assert_eq!(event_b.read_receipts().len(), 1); + assert!(event_b.read_receipts().get(*BOB).is_some()); + + // Implicit read receipt of Bob is updated. + timeline.handle_live_message_event(*BOB, RoomMessageEventContent::text_plain("C")).await; + + let item_a = + assert_matches!(stream.next().await, Some(VectorDiff::Set { index: 2, value }) => value); + let event_a = item_a.as_event().unwrap().as_remote().unwrap(); + assert!(event_a.read_receipts().is_empty()); + + let item_c = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let event_c = item_c.as_event().unwrap().as_remote().unwrap(); + assert_eq!(event_c.read_receipts().len(), 1); + assert!(event_c.read_receipts().get(*BOB).is_some()); + + timeline.handle_live_message_event(*ALICE, RoomMessageEventContent::text_plain("D")).await; + + let item_d = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let event_d = item_d.as_event().unwrap().as_remote().unwrap(); + assert!(event_d.read_receipts().is_empty()); + + // Explicit read receipt is updated. + timeline + .handle_read_receipts([( + event_d.event_id().to_owned(), + ReceiptType::Read, + BOB.to_owned(), + ReceiptThread::Unthreaded, + )]) + .await; + + let item_c = + assert_matches!(stream.next().await, Some(VectorDiff::Set { index: 3, value }) => value); + let event_c = item_c.as_event().unwrap().as_remote().unwrap(); + assert!(event_c.read_receipts().is_empty()); + + let item_d = + assert_matches!(stream.next().await, Some(VectorDiff::Set { index: 4, value }) => value); + let event_d = item_d.as_event().unwrap().as_remote().unwrap(); + assert_eq!(event_d.read_receipts().len(), 1); + assert!(event_d.read_receipts().get(*BOB).is_some()); +} diff --git a/crates/matrix-sdk/src/room/timeline/tests/virt.rs b/crates/matrix-sdk/src/room/timeline/tests/virt.rs new file mode 100644 index 00000000000..b26483b7a65 --- /dev/null +++ b/crates/matrix-sdk/src/room/timeline/tests/virt.rs @@ -0,0 +1,136 @@ +use assert_matches::assert_matches; +use chrono::{Datelike, Local, TimeZone}; +use eyeball_im::VectorDiff; +use futures_util::StreamExt; +use matrix_sdk_test::async_test; +use ruma::{ + event_id, + events::{room::message::RoomMessageEventContent, AnyMessageLikeEventContent}, +}; + +use super::{TestTimeline, ALICE, BOB}; +use crate::room::timeline::{TimelineItem, VirtualTimelineItem}; + +#[async_test] +async fn day_divider() { + let timeline = TestTimeline::new(); + let mut stream = timeline.subscribe().await; + + timeline + .handle_live_message_event( + *ALICE, + RoomMessageEventContent::text_plain("This is a first message on the first day"), + ) + .await; + + let day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let ts = assert_matches!( + day_divider.as_virtual().unwrap(), + VirtualTimelineItem::DayDivider(ts) => *ts + ); + let date = Local.timestamp_millis_opt(ts.0.into()).single().unwrap(); + assert_eq!(date.year(), 1970); + assert_eq!(date.month(), 1); + assert_eq!(date.day(), 1); + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + item.as_event().unwrap(); + + timeline + .handle_live_message_event( + *ALICE, + RoomMessageEventContent::text_plain("This is a second message on the first day"), + ) + .await; + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + item.as_event().unwrap(); + + // Timestamps start at unix epoch, advance to one day later + timeline.set_next_ts(24 * 60 * 60 * 1000); + + timeline + .handle_live_message_event( + *ALICE, + RoomMessageEventContent::text_plain("This is a first message on the next day"), + ) + .await; + + let day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let ts = assert_matches!( + day_divider.as_virtual().unwrap(), + VirtualTimelineItem::DayDivider(ts) => *ts + ); + let date = Local.timestamp_millis_opt(ts.0.into()).single().unwrap(); + assert_eq!(date.year(), 1970); + assert_eq!(date.month(), 1); + assert_eq!(date.day(), 2); + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + item.as_event().unwrap(); + + let _ = timeline + .handle_local_event(AnyMessageLikeEventContent::RoomMessage( + RoomMessageEventContent::text_plain("A message I'm sending just now"), + )) + .await; + + // The other events are in the past so a local event always creates a new day + // divider. + let day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + assert_matches!(day_divider.as_virtual().unwrap(), VirtualTimelineItem::DayDivider { .. }); + + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + item.as_event().unwrap(); +} + +#[async_test] +async fn update_read_marker() { + let timeline = TestTimeline::new(); + let mut stream = timeline.subscribe().await; + + timeline.handle_live_message_event(&ALICE, RoomMessageEventContent::text_plain("A")).await; + let _day_divider = + assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let event_id = item.as_event().unwrap().event_id().unwrap().to_owned(); + + timeline.inner.set_fully_read_event(event_id).await; + let item = + assert_matches!(stream.next().await, Some(VectorDiff::Insert { index: 2, value }) => value); + assert_matches!(item.as_virtual(), Some(VirtualTimelineItem::ReadMarker)); + + timeline.handle_live_message_event(&BOB, RoomMessageEventContent::text_plain("B")).await; + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let event_id = item.as_event().unwrap().event_id().unwrap().to_owned(); + + timeline.inner.set_fully_read_event(event_id.clone()).await; + assert_matches!(stream.next().await, Some(VectorDiff::Remove { index: 2 })); + let marker = + assert_matches!(stream.next().await, Some(VectorDiff::Insert { index: 3, value }) => value); + assert_matches!(*marker, TimelineItem::Virtual(VirtualTimelineItem::ReadMarker)); + + // Nothing should happen if the fully read event is set back to the same event + // as before. + timeline.inner.set_fully_read_event(event_id.clone()).await; + + // Nothing should happen if the fully read event isn't found. + timeline.inner.set_fully_read_event(event_id!("$fake_event_id").to_owned()).await; + + // Nothing should happen if the fully read event is referring to an old event + // that has already been marked as fully read. + timeline.inner.set_fully_read_event(event_id).await; + + timeline.handle_live_message_event(&ALICE, RoomMessageEventContent::text_plain("C")).await; + let item = assert_matches!(stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let event_id = item.as_event().unwrap().event_id().unwrap().to_owned(); + + timeline.inner.set_fully_read_event(event_id).await; + assert_matches!(stream.next().await, Some(VectorDiff::Remove { index: 3 })); + let marker = + assert_matches!(stream.next().await, Some(VectorDiff::Insert { index: 4, value }) => value); + assert_matches!(*marker, TimelineItem::Virtual(VirtualTimelineItem::ReadMarker)); +} diff --git a/crates/matrix-sdk/src/room/timeline/to_device.rs b/crates/matrix-sdk/src/room/timeline/to_device.rs index ea60055e3e4..5fde9a63412 100644 --- a/crates/matrix-sdk/src/room/timeline/to_device.rs +++ b/crates/matrix-sdk/src/room/timeline/to_device.rs @@ -1,3 +1,17 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use std::{iter, sync::Arc}; use ruma::{ diff --git a/crates/matrix-sdk/src/sliding_sync.rs b/crates/matrix-sdk/src/sliding_sync.rs deleted file mode 100644 index c4ab094e4b9..00000000000 --- a/crates/matrix-sdk/src/sliding_sync.rs +++ /dev/null @@ -1,1924 +0,0 @@ -// Copyright 2022 Benjamin Kampmann -// Copyright 2022 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{ - collections::BTreeMap, - fmt::Debug, - ops::Deref, - sync::{ - atomic::{AtomicBool, AtomicU8, Ordering}, - Arc, Mutex, - }, - time::Duration, -}; - -use futures_core::stream::Stream; -use futures_signals::{ - signal::Mutable, - signal_map::{MutableBTreeMap, MutableBTreeMapLockRef}, - signal_vec::{MutableVec, MutableVecLockMut}, -}; -use matrix_sdk_base::{deserialized_responses::SyncTimelineEvent, sync::SyncResponse}; -use ruma::{ - api::client::{ - error::ErrorKind, - sync::sync_events::v4::{ - self, AccountDataConfig, E2EEConfig, ExtensionsConfig, ReceiptConfig, ToDeviceConfig, - TypingConfig, - }, - }, - assign, - events::TimelineEventType, - OwnedRoomId, RoomId, UInt, -}; -use serde::{Deserialize, Serialize}; -use thiserror::Error; -use tracing::{debug, error, instrument, trace, warn}; -use url::Url; - -#[cfg(feature = "experimental-timeline")] -use crate::room::timeline::{EventTimelineItem, Timeline}; -use crate::{config::RequestConfig, Client, Result}; - -/// Internal representation of errors in Sliding Sync -#[derive(Error, Debug)] -#[non_exhaustive] -pub enum Error { - #[error("Received response for {found} lists, yet we have {expected}.")] - BadViewsCount { found: usize, expected: usize }, - #[error("The sliding sync response could not be handled: {0}")] - BadResponse(String), - #[error("Builder went wrong: {0}")] - SlidingSyncBuilder(#[from] SlidingSyncBuilderError), -} - -/// The state the [`SlidingSyncView`] is in. -/// -/// The lifetime of a SlidingSync usually starts at a `Preload`, getting a fast -/// response for the first given number of Rooms, then switches into -/// `CatchingUp` during which the view fetches the remaining rooms, usually in -/// order, some times in batches. Once that is ready, it switches into `Live`. -/// -/// If the client has been offline for a while, though, the SlidingSync might -/// return back to `CatchingUp` at any point. -#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum SlidingSyncState { - /// Hasn't started yet - #[default] - Cold, - /// We are quickly preloading a preview of the most important rooms - Preload, - /// We are trying to load all remaining rooms, might be in batches - CatchingUp, - /// We are all caught up and now only sync the live responses. - Live, -} - -/// The mode by which the the [`SlidingSyncView`] is in fetching the data. -#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum SlidingSyncMode { - /// fully sync all rooms in the background, page by page of `batch_size` - #[default] - #[serde(alias = "FullSync")] - PagingFullSync, - /// fully sync all rooms in the background, with a growing window of - /// `batch_size`, - GrowingFullSync, - /// Only sync the specific windows defined - Selective, -} - -/// The Entry in the sliding sync room list per sliding sync view -#[derive(Clone, Debug, Default, Serialize, Deserialize)] -pub enum RoomListEntry { - /// This entry isn't known at this point and thus considered `Empty` - #[default] - Empty, - /// There was `OwnedRoomId` but since the server told us to invalid this - /// entry. it is considered stale - Invalidated(OwnedRoomId), - /// This Entry is followed with `OwnedRoomId` - Filled(OwnedRoomId), -} - -impl RoomListEntry { - /// Is this entry empty or invalidated? - pub fn empty_or_invalidated(&self) -> bool { - matches!(self, RoomListEntry::Empty | RoomListEntry::Invalidated(_)) - } - - /// The inner room_id if given - pub fn as_room_id(&self) -> Option<&RoomId> { - match &self { - RoomListEntry::Empty => None, - RoomListEntry::Invalidated(b) | RoomListEntry::Filled(b) => Some(b.as_ref()), - } - } - - fn freeze(&self) -> RoomListEntry { - match &self { - RoomListEntry::Empty => RoomListEntry::Empty, - RoomListEntry::Invalidated(b) | RoomListEntry::Filled(b) => { - RoomListEntry::Invalidated(b.clone()) - } - } - } -} - -pub type AliveRoomTimeline = Arc>; - -/// Room info as giving by the SlidingSync Feature. -#[derive(Debug, Clone)] -pub struct SlidingSyncRoom { - client: Client, - room_id: OwnedRoomId, - inner: v4::SlidingSyncRoom, - is_loading_more: Mutable, - is_cold: Arc, - prev_batch: Mutable>, - timeline: AliveRoomTimeline, -} - -#[derive(Serialize, Deserialize)] -struct FrozenSlidingSyncRoom { - room_id: OwnedRoomId, - inner: v4::SlidingSyncRoom, - prev_batch: Option, - timeline: Vec, -} - -impl From<&SlidingSyncRoom> for FrozenSlidingSyncRoom { - fn from(value: &SlidingSyncRoom) -> Self { - let locked_tl = value.timeline.lock_ref(); - let tl_len = locked_tl.len(); - // To not overflow the database, we only freeze the newest 10 items. on doing - // so, we must drop the `prev_batch` key however, as we'd otherwise - // create a gap between what we have loaded and where the - // prev_batch-key will start loading when paginating backwards. - let (prev_batch, timeline) = if tl_len > 10 { - let pos = tl_len - 10; - (None, locked_tl.iter().skip(pos).cloned().collect()) - } else { - (value.prev_batch.lock_ref().clone(), locked_tl.to_vec()) - }; - FrozenSlidingSyncRoom { - prev_batch, - timeline, - room_id: value.room_id.clone(), - inner: value.inner.clone(), - } - } -} - -impl SlidingSyncRoom { - fn from_frozen(val: FrozenSlidingSyncRoom, client: Client) -> Self { - let FrozenSlidingSyncRoom { room_id, inner, prev_batch, timeline } = val; - SlidingSyncRoom { - client, - room_id, - inner, - is_loading_more: Mutable::new(false), - is_cold: Arc::new(AtomicBool::new(true)), - prev_batch: Mutable::new(prev_batch), - timeline: Arc::new(MutableVec::new_with_values(timeline)), - } - } -} - -impl SlidingSyncRoom { - fn from( - client: Client, - room_id: OwnedRoomId, - mut inner: v4::SlidingSyncRoom, - timeline: Vec, - ) -> Self { - // we overwrite to only keep one copy - inner.timeline = vec![]; - Self { - client, - room_id, - is_loading_more: Mutable::new(false), - is_cold: Arc::new(AtomicBool::new(false)), - prev_batch: Mutable::new(inner.prev_batch.clone()), - timeline: Arc::new(MutableVec::new_with_values(timeline)), - inner, - } - } - - /// RoomId of this SlidingSyncRoom - pub fn room_id(&self) -> &OwnedRoomId { - &self.room_id - } - - /// Are we currently fetching more timeline events in this room? - pub fn is_loading_more(&self) -> bool { - *self.is_loading_more.lock_ref() - } - - /// the `prev_batch` key to fetch more timeline events for this room - pub fn prev_batch(&self) -> Option { - self.prev_batch.lock_ref().clone() - } - - /// `AliveTimeline` of this room - #[cfg(not(feature = "experimental-timeline"))] - pub fn timeline(&self) -> AliveRoomTimeline { - self.timeline.clone() - } - - /// `Timeline` of this room - #[cfg(feature = "experimental-timeline")] - pub async fn timeline(&self) -> Option { - Some(self.timeline_no_fully_read_tracking().await?.with_fully_read_tracking().await) - } - - async fn timeline_no_fully_read_tracking(&self) -> Option { - if let Some(room) = self.client.get_room(&self.room_id) { - let current_timeline = self.timeline.lock_ref().to_vec(); - let prev_batch = self.prev_batch.lock_ref().clone(); - Some(Timeline::with_events(&room, prev_batch, current_timeline).await) - } else if let Some(invited_room) = self.client.get_invited_room(&self.room_id) { - Some(Timeline::with_events(&invited_room, None, vec![]).await) - } else { - error!( - room_id = ?self.room_id, - "Room not found in client. Can't provide a timeline for it" - ); - None - } - } - - /// The latest timeline item of this room. - /// - /// Use `Timeline::latest_event` instead if you already have a timeline for - /// this `SlidingSyncRoom`. - #[cfg(feature = "experimental-timeline")] - pub async fn latest_event(&self) -> Option { - self.timeline_no_fully_read_tracking().await?.latest_event() - } - - /// This rooms name as calculated by the server, if any - pub fn name(&self) -> Option<&str> { - self.inner.name.as_deref() - } - - fn update(&mut self, room_data: &v4::SlidingSyncRoom, timeline: Vec) { - let v4::SlidingSyncRoom { - name, - initial, - limited, - is_dm, - invite_state, - unread_notifications, - required_state, - prev_batch, - .. - } = room_data; - - self.inner.unread_notifications = unread_notifications.clone(); - - if name.is_some() { - self.inner.name = name.clone(); - } - if initial.is_some() { - self.inner.initial = *initial; - } - if is_dm.is_some() { - self.inner.is_dm = *is_dm; - } - if !invite_state.is_empty() { - self.inner.invite_state = invite_state.clone(); - } - if !required_state.is_empty() { - self.inner.required_state = required_state.clone(); - } - - if let Some(batch) = prev_batch { - self.prev_batch.lock_mut().replace(batch.clone()); - } - - if !timeline.is_empty() { - if self.is_cold.load(Ordering::SeqCst) { - // if we come from cold storage, we hard overwrite - self.timeline.lock_mut().replace_cloned(timeline); - self.is_cold.store(false, Ordering::SeqCst); - } else if *limited { - // the server alerted us that we missed items in between - self.timeline.lock_mut().replace_cloned(timeline); - } else { - let mut ref_timeline = self.timeline.lock_mut(); - for e in timeline { - ref_timeline.push_cloned(e); - } - } - } else if *limited { - // notihing but we were alerted that we are stale. clear up - self.timeline.lock_mut().clear(); - } - } -} - -impl Deref for SlidingSyncRoom { - type Target = v4::SlidingSyncRoom; - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -type ViewState = Mutable; -type SyncMode = Mutable; -type StringState = Mutable>; -type RangeState = Mutable>; -type RoomsCount = Mutable>; -type RoomsList = Arc>; -type RoomsMap = Arc>; -type RoomsSubscriptions = Arc>; -type RoomUnsubscribe = Arc>; -type Views = Arc>; - -use derive_builder::Builder; - -/// The Summary of a new SlidingSync Update received -#[derive(Debug, Clone)] -pub struct UpdateSummary { - /// The views (according to their name), which have seen an update - pub views: Vec, - /// The Rooms that have seen updates - pub rooms: Vec, -} - -/// Configuration for a Sliding Sync Instance -#[derive(Clone, Debug, Builder)] -#[builder( - name = "SlidingSyncBuilder", - pattern = "owned", - build_fn(name = "build_no_cache"), - derive(Clone, Debug) -)] -pub struct SlidingSyncConfig { - /// The storage key to keep this cache at and load it from - #[builder(setter(strip_option), default)] - storage_key: Option, - /// Customize the homeserver for sliding sync only - #[builder(setter(strip_option), default)] - homeserver: Option, - - /// The client this sliding sync will be using - client: Client, - #[builder(private, default)] - views: BTreeMap, - #[builder(private, default)] - extensions: Option, - #[builder(private, default)] - subscriptions: BTreeMap, -} - -impl SlidingSyncConfig { - pub async fn build(self) -> Result { - let SlidingSyncConfig { - homeserver, - storage_key, - client, - mut views, - mut extensions, - subscriptions, - } = self; - let mut delta_token_inner = None; - let mut rooms_found: BTreeMap = BTreeMap::new(); - - if let Some(storage_key) = storage_key.as_ref() { - trace!(storage_key, "trying to load from cold"); - - for (name, view) in views.iter_mut() { - if let Some(frozen_view) = client - .store() - .get_custom_value(format!("{storage_key}::{name}").as_bytes()) - .await? - .map(|v| serde_json::from_slice::(&v)) - .transpose()? - { - trace!(name, "frozen for view found"); - - let FrozenSlidingSyncView { rooms_count, rooms_list, rooms } = frozen_view; - view.set_from_cold(rooms_count, rooms_list); - for (key, frozen_room) in rooms.into_iter() { - rooms_found.entry(key).or_insert_with(|| { - SlidingSyncRoom::from_frozen(frozen_room, client.clone()) - }); - } - } else { - trace!(name, "no frozen state for view found"); - } - } - - if let Some(FrozenSlidingSync { to_device_since, delta_token }) = client - .store() - .get_custom_value(storage_key.as_bytes()) - .await? - .map(|v| serde_json::from_slice::(&v)) - .transpose()? - { - trace!("frozen for generic found"); - if let Some(since) = to_device_since { - if let Some(to_device_ext) = - extensions.get_or_insert_with(Default::default).to_device.as_mut() - { - to_device_ext.since = Some(since); - } - } - delta_token_inner = delta_token; - } - trace!("sync unfrozen done"); - }; - - trace!(len = rooms_found.len(), "rooms unfrozen"); - let rooms = Arc::new(MutableBTreeMap::with_values(rooms_found)); - - let views = Arc::new(MutableBTreeMap::with_values(views)); - - Ok(SlidingSync { - homeserver, - client, - storage_key, - - views, - rooms, - - extensions: Mutex::new(extensions).into(), - sent_extensions: Mutex::new(None).into(), - failure_count: Default::default(), - - pos: Mutable::new(None), - delta_token: Mutable::new(delta_token_inner), - subscriptions: Arc::new(MutableBTreeMap::with_values(subscriptions)), - unsubscribe: Default::default(), - }) - } -} - -impl SlidingSyncBuilder { - /// Convenience function to add a full-sync view to the builder - pub fn add_fullsync_view(self) -> Self { - self.add_view( - SlidingSyncViewBuilder::default_with_fullsync() - .build() - .expect("Building default full sync view doesn't fail"), - ) - } - - /// The cold cache key to read from and store the frozen state at - pub fn cold_cache(mut self, name: T) -> Self { - self.storage_key = Some(Some(name.to_string())); - self - } - - /// Do not use the cold cache - pub fn no_cold_cache(mut self) -> Self { - self.storage_key = None; - self - } - - /// Reset the views to `None` - pub fn no_views(mut self) -> Self { - self.views = None; - self - } - - /// Add the given view to the views. - /// - /// Replace any view with the name. - pub fn add_view(mut self, v: SlidingSyncView) -> Self { - let views = self.views.get_or_insert_with(Default::default); - views.insert(v.name.clone(), v); - self - } - - /// Activate e2ee, to-device-message and account data extensions if not yet - /// configured. - /// - /// Will leave any extension configuration found untouched, so the order - /// does not matter. - pub fn with_common_extensions(mut self) -> Self { - { - let mut cfg = self - .extensions - .get_or_insert_with(Default::default) - .get_or_insert_with(Default::default); - if cfg.to_device.is_none() { - cfg.to_device = Some(assign!(ToDeviceConfig::default(), { enabled: Some(true) })); - } - - if cfg.e2ee.is_none() { - cfg.e2ee = Some(assign!(E2EEConfig::default(), { enabled: Some(true) })); - } - - if cfg.account_data.is_none() { - cfg.account_data = - Some(assign!(AccountDataConfig::default(), { enabled: Some(true) })); - } - } - self - } - - /// Activate e2ee, to-device-message, account data, typing and receipt - /// extensions if not yet configured. - /// - /// Will leave any extension configuration found untouched, so the order - /// does not matter. - pub fn with_all_extensions(mut self) -> Self { - { - let mut cfg = self - .extensions - .get_or_insert_with(Default::default) - .get_or_insert_with(Default::default); - if cfg.to_device.is_none() { - cfg.to_device = Some(assign!(ToDeviceConfig::default(), { enabled: Some(true) })); - } - - if cfg.e2ee.is_none() { - cfg.e2ee = Some(assign!(E2EEConfig::default(), { enabled: Some(true) })); - } - - if cfg.account_data.is_none() { - cfg.account_data = - Some(assign!(AccountDataConfig::default(), { enabled: Some(true) })); - } - - if cfg.receipt.is_none() { - cfg.receipt = Some(assign!(ReceiptConfig::default(), { enabled: Some(true) })); - } - - if cfg.typing.is_none() { - cfg.typing = Some(assign!(TypingConfig::default(), { enabled: Some(true) })); - } - } - self - } - - /// Set the E2EE extension configuration. - pub fn with_e2ee_extension(mut self, e2ee: E2EEConfig) -> Self { - self.extensions - .get_or_insert_with(Default::default) - .get_or_insert_with(Default::default) - .e2ee = Some(e2ee); - self - } - - /// Unset the E2EE extension configuration. - pub fn without_e2ee_extension(mut self) -> Self { - self.extensions - .get_or_insert_with(Default::default) - .get_or_insert_with(Default::default) - .e2ee = None; - self - } - - /// Set the ToDevice extension configuration. - pub fn with_to_device_extension(mut self, to_device: ToDeviceConfig) -> Self { - self.extensions - .get_or_insert_with(Default::default) - .get_or_insert_with(Default::default) - .to_device = Some(to_device); - self - } - - /// Unset the ToDevice extension configuration. - pub fn without_to_device_extension(mut self) -> Self { - self.extensions - .get_or_insert_with(Default::default) - .get_or_insert_with(Default::default) - .to_device = None; - self - } - - /// Set the account data extension configuration. - pub fn with_account_data_extension(mut self, account_data: AccountDataConfig) -> Self { - self.extensions - .get_or_insert_with(Default::default) - .get_or_insert_with(Default::default) - .account_data = Some(account_data); - self - } - - /// Unset the account data extension configuration. - pub fn without_account_data_extension(mut self) -> Self { - self.extensions - .get_or_insert_with(Default::default) - .get_or_insert_with(Default::default) - .account_data = None; - self - } - - /// Set the Typing extension configuration. - pub fn with_typing_extension(mut self, typing: TypingConfig) -> Self { - self.extensions - .get_or_insert_with(Default::default) - .get_or_insert_with(Default::default) - .typing = Some(typing); - self - } - - /// Unset the Typing extension configuration. - pub fn without_typing_extension(mut self) -> Self { - self.extensions - .get_or_insert_with(Default::default) - .get_or_insert_with(Default::default) - .typing = None; - self - } - - /// Set the Receipt extension configuration. - pub fn with_receipt_extension(mut self, receipt: ReceiptConfig) -> Self { - self.extensions - .get_or_insert_with(Default::default) - .get_or_insert_with(Default::default) - .receipt = Some(receipt); - self - } - - /// Unset the Receipt extension configuration. - pub fn without_receipt_extension(mut self) -> Self { - self.extensions - .get_or_insert_with(Default::default) - .get_or_insert_with(Default::default) - .receipt = None; - self - } - - /// Build the Sliding Sync - /// - /// if configured, load the cached data from cold storage - pub async fn build(self) -> Result { - self.build_no_cache().map_err(Error::SlidingSyncBuilder)?.build().await - } -} - -/// The sliding sync instance -#[derive(Clone, Debug)] -pub struct SlidingSync { - /// Customize the homeserver for sliding sync only - homeserver: Option, - - client: Client, - - /// The storage key to keep this cache at and load it from - storage_key: Option, - - // ------ Internal state - pub(crate) pos: StringState, - delta_token: StringState, - - /// The views of this sliding sync instance - pub views: Views, - - subscriptions: RoomsSubscriptions, - unsubscribe: RoomUnsubscribe, - - /// The rooms details - rooms: RoomsMap, - - /// keeping track of retries and failure counts - failure_count: Arc, - - /// the intended state of the extensions being supplied to sliding /sync - /// calls. May contain the latest next_batch for to_devices, etc. - extensions: Arc>>, - - /// the last extensions known to be successfully sent to the server. - /// if the current extensions match this, we can avoid sending them again. - sent_extensions: Arc>>, -} - -#[derive(Serialize, Deserialize)] -struct FrozenSlidingSync { - #[serde(skip_serializing_if = "Option::is_none")] - to_device_since: Option, - #[serde(skip_serializing_if = "Option::is_none")] - delta_token: Option, -} - -impl From<&SlidingSync> for FrozenSlidingSync { - fn from(v: &SlidingSync) -> Self { - FrozenSlidingSync { - delta_token: v.delta_token.get_cloned(), - to_device_since: v - .extensions - .lock() - .unwrap() - .as_ref() - .and_then(|ext| ext.to_device.as_ref()?.since.clone()), - } - } -} - -impl SlidingSync { - async fn cache_to_storage(&self) -> Result<()> { - let Some(storage_key) = self.storage_key.as_ref() else { return Ok(()) }; - trace!(storage_key, "saving to storage for later use"); - let v = serde_json::to_vec(&FrozenSlidingSync::from(self))?; - self.client.store().set_custom_value(storage_key.as_bytes(), v).await?; - let frozen_views = { - let rooms_lock = self.rooms.lock_ref(); - self.views - .lock_ref() - .iter() - .map(|(name, view)| { - (name.clone(), FrozenSlidingSyncView::freeze(view, &rooms_lock)) - }) - .collect::>() - }; - for (name, frozen) in frozen_views { - trace!(storage_key, name, "saving to view for later use"); - self.client - .store() - .set_custom_value( - format!("{storage_key}::{name}").as_bytes(), - serde_json::to_vec(&frozen)?, - ) - .await?; // FIXME: parallilize? - } - Ok(()) - } -} - -impl SlidingSync { - /// Generate a new SlidingSyncBuilder with the same inner settings and views - /// but without the current state - pub fn new_builder_copy(&self) -> SlidingSyncBuilder { - let mut builder = SlidingSyncBuilder::default() - .client(self.client.clone()) - .subscriptions(self.subscriptions.lock_ref().to_owned()); - for view in self - .views - .lock_ref() - .values() - .map(|v| v.new_builder().build().expect("builder worked before, builder works now")) - { - builder = builder.add_view(view); - } - - if let Some(h) = &self.homeserver { - builder.homeserver(h.clone()) - } else { - builder - } - } - - /// Subscribe to a given room. - /// - /// Note: this does not cancel any pending request, so make sure to only - /// poll the stream after you've altered this. If you do that during, it - /// might take one round trip to take effect. - pub fn subscribe(&self, room_id: OwnedRoomId, settings: Option) { - self.subscriptions.lock_mut().insert_cloned(room_id, settings.unwrap_or_default()); - } - - /// Unsubscribe from a given room. - /// - /// Note: this does not cancel any pending request, so make sure to only - /// poll the stream after you've altered this. If you do that during, it - /// might take one round trip to take effect. - pub fn unsubscribe(&self, room_id: OwnedRoomId) { - if self.subscriptions.lock_mut().remove(&room_id).is_some() { - self.unsubscribe.lock_mut().push_cloned(room_id); - } - } - - /// Add the common extensions if not already configured - pub fn add_common_extensions(&self) { - let mut lock = self.extensions.lock().unwrap(); - let mut cfg = lock.get_or_insert_with(Default::default); - if cfg.to_device.is_none() { - cfg.to_device = Some(assign!(ToDeviceConfig::default(), { enabled: Some(true) })); - } - - if cfg.e2ee.is_none() { - cfg.e2ee = Some(assign!(E2EEConfig::default(), { enabled: Some(true) })); - } - - if cfg.account_data.is_none() { - cfg.account_data = Some(assign!(AccountDataConfig::default(), { enabled: Some(true) })); - } - } - - /// Lookup a specific room - pub fn get_room(&self, room_id: &RoomId) -> Option { - self.rooms.lock_ref().get(room_id).cloned() - } - - /// Check the number of rooms. - pub fn get_number_of_rooms(&self) -> usize { - self.rooms.lock_ref().len() - } - - fn update_to_device_since(&self, since: String) { - self.extensions - .lock() - .unwrap() - .get_or_insert_with(Default::default) - .to_device - .get_or_insert_with(Default::default) - .since = Some(since); - } - - /// Get access to the SlidingSyncView named `view_name` - /// - /// Note: Remember that this list might have been changed since you started - /// listening to the stream and is therefor not necessarily up to date - /// with the views used for the stream. - pub fn view(&self, view_name: &str) -> Option { - self.views.lock_ref().get(view_name).cloned() - } - - /// Remove the SlidingSyncView named `view_name` from the views list if - /// found - /// - /// Note: Remember that this change will only be applicable for any new - /// stream created after this. The old stream will still continue to use the - /// previous set of views - pub fn pop_view(&self, view_name: &String) -> Option { - self.views.lock_mut().remove(view_name) - } - - /// Add the view to the list of views - /// - /// As views need to have a unique `.name`, if a view with the same name - /// is found the new view will replace the old one and the return it or - /// `None`. - /// - /// Note: Remember that this change will only be applicable for any new - /// stream created after this. The old stream will still continue to use the - /// previous set of views - pub fn add_view(&self, view: SlidingSyncView) -> Option { - self.views.lock_mut().insert_cloned(view.name.clone(), view) - } - - /// Lookup a set of rooms - pub fn get_rooms>( - &self, - room_ids: I, - ) -> Vec> { - let rooms = self.rooms.lock_ref(); - room_ids.map(|room_id| rooms.get(&room_id).cloned()).collect() - } - - /// Get all rooms. - pub fn get_all_rooms(&self) -> Vec { - self.rooms.lock_ref().iter().map(|(_, room)| room.clone()).collect() - } - - #[instrument(skip_all, fields(views = views.len()))] - async fn handle_response( - &self, - resp: v4::Response, - extensions: Option, - views: &mut BTreeMap, - ) -> Result { - let mut processed = self.client.process_sliding_sync(resp.clone()).await?; - debug!("main client processed."); - self.pos.replace(Some(resp.pos)); - self.delta_token.replace(resp.delta_token); - let update = { - let mut rooms = Vec::new(); - let mut rooms_map = self.rooms.lock_mut(); - for (id, mut room_data) in resp.rooms.into_iter() { - let timeline = if let Some(joined_room) = processed.rooms.join.remove(&id) { - joined_room.timeline.events - } else { - let events = room_data.timeline.into_iter().map(Into::into).collect(); - room_data.timeline = vec![]; - events - }; - - if let Some(mut r) = rooms_map.remove(&id) { - r.update(&room_data, timeline); - rooms_map.insert_cloned(id.clone(), r); - rooms.push(id.clone()); - } else { - rooms_map.insert_cloned( - id.clone(), - SlidingSyncRoom::from(self.client.clone(), id.clone(), room_data, timeline), - ); - rooms.push(id); - } - } - - let mut updated_views = Vec::new(); - - for (name, updates) in resp.lists { - let Some(generator) = views.get_mut(&name) else { - error!("Response for view {name} - unknown to us. skipping"); - continue - }; - let count: u32 = - updates.count.try_into().expect("the list total count convertible into u32"); - if generator.handle_response(count, &updates.ops, &rooms)? { - updated_views.push(name.clone()); - } - } - - // Update the `to-device` next-batch if found. - if let Some(to_device_since) = resp.extensions.to_device.map(|t| t.next_batch) { - self.update_to_device_since(to_device_since) - } - - // track the most recently successfully sent extensions (needed for sticky - // semantics) - if extensions.is_some() { - *self.sent_extensions.lock().unwrap() = extensions; - } - - UpdateSummary { views: updated_views, rooms } - }; - - self.cache_to_storage().await?; - - Ok(update) - } - - /// Create the inner stream for the view. - /// - /// Run this stream to receive new updates from the server. - pub fn stream(&self) -> impl Stream> + '_ { - let mut views = { - let mut views = BTreeMap::new(); - let views_lock = self.views.lock_ref(); - for (name, view) in views_lock.deref().iter() { - views.insert(name.clone(), view.request_generator()); - } - views - }; - let client = self.client.clone(); - - debug!(?self.extensions, "Setting view stream going"); - async_stream::stream! { - - loop { - debug!(?self.extensions, "Sync loop running"); - - let mut requests = BTreeMap::new(); - let mut to_remove = Vec::new(); - - for (name, generator) in views.iter_mut() { - if let Some(request) = generator.next() { - requests.insert(name.clone(), request); - } else { - to_remove.push(name.clone()); - } - } - for n in to_remove { - views.remove(&n); - } - - if views.is_empty() { - return - } - - let pos = self.pos.get_cloned(); - let delta_token = self.delta_token.get_cloned(); - let room_subscriptions = self.subscriptions.lock_ref().clone(); - let unsubscribe_rooms = { - let unsubs = self.unsubscribe.lock_ref().to_vec(); - if !unsubs.is_empty() { - self.unsubscribe.lock_mut().clear(); - } - unsubs - }; - let timeout = Duration::from_secs(30); - - // implement stickiness by only sending extensions if they have - // changed since the last time we sent them - let extensions = { - let extensions = self.extensions.lock().unwrap(); - if *extensions == *self.sent_extensions.lock().unwrap() { - None - } else { - extensions.clone() - } - }; - - let req = assign!(v4::Request::new(), { - lists: requests, - pos, - delta_token, - timeout: Some(timeout), - room_subscriptions, - unsubscribe_rooms, - extensions: extensions.clone().unwrap_or_default(), - }); - debug!("requesting"); - - // 30s for the long poll + 30s for network delays - let request_config = RequestConfig::default().timeout(timeout + Duration::from_secs(30)); - let req = client.send_with_homeserver(req, Some(request_config), self.homeserver.as_ref().map(ToString::to_string)); - - #[cfg(feature = "e2e-encryption")] - let resp_res = { - let (e2ee_uploads, resp) = futures_util::join!(client.send_outgoing_requests(), req); - if let Err(e) = e2ee_uploads { - error!(error = ?e, "Error while sending outgoing E2EE requests"); - } - resp - }; - #[cfg(not(feature = "e2e-encryption"))] - let resp_res = req.await; - - let resp = match resp_res { - Ok(r) => { - self.failure_count.store(0, Ordering::SeqCst); - r - }, - Err(e) => { - if e.client_api_error_kind() == Some(&ErrorKind::UnknownPos) { - // session expired, let's reset - if self.failure_count.fetch_add(1, Ordering::SeqCst) >= 3 { - error!("session expired three times in a row"); - yield Err(e.into()); - break - } - warn!("Session expired. Restarting sliding sync."); - *self.pos.lock_mut() = None; - - // reset our extensions to the last known good ones. - *self.extensions.lock().unwrap() = self.sent_extensions.lock().unwrap().take(); - - debug!(?self.extensions, "Resetting view stream"); - } - yield Err(e.into()); - continue - } - }; - - debug!("received"); - - let updates = match self.handle_response(resp, extensions, &mut views).await { - Ok(r) => r, - Err(e) => { - yield Err(e.into()); - continue - } - }; - debug!("handled"); - yield Ok(updates); - } - } - } -} - -/// Holding a specific filtered view within the concept of sliding sync. -/// Main entrypoint to the SlidingSync -/// -/// -/// ```no_run -/// # use futures::executor::block_on; -/// # use matrix_sdk::Client; -/// # use url::Url; -/// # block_on(async { -/// # let homeserver = Url::parse("http://example.com")?; -/// let client = Client::new(homeserver).await?; -/// let sliding_sync = client.sliding_sync().default_with_fullsync().build()?; -/// -/// # }) -/// ``` -#[derive(Clone, Debug, Builder)] -#[builder(build_fn(name = "finish_build"), pattern = "owned", derive(Clone, Debug))] -pub struct SlidingSyncView { - /// Which SyncMode to start this view under - #[builder(setter(custom), default)] - sync_mode: SyncMode, - - /// Sort the rooms list by this - #[builder(default = "SlidingSyncViewBuilder::default_sort()")] - sort: Vec, - - /// Required states to return per room - #[builder(default = "SlidingSyncViewBuilder::default_required_state()")] - required_state: Vec<(TimelineEventType, String)>, - - /// How many rooms request at a time when doing a full-sync catch up - #[builder(default = "20")] - batch_size: u32, - - /// Whether the view should send `UpdatedAt`-Diff signals for rooms - /// that have changed - #[builder(default = "false")] - send_updates_for_items: bool, - - /// How many rooms request a total hen doing a full-sync catch up - #[builder(setter(into), default)] - limit: Option, - - /// Any filters to apply to the query - #[builder(default)] - filters: Option, - - /// The maximum number of timeline events to query for - #[builder(setter(name = "timeline_limit_raw"), default)] - pub timeline_limit: Mutable>, - - // ----- Public state - /// Name of this view to easily recognise them - #[builder(setter(into))] - pub name: String, - - /// The state this view is in - #[builder(private, default)] - pub state: ViewState, - - /// The total known number of rooms, - #[builder(private, default)] - pub rooms_count: RoomsCount, - - /// The rooms in order - #[builder(private, default)] - pub rooms_list: RoomsList, - - /// The ranges windows of the view - #[builder(setter(name = "ranges_raw"), default)] - ranges: RangeState, - - /// Signaling updates on the roomlist after processing - #[builder(private)] - rooms_updated_signal: futures_signals::signal::Sender<()>, - - #[builder(private)] - is_cold: Arc, - - /// Get informed if anything in the room changed - /// - /// If you only care to know about changes once all of them have applied - /// (including the total) listen to a clone of this signal. - #[builder(private)] - pub rooms_updated_broadcaster: - futures_signals::signal::Broadcaster>, -} - -#[derive(Serialize, Deserialize)] -struct FrozenSlidingSyncView { - #[serde(default, skip_serializing_if = "Option::is_none")] - rooms_count: Option, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - rooms_list: Vec, - #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] - rooms: BTreeMap, -} - -impl FrozenSlidingSyncView { - fn freeze( - source_view: &SlidingSyncView, - rooms_map: &MutableBTreeMapLockRef<'_, OwnedRoomId, SlidingSyncRoom>, - ) -> Self { - let mut rooms = BTreeMap::new(); - let mut rooms_list = Vec::new(); - for entry in source_view.rooms_list.lock_ref().iter() { - match entry { - RoomListEntry::Filled(o) | RoomListEntry::Invalidated(o) => { - rooms.insert(o.clone(), rooms_map.get(o).expect("rooms always exists").into()); - } - _ => {} - }; - - rooms_list.push(entry.freeze()); - } - FrozenSlidingSyncView { - rooms_count: *source_view.rooms_count.lock_ref(), - rooms_list, - rooms, - } - } -} - -impl SlidingSyncView { - fn set_from_cold(&mut self, rooms_count: Option, rooms_list: Vec) { - self.state.set(SlidingSyncState::Preload); - self.is_cold.store(true, Ordering::SeqCst); - self.rooms_count.replace(rooms_count); - self.rooms_list.lock_mut().replace_cloned(rooms_list); - } -} - -// /// the default name for the full sync view -pub const FULL_SYNC_VIEW_NAME: &str = "full-sync"; - -impl SlidingSyncViewBuilder { - /// Create a Builder set up for full sync - pub fn default_with_fullsync() -> Self { - Self::default().name(FULL_SYNC_VIEW_NAME).sync_mode(SlidingSyncMode::PagingFullSync) - } - - /// Build the view - pub fn build(mut self) -> Result { - let (sender, receiver) = futures_signals::signal::channel(()); - self.is_cold = Some(Arc::new(AtomicBool::new(false))); - self.rooms_updated_signal = Some(sender); - self.rooms_updated_broadcaster = Some(futures_signals::signal::Broadcaster::new(receiver)); - self.finish_build() - } - - fn default_sort() -> Vec { - vec!["by_recency".to_owned(), "by_name".to_owned()] - } - - fn default_required_state() -> Vec<(TimelineEventType, String)> { - vec![ - (TimelineEventType::RoomEncryption, "".to_owned()), - (TimelineEventType::RoomTombstone, "".to_owned()), - ] - } - - /// Set the Syncing mode - pub fn sync_mode(mut self, sync_mode: SlidingSyncMode) -> Self { - self.sync_mode = Some(SyncMode::new(sync_mode)); - self - } - - /// Set the ranges to fetch - pub fn ranges>(mut self, range: Vec<(U, U)>) -> Self { - self.ranges = - Some(RangeState::new(range.into_iter().map(|(a, b)| (a.into(), b.into())).collect())); - self - } - - /// Set a single range fetch - pub fn set_range>(mut self, from: U, to: U) -> Self { - self.ranges = Some(RangeState::new(vec![(from.into(), to.into())])); - self - } - - /// Set the ranges to fetch - pub fn add_range>(mut self, from: U, to: U) -> Self { - let r = self.ranges.get_or_insert_with(|| RangeState::new(Vec::new())); - r.lock_mut().push((from.into(), to.into())); - self - } - - /// Set the ranges to fetch - pub fn reset_ranges(mut self) -> Self { - self.ranges = None; - self - } - - /// Set the limit of regular events to fetch for the timeline. - pub fn timeline_limit>(mut self, timeline_limit: U) -> Self { - self.timeline_limit = Some(Mutable::new(Some(timeline_limit.into()))); - self - } - - /// Reset the limit of regular events to fetch for the timeline. It is left - /// to the server to decide how many to send back - pub fn no_timeline_limit(mut self) -> Self { - self.timeline_limit = None; - self - } -} - -enum InnerSlidingSyncViewRequestGenerator { - GrowingFullSync { position: u32, batch_size: u32, limit: Option, live: bool }, - PagingFullSync { position: u32, batch_size: u32, limit: Option, live: bool }, - Live, -} - -struct SlidingSyncViewRequestGenerator { - view: SlidingSyncView, - ranges: Vec<(usize, usize)>, - inner: InnerSlidingSyncViewRequestGenerator, -} - -impl SlidingSyncViewRequestGenerator { - fn new_with_paging_syncup(view: SlidingSyncView) -> Self { - let batch_size = view.batch_size; - let limit = view.limit; - let position = view - .ranges - .get_cloned() - .first() - .map(|(_start, end)| u32::try_from(*end).unwrap()) - .unwrap_or_default(); - - SlidingSyncViewRequestGenerator { - view, - ranges: Default::default(), - inner: InnerSlidingSyncViewRequestGenerator::PagingFullSync { - position, - batch_size, - limit, - live: false, - }, - } - } - - fn new_with_growing_syncup(view: SlidingSyncView) -> Self { - let batch_size = view.batch_size; - let limit = view.limit; - let position = view - .ranges - .get_cloned() - .first() - .map(|(_start, end)| u32::try_from(*end).unwrap()) - .unwrap_or_default(); - - SlidingSyncViewRequestGenerator { - view, - ranges: Default::default(), - inner: InnerSlidingSyncViewRequestGenerator::GrowingFullSync { - position, - batch_size, - limit, - live: false, - }, - } - } - - fn new_live(view: SlidingSyncView) -> Self { - SlidingSyncViewRequestGenerator { - view, - ranges: Default::default(), - inner: InnerSlidingSyncViewRequestGenerator::Live, - } - } - - fn prefetch_request( - &mut self, - start: u32, - batch_size: u32, - limit: Option, - ) -> v4::SyncRequestList { - let calc_end = start + batch_size; - let end = match limit { - Some(l) => std::cmp::min(l, calc_end), - _ => calc_end, - }; - self.make_request_for_ranges(vec![(start.into(), end.into())]) - } - - #[instrument(skip(self), fields(name = self.view.name))] - fn make_request_for_ranges(&mut self, ranges: Vec<(UInt, UInt)>) -> v4::SyncRequestList { - let sort = self.view.sort.clone(); - let required_state = self.view.required_state.clone(); - let timeline_limit = self.view.timeline_limit.get_cloned(); - let filters = self.view.filters.clone(); - - self.ranges = ranges - .iter() - .map(|(a, b)| { - ( - usize::try_from(*a).expect("range is a valid u32"), - usize::try_from(*b).expect("range is a valid u32"), - ) - }) - .collect(); - - assign!(v4::SyncRequestList::default(), { - ranges: ranges, - room_details: assign!(v4::RoomDetailsConfig::default(), { - required_state, - timeline_limit, - }), - sort, - filters, - }) - } - - // generate the next live request - fn live_request(&mut self) -> v4::SyncRequestList { - let ranges = self.view.ranges.read_only().get_cloned(); - self.make_request_for_ranges(ranges) - } - - #[instrument(skip_all, fields(name = self.view.name, rooms_count, has_ops = !ops.is_empty()))] - fn handle_response( - &mut self, - rooms_count: u32, - ops: &Vec, - rooms: &Vec, - ) -> Result { - let res = self.view.handle_response(rooms_count, ops, &self.ranges, rooms)?; - self.update_state(rooms_count.saturating_sub(1)); // index is 0 based, count is 1 based - Ok(res) - } - - fn update_state(&mut self, max_index: u32) { - let Some((_start, range_end)) = self.ranges.first() else { - error!("Why don't we have any ranges?"); - return - }; - - let end = if &(max_index as usize) < range_end { max_index } else { *range_end as u32 }; - - trace!(end, max_index, range_end, name = self.view.name, "updating state"); - - match &mut self.inner { - InnerSlidingSyncViewRequestGenerator::PagingFullSync { - position, live, limit, .. - } - | InnerSlidingSyncViewRequestGenerator::GrowingFullSync { - position, live, limit, .. - } => { - let max = limit - .map(|limit| if limit > max_index { max_index } else { limit }) - .unwrap_or(max_index); - trace!(end, max, name = self.view.name, "updating state"); - if end >= max { - trace!(name = self.view.name, "going live"); - // we are switching to live mode - self.view.set_range(0, max); - *position = max; - *live = true; - - self.view.state.set_if(SlidingSyncState::Live, |before, _now| { - !matches!(before, SlidingSyncState::Live) - }); - } else { - *position = end; - *live = false; - self.view.set_range(0, end); - self.view.state.set_if(SlidingSyncState::CatchingUp, |before, _now| { - !matches!(before, SlidingSyncState::CatchingUp) - }); - } - } - InnerSlidingSyncViewRequestGenerator::Live => { - self.view.state.set_if(SlidingSyncState::Live, |before, _now| { - !matches!(before, SlidingSyncState::Live) - }); - } - } - } -} - -impl Iterator for SlidingSyncViewRequestGenerator { - type Item = v4::SyncRequestList; - - fn next(&mut self) -> Option { - match self.inner { - InnerSlidingSyncViewRequestGenerator::PagingFullSync { live, .. } - | InnerSlidingSyncViewRequestGenerator::GrowingFullSync { live, .. } - if live => - { - Some(self.live_request()) - } - InnerSlidingSyncViewRequestGenerator::PagingFullSync { - position, - batch_size, - limit, - .. - } => Some(self.prefetch_request(position, batch_size, limit)), - InnerSlidingSyncViewRequestGenerator::GrowingFullSync { - position, - batch_size, - limit, - .. - } => Some(self.prefetch_request(0, position + batch_size, limit)), - InnerSlidingSyncViewRequestGenerator::Live => Some(self.live_request()), - } - } -} - -#[instrument(skip(ops))] -fn room_ops( - rooms_list: &mut MutableVecLockMut<'_, RoomListEntry>, - ops: &Vec, - room_ranges: &Vec<(usize, usize)>, -) -> Result<(), Error> { - let index_in_range = |idx| room_ranges.iter().any(|(start, end)| idx >= *start && idx <= *end); - for op in ops { - match &op.op { - v4::SlidingOp::Sync => { - let start: u32 = op - .range - .ok_or_else(|| { - Error::BadResponse( - "`range` must be present for Sync and Update operation".to_owned(), - ) - })? - .0 - .try_into() - .map_err(|e| Error::BadResponse(format!("`range` not a valid int: {e:}")))?; - let room_ids = op.room_ids.clone(); - room_ids - .into_iter() - .enumerate() - .map(|(i, r)| { - let idx = start as usize + i; - if idx >= rooms_list.len() { - rooms_list.insert_cloned(idx, RoomListEntry::Filled(r)); - } else { - rooms_list.set_cloned(idx, RoomListEntry::Filled(r)); - } - }) - .count(); - } - v4::SlidingOp::Delete => { - let pos: u32 = op - .index - .ok_or_else(|| { - Error::BadResponse( - "`index` must be present for DELETE operation".to_owned(), - ) - })? - .try_into() - .map_err(|e| { - Error::BadResponse(format!("`index` not a valid int for DELETE: {e:}")) - })?; - rooms_list.set_cloned(pos as usize, RoomListEntry::Empty); - } - v4::SlidingOp::Insert => { - let pos: usize = op - .index - .ok_or_else(|| { - Error::BadResponse( - "`index` must be present for INSERT operation".to_owned(), - ) - })? - .try_into() - .map_err(|e| { - Error::BadResponse(format!("`index` not a valid int for INSERT: {e:}")) - })?; - let sliced = rooms_list.as_slice(); - let room = RoomListEntry::Filled(op.room_id.clone().ok_or_else(|| { - Error::BadResponse("`room_id` must be present for INSERT operation".to_owned()) - })?); - let mut dif = 0usize; - loop { - // find the next empty slot and drop it - let (prev_p, prev_overflow) = pos.overflowing_sub(dif); - let check_prev = !prev_overflow && index_in_range(prev_p); - let (next_p, overflown) = pos.overflowing_add(dif); - let check_after = !overflown && next_p < sliced.len() && index_in_range(next_p); - if !check_prev && !check_after { - return Err(Error::BadResponse( - "We were asked to insert but could not find any direction to shift to" - .to_owned(), - )); - } - - if check_prev && sliced[prev_p].empty_or_invalidated() { - // we only check for previous, if there are items left - rooms_list.remove(prev_p); - break; - } else if check_after && sliced[next_p].empty_or_invalidated() { - rooms_list.remove(next_p); - break; - } else { - // let's check the next position; - dif += 1; - } - } - rooms_list.insert_cloned(pos, room); - } - v4::SlidingOp::Invalidate => { - let max_len = rooms_list.len(); - let (mut pos, end): (u32, u32) = if let Some(range) = op.range { - ( - range.0.try_into().map_err(|e| { - Error::BadResponse(format!("`range.0` not a valid int: {e:}")) - })?, - range.1.try_into().map_err(|e| { - Error::BadResponse(format!("`range.1` not a valid int: {e:}")) - })?, - ) - } else { - return Err(Error::BadResponse( - "`range` must be given on `Invalidate` operation".to_owned(), - )); - }; - - if pos > end { - return Err(Error::BadResponse( - "Invalid invalidation, end smaller than start".to_owned(), - )); - } - - // ranges are inclusive up to the last index. e.g. `[0, 10]`; `[0, 0]`. - // ensure we pick them all up - while pos <= end { - if pos as usize >= max_len { - break; // how does this happen? - } - let idx = pos as usize; - let entry = if let Some(RoomListEntry::Filled(b)) = rooms_list.get(idx) { - Some(b.clone()) - } else { - None - }; - - if let Some(b) = entry { - rooms_list.set_cloned(pos as usize, RoomListEntry::Invalidated(b)); - } else { - rooms_list.set_cloned(pos as usize, RoomListEntry::Empty); - } - pos += 1; - } - } - s => { - warn!("Unknown operation occurred: {:?}", s); - } - } - } - - Ok(()) -} - -impl SlidingSyncView { - /// Return a builder with the same settings as before - pub fn new_builder(&self) -> SlidingSyncViewBuilder { - SlidingSyncViewBuilder::default() - .name(&self.name) - .sync_mode(self.sync_mode.lock_ref().clone()) - .sort(self.sort.clone()) - .required_state(self.required_state.clone()) - .batch_size(self.batch_size) - .ranges(self.ranges.read_only().get_cloned()) - } - - /// Set the ranges to fetch - /// - /// Remember to cancel the existing stream and fetch a new one as this will - /// only be applied on the next request. - pub fn set_ranges(&self, range: Vec<(u32, u32)>) -> &Self { - *self.ranges.lock_mut() = range.into_iter().map(|(a, b)| (a.into(), b.into())).collect(); - self - } - - /// Reset the ranges to a particular set - /// - /// Remember to cancel the existing stream and fetch a new one as this will - /// only be applied on the next request. - pub fn set_range(&self, start: u32, end: u32) -> &Self { - *self.ranges.lock_mut() = vec![(start.into(), end.into())]; - self - } - - /// Set the ranges to fetch - /// - /// Remember to cancel the existing stream and fetch a new one as this will - /// only be applied on the next request. - pub fn add_range(&self, start: u32, end: u32) -> &Self { - self.ranges.lock_mut().push((start.into(), end.into())); - self - } - - /// Set the ranges to fetch - /// - /// Note: sending an empty list of ranges is, according to the spec, to be - /// understood that the consumer doesn't care about changes of the room - /// order but you will only receive updates when for rooms entering or - /// leaving the set. - /// - /// Remember to cancel the existing stream and fetch a new one as this will - /// only be applied on the next request. - pub fn reset_ranges(&self) -> &Self { - self.ranges.lock_mut().clear(); - self - } - - /// Find the current valid position of the room in the view room_list. - /// - /// Only matches against the current ranges and only against filled items. - /// Invalid items are ignore. Return the total position the item was - /// found in the room_list, return None otherwise. - pub fn find_room_in_view(&self, room_id: &RoomId) -> Option { - let ranges = self.ranges.lock_ref(); - let listing = self.rooms_list.lock_ref(); - for (start_uint, end_uint) in ranges.iter() { - let mut cur_pos: usize = (*start_uint).try_into().unwrap(); - let end: usize = (*end_uint).try_into().unwrap(); - let iterator = listing.iter().skip(cur_pos); - for n in iterator { - if let RoomListEntry::Filled(r) = n { - if room_id == r { - return Some(cur_pos); - } - } - if cur_pos == end { - break; - } - cur_pos += 1; - } - } - None - } - - /// Find the current valid position of the rooms in the views room_list. - /// - /// Only matches against the current ranges and only against filled items. - /// Invalid items are ignore. Return the total position the items that were - /// found in the room_list, will skip any room not found in the rooms_list. - pub fn find_rooms_in_view(&self, room_ids: &[OwnedRoomId]) -> Vec<(usize, OwnedRoomId)> { - let ranges = self.ranges.lock_ref(); - let listing = self.rooms_list.lock_ref(); - let mut rooms_found = Vec::new(); - for (start_uint, end_uint) in ranges.iter() { - let mut cur_pos: usize = (*start_uint).try_into().unwrap(); - let end: usize = (*end_uint).try_into().unwrap(); - let iterator = listing.iter().skip(cur_pos); - for n in iterator { - if let RoomListEntry::Filled(r) = n { - if room_ids.contains(r) { - rooms_found.push((cur_pos, r.clone())); - } - } - if cur_pos == end { - break; - } - cur_pos += 1; - } - } - rooms_found - } - - /// Return the room_id at the given index - pub fn get_room_id(&self, index: usize) -> Option { - self.rooms_list.lock_ref().get(index).and_then(|e| e.as_room_id().map(ToOwned::to_owned)) - } - - #[instrument(skip(self, ops), fields(name = self.name, ops_count = ops.len()))] - fn handle_response( - &self, - rooms_count: u32, - ops: &Vec, - ranges: &Vec<(usize, usize)>, - rooms: &Vec, - ) -> Result { - let current_rooms_count = self.rooms_count.get(); - if current_rooms_count.is_none() - || current_rooms_count == Some(0) - || self.is_cold.load(Ordering::SeqCst) - { - debug!("first run, replacing roomslist"); - // first response, we do that slightly differentely - let rooms_list = - MutableVec::new_with_values(vec![RoomListEntry::Empty; rooms_count as usize]); - // then we apply it - let mut locked = rooms_list.lock_mut(); - room_ops(&mut locked, ops, ranges)?; - self.rooms_list.lock_mut().replace_cloned(locked.as_slice().to_vec()); - self.rooms_count.set(Some(rooms_count)); - self.is_cold.store(false, Ordering::SeqCst); - return Ok(true); - } - - debug!("regular update"); - let mut missing = - rooms_count.checked_sub(self.rooms_list.lock_ref().len() as u32).unwrap_or_default(); - let mut changed = false; - if missing > 0 { - let mut list = self.rooms_list.lock_mut(); - list.reserve_exact(missing as usize); - while missing > 0 { - list.push_cloned(RoomListEntry::Empty); - missing -= 1; - } - changed = true; - } - - { - // keep the lock scoped so that the later find_rooms_in_view doesn't deadlock - let mut rooms_list = self.rooms_list.lock_mut(); - - if !ops.is_empty() { - room_ops(&mut rooms_list, ops, ranges)?; - changed = true; - } else { - debug!("no rooms operations found"); - } - } - - if self.rooms_count.get() != Some(rooms_count) { - self.rooms_count.set(Some(rooms_count)); - changed = true; - } - - if self.send_updates_for_items && !rooms.is_empty() { - let found_views = self.find_rooms_in_view(rooms); - if !found_views.is_empty() { - debug!("room details found"); - let mut rooms_list = self.rooms_list.lock_mut(); - for (pos, room_id) in found_views { - // trigger an `UpdatedAt` update - rooms_list.set_cloned(pos, RoomListEntry::Filled(room_id)); - changed = true; - } - } - } - - if changed { - if let Err(e) = self.rooms_updated_signal.send(()) { - warn!("Could not inform about rooms updated: {:?}", e); - } - } - - Ok(changed) - } - - fn request_generator(&self) -> SlidingSyncViewRequestGenerator { - match self.sync_mode.read_only().get_cloned() { - SlidingSyncMode::PagingFullSync => { - SlidingSyncViewRequestGenerator::new_with_paging_syncup(self.clone()) - } - SlidingSyncMode::GrowingFullSync => { - SlidingSyncViewRequestGenerator::new_with_growing_syncup(self.clone()) - } - SlidingSyncMode::Selective => SlidingSyncViewRequestGenerator::new_live(self.clone()), - } - } -} - -impl Client { - /// Create a SlidingSyncBuilder tied to this client - pub async fn sliding_sync(&self) -> SlidingSyncBuilder { - SlidingSyncBuilder::default().client(self.clone()) - } - - #[instrument(skip(self, response))] - pub(crate) async fn process_sliding_sync( - &self, - response: v4::Response, - ) -> Result { - let response = self.base_client().process_sliding_sync(response).await?; - debug!("done processing on base_client"); - self.handle_sync_response(&response).await?; - Ok(response) - } -} - -#[cfg(test)] -mod test { - use ruma::room_id; - use serde_json::json; - - use super::*; - - #[tokio::test] - async fn check_find_room_in_view() -> Result<()> { - let view = SlidingSyncViewBuilder::default() - .name("testview") - .add_range(0u32, 9u32) - .build() - .unwrap(); - let full_window_update: v4::SyncOp = serde_json::from_value(json! ({ - "op": "SYNC", - "range": [0, 9], - "room_ids": [ - "!A00000:matrix.example", - "!A00001:matrix.example", - "!A00002:matrix.example", - "!A00003:matrix.example", - "!A00004:matrix.example", - "!A00005:matrix.example", - "!A00006:matrix.example", - "!A00007:matrix.example", - "!A00008:matrix.example", - "!A00009:matrix.example" - ], - })) - .unwrap(); - - view.handle_response(10u32, &vec![full_window_update], &vec![(0, 9)], &vec![]).unwrap(); - - let a02 = room_id!("!A00002:matrix.example").to_owned(); - let a05 = room_id!("!A00005:matrix.example").to_owned(); - let a09 = room_id!("!A00009:matrix.example").to_owned(); - - assert_eq!(view.find_room_in_view(&a02), Some(2)); - assert_eq!(view.find_room_in_view(&a05), Some(5)); - assert_eq!(view.find_room_in_view(&a09), Some(9)); - - assert_eq!( - view.find_rooms_in_view(&[a02.clone(), a05.clone(), a09.clone()]), - vec![(2, a02.clone()), (5, a05.clone()), (9, a09.clone())] - ); - - // we invalidate a few in the center - let update: v4::SyncOp = serde_json::from_value(json! ({ - "op": "INVALIDATE", - "range": [4, 7], - })) - .unwrap(); - - view.handle_response(10u32, &vec![update], &vec![(0, 3), (8, 9)], &vec![]).unwrap(); - - assert_eq!(view.find_room_in_view(room_id!("!A00002:matrix.example")), Some(2)); - assert_eq!(view.find_room_in_view(room_id!("!A00005:matrix.example")), None); - assert_eq!(view.find_room_in_view(room_id!("!A00009:matrix.example")), Some(9)); - - assert_eq!( - view.find_rooms_in_view(&[a02.clone(), a05, a09.clone()]), - vec![(2, a02), (9, a09)] - ); - - Ok(()) - } -} diff --git a/crates/matrix-sdk/src/sliding_sync/builder.rs b/crates/matrix-sdk/src/sliding_sync/builder.rs new file mode 100644 index 00000000000..1147874876d --- /dev/null +++ b/crates/matrix-sdk/src/sliding_sync/builder.rs @@ -0,0 +1,313 @@ +use std::{ + collections::BTreeMap, + fmt::Debug, + sync::{Mutex, RwLock as StdRwLock}, +}; + +use eyeball::unique::Observable; +use ruma::{ + api::client::sync::sync_events::v4::{ + self, AccountDataConfig, E2EEConfig, ExtensionsConfig, ReceiptsConfig, ToDeviceConfig, + TypingConfig, + }, + assign, OwnedRoomId, +}; +use tracing::trace; +use url::Url; + +use super::{ + Error, FrozenSlidingSync, FrozenSlidingSyncList, SlidingSync, SlidingSyncInner, + SlidingSyncList, SlidingSyncListBuilder, SlidingSyncPositionMarkers, SlidingSyncRoom, +}; +use crate::{Client, Result}; + +/// Configuration for a Sliding Sync instance. +/// +/// Get a new builder with methods like [`crate::Client::sliding_sync`], or +/// [`crate::SlidingSync::builder`]. +#[derive(Clone, Debug)] +pub struct SlidingSyncBuilder { + storage_key: Option, + homeserver: Option, + client: Option, + lists: BTreeMap, + extensions: Option, + subscriptions: BTreeMap, +} + +impl SlidingSyncBuilder { + pub(super) fn new() -> Self { + Self { + storage_key: None, + homeserver: None, + client: None, + lists: BTreeMap::new(), + extensions: None, + subscriptions: BTreeMap::new(), + } + } + + /// Set the storage key to keep this cache at and load it from. + pub fn storage_key(mut self, value: Option) -> Self { + self.storage_key = value; + self + } + + /// Set the homeserver for sliding sync only. + pub fn homeserver(mut self, value: Url) -> Self { + self.homeserver = Some(value); + self + } + + /// Set the client this sliding sync will be using. + pub fn client(mut self, value: Client) -> Self { + self.client = Some(value); + self + } + + pub(super) fn subscriptions( + mut self, + value: BTreeMap, + ) -> Self { + self.subscriptions = value; + self + } + + /// Convenience function to add a full-sync list to the builder + pub fn add_fullsync_list(self) -> Self { + self.add_list( + SlidingSyncListBuilder::default_with_fullsync() + .build() + .expect("Building default full sync list doesn't fail"), + ) + } + + /// The cold cache key to read from and store the frozen state at + pub fn cold_cache(mut self, name: T) -> Self { + self.storage_key = Some(name.to_string()); + self + } + + /// Do not use the cold cache + pub fn no_cold_cache(mut self) -> Self { + self.storage_key = None; + self + } + + /// Reset the lists to `None` + pub fn no_lists(mut self) -> Self { + self.lists.clear(); + self + } + + /// Add the given list to the lists. + /// + /// Replace any list with the name. + pub fn add_list(mut self, list: SlidingSyncList) -> Self { + self.lists.insert(list.name.clone(), list); + + self + } + + /// Activate e2ee, to-device-message and account data extensions if not yet + /// configured. + /// + /// Will leave any extension configuration found untouched, so the order + /// does not matter. + pub fn with_common_extensions(mut self) -> Self { + { + let mut cfg = self.extensions.get_or_insert_with(Default::default); + if cfg.to_device.is_none() { + cfg.to_device = Some(assign!(ToDeviceConfig::default(), { enabled: Some(true) })); + } + + if cfg.e2ee.is_none() { + cfg.e2ee = Some(assign!(E2EEConfig::default(), { enabled: Some(true) })); + } + + if cfg.account_data.is_none() { + cfg.account_data = + Some(assign!(AccountDataConfig::default(), { enabled: Some(true) })); + } + } + self + } + + /// Activate e2ee, to-device-message, account data, typing and receipt + /// extensions if not yet configured. + /// + /// Will leave any extension configuration found untouched, so the order + /// does not matter. + pub fn with_all_extensions(mut self) -> Self { + { + let mut cfg = self.extensions.get_or_insert_with(Default::default); + if cfg.to_device.is_none() { + cfg.to_device = Some(assign!(ToDeviceConfig::default(), { enabled: Some(true) })); + } + + if cfg.e2ee.is_none() { + cfg.e2ee = Some(assign!(E2EEConfig::default(), { enabled: Some(true) })); + } + + if cfg.account_data.is_none() { + cfg.account_data = + Some(assign!(AccountDataConfig::default(), { enabled: Some(true) })); + } + + if cfg.receipts.is_none() { + cfg.receipts = Some(assign!(ReceiptsConfig::default(), { enabled: Some(true) })); + } + + if cfg.typing.is_none() { + cfg.typing = Some(assign!(TypingConfig::default(), { enabled: Some(true) })); + } + } + self + } + + /// Set the E2EE extension configuration. + pub fn with_e2ee_extension(mut self, e2ee: E2EEConfig) -> Self { + self.extensions.get_or_insert_with(Default::default).e2ee = Some(e2ee); + self + } + + /// Unset the E2EE extension configuration. + pub fn without_e2ee_extension(mut self) -> Self { + self.extensions.get_or_insert_with(Default::default).e2ee = None; + self + } + + /// Set the ToDevice extension configuration. + pub fn with_to_device_extension(mut self, to_device: ToDeviceConfig) -> Self { + self.extensions.get_or_insert_with(Default::default).to_device = Some(to_device); + self + } + + /// Unset the ToDevice extension configuration. + pub fn without_to_device_extension(mut self) -> Self { + self.extensions.get_or_insert_with(Default::default).to_device = None; + self + } + + /// Set the account data extension configuration. + pub fn with_account_data_extension(mut self, account_data: AccountDataConfig) -> Self { + self.extensions.get_or_insert_with(Default::default).account_data = Some(account_data); + self + } + + /// Unset the account data extension configuration. + pub fn without_account_data_extension(mut self) -> Self { + self.extensions.get_or_insert_with(Default::default).account_data = None; + self + } + + /// Set the Typing extension configuration. + pub fn with_typing_extension(mut self, typing: TypingConfig) -> Self { + self.extensions.get_or_insert_with(Default::default).typing = Some(typing); + self + } + + /// Unset the Typing extension configuration. + pub fn without_typing_extension(mut self) -> Self { + self.extensions.get_or_insert_with(Default::default).typing = None; + self + } + + /// Set the Receipt extension configuration. + pub fn with_receipt_extension(mut self, receipt: ReceiptsConfig) -> Self { + self.extensions.get_or_insert_with(Default::default).receipts = Some(receipt); + self + } + + /// Unset the Receipt extension configuration. + pub fn without_receipt_extension(mut self) -> Self { + self.extensions.get_or_insert_with(Default::default).receipts = None; + self + } + + /// Build the Sliding Sync. + /// + /// If `self.storage_key` is `Some(_)`, load the cached data from cold + /// storage. + pub async fn build(mut self) -> Result { + let client = self.client.ok_or(Error::BuildMissingField("client"))?; + + let mut delta_token = None; + let mut rooms_found: BTreeMap = BTreeMap::new(); + + if let Some(storage_key) = &self.storage_key { + trace!(storage_key, "trying to load from cold"); + + for (name, list) in &mut self.lists { + if let Some(frozen_list) = client + .store() + .get_custom_value(format!("{storage_key}::{name}").as_bytes()) + .await? + .map(|v| serde_json::from_slice::(&v)) + .transpose()? + { + trace!(name, "frozen for list found"); + + let FrozenSlidingSyncList { rooms_count, rooms_list, rooms } = frozen_list; + list.set_from_cold(rooms_count, rooms_list); + + for (key, frozen_room) in rooms.into_iter() { + rooms_found.entry(key).or_insert_with(|| { + SlidingSyncRoom::from_frozen(frozen_room, client.clone()) + }); + } + } else { + trace!(name, "no frozen state for list found"); + } + } + + if let Some(FrozenSlidingSync { to_device_since, delta_token: frozen_delta_token }) = + client + .store() + .get_custom_value(storage_key.as_bytes()) + .await? + .map(|v| serde_json::from_slice::(&v)) + .transpose()? + { + trace!("frozen for generic found"); + + if let Some(since) = to_device_since { + if let Some(to_device_ext) = + self.extensions.get_or_insert_with(Default::default).to_device.as_mut() + { + to_device_ext.since = Some(since); + } + } + + delta_token = frozen_delta_token; + } + + trace!("sync unfrozen done"); + }; + + trace!(len = rooms_found.len(), "rooms unfrozen"); + + let rooms = StdRwLock::new(rooms_found); + let lists = StdRwLock::new(self.lists); + + Ok(SlidingSync::new(SlidingSyncInner { + homeserver: self.homeserver, + client, + storage_key: self.storage_key, + + lists, + rooms, + + extensions: Mutex::new(self.extensions), + reset_counter: Default::default(), + + position: StdRwLock::new(SlidingSyncPositionMarkers { + pos: Observable::new(None), + delta_token: Observable::new(delta_token), + }), + + subscriptions: StdRwLock::new(self.subscriptions), + unsubscribe: Default::default(), + })) + } +} diff --git a/crates/matrix-sdk/src/sliding_sync/client.rs b/crates/matrix-sdk/src/sliding_sync/client.rs new file mode 100644 index 00000000000..68bb3ce6374 --- /dev/null +++ b/crates/matrix-sdk/src/sliding_sync/client.rs @@ -0,0 +1,25 @@ +use matrix_sdk_base::sync::SyncResponse; +use ruma::api::client::sync::sync_events::v4; +use tracing::{debug, instrument}; + +use super::{SlidingSync, SlidingSyncBuilder}; +use crate::{Client, Result}; + +impl Client { + /// Create a [`SlidingSyncBuilder`] tied to this client. + pub async fn sliding_sync(&self) -> SlidingSyncBuilder { + SlidingSync::builder().client(self.clone()) + } + + #[instrument(skip(self, response))] + pub(crate) async fn process_sliding_sync( + &self, + response: &v4::Response, + ) -> Result { + let response = self.base_client().process_sliding_sync(response).await?; + debug!("done processing on base_client"); + self.handle_sync_response(&response).await?; + + Ok(response) + } +} diff --git a/crates/matrix-sdk/src/sliding_sync/error.rs b/crates/matrix-sdk/src/sliding_sync/error.rs new file mode 100644 index 00000000000..658e779ea6e --- /dev/null +++ b/crates/matrix-sdk/src/sliding_sync/error.rs @@ -0,0 +1,18 @@ +//! Sliding Sync errors. + +use thiserror::Error; + +/// Internal representation of errors in Sliding Sync. +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum Error { + /// The response we've received from the server can't be parsed or doesn't + /// match up with the current expectations on the client side. A + /// `sync`-restart might be required. + #[error("The sliding sync response could not be handled: {0}")] + BadResponse(String), + /// Called `.build()` on a builder type, but the given required field was + /// missing. + #[error("Required field missing: `{0}`")] + BuildMissingField(&'static str), +} diff --git a/crates/matrix-sdk/src/sliding_sync/list/builder.rs b/crates/matrix-sdk/src/sliding_sync/list/builder.rs new file mode 100644 index 00000000000..772e4d6fcd5 --- /dev/null +++ b/crates/matrix-sdk/src/sliding_sync/list/builder.rs @@ -0,0 +1,173 @@ +//! Builder for [`SlidingSyncList`]. + +use std::{ + fmt::Debug, + sync::{atomic::AtomicBool, Arc, RwLock as StdRwLock}, +}; + +use eyeball::unique::Observable; +use eyeball_im::ObservableVector; +use im::Vector; +use ruma::{api::client::sync::sync_events::v4, events::StateEventType, UInt}; + +use super::{Error, RoomListEntry, SlidingSyncList, SlidingSyncMode, SlidingSyncState}; +use crate::Result; + +/// The default name for the full sync list. +pub const FULL_SYNC_LIST_NAME: &str = "full-sync"; + +/// Builder for [`SlidingSyncList`]. +#[derive(Clone, Debug)] +pub struct SlidingSyncListBuilder { + sync_mode: SlidingSyncMode, + sort: Vec, + required_state: Vec<(StateEventType, String)>, + batch_size: u32, + send_updates_for_items: bool, + limit: Option, + filters: Option, + timeline_limit: Option, + name: Option, + state: SlidingSyncState, + rooms_count: Option, + rooms_list: Vector, + ranges: Vec<(UInt, UInt)>, +} + +impl SlidingSyncListBuilder { + pub(super) fn new() -> Self { + Self { + sync_mode: SlidingSyncMode::default(), + sort: vec!["by_recency".to_owned(), "by_name".to_owned()], + required_state: vec![ + (StateEventType::RoomEncryption, "".to_owned()), + (StateEventType::RoomTombstone, "".to_owned()), + ], + batch_size: 20, + send_updates_for_items: false, + limit: None, + filters: None, + timeline_limit: None, + name: None, + state: SlidingSyncState::default(), + rooms_count: None, + rooms_list: Vector::new(), + ranges: Vec::new(), + } + } + + /// Create a Builder set up for full sync. + pub fn default_with_fullsync() -> Self { + Self::new().name(FULL_SYNC_LIST_NAME).sync_mode(SlidingSyncMode::PagingFullSync) + } + + /// Which SlidingSyncMode to start this list under. + pub fn sync_mode(mut self, value: SlidingSyncMode) -> Self { + self.sync_mode = value; + self + } + + /// Sort the rooms list by this. + pub fn sort(mut self, value: Vec) -> Self { + self.sort = value; + self + } + + /// Required states to return per room. + pub fn required_state(mut self, value: Vec<(StateEventType, String)>) -> Self { + self.required_state = value; + self + } + + /// How many rooms request at a time when doing a full-sync catch up. + pub fn batch_size(mut self, value: u32) -> Self { + self.batch_size = value; + self + } + + /// Whether the list should send `UpdatedAt`-Diff signals for rooms that + /// have changed. + pub fn send_updates_for_items(mut self, value: bool) -> Self { + self.send_updates_for_items = value; + self + } + + /// How many rooms request a total hen doing a full-sync catch up. + pub fn limit(mut self, value: impl Into>) -> Self { + self.limit = value.into(); + self + } + + /// Any filters to apply to the query. + pub fn filters(mut self, value: Option) -> Self { + self.filters = value; + self + } + + /// Set the limit of regular events to fetch for the timeline. + pub fn timeline_limit>(mut self, timeline_limit: U) -> Self { + self.timeline_limit = Some(timeline_limit.into()); + self + } + + /// Reset the limit of regular events to fetch for the timeline. It is left + /// to the server to decide how many to send back + pub fn no_timeline_limit(mut self) -> Self { + self.timeline_limit = Default::default(); + self + } + + /// Set the name of this list, to easily recognize it. + pub fn name(mut self, value: impl Into) -> Self { + self.name = Some(value.into()); + self + } + + /// Set the ranges to fetch + pub fn ranges>(mut self, range: Vec<(U, U)>) -> Self { + self.ranges = range.into_iter().map(|(a, b)| (a.into(), b.into())).collect(); + self + } + + /// Set a single range fetch + pub fn set_range>(mut self, from: U, to: U) -> Self { + self.ranges = vec![(from.into(), to.into())]; + self + } + + /// Set the ranges to fetch + pub fn add_range>(mut self, from: U, to: U) -> Self { + self.ranges.push((from.into(), to.into())); + self + } + + /// Set the ranges to fetch + pub fn reset_ranges(mut self) -> Self { + self.ranges = Default::default(); + self + } + + /// Build the list + pub fn build(self) -> Result { + let mut rooms_list = ObservableVector::new(); + rooms_list.append(self.rooms_list); + + Ok(SlidingSyncList { + sync_mode: self.sync_mode, + sort: self.sort, + required_state: self.required_state, + batch_size: self.batch_size, + send_updates_for_items: self.send_updates_for_items, + limit: self.limit, + filters: self.filters, + timeline_limit: Arc::new(StdRwLock::new(Observable::new(self.timeline_limit))), + name: self.name.ok_or(Error::BuildMissingField("name"))?, + state: Arc::new(StdRwLock::new(Observable::new(self.state))), + rooms_count: Arc::new(StdRwLock::new(Observable::new(self.rooms_count))), + rooms_list: Arc::new(StdRwLock::new(rooms_list)), + ranges: Arc::new(StdRwLock::new(Observable::new(self.ranges))), + is_cold: Arc::new(AtomicBool::new(false)), + rooms_updated_broadcast: Arc::new(StdRwLock::new(Observable::new(()))), + }) + } +} diff --git a/crates/matrix-sdk/src/sliding_sync/list/mod.rs b/crates/matrix-sdk/src/sliding_sync/list/mod.rs new file mode 100644 index 00000000000..af7c706d3e0 --- /dev/null +++ b/crates/matrix-sdk/src/sliding_sync/list/mod.rs @@ -0,0 +1,681 @@ +mod builder; +mod request_generator; + +use std::{ + collections::BTreeMap, + fmt::Debug, + iter, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, RwLock as StdRwLock, + }, +}; + +pub use builder::*; +use eyeball::unique::Observable; +use eyeball_im::{ObservableVector, VectorDiff}; +use futures_core::Stream; +use im::Vector; +pub(super) use request_generator::*; +use ruma::{api::client::sync::sync_events::v4, events::StateEventType, OwnedRoomId, RoomId, UInt}; +use serde::{Deserialize, Serialize}; +use tracing::{debug, instrument, warn}; + +use super::{Error, FrozenSlidingSyncRoom, SlidingSyncRoom}; +use crate::Result; + +/// Holding a specific filtered list within the concept of sliding sync. +/// Main entrypoint to the SlidingSync +/// +/// ```no_run +/// # use futures::executor::block_on; +/// # use matrix_sdk::Client; +/// # use url::Url; +/// # block_on(async { +/// # let homeserver = Url::parse("http://example.com")?; +/// let client = Client::new(homeserver).await?; +/// let sliding_sync = +/// client.sliding_sync().await.add_fullsync_list().build().await?; +/// +/// # anyhow::Ok(()) +/// # }); +/// ``` +#[derive(Clone, Debug)] +pub struct SlidingSyncList { + /// Which SlidingSyncMode to start this list under + sync_mode: SlidingSyncMode, + + /// Sort the rooms list by this + sort: Vec, + + /// Required states to return per room + required_state: Vec<(StateEventType, String)>, + + /// How many rooms request at a time when doing a full-sync catch up + batch_size: u32, + + /// Whether the list should send `UpdatedAt`-Diff signals for rooms + /// that have changed + send_updates_for_items: bool, + + /// How many rooms request a total hen doing a full-sync catch up + limit: Option, + + /// Any filters to apply to the query + filters: Option, + + /// The maximum number of timeline events to query for + pub timeline_limit: Arc>>>, + + /// Name of this list to easily recognize them + pub name: String, + + /// The state this list is in + state: Arc>>, + + /// The total known number of rooms, + rooms_count: Arc>>>, + + /// The rooms in order + rooms_list: Arc>>, + + /// The ranges windows of the list + #[allow(clippy::type_complexity)] // temporarily + ranges: Arc>>>, + + /// Get informed if anything in the room changed. + /// + /// If you only care to know about changes once all of them have applied + /// (including the total), subscribe to this observable. + pub rooms_updated_broadcast: Arc>>, + + is_cold: Arc, +} + +impl SlidingSyncList { + pub(crate) fn set_from_cold( + &mut self, + rooms_count: Option, + rooms_list: Vector, + ) { + Observable::set(&mut self.state.write().unwrap(), SlidingSyncState::Preload); + self.is_cold.store(true, Ordering::SeqCst); + Observable::set(&mut self.rooms_count.write().unwrap(), rooms_count); + + let mut lock = self.rooms_list.write().unwrap(); + lock.clear(); + lock.append(rooms_list); + } + + /// Create a new [`SlidingSyncListBuilder`]. + pub fn builder() -> SlidingSyncListBuilder { + SlidingSyncListBuilder::new() + } + + /// Return a builder with the same settings as before + pub fn new_builder(&self) -> SlidingSyncListBuilder { + Self::builder() + .name(&self.name) + .sync_mode(self.sync_mode.clone()) + .sort(self.sort.clone()) + .required_state(self.required_state.clone()) + .batch_size(self.batch_size) + .ranges(self.ranges.read().unwrap().clone()) + } + + /// Set the ranges to fetch. + /// + /// Remember to cancel the existing stream and fetch a new one as this will + /// only be applied on the next request. + pub fn set_ranges(&self, range: Vec<(u32, u32)>) -> &Self { + let value = range.into_iter().map(|(a, b)| (a.into(), b.into())).collect(); + Observable::set(&mut self.ranges.write().unwrap(), value); + + self + } + + /// Reset the ranges to a particular set + /// + /// Remember to cancel the existing stream and fetch a new one as this will + /// only be applied on the next request. + pub fn set_range(&self, start: u32, end: u32) -> &Self { + let value = vec![(start.into(), end.into())]; + Observable::set(&mut self.ranges.write().unwrap(), value); + + self + } + + /// Set the ranges to fetch + /// + /// Remember to cancel the existing stream and fetch a new one as this will + /// only be applied on the next request. + pub fn add_range(&self, start: u32, end: u32) -> &Self { + Observable::update(&mut self.ranges.write().unwrap(), |ranges| { + ranges.push((start.into(), end.into())); + }); + + self + } + + /// Set the ranges to fetch + /// + /// Note: sending an empty list of ranges is, according to the spec, to be + /// understood that the consumer doesn't care about changes of the room + /// order but you will only receive updates when for rooms entering or + /// leaving the set. + /// + /// Remember to cancel the existing stream and fetch a new one as this will + /// only be applied on the next request. + pub fn reset_ranges(&self) -> &Self { + Observable::set(&mut self.ranges.write().unwrap(), Vec::new()); + + self + } + + /// Get the current state. + pub fn state(&self) -> SlidingSyncState { + self.state.read().unwrap().clone() + } + + /// Get a stream of state. + pub fn state_stream(&self) -> impl Stream { + Observable::subscribe(&self.state.read().unwrap()) + } + + /// Get the current rooms list. + pub fn rooms_list(&self) -> Vec + where + R: for<'a> From<&'a RoomListEntry>, + { + self.rooms_list.read().unwrap().iter().map(|e| R::from(e)).collect() + } + + /// Get a stream of rooms list. + pub fn rooms_list_stream(&self) -> impl Stream> { + ObservableVector::subscribe(&self.rooms_list.read().unwrap()) + } + + /// Get the current rooms count. + pub fn rooms_count(&self) -> Option { + **self.rooms_count.read().unwrap() + } + + /// Get a stream of rooms count. + pub fn rooms_count_stream(&self) -> impl Stream> { + Observable::subscribe(&self.rooms_count.read().unwrap()) + } + + /// Find the current valid position of the room in the list `room_list`. + /// + /// Only matches against the current ranges and only against filled items. + /// Invalid items are ignore. Return the total position the item was + /// found in the room_list, return None otherwise. + pub fn find_room_in_list(&self, room_id: &RoomId) -> Option { + let ranges = self.ranges.read().unwrap(); + let listing = self.rooms_list.read().unwrap(); + + for (start_uint, end_uint) in ranges.iter() { + let mut current_position: usize = (*start_uint).try_into().unwrap(); + let end: usize = (*end_uint).try_into().unwrap(); + let room_list_entries = listing.iter().skip(current_position); + + for room_list_entry in room_list_entries { + if let RoomListEntry::Filled(this_room_id) = room_list_entry { + if room_id == this_room_id { + return Some(current_position); + } + } + + if current_position == end { + break; + } + + current_position += 1; + } + } + + None + } + + /// Find the current valid position of the rooms in the lists `room_list`. + /// + /// Only matches against the current ranges and only against filled items. + /// Invalid items are ignore. Return the total position the items that were + /// found in the `room_list`, will skip any room not found in the + /// `rooms_list`. + pub fn find_rooms_in_list(&self, room_ids: &[OwnedRoomId]) -> Vec<(usize, OwnedRoomId)> { + let ranges = self.ranges.read().unwrap(); + let listing = self.rooms_list.read().unwrap(); + let mut rooms_found = Vec::new(); + + for (start_uint, end_uint) in ranges.iter() { + let mut current_position: usize = (*start_uint).try_into().unwrap(); + let end: usize = (*end_uint).try_into().unwrap(); + let room_list_entries = listing.iter().skip(current_position); + + for room_list_entry in room_list_entries { + if let RoomListEntry::Filled(room_id) = room_list_entry { + if room_ids.contains(room_id) { + rooms_found.push((current_position, room_id.clone())); + } + } + + if current_position == end { + break; + } + + current_position += 1; + } + } + + rooms_found + } + + /// Return the `room_id` at the given index. + pub fn get_room_id(&self, index: usize) -> Option { + self.rooms_list + .read() + .unwrap() + .get(index) + .and_then(|room_list_entry| room_list_entry.as_room_id().map(ToOwned::to_owned)) + } + + #[instrument(skip(self, ops), fields(name = self.name, ops_count = ops.len()))] + pub(super) fn handle_response( + &self, + rooms_count: u32, + ops: &Vec, + ranges: &Vec<(usize, usize)>, + rooms: &Vec, + ) -> Result { + let current_rooms_count = **self.rooms_count.read().unwrap(); + + if current_rooms_count.is_none() + || current_rooms_count == Some(0) + || self.is_cold.load(Ordering::SeqCst) + { + debug!("first run, replacing rooms list"); + + // first response, we do that slightly differently + let mut rooms_list = ObservableVector::new(); + rooms_list + .append(iter::repeat(RoomListEntry::Empty).take(rooms_count as usize).collect()); + + // then we apply it + room_ops(&mut rooms_list, ops, ranges)?; + + { + let mut lock = self.rooms_list.write().unwrap(); + lock.clear(); + lock.append(rooms_list.into_inner()); + } + + Observable::set(&mut self.rooms_count.write().unwrap(), Some(rooms_count)); + self.is_cold.store(false, Ordering::SeqCst); + + return Ok(true); + } + + debug!("regular update"); + + let mut missing = rooms_count + .checked_sub(self.rooms_list.read().unwrap().len() as u32) + .unwrap_or_default(); + let mut changed = false; + + if missing > 0 { + let mut list = self.rooms_list.write().unwrap(); + + while missing > 0 { + list.push_back(RoomListEntry::Empty); + missing -= 1; + } + + changed = true; + } + + { + // keep the lock scoped so that the later `find_rooms_in_list` doesn't deadlock + let mut rooms_list = self.rooms_list.write().unwrap(); + + if !ops.is_empty() { + room_ops(&mut rooms_list, ops, ranges)?; + changed = true; + } else { + debug!("no rooms operations found"); + } + } + + { + let mut lock = self.rooms_count.write().unwrap(); + + if **lock != Some(rooms_count) { + Observable::set(&mut lock, Some(rooms_count)); + changed = true; + } + } + + if self.send_updates_for_items && !rooms.is_empty() { + let found_lists = self.find_rooms_in_list(rooms); + + if !found_lists.is_empty() { + debug!("room details found"); + let mut rooms_list = self.rooms_list.write().unwrap(); + + for (pos, room_id) in found_lists { + // trigger an `UpdatedAt` update + rooms_list.set(pos, RoomListEntry::Filled(room_id)); + changed = true; + } + } + } + + if changed { + Observable::set(&mut self.rooms_updated_broadcast.write().unwrap(), ()); + } + + Ok(changed) + } + + pub(super) fn request_generator(&self) -> SlidingSyncListRequestGenerator { + match &self.sync_mode { + SlidingSyncMode::PagingFullSync => { + SlidingSyncListRequestGenerator::new_with_paging_syncup(self.clone()) + } + + SlidingSyncMode::GrowingFullSync => { + SlidingSyncListRequestGenerator::new_with_growing_syncup(self.clone()) + } + + SlidingSyncMode::Selective => SlidingSyncListRequestGenerator::new_live(self.clone()), + } + } +} + +#[derive(Serialize, Deserialize)] +pub(super) struct FrozenSlidingSyncList { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub(super) rooms_count: Option, + #[serde(default, skip_serializing_if = "Vector::is_empty")] + pub(super) rooms_list: Vector, + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + pub(super) rooms: BTreeMap, +} + +impl FrozenSlidingSyncList { + pub(super) fn freeze( + source_list: &SlidingSyncList, + rooms_map: &BTreeMap, + ) -> Self { + let mut rooms = BTreeMap::new(); + let mut rooms_list = Vector::new(); + + for room_list_entry in source_list.rooms_list.read().unwrap().iter() { + match room_list_entry { + RoomListEntry::Filled(room_id) | RoomListEntry::Invalidated(room_id) => { + rooms.insert( + room_id.clone(), + rooms_map.get(room_id).expect("room doesn't exist").into(), + ); + } + + _ => {} + }; + + rooms_list.push_back(room_list_entry.freeze()); + } + + FrozenSlidingSyncList { + rooms_count: **source_list.rooms_count.read().unwrap(), + rooms_list, + rooms, + } + } +} + +#[instrument(skip(operations))] +fn room_ops( + rooms_list: &mut ObservableVector, + operations: &Vec, + room_ranges: &Vec<(usize, usize)>, +) -> Result<(), Error> { + let index_in_range = |idx| room_ranges.iter().any(|(start, end)| idx >= *start && idx <= *end); + + for operation in operations { + match &operation.op { + v4::SlidingOp::Sync => { + let start: u32 = operation + .range + .ok_or_else(|| { + Error::BadResponse( + "`range` must be present for Sync and Update operation".to_owned(), + ) + })? + .0 + .try_into() + .map_err(|error| { + Error::BadResponse(format!("`range` not a valid int: {error}")) + })?; + let room_ids = operation.room_ids.clone(); + + room_ids + .into_iter() + .enumerate() + .map(|(i, r)| { + let idx = start as usize + i; + + if idx >= rooms_list.len() { + rooms_list.insert(idx, RoomListEntry::Filled(r)); + } else { + rooms_list.set(idx, RoomListEntry::Filled(r)); + } + }) + .count(); + } + + v4::SlidingOp::Delete => { + let position: u32 = operation + .index + .ok_or_else(|| { + Error::BadResponse( + "`index` must be present for DELETE operation".to_owned(), + ) + })? + .try_into() + .map_err(|error| { + Error::BadResponse(format!("`index` not a valid int for DELETE: {error}")) + })?; + rooms_list.set(position as usize, RoomListEntry::Empty); + } + + v4::SlidingOp::Insert => { + let position: usize = operation + .index + .ok_or_else(|| { + Error::BadResponse( + "`index` must be present for INSERT operation".to_owned(), + ) + })? + .try_into() + .map_err(|error| { + Error::BadResponse(format!("`index` not a valid int for INSERT: {error}")) + })?; + + let room = RoomListEntry::Filled(operation.room_id.clone().ok_or_else(|| { + Error::BadResponse("`room_id` must be present for INSERT operation".to_owned()) + })?); + let mut offset = 0usize; + + loop { + // Find the next empty slot and drop it. + let (previous_position, overflow) = position.overflowing_sub(offset); + let check_previous = !overflow && index_in_range(previous_position); + + let (next_position, overflow) = position.overflowing_add(offset); + let check_next = !overflow + && next_position < rooms_list.len() + && index_in_range(next_position); + + if !check_previous && !check_next { + return Err(Error::BadResponse( + "We were asked to insert but could not find any direction to shift to" + .to_owned(), + )); + } + + if check_previous && rooms_list[previous_position].is_empty_or_invalidated() { + // we only check for previous, if there are items left + rooms_list.remove(previous_position); + + break; + } else if check_next && rooms_list[next_position].is_empty_or_invalidated() { + rooms_list.remove(next_position); + + break; + } else { + // Let's check the next position. + offset += 1; + } + } + + rooms_list.insert(position, room); + } + + v4::SlidingOp::Invalidate => { + let max_len = rooms_list.len(); + let (mut position, end): (usize, usize) = if let Some(range) = operation.range { + ( + range.0.try_into().map_err(|error| { + Error::BadResponse(format!("`range.0` not a valid int: {error}")) + })?, + range.1.try_into().map_err(|error| { + Error::BadResponse(format!("`range.1` not a valid int: {error}")) + })?, + ) + } else { + return Err(Error::BadResponse( + "`range` must be given on `Invalidate` operation".to_owned(), + )); + }; + + if position > end { + return Err(Error::BadResponse( + "Invalid invalidation, end smaller than start".to_owned(), + )); + } + + // Ranges are inclusive up to the last index. e.g. `[0, 10]`; `[0, 0]`. + // ensure we pick them all up. + while position <= end { + if position >= max_len { + break; // how does this happen? + } + + let room_id = + if let Some(RoomListEntry::Filled(room_id)) = rooms_list.get(position) { + Some(room_id.clone()) + } else { + None + }; + + if let Some(room_id) = room_id { + rooms_list.set(position, RoomListEntry::Invalidated(room_id)); + } else { + rooms_list.set(position, RoomListEntry::Empty); + } + + position += 1; + } + } + + unknown_operation => { + warn!("Unknown operation occurred: {unknown_operation:?}"); + } + } + } + + Ok(()) +} + +/// The state the [`SlidingSyncList`] is in. +/// +/// The lifetime of a SlidingSync usually starts at a `Preload`, getting a fast +/// response for the first given number of Rooms, then switches into +/// `CatchingUp` during which the list fetches the remaining rooms, usually in +/// order, some times in batches. Once that is ready, it switches into `Live`. +/// +/// If the client has been offline for a while, though, the SlidingSync might +/// return back to `CatchingUp` at any point. +#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SlidingSyncState { + /// Hasn't started yet + #[default] + Cold, + /// We are quickly preloading a preview of the most important rooms + Preload, + /// We are trying to load all remaining rooms, might be in batches + CatchingUp, + /// We are all caught up and now only sync the live responses. + Live, +} + +/// The mode by which the the [`SlidingSyncList`] is in fetching the data. +#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SlidingSyncMode { + /// Fully sync all rooms in the background, page by page of `batch_size`, + /// like `0..20`, `21..40`, 41..60` etc. assuming the `batch_size` is 20. + #[serde(alias = "FullSync")] + PagingFullSync, + /// Fully sync all rooms in the background, with a growing window of + /// `batch_size`, like `0..20`, `0..40`, `0..60` etc. assuming the + /// `batch_size` is 20. + GrowingFullSync, + /// Only sync the specific windows defined + #[default] + Selective, +} + +/// The Entry in the Sliding Sync room list per Sliding Sync list. +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub enum RoomListEntry { + /// This entry isn't known at this point and thus considered `Empty`. + #[default] + Empty, + /// There was `OwnedRoomId` but since the server told us to invalid this + /// entry. it is considered stale. + Invalidated(OwnedRoomId), + /// This entry is followed with `OwnedRoomId`. + Filled(OwnedRoomId), +} + +impl RoomListEntry { + /// Is this entry empty or invalidated? + pub fn is_empty_or_invalidated(&self) -> bool { + matches!(self, Self::Empty | Self::Invalidated(_)) + } + + /// Return the inner `room_id` if the entry' state is not empty. + pub fn as_room_id(&self) -> Option<&RoomId> { + match &self { + Self::Empty => None, + Self::Invalidated(room_id) | Self::Filled(room_id) => Some(room_id.as_ref()), + } + } + + /// Clone this entry, but freeze it, i.e. if the entry is empty, it remains + /// empty, otherwise it is invalidated. + fn freeze(&self) -> Self { + match &self { + Self::Empty => Self::Empty, + Self::Invalidated(room_id) | Self::Filled(room_id) => { + Self::Invalidated(room_id.clone()) + } + } + } +} + +impl<'a> From<&'a RoomListEntry> for RoomListEntry { + fn from(value: &'a RoomListEntry) -> Self { + value.clone() + } +} diff --git a/crates/matrix-sdk/src/sliding_sync/list/request_generator.rs b/crates/matrix-sdk/src/sliding_sync/list/request_generator.rs new file mode 100644 index 00000000000..76281e0b5c0 --- /dev/null +++ b/crates/matrix-sdk/src/sliding_sync/list/request_generator.rs @@ -0,0 +1,196 @@ +use std::cmp::min; + +use eyeball::unique::Observable; +use ruma::{api::client::sync::sync_events::v4, assign, OwnedRoomId, UInt}; +use tracing::{error, instrument, trace}; + +use super::{Error, SlidingSyncList, SlidingSyncState}; + +enum GeneratorKind { + GrowingFullSync { position: u32, batch_size: u32, limit: Option, live: bool }, + PagingFullSync { position: u32, batch_size: u32, limit: Option, live: bool }, + Live, +} + +pub(in super::super) struct SlidingSyncListRequestGenerator { + list: SlidingSyncList, + ranges: Vec<(usize, usize)>, + kind: GeneratorKind, +} + +impl SlidingSyncListRequestGenerator { + pub(super) fn new_with_paging_syncup(list: SlidingSyncList) -> Self { + let batch_size = list.batch_size; + let limit = list.limit; + let position = list + .ranges + .read() + .unwrap() + .first() + .map(|(_start, end)| u32::try_from(*end).unwrap()) + .unwrap_or_default(); + + Self { + list, + ranges: Default::default(), + kind: GeneratorKind::PagingFullSync { position, batch_size, limit, live: false }, + } + } + + pub(super) fn new_with_growing_syncup(list: SlidingSyncList) -> Self { + let batch_size = list.batch_size; + let limit = list.limit; + let position = list + .ranges + .read() + .unwrap() + .first() + .map(|(_start, end)| u32::try_from(*end).unwrap()) + .unwrap_or_default(); + + Self { + list, + ranges: Default::default(), + kind: GeneratorKind::GrowingFullSync { position, batch_size, limit, live: false }, + } + } + + pub(super) fn new_live(list: SlidingSyncList) -> Self { + Self { list, ranges: Default::default(), kind: GeneratorKind::Live } + } + + fn prefetch_request( + &mut self, + start: u32, + batch_size: u32, + limit: Option, + ) -> v4::SyncRequestList { + let calculated_end = start + batch_size; + + let mut end = match limit { + Some(limit) => min(limit, calculated_end), + _ => calculated_end, + }; + + end = match self.list.rooms_count() { + Some(total_room_count) => min(end, total_room_count - 1), + _ => end, + }; + + self.make_request_for_ranges(vec![(start.into(), end.into())]) + } + + #[instrument(skip(self), fields(name = self.list.name))] + fn make_request_for_ranges(&mut self, ranges: Vec<(UInt, UInt)>) -> v4::SyncRequestList { + let sort = self.list.sort.clone(); + let required_state = self.list.required_state.clone(); + let timeline_limit = **self.list.timeline_limit.read().unwrap(); + let filters = self.list.filters.clone(); + + self.ranges = ranges + .iter() + .map(|(a, b)| { + ( + usize::try_from(*a).expect("range is a valid u32"), + usize::try_from(*b).expect("range is a valid u32"), + ) + }) + .collect(); + + assign!(v4::SyncRequestList::default(), { + ranges: ranges, + room_details: assign!(v4::RoomDetailsConfig::default(), { + required_state, + timeline_limit, + }), + sort, + filters, + }) + } + + #[instrument(skip_all, fields(name = self.list.name, rooms_count, has_ops = !ops.is_empty()))] + pub(in super::super) fn handle_response( + &mut self, + rooms_count: u32, + ops: &Vec, + rooms: &Vec, + ) -> Result { + let response = self.list.handle_response(rooms_count, ops, &self.ranges, rooms)?; + self.update_state(rooms_count.saturating_sub(1)); // index is 0 based, count is 1 based + + Ok(response) + } + + fn update_state(&mut self, max_index: u32) { + let Some((_start, range_end)) = self.ranges.first() else { + error!("Why don't we have any ranges?"); + + return; + }; + + let end = if &(max_index as usize) < range_end { max_index } else { *range_end as u32 }; + + trace!(end, max_index, range_end, name = self.list.name, "updating state"); + + match &mut self.kind { + GeneratorKind::PagingFullSync { position, live, limit, .. } + | GeneratorKind::GrowingFullSync { position, live, limit, .. } => { + let max = limit.map(|limit| min(limit, max_index)).unwrap_or(max_index); + + trace!(end, max, name = self.list.name, "updating state"); + + if end >= max { + // Switching to live mode. + + trace!(name = self.list.name, "going live"); + + self.list.set_range(0, max); + *position = max; + *live = true; + + Observable::update_eq(&mut self.list.state.write().unwrap(), |state| { + *state = SlidingSyncState::Live; + }); + } else { + *position = end; + *live = false; + self.list.set_range(0, end); + + Observable::update_eq(&mut self.list.state.write().unwrap(), |state| { + *state = SlidingSyncState::CatchingUp; + }); + } + } + + GeneratorKind::Live => { + Observable::update_eq(&mut self.list.state.write().unwrap(), |state| { + *state = SlidingSyncState::Live; + }); + } + } + } +} + +impl Iterator for SlidingSyncListRequestGenerator { + type Item = v4::SyncRequestList; + + fn next(&mut self) -> Option { + match self.kind { + GeneratorKind::PagingFullSync { live: true, .. } + | GeneratorKind::GrowingFullSync { live: true, .. } + | GeneratorKind::Live => { + let ranges = self.list.ranges.read().unwrap().clone(); + + Some(self.make_request_for_ranges(ranges)) + } + + GeneratorKind::PagingFullSync { position, batch_size, limit, .. } => { + Some(self.prefetch_request(position, batch_size, limit)) + } + + GeneratorKind::GrowingFullSync { position, batch_size, limit, .. } => { + Some(self.prefetch_request(0, position + batch_size, limit)) + } + } + } +} diff --git a/crates/matrix-sdk/src/sliding_sync/mod.rs b/crates/matrix-sdk/src/sliding_sync/mod.rs new file mode 100644 index 00000000000..83855191b24 --- /dev/null +++ b/crates/matrix-sdk/src/sliding_sync/mod.rs @@ -0,0 +1,1399 @@ +// Copyright 2022-2023 Benjamin Kampmann +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for that specific language governing permissions and +// limitations under the License. + +//! Sliding Sync Client implementation of [MSC3575][MSC] & extensions +//! +//! [`Sliding Sync`][MSC] is the third generation synchronization mechanism of +//! Matrix with a strong focus on bandwidth efficiency. This is made possible by +//! allowing the client to filter the content very specifically in its request +//! which, as a result, allows the server to reduce the data sent to the +//! absolute necessary minimum needed. The API is modeled after common patterns +//! and UI components end-user messenger clients typically offer. By allowing a +//! tight coupling of what a client shows and synchronizing that state over +//! the protocol to the server, the server always sends exactly the information +//! necessary for the currently displayed subset for the user rather than +//! filling the connection with data the user isn't interested in right now. +//! +//! Sliding Sync is a live-protocol using [long-polling](#long-polling) HTTP(S) +//! connections to stay up to date. On the client side these updates are applied +//! and propagated through an [asynchronous reactive API](#reactive-api). +//! +//! The protocol is split into three major sections for that: [lists][#lists], +//! the [room details](#rooms) and [extensions](#extensions), most notably the +//! end-to-end-encryption and to-device extensions to enable full +//! end-to-end-encryption support. +//! +//! ## Starting up +//! +//! To create a new Sliding Sync session, one must query an existing +//! (authenticated) `Client` for a new [`SlidingSyncBuilder`] by calling +//! [`Client::sliding_sync`](`super::Client::sliding_sync`). The +//! [`SlidingSyncBuilder`] is the baseline configuration to create a +//! [`SlidingSync`] session by calling `.build()` once everything is ready. +//! Typically one configures the custom homeserver endpoint. +//! +//! At the time of writing, no Matrix server natively supports Sliding Sync; +//! a sidecar called the [Sliding Sync Proxy][proxy] is needed. As that +//! typically runs on a separate domain, it can be configured on the +//! [`SlidingSyncBuilder`]: +//! +//! ```no_run +//! # use futures::executor::block_on; +//! # use matrix_sdk::Client; +//! # use url::Url; +//! # block_on(async { +//! # let homeserver = Url::parse("http://example.com")?; +//! # let client = Client::new(homeserver).await?; +//! let sliding_sync_builder = client +//! .sliding_sync() +//! .await +//! .homeserver(Url::parse("http://sliding-sync.example.org")?); +//! +//! # anyhow::Ok(()) +//! # }); +//! ``` +//! +//! After the general configuration, one typically wants to add a list via the +//! [`add_list`][`SlidingSyncBuilder::add_list`] function. +//! +//! ## Lists +//! +//! A list defines a subset of matching rooms one wants to filter for, and be +//! kept up about. The [`v4::SyncRequestListFilters`][] allows for a granular +//! specification of the exact rooms one wants the server to select and the way +//! one wants them to be ordered before receiving. Secondly each list has a set +//! of `ranges`: the subset of indexes of the entire list one is interested in +//! and a unique name to be identified with. +//! +//! For example, a user might be part of thousands of rooms, but if the client +//! app always starts by showing the most recent direct message conversations, +//! loading all rooms is an inefficient approach. Instead with Sliding Sync one +//! defines a list (e.g. named `"main_list"`) filtering for `is_dm`, ordered +//! by recency and select to list the top 10 via `ranges: [ [0,9] ]` (indexes +//! are **inclusive**) like so: +//! +//! ```rust +//! # use matrix_sdk::sliding_sync::{SlidingSyncList, SlidingSyncMode}; +//! use ruma::{assign, api::client::sync::sync_events::v4}; +//! +//! let list_builder = SlidingSyncList::builder() +//! .name("main_list") +//! .sync_mode(SlidingSyncMode::Selective) +//! .filters(Some(assign!( +//! v4::SyncRequestListFilters::default(), { is_dm: Some(true)} +//! ))) +//! .sort(vec!["by_recency".to_owned()]) +//! .set_range(0u32, 9u32); +//! ``` +//! +//! Please refer to the [specification][MSC], the [Ruma types][ruma-types], +//! specifically [`SyncRequestListFilter`](https://docs.rs/ruma/latest/ruma/api/client/sync/sync_events/v4/struct.SyncRequestListFilters.html) and the +//! [`SlidingSyncListBuilder`] for details on the filters, sort-order and +//! range-options and data one requests to be sent. Once the list is fully +//! configured, `build()` it and add the list to the sliding sync session +//! by supplying it to [`add_list`][`SlidingSyncBuilder::add_list`]. +//! +//! Lists are inherently stateful and all updates are applied on the shared +//! list-object. Once a list has been added to [`SlidingSync`], a cloned shared +//! copy can be retrieved by calling `SlidingSync::list()`, providing the name +//! of the list. Next to the configuration settings (like name and +//! `timeline_limit`), the list provides the stateful +//! [`rooms_count`](SlidingSyncList::rooms_count), +//! [`rooms_list`](SlidingSyncList::rooms_list) and +//! [`state`](SlidingSyncList::state): +//! +//! - `rooms_count` is the number of rooms _total_ there were found matching +//! the filters given. +//! - `rooms_list` is a vector of `rooms_count` [`RoomListEntry`]'s at the +//! current state. `RoomListEntry`'s only hold `the room_id` if given, the +//! [Rooms API](#rooms) holds the actual information about each room +//! - `state` is a [`SlidingSyncMode`] signalling meta information about the +//! list and its stateful data β€” whether this is the state loaded from local +//! cache, whether the [full sync](#helper-lists) is in progress or whether +//! this is the current live information +//! +//! These are updated upon every update received from the server. One can query +//! these for their current value at any time, or use the [Reactive API +//! to subscribe to changes](#reactive-api). +//! +//! ### Helper lists +//! +//! By default lists run in the [`Selective` mode](SlidingSyncMode::Selective). +//! That means one sets the desired range(s) to see explicitly (as described +//! above). Very often, one still wants to load up the entire room list in +//! background though. For that, the client implementation offers to run lists +//! in two additional full-sync-modes, which require additional configuration: +//! +//! - [`SlidingSyncMode::PagingFullSync`]: Pages through the entire list of +//! rooms one request at a time asking for the next `batch_size` number of +//! rooms up to the end or `limit` if configured +//! - [`SlidingSyncMode::GrowingFullSync`]: Grows the window by `batch_size` on +//! every request till all rooms or until `limit` of rooms are in list. +//! +//! For both, one should configure +//! [`batch_size`](SlidingSyncListBuilder::batch_size) and optionally +//! [`limit`](SlidingSyncListBuilder::limit) on the [`SlidingSyncListBuilder`]. +//! Both full-sync lists will notice if the number of rooms increased at runtime +//! and will attempt to catch up to that (barring the `limit`). +//! +//! ## Rooms +//! +//! Next to the room list, the details for rooms are the next important aspect. +//! Each [list](#lists) only references the [`OwnedRoomId`][ruma::OwnedRoomId] +//! of the room at the given position. The details (`required_state`s and +//! timeline items) requested by all lists are bundled, together with the common +//! details (e.g. whether it is a `dm` or its calculated name) and made +//! available on the Sliding Sync session struct as a [reactive](#reactive-api) +//! through [`.rooms`](SlidingSync::rooms), [`get_room`](SlidingSync::get_room) +//! and [`get_rooms`](SlidingSync::get_rooms) APIs. +//! +//! Notably, this map only knows about the rooms that have come down [Sliding +//! Sync protocol][MSC] and if the given room isn't in any active list range, it +//! may be stale. Additionally to selecting the room data via the room lists, +//! the [Sliding Sync protocol][MSC] allows to subscribe to specific rooms via +//! the [`subscribe()`](SlidingSync::subscribe). Any room subscribed to will +//! receive updates (with the given settings) regardless of whether they are +//! visible in any list. The most common case for using this API is when the +//! user enters a room - as we want to receive the incoming new messages +//! regardless of whether the room is pushed out of the lists room list. +//! +//! ### Room List Entries +//! +//! As the room list of each list is a vec of the `rooms_count` len but a room +//! may only know of a subset of entries for sure at any given time, these +//! entries are wrapped in [`RoomListEntry`][]. This type, in close proximity to +//! the [specification][MSC], can be either `Empty`, `Filled` or `Invalidated`, +//! signaling the state of each entry position. +//! - `Empty` we don't know what sits here at this position in the list. +//! - `Filled`: there is this `room_id` at this position. +//! - `Invalidated` in that sense means that we _knew_ what was here before, but +//! can't be sure anymore this is still accurate. This occurs when we move the +//! sliding window (by changing the ranges) or when a room might drop out of +//! the window we are looking at. For the sake of displaying, this is probably +//! still fine to display to be at this position, but we can't be sure +//! anymore. +//! +//! Because `Invalidated` occurs whenever a room we knew about before drops out +//! of focus, we aren't updated about its changes anymore either, there could be +//! duplicates rooms within invalidated rooms as well as in the union of +//! invalidated and filled rooms. Keep that in mind, as most UI frameworks don't +//! like it when their list entries aren't unique. +//! +//! When [restoring from cold cache][#caching] the room list also only +//! propagated with `Invalidated` rooms. So if you want to be able to display +//! data quickly, ensure you are able to render `Invalidated` entries. +//! +//! ### Unsubscribe +//! +//! Don't forget to [unsubscribe](`SlidingSync::subscribe`) when the data isn't +//! needed to be updated anymore, e.g. when the user leaves the room, to reduce +//! the bandwidth back down to what is really needed. +//! +//! ## Extensions +//! +//! Additionally to the rooms list and rooms with their state and latest +//! messages Matrix knows of many other exchange information. All these are +//! modeled as specific, optional extensions in the [sliding sync +//! protocol][MSC]. This includes end-to-end-encryption, to-device-messages, +//! typing- and presence-information and account-data, but can be extended by +//! any implementation as they please. Handling of the data of the e2ee, +//! to-device and typing-extensions takes place transparently within the SDK. +//! +//! By default [`SlidingSync`][] doesn't activate _any_ extensions to save on +//! bandwidth, but we generally recommend to use the [`with_common_extensions` +//! when building sliding sync](`SlidingSyncBuilder::with_common_extensions`) to +//! active e2ee, to-device-messages and account-data-extensions. +//! +//! ## Timeline events +//! +//! Both the list configuration as well as the [room subscription +//! settings](`v4::RoomSubscription`) allow to specify a `timeline_limit` to +//! receive timeline events. If that is unset or set to 0, no events are sent by +//! the server (which is the default), if multiple limits are found, the highest +//! takes precedence. Any positive number indicates that on the first request a +//! room should come into list, up to that count of messages are sent +//! (depending how many the server has in cache). Following, whenever new events +//! are found for the matching rooms, the server relays them to the client. +//! +//! All timeline events coming through Sliding Sync will be processed through +//! the [`BaseClient`][`matrix_sdk_base::BaseClient`] as in previous sync. This +//! allows for transparent decryption as well trigger the `client_handlers`. +//! +//! The current and then following live events list can be queried via the +//! [`timeline` API](`SlidingSyncRoom::timeline). This is prefilled with already +//! received data. +//! +//! ### Timeline trickling +//! +//! To allow for a quick startup, client might want to request only a very low +//! `timeline_limit` (maybe 1 or even 0) at first and update the count later on +//! the list or room subscription (see [reactive api](#reactive-api)), Since +//! `0.99.0-rc1` the [sliding sync proxy][proxy] will then "paginate back" and +//! resent the now larger number of events. All this is handled transparently. +//! +//! ## Long Polling +//! +//! [Sliding Sync][MSC] is a long-polling API. That means that immediately after +//! one has received data from the server, they re-open the network connection +//! again and await for a new response. As there might not be happening much or +//! a lot happening in short succession β€” from the client perspective we never +//! know when new data is received. +//! +//! One principle of long-polling is, therefore, that it might also takes one +//! or two requests before the changes one asked for to actually be applied +//! and the results come back for that. Just assume that at the same time one +//! adds a room subscription, a new message comes in. The server might reply +//! with that message immediately and will only kick off the process of +//! calculating the rooms details and respond with that in the next request one +//! does after. +//! +//! This is modelled as a [async `Stream`][`futures_core::stream::Stream`] in +//! our API, that one basically wants to continue polling. Once one has made its +//! setup ready and build its sliding sync sessions, one wants to acquire its +//! [`.stream()`](`SlidingSync::stream`) and continuously poll it. +//! +//! While the async stream API allows for streams to end (by returning `None`) +//! Sliding Sync streams items `Result`. For every +//! successful poll, all data is applied internally, through the base client and +//! the [reactive structs](#reactive-api) and an +//! [`Ok(UpdateSummary)`][`UpdateSummary`] is yielded with the minimum +//! information, which data has been refreshed _in this iteration_: names of +//! lists and `room_id`s of rooms. Note that, the same way that a list isn't +//! reacting if only the room data has changed (but not its position in its +//! list), the list won't be mentioned here either, only the `room_id`. So be +//! sure to look at both for all subscribed objects. +//! +//! In full, this typically looks like this: +//! +//! ```no_run +//! # use futures::executor::block_on; +//! # use futures::{pin_mut, StreamExt}; +//! # use matrix_sdk::{ +//! # sliding_sync::{SlidingSyncMode, SlidingSyncListBuilder}, +//! # Client, +//! # }; +//! # use ruma::{ +//! # api::client::sync::sync_events::v4, assign, +//! # }; +//! # use tracing::{debug, error, info, warn}; +//! # use url::Url; +//! # block_on(async { +//! # let homeserver = Url::parse("http://example.com")?; +//! # let client = Client::new(homeserver).await?; +//! let sliding_sync = client +//! .sliding_sync() +//! .await +//! // any lists you want are added here. +//! .build() +//! .await?; +//! +//! let stream = sliding_sync.stream(); +//! +//! // continuously poll for updates +//! pin_mut!(stream); +//! +//! loop { +//! let update = match stream.next().await { +//! Some(Ok(u)) => { +//! info!("Received an update. Summary: {u:?}"); +//! } +//! Some(Err(e)) => { +//! error!("loop was stopped by client error processing: {e}"); +//! } +//! None => { +//! error!("Streaming loop ended unexpectedly"); +//! break; +//! } +//! }; +//! } +//! +//! # anyhow::Ok(()) +//! # }); +//! ``` +//! +//! ### Quick refreshing +//! +//! A main purpose of [Sliding Sync][MSC] is to provide an API for snappy end +//! user applications. Long-polling on the other side means that we wait for the +//! server to respond and that can take quite some time, before sending the next +//! request with our updates, for example an update in a list's `range`. +//! +//! That is a bit unfortunate and leaks through the `stream` API as well. We are +//! waiting for a `stream.next().await` call before the next request is sent. +//! The [specification][MSC] on long polling also states, however, that if an +//! new request is found coming in, the previous one shall be sent out. In +//! practice that means one can just start a new stream and the old connection +//! will return immediately β€” with a proper response though. One just needs to +//! make sure to not call that stream any further. Additionally, as both +//! requests are sent with the same positional argument, the server might +//! respond with data, the client has already processed. This isn't a problem, +//! the [`SlidingSync`][] will only process new data and skip the processing +//! even across restarts. +//! +//! To support this, in practice one should usually wrap its `loop` in a +//! spawn with an atomic flag that tells it to stop, which one can set upon +//! restart. Something along the lines of: +//! +//! ```no_run +//! # use futures::executor::block_on; +//! # use futures::{pin_mut, StreamExt}; +//! # use matrix_sdk::{ +//! # sliding_sync::{SlidingSyncMode, SlidingSyncListBuilder, SlidingSync, Error}, +//! # Client, +//! # }; +//! # use ruma::{ +//! # api::client::sync::sync_events::v4, assign, +//! # }; +//! # use tracing::{debug, error, info, warn}; +//! # use url::Url; +//! # block_on(async { +//! # let homeserver = Url::parse("http://example.com")?; +//! # let client = Client::new(homeserver).await?; +//! # let sliding_sync = client +//! # .sliding_sync() +//! # .await +//! # // any lists you want are added here. +//! # .build() +//! # .await?; +//! use std::sync::{Arc, atomic::{AtomicBool, Ordering}}; +//! +//! struct MyRunner { lock: Arc, sliding_sync: SlidingSync }; +//! +//! impl MyRunner { +//! pub fn restart_sync(&mut self) { +//! self.lock.store(false, Ordering::SeqCst); +//! // create a new lock +//! self.lock = Arc::new(AtomicBool::new(false)); +//! +//! let stream_lock = self.lock.clone(); +//! let sliding_sync = self.sliding_sync.clone(); +//! +//! // continuously poll for updates +//! tokio::spawn(async move { +//! let stream = sliding_sync.stream(); +//! pin_mut!(stream); +//! loop { +//! match stream.next().await { +//! Some(Ok(u)) => { +//! info!("Received an update. Summary: {u:?}"); +//! } +//! Some(Err(e)) => { +//! error!("loop was stopped by client error processing: {e}"); +//! } +//! None => { +//! error!("Streaming loop ended unexpectedly"); +//! break; +//! } +//! }; +//! if !stream_lock.load(Ordering::SeqCst) { +//! info!("Asked to stop"); +//! break +//! } +//! }; +//! }); +//! } +//! } +//! +//! # anyhow::Ok(()) +//! # }); +//! ``` +//! +//! +//! ## Reactive API +//! +//! As the main source of truth is the data coming from the server, all updates +//! must be applied transparently throughout to the data layer. The simplest +//! way to stay up to date on what objects have changed is by checking the +//! [`lists`](`UpdateSummary.lists`) and [`rooms`](`UpdateSummary.rooms`) of +//! each [`UpdateSummary`] given by each stream iteration and update the local +//! copies accordingly. Because of where the loop sits in the stack, that can +//! be a bit tedious though, so lists and rooms have an additional way of +//! subscribing to updates via [`eyeball`]. +//! +//! The `Timeline` one can receive per room by calling +//! [`.timeline()`][`SlidingSyncRoom::timeline`] will be populated with the +//! currently cached timeline events. +//! +//! ## Caching +//! +//! All room data, for filled but also _invalidated_ rooms, including the entire +//! timeline events as well as all list `room_lists` and `rooms_count` are held +//! in memory (unless one `pop`s the list out). +//! +//! This is a purely in-memory cache layer though. If one wants Sliding Sync to +//! persist and load from cold (storage) cache, one needs to set its key with +//! [`cold_cache(name)`][`SlidingSyncBuilder::cold_cache`] and for each list +//! present at `.build()`[`SlidingSyncBuilder::build`] sliding sync will attempt +//! to load their latest cached version from storage, as well as some overall +//! information of Sliding Sync. If that succeeded the lists `state` has been +//! set to [`Preload`][SlidingSyncListState::Preload]. Only room data of rooms +//! present in one of the lists is loaded from storage. +//! +//! Once [#1441](https://github.com/matrix-org/matrix-rust-sdk/pull/1441) is merged +//! one can disable caching on a per-list basis by setting +//! [`cold_cache(false)`][`SlidingSyncListBuilder::cold_cache`] when +//! constructing the builder. +//! +//! Notice that lists added after Sliding Sync has been built **will not be +//! loaded from cache** regardless of their settings (as this could lead to +//! inconsistencies between lists). The same goes for any extension: some +//! extension data (like the to-device-message position) are stored to storage, +//! but only retrieved upon `build()` of the `SlidingSyncBuilder`. So if one +//! only adds them later, they will not be reading the data from storage (to +//! avoid inconsistencies) and might require more data to be sent in their first +//! request than if they were loaded form cold-cache. +//! +//! When loading from storage `rooms_list` entries found are set to +//! `Invalidated` β€” the initial setting here is communicated as a single +//! `VecDiff::Replace` event through the [reactive API](#reactive-api). +//! +//! Only the latest 10 timeline items of each room are cached and they are reset +//! whenever a new set of timeline items is received by the server. +//! +//! ## Bot mode +//! +//! _Note_: This is not yet exposed via the API. See [#1475](https://github.com/matrix-org/matrix-rust-sdk/issues/1475) +//! +//! Sliding Sync is modeled for faster and more efficient user-facing client +//! applications, but offers significant speed ups even for bot cases through +//! its filtering mechanism. The sort-order and specific subsets, however, are +//! usually not of interest for bots. For that use case the the +//! [`v4::SyncRequestList`][] offers the +//! [`slow_get_all_rooms`](`v4::SyncRequestList::slow_get_all_rooms`) flag. +//! +//! Once switched on, this mode will not trigger any updates on "list +//! movements", ranges and sorting are ignored and all rooms matching the filter +//! will be returned with the given room details settings. Depending on the data +//! that is requested this will still be significantly faster as the response +//! only returns the matching rooms and states as per settings. +//! +//! Think about a bot that only interacts in `is_dm = true` and doesn't need +//! room topic, room avatar and all the other state. It will be a lot faster to +//! start up and retrieve only the data needed to actually run. +//! +//! # Full example +//! +//! ```no_run +//! # use futures::executor::block_on; +//! use matrix_sdk::{Client, sliding_sync::{SlidingSyncList, SlidingSyncMode}}; +//! use ruma::{assign, {api::client::sync::sync_events::v4, events::StateEventType}}; +//! use tracing::{warn, error, info, debug}; +//! use futures::{StreamExt, pin_mut}; +//! use url::Url; +//! # block_on(async { +//! # let homeserver = Url::parse("http://example.com")?; +//! # let client = Client::new(homeserver).await?; +//! let full_sync_list_name = "full-sync".to_owned(); +//! let active_list_name = "active-list".to_owned(); +//! let sliding_sync_builder = client +//! .sliding_sync() +//! .await +//! .homeserver(Url::parse("http://sliding-sync.example.org")?) // our proxy server +//! .with_common_extensions() // we want the e2ee and to-device enabled, please +//! .cold_cache("example-cache".to_owned()); // we want these to be loaded from and stored into the persistent storage +//! +//! let full_sync_list = SlidingSyncList::builder() +//! .sync_mode(SlidingSyncMode::GrowingFullSync) // sync up by growing the window +//! .name(&full_sync_list_name) // needed to lookup again. +//! .sort(vec!["by_recency".to_owned()]) // ordered by most recent +//! .required_state(vec![ +//! (StateEventType::RoomEncryption, "".to_owned()) +//! ]) // only want to know if the room is encrypted +//! .batch_size(50) // grow the window by 50 items at a time +//! .limit(500) // only sync up the top 500 rooms +//! .build()?; +//! +//! let active_list = SlidingSyncList::builder() +//! .name(&active_list_name) // the active window +//! .sync_mode(SlidingSyncMode::Selective) // sync up the specific range only +//! .set_range(0u32, 9u32) // only the top 10 items +//! .sort(vec!["by_recency".to_owned()]) // last active +//! .timeline_limit(5u32) // add the last 5 timeline items for room preview and faster timeline loading +//! .required_state(vec![ // we want to know immediately: +//! (StateEventType::RoomEncryption, "".to_owned()), // is it encrypted +//! (StateEventType::RoomTopic, "".to_owned()), // any topic if known +//! (StateEventType::RoomAvatar, "".to_owned()), // avatar if set +//! ]) +//! .build()?; +//! +//! let sliding_sync = sliding_sync_builder +//! .add_list(active_list) +//! .add_list(full_sync_list) +//! .build() +//! .await?; +//! +//! // subscribe to the list APIs for updates +//! +//! let active_list = sliding_sync.list(&active_list_name).unwrap(); +//! let list_state_stream = active_list.state_stream(); +//! let list_count_stream = active_list.rooms_count_stream(); +//! let list_stream = active_list.rooms_list_stream(); +//! +//! tokio::spawn(async move { +//! pin_mut!(list_state_stream); +//! while let Some(new_state) = list_state_stream.next().await { +//! info!("active-list switched state to {new_state:?}"); +//! } +//! }); +//! +//! tokio::spawn(async move { +//! pin_mut!(list_count_stream); +//! while let Some(new_count) = list_count_stream.next().await { +//! info!("active-list new count: {new_count:?}"); +//! } +//! }); +//! +//! tokio::spawn(async move { +//! pin_mut!(list_stream); +//! while let Some(v_diff) = list_stream.next().await { +//! info!("active-list rooms list diff update: {v_diff:?}"); +//! } +//! }); +//! +//! let stream = sliding_sync.stream(); +//! +//! // continuously poll for updates +//! pin_mut!(stream); +//! loop { +//! let update = match stream.next().await { +//! Some(Ok(u)) => { +//! info!("Received an update. Summary: {u:?}"); +//! }, +//! Some(Err(e)) => { +//! error!("loop was stopped by client error processing: {e}"); +//! } +//! None => { +//! error!("Streaming loop ended unexpectedly"); +//! break; +//! } +//! }; +//! } +//! +//! # anyhow::Ok(()) +//! # }); +//! ``` +//! +//! +//! [MSC]: https://github.com/matrix-org/matrix-spec-proposals/pull/3575 +//! [proxy]: https://github.com/matrix-org/sliding-sync +//! [ruma-types]: https://docs.rs/ruma/latest/ruma/api/client/sync/sync_events/v4/index.html + +mod builder; +mod client; +mod error; +mod list; +mod room; + +use std::{ + borrow::BorrowMut, + collections::BTreeMap, + fmt::Debug, + mem, + sync::{ + atomic::{AtomicU8, Ordering}, + Arc, Mutex, RwLock as StdRwLock, + }, + time::Duration, +}; + +pub use builder::*; +pub use client::*; +pub use error::*; +use eyeball::unique::Observable; +use futures_core::stream::Stream; +pub use list::*; +use matrix_sdk_base::sync::SyncResponse; +use matrix_sdk_common::locks::Mutex as AsyncMutex; +pub use room::*; +use ruma::{ + api::client::{ + error::ErrorKind, + sync::sync_events::v4::{ + self, AccountDataConfig, E2EEConfig, ExtensionsConfig, ToDeviceConfig, + }, + }, + assign, OwnedRoomId, RoomId, +}; +use serde::{Deserialize, Serialize}; +use tokio::spawn; +use tracing::{debug, error, info_span, instrument, trace, warn, Instrument, Span}; +use url::Url; +use uuid::Uuid; + +use crate::{config::RequestConfig, Client, Result}; + +/// Number of times a Sliding Sync session can expire before raising an error. +/// +/// A Sliding Sync session can expire. In this case, it is reset. However, to +/// avoid entering an infinite loop of β€œit's expired, let's reset, it's expired, +/// let's reset…” (maybe if the network has an issue, or the server, or anything +/// else), we define a maximum times a session can expire before +/// raising a proper error. +const MAXIMUM_SLIDING_SYNC_SESSION_EXPIRATION: u8 = 3; + +/// The Sliding Sync instance. +/// +/// It is OK to clone this type as much as you need: cloning it is cheap. +#[derive(Clone, Debug)] +pub struct SlidingSync { + /// The Sliding Sync data. + inner: Arc, + + /// A lock to ensure that responses are handled one at a time. + response_handling_lock: Arc>, +} + +#[derive(Debug)] +pub(super) struct SlidingSyncInner { + /// Customize the homeserver for sliding sync only + homeserver: Option, + + /// The HTTP Matrix client. + client: Client, + + /// The storage key to keep this cache at and load it from + storage_key: Option, + + /// The `pos` and `delta_token` markers. + position: StdRwLock, + + /// The lists of this Sliding Sync instance. + lists: StdRwLock>, + + /// The rooms details + rooms: StdRwLock>, + + subscriptions: StdRwLock>, + unsubscribe: StdRwLock>, + + /// Number of times a Sliding Session session has been reset. + reset_counter: AtomicU8, + + /// the intended state of the extensions being supplied to sliding /sync + /// calls. May contain the latest next_batch for to_devices, etc. + extensions: Mutex>, +} + +impl SlidingSync { + pub(super) fn new(inner: SlidingSyncInner) -> Self { + Self { inner: Arc::new(inner), response_handling_lock: Arc::new(AsyncMutex::new(())) } + } + + async fn cache_to_storage(&self) -> Result<(), crate::Error> { + let Some(storage_key) = self.inner.storage_key.as_ref() else { return Ok(()) }; + trace!(storage_key, "Saving to storage for later use"); + + let store = self.inner.client.store(); + + // Write this `SlidingSync` instance, as a `FrozenSlidingSync` instance, inside + // the client store. + store + .set_custom_value( + storage_key.as_bytes(), + serde_json::to_vec(&FrozenSlidingSync::from(self))?, + ) + .await?; + + // Write every `SlidingSyncList` inside the client the store. + let frozen_lists = { + let rooms_lock = self.inner.rooms.read().unwrap(); + + self.inner + .lists + .read() + .unwrap() + .iter() + .map(|(name, list)| { + Ok(( + format!("{storage_key}::{name}"), + serde_json::to_vec(&FrozenSlidingSyncList::freeze(list, &rooms_lock))?, + )) + }) + .collect::, crate::Error>>()? + }; + + for (storage_key, frozen_list) in frozen_lists { + trace!(storage_key, "Saving the frozen Sliding Sync list"); + + store.set_custom_value(storage_key.as_bytes(), frozen_list).await?; + } + + Ok(()) + } + + /// Create a new [`SlidingSyncBuilder`]. + pub fn builder() -> SlidingSyncBuilder { + SlidingSyncBuilder::new() + } + + /// Generate a new [`SlidingSyncBuilder`] with the same inner settings and + /// lists but without the current state. + pub fn new_builder_copy(&self) -> SlidingSyncBuilder { + let mut builder = Self::builder() + .client(self.inner.client.clone()) + .subscriptions(self.inner.subscriptions.read().unwrap().to_owned()); + + for list in self.inner.lists.read().unwrap().values().map(|list| { + list.new_builder().build().expect("builder worked before, builder works now") + }) { + builder = builder.add_list(list); + } + + if let Some(homeserver) = &self.inner.homeserver { + builder.homeserver(homeserver.clone()) + } else { + builder + } + } + + /// Subscribe to a given room. + /// + /// Note: this does not cancel any pending request, so make sure to only + /// poll the stream after you've altered this. If you do that during, it + /// might take one round trip to take effect. + pub fn subscribe(&self, room_id: OwnedRoomId, settings: Option) { + self.inner.subscriptions.write().unwrap().insert(room_id, settings.unwrap_or_default()); + } + + /// Unsubscribe from a given room. + /// + /// Note: this does not cancel any pending request, so make sure to only + /// poll the stream after you've altered this. If you do that during, it + /// might take one round trip to take effect. + pub fn unsubscribe(&self, room_id: OwnedRoomId) { + if self.inner.subscriptions.write().unwrap().remove(&room_id).is_some() { + self.inner.unsubscribe.write().unwrap().push(room_id); + } + } + + /// Add the common extensions if not already configured. + pub fn add_common_extensions(&self) { + let mut lock = self.inner.extensions.lock().unwrap(); + let mut cfg = lock.get_or_insert_with(Default::default); + + if cfg.to_device.is_none() { + cfg.to_device = Some(assign!(ToDeviceConfig::default(), { enabled: Some(true) })); + } + + if cfg.e2ee.is_none() { + cfg.e2ee = Some(assign!(E2EEConfig::default(), { enabled: Some(true) })); + } + + if cfg.account_data.is_none() { + cfg.account_data = Some(assign!(AccountDataConfig::default(), { enabled: Some(true) })); + } + } + + /// Lookup a specific room + pub fn get_room(&self, room_id: &RoomId) -> Option { + self.inner.rooms.read().unwrap().get(room_id).cloned() + } + + /// Check the number of rooms. + pub fn get_number_of_rooms(&self) -> usize { + self.inner.rooms.read().unwrap().len() + } + + #[instrument(skip(self))] + fn update_to_device_since(&self, since: String) { + // FIXME: Find a better place where the to-device since token should be + // persisted. + self.inner + .extensions + .lock() + .unwrap() + .get_or_insert_with(Default::default) + .to_device + .get_or_insert_with(Default::default) + .since = Some(since); + } + + /// Get access to the SlidingSyncList named `list_name`. + /// + /// Note: Remember that this list might have been changed since you started + /// listening to the stream and is therefor not necessarily up to date + /// with the lists used for the stream. + pub fn list(&self, list_name: &str) -> Option { + self.inner.lists.read().unwrap().get(list_name).cloned() + } + + /// Remove the SlidingSyncList named `list_name` from the lists list if + /// found. + /// + /// Note: Remember that this change will only be applicable for any new + /// stream created after this. The old stream will still continue to use the + /// previous set of lists. + pub fn pop_list(&self, list_name: &String) -> Option { + self.inner.lists.write().unwrap().remove(list_name) + } + + /// Add the list to the list of lists. + /// + /// As lists need to have a unique `.name`, if a list with the same name + /// is found the new list will replace the old one and the return it or + /// `None`. + /// + /// Note: Remember that this change will only be applicable for any new + /// stream created after this. The old stream will still continue to use the + /// previous set of lists. + pub fn add_list(&self, list: SlidingSyncList) -> Option { + self.inner.lists.write().unwrap().insert(list.name.clone(), list) + } + + /// Lookup a set of rooms + pub fn get_rooms>( + &self, + room_ids: I, + ) -> Vec> { + let rooms = self.inner.rooms.read().unwrap(); + + room_ids.map(|room_id| rooms.get(&room_id).cloned()).collect() + } + + /// Get all rooms. + pub fn get_all_rooms(&self) -> Vec { + self.inner.rooms.read().unwrap().values().cloned().collect() + } + + fn prepare_extension_config(&self, pos: Option<&str>) -> ExtensionsConfig { + if pos.is_none() { + // The pos is `None`, it's either our initial sync or the proxy forgot about us + // and sent us an `UnknownPos` error. We need to send out the config for our + // extensions. + let mut extensions = self.inner.extensions.lock().unwrap().clone().unwrap_or_default(); + + // Always enable to-device events and the e2ee-extension on the initial request, + // no matter what the caller wants. + // + // The to-device `since` parameter is either `None` or guaranteed to be set + // because the `update_to_device_since()` method updates the + // self.extensions field and they get persisted to the store using the + // `cache_to_storage()` method. + // + // The token is also loaded from storage in the `SlidingSyncBuilder::build()` + // method. + let mut to_device = extensions.to_device.unwrap_or_default(); + to_device.enabled = Some(true); + + extensions.to_device = Some(to_device); + extensions.e2ee = Some(assign!(E2EEConfig::default(), { enabled: Some(true) })); + + extensions + } else { + // We already enabled all the things, just fetch out the to-device since token + // out of self.extensions and set it in a new, and empty, `ExtensionsConfig`. + let since = self + .inner + .extensions + .lock() + .unwrap() + .as_ref() + .and_then(|e| e.to_device.as_ref()?.since.to_owned()); + + let mut extensions: ExtensionsConfig = Default::default(); + extensions.to_device = Some(assign!(ToDeviceConfig::default(), { since })); + + extensions + } + } + + /// Handle the HTTP response. + #[instrument(skip_all, fields(lists = list_generators.len()))] + fn handle_response( + &self, + sliding_sync_response: v4::Response, + mut sync_response: SyncResponse, + list_generators: &mut BTreeMap, + ) -> Result { + { + debug!( + pos = ?sliding_sync_response.pos, + delta_token = ?sliding_sync_response.delta_token, + "Update position markers`" + ); + + let mut position_lock = self.inner.position.write().unwrap(); + Observable::set(&mut position_lock.pos, Some(sliding_sync_response.pos)); + Observable::set(&mut position_lock.delta_token, sliding_sync_response.delta_token); + } + + let update_summary = { + let mut rooms = Vec::new(); + let mut rooms_map = self.inner.rooms.write().unwrap(); + + for (room_id, mut room_data) in sliding_sync_response.rooms.into_iter() { + // `sync_response` contains the rooms with decrypted events if any, so look at + // the timeline events here first if the room exists. + // Otherwise, let's look at the timeline inside the `sliding_sync_response`. + let timeline = if let Some(joined_room) = sync_response.rooms.join.remove(&room_id) + { + joined_room.timeline.events + } else { + room_data.timeline.drain(..).map(Into::into).collect() + }; + + if let Some(mut room) = rooms_map.remove(&room_id) { + // The room existed before, let's update it. + + room.update(room_data, timeline); + rooms_map.insert(room_id.clone(), room); + } else { + // First time we need this room, let's create it. + + rooms_map.insert( + room_id.clone(), + SlidingSyncRoom::new( + self.inner.client.clone(), + room_id.clone(), + room_data, + timeline, + ), + ); + } + + rooms.push(room_id); + } + + let mut updated_lists = Vec::new(); + + for (name, updates) in sliding_sync_response.lists { + let Some(generator) = list_generators.get_mut(&name) else { + error!("Response for list `{name}` - unknown to us; skipping"); + + continue + }; + + let count: u32 = + updates.count.try_into().expect("the list total count convertible into u32"); + + if generator.handle_response(count, &updates.ops, &rooms)? { + updated_lists.push(name.clone()); + } + } + + // Update the `to-device` next-batch if any. + if let Some(to_device) = sliding_sync_response.extensions.to_device { + self.update_to_device_since(to_device.next_batch); + } + + UpdateSummary { lists: updated_lists, rooms } + }; + + Ok(update_summary) + } + + async fn sync_once( + &self, + stream_id: &str, + list_generators: Arc>>, + ) -> Result> { + let mut lists = BTreeMap::new(); + + { + let mut list_generators_lock = list_generators.lock().unwrap(); + let list_generators = list_generators_lock.borrow_mut(); + let mut lists_to_remove = Vec::new(); + + for (name, generator) in list_generators.iter_mut() { + if let Some(request) = generator.next() { + lists.insert(name.clone(), request); + } else { + lists_to_remove.push(name.clone()); + } + } + + for list_name in lists_to_remove { + list_generators.remove(&list_name); + } + + if list_generators.is_empty() { + return Ok(None); + } + } + + let (pos, delta_token) = { + let position_lock = self.inner.position.read().unwrap(); + + (position_lock.pos.clone(), position_lock.delta_token.clone()) + }; + + let room_subscriptions = self.inner.subscriptions.read().unwrap().clone(); + let unsubscribe_rooms = mem::take(&mut *self.inner.unsubscribe.write().unwrap()); + let timeout = Duration::from_secs(30); + let extensions = self.prepare_extension_config(pos.as_deref()); + + debug!("Sending the sliding sync request"); + + // Configure long-polling. We need 30 seconds for the long-poll itself, in + // addition to 30 more extra seconds for the network delays. + let request_config = RequestConfig::default().timeout(timeout + Duration::from_secs(30)); + + // Prepare the request. + let request = self.inner.client.send_with_homeserver( + assign!(v4::Request::new(), { + pos, + delta_token, + // We want to track whether the incoming response maps to this + // request. We use the (optional) `txn_id` field for that. + txn_id: Some(stream_id.to_owned()), + timeout: Some(timeout), + lists, + room_subscriptions, + unsubscribe_rooms, + extensions, + }), + Some(request_config), + self.inner.homeserver.as_ref().map(ToString::to_string), + ); + + // Send the request and get a response with end-to-end encryption support. + // + // Sending the `/sync` request out when end-to-end encryption is enabled means + // that we need to also send out any outgoing e2ee related request out + // coming from the `OlmMachine::outgoing_requests()` method. + #[cfg(feature = "e2e-encryption")] + let response = { + debug!("Sliding Sync is sending the request along with outgoing E2EE requests"); + + let (e2ee_uploads, response) = + futures_util::future::join(self.inner.client.send_outgoing_requests(), request) + .await; + + if let Err(error) = e2ee_uploads { + error!(?error, "Error while sending outgoing E2EE requests"); + } + + response + }?; + + // Send the request and get a response _without_ end-to-end encryption support. + #[cfg(not(feature = "e2e-encryption"))] + let response = { + debug!("Sliding Sync is sending the request"); + + request.await? + }; + + debug!(?response, "Sliding Sync response received"); + + // At this point, the request has been sent, and a response has been received. + // + // We must ensure the handling of the response cannot be stopped/ + // cancelled. It must be done entirely, otherwise we can have + // corrupted/incomplete states for Sliding Sync and other parts of + // the code. + // + // That's why we are running the handling of the response in a spawned + // future that cannot be cancelled by anything. + let this = self.clone(); + let stream_id = stream_id.to_owned(); + + // Spawn a new future to ensure that the code inside this future cannot be + // cancelled if this method is cancelled. + spawn(async move { + debug!("Sliding Sync response handling starts"); + + // In case the task running this future is detached, we must + // ensure responses are handled one at a time, hence we lock the + // `response_handling_lock`. + let response_handling_lock = this.response_handling_lock.lock().await; + + match &response.txn_id { + None => { + error!(stream_id, "Sliding Sync has received an unexpected response: `txn_id` must match `stream_id`; it's missing"); + } + + Some(txn_id) if txn_id != &stream_id => { + error!( + stream_id, + txn_id, + "Sliding Sync has received an unexpected response: `txn_id` must match `stream_id`; they differ" + ); + } + + _ => {} + } + + // Handle and transform a Sliding Sync Response to a `SyncResponse`. + // + // We may not need the `sync_response` in the future (once `SyncResponse` will + // move to Sliding Sync, i.e. to `v4::Response`), but processing the + // `sliding_sync_response` is vital, so it must be done somewhere; for now it + // happens here. + let sync_response = this.inner.client.process_sliding_sync(&response).await?; + + debug!(?sync_response, "Sliding Sync response has been handled by the client"); + + let updates = this.handle_response(response, sync_response, list_generators.lock().unwrap().borrow_mut())?; + + this.cache_to_storage().await?; + + // Release the lock. + drop(response_handling_lock); + + debug!("Sliding Sync response has been fully handled"); + + Ok(Some(updates)) + }).await.unwrap() + } + + /// Create a _new_ Sliding Sync stream. + /// + /// This stream will send requests and will handle responses automatically, + /// hence updating the lists. + #[allow(unknown_lints, clippy::let_with_type_underscore)] // triggered by instrument macro + #[instrument(name = "sync_stream", skip_all, parent = &self.inner.client.inner.root_span)] + pub fn stream<'a>(&'a self) -> impl Stream> + 'a { + // Collect all the lists that need to be updated. + let list_generators = { + let mut list_generators = BTreeMap::new(); + let lock = self.inner.lists.read().unwrap(); + + for (name, lists) in lock.iter() { + list_generators.insert(name.clone(), lists.request_generator()); + } + + list_generators + }; + + let stream_id = Uuid::new_v4().to_string(); + + debug!(?self.inner.extensions, stream_id, "About to run the sync stream"); + + let instrument_span = Span::current(); + let list_generators = Arc::new(Mutex::new(list_generators)); + + async_stream::stream! { + loop { + let sync_span = info_span!(parent: &instrument_span, "sync_once"); + + sync_span.in_scope(|| { + debug!(?self.inner.extensions, "Sync stream loop is running"); + }); + + match self.sync_once(&stream_id, list_generators.clone()).instrument(sync_span.clone()).await { + Ok(Some(updates)) => { + self.inner.reset_counter.store(0, Ordering::SeqCst); + + yield Ok(updates); + } + + Ok(None) => { + break; + } + + Err(error) => { + if error.client_api_error_kind() == Some(&ErrorKind::UnknownPos) { + // The session has expired. + + // Has it expired too many times? + if self.inner.reset_counter.fetch_add(1, Ordering::SeqCst) >= MAXIMUM_SLIDING_SYNC_SESSION_EXPIRATION { + sync_span.in_scope(|| error!("Session expired {MAXIMUM_SLIDING_SYNC_SESSION_EXPIRATION} times in a row")); + + // The session has expired too many times, let's raise an error! + yield Err(error.into()); + + break; + } + + // Let's reset the Sliding Sync session. + sync_span.in_scope(|| { + warn!("Session expired. Restarting Sliding Sync."); + + // To β€œrestart” a Sliding Sync session, we set `pos` to its initial value. + { + let mut position_lock = self.inner.position.write().unwrap(); + + Observable::set(&mut position_lock.pos, None); + } + + debug!(?self.inner.extensions, "Sliding Sync has been reset"); + }); + } + + yield Err(error.into()); + + continue; + } + } + } + } + } +} + +#[cfg(any(test, feature = "testing"))] +impl SlidingSync { + /// Get a copy of the `pos` value. + pub fn pos(&self) -> Option { + let position_lock = self.inner.position.read().unwrap(); + + position_lock.pos.clone() + } + + /// Set a new value for `pos`. + pub fn set_pos(&self, new_pos: String) { + let mut position_lock = self.inner.position.write().unwrap(); + + Observable::set(&mut position_lock.pos, Some(new_pos)); + } +} + +#[derive(Debug)] +pub(super) struct SlidingSyncPositionMarkers { + pos: Observable>, + delta_token: Observable>, +} + +#[derive(Serialize, Deserialize)] +struct FrozenSlidingSync { + #[serde(skip_serializing_if = "Option::is_none")] + to_device_since: Option, + #[serde(skip_serializing_if = "Option::is_none")] + delta_token: Option, +} + +impl From<&SlidingSync> for FrozenSlidingSync { + fn from(sliding_sync: &SlidingSync) -> Self { + FrozenSlidingSync { + delta_token: sliding_sync.inner.position.read().unwrap().delta_token.clone(), + to_device_since: sliding_sync + .inner + .extensions + .lock() + .unwrap() + .as_ref() + .and_then(|ext| ext.to_device.as_ref()?.since.clone()), + } + } +} + +/// A summary of the updates received after a sync (like in +/// [`SlidingSync::stream`]). +#[derive(Debug, Clone)] +pub struct UpdateSummary { + /// The names of the lists that have seen an update. + pub lists: Vec, + /// The rooms that have seen updates + pub rooms: Vec, +} + +#[cfg(test)] +mod test { + use assert_matches::assert_matches; + use ruma::room_id; + use serde_json::json; + use wiremock::MockServer; + + use super::*; + use crate::test_utils::logged_in_client; + + #[tokio::test] + async fn check_find_room_in_list() -> Result<()> { + let list = + SlidingSyncList::builder().name("testlist").add_range(0u32, 9u32).build().unwrap(); + let full_window_update: v4::SyncOp = serde_json::from_value(json! ({ + "op": "SYNC", + "range": [0, 9], + "room_ids": [ + "!A00000:matrix.example", + "!A00001:matrix.example", + "!A00002:matrix.example", + "!A00003:matrix.example", + "!A00004:matrix.example", + "!A00005:matrix.example", + "!A00006:matrix.example", + "!A00007:matrix.example", + "!A00008:matrix.example", + "!A00009:matrix.example" + ], + })) + .unwrap(); + + list.handle_response(10u32, &vec![full_window_update], &vec![(0, 9)], &vec![]).unwrap(); + + let a02 = room_id!("!A00002:matrix.example").to_owned(); + let a05 = room_id!("!A00005:matrix.example").to_owned(); + let a09 = room_id!("!A00009:matrix.example").to_owned(); + + assert_eq!(list.find_room_in_list(&a02), Some(2)); + assert_eq!(list.find_room_in_list(&a05), Some(5)); + assert_eq!(list.find_room_in_list(&a09), Some(9)); + + assert_eq!( + list.find_rooms_in_list(&[a02.clone(), a05.clone(), a09.clone()]), + vec![(2, a02.clone()), (5, a05.clone()), (9, a09.clone())] + ); + + // we invalidate a few in the center + let update: v4::SyncOp = serde_json::from_value(json! ({ + "op": "INVALIDATE", + "range": [4, 7], + })) + .unwrap(); + + list.handle_response(10u32, &vec![update], &vec![(0, 3), (8, 9)], &vec![]).unwrap(); + + assert_eq!(list.find_room_in_list(room_id!("!A00002:matrix.example")), Some(2)); + assert_eq!(list.find_room_in_list(room_id!("!A00005:matrix.example")), None); + assert_eq!(list.find_room_in_list(room_id!("!A00009:matrix.example")), Some(9)); + + assert_eq!( + list.find_rooms_in_list(&[a02.clone(), a05, a09.clone()]), + vec![(2, a02), (9, a09)] + ); + + Ok(()) + } + + #[tokio::test] + async fn to_device_is_enabled_when_pos_is_none() -> Result<()> { + let server = MockServer::start().await; + let client = logged_in_client(Some(server.uri())).await; + + let sync = client.sliding_sync().await.build().await?; + let extensions = sync.prepare_extension_config(None); + + // If the user doesn't provide any extension config, we enable to-device and + // e2ee anyways. + assert_matches!( + extensions.to_device, + Some(ToDeviceConfig { enabled: Some(true), since: None, .. }) + ); + assert_matches!(extensions.e2ee, Some(E2EEConfig { enabled: Some(true), .. })); + + let some_since = "some_since".to_owned(); + sync.update_to_device_since(some_since.to_owned()); + let extensions = sync.prepare_extension_config(Some("foo")); + + // If there's a `pos` and to-device `since` token, we make sure we put the token + // into the extension config. The rest doesn't need to be re-enabled due to + // stickyness. + assert_matches!( + extensions.to_device, + Some(ToDeviceConfig { enabled: None, since: Some(since), .. }) if since == some_since + ); + assert_matches!(extensions.e2ee, None); + + let extensions = sync.prepare_extension_config(None); + // Even if there isn't a `pos`, if we have a to-device `since` token, we put it + // into the request. + assert_matches!( + extensions.to_device, + Some(ToDeviceConfig { enabled: Some(true), since: Some(since), .. }) if since == some_since + ); + + Ok(()) + } +} diff --git a/crates/matrix-sdk/src/sliding_sync/room.rs b/crates/matrix-sdk/src/sliding_sync/room.rs new file mode 100644 index 00000000000..48d270ba234 --- /dev/null +++ b/crates/matrix-sdk/src/sliding_sync/room.rs @@ -0,0 +1,391 @@ +use std::{ + fmt::Debug, + ops::Not, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, RwLock as StdRwLock, + }, +}; + +use eyeball::unique::Observable; +use eyeball_im::ObservableVector; +use im::Vector; +use matrix_sdk_base::deserialized_responses::SyncTimelineEvent; +use ruma::{ + api::client::sync::sync_events::{v4, UnreadNotificationsCount}, + events::AnySyncStateEvent, + serde::Raw, + OwnedRoomId, +}; +use serde::{Deserialize, Serialize}; +use tracing::{error, instrument}; + +use crate::{ + room::timeline::{EventTimelineItem, Timeline, TimelineBuilder}, + Client, +}; + +/// Room details, provided by a [`SlidingSync`] instance. +#[derive(Debug, Clone)] +pub struct SlidingSyncRoom { + client: Client, + room_id: OwnedRoomId, + inner: v4::SlidingSyncRoom, + is_loading_more: Arc>>, + is_cold: Arc, + prev_batch: Arc>>>, + timeline_queue: Arc>>, +} + +impl SlidingSyncRoom { + pub(super) fn new( + client: Client, + room_id: OwnedRoomId, + inner: v4::SlidingSyncRoom, + timeline: Vec, + ) -> Self { + let mut timeline_queue = ObservableVector::new(); + timeline_queue.append(timeline.into_iter().collect()); + + Self { + client, + room_id, + is_loading_more: Arc::new(StdRwLock::new(Observable::new(false))), + is_cold: Arc::new(AtomicBool::new(false)), + prev_batch: Arc::new(StdRwLock::new(Observable::new(inner.prev_batch.clone()))), + timeline_queue: Arc::new(StdRwLock::new(timeline_queue)), + inner, + } + } + + /// RoomId of this SlidingSyncRoom + pub fn room_id(&self) -> &OwnedRoomId { + &self.room_id + } + + /// Are we currently fetching more timeline events in this room? + pub fn is_loading_more(&self) -> bool { + **self.is_loading_more.read().unwrap() + } + + /// The `prev_batch` key to fetch more timeline events for this room. + pub fn prev_batch(&self) -> Option { + self.prev_batch.read().unwrap().clone() + } + + /// `Timeline` of this room + pub async fn timeline(&self) -> Option { + Some(self.timeline_builder()?.track_read_marker_and_receipts().build().await) + } + + fn timeline_builder(&self) -> Option { + if let Some(room) = self.client.get_room(&self.room_id) { + Some(Timeline::builder(&room).events( + self.prev_batch.read().unwrap().clone(), + self.timeline_queue.read().unwrap().clone(), + )) + } else if let Some(invited_room) = self.client.get_invited_room(&self.room_id) { + Some(Timeline::builder(&invited_room).events(None, Vector::new())) + } else { + error!( + room_id = ?self.room_id, + "Room not found in client. Can't provide a timeline for it" + ); + + None + } + } + + /// The latest timeline item of this room. + /// + /// Use `Timeline::latest_event` instead if you already have a timeline for + /// this `SlidingSyncRoom`. + #[instrument(skip_all, parent = &self.client.inner.root_span)] + pub async fn latest_event(&self) -> Option { + self.timeline_builder()?.build().await.latest_event().await + } + + /// This rooms name as calculated by the server, if any + pub fn name(&self) -> Option<&str> { + self.inner.name.as_deref() + } + + /// Is this a direct message? + pub fn is_dm(&self) -> Option { + self.inner.is_dm + } + + /// Was this an initial response. + pub fn is_initial_response(&self) -> Option { + self.inner.initial + } + + /// Is there any unread notifications? + pub fn has_unread_notifications(&self) -> bool { + self.inner.unread_notifications.is_empty().not() + } + + /// Get unread notifications. + pub fn unread_notifications(&self) -> &UnreadNotificationsCount { + &self.inner.unread_notifications + } + + /// Get the required state. + pub fn required_state(&self) -> &Vec> { + &self.inner.required_state + } + + pub(super) fn update( + &mut self, + room_data: v4::SlidingSyncRoom, + timeline_updates: Vec, + ) { + let v4::SlidingSyncRoom { + name, + initial, + limited, + is_dm, + invite_state, + unread_notifications, + required_state, + prev_batch, + .. + } = room_data; + + self.inner.unread_notifications = unread_notifications; + + // The server might not send some parts of the response, because they were sent + // before and the server wants to save bandwidth. So let's update the values + // only when they exist. + + if name.is_some() { + self.inner.name = name; + } + + if initial.is_some() { + self.inner.initial = initial; + } + + if is_dm.is_some() { + self.inner.is_dm = is_dm; + } + + if !invite_state.is_empty() { + self.inner.invite_state = invite_state; + } + + if !required_state.is_empty() { + self.inner.required_state = required_state; + } + + if prev_batch.is_some() { + Observable::set(&mut self.prev_batch.write().unwrap(), prev_batch); + } + + // There is timeline updates. + if !timeline_updates.is_empty() { + if self.is_cold.load(Ordering::SeqCst) { + // If we come from a cold storage, we overwrite the timeline queue with the + // timeline updates. + + let mut lock = self.timeline_queue.write().unwrap(); + lock.clear(); + for event in timeline_updates { + lock.push_back(event); + } + + self.is_cold.store(false, Ordering::SeqCst); + } else if limited { + // The server alerted us that we missed items in between. + + let mut lock = self.timeline_queue.write().unwrap(); + lock.clear(); + for event in timeline_updates { + lock.push_back(event); + } + } else { + // It's the hot path. We have new updates that must be added to the existing + // timeline queue. + + let mut timeline_queue = self.timeline_queue.write().unwrap(); + + // If the `timeline_queue` contains: + // [D, E, F] + // and if the `timeline_updates` contains: + // [A, B, C, D, E, F] + // the resulting `timeline_queue` must be: + // [A, B, C, D, E, F] + // + // To do that, we find the longest suffix between `timeline_queue` and + // `timeline_updates`, in this case: + // [D, E, F] + // Remove the suffix from `timeline_updates`, we get `[A, B, C]` that is + // prepended to `timeline_queue`. + // + // If the `timeline_queue` contains: + // [A, B, C, D, E, F] + // and if the `timeline_updates` contains: + // [D, E, F] + // the resulting `timeline_queue` must be: + // [A, B, C, D, E, F] + // + // To do that, we continue with the longest suffix. In this case, it is: + // [D, E, F] + // Remove the suffix from `timeline_updates`, we get `[]`. It's empty, we don't + // touch at `timeline_queue`. + + { + let timeline_queue_len = timeline_queue.len(); + let timeline_updates_len = timeline_updates.len(); + + let position = match timeline_queue + .iter() + .rev() + .zip(timeline_updates.iter().rev()) + .position(|(queue, update)| queue.event_id() != update.event_id()) + { + // We have found a suffix that equals the size of `timeline_queue` or + // `timeline_update`, typically: + // timeline_queue = [D, E, F] + // timeline_update = [A, B, C, D, E, F] + // or + // timeline_queue = [A, B, C, D, E, F] + // timeline_update = [D, E, F] + // in both case, `position` will return `None` because we are looking for + // (from the end) an item that is different. + None => std::cmp::min(timeline_queue_len, timeline_updates_len), + + // We may have found a suffix. + // + // If we have `Some(0)`, it means we don't have found a suffix. That's the + // hot path, `timeline_updates` will just be appended to `timeline_queue`. + // + // If we have `Some(n)` with `n > 0`, it means we have a prefix but it + // doesn't cover all `timeline_queue` or `timeline_update`, typically: + // timeline_queue = [B, D, E, F] + // timeline_update = [A, B, C, D, E, F] + // in this case, `position` will return `Some(3)`. + // That's annoying because it means we have an invalid `timeline_queue` or + // `timeline_update`, but let's try to do our best. + Some(position) => position, + }; + + if position == 0 { + // No prefix found. + + for event in timeline_updates { + timeline_queue.push_back(event); + } + } else { + // Prefix found. + + let new_timeline_updates = + &timeline_updates[..timeline_updates_len - position]; + + if !new_timeline_updates.is_empty() { + for (at, update) in new_timeline_updates.iter().cloned().enumerate() { + timeline_queue.insert(at, update); + } + } + } + } + } + } else if limited { + // The timeline updates are empty. But `limited` is set to true. It's a way to + // alert that we are stale. In this case, we should just clear the + // existing timeline. + + self.timeline_queue.write().unwrap().clear(); + } + } + + pub(super) fn from_frozen(frozen_room: FrozenSlidingSyncRoom, client: Client) -> Self { + let FrozenSlidingSyncRoom { room_id, inner, prev_batch, timeline_queue } = frozen_room; + + let mut timeline_queue_ob = ObservableVector::new(); + timeline_queue_ob.append(timeline_queue); + + Self { + client, + room_id, + inner, + is_loading_more: Arc::new(StdRwLock::new(Observable::new(false))), + is_cold: Arc::new(AtomicBool::new(true)), + prev_batch: Arc::new(StdRwLock::new(Observable::new(prev_batch))), + timeline_queue: Arc::new(StdRwLock::new(timeline_queue_ob)), + } + } +} + +/// A β€œfrozen” [`SlidingSyncRoom`], i.e. that can be written into, or read from +/// a store. +#[derive(Serialize, Deserialize)] +pub(super) struct FrozenSlidingSyncRoom { + room_id: OwnedRoomId, + inner: v4::SlidingSyncRoom, + prev_batch: Option, + #[serde(rename = "timeline")] + timeline_queue: Vector, +} + +impl From<&SlidingSyncRoom> for FrozenSlidingSyncRoom { + fn from(value: &SlidingSyncRoom) -> Self { + let timeline = value.timeline_queue.read().unwrap(); + let timeline_length = timeline.len(); + + // To not overflow the database, we only freeze the newest 10 items. On doing + // so, we must drop the `prev_batch` key however, as we'd otherwise + // create a gap between what we have loaded and where the + // prev_batch-key will start loading when paginating backwards. + let (prev_batch, timeline) = if timeline_length > 10 { + let pos = timeline_length - 10; + (None, timeline.iter().skip(pos).cloned().collect()) + } else { + (value.prev_batch.read().unwrap().clone(), timeline.clone()) + }; + + Self { + prev_batch, + timeline_queue: timeline, + room_id: value.room_id.clone(), + inner: value.inner.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use im::vector; + use matrix_sdk_base::deserialized_responses::TimelineEvent; + use ruma::{events::room::message::RoomMessageEventContent, RoomId}; + use serde_json::json; + + use super::*; + + #[test] + fn test_frozen_sliding_sync_room_serialize() { + let frozen_sliding_sync_room = FrozenSlidingSyncRoom { + room_id: <&RoomId>::try_from("!29fhd83h92h0:example.com").unwrap().to_owned(), + inner: v4::SlidingSyncRoom::default(), + prev_batch: Some("let it go!".to_owned()), + timeline_queue: vector![TimelineEvent::new( + Raw::new(&json! ({ + "content": RoomMessageEventContent::text_plain("let it gooo!"), + "type": "m.room.message", + "event_id": "$xxxxx:example.org", + "room_id": "!someroom:example.com", + "origin_server_ts": 2189, + "sender": "@bob:example.com", + })) + .unwrap() + .cast() + ) + .into()], + }; + + assert_eq!( + serde_json::to_string(&frozen_sliding_sync_room).unwrap(), + "{\"room_id\":\"!29fhd83h92h0:example.com\",\"inner\":{},\"prev_batch\":\"let it go!\",\"timeline\":[{\"event\":{\"content\":{\"body\":\"let it gooo!\",\"msgtype\":\"m.text\"},\"event_id\":\"$xxxxx:example.org\",\"origin_server_ts\":2189,\"room_id\":\"!someroom:example.com\",\"sender\":\"@bob:example.com\",\"type\":\"m.room.message\"},\"encryption_info\":null}]}", + ); + } +} diff --git a/crates/matrix-sdk/src/sync.rs b/crates/matrix-sdk/src/sync.rs index b92160da153..67a99136f22 100644 --- a/crates/matrix-sdk/src/sync.rs +++ b/crates/matrix-sdk/src/sync.rs @@ -184,7 +184,7 @@ impl Client { async fn sleep() { #[cfg(target_arch = "wasm32")] - let _ = wasm_timer::Delay::new(Duration::from_secs(1)).await; + gloo_timers::future::TimeoutFuture::new(1_000).await; #[cfg(not(target_arch = "wasm32"))] tokio::time::sleep(Duration::from_secs(1)).await; diff --git a/crates/matrix-sdk/src/test_utils.rs b/crates/matrix-sdk/src/test_utils.rs index 30ebe50fea6..53efd0d349f 100644 --- a/crates/matrix-sdk/src/test_utils.rs +++ b/crates/matrix-sdk/src/test_utils.rs @@ -4,8 +4,6 @@ use matrix_sdk_base::Session; use ruma::{api::MatrixVersion, device_id, user_id}; -#[cfg(feature = "experimental-sliding-sync")] -use crate::sliding_sync::SlidingSync; use crate::{config::RequestConfig, Client, ClientBuilder}; pub(crate) fn test_client_builder(homeserver_url: Option) -> ClientBuilder { @@ -33,9 +31,3 @@ pub(crate) async fn logged_in_client(homeserver_url: Option) -> Client { client } - -/// Force a specific pos-value to be used for the given sliding-sync instance. -#[cfg(feature = "experimental-sliding-sync")] -pub fn force_sliding_sync_pos(sliding_sync: &SlidingSync, new_pos: String) { - sliding_sync.pos.set(Some(new_pos)); -} diff --git a/crates/matrix-sdk/tests/integration/main.rs b/crates/matrix-sdk/tests/integration/main.rs index 4be34961aad..2619a4c0a97 100644 --- a/crates/matrix-sdk/tests/integration/main.rs +++ b/crates/matrix-sdk/tests/integration/main.rs @@ -1,7 +1,10 @@ // The http mocking library is not supported for wasm32 #![cfg(not(target_arch = "wasm32"))] -use matrix_sdk::{config::RequestConfig, Client, ClientBuilder, Session}; +use matrix_sdk::{ + config::{RequestConfig, SyncSettings}, + Client, ClientBuilder, Session, +}; use matrix_sdk_test::test_json; use ruma::{api::MatrixVersion, device_id, user_id}; use serde::Serialize; @@ -51,6 +54,17 @@ async fn logged_in_client() -> (Client, MockServer) { (client, server) } +async fn synced_client() -> (Client, MockServer) { + let (client, server) = logged_in_client().await; + mock_sync(&server, &*test_json::SYNC, None).await; + + let sync_settings = SyncSettings::new(); + + let _response = client.sync_once(sync_settings).await.unwrap(); + + (client, server) +} + /// Mount a Mock on the given server to handle the `GET /sync` endpoint with /// an optional `since` param that returns a 200 status code with the given /// response body. diff --git a/crates/matrix-sdk/tests/integration/refresh_token.rs b/crates/matrix-sdk/tests/integration/refresh_token.rs index 5b1e6263f2b..a83a29bc75a 100644 --- a/crates/matrix-sdk/tests/integration/refresh_token.rs +++ b/crates/matrix-sdk/tests/integration/refresh_token.rs @@ -1,11 +1,7 @@ use std::time::Duration; use assert_matches::assert_matches; -use futures::{ - channel::{mpsc, oneshot}, - StreamExt, -}; -use futures_signals::signal::SignalExt; +use futures::{channel::mpsc, StreamExt}; use matrix_sdk::{config::RequestConfig, executor::spawn, HttpError, RefreshTokenError, Session}; use matrix_sdk_test::{async_test, test_json}; use ruma::{ @@ -248,27 +244,16 @@ async fn refresh_token_handled_success() { }; client.restore_session(session).await.unwrap(); - let mut tokens_stream = client.session_tokens_signal().to_stream(); - let (tokens_sender, tokens_receiver) = oneshot::channel::<()>(); - spawn(async move { - let tokens = tokens_stream.next().await.flatten().unwrap(); - assert_eq!(tokens.access_token, "1234"); - assert_eq!(tokens.refresh_token.as_deref(), Some("abcd")); - + let mut tokens_stream = client.session_tokens_stream(); + let tokens_join_handle = spawn(async move { let tokens = tokens_stream.next().await.flatten().unwrap(); assert_eq!(tokens.access_token, "5678"); assert_eq!(tokens.refresh_token.as_deref(), Some("abcd")); - - tokens_sender.send(()).unwrap(); }); - let mut tokens_changed_stream = client.session_tokens_changed_signal().to_stream(); - let (changed_sender, changed_receiver) = oneshot::channel::<()>(); - spawn(async move { + let mut tokens_changed_stream = client.session_tokens_changed_stream(); + let changed_join_handle = spawn(async move { tokens_changed_stream.next().await.unwrap(); - tokens_changed_stream.next().await.unwrap(); - - changed_sender.send(()).unwrap(); }); Mock::given(method("POST")) @@ -300,8 +285,8 @@ async fn refresh_token_handled_success() { .await; client.whoami().await.unwrap(); - tokens_receiver.await.unwrap(); - changed_receiver.await.unwrap(); + tokens_join_handle.await.unwrap(); + changed_join_handle.await.unwrap(); } #[async_test] diff --git a/crates/matrix-sdk/tests/integration/room/joined.rs b/crates/matrix-sdk/tests/integration/room/joined.rs index f5eaceaef55..adcb62049c9 100644 --- a/crates/matrix-sdk/tests/integration/room/joined.rs +++ b/crates/matrix-sdk/tests/integration/room/joined.rs @@ -1,25 +1,28 @@ use std::time::Duration; +use futures::future::join_all; use matrix_sdk::{ attachment::{ AttachmentConfig, AttachmentInfo, BaseImageInfo, BaseThumbnailInfo, BaseVideoInfo, Thumbnail, }, config::SyncSettings, + room::Receipts, }; use matrix_sdk_test::{async_test, test_json}; use ruma::{ - api::client::membership::Invite3pidInit, assign, event_id, - events::room::message::RoomMessageEventContent, mxc_uri, thirdparty, uint, user_id, - TransactionId, + api::client::{membership::Invite3pidInit, receipt::create_receipt::v3::ReceiptType}, + assign, event_id, + events::{receipt::ReceiptThread, room::message::RoomMessageEventContent}, + mxc_uri, thirdparty, uint, user_id, TransactionId, }; use serde_json::json; use wiremock::{ - matchers::{body_partial_json, header, method, path, path_regex}, + matchers::{body_json, body_partial_json, header, method, path, path_regex}, Mock, ResponseTemplate, }; -use crate::{logged_in_client, mock_encryption_state, mock_sync}; +use crate::{logged_in_client, mock_encryption_state, mock_sync, synced_client}; #[async_test] async fn invite_user_by_id() { @@ -145,7 +148,7 @@ async fn kick_user() { } #[async_test] -async fn read_receipt() { +async fn send_single_receipt() { let (client, server) = logged_in_client().await; Mock::given(method("POST")) @@ -161,14 +164,14 @@ async fn read_receipt() { let _response = client.sync_once(sync_settings).await.unwrap(); - let event_id = event_id!("$xxxxxx:example.org"); + let event_id = event_id!("$xxxxxx:example.org").to_owned(); let room = client.get_joined_room(&test_json::DEFAULT_SYNC_ROOM_ID).unwrap(); - room.read_receipt(event_id).await.unwrap(); + room.send_single_receipt(ReceiptType::Read, ReceiptThread::Unthreaded, event_id).await.unwrap(); } #[async_test] -async fn read_marker() { +async fn send_multiple_receipts() { let (client, server) = logged_in_client().await; Mock::given(method("POST")) @@ -184,10 +187,11 @@ async fn read_marker() { let _response = client.sync_once(sync_settings).await.unwrap(); - let event_id = event_id!("$xxxxxx:example.org"); + let event_id = event_id!("$xxxxxx:example.org").to_owned(); let room = client.get_joined_room(&test_json::DEFAULT_SYNC_ROOM_ID).unwrap(); - room.read_marker(event_id, None).await.unwrap(); + let receipts = Receipts::new().fully_read_marker(event_id); + room.send_multiple_receipts(receipts).await.unwrap(); } #[async_test] @@ -493,7 +497,7 @@ async fn room_attachment_send_info_thumbnail() { #[async_test] async fn room_redact() { - let (client, server) = logged_in_client().await; + let (client, server) = synced_client().await; Mock::given(method("PUT")) .and(path_regex(r"^/_matrix/client/r0/rooms/.*/redact/.*?/.*?")) @@ -502,12 +506,6 @@ async fn room_redact() { .mount(&server) .await; - mock_sync(&server, &*test_json::SYNC, None).await; - - let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000)); - - let _response = client.sync_once(sync_settings).await.unwrap(); - let room = client.get_joined_room(&test_json::DEFAULT_SYNC_ROOM_ID).unwrap(); let event_id = event_id!("$xxxxxxxx:example.com"); @@ -518,3 +516,72 @@ async fn room_redact() { assert_eq!(event_id!("$h29iv0s8:example.com"), response.event_id) } + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn fetch_members_deduplication() { + let (client, server) = synced_client().await; + + // We don't need any members, we're just checking if we're correctly + // deduplicating calls to the method. + let response_body = json!({ + "chunk": [], + }); + + Mock::given(method("GET")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/members")) + .and(header("authorization", "Bearer 1234")) + .respond_with(ResponseTemplate::new(200).set_body_json(response_body)) + // Expect that we're only going to send the request out once. + .expect(1..=1) + .mount(&server) + .await; + + let room = client.get_joined_room(&test_json::DEFAULT_SYNC_ROOM_ID).unwrap(); + + let mut tasks = Vec::new(); + + // Create N tasks that try to fetch the members. + for _ in 0..5 { + #[allow(unknown_lints, clippy::redundant_async_block)] // false positive + let task = tokio::spawn({ + let room = room.clone(); + async move { room.sync_members().await } + }); + + tasks.push(task); + } + + // Wait on all of them at once. + let results = join_all(tasks).await; + + // See how many of them sent a request and thus have a response. + let response_count = + results.iter().filter(|r| r.as_ref().unwrap().as_ref().unwrap().is_some()).count(); + assert_eq!(response_count, 1); +} + +#[async_test] +async fn set_name() { + let (client, server) = synced_client().await; + + mock_sync(&server, &*test_json::SYNC, None).await; + let sync_settings = SyncSettings::new(); + client.sync_once(sync_settings).await.unwrap(); + + let room = client.get_joined_room(&test_json::DEFAULT_SYNC_ROOM_ID).unwrap(); + let name = "The room name"; + + Mock::given(method("PUT")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/state/m.room.name/$")) + .and(header("authorization", "Bearer 1234")) + .and(body_json(json!({ + "name": name, + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::EVENT_ID)) + .expect(1) + .mount(&server) + .await; + + room.set_name(Some(name.to_owned())).await.unwrap(); +} diff --git a/crates/matrix-sdk/tests/integration/room/timeline.rs b/crates/matrix-sdk/tests/integration/room/timeline/mod.rs similarity index 67% rename from crates/matrix-sdk/tests/integration/room/timeline.rs rename to crates/matrix-sdk/tests/integration/room/timeline/mod.rs index 8aab02552cb..f43107f47c2 100644 --- a/crates/matrix-sdk/tests/integration/room/timeline.rs +++ b/crates/matrix-sdk/tests/integration/room/timeline/mod.rs @@ -3,13 +3,13 @@ use std::{sync::Arc, time::Duration}; use assert_matches::assert_matches; -use futures_signals::signal_vec::{SignalVecExt, VecDiff}; +use eyeball_im::VectorDiff; use futures_util::StreamExt; use matrix_sdk::{ config::SyncSettings, room::timeline::{ AnyOtherFullStateEventContent, Error as TimelineError, EventSendState, PaginationOptions, - TimelineDetails, TimelineItemContent, VirtualTimelineItem, + TimelineDetails, TimelineItem, TimelineItemContent, VirtualTimelineItem, }, ruma::MilliSecondsSinceUnixEpoch, Error, @@ -17,7 +17,7 @@ use matrix_sdk::{ use matrix_sdk_common::executor::spawn; use matrix_sdk_test::{ async_test, test_json, EventBuilder, JoinedRoomBuilder, RoomAccountDataTestEvent, - TimelineTestEvent, + StateTestEvent, TimelineTestEvent, }; use ruma::{ event_id, @@ -33,6 +33,8 @@ use wiremock::{ Mock, ResponseTemplate, }; +mod read_receipts; + use crate::{logged_in_client, mock_encryption_state, mock_sync}; #[async_test] @@ -50,7 +52,7 @@ async fn edit() { let room = client.get_room(room_id).unwrap(); let timeline = room.timeline().await; - let mut timeline_stream = timeline.signal().to_stream(); + let (_, mut timeline_stream) = timeline.subscribe().await; ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event( TimelineTestEvent::Custom(json!({ @@ -69,10 +71,14 @@ async fn edit() { let _response = client.sync_once(sync_settings.clone()).await.unwrap(); server.reset().await; - let _day_divider = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); - let first = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); + let _day_divider = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::PushBack { value }) => value + ); + let first = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::PushBack { value }) => value + ); let msg = assert_matches!( first.as_event().unwrap().content(), TimelineItemContent::Message(msg) => msg @@ -119,8 +125,7 @@ async fn edit() { let _response = client.sync_once(sync_settings.clone()).await.unwrap(); server.reset().await; - let second = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); + let second = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => value); let item = second.as_event().unwrap(); assert_eq!(item.timestamp(), MilliSecondsSinceUnixEpoch(uint!(152038280))); assert!(item.event_id().is_some()); @@ -134,7 +139,7 @@ async fn edit() { let edit = assert_matches!( timeline_stream.next().await, - Some(VecDiff::UpdateAt { index: 1, value }) => value + Some(VectorDiff::Set { index: 1, value }) => value ); let edited = assert_matches!( edit.as_event().unwrap().content(), @@ -161,7 +166,7 @@ async fn echo() { let room = client.get_room(room_id).unwrap(); let timeline = Arc::new(room.timeline().await); - let mut timeline_stream = timeline.signal().to_stream(); + let (_, mut timeline_stream) = timeline.subscribe().await; let event_id = event_id!("$wWgymRfo7ri1uQx0NXO40vLJ"); let txn_id: &TransactionId = "my-txn-id".into(); @@ -177,20 +182,19 @@ async fn echo() { // Don't move the original timeline, it must live until the end of the test let timeline = timeline.clone(); + #[allow(unknown_lints, clippy::redundant_async_block)] // false positive let send_hdl = spawn(async move { timeline .send(RoomMessageEventContent::text_plain("Hello, World!").into(), Some(txn_id)) .await }); - let _day_divider = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); - let local_echo = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); + let _day_divider = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let local_echo = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => value); let item = local_echo.as_event().unwrap().as_local().unwrap(); - assert_matches!(&item.send_state, EventSendState::NotSentYet); + assert_matches!(item.send_state(), EventSendState::NotSentYet); - let msg = assert_matches!(&item.content, TimelineItemContent::Message(msg) => msg); + let msg = assert_matches!(item.content(), TimelineItemContent::Message(msg) => msg); let text = assert_matches!(msg.msgtype(), MessageType::Text(text) => text); assert_eq!(text.body, "Hello, World!"); @@ -199,10 +203,10 @@ async fn echo() { let sent_confirmation = assert_matches!( timeline_stream.next().await, - Some(VecDiff::UpdateAt { index: 1, value }) => value + Some(VectorDiff::Set { index: 1, value }) => value ); let item = sent_confirmation.as_event().unwrap().as_local().unwrap(); - assert_matches!(&item.send_state, EventSendState::Sent { .. }); + assert_matches!(item.send_state(), EventSendState::Sent { .. }); ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event( TimelineTestEvent::Custom(json!({ @@ -223,17 +227,25 @@ async fn echo() { server.reset().await; // Local echo is removed - assert_matches!(timeline_stream.next().await, Some(VecDiff::Pop { .. })); - // Bug, will be fixed later. See comment in remote_echo_without_txn_id test - // from `room::timeline::tests`. - let _day_divider = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); - - let remote_echo = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); + assert_matches!(timeline_stream.next().await, Some(VectorDiff::Remove { index: 1 })); + // Local echo day divider is removed + assert_matches!(timeline_stream.next().await, Some(VectorDiff::Remove { index: 0 })); + + // New day divider is added + let new_item = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::PushBack { value }) => value + ); + assert_matches!(&*new_item, TimelineItem::Virtual(VirtualTimelineItem::DayDivider(_))); + + // Remote echo is added + let remote_echo = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::PushBack { value }) => value + ); let item = remote_echo.as_event().unwrap().as_remote().unwrap(); - assert!(item.is_own); - assert_eq!(item.timestamp, MilliSecondsSinceUnixEpoch(uint!(152038280))); + assert!(item.is_own()); + assert_eq!(item.timestamp(), MilliSecondsSinceUnixEpoch(uint!(152038280))); } #[async_test] @@ -251,7 +263,7 @@ async fn back_pagination() { let room = client.get_room(room_id).unwrap(); let timeline = Arc::new(room.timeline().await); - let mut timeline_stream = timeline.signal().to_stream(); + let (_, mut timeline_stream) = timeline.subscribe().await; Mock::given(method("GET")) .and(path_regex(r"^/_matrix/client/r0/rooms/.*/messages$")) @@ -267,19 +279,19 @@ async fn back_pagination() { let loading = assert_matches!( timeline_stream.next().await, - Some(VecDiff::Push { value }) => value + Some(VectorDiff::PushFront { value }) => value ); assert_matches!(loading.as_virtual().unwrap(), VirtualTimelineItem::LoadingIndicator); let day_divider = assert_matches!( timeline_stream.next().await, - Some(VecDiff::Push { value }) => value + Some(VectorDiff::Insert { index: 1, value }) => value ); assert_matches!(day_divider.as_virtual().unwrap(), VirtualTimelineItem::DayDivider(_)); let message = assert_matches!( timeline_stream.next().await, - Some(VecDiff::Push { value }) => value + Some(VectorDiff::Insert { index: 2, value }) => value ); let msg = assert_matches!( message.as_event().unwrap().content(), @@ -290,7 +302,7 @@ async fn back_pagination() { let message = assert_matches!( timeline_stream.next().await, - Some(VecDiff::InsertAt { index: 2, value }) => value + Some(VectorDiff::Insert { index: 2, value }) => value ); let msg = assert_matches!( message.as_event().unwrap().content(), @@ -301,7 +313,7 @@ async fn back_pagination() { let message = assert_matches!( timeline_stream.next().await, - Some(VecDiff::InsertAt { index: 2, value }) => value + Some(VectorDiff::Insert { index: 2, value }) => value ); let state = assert_matches!( message.as_event().unwrap().content(), @@ -318,7 +330,7 @@ async fn back_pagination() { assert_eq!(prev_content.as_ref().unwrap().name.as_ref().unwrap(), "Old room name"); // Removal of the loading indicator - assert_matches!(timeline_stream.next().await, Some(VecDiff::RemoveAt { index: 0 })); + assert_matches!(timeline_stream.next().await, Some(VectorDiff::PopFront)); Mock::given(method("GET")) .and(path_regex(r"^/_matrix/client/r0/rooms/.*/messages$")) @@ -338,13 +350,13 @@ async fn back_pagination() { let loading = assert_matches!( timeline_stream.next().await, - Some(VecDiff::InsertAt { index: 0, value }) => value + Some(VectorDiff::PushFront { value }) => value ); assert_matches!(loading.as_virtual().unwrap(), VirtualTimelineItem::LoadingIndicator); let loading = assert_matches!( timeline_stream.next().await, - Some(VecDiff::UpdateAt { index: 0, value }) => value + Some(VectorDiff::Set { index: 0, value }) => value ); assert_matches!(loading.as_virtual().unwrap(), VirtualTimelineItem::TimelineStart); } @@ -364,7 +376,7 @@ async fn reaction() { let room = client.get_room(room_id).unwrap(); let timeline = room.timeline().await; - let mut timeline_stream = timeline.signal().to_stream(); + let (_, mut timeline_stream) = timeline.subscribe().await; ev_builder.add_joined_room( JoinedRoomBuilder::new(room_id) @@ -397,18 +409,22 @@ async fn reaction() { let _response = client.sync_once(sync_settings.clone()).await.unwrap(); server.reset().await; - let _day_divider = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); - let message = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); + let _day_divider = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::PushBack { value }) => value + ); + let message = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::PushBack { value }) => value + ); assert_matches!(message.as_event().unwrap().content(), TimelineItemContent::Message(_)); let updated_message = assert_matches!( timeline_stream.next().await, - Some(VecDiff::UpdateAt { index: 1, value }) => value + Some(VectorDiff::Set { index: 1, value }) => value ); let event_item = updated_message.as_event().unwrap().as_remote().unwrap(); - let msg = assert_matches!(&event_item.content, TimelineItemContent::Message(msg) => msg); + let msg = assert_matches!(event_item.content(), TimelineItemContent::Message(msg) => msg); assert!(!msg.is_edited()); assert_eq!(event_item.reactions().len(), 1); let group = &event_item.reactions()["πŸ‘"]; @@ -435,10 +451,10 @@ async fn reaction() { let updated_message = assert_matches!( timeline_stream.next().await, - Some(VecDiff::UpdateAt { index: 1, value }) => value + Some(VectorDiff::Set { index: 1, value }) => value ); let event_item = updated_message.as_event().unwrap().as_remote().unwrap(); - let msg = assert_matches!(&event_item.content, TimelineItemContent::Message(msg) => msg); + let msg = assert_matches!(event_item.content(), TimelineItemContent::Message(msg) => msg); assert!(!msg.is_edited()); assert_eq!(event_item.reactions().len(), 0); } @@ -458,7 +474,7 @@ async fn redacted_message() { let room = client.get_room(room_id).unwrap(); let timeline = room.timeline().await; - let mut timeline_stream = timeline.signal().to_stream(); + let (_, mut timeline_stream) = timeline.subscribe().await; ev_builder.add_joined_room( JoinedRoomBuilder::new(room_id) @@ -493,10 +509,14 @@ async fn redacted_message() { let _response = client.sync_once(sync_settings.clone()).await.unwrap(); server.reset().await; - let _day_divider = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); - let first = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); + let _day_divider = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::PushBack { value }) => value + ); + let first = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::PushBack { value }) => value + ); assert_matches!(first.as_event().unwrap().content(), TimelineItemContent::RedactedMessage); // TODO: After adding raw timeline items, check for one here @@ -517,7 +537,7 @@ async fn read_marker() { let room = client.get_room(room_id).unwrap(); let timeline = room.timeline().await; - let mut timeline_stream = timeline.signal().to_stream(); + let (_, mut timeline_stream) = timeline.subscribe().await; ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event( TimelineTestEvent::Custom(json!({ @@ -536,10 +556,8 @@ async fn read_marker() { let _response = client.sync_once(sync_settings.clone()).await.unwrap(); server.reset().await; - let _day_divider = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); - let message = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); + let _day_divider = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let message = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => value); assert_matches!(message.as_event().unwrap().content(), TimelineItemContent::Message(_)); ev_builder.add_joined_room( @@ -550,8 +568,10 @@ async fn read_marker() { let _response = client.sync_once(sync_settings.clone()).await.unwrap(); server.reset().await; - let marker = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); + let marker = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::Insert { index: 2, value }) => value + ); assert_matches!(marker.as_virtual().unwrap(), VirtualTimelineItem::ReadMarker); } @@ -570,7 +590,7 @@ async fn in_reply_to_details() { let room = client.get_room(room_id).unwrap(); let timeline = room.timeline().await; - let mut timeline_stream = timeline.signal().to_stream(); + let (_, mut timeline_stream) = timeline.subscribe().await; // The event doesn't exist. assert_matches!( @@ -611,24 +631,21 @@ async fn in_reply_to_details() { let _response = client.sync_once(sync_settings.clone()).await.unwrap(); server.reset().await; - let _day_divider = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); - let first = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); + let _day_divider = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let first = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => value); assert_matches!(first.as_event().unwrap().content(), TimelineItemContent::Message(_)); - let second = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); + let second = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => value); let second_event = second.as_event().unwrap().as_remote().unwrap(); let message = - assert_matches!(&second_event.content, TimelineItemContent::Message(message) => message); + assert_matches!(second_event.content(), TimelineItemContent::Message(message) => message); let in_reply_to = message.in_reply_to().unwrap(); assert_eq!(in_reply_to.event_id, event_id!("$event1")); assert_matches!(in_reply_to.details, TimelineDetails::Unavailable); // Fetch details locally first. - timeline.fetch_event_details(&second_event.event_id).await.unwrap(); + timeline.fetch_event_details(second_event.event_id()).await.unwrap(); - let second = assert_matches!(timeline_stream.next().await, Some(VecDiff::UpdateAt { index: 2, value }) => value); + let second = assert_matches!(timeline_stream.next().await, Some(VectorDiff::Set { index: 2, value }) => value); let message = assert_matches!(second.as_event().unwrap().content(), TimelineItemContent::Message(message) => message); assert_matches!(message.in_reply_to().unwrap().details, TimelineDetails::Ready(_)); @@ -654,11 +671,13 @@ async fn in_reply_to_details() { let _response = client.sync_once(sync_settings.clone()).await.unwrap(); server.reset().await; - let third = - assert_matches!(timeline_stream.next().await, Some(VecDiff::Push { value }) => value); + let _read_receipt_update = + assert_matches!(timeline_stream.next().await, Some(VectorDiff::Set { value, .. }) => value); + + let third = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => value); let third_event = third.as_event().unwrap().as_remote().unwrap(); let message = - assert_matches!(&third_event.content, TimelineItemContent::Message(message) => message); + assert_matches!(third_event.content(), TimelineItemContent::Message(message) => message); let in_reply_to = message.in_reply_to().unwrap(); assert_eq!(in_reply_to.event_id, event_id!("$remoteevent")); assert_matches!(in_reply_to.details, TimelineDetails::Unavailable); @@ -675,14 +694,14 @@ async fn in_reply_to_details() { .await; // Fetch details remotely if we can't find them locally. - timeline.fetch_event_details(&third_event.event_id).await.unwrap(); + timeline.fetch_event_details(third_event.event_id()).await.unwrap(); server.reset().await; - let third = assert_matches!(timeline_stream.next().await, Some(VecDiff::UpdateAt { index: 3, value }) => value); + let third = assert_matches!(timeline_stream.next().await, Some(VectorDiff::Set { index: 3, value }) => value); let message = assert_matches!(third.as_event().unwrap().content(), TimelineItemContent::Message(message) => message); assert_matches!(message.in_reply_to().unwrap().details, TimelineDetails::Pending); - let third = assert_matches!(timeline_stream.next().await, Some(VecDiff::UpdateAt { index: 3, value }) => value); + let third = assert_matches!(timeline_stream.next().await, Some(VectorDiff::Set { index: 3, value }) => value); let message = assert_matches!(third.as_event().unwrap().content(), TimelineItemContent::Message(message) => message); assert_matches!(message.in_reply_to().unwrap().details, TimelineDetails::Error(_)); @@ -704,13 +723,188 @@ async fn in_reply_to_details() { .mount(&server) .await; - timeline.fetch_event_details(&third_event.event_id).await.unwrap(); + timeline.fetch_event_details(third_event.event_id()).await.unwrap(); - let third = assert_matches!(timeline_stream.next().await, Some(VecDiff::UpdateAt { index: 3, value }) => value); + let third = assert_matches!(timeline_stream.next().await, Some(VectorDiff::Set { index: 3, value }) => value); let message = assert_matches!(third.as_event().unwrap().content(), TimelineItemContent::Message(message) => message); assert_matches!(message.in_reply_to().unwrap().details, TimelineDetails::Pending); - let third = assert_matches!(timeline_stream.next().await, Some(VecDiff::UpdateAt { index: 3, value }) => value); + let third = assert_matches!(timeline_stream.next().await, Some(VectorDiff::Set { index: 3, value }) => value); let message = assert_matches!(third.as_event().unwrap().content(), TimelineItemContent::Message(message) => message); assert_matches!(message.in_reply_to().unwrap().details, TimelineDetails::Ready(_)); } + +#[async_test] +async fn sync_highlighted() { + let room_id = room_id!("!a98sd12bjh:example.org"); + let (client, server) = logged_in_client().await; + let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000)); + + let mut ev_builder = EventBuilder::new(); + ev_builder + // We need the member event and power levels locally so the push rules processor works. + .add_joined_room( + JoinedRoomBuilder::new(room_id) + .add_state_event(StateTestEvent::Member) + .add_state_event(StateTestEvent::PowerLevels), + ); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + let room = client.get_room(room_id).unwrap(); + let timeline = room.timeline().await; + let (_, mut timeline_stream) = timeline.subscribe().await; + + ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event( + TimelineTestEvent::Custom(json!({ + "content": { + "body": "hello", + "msgtype": "m.text", + }, + "event_id": "$msda7m0df9E9op3", + "origin_server_ts": 152037280, + "sender": "@example:localhost", + "type": "m.room.message", + })), + )); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + let _day_divider = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::PushBack { value }) => value + ); + let first = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::PushBack { value }) => value + ); + let remote_event = first.as_event().unwrap().as_remote().unwrap(); + // Own events don't trigger push rules. + assert!(!remote_event.is_highlighted()); + + ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event( + TimelineTestEvent::Custom(json!({ + "content": { + "body": "This room has been replaced", + "replacement_room": "!newroom:localhost", + }, + "event_id": "$foun39djjod0f", + "origin_server_ts": 152039280, + "sender": "@bob:localhost", + "state_key": "", + "type": "m.room.tombstone", + })), + )); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + let second = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::PushBack { value }) => value + ); + let remote_event = second.as_event().unwrap().as_remote().unwrap(); + // `m.room.tombstone` should be highlighted by default. + assert!(remote_event.is_highlighted()); +} + +#[async_test] +async fn back_pagination_highlighted() { + let room_id = room_id!("!a98sd12bjh:example.org"); + let (client, server) = logged_in_client().await; + let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000)); + + let mut ev_builder = EventBuilder::new(); + ev_builder + // We need the member event and power levels locally so the push rules processor works. + .add_joined_room( + JoinedRoomBuilder::new(room_id) + .add_state_event(StateTestEvent::Member) + .add_state_event(StateTestEvent::PowerLevels), + ); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + let room = client.get_room(room_id).unwrap(); + let timeline = Arc::new(room.timeline().await); + let (_, mut timeline_stream) = timeline.subscribe().await; + + let response_json = json!({ + "chunk": [ + { + "content": { + "body": "hello", + "msgtype": "m.text", + }, + "event_id": "$msda7m0df9E9op3", + "origin_server_ts": 152037280, + "sender": "@example:localhost", + "type": "m.room.message", + "room_id": room_id, + }, + { + "content": { + "body": "This room has been replaced", + "replacement_room": "!newroom:localhost", + }, + "event_id": "$foun39djjod0f", + "origin_server_ts": 152039280, + "sender": "@bob:localhost", + "state_key": "", + "type": "m.room.tombstone", + "room_id": room_id, + }, + ], + "end": "t47409-4357353_219380_26003_2269", + "start": "t392-516_47314_0_7_1_1_1_11444_1" + }); + Mock::given(method("GET")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/messages$")) + .and(header("authorization", "Bearer 1234")) + .respond_with(ResponseTemplate::new(200).set_body_json(response_json)) + .expect(1) + .named("messages_batch_1") + .mount(&server) + .await; + + timeline.paginate_backwards(PaginationOptions::single_request(10)).await.unwrap(); + server.reset().await; + + let loading = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::PushFront { value }) => value + ); + assert_matches!(loading.as_virtual().unwrap(), VirtualTimelineItem::LoadingIndicator); + + let day_divider = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::Insert { index: 1, value }) => value + ); + assert_matches!(day_divider.as_virtual().unwrap(), VirtualTimelineItem::DayDivider(_)); + + let first = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::Insert { index: 2, value }) => value + ); + let remote_event = first.as_event().unwrap().as_remote().unwrap(); + // Own events don't trigger push rules. + assert!(!remote_event.is_highlighted()); + + let second = assert_matches!( + timeline_stream.next().await, + Some(VectorDiff::Insert { index: 2, value }) => value + ); + let remote_event = second.as_event().unwrap().as_remote().unwrap(); + // `m.room.tombstone` should be highlighted by default. + assert!(remote_event.is_highlighted()); + + // Removal of the loading indicator + assert_matches!(timeline_stream.next().await, Some(VectorDiff::PopFront)); +} diff --git a/crates/matrix-sdk/tests/integration/room/timeline/read_receipts.rs b/crates/matrix-sdk/tests/integration/room/timeline/read_receipts.rs new file mode 100644 index 00000000000..9edc03a8ed5 --- /dev/null +++ b/crates/matrix-sdk/tests/integration/room/timeline/read_receipts.rs @@ -0,0 +1,778 @@ +use std::time::Duration; + +use assert_matches::assert_matches; +use eyeball_im::VectorDiff; +use futures_util::StreamExt; +use matrix_sdk::{config::SyncSettings, room::Receipts}; +use matrix_sdk_test::{ + async_test, EphemeralTestEvent, EventBuilder, JoinedRoomBuilder, RoomAccountDataTestEvent, + TimelineTestEvent, +}; +use ruma::{ + api::client::receipt::create_receipt::v3::ReceiptType, event_id, + events::receipt::ReceiptThread, room_id, user_id, +}; +use serde_json::json; +use wiremock::{ + matchers::{body_json, header, method, path_regex}, + Mock, ResponseTemplate, +}; + +use crate::{logged_in_client, mock_sync}; + +#[async_test] +async fn read_receipts_updates() { + let room_id = room_id!("!a98sd12bjh:example.org"); + let (client, server) = logged_in_client().await; + let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000)); + + let own_user_id = client.user_id().unwrap(); + let alice = user_id!("@alice:localhost"); + let bob = user_id!("@bob:localhost"); + + let second_event_id = event_id!("$e32037280er453l:localhost"); + let third_event_id = event_id!("$Sg2037280074GZr34:localhost"); + + let mut ev_builder = EventBuilder::new(); + ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id)); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + let room = client.get_room(room_id).unwrap(); + let timeline = room.timeline().await; + let (items, mut timeline_stream) = timeline.subscribe().await; + + assert!(items.is_empty()); + + let own_receipt = timeline.latest_user_read_receipt(own_user_id).await; + assert_matches!(own_receipt, None); + let alice_receipt = timeline.latest_user_read_receipt(alice).await; + assert_matches!(alice_receipt, None); + let bob_receipt = timeline.latest_user_read_receipt(bob).await; + assert_matches!(bob_receipt, None); + + ev_builder.add_joined_room( + JoinedRoomBuilder::new(room_id) + .add_timeline_event(TimelineTestEvent::MessageText) + .add_timeline_event(TimelineTestEvent::Custom(json!({ + "content": { + "body": "I'm dancing too", + "msgtype": "m.text" + }, + "event_id": second_event_id, + "origin_server_ts": 152039280, + "sender": alice, + "type": "m.room.message", + }))) + .add_timeline_event(TimelineTestEvent::Custom(json!({ + "content": { + "body": "Viva la macarena!", + "msgtype": "m.text" + }, + "event_id": third_event_id, + "origin_server_ts": 152045280, + "sender": alice, + "type": "m.room.message", + }))), + ); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + let _day_divider = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => value); + + // We don't list the read receipt of our own user on events. + let first_item = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let first_event = first_item.as_event().unwrap().as_remote().unwrap(); + assert!(first_event.read_receipts().is_empty()); + + let (own_receipt_event_id, _) = timeline.latest_user_read_receipt(own_user_id).await.unwrap(); + assert_eq!(own_receipt_event_id, first_event.event_id()); + + // Implicit read receipt of @alice:localhost. + let second_item = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let second_event = second_item.as_event().unwrap().as_remote().unwrap(); + assert_eq!(second_event.read_receipts().len(), 1); + + // Read receipt of @alice:localhost is moved to third event. + let second_item = assert_matches!(timeline_stream.next().await, Some(VectorDiff::Set { index: 2, value }) => value); + let second_event = second_item.as_event().unwrap().as_remote().unwrap(); + assert!(second_event.read_receipts().is_empty()); + + let third_item = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => value); + let third_event = third_item.as_event().unwrap().as_remote().unwrap(); + assert_eq!(third_event.read_receipts().len(), 1); + + let (alice_receipt_event_id, _) = timeline.latest_user_read_receipt(alice).await.unwrap(); + assert_eq!(alice_receipt_event_id, third_event_id); + + // Read receipt on unknown event is ignored. + ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_ephemeral_event( + EphemeralTestEvent::Custom(json!({ + "content": { + "$unknowneventid": { + "m.read": { + alice: { + "ts": 1436453550, + }, + }, + }, + }, + "type": "m.receipt", + })), + )); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + let (alice_receipt_event_id, _) = timeline.latest_user_read_receipt(alice).await.unwrap(); + assert_eq!(alice_receipt_event_id, third_event.event_id()); + + // Read receipt on older event is ignored. + ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_ephemeral_event( + EphemeralTestEvent::Custom(json!({ + "content": { + second_event_id: { + "m.read": { + alice: { + "ts": 1436451550, + }, + }, + }, + }, + "type": "m.receipt", + })), + )); + + let (alice_receipt_event_id, _) = timeline.latest_user_read_receipt(alice).await.unwrap(); + assert_eq!(alice_receipt_event_id, third_event_id); + + // Read receipt on same event is ignored. + ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_ephemeral_event( + EphemeralTestEvent::Custom(json!({ + "content": { + third_event_id: { + "m.read": { + alice: { + "ts": 1436451550, + }, + }, + }, + }, + "type": "m.receipt", + })), + )); + + let (alice_receipt_event_id, _) = timeline.latest_user_read_receipt(alice).await.unwrap(); + assert_eq!(alice_receipt_event_id, third_event_id); + + // New user with explicit read receipt. + ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_ephemeral_event( + EphemeralTestEvent::Custom(json!({ + "content": { + third_event_id: { + "m.read": { + bob: { + "ts": 1436451550, + }, + }, + }, + }, + "type": "m.receipt", + })), + )); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + let third_item = assert_matches!(timeline_stream.next().await, Some(VectorDiff::Set { index: 3, value }) => value); + let third_event = third_item.as_event().unwrap().as_remote().unwrap(); + assert_eq!(third_event.read_receipts().len(), 2); + + let (bob_receipt_event_id, _) = timeline.latest_user_read_receipt(bob).await.unwrap(); + assert_eq!(bob_receipt_event_id, third_event_id); + + // Private read receipt is updated. + ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_ephemeral_event( + EphemeralTestEvent::Custom(json!({ + "content": { + second_event_id: { + "m.read.private": { + own_user_id: { + "ts": 1436453550, + }, + }, + }, + }, + "type": "m.receipt", + })), + )); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + let (own_user_receipt_event_id, _) = + timeline.latest_user_read_receipt(own_user_id).await.unwrap(); + assert_eq!(own_user_receipt_event_id, second_event_id); +} + +#[async_test] +async fn send_single_receipt() { + let room_id = room_id!("!a98sd12bjh:example.org"); + let (client, server) = logged_in_client().await; + let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000)); + + let own_user_id = client.user_id().unwrap(); + + let mut ev_builder = EventBuilder::new(); + ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id)); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + let room = client.get_room(room_id).unwrap(); + let timeline = room.timeline().await; + + // Unknown receipts are sent. + let first_receipts_event_id = event_id!("$first_receipts_event_id"); + + Mock::given(method("POST")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/receipt/m\.read/")) + .and(header("authorization", "Bearer 1234")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) + .expect(1) + .named("Public read receipt") + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/receipt/m\.read\.private/")) + .and(header("authorization", "Bearer 1234")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) + .expect(1) + .named("Private read receipt") + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/receipt/m\.fully_read/")) + .and(header("authorization", "Bearer 1234")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) + .expect(1) + .named("Fully-read marker") + .mount(&server) + .await; + + timeline + .send_single_receipt( + ReceiptType::Read, + ReceiptThread::Unthreaded, + first_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + timeline + .send_single_receipt( + ReceiptType::ReadPrivate, + ReceiptThread::Unthreaded, + first_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + timeline + .send_single_receipt( + ReceiptType::FullyRead, + ReceiptThread::Unthreaded, + first_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + server.reset().await; + + // Unchanged receipts are not sent. + ev_builder.add_joined_room( + JoinedRoomBuilder::new(room_id) + .add_ephemeral_event(EphemeralTestEvent::Custom(json!({ + "content": { + first_receipts_event_id: { + "m.read.private": { + own_user_id: { + "ts": 1436453550, + }, + }, + "m.read": { + own_user_id: { + "ts": 1436453550, + }, + }, + }, + }, + "type": "m.receipt", + }))) + .add_account_data(RoomAccountDataTestEvent::Custom(json!({ + "content": { + "event_id": first_receipts_event_id, + }, + "type": "m.fully_read", + }))), + ); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + timeline + .send_single_receipt( + ReceiptType::Read, + ReceiptThread::Unthreaded, + first_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + timeline + .send_single_receipt( + ReceiptType::ReadPrivate, + ReceiptThread::Unthreaded, + first_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + timeline + .send_single_receipt( + ReceiptType::FullyRead, + ReceiptThread::Unthreaded, + first_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + server.reset().await; + + // Receipts with unknown previous receipts are always sent. + let second_receipts_event_id = event_id!("$second_receipts_event_id"); + + Mock::given(method("POST")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/receipt/m\.read/")) + .and(header("authorization", "Bearer 1234")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) + .expect(1) + .named("Public read receipt") + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/receipt/m\.read\.private/")) + .and(header("authorization", "Bearer 1234")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) + .expect(1) + .named("Private read receipt") + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/receipt/m\.fully_read/")) + .and(header("authorization", "Bearer 1234")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) + .expect(1) + .named("Fully-read marker") + .mount(&server) + .await; + + timeline + .send_single_receipt( + ReceiptType::Read, + ReceiptThread::Unthreaded, + second_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + timeline + .send_single_receipt( + ReceiptType::ReadPrivate, + ReceiptThread::Unthreaded, + second_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + timeline + .send_single_receipt( + ReceiptType::FullyRead, + ReceiptThread::Unthreaded, + second_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + server.reset().await; + + // Newer receipts in the timeline are sent. + let third_receipts_event_id = event_id!("$third_receipts_event_id"); + + ev_builder.add_joined_room( + JoinedRoomBuilder::new(room_id) + .add_timeline_event(TimelineTestEvent::Custom(json!({ + "content": { + "body": "I'm User A", + "msgtype": "m.text", + }, + "event_id": second_receipts_event_id, + "origin_server_ts": 152046694, + "sender": "@user_a:example.org", + "type": "m.room.message", + }))) + .add_timeline_event(TimelineTestEvent::Custom(json!({ + "content": { + "body": "I'm User B", + "msgtype": "m.text", + }, + "event_id": third_receipts_event_id, + "origin_server_ts": 152049794, + "sender": "@user_b:example.org", + "type": "m.room.message", + }))) + .add_ephemeral_event(EphemeralTestEvent::Custom(json!({ + "content": { + second_receipts_event_id: { + "m.read.private": { + own_user_id: { + "ts": 1436453550, + }, + }, + "m.read": { + own_user_id: { + "ts": 1436453550, + }, + }, + }, + }, + "type": "m.receipt", + }))) + .add_account_data(RoomAccountDataTestEvent::Custom(json!({ + "content": { + "event_id": second_receipts_event_id, + }, + "type": "m.fully_read", + }))), + ); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + Mock::given(method("POST")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/receipt/m\.read/")) + .and(header("authorization", "Bearer 1234")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) + .expect(1) + .named("Public read receipt") + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/receipt/m\.read\.private/")) + .and(header("authorization", "Bearer 1234")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) + .expect(1) + .named("Private read receipt") + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/receipt/m\.fully_read/")) + .and(header("authorization", "Bearer 1234")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) + .expect(1) + .named("Fully-read marker") + .mount(&server) + .await; + + timeline + .send_single_receipt( + ReceiptType::Read, + ReceiptThread::Unthreaded, + third_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + timeline + .send_single_receipt( + ReceiptType::ReadPrivate, + ReceiptThread::Unthreaded, + third_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + timeline + .send_single_receipt( + ReceiptType::FullyRead, + ReceiptThread::Unthreaded, + third_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + server.reset().await; + + // Older receipts in the timeline are not sent. + ev_builder.add_joined_room( + JoinedRoomBuilder::new(room_id) + .add_ephemeral_event(EphemeralTestEvent::Custom(json!({ + "content": { + third_receipts_event_id: { + "m.read.private": { + own_user_id: { + "ts": 1436453550, + }, + }, + "m.read": { + own_user_id: { + "ts": 1436453550, + }, + }, + }, + }, + "type": "m.receipt", + }))) + .add_account_data(RoomAccountDataTestEvent::Custom(json!({ + "content": { + "event_id": third_receipts_event_id, + }, + "type": "m.fully_read", + }))), + ); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + timeline + .send_single_receipt( + ReceiptType::Read, + ReceiptThread::Unthreaded, + second_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + timeline + .send_single_receipt( + ReceiptType::ReadPrivate, + ReceiptThread::Unthreaded, + second_receipts_event_id.to_owned(), + ) + .await + .unwrap(); + timeline + .send_single_receipt( + ReceiptType::FullyRead, + ReceiptThread::Unthreaded, + second_receipts_event_id.to_owned(), + ) + .await + .unwrap(); +} + +#[async_test] +async fn send_multiple_receipts() { + let room_id = room_id!("!a98sd12bjh:example.org"); + let (client, server) = logged_in_client().await; + let sync_settings = SyncSettings::new().timeout(Duration::from_millis(3000)); + + let own_user_id = client.user_id().unwrap(); + + let mut ev_builder = EventBuilder::new(); + ev_builder.add_joined_room(JoinedRoomBuilder::new(room_id)); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + let room = client.get_room(room_id).unwrap(); + let timeline = room.timeline().await; + + // Unknown receipts are sent. + let first_receipts_event_id = event_id!("$first_receipts_event_id"); + let first_receipts = Receipts::new() + .fully_read_marker(Some(first_receipts_event_id.to_owned())) + .public_read_receipt(Some(first_receipts_event_id.to_owned())) + .private_read_receipt(Some(first_receipts_event_id.to_owned())); + + Mock::given(method("POST")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/read_markers$")) + .and(header("authorization", "Bearer 1234")) + .and(body_json(json!({ + "m.fully_read": first_receipts_event_id, + "m.read": first_receipts_event_id, + "m.read.private": first_receipts_event_id, + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) + .expect(1) + .mount(&server) + .await; + + timeline.send_multiple_receipts(first_receipts.clone()).await.unwrap(); + server.reset().await; + + // Unchanged receipts are not sent. + ev_builder.add_joined_room( + JoinedRoomBuilder::new(room_id) + .add_ephemeral_event(EphemeralTestEvent::Custom(json!({ + "content": { + first_receipts_event_id: { + "m.read.private": { + own_user_id: { + "ts": 1436453550, + }, + }, + "m.read": { + own_user_id: { + "ts": 1436453550, + }, + }, + }, + }, + "type": "m.receipt", + }))) + .add_account_data(RoomAccountDataTestEvent::Custom(json!({ + "content": { + "event_id": first_receipts_event_id, + }, + "type": "m.fully_read", + }))), + ); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + timeline.send_multiple_receipts(first_receipts).await.unwrap(); + server.reset().await; + + // Receipts with unknown previous receipts are always sent. + let second_receipts_event_id = event_id!("$second_receipts_event_id"); + let second_receipts = Receipts::new() + .fully_read_marker(Some(second_receipts_event_id.to_owned())) + .public_read_receipt(Some(second_receipts_event_id.to_owned())) + .private_read_receipt(Some(second_receipts_event_id.to_owned())); + + Mock::given(method("POST")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/read_markers$")) + .and(header("authorization", "Bearer 1234")) + .and(body_json(json!({ + "m.fully_read": second_receipts_event_id, + "m.read": second_receipts_event_id, + "m.read.private": second_receipts_event_id, + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) + .expect(1) + .mount(&server) + .await; + + timeline.send_multiple_receipts(second_receipts.clone()).await.unwrap(); + server.reset().await; + + // Newer receipts in the timeline are sent. + let third_receipts_event_id = event_id!("$third_receipts_event_id"); + let third_receipts = Receipts::new() + .fully_read_marker(Some(third_receipts_event_id.to_owned())) + .public_read_receipt(Some(third_receipts_event_id.to_owned())) + .private_read_receipt(Some(third_receipts_event_id.to_owned())); + + ev_builder.add_joined_room( + JoinedRoomBuilder::new(room_id) + .add_timeline_event(TimelineTestEvent::Custom(json!({ + "content": { + "body": "I'm User A", + "msgtype": "m.text", + }, + "event_id": second_receipts_event_id, + "origin_server_ts": 152046694, + "sender": "@user_a:example.org", + "type": "m.room.message", + }))) + .add_timeline_event(TimelineTestEvent::Custom(json!({ + "content": { + "body": "I'm User B", + "msgtype": "m.text", + }, + "event_id": third_receipts_event_id, + "origin_server_ts": 152049794, + "sender": "@user_b:example.org", + "type": "m.room.message", + }))) + .add_ephemeral_event(EphemeralTestEvent::Custom(json!({ + "content": { + second_receipts_event_id: { + "m.read.private": { + own_user_id: { + "ts": 1436453550, + }, + }, + "m.read": { + own_user_id: { + "ts": 1436453550, + }, + }, + }, + }, + "type": "m.receipt", + }))) + .add_account_data(RoomAccountDataTestEvent::Custom(json!({ + "content": { + "event_id": second_receipts_event_id, + }, + "type": "m.fully_read", + }))), + ); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + Mock::given(method("POST")) + .and(path_regex(r"^/_matrix/client/r0/rooms/.*/read_markers$")) + .and(header("authorization", "Bearer 1234")) + .and(body_json(json!({ + "m.fully_read": third_receipts_event_id, + "m.read": third_receipts_event_id, + "m.read.private": third_receipts_event_id, + }))) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) + .expect(1) + .mount(&server) + .await; + + timeline.send_multiple_receipts(third_receipts.clone()).await.unwrap(); + server.reset().await; + + // Older receipts in the timeline are not sent. + ev_builder.add_joined_room( + JoinedRoomBuilder::new(room_id) + .add_ephemeral_event(EphemeralTestEvent::Custom(json!({ + "content": { + third_receipts_event_id: { + "m.read.private": { + own_user_id: { + "ts": 1436453550, + }, + }, + "m.read": { + own_user_id: { + "ts": 1436453550, + }, + }, + }, + }, + "type": "m.receipt", + }))) + .add_account_data(RoomAccountDataTestEvent::Custom(json!({ + "content": { + "event_id": third_receipts_event_id, + }, + "type": "m.fully_read", + }))), + ); + + mock_sync(&server, ev_builder.build_json_sync_response(), None).await; + let _response = client.sync_once(sync_settings.clone()).await.unwrap(); + server.reset().await; + + timeline.send_multiple_receipts(second_receipts.clone()).await.unwrap(); +} diff --git a/examples/appservice_autojoin/Cargo.toml b/examples/appservice_autojoin/Cargo.toml index e15c84c181b..9d8c1c2b7bf 100644 --- a/examples/appservice_autojoin/Cargo.toml +++ b/examples/appservice_autojoin/Cargo.toml @@ -10,7 +10,7 @@ test = false [dependencies] anyhow = "1" -tokio = { version = "1.23.1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } tracing-subscriber = "0.3.15" tracing = { workspace = true } diff --git a/examples/autojoin/Cargo.toml b/examples/autojoin/Cargo.toml index de7ac7526d5..62b70b707f7 100644 --- a/examples/autojoin/Cargo.toml +++ b/examples/autojoin/Cargo.toml @@ -9,7 +9,7 @@ name = "example-autojoin" test = false [dependencies] -tokio = { version = "1.23.1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } anyhow = "1" tracing-subscriber = "0.3.15" diff --git a/examples/autojoin/src/main.rs b/examples/autojoin/src/main.rs index bfb9b3b83bf..3964692bc29 100644 --- a/examples/autojoin/src/main.rs +++ b/examples/autojoin/src/main.rs @@ -43,22 +43,10 @@ async fn login_and_sync( username: &str, password: &str, ) -> anyhow::Result<()> { - #[allow(unused_mut)] - let mut client_builder = Client::builder().homeserver_url(homeserver_url); - - #[cfg(feature = "sled")] - { - // The location to save files to - let home = dirs::home_dir().expect("no home directory found").join("autojoin_bot"); - client_builder = client_builder.sled_store(home, None)?; - } - - #[cfg(feature = "indexeddb")] - { - client_builder = client_builder.indexeddb_store("autojoin_bot", None).await?; - } - - let client = client_builder.build().await?; + // Note that when encryption is enabled, you should use a persistent store to be + // able to restore the session with a working encryption setup. + // See the `persist_session` example. + let client = Client::builder().homeserver_url(homeserver_url).build().await?; client.login_username(username, password).initial_device_display_name("autojoin bot").await?; diff --git a/examples/command_bot/Cargo.toml b/examples/command_bot/Cargo.toml index 2c1ae729710..bb3ffd8071c 100644 --- a/examples/command_bot/Cargo.toml +++ b/examples/command_bot/Cargo.toml @@ -10,7 +10,7 @@ test = false [dependencies] anyhow = "1" -tokio = { version = "1.23.1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } tracing-subscriber = "0.3.15" url = "2.2.2" diff --git a/examples/command_bot/src/main.rs b/examples/command_bot/src/main.rs index 52030d3362d..e5adc376bbb 100644 --- a/examples/command_bot/src/main.rs +++ b/examples/command_bot/src/main.rs @@ -35,29 +35,16 @@ async fn login_and_sync( username: String, password: String, ) -> anyhow::Result<()> { - #[allow(unused_mut)] - let mut client_builder = Client::builder().homeserver_url(homeserver_url); - - #[cfg(feature = "sled")] - { - // The location to save files to - let home = dirs::home_dir().expect("no home directory found").join("party_bot"); - client_builder = client_builder.sled_store(home, None)?; - } - - #[cfg(feature = "indexeddb")] - { - client_builder = client_builder.indexeddb_store("party_bot", None).await?; - } - - let client = client_builder.build().await.unwrap(); + // Note that when encryption is enabled, you should use a persistent store to be + // able to restore the session with a working encryption setup. + // See the `persist_session` example. + let client = Client::builder().homeserver_url(homeserver_url).build().await.unwrap(); client.login_username(&username, &password).initial_device_display_name("command bot").await?; println!("logged in as {username}"); // An initial sync to set up state and so our bot doesn't respond to old - // messages. If the `StateStore` finds saved state in the location given the - // initial sync will be skipped in favor of loading state from the store + // messages. let response = client.sync_once(SyncSettings::default()).await.unwrap(); // add our CommandBot to be notified of incoming messages, we do this after the // initial sync to avoid responding to messages before the bot was running. diff --git a/examples/cross_signing_bootstrap/Cargo.toml b/examples/cross_signing_bootstrap/Cargo.toml index f49b2ab3fed..714ccf829aa 100644 --- a/examples/cross_signing_bootstrap/Cargo.toml +++ b/examples/cross_signing_bootstrap/Cargo.toml @@ -10,7 +10,7 @@ test = false [dependencies] anyhow = "1" -tokio = { version = "1.23.1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } tracing-subscriber = "0.3.15" url = "2.2.2" diff --git a/examples/custom_events/Cargo.toml b/examples/custom_events/Cargo.toml index 20eb3aea3cf..3ae2394fef6 100644 --- a/examples/custom_events/Cargo.toml +++ b/examples/custom_events/Cargo.toml @@ -12,7 +12,7 @@ test = false anyhow = "1" dirs = "4.0.0" serde = "1.0" -tokio = { version = "1.23.1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } tracing-subscriber = "0.3.15" [dependencies.matrix-sdk] diff --git a/examples/custom_events/src/main.rs b/examples/custom_events/src/main.rs index a7000da4ded..62f8c724cc3 100644 --- a/examples/custom_events/src/main.rs +++ b/examples/custom_events/src/main.rs @@ -103,9 +103,7 @@ async fn login_and_sync( username: &str, password: &str, ) -> anyhow::Result<()> { - let home = dirs::data_dir().expect("no home directory found").join("getting_started"); - let client = - Client::builder().homeserver_url(homeserver_url).sled_store(home, None).build().await?; + let client = Client::builder().homeserver_url(homeserver_url).build().await?; client .login_username(username, password) .initial_device_display_name("getting started bot") diff --git a/examples/emoji_verification/Cargo.toml b/examples/emoji_verification/Cargo.toml index 79e09e541ec..6203b4ad8e5 100644 --- a/examples/emoji_verification/Cargo.toml +++ b/examples/emoji_verification/Cargo.toml @@ -10,7 +10,7 @@ test = false [dependencies] anyhow = "1" -tokio = { version = "1.23.1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } clap = { version = "4.0.15", features = ["derive"] } futures = "0.3.24" tracing-subscriber = "0.3.16" diff --git a/examples/get_profiles/Cargo.toml b/examples/get_profiles/Cargo.toml index 0adcf257a5c..ba766ba52e5 100644 --- a/examples/get_profiles/Cargo.toml +++ b/examples/get_profiles/Cargo.toml @@ -10,7 +10,7 @@ test = false [dependencies] anyhow = "1" -tokio = { version = "1.23.1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } tracing-subscriber = "0.3.15" url = "2.2.2" diff --git a/examples/getting_started/Cargo.toml b/examples/getting_started/Cargo.toml index a147f1e8e4d..7bdc9eed9f4 100644 --- a/examples/getting_started/Cargo.toml +++ b/examples/getting_started/Cargo.toml @@ -11,7 +11,7 @@ test = false [dependencies] anyhow = "1" dirs = "4.0.0" -tokio = { version = "1.23.1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } tracing-subscriber = "0.3.15" [dependencies.matrix-sdk] diff --git a/examples/getting_started/src/main.rs b/examples/getting_started/src/main.rs index f4b8412c449..b5ffc79cd96 100644 --- a/examples/getting_started/src/main.rs +++ b/examples/getting_started/src/main.rs @@ -59,15 +59,12 @@ async fn login_and_sync( ) -> anyhow::Result<()> { // First, we set up the client. - let home = dirs::data_dir().expect("no home directory found").join("getting_started"); - + // Note that when encryption is enabled, you should use a persistent store to be + // able to restore the session with a working encryption setup. + // See the `persist_session` example. let client = Client::builder() // We use the convenient client builder to set our custom homeserver URL on it. .homeserver_url(homeserver_url) - // Matrix-SDK has support for pluggable, configurable state and crypto-store - // support we use the default sled-store (enabled by default on native - // architectures), to configure a local cache and store for our crypto keys - .sled_store(home, None) .build() .await?; diff --git a/examples/image_bot/Cargo.toml b/examples/image_bot/Cargo.toml index 7abc9a646ce..1b743e99b2c 100644 --- a/examples/image_bot/Cargo.toml +++ b/examples/image_bot/Cargo.toml @@ -11,7 +11,7 @@ test = false [dependencies] anyhow = "1" mime = "0.3.16" -tokio = { version = "1.23.1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } tracing-subscriber = "0.3.15" url = "2.2.2" diff --git a/examples/login/Cargo.toml b/examples/login/Cargo.toml index 5186a576ca9..ef680556480 100644 --- a/examples/login/Cargo.toml +++ b/examples/login/Cargo.toml @@ -10,10 +10,11 @@ test = false [dependencies] anyhow = "1" -tokio = { version = "1.23.1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } tracing-subscriber = "0.3.15" url = "2.2.2" [dependencies.matrix-sdk] path = "../../crates/matrix-sdk" version = "0.6.0" +features = ["sso-login"] diff --git a/examples/login/src/main.rs b/examples/login/src/main.rs index e75ab69f9ec..f954ee2eb40 100644 --- a/examples/login/src/main.rs +++ b/examples/login/src/main.rs @@ -1,64 +1,232 @@ -use std::{env, process::exit}; +use std::{ + env, fmt, + io::{self, Write}, + process::exit, +}; +use anyhow::anyhow; use matrix_sdk::{ self, config::SyncSettings, room::Room, - ruma::events::room::message::{ - MessageType, OriginalSyncRoomMessageEvent, RoomMessageEventContent, TextMessageEventContent, + ruma::{ + api::client::session::get_login_types::v3::{IdentityProvider, LoginType}, + events::room::message::{MessageType, OriginalSyncRoomMessageEvent}, }, Client, }; use url::Url; -async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) { - if let Room::Joined(room) = room { - if let OriginalSyncRoomMessageEvent { - content: - RoomMessageEventContent { - msgtype: MessageType::Text(TextMessageEventContent { body: msg_body, .. }), - .. - }, - sender, - .. - } = event - { - let member = room.get_member(&sender).await.unwrap().unwrap(); - let name = member.display_name().unwrap_or_else(|| member.user_id().as_str()); - println!("{name}: {msg_body}"); +/// The initial device name when logging in with a device for the first time. +const INITIAL_DEVICE_DISPLAY_NAME: &str = "login client"; + +/// A simple program that adapts to the different login methods offered by a +/// Matrix homeserver. +/// +/// Homeservers usually offer to login either via password, Single Sign-On (SSO) +/// or both. +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt::init(); + + let Some(homeserver_url) = env::args().nth(1) else { + eprintln!( + "Usage: {} ", + env::args().next().unwrap() + ); + exit(1) + }; + + login_and_sync(homeserver_url).await?; + + Ok(()) +} + +/// Log in to the given homeserver and sync. +async fn login_and_sync(homeserver_url: String) -> anyhow::Result<()> { + let homeserver_url = Url::parse(&homeserver_url)?; + let client = Client::new(homeserver_url).await?; + + // First, let's figure out what login types are supported by the homeserver. + let mut choices = Vec::new(); + let login_types = client.get_login_types().await?.flows; + + for login_type in login_types { + match login_type { + LoginType::Password(_) => { + choices.push(LoginChoice::Password) + } + LoginType::Sso(sso) => { + if sso.identity_providers.is_empty() { + choices.push(LoginChoice::Sso) + } else { + choices.extend(sso.identity_providers.into_iter().map(LoginChoice::SsoIdp)) + } + } + // This is used for SSO, so it's not a separate choice. + LoginType::Token(_) | + // This is only for application services, ignore it here. + LoginType::ApplicationService(_) => {}, + // We don't support unknown login types. + _ => {}, } } -} -async fn login(homeserver_url: String, username: &str, password: &str) -> matrix_sdk::Result<()> { - let homeserver_url = Url::parse(&homeserver_url).expect("Couldn't parse the homeserver URL"); - let client = Client::new(homeserver_url).await.unwrap(); + match choices.len() { + 0 => return Err(anyhow!("Homeserver login types incompatible with this client")), + 1 => choices[0].login(&client).await?, + _ => offer_choices_and_login(&client, choices).await?, + } + // Now that we are logged in, we can sync and listen to new messages. client.add_event_handler(on_room_message); - - client.login_username(username, password).initial_device_display_name("rust-sdk").await?; + // This will sync until an error happens or the program is killed. client.sync(SyncSettings::new()).await?; Ok(()) } -#[tokio::main] -async fn main() -> anyhow::Result<()> { - tracing_subscriber::fmt::init(); +#[derive(Debug)] +enum LoginChoice { + /// Login with username and password. + Password, - let (homeserver_url, username, password) = - match (env::args().nth(1), env::args().nth(2), env::args().nth(3)) { - (Some(a), Some(b), Some(c)) => (a, b, c), - _ => { - eprintln!( - "Usage: {} ", - env::args().next().unwrap() - ); - exit(1) + /// Login with SSO. + Sso, + + /// Login with a specific SSO identity provider. + SsoIdp(IdentityProvider), +} + +impl LoginChoice { + /// Login with this login choice. + async fn login(&self, client: &Client) -> anyhow::Result<()> { + match self { + LoginChoice::Password => login_with_password(client).await, + LoginChoice::Sso => login_with_sso(client, None).await, + LoginChoice::SsoIdp(idp) => login_with_sso(client, Some(idp)).await, + } + } +} + +impl fmt::Display for LoginChoice { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + LoginChoice::Password => write!(f, "Username and password"), + LoginChoice::Sso => write!(f, "SSO"), + LoginChoice::SsoIdp(idp) => write!(f, "SSO via {}", idp.name), + } + } +} + +/// Offer the given choices to the user and login with the selected option. +async fn offer_choices_and_login(client: &Client, choices: Vec) -> anyhow::Result<()> { + println!("Several options are available to login with this homeserver:\n"); + + let choice = loop { + for (idx, login_choice) in choices.iter().enumerate() { + println!("{idx}) {login_choice}"); + } + + print!("\nEnter your choice: "); + io::stdout().flush().expect("Unable to write to stdout"); + let mut choice_str = String::new(); + io::stdin().read_line(&mut choice_str).expect("Unable to read user input"); + + match choice_str.trim().parse::() { + Ok(choice) => { + if choice >= choices.len() { + eprintln!("This is not a valid choice"); + } else { + break choice; + } } + Err(_) => eprintln!("This is not a valid choice. Try again.\n"), }; + }; - login(homeserver_url, &username, &password).await?; + choices[choice].login(client).await?; Ok(()) } + +/// Login with a username and password. +async fn login_with_password(client: &Client) -> anyhow::Result<()> { + println!("Logging in with username and password…"); + + loop { + print!("\nUsername: "); + io::stdout().flush().expect("Unable to write to stdout"); + let mut username = String::new(); + io::stdin().read_line(&mut username).expect("Unable to read user input"); + username = username.trim().to_owned(); + + print!("Password: "); + io::stdout().flush().expect("Unable to write to stdout"); + let mut password = String::new(); + io::stdin().read_line(&mut password).expect("Unable to read user input"); + password = password.trim().to_owned(); + + match client + .login_username(&username, &password) + .initial_device_display_name(INITIAL_DEVICE_DISPLAY_NAME) + .await + { + Ok(_) => { + println!("Logged in as {username}"); + break; + } + Err(error) => { + println!("Error logging in: {error}"); + println!("Please try again\n"); + } + } + } + + Ok(()) +} + +/// Login with SSO. +async fn login_with_sso(client: &Client, idp: Option<&IdentityProvider>) -> anyhow::Result<()> { + println!("Logging in with SSO…"); + + let mut login_builder = client.login_sso(|url| async move { + // Usually we would want to use a library to open the URL in the browser, but + // let's keep it simple. + println!("\nOpen this URL in your browser: {url}\n"); + println!("Waiting for login token…"); + Ok(()) + }); + + if let Some(idp) = idp { + login_builder = login_builder.identity_provider_id(&idp.id); + } + + login_builder.await?; + + println!("Logged in as {}", client.user_id().unwrap()); + + Ok(()) +} + +/// Handle room messages by logging them. +async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) { + // We only want to listen to joined rooms. + let Room::Joined(room) = room else { + return; + }; + + // We only want to log text messages. + let MessageType::Text(msgtype) = &event.content.msgtype else { + return; + }; + + let member = room + .get_member(&event.sender) + .await + .expect("Couldn't get the room member") + .expect("The room member doesn't exist"); + let name = member.name(); + + println!("{name}: {}", msgtype.body); +} diff --git a/examples/persist_session/Cargo.toml b/examples/persist_session/Cargo.toml new file mode 100644 index 00000000000..2aa9663f7bc --- /dev/null +++ b/examples/persist_session/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "example-persist-session" +version = "0.1.0" +edition = "2021" +publish = false + +[[bin]] +name = "example-persist-session" +test = false + +[dependencies] +anyhow = "1" +dirs = "4.0.0" +rand = "0.8.5" +serde = { workspace = true } +serde_json = { workspace = true } +tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } +tracing-subscriber = "0.3.15" + +[dependencies.matrix-sdk] +path = "../../crates/matrix-sdk" +version = "0.6.0" diff --git a/examples/persist_session/src/main.rs b/examples/persist_session/src/main.rs new file mode 100644 index 00000000000..f54aab31ff5 --- /dev/null +++ b/examples/persist_session/src/main.rs @@ -0,0 +1,312 @@ +use std::{ + io::{self, Write}, + path::{Path, PathBuf}, +}; + +use matrix_sdk::{ + config::SyncSettings, + room::Room, + ruma::{ + api::client::filter::{FilterDefinition, LazyLoadOptions, RoomEventFilter, RoomFilter}, + events::room::message::{MessageType, OriginalSyncRoomMessageEvent}, + }, + Client, Error, LoopCtrl, Session, +}; +use rand::{distributions::Alphanumeric, thread_rng, Rng}; +use serde::{Deserialize, Serialize}; +use tokio::fs; + +/// The data needed to re-build a client. +#[derive(Debug, Serialize, Deserialize)] +struct ClientSession { + /// The URL of the homeserver of the user. + homeserver: String, + + /// The path of the database. + db_path: PathBuf, + + /// The passphrase of the database. + passphrase: String, +} + +/// The full session to persist. +#[derive(Debug, Serialize, Deserialize)] +struct FullSession { + /// The data to re-build the client. + client_session: ClientSession, + + /// The Matrix user session. + user_session: Session, + + /// The latest sync token. + /// + /// It is only needed to persist it when using `Client::sync_once()` and we + /// want to make our syncs faster by not receiving all the initial sync + /// again. + #[serde(skip_serializing_if = "Option::is_none")] + sync_token: Option, +} + +/// A simple example to show how to persist a client's data to be able to +/// restore it. +/// +/// Restoring a session with encryption without having a persisted store +/// will break the encryption setup and the client will not be able to send or +/// receive encrypted messages, hence the need to persist the session. +/// +/// To use this, just run `cargo run -p example-persist-session`, and everything +/// is interactive after that. You might want to set the `RUST_LOG` environment +/// variable to `warn` to reduce the noise in the logs. The program exits +/// whenever an unexpected error occurs. +/// +/// To reset the login, simply delete the folder containing the session +/// file, the location is shown in the logs. Note that the database must be +/// deleted too as it can't be reused. +#[tokio::main] +async fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt::init(); + + // The folder containing this example's data. + let data_dir = dirs::data_dir().expect("no data_dir directory found").join("persist_session"); + // The file where the session is persisted. + let session_file = data_dir.join("session"); + + let (client, sync_token) = if session_file.exists() { + restore_session(&session_file).await? + } else { + (login(&data_dir, &session_file).await?, None) + }; + + sync(client, sync_token, &session_file).await.map_err(Into::into) +} + +/// Restore a previous session. +async fn restore_session(session_file: &Path) -> anyhow::Result<(Client, Option)> { + println!("Previous session found in '{}'", session_file.to_string_lossy()); + + // The session was serialized as JSON in a file. + let serialized_session = fs::read_to_string(session_file).await?; + let FullSession { client_session, user_session, sync_token } = + serde_json::from_str(&serialized_session)?; + + // Build the client with the previous settings from the session. + let client = Client::builder() + .homeserver_url(client_session.homeserver) + .sled_store(client_session.db_path, Some(&client_session.passphrase)) + .build() + .await?; + + println!("Restoring session for {}…", user_session.user_id); + + // Restore the Matrix user session. + client.restore_session(user_session).await?; + + Ok((client, sync_token)) +} + +/// Login with a new device. +async fn login(data_dir: &Path, session_file: &Path) -> anyhow::Result { + println!("No previous session found, logging in…"); + + let (client, client_session) = build_client(data_dir).await?; + + loop { + print!("\nUsername: "); + io::stdout().flush().expect("Unable to write to stdout"); + let mut username = String::new(); + io::stdin().read_line(&mut username).expect("Unable to read user input"); + username = username.trim().to_owned(); + + print!("Password: "); + io::stdout().flush().expect("Unable to write to stdout"); + let mut password = String::new(); + io::stdin().read_line(&mut password).expect("Unable to read user input"); + password = password.trim().to_owned(); + + match client + .login_username(&username, &password) + .initial_device_display_name("persist-session client") + .await + { + Ok(_) => { + println!("Logged in as {username}"); + break; + } + Err(error) => { + println!("Error logging in: {error}"); + println!("Please try again\n"); + } + } + } + + // Persist the session to reuse it later. + // This is not very secure, for simplicity. If the system provides a way of + // storing secrets securely, it should be used instead. + // Note that we could also build the user session from the login response. + let user_session = client.session().expect("A logged-in client should have a session"); + let serialized_session = + serde_json::to_string(&FullSession { client_session, user_session, sync_token: None })?; + fs::write(session_file, serialized_session).await?; + + println!("Session persisted in {}", session_file.to_string_lossy()); + + // After logging in, you might want to verify this session with another one (see + // the `emoji_verification` example), or bootstrap cross-signing if this is your + // first session with encryption, or if you need to reset cross-signing because + // you don't have access to your old sessions (see the + // `cross_signing_bootstrap` example). + + Ok(client) +} + +/// Build a new client. +async fn build_client(data_dir: &Path) -> anyhow::Result<(Client, ClientSession)> { + let mut rng = thread_rng(); + + // Generating a subfolder for the database is not mandatory, but it is useful if + // you allow several clients to run at the same time. Each one must have a + // separate database, which is a different folder with the sled store. + let db_subfolder: String = + (&mut rng).sample_iter(Alphanumeric).take(7).map(char::from).collect(); + let db_path = data_dir.join(db_subfolder); + + // Generate a random passphrase. + let passphrase: String = + (&mut rng).sample_iter(Alphanumeric).take(32).map(char::from).collect(); + + // We create a loop here so the user can retry if an error happens. + loop { + let mut homeserver = String::new(); + + print!("Homeserver URL: "); + io::stdout().flush().expect("Unable to write to stdout"); + io::stdin().read_line(&mut homeserver).expect("Unable to read user input"); + + println!("\nChecking homeserver…"); + + match Client::builder() + .homeserver_url(&homeserver) + // We use the sled store, which is enabled by default. This is the crucial part to + // persist the encryption setup. + // Note that other store backends are available and you an even implement your own. + .sled_store(&db_path, Some(&passphrase)) + .build() + .await + { + Ok(client) => return Ok((client, ClientSession { homeserver, db_path, passphrase })), + Err(error) => match &error { + matrix_sdk::ClientBuildError::AutoDiscovery(_) + | matrix_sdk::ClientBuildError::Url(_) + | matrix_sdk::ClientBuildError::Http(_) => { + println!("Error checking the homeserver: {error}"); + println!("Please try again\n"); + } + _ => { + // Forward other errors, it's unlikely we can retry with a different outcome. + return Err(error.into()); + } + }, + } + } +} + +/// Setup the client to listen to new messages. +async fn sync( + client: Client, + initial_sync_token: Option, + session_file: &Path, +) -> anyhow::Result<()> { + println!("Launching a first sync to ignore past messages…"); + + // Enable room members lazy-loading, it will speed up the initial sync a lot + // with accounts in lots of rooms. + // See . + let mut state_filter = RoomEventFilter::empty(); + state_filter.lazy_load_options = LazyLoadOptions::Enabled { include_redundant_members: false }; + let mut room_filter = RoomFilter::empty(); + room_filter.state = state_filter; + let mut filter = FilterDefinition::empty(); + filter.room = room_filter; + + let mut sync_settings = SyncSettings::default().filter(filter.into()); + + // We restore the sync where we left. + // This is not necessary when not using `sync_once`. The other sync methods get + // the sync token from the store. + if let Some(sync_token) = initial_sync_token { + sync_settings = sync_settings.token(sync_token); + } + + // Let's ignore messages before the program was launched. + // This is a loop in case the initial sync is longer than our timeout. The + // server should cache the response and it will ultimately take less time to + // receive. + loop { + match client.sync_once(sync_settings.clone()).await { + Ok(response) => { + // This is the last time we need to provide this token, the sync method after + // will handle it on its own. + sync_settings = sync_settings.token(response.next_batch.clone()); + persist_sync_token(session_file, response.next_batch).await?; + break; + } + Err(error) => { + println!("An error occurred during initial sync: {error}"); + println!("Trying again…"); + } + } + } + + println!("The client is ready! Listening to new messages…"); + + // Now that we've synced, let's attach a handler for incoming room messages. + client.add_event_handler(on_room_message); + + // This loops until we kill the program or an error happens. + client + .sync_with_result_callback(sync_settings, |sync_result| async move { + let response = sync_result?; + + // We persist the token each time to be able to restore our session + persist_sync_token(session_file, response.next_batch) + .await + .map_err(|err| Error::UnknownError(err.into()))?; + + Ok(LoopCtrl::Continue) + }) + .await?; + + Ok(()) +} + +/// Persist the sync token for a future session. +/// Note that this is needed only when using `sync_once`. Other sync methods get +/// the sync token from the store. +async fn persist_sync_token(session_file: &Path, sync_token: String) -> anyhow::Result<()> { + let serialized_session = fs::read_to_string(session_file).await?; + let mut full_session: FullSession = serde_json::from_str(&serialized_session)?; + + full_session.sync_token = Some(sync_token); + let serialized_session = serde_json::to_string(&full_session)?; + fs::write(session_file, serialized_session).await?; + + Ok(()) +} + +/// Handle room messages. +async fn on_room_message(event: OriginalSyncRoomMessageEvent, room: Room) { + // We only want to log text messages in joined rooms. + let Room::Joined(room) = room else { return }; + let MessageType::Text(text_content) = &event.content.msgtype else { return }; + + let room_name = match room.display_name().await { + Ok(room_name) => room_name.to_string(), + Err(error) => { + println!("Error getting room display name: {error}"); + // Let's fallback to the room ID. + room.room_id().to_string() + } + }; + + println!("[{room_name}] {}: {}", event.sender, text_content.body) +} diff --git a/examples/timeline/Cargo.toml b/examples/timeline/Cargo.toml index 13dee4be673..bc9c979e6ca 100644 --- a/examples/timeline/Cargo.toml +++ b/examples/timeline/Cargo.toml @@ -12,8 +12,7 @@ test = false anyhow = "1" clap = "4.0.16" futures = "0.3" -futures-signals = { version = "0.3.30", default-features = false } -tokio = { version = "1.23.1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } tracing-subscriber = "0.3.15" url = "2.2.2" diff --git a/examples/timeline/src/main.rs b/examples/timeline/src/main.rs index 297d73071b9..994449833c6 100644 --- a/examples/timeline/src/main.rs +++ b/examples/timeline/src/main.rs @@ -1,7 +1,6 @@ use anyhow::Result; use clap::Parser; use futures::StreamExt; -use futures_signals::signal_vec::SignalVecExt; use matrix_sdk::{self, config::SyncSettings, ruma::OwnedRoomId, Client}; use url::Url; @@ -33,8 +32,10 @@ struct Cli { } async fn login(cli: Cli) -> Result { - let mut builder = - Client::builder().homeserver_url(cli.homeserver).sled_store("./", Some("some password")); + // Note that when encryption is enabled, you should use a persistent store to be + // able to restore the session with a working encryption setup. + // See the `persist_session` example. + let mut builder = Client::builder().homeserver_url(cli.homeserver); if let Some(proxy) = cli.proxy { builder = builder.proxy(proxy); @@ -68,11 +69,12 @@ async fn main() -> Result<()> { // Get the timeline stream and listen to it. let room = client.get_room(&room_id).unwrap(); let timeline = room.timeline().await; - let mut timeline_stream = timeline.signal().to_stream(); + let (timeline_items, mut timeline_stream) = timeline.subscribe().await; + println!("Initial timeline items: {timeline_items:#?}"); tokio::spawn(async move { while let Some(diff) = timeline_stream.next().await { - println!("Received a timeline diff {diff:#?}"); + println!("Received a timeline diff: {diff:#?}"); } }); diff --git a/labs/jack-in/Cargo.toml b/labs/jack-in/Cargo.toml index 5fc48724a8b..bae5ddd480c 100644 --- a/labs/jack-in/Cargo.toml +++ b/labs/jack-in/Cargo.toml @@ -13,9 +13,10 @@ app_dirs2 = "2" chrono = "0.4.23" clap = { version = "4.0.29", features = ["derive", "env"] } dialoguer = "0.10.2" +eyeball = { workspace = true } +eyeball-im = { workspace = true } eyre = "0.6" futures = { version = "0.3.1" } -futures-signals = "0.3.24" matrix-sdk = { path = "../../crates/matrix-sdk", default-features = false, features = ["e2e-encryption", "anyhow", "native-tls", "sled", "experimental-sliding-sync", "experimental-timeline"], version = "0.6.0" } matrix-sdk-common = { path = "../../crates/matrix-sdk-common", version = "0.6.0" } matrix-sdk-sled = { path = "../../crates/matrix-sdk-sled", features = ["state-store", "crypto-store"], version = "0.2.0" } diff --git a/labs/jack-in/src/app/model.rs b/labs/jack-in/src/app/model.rs index 5a9b93200fd..2f833ec996e 100644 --- a/labs/jack-in/src/app/model.rs +++ b/labs/jack-in/src/app/model.rs @@ -2,7 +2,7 @@ //! //! app model -use std::{ops::Deref, time::Duration}; +use std::time::Duration; use futures::executor::block_on; use matrix_sdk::{ruma::events::room::message::RoomMessageEventContent, Client}; @@ -212,7 +212,7 @@ impl Update for Model { None } Msg::SendMessage(m) => { - if let Some(tl) = self.sliding_sync.room_timeline.lock_ref().deref() { + if let Some(tl) = &*self.sliding_sync.room_timeline.read() { block_on(async move { // fire and forget tl.send(RoomMessageEventContent::text_plain(m).into(), None).await; diff --git a/labs/jack-in/src/client/mod.rs b/labs/jack-in/src/client/mod.rs index d649f3988da..d411211ed84 100644 --- a/labs/jack-in/src/client/mod.rs +++ b/labs/jack-in/src/client/mod.rs @@ -7,7 +7,7 @@ pub mod state; use matrix_sdk::{ ruma::{api::client::error::ErrorKind, OwnedRoomId}, - Client, SlidingSyncState, SlidingSyncViewBuilder, + Client, SlidingSyncListBuilder, SlidingSyncState, }; pub async fn run_client( @@ -17,7 +17,7 @@ pub async fn run_client( ) -> Result<()> { info!("Starting sliding sync now"); let builder = client.sliding_sync().await; - let mut full_sync_view_builder = SlidingSyncViewBuilder::default_with_fullsync() + let mut full_sync_view_builder = SlidingSyncListBuilder::default_with_fullsync() .timeline_limit(10u32) .sync_mode(config.full_sync_mode.into()); if let Some(size) = config.batch_size { @@ -35,15 +35,14 @@ pub async fn run_client( let syncer = builder .homeserver(config.proxy.parse().wrap_err("can't parse sync proxy")?) - .add_view(full_sync_view) + .add_list(full_sync_view) .with_common_extensions() .cold_cache("jack-in-default") .build() .await?; let stream = syncer.stream(); - let view = syncer.view("full-sync").expect("we have the full syncer there").clone(); - let state = view.state.clone(); - let mut ssync_state = state::SlidingSyncState::new(syncer.clone(), view); + let view = syncer.list("full-sync").expect("we have the full syncer there").clone(); + let mut ssync_state = state::SlidingSyncState::new(syncer.clone(), view.clone()); tx.send(ssync_state.clone()).await?; info!("starting polling"); @@ -64,7 +63,7 @@ pub async fn run_client( match stream.next().await { Some(Ok(_)) => { // we are switching into live updates mode next. ignoring - let state = state.read_only().get_cloned(); + let state = view.state(); ssync_state.set_view_state(state.clone()); if state == SlidingSyncState::Live { @@ -96,7 +95,7 @@ pub async fn run_client( while let Some(update) = stream.next().await { { - let selected_room = ssync_state.selected_room.lock_ref().clone(); + let selected_room = ssync_state.selected_room.get(); if let Some(room_id) = selected_room { if let Some(prev) = &prev_selected_room { if prev != &room_id { diff --git a/labs/jack-in/src/client/state.rs b/labs/jack-in/src/client/state.rs index 683345eead1..5f60af891ca 100644 --- a/labs/jack-in/src/client/state.rs +++ b/labs/jack-in/src/client/state.rs @@ -1,17 +1,15 @@ use std::{ - sync::Arc, + sync::{Arc, RwLock as StdRwLock}, time::{Duration, Instant}, }; +use eyeball::shared::Observable as SharedObservable; +use eyeball_im::{ObservableVector, VectorDiff}; use futures::{pin_mut, StreamExt}; -use futures_signals::{ - signal::Mutable, - signal_vec::{MutableVec, VecDiff}, -}; use matrix_sdk::{ room::timeline::{Timeline, TimelineItem}, ruma::{OwnedRoomId, RoomId}, - SlidingSync, SlidingSyncRoom, SlidingSyncState as ViewState, SlidingSyncView, + SlidingSync, SlidingSyncList, SlidingSyncRoom, SlidingSyncState as ViewState, }; use tokio::task::JoinHandle; @@ -26,19 +24,19 @@ pub struct CurrentRoomSummary { pub struct SlidingSyncState { started: Instant, syncer: SlidingSync, - view: SlidingSyncView, + view: SlidingSyncList, /// the current list selector for the room first_render: Option, full_sync: Option, current_state: ViewState, - tl_handle: Mutable>>, - pub selected_room: Mutable>, - pub current_timeline: MutableVec>, - pub room_timeline: Mutable>, + tl_handle: SharedObservable>>, + pub selected_room: SharedObservable>, + pub current_timeline: Arc>>>, + pub room_timeline: SharedObservable>, } impl SlidingSyncState { - pub fn new(syncer: SlidingSync, view: SlidingSyncView) -> Self { + pub fn new(syncer: SlidingSync, view: SlidingSyncList) -> Self { Self { started: Instant::now(), syncer, @@ -58,12 +56,12 @@ impl SlidingSyncState { } pub fn has_selected_room(&self) -> bool { - self.selected_room.lock_ref().is_some() + self.selected_room.read().is_some() } pub fn select_room(&self, r: Option) { - self.current_timeline.lock_mut().clear(); - if let Some(c) = self.tl_handle.lock_mut().take() { + self.current_timeline.write().unwrap().clear(); + if let Some(c) = self.tl_handle.take() { c.abort(); } if let Some(room) = r.as_ref().and_then(|room_id| self.get_room(room_id)) { @@ -71,41 +69,54 @@ impl SlidingSyncState { let room_timeline = self.room_timeline.clone(); let handle = tokio::spawn(async move { let timeline = room.timeline().await.unwrap(); - let listener = timeline.stream(); - *room_timeline.lock_mut() = Some(timeline); + let (items, listener) = timeline.subscribe().await; + room_timeline.set(Some(timeline)); + { + let mut lock = current_timeline.write().unwrap(); + lock.clear(); + lock.append(items.into_iter().collect()); + } pin_mut!(listener); while let Some(diff) = listener.next().await { match diff { - VecDiff::Clear {} => { - current_timeline.lock_mut().clear(); + VectorDiff::Append { values } => { + current_timeline.write().unwrap().append(values); + } + VectorDiff::Clear => { + current_timeline.write().unwrap().clear(); + } + VectorDiff::Insert { index, value } => { + current_timeline.write().unwrap().insert(index, value); } - VecDiff::InsertAt { index, value } => { - current_timeline.lock_mut().insert_cloned(index, value); + VectorDiff::PopBack => { + current_timeline.write().unwrap().pop_back(); } - VecDiff::Move { old_index, new_index } => { - current_timeline.lock_mut().move_from_to(old_index, new_index); + VectorDiff::PopFront => { + current_timeline.write().unwrap().pop_front(); } - VecDiff::Pop {} => { - current_timeline.lock_mut().pop(); + VectorDiff::PushBack { value } => { + current_timeline.write().unwrap().push_back(value); } - VecDiff::Push { value } => { - current_timeline.lock_mut().push_cloned(value); + VectorDiff::PushFront { value } => { + current_timeline.write().unwrap().push_front(value); } - VecDiff::RemoveAt { index } => { - current_timeline.lock_mut().remove(index); + VectorDiff::Remove { index } => { + current_timeline.write().unwrap().remove(index); } - VecDiff::Replace { values } => { - current_timeline.lock_mut().replace_cloned(values); + VectorDiff::Set { index, value } => { + current_timeline.write().unwrap().set(index, value); } - VecDiff::UpdateAt { index, value } => { - current_timeline.lock_mut().set_cloned(index, value); + VectorDiff::Reset { values } => { + let mut lock = current_timeline.write().unwrap(); + lock.clear(); + lock.append(values); } } } }); - *self.tl_handle.lock_mut() = Some(handle); + self.tl_handle.set(Some(handle)); } - self.selected_room.replace(r); + self.selected_room.set(r); } pub fn time_to_first_render(&self) -> Option { @@ -124,14 +135,14 @@ impl SlidingSyncState { } pub fn total_rooms_count(&self) -> Option { - self.view.rooms_count.get() + self.view.rooms_count() } pub fn set_first_render_now(&mut self) { self.first_render = Some(self.started.elapsed()) } - pub fn view(&self) -> &SlidingSyncView { + pub fn view(&self) -> &SlidingSyncList { &self.view } diff --git a/labs/jack-in/src/components/details.rs b/labs/jack-in/src/components/details.rs index efd38009b86..351a31c990e 100644 --- a/labs/jack-in/src/components/details.rs +++ b/labs/jack-in/src/components/details.rs @@ -46,15 +46,15 @@ impl Details { } pub fn refresh_data(&mut self) { - let Some(room_id) = self.sstate.selected_room.lock_ref().clone() else { return }; + let Some(room_id) = self.sstate.selected_room.get() else { return }; let Some(room_data) = self.sstate.get_room(&room_id) else { return; }; - let name = room_data.name.clone().unwrap_or_else(|| "unknown".to_owned()); + let name = room_data.name().unwrap_or("unknown").to_owned(); let state_events = room_data - .required_state + .required_state() .iter() .filter_map(|r| r.deserialize().ok()) .fold(BTreeMap::>::new(), |mut b, r| { @@ -72,7 +72,8 @@ impl Details { let timeline: Vec = self .sstate .current_timeline - .lock_ref() + .read() + .unwrap() .iter() .filter_map(|t| t.as_event()) // we ignore virtual events .map(|e| match e.content() { diff --git a/labs/jack-in/src/components/rooms.rs b/labs/jack-in/src/components/rooms.rs index bef90f34bdc..f6cd780cdff 100644 --- a/labs/jack-in/src/components/rooms.rs +++ b/labs/jack-in/src/components/rooms.rs @@ -64,9 +64,8 @@ impl MockComponent for Rooms { let mut paras = vec![]; for r in self.sstate.get_all_rooms() { - let mut cells = - vec![Cell::from(r.name.clone().unwrap_or_else(|| "unknown".to_owned()))]; - if let Some(c) = r.unread_notifications.notification_count { + let mut cells = vec![Cell::from(r.name().unwrap_or("unknown").to_owned())]; + if let Some(c) = r.unread_notifications().notification_count { let count: u32 = c.try_into().unwrap_or_default(); if count > 0 { cells.push(Cell::from(c.to_string())) diff --git a/labs/jack-in/src/main.rs b/labs/jack-in/src/main.rs index 40971216b97..cd97b27503d 100644 --- a/labs/jack-in/src/main.rs +++ b/labs/jack-in/src/main.rs @@ -7,8 +7,8 @@ use std::{path::Path, time::Duration}; use app_dirs2::{app_root, AppDataType, AppInfo}; use clap::Parser; use dialoguer::{theme::ColorfulTheme, Password}; +use eyeball_im::VectorDiff; use eyre::{eyre, Result}; -use futures_signals::signal_vec::VecDiff; use matrix_sdk::{ config::RequestConfig, room::timeline::TimelineItem, @@ -51,7 +51,7 @@ pub enum Msg { pub enum JackInEvent { Any, // match all SyncUpdate(client::state::SlidingSyncState), - RoomDataUpdate(VecDiff), + RoomDataUpdate(VectorDiff), } impl PartialOrd for JackInEvent { diff --git a/labs/sled-state-inspector/Cargo.toml b/labs/sled-state-inspector/Cargo.toml deleted file mode 100644 index b7126992d44..00000000000 --- a/labs/sled-state-inspector/Cargo.toml +++ /dev/null @@ -1,22 +0,0 @@ -[package] -name = "sled-state-inspector" -version = "0.1.0" -edition = "2021" -publish = false - -[[bin]] -name = "sled-state-inspector" -test = false - -[dependencies] -atty = "0.2.14" -clap = "3.2.4" -futures = { version = "0.3.21", default-features = false, features = ["executor"] } -matrix-sdk-base = { path = "../../crates/matrix-sdk-base", version = "0.6.0"} -matrix-sdk-sled = { path = "../../crates/matrix-sdk-sled", version = "0.2.0"} -ruma = { workspace = true } -rustyline = "10.0.0" -rustyline-derive = "0.7.0" -serde = { workspace = true } -serde_json = { workspace = true } -syntect = { version = "5.0.0", default-features = false, features = ["dump-load", "parsing", "regex-fancy"] } diff --git a/labs/sled-state-inspector/src/main.rs b/labs/sled-state-inspector/src/main.rs deleted file mode 100644 index 8c939d27544..00000000000 --- a/labs/sled-state-inspector/src/main.rs +++ /dev/null @@ -1,367 +0,0 @@ -use std::{fmt::Debug, sync::Arc}; - -use atty::Stream; -use clap::{Arg, ArgMatches, Command as Argparse}; -use futures::executor::block_on; -use matrix_sdk_base::{RoomInfo, StateStore}; -use matrix_sdk_sled::SledStateStore; -use ruma::{events::StateEventType, OwnedRoomId, OwnedUserId, RoomId}; -use rustyline::{ - completion::{Completer, Pair}, - error::ReadlineError, - highlight::{Highlighter, MatchingBracketHighlighter}, - hint::{Hinter, HistoryHinter}, - validate::{MatchingBracketValidator, Validator}, - CompletionType, Config, Context, EditMode, Editor, -}; -use rustyline_derive::Helper; -use serde::Serialize; -use syntect::{ - dumps::from_binary, - easy::HighlightLines, - highlighting::{Style, ThemeSet}, - parsing::SyntaxSet, - util::{as_24_bit_terminal_escaped, LinesWithEndings}, -}; - -#[derive(Clone)] -struct Inspector { - store: Arc, - printer: Printer, -} - -#[derive(Helper)] -struct InspectorHelper { - store: Arc, - _highlighter: MatchingBracketHighlighter, - _validator: MatchingBracketValidator, - _hinter: HistoryHinter, -} - -impl InspectorHelper { - const EVENT_TYPES: &'static [&'static str] = &[ - "m.room.aliases", - "m.room.avatar", - "m.room.canonical_alias", - "m.room.create", - "m.room.encryption", - "m.room.guest_access", - "m.room.history_visibility", - "m.room.join_rules", - "m.room.name", - "m.room.power_levels", - "m.room.tombstone", - "m.room.topic", - ]; - - fn new(store: Arc) -> Self { - Self { - store, - _highlighter: MatchingBracketHighlighter::new(), - _validator: MatchingBracketValidator::new(), - _hinter: HistoryHinter {}, - } - } - - fn complete_event_types(&self, arg: Option<&&str>) -> Vec { - Self::EVENT_TYPES - .iter() - .map(|&t| Pair { display: t.to_owned(), replacement: format!("{t} ") }) - .filter(|r| if let Some(arg) = arg { r.replacement.starts_with(arg) } else { true }) - .collect() - } - - fn complete_rooms(&self, arg: Option<&&str>) -> Vec { - let rooms: Vec = - block_on(async { StateStore::get_room_infos(&*self.store).await.unwrap() }); - - rooms - .into_iter() - .map(|r| Pair { - display: r.room_id().to_string(), - replacement: format!("{} ", r.room_id()), - }) - .filter(|r| if let Some(arg) = arg { r.replacement.starts_with(arg) } else { true }) - .collect() - } -} - -impl Completer for InspectorHelper { - type Candidate = Pair; - - fn complete( - &self, - line: &str, - pos: usize, - _: &Context<'_>, - ) -> Result<(usize, Vec), ReadlineError> { - let args: Vec<&str> = line.split_ascii_whitespace().collect(); - - let commands = vec![ - ("get-state", "get a state event in the given room"), - ("get-profiles", "get all the stored profiles in the given room"), - ("list-rooms", "list all rooms"), - ("get-members", "get all the membership events in the given room"), - ] - .iter() - .map(|(r, d)| Pair { display: format!("{r} ({d})"), replacement: format!("{r} ") }) - .collect(); - - if args.is_empty() { - Ok((pos, commands)) - } else if args.len() == 1 { - if (args[0] == "get-state" || args[0] == "get-members" || args[0] == "get-profiles") - && line.ends_with(' ') - { - Ok((args[0].len() + 1, self.complete_rooms(args.get(1)))) - } else { - Ok(( - 0, - commands.into_iter().filter(|c| c.replacement.starts_with(args[0])).collect(), - )) - } - } else if args.len() == 2 { - if args[0] == "get-state" { - if line.ends_with(' ') { - Ok((args[0].len() + args[1].len() + 2, self.complete_event_types(args.get(2)))) - } else { - Ok((args[0].len() + 1, self.complete_rooms(args.get(1)))) - } - } else if args[0] == "get-members" || args[0] == "get-profiles" { - Ok((args[0].len() + 1, self.complete_rooms(args.get(1)))) - } else { - Ok((pos, vec![])) - } - } else if args.len() == 3 { - if args[0] == "get-state" { - Ok((args[0].len() + args[1].len() + 2, self.complete_event_types(args.get(2)))) - } else { - Ok((pos, vec![])) - } - } else { - Ok((pos, vec![])) - } - } -} - -impl Hinter for InspectorHelper { - type Hint = String; -} - -impl Highlighter for InspectorHelper {} - -impl Validator for InspectorHelper {} - -#[derive(Clone, Debug)] -struct Printer { - ps: Arc, - ts: Arc, - json: bool, - color: bool, -} - -impl Printer { - fn new(json: bool, color: bool) -> Self { - let syntax_set: SyntaxSet = from_binary(include_bytes!("../syntaxes.bin")); - let themes: ThemeSet = from_binary(include_bytes!("../themes.bin")); - - Self { ps: syntax_set.into(), ts: themes.into(), json, color } - } - - fn pretty_print_struct(&self, data: &T) { - let data = if self.json { - serde_json::to_string_pretty(data).expect("Can't serialize struct") - } else { - format!("{data:#?}") - }; - - let syntax = if self.json { - self.ps.find_syntax_by_extension("rs").expect("Can't find rust syntax extension") - } else { - self.ps.find_syntax_by_extension("json").expect("Can't find json syntax extension") - }; - - if self.color { - let mut h = HighlightLines::new(syntax, &self.ts.themes["Forest Night"]); - - for line in LinesWithEndings::from(&data) { - let ranges: Vec<(Style, &str)> = - h.highlight_line(line, &self.ps).expect("Failed to highlight line"); - let escaped = as_24_bit_terminal_escaped(&ranges[..], false); - print!("{escaped}"); - } - - // Clear the formatting - println!("\x1b[0m"); - } else { - println!("{data}"); - } - } -} - -impl Inspector { - fn new(database_path: &str, json: bool, color: bool) -> Self { - let printer = Printer::new(json, color); - let store = Arc::new( - SledStateStore::builder() - .path(database_path.into()) - .build() - .expect("Can't open sled database"), - ); - - Self { store, printer } - } - - async fn run(&self, matches: ArgMatches) { - match matches.subcommand() { - Some(("get-profiles", args)) => { - let room_id = RoomId::parse(args.value_of("room-id").unwrap()).unwrap(); - - self.get_profiles(room_id).await; - } - - Some(("get-members", args)) => { - let room_id = RoomId::parse(args.value_of("room-id").unwrap()).unwrap(); - - self.get_members(room_id).await; - } - Some(("list-rooms", _)) => self.list_rooms().await, - Some(("get-display-names", args)) => { - let room_id = RoomId::parse(args.value_of("room-id").unwrap()).unwrap(); - let display_name = args.value_of("display-name").unwrap().to_owned(); - self.get_display_name_owners(room_id, display_name).await; - } - Some(("get-state", args)) => { - let room_id = RoomId::parse(args.value_of("room-id").unwrap()).unwrap(); - let event_type = - StateEventType::try_from(args.value_of("event-type").unwrap()).unwrap(); - self.get_state(room_id, event_type).await; - } - _ => unreachable!(), - } - } - - async fn list_rooms(&self) { - let rooms: Vec = StateStore::get_room_infos(&*self.store).await.unwrap(); - self.printer.pretty_print_struct(&rooms); - } - - async fn get_display_name_owners(&self, room_id: OwnedRoomId, display_name: String) { - let users = self.store.get_users_with_display_name(&room_id, &display_name).await.unwrap(); - self.printer.pretty_print_struct(&users); - } - - async fn get_profiles(&self, room_id: OwnedRoomId) { - let joined: Vec = - StateStore::get_joined_user_ids(&*self.store, &room_id).await.unwrap(); - - for member in joined { - let event = self.store.get_profile(&room_id, &member).await.unwrap(); - self.printer.pretty_print_struct(&event); - } - } - - async fn get_members(&self, room_id: OwnedRoomId) { - let joined: Vec = - StateStore::get_joined_user_ids(&*self.store, &room_id).await.unwrap(); - - for member in joined { - let event = self.store.get_member_event(&room_id, &member).await.unwrap(); - self.printer.pretty_print_struct(&event); - } - } - - async fn get_state(&self, room_id: OwnedRoomId, event_type: StateEventType) { - self.printer.pretty_print_struct( - &self.store.get_state_event(&room_id, event_type, "").await.unwrap(), - ); - } - - fn subcommands() -> Vec> { - vec![ - Argparse::new("list-rooms"), - Argparse::new("get-members").arg(Arg::new("room-id").required(true).validator(|r| { - RoomId::parse(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned()) - })), - Argparse::new("get-profiles").arg(Arg::new("room-id").required(true).validator(|r| { - RoomId::parse(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned()) - })), - Argparse::new("get-display-names") - .arg(Arg::new("room-id").required(true).validator(|r| { - RoomId::parse(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned()) - })) - .arg(Arg::new("display-name").required(true)), - Argparse::new("get-state") - .arg(Arg::new("room-id").required(true).validator(|r| { - RoomId::parse(r).map(|_| ()).map_err(|_| "Invalid room id given".to_owned()) - })) - .arg(Arg::new("event-type").required(true).validator(|e| { - StateEventType::try_from(e) - .map(|_| ()) - .map_err(|_| "Invalid event type".to_owned()) - })), - ] - } - - async fn parse_and_run(&self, input: &str) { - let argparse = Argparse::new("state-inspector") - .disable_version_flag(true) - .disable_help_flag(true) - .no_binary_name(true) - .subcommand_required(true) - .arg_required_else_help(true) - .subcommands(Inspector::subcommands()); - - match argparse.try_get_matches_from(input.split_ascii_whitespace()) { - Ok(m) => { - self.run(m).await; - } - Err(e) => { - println!("{e}"); - } - } - } -} - -fn main() { - let argparse = Argparse::new("state-inspector") - .disable_version_flag(true) - .arg(Arg::new("database").required(true)) - .arg( - Arg::new("json") - .long("json") - .help("set the output to raw json instead of Rust structs") - .global(true) - .takes_value(false), - ) - .subcommands(Inspector::subcommands()); - - let matches = argparse.get_matches(); - - let database_path = matches.value_of("database").expect("No database path"); - let json = matches.is_present("json"); - let color = atty::is(Stream::Stdout); - - let inspector = Inspector::new(database_path, json, color); - - if matches.subcommand().is_none() { - let config = Config::builder() - .history_ignore_space(true) - .completion_type(CompletionType::List) - .edit_mode(EditMode::Emacs) - .build(); - - let helper = InspectorHelper::new(inspector.store.clone()); - - let mut rl = - Editor::::with_config(config).expect("Failed to create Editor"); - rl.set_helper(Some(helper)); - - while let Ok(input) = rl.readline(">> ") { - rl.add_history_entry(input.as_str()); - block_on(inspector.parse_and_run(input.as_str())); - } - } else { - block_on(inspector.run(matches)); - } -} diff --git a/labs/sled-state-inspector/syntaxes.bin b/labs/sled-state-inspector/syntaxes.bin deleted file mode 100644 index 71c64c84dff..00000000000 Binary files a/labs/sled-state-inspector/syntaxes.bin and /dev/null differ diff --git a/labs/sled-state-inspector/themes.bin b/labs/sled-state-inspector/themes.bin deleted file mode 100644 index 7342dbdba1f..00000000000 Binary files a/labs/sled-state-inspector/themes.bin and /dev/null differ diff --git a/tarpaulin.toml b/tarpaulin.toml index 6a74337d3e0..89b44e20ddc 100644 --- a/tarpaulin.toml +++ b/tarpaulin.toml @@ -7,6 +7,9 @@ exclude-files = [ "**/tests/*", ] workspace = true +# sqlite crypto store is not tested otherwise because it's only activated by +# matrix-sdk-crypto-ffi, which is excluded from testing below +features = "crypto-store" exclude = [ # bindings "matrix-sdk-crypto-ffi", @@ -20,7 +23,7 @@ exclude = [ "matrix-sdk-test-macros", # labs "jack-in", - "sled-state-inspector", # repo automation (ci, codegen) + "uniffi-bindgen", "xtask", ] diff --git a/testing/matrix-sdk-integration-testing/README.md b/testing/matrix-sdk-integration-testing/README.md index 36c4c8ad78f..dd572a7c1af 100644 --- a/testing/matrix-sdk-integration-testing/README.md +++ b/testing/matrix-sdk-integration-testing/README.md @@ -28,7 +28,7 @@ To drop the database of your docker-compose run: ```bash docker-compose -f assets/docker-compose.yml stop -docker volume rm -f assets_marix-rust-sdk-ci-data +docker volume rm -f assets_matrix-rust-sdk-ci-data ``` or simply: diff --git a/testing/matrix-sdk-integration-testing/assets/Dockerfile b/testing/matrix-sdk-integration-testing/assets/Dockerfile index 56dd700a796..b68b2519096 100644 --- a/testing/matrix-sdk-integration-testing/assets/Dockerfile +++ b/testing/matrix-sdk-integration-testing/assets/Dockerfile @@ -1,4 +1,4 @@ -FROM matrixdotorg/synapse:latest +FROM docker.io/matrixdotorg/synapse:latest ADD ci-start.sh /ci-start.sh RUN chmod 770 /ci-start.sh ENTRYPOINT /ci-start.sh diff --git a/testing/matrix-sdk-integration-testing/src/tests/repeated_join.rs b/testing/matrix-sdk-integration-testing/src/tests/repeated_join.rs index a6506082d4a..0ea46b04e3c 100644 --- a/testing/matrix-sdk-integration-testing/src/tests/repeated_join.rs +++ b/testing/matrix-sdk-integration-testing/src/tests/repeated_join.rs @@ -9,7 +9,7 @@ use matrix_sdk::{ api::client::room::create_room::v3::Request as CreateRoomRequest, events::room::member::{MembershipState, StrippedRoomMemberEvent}, }, - Client, RoomType, + Client, RoomState, }; use tokio::sync::Notify; @@ -136,7 +136,7 @@ async fn signal_on_invite( return; } - if room.room_type() != RoomType::Invited { + if room.state() != RoomState::Invited { return; } diff --git a/testing/matrix-sdk-test/Cargo.toml b/testing/matrix-sdk-test/Cargo.toml index 3260f6e9ae9..6e4208c3bb8 100644 --- a/testing/matrix-sdk-test/Cargo.toml +++ b/testing/matrix-sdk-test/Cargo.toml @@ -27,7 +27,7 @@ serde = { workspace = true } serde_json = { workspace = true } [target.'cfg(not(target_arch = "wasm32"))'.dependencies] -tokio = { version = "1.23.1", default-features = false, features = ["rt", "macros"] } +tokio = { version = "1.24.2", default-features = false, features = ["rt", "macros"] } [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen-test = "0.3.33" diff --git a/testing/matrix-sdk-test/src/test_json/api_responses.rs b/testing/matrix-sdk-test/src/test_json/api_responses.rs index 3405e8038c2..d534697717a 100644 --- a/testing/matrix-sdk-test/src/test_json/api_responses.rs +++ b/testing/matrix-sdk-test/src/test_json/api_responses.rs @@ -68,6 +68,91 @@ pub static KEYS_QUERY: Lazy = Lazy::new(|| { }) }); +/// `POST /_matrix/client/v3/keys/query` +/// For a set of 2 devices own by a user named web2. +/// First device is unsigned, second one is signed +pub static KEYS_QUERY_TWO_DEVICES_ONE_SIGNED: Lazy = Lazy::new(|| { + json!({ + "device_keys":{ + "@web2:localhost:8482":{ + "AVXFQWJUQA":{ + "algorithms":[ + "m.olm.v1.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2" + ], + "device_id":"AVXFQWJUQA", + "keys":{ + "curve25519:AVXFQWJUQA":"LTpv2DGMhggPAXO02+7f68CNEp6A40F0Yl8B094Y8gc", + "ed25519:AVXFQWJUQA":"loz5i40dP+azDtWvsD0L/xpnCjNkmrcvtXVXzCHX8Vw" + }, + "signatures":{ + "@web2:localhost:8482":{ + "ed25519:AVXFQWJUQA":"BmdzjXMwZaZ0ZK8T6h3pkTA+gZbD34Bzf8FNazBdAIE16fxVzrlSJkLfXnjdBqRO0Dlda5vKgGpqJazZP6obDw" + } + }, + "user_id":"@web2:localhost:8482" + }, + "JERTCKWUWG":{ + "algorithms":[ + "m.olm.v1.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2" + ], + "device_id":"JERTCKWUWG", + "keys":{ + "curve25519:JERTCKWUWG":"XJixbpnfIk+RqcK5T6moqVY9d9Q1veR8WjjSlNiQNT0", + "ed25519:JERTCKWUWG":"48f3WQAMGwYLBg5M5qUhqnEVA8yeibjZpPsShoWMFT8" + }, + "signatures":{ + "@web2:localhost:8482":{ + "ed25519:JERTCKWUWG":"Wc67XYem4IKCpshcslQ6ketCE5otubpX+Bh01OB8ghLxl1d6exlZsgaRA57N8RJ0EMvbeTWCweHXXC/UeeQ4DQ", + "ed25519:uXOM0Xlfts9SGysk/yNr0Vn9rgv1Ifh3R8oPhtic4BM":"dto9VPhhJbNw62j8NQyjnwukMd1NtYnDYSoUOzD5dABq1u2Kt/ZdthcTO42HyxG/3/hZdno8XPfJ47l1ZxuXBA" + } + }, + "user_id":"@web2:localhost:8482" + } + } + }, + "failures":{ + + }, + "master_keys":{ + "@web2:localhost:8482":{ + "user_id":"@web2:localhost:8482", + "usage":[ + "master" + ], + "keys":{ + "ed25519:Ct4QR+aXrzW4iYIgH1B/56NkPEtSPoN+h2TGoQ0xxYI":"Ct4QR+aXrzW4iYIgH1B/56NkPEtSPoN+h2TGoQ0xxYI" + }, + "signatures":{ + "@web2:localhost:8482":{ + "ed25519:JERTCKWUWG":"H9hEsUJ+alB5XAboDzU4loVb+SZajC4tsQzGaeU/FHMFAnWeVarTMCR+NmPSGsZfvPrNz2WVS2G7FIH5yhJfBg" + } + } + } + }, + "self_signing_keys":{ + "@web2:localhost:8482":{ + "user_id":"@web2:localhost:8482", + "usage":[ + "self_signing" + ], + "keys":{ + "ed25519:uXOM0Xlfts9SGysk/yNr0Vn9rgv1Ifh3R8oPhtic4BM":"uXOM0Xlfts9SGysk/yNr0Vn9rgv1Ifh3R8oPhtic4BM" + }, + "signatures":{ + "@web2:localhost:8482":{ + "ed25519:Ct4QR+aXrzW4iYIgH1B/56NkPEtSPoN+h2TGoQ0xxYI":"YbD6gTEwY078nllTxmlyea2VNvAElQ/ig7aPsyhA3h1gGwFvPdtyDbomjdIphUF/lXQ+Eyz4SzlUWeghr1b3BA" + } + } + } + }, + "user_signing_keys":{ + + } + }) +}); + /// `` pub static KEYS_UPLOAD: Lazy = Lazy::new(|| { json!({ diff --git a/testing/matrix-sdk-test/src/test_json/mod.rs b/testing/matrix-sdk-test/src/test_json/mod.rs index b7472d36a40..90c56413f48 100644 --- a/testing/matrix-sdk-test/src/test_json/mod.rs +++ b/testing/matrix-sdk-test/src/test_json/mod.rs @@ -14,10 +14,10 @@ pub mod sync; pub mod sync_events; pub use api_responses::{ - DEVICES, GET_ALIAS, KEYS_QUERY, KEYS_UPLOAD, LOGIN, LOGIN_RESPONSE_ERR, LOGIN_TYPES, - LOGIN_WITH_DISCOVERY, LOGIN_WITH_REFRESH_TOKEN, NOT_FOUND, PUBLIC_ROOMS, REFRESH_TOKEN, - REFRESH_TOKEN_WITH_REFRESH_TOKEN, REGISTRATION_RESPONSE_ERR, UNKNOWN_TOKEN_SOFT_LOGOUT, - VERSIONS, WELL_KNOWN, WHOAMI, + DEVICES, GET_ALIAS, KEYS_QUERY, KEYS_QUERY_TWO_DEVICES_ONE_SIGNED, KEYS_UPLOAD, LOGIN, + LOGIN_RESPONSE_ERR, LOGIN_TYPES, LOGIN_WITH_DISCOVERY, LOGIN_WITH_REFRESH_TOKEN, NOT_FOUND, + PUBLIC_ROOMS, REFRESH_TOKEN, REFRESH_TOKEN_WITH_REFRESH_TOKEN, REGISTRATION_RESPONSE_ERR, + UNKNOWN_TOKEN_SOFT_LOGOUT, VERSIONS, WELL_KNOWN, WHOAMI, }; pub use members::MEMBERS; pub use messages::{ROOM_MESSAGES, ROOM_MESSAGES_BATCH_1, ROOM_MESSAGES_BATCH_2}; diff --git a/testing/sliding-sync-integration-test/Cargo.toml b/testing/sliding-sync-integration-test/Cargo.toml index bab74a456c0..49870627cff 100644 --- a/testing/sliding-sync-integration-test/Cargo.toml +++ b/testing/sliding-sync-integration-test/Cargo.toml @@ -6,9 +6,11 @@ publish = false [dependencies] anyhow = { workspace = true } -ctor = { workspace = true } +eyeball = { workspace = true } +eyeball-im = { workspace = true } matrix-sdk-integration-testing = { path = "../matrix-sdk-integration-testing", features = ["helpers"] } matrix-sdk = { path = "../../crates/matrix-sdk", features = ["experimental-sliding-sync", "testing"] } tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros"] } futures = { version = "0.3.25" } uuid = { version = "1.2.2" } +assert_matches = "1.5.0" diff --git a/testing/sliding-sync-integration-test/assets/Dockerfile b/testing/sliding-sync-integration-test/assets/Dockerfile index 56dd700a796..b68b2519096 100644 --- a/testing/sliding-sync-integration-test/assets/Dockerfile +++ b/testing/sliding-sync-integration-test/assets/Dockerfile @@ -1,4 +1,4 @@ -FROM matrixdotorg/synapse:latest +FROM docker.io/matrixdotorg/synapse:latest ADD ci-start.sh /ci-start.sh RUN chmod 770 /ci-start.sh ENTRYPOINT /ci-start.sh diff --git a/testing/sliding-sync-integration-test/assets/docker-compose.yml b/testing/sliding-sync-integration-test/assets/docker-compose.yml index a79713f02fc..f906b9eddb5 100644 --- a/testing/sliding-sync-integration-test/assets/docker-compose.yml +++ b/testing/sliding-sync-integration-test/assets/docker-compose.yml @@ -9,27 +9,30 @@ services: disable: true volumes: - ./data/synapse:/data - + ports: - 8228:8008/tcp postgres: - image: postgres + image: docker.io/postgres environment: POSTGRES_PASSWORD: postgres POSTGRES_USER: postgres POSTGRES_DB: syncv3 healthcheck: - test: ["pg_isready"] + test: ["CMD", "pg_isready"] interval: 10s timeout: 5s retries: 5 volumes: - ./data/db:/var/lib/postgresql/data - + sliding-sync-proxy: - image: ghcr.io/matrix-org/sliding-sync:v0.99.0-rc1 - # image: ghcr.io/matrix-org/sliding-sync:latest + image: ghcr.io/matrix-org/sliding-sync:v0.99.1 + depends_on: + postgres: + condition: service_healthy + links: - synapse - postgres diff --git a/testing/sliding-sync-integration-test/src/lib.rs b/testing/sliding-sync-integration-test/src/lib.rs index 4eaa880ed22..90bd8b35481 100644 --- a/testing/sliding-sync-integration-test/src/lib.rs +++ b/testing/sliding-sync-integration-test/src/lib.rs @@ -72,14 +72,25 @@ mod tests { }; use anyhow::{bail, Context}; + use assert_matches::assert_matches; + use eyeball::unique::Observable; + use eyeball_im::VectorDiff; use futures::{pin_mut, stream::StreamExt}; use matrix_sdk::{ + room::timeline::EventTimelineItem, ruma::{ - api::client::error::ErrorKind as RumaError, - events::room::message::RoomMessageEventContent, + api::client::{ + error::ErrorKind as RumaError, + receipt::create_receipt::v3::ReceiptType as CreateReceiptType, + sync::sync_events::v4::ReceiptsConfig, + }, + events::{ + receipt::{ReceiptThread, ReceiptType}, + room::message::RoomMessageEventContent, + }, + uint, }, - test_utils::force_sliding_sync_pos, - SlidingSyncMode, SlidingSyncState, SlidingSyncViewBuilder, + SlidingSyncList, SlidingSyncMode, SlidingSyncState, }; use super::*; @@ -87,7 +98,7 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn it_works_smoke_test() -> anyhow::Result<()> { let (_client, sync_proxy_builder) = setup("odo".to_owned(), false).await?; - let sync_proxy = sync_proxy_builder.add_fullsync_view().build().await?; + let sync_proxy = sync_proxy_builder.add_fullsync_list().build().await?; let stream = sync_proxy.stream(); pin_mut!(stream); let room_summary = @@ -98,29 +109,250 @@ mod tests { } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] - async fn adding_view_later() -> anyhow::Result<()> { - let view_name_1 = "sliding1"; - let view_name_2 = "sliding2"; - let view_name_3 = "sliding3"; + async fn modifying_timeline_limit() -> anyhow::Result<()> { + let (client, sync_builder) = random_setup_with_rooms(1).await?; + + // List one room. + let room_id = { + let sync = sync_builder + .clone() + .add_list( + SlidingSyncList::builder() + .sync_mode(SlidingSyncMode::Selective) + .add_range(0u32, 1) + .timeline_limit(0u32) + .name("init_list") + .build()?, + ) + .build() + .await?; + + // Get the sync stream. + let stream = sync.stream(); + pin_mut!(stream); + + // Get the list to all rooms to check the list' state. + let list = sync.list("init_list").context("list `init_list` isn't found")?; + assert_eq!(list.state(), SlidingSyncState::Cold); + + // Send the request and wait for a response. + let update_summary = stream + .next() + .await + .context("No room summary found, loop ended unsuccessfully")??; + + // Check the state has switched to `Live`. + assert_eq!(list.state(), SlidingSyncState::Live); + + // One room has received an update. + assert_eq!(update_summary.rooms.len(), 1); + + // Let's fetch the room ID then. + let room_id = update_summary.rooms[0].clone(); + + // Let's fetch the room ID from the list too. + assert_matches!(list.rooms_list().get(0), Some(RoomListEntry::Filled(same_room_id)) => { + assert_eq!(same_room_id, &room_id); + }); + + room_id + }; + + // Join a room and send 20 messages. + { + // Join the room. + let room = + client.get_joined_room(&room_id).context("Failed to join room `{room_id}`")?; + + // In this room, let's send 20 messages! + for nth in 0..20 { + let message = RoomMessageEventContent::text_plain(format!("Message #{nth}")); + + room.send(message, None).await?; + } + + // Wait on the server to receive all the messages. + tokio::time::sleep(Duration::from_secs(1)).await; + } + + let sync = sync_builder + .clone() + .add_list( + SlidingSyncList::builder() + .sync_mode(SlidingSyncMode::Selective) + .name("visible_rooms_list") + .add_range(0u32, 1) + .timeline_limit(1u32) + .build()?, + ) + .build() + .await?; + + // Get the sync stream. + let stream = sync.stream(); + pin_mut!(stream); + + // Get the list. + let list = + sync.list("visible_rooms_list").context("list `visible_rooms_list` isn't found")?; + + let mut all_event_ids = Vec::new(); + + // Sync to receive a message with a `timeline_limit` set to 1. + let (room, _timeline, mut timeline_stream) = { + let mut update_summary; + + loop { + // Wait for a response. + update_summary = stream + .next() + .await + .context("No update summary found, loop ended unsuccessfully")??; + + if !update_summary.rooms.is_empty() { + break; + } + } + + // We see that one room has received an update, and it's our room! + assert_eq!(update_summary.rooms.len(), 1); + assert_eq!(room_id, update_summary.rooms[0]); + + // OK, now let's read the timeline! + let room = sync.get_room(&room_id).expect("Failed to get the room"); + + // Test the `Timeline`. + let timeline = room.timeline().await.unwrap(); + let (timeline_items, timeline_stream) = timeline.subscribe().await; + + // First timeline item. + assert_matches!(timeline_items[0].as_virtual(), Some(_)); + + // Second timeline item. + let latest_remote_event = assert_matches!( + timeline_items[1].as_event(), + Some(EventTimelineItem::Remote(remote_event)) => remote_event + ); + all_event_ids.push(latest_remote_event.event_id().to_owned()); + + // Test the room to see the last event. + assert_matches!(room.latest_event().await, Some(EventTimelineItem::Remote(remote_event)) => { + assert_eq!(remote_event.event_id(), latest_remote_event.event_id(), "Unexpected latest event"); + assert_eq!(remote_event.content().as_message().unwrap().body(), "Message #19"); + }); + + (room, timeline, timeline_stream) + }; + + // Sync to receive messages with a `timeline_limit` set to 20. + { + Observable::set(&mut list.timeline_limit.write().unwrap(), Some(uint!(20))); + + let mut update_summary; + + loop { + // Wait for a response. + update_summary = stream + .next() + .await + .context("No update summary found, loop ended unsuccessfully")??; + + if !update_summary.rooms.is_empty() { + break; + } + } + + // We see that one room has received an update, and it's our room! + assert_eq!(update_summary.rooms.len(), 1); + assert_eq!(room_id, update_summary.rooms[0]); + + // Let's fetch the room ID from the list too. + assert_matches!(list.rooms_list().get(0), Some(RoomListEntry::Filled(same_room_id)) => { + assert_eq!(same_room_id, &room_id); + }); + + // Test the `Timeline`. + + // The first 19th items are `VectorDiff::PushBack`. + for nth in 0..19 { + assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => { + let remote_event = assert_matches!( + value.as_event(), + Some(EventTimelineItem::Remote(remote_event)) => remote_event + ); + + // Check messages arrived in the correct order. + assert_eq!( + remote_event.content().as_message().expect("Received event is not a message").body(), + format!("Message #{nth}"), + ); + + all_event_ids.push(remote_event.event_id().to_owned()); + }); + } + + // The 20th item is a `VectorDiff::Remove`, i.e. the first message is removed. + assert_matches!(timeline_stream.next().await, Some(VectorDiff::Remove { index }) => { + // Index 0 is for day divider. So our first event is at index 1. + assert_eq!(index, 1); + }); + + // And now, the initial message is pushed at the bottom, so the 21th item is a + // `VectorDiff::PushBack`. + let latest_remote_event = assert_matches!(timeline_stream.next().await, Some(VectorDiff::PushBack { value }) => { + let remote_event = assert_matches!( + value.as_event(), + Some(EventTimelineItem::Remote(remote_event)) => remote_event + ); + assert_eq!(remote_event.content().as_message().unwrap().body(), "Message #19"); + assert_eq!(remote_event.event_id(), all_event_ids[0]); + + remote_event.clone() + }); + + // Test the room to see the last event. + assert_matches!(room.latest_event().await, Some(EventTimelineItem::Remote(remote_event)) => { + assert_eq!(remote_event.content().as_message().unwrap().body(), "Message #19"); + assert_eq!(remote_event.event_id(), latest_remote_event.event_id(), "Unexpected latest event"); + }); + + // Ensure there is no event ID duplication. + { + let mut dedup_event_ids = all_event_ids.clone(); + dedup_event_ids.sort(); + dedup_event_ids.dedup(); + + assert_eq!(dedup_event_ids.len(), all_event_ids.len(), "Found duplicated event ID"); + } + } + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn adding_list_later() -> anyhow::Result<()> { + let list_name_1 = "sliding1"; + let list_name_2 = "sliding2"; + let list_name_3 = "sliding3"; let (client, sync_proxy_builder) = random_setup_with_rooms(20).await?; - let build_view = |name| { - SlidingSyncViewBuilder::default() + let build_list = |name| { + SlidingSyncList::builder() .sync_mode(SlidingSyncMode::Selective) .set_range(0u32, 10u32) - .sort(vec!["by_recency".to_string(), "by_name".to_string()]) + .sort(vec!["by_recency".to_owned(), "by_name".to_owned()]) .name(name) .build() }; let sync_proxy = sync_proxy_builder - .add_view(build_view(view_name_1)?) - .add_view(build_view(view_name_2)?) + .add_list(build_list(list_name_1)?) + .add_list(build_list(list_name_2)?) .build() .await?; - let view1 = sync_proxy.view(view_name_1).context("but we just added that view!")?; - let _view2 = sync_proxy.view(view_name_2).context("but we just added that view!")?; + let list1 = sync_proxy.list(list_name_1).context("but we just added that list!")?; + let _list2 = sync_proxy.list(list_name_2).context("but we just added that list!")?; - assert!(sync_proxy.view(view_name_3).is_none()); + assert!(sync_proxy.list(list_name_3).is_none()); let stream = sync_proxy.stream(); pin_mut!(stream); @@ -128,11 +360,11 @@ mod tests { stream.next().await.context("No room summary found, loop ended unsuccessfully")?; let summary = room_summary?; // we only heard about the ones we had asked for - assert_eq!(summary.views, [view_name_1, view_name_2]); + assert_eq!(summary.lists, [list_name_1, list_name_2]); - assert!(sync_proxy.add_view(build_view(view_name_3)?).is_none()); + assert!(sync_proxy.add_list(build_list(list_name_3)?).is_none()); - // we need to restart the stream after every view listing update + // we need to restart the stream after every list listing update let stream = sync_proxy.stream(); pin_mut!(stream); @@ -141,10 +373,10 @@ mod tests { let room_summary = stream.next().await.context("sync has closed unexpectedly")?; let summary = room_summary?; // we only heard about the ones we had asked for - if !summary.views.is_empty() { + if !summary.lists.is_empty() { // only if we saw an update come through - assert_eq!(summary.views, [view_name_3]); - // we didn't update the other views, so only no 2 should se an update + assert_eq!(summary.lists, [list_name_3]); + // we didn't update the other lists, so only no 2 should se an update saw_update = true; break; } @@ -152,15 +384,8 @@ mod tests { assert!(saw_update, "We didn't see the update come through the pipe"); - // and let's update the order of all views again - let Some(RoomListEntry::Filled(room_id)) = view1 - .rooms_list - .lock_ref() - .iter() - .nth(4) - .map(Clone::clone) else { - panic!("4th room has moved? how?") - }; + // and let's update the order of all lists again + let room_id = assert_matches!(list1.rooms_list().get(4), Some(RoomListEntry::Filled(room_id)) => room_id.clone()); let room = client.get_joined_room(&room_id).context("No joined room {room_id}")?; @@ -173,10 +398,10 @@ mod tests { let room_summary = stream.next().await.context("sync has closed unexpectedly")?; let summary = room_summary?; // we only heard about the ones we had asked for - if !summary.views.is_empty() { + if !summary.lists.is_empty() { // only if we saw an update come through - assert_eq!(summary.views, [view_name_1, view_name_2, view_name_3,]); - // notice that our view 2 is now the last view, but all have seen updates + assert_eq!(summary.lists, [list_name_1, list_name_2, list_name_3,]); + // notice that our list 2 is now the last list, but all have seen updates saw_update = true; break; } @@ -187,39 +412,39 @@ mod tests { Ok(()) } - // index-based views don't support removing views. Leaving this test for an API + // index-based lists don't support removing lists. Leaving this test for an API // update later. // #[tokio::test(flavor = "multi_thread", worker_threads = 4)] - async fn live_views() -> anyhow::Result<()> { - let view_name_1 = "sliding1"; - let view_name_2 = "sliding2"; - let view_name_3 = "sliding3"; + async fn live_lists() -> anyhow::Result<()> { + let list_name_1 = "sliding1"; + let list_name_2 = "sliding2"; + let list_name_3 = "sliding3"; let (client, sync_proxy_builder) = random_setup_with_rooms(20).await?; - let build_view = |name| { - SlidingSyncViewBuilder::default() + let build_list = |name| { + SlidingSyncList::builder() .sync_mode(SlidingSyncMode::Selective) .set_range(0u32, 10u32) - .sort(vec!["by_recency".to_string(), "by_name".to_string()]) + .sort(vec!["by_recency".to_owned(), "by_name".to_owned()]) .name(name) .build() }; let sync_proxy = sync_proxy_builder - .add_view(build_view(view_name_1)?) - .add_view(build_view(view_name_2)?) - .add_view(build_view(view_name_3)?) + .add_list(build_list(list_name_1)?) + .add_list(build_list(list_name_2)?) + .add_list(build_list(list_name_3)?) .build() .await?; - let Some(view1 )= sync_proxy.view(view_name_1) else { - bail!("but we just added that view!"); + let Some(list1 )= sync_proxy.list(list_name_1) else { + bail!("but we just added that list!"); }; - let Some(_view2 )= sync_proxy.view(view_name_2) else { - bail!("but we just added that view!"); + let Some(_list2 )= sync_proxy.list(list_name_2) else { + bail!("but we just added that list!"); }; - let Some(_view3 )= sync_proxy.view(view_name_3) else { - bail!("but we just added that view!"); + let Some(_list3 )= sync_proxy.list(list_name_3) else { + bail!("but we just added that list!"); }; let stream = sync_proxy.stream(); @@ -229,26 +454,20 @@ mod tests { }; let summary = room_summary?; // we only heard about the ones we had asked for - assert_eq!(summary.views, [view_name_1, view_name_2, view_name_3]); + assert_eq!(summary.lists, [list_name_1, list_name_2, list_name_3]); - let Some(view_2) = sync_proxy.pop_view(&view_name_2.to_owned()) else { + let Some(list_2) = sync_proxy.pop_list(&list_name_2.to_owned()) else { bail!("Room exists"); }; - // we need to restart the stream after every view listing update + // we need to restart the stream after every list listing update let stream = sync_proxy.stream(); pin_mut!(stream); // Let's trigger an update by sending a message to room pos=3, making it move to // pos 0 - let Some(RoomListEntry::Filled(room_id)) = view1 - .rooms_list - .lock_ref() - .iter().nth(3).map(Clone::clone) else - { - bail!("2nd room has moved? how?"); - }; + let room_id = assert_matches!(list1.rooms_list().get(3), Some(RoomListEntry::Filled(room_id)) => room_id.clone()); let Some(room) = client.get_joined_room(&room_id) else { bail!("No joined room {room_id}"); @@ -265,9 +484,9 @@ mod tests { }; let summary = room_summary?; // we only heard about the ones we had asked for - if !summary.views.is_empty() { + if !summary.lists.is_empty() { // only if we saw an update come through - assert_eq!(summary.views, [view_name_1, view_name_3]); + assert_eq!(summary.lists, [list_name_1, list_name_3]); saw_update = true; break; } @@ -275,20 +494,14 @@ mod tests { assert!(saw_update, "We didn't see the update come through the pipe"); - assert!(sync_proxy.add_view(view_2).is_none()); + assert!(sync_proxy.add_list(list_2).is_none()); - // we need to restart the stream after every view listing update + // we need to restart the stream after every list listing update let stream = sync_proxy.stream(); pin_mut!(stream); - // and let's update the order of all views again - let Some(RoomListEntry::Filled(room_id)) = view1 - .rooms_list - .lock_ref() - .iter().nth(4).map(Clone::clone) else - { - bail!("4th room has moved? how?"); - }; + // and let's update the order of all lists again + let room_id = assert_matches!(list1.rooms_list().get(4), Some(RoomListEntry::Filled(room_id)) => room_id.clone()); let Some(room) = client.get_joined_room(&room_id) else { bail!("No joined room {room_id}"); @@ -305,9 +518,9 @@ mod tests { }; let summary = room_summary?; // we only heard about the ones we had asked for - if !summary.views.is_empty() { + if !summary.lists.is_empty() { // only if we saw an update come through - assert_eq!(summary.views, [view_name_1, view_name_2, view_name_3]); // all views are visible again + assert_eq!(summary.lists, [list_name_1, list_name_2, list_name_3]); // all lists are visible again saw_update = true; break; } @@ -319,28 +532,28 @@ mod tests { } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] - async fn view_goes_live() -> anyhow::Result<()> { + async fn list_goes_live() -> anyhow::Result<()> { let (_client, sync_proxy_builder) = random_setup_with_rooms(21).await?; - let sliding_window_view = SlidingSyncViewBuilder::default() + let sliding_window_list = SlidingSyncList::builder() .sync_mode(SlidingSyncMode::Selective) .set_range(0u32, 10u32) - .sort(vec!["by_recency".to_string(), "by_name".to_string()]) + .sort(vec!["by_recency".to_owned(), "by_name".to_owned()]) .name("sliding") .build()?; - let full = SlidingSyncViewBuilder::default() + let full = SlidingSyncList::builder() .sync_mode(SlidingSyncMode::GrowingFullSync) .batch_size(10u32) - .sort(vec!["by_recency".to_string(), "by_name".to_string()]) + .sort(vec!["by_recency".to_owned(), "by_name".to_owned()]) .name("full") .build()?; let sync_proxy = - sync_proxy_builder.add_view(sliding_window_view).add_view(full).build().await?; + sync_proxy_builder.add_list(sliding_window_list).add_list(full).build().await?; - let view = sync_proxy.view("sliding").context("but we just added that view!")?; - let full_view = sync_proxy.view("full").context("but we just added that view!")?; - assert_eq!(view.state.get_cloned(), SlidingSyncState::Cold, "view isn't cold"); - assert_eq!(full_view.state.get_cloned(), SlidingSyncState::Cold, "full isn't cold"); + let list = sync_proxy.list("sliding").context("but we just added that list!")?; + let full_list = sync_proxy.list("full").context("but we just added that list!")?; + assert_eq!(list.state(), SlidingSyncState::Cold, "list isn't cold"); + assert_eq!(full_list.state(), SlidingSyncState::Cold, "full isn't cold"); let stream = sync_proxy.stream(); pin_mut!(stream); @@ -351,41 +564,32 @@ mod tests { // we only heard about the ones we had asked for assert_eq!(room_summary.rooms.len(), 11); - assert_eq!(view.state.get_cloned(), SlidingSyncState::Live, "view isn't live"); - assert_eq!( - full_view.state.get_cloned(), - SlidingSyncState::CatchingUp, - "full isn't preloading" - ); + assert_eq!(list.state(), SlidingSyncState::Live, "list isn't live"); + assert_eq!(full_list.state(), SlidingSyncState::CatchingUp, "full isn't preloading"); // doing another two requests 0-20; 0-21 should bring full live, too let _room_summary = stream.next().await.context("No room summary found, loop ended unsuccessfully")??; - let rooms_list = full_view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let rooms_list = full_list.rooms_list::(); assert_eq!(rooms_list, repeat(RoomListEntryEasy::Filled).take(21).collect::>()); + assert_eq!(full_list.state(), SlidingSyncState::Live, "full isn't live yet"); - assert_eq!(full_view.state.get_cloned(), SlidingSyncState::Live, "full isn't live yet"); Ok(()) } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn resizing_sliding_window() -> anyhow::Result<()> { let (_client, sync_proxy_builder) = random_setup_with_rooms(20).await?; - let sliding_window_view = SlidingSyncViewBuilder::default() + let sliding_window_list = SlidingSyncList::builder() .sync_mode(SlidingSyncMode::Selective) .set_range(0u32, 10u32) - .sort(vec!["by_recency".to_string(), "by_name".to_string()]) + .sort(vec!["by_recency".to_owned(), "by_name".to_owned()]) .name("sliding") .build()?; - let sync_proxy = sync_proxy_builder.add_view(sliding_window_view).build().await?; - let view = sync_proxy.view("sliding").context("but we just added that view!")?; + let sync_proxy = sync_proxy_builder.add_list(sliding_window_list).build().await?; + let list = sync_proxy.list("sliding").context("but we just added that list!")?; let stream = sync_proxy.stream(); pin_mut!(stream); let room_summary = @@ -393,12 +597,9 @@ mod tests { let summary = room_summary?; // we only heard about the ones we had asked for assert_eq!(summary.rooms.len(), 11); - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple, repeat(RoomListEntryEasy::Filled) @@ -407,28 +608,24 @@ mod tests { .collect::>() ); - let _signal = view.rooms_list.signal_vec_cloned(); + let _signal = list.rooms_list_stream(); // let's move the window - view.set_range(1, 10); + list.set_range(1, 10); // Ensure 0-0 invalidation ranges work. for _n in 0..2 { let room_summary = stream.next().await.context("sync has closed unexpectedly")?; let summary = room_summary?; // we only heard about the ones we had asked for - if summary.views.iter().any(|s| s == "sliding") { + if summary.lists.iter().any(|s| s == "sliding") { break; } } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple, repeat(RoomListEntryEasy::Invalid) @@ -438,23 +635,19 @@ mod tests { .collect::>() ); - view.set_range(5, 10); + list.set_range(5, 10); for _n in 0..2 { let room_summary = stream.next().await.context("sync has closed unexpectedly")?; let summary = room_summary?; // we only heard about the ones we had asked for - if summary.views.iter().any(|s| s == "sliding") { + if summary.lists.iter().any(|s| s == "sliding") { break; } } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple, repeat(RoomListEntryEasy::Invalid) @@ -466,23 +659,19 @@ mod tests { // let's move the window - view.set_range(5, 15); + list.set_range(5, 15); for _n in 0..2 { let room_summary = stream.next().await.context("sync has closed unexpectedly")?; let summary = room_summary?; // we only heard about the ones we had asked for - if summary.views.iter().any(|s| s == "sliding") { + if summary.lists.iter().any(|s| s == "sliding") { break; } } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple, repeat(RoomListEntryEasy::Invalid) @@ -497,14 +686,14 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn moving_out_of_sliding_window() -> anyhow::Result<()> { let (client, sync_proxy_builder) = random_setup_with_rooms(20).await?; - let sliding_window_view = SlidingSyncViewBuilder::default() + let sliding_window_list = SlidingSyncList::builder() .sync_mode(SlidingSyncMode::Selective) .set_range(1u32, 10u32) - .sort(vec!["by_recency".to_string(), "by_name".to_string()]) + .sort(vec!["by_recency".to_owned(), "by_name".to_owned()]) .name("sliding") .build()?; - let sync_proxy = sync_proxy_builder.add_view(sliding_window_view).build().await?; - let view = sync_proxy.view("sliding").context("but we just added that view!")?; + let sync_proxy = sync_proxy_builder.add_list(sliding_window_list).build().await?; + let list = sync_proxy.list("sliding").context("but we just added that list!")?; let stream = sync_proxy.stream(); pin_mut!(stream); let room_summary = @@ -512,12 +701,8 @@ mod tests { let summary = room_summary?; // we only heard about the ones we had asked for assert_eq!(summary.rooms.len(), 10); - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple, repeat(RoomListEntryEasy::Empty) @@ -527,27 +712,23 @@ mod tests { .collect::>() ); - let _signal = view.rooms_list.signal_vec_cloned(); + let _signal = list.rooms_list_stream(); // let's move the window - view.set_range(0, 10); + list.set_range(0, 10); for _n in 0..2 { let room_summary = stream.next().await.context("sync has closed unexpectedly")?; let summary = room_summary?; // we only heard about the ones we had asked for - if summary.views.iter().any(|s| s == "sliding") { + if summary.lists.iter().any(|s| s == "sliding") { break; } } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple, repeat(RoomListEntryEasy::Filled) @@ -558,23 +739,19 @@ mod tests { // let's move the window again - view.set_range(2, 12); + list.set_range(2, 12); for _n in 0..2 { let room_summary = stream.next().await.context("sync has closed unexpectedly")?; let summary = room_summary?; // we only heard about the ones we had asked for - if summary.views.iter().any(|s| s == "sliding") { + if summary.lists.iter().any(|s| s == "sliding") { break; } } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple, repeat(RoomListEntryEasy::Invalid) @@ -587,15 +764,7 @@ mod tests { // now we "move" the room of pos 3 to pos 0; // this is a bordering case - let Some(RoomListEntry::Filled(room_id)) = view - .rooms_list - .lock_ref() - .iter() - .nth(3) - .map(Clone::clone) else - { - panic!("2nd room has moved? how?"); - }; + let room_id = assert_matches!(list.rooms_list().get(3), Some(RoomListEntry::Filled(room_id)) => room_id.clone()); let room = client.get_joined_room(&room_id).context("No joined room {room_id}")?; @@ -607,17 +776,13 @@ mod tests { let room_summary = stream.next().await.context("sync has closed unexpectedly")?; let summary = room_summary?; // we only heard about the ones we had asked for - if summary.views.iter().any(|s| s == "sliding") { + if summary.lists.iter().any(|s| s == "sliding") { break; } } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple, repeat(RoomListEntryEasy::Invalid) @@ -628,27 +793,25 @@ mod tests { ); // items has moved, thus we shouldn't find it where it was - assert!(view.rooms_list.lock_ref().iter().nth(3).unwrap().as_room_id().unwrap() != room_id); + assert!( + list.rooms_list::().get(3).unwrap().as_room_id().unwrap() != room_id + ); // let's move the window again - view.set_range(0, 10); + list.set_range(0, 10); for _n in 0..2 { let room_summary = stream.next().await.context("sync has closed unexpectedly")?; let summary = room_summary?; // we only heard about the ones we had asked for - if summary.views.iter().any(|s| s == "sliding") { + if summary.lists.iter().any(|s| s == "sliding") { break; } } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple, repeat(RoomListEntryEasy::Filled) @@ -660,7 +823,7 @@ mod tests { // and check that our room move has been accepted properly, too. assert_eq!( - view.rooms_list.lock_ref().iter().next().unwrap().as_room_id().unwrap(), + list.rooms_list::().get(0).unwrap().as_room_id().unwrap(), &room_id ); @@ -672,39 +835,39 @@ mod tests { async fn fast_unfreeze() -> anyhow::Result<()> { let (_client, sync_proxy_builder) = random_setup_with_rooms(500).await?; print!("setup took its time"); - let build_views = || { - let sliding_window_view = SlidingSyncViewBuilder::default() + let build_lists = || { + let sliding_window_list = SlidingSyncList::builder() .sync_mode(SlidingSyncMode::Selective) .set_range(1u32, 10u32) - .sort(vec!["by_recency".to_string(), "by_name".to_string()]) + .sort(vec!["by_recency".to_owned(), "by_name".to_owned()]) .name("sliding") .build()?; - let growing_sync = SlidingSyncViewBuilder::default() + let growing_sync = SlidingSyncList::builder() .sync_mode(SlidingSyncMode::GrowingFullSync) .limit(100) - .sort(vec!["by_recency".to_string(), "by_name".to_string()]) + .sort(vec!["by_recency".to_owned(), "by_name".to_owned()]) .name("growing") .build()?; - anyhow::Ok((sliding_window_view, growing_sync)) + anyhow::Ok((sliding_window_list, growing_sync)) }; println!("starting the sliding sync setup"); { // SETUP - let (sliding_window_view, growing_sync) = build_views()?; + let (sliding_window_list, growing_sync) = build_lists()?; let sync_proxy = sync_proxy_builder .clone() .cold_cache("sliding_sync") - .add_view(sliding_window_view) - .add_view(growing_sync) + .add_list(sliding_window_list) + .add_list(growing_sync) .build() .await?; let growing_sync = - sync_proxy.view("growing").context("but we just added that view!")?; // let's catch it up fully. + sync_proxy.list("growing").context("but we just added that list!")?; // let's catch it up fully. let stream = sync_proxy.stream(); pin_mut!(stream); - while growing_sync.state.get_cloned() != SlidingSyncState::Live { + while growing_sync.state() != SlidingSyncState::Live { // we wait until growing sync is all done, too println!("awaiting"); let _room_summary = stream @@ -716,15 +879,15 @@ mod tests { println!("starting from cold"); // recover from frozen state. - let (sliding_window_view, growing_sync) = build_views()?; + let (sliding_window_list, growing_sync) = build_lists()?; // we recover only the window. this should be quick! let start = Instant::now(); let _sync_proxy = sync_proxy_builder .clone() .cold_cache("sliding_sync") - .add_view(sliding_window_view) - .add_view(growing_sync) + .add_list(sliding_window_list) + .add_list(growing_sync) .build() .await?; let duration = start.elapsed(); @@ -737,15 +900,15 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn growing_sync_keeps_going() -> anyhow::Result<()> { let (_client, sync_proxy_builder) = random_setup_with_rooms(50).await?; - let growing_sync = SlidingSyncViewBuilder::default() + let growing_sync = SlidingSyncList::builder() .sync_mode(SlidingSyncMode::GrowingFullSync) .batch_size(10u32) - .sort(vec!["by_recency".to_string(), "by_name".to_string()]) + .sort(vec!["by_recency".to_owned(), "by_name".to_owned()]) .name("growing") .build()?; - let sync_proxy = sync_proxy_builder.clone().add_view(growing_sync).build().await?; - let view = sync_proxy.view("growing").context("but we just added that view!")?; + let sync_proxy = sync_proxy_builder.clone().add_list(growing_sync).build().await?; + let list = sync_proxy.list("growing").context("but we just added that list!")?; let stream = sync_proxy.stream(); pin_mut!(stream); @@ -757,12 +920,8 @@ mod tests { let _summary = room_summary?; } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple, repeat(RoomListEntryEasy::Filled) @@ -777,12 +936,8 @@ mod tests { let _summary = room_summary?; } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple, repeat(RoomListEntryEasy::Filled) @@ -797,15 +952,15 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn growing_sync_keeps_going_after_restart() -> anyhow::Result<()> { let (_client, sync_proxy_builder) = random_setup_with_rooms(50).await?; - let growing_sync = SlidingSyncViewBuilder::default() + let growing_sync = SlidingSyncList::builder() .sync_mode(SlidingSyncMode::GrowingFullSync) .batch_size(10u32) - .sort(vec!["by_recency".to_string(), "by_name".to_string()]) + .sort(vec!["by_recency".to_owned(), "by_name".to_owned()]) .name("growing") .build()?; - let sync_proxy = sync_proxy_builder.clone().add_view(growing_sync).build().await?; - let view = sync_proxy.view("growing").context("but we just added that view!")?; + let sync_proxy = sync_proxy_builder.clone().add_list(growing_sync).build().await?; + let list = sync_proxy.list("growing").context("but we just added that list!")?; let stream = sync_proxy.stream(); pin_mut!(stream); @@ -817,12 +972,8 @@ mod tests { let _summary = room_summary?; } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple.iter().fold(0, |acc, i| if *i == RoomListEntryEasy::Filled { acc + 1 @@ -843,12 +994,8 @@ mod tests { let _summary = room_summary?; } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple.iter().fold(0, |acc, i| if *i == RoomListEntryEasy::Filled { acc + 1 @@ -865,10 +1012,10 @@ mod tests { async fn continue_on_reset() -> anyhow::Result<()> { let (_client, sync_proxy_builder) = random_setup_with_rooms(30).await?; print!("setup took its time"); - let growing_sync = SlidingSyncViewBuilder::default() + let growing_sync = SlidingSyncList::builder() .sync_mode(SlidingSyncMode::GrowingFullSync) .limit(100) - .sort(vec!["by_recency".to_string(), "by_name".to_string()]) + .sort(vec!["by_recency".to_owned(), "by_name".to_owned()]) .name("growing") .build()?; @@ -876,27 +1023,23 @@ mod tests { let sync_proxy = sync_proxy_builder .clone() .cold_cache("sliding_sync") - .add_view(growing_sync) + .add_list(growing_sync) .build() .await?; - let view = sync_proxy.view("growing").context("but we just added that view!")?; // let's catch it up fully. + let list = sync_proxy.list("growing").context("but we just added that list!")?; // let's catch it up fully. let stream = sync_proxy.stream(); pin_mut!(stream); for _n in 0..2 { let room_summary = stream.next().await.context("sync has closed unexpectedly")?; let summary = room_summary?; - if summary.views.iter().any(|s| s == "growing") { + if summary.lists.iter().any(|s| s == "growing") { break; } } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple.iter().fold(0, |acc, i| if *i == RoomListEntryEasy::Filled { acc + 1 @@ -907,7 +1050,7 @@ mod tests { ); // force the pos to be invalid and thus this being reset internally - force_sliding_sync_pos(&sync_proxy, "100".to_owned()); + sync_proxy.set_pos("100".to_owned()); let mut error_seen = false; for _n in 0..2 { @@ -926,19 +1069,15 @@ mod tests { None => anyhow::bail!("Stream ended unexpectedly."), }; // we only heard about the ones we had asked for - if summary.views.iter().any(|s| s == "growing") { + if summary.lists.iter().any(|s| s == "growing") { break; } } assert!(error_seen, "We have not seen the UnknownPos error"); - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple.iter().fold(0, |acc, i| if *i == RoomListEntryEasy::Filled { acc + 1 @@ -955,10 +1094,10 @@ mod tests { async fn noticing_new_rooms_in_growing() -> anyhow::Result<()> { let (client, sync_proxy_builder) = random_setup_with_rooms(30).await?; print!("setup took its time"); - let growing_sync = SlidingSyncViewBuilder::default() + let growing_sync = SlidingSyncList::builder() .sync_mode(SlidingSyncMode::GrowingFullSync) .limit(100) - .sort(vec!["by_recency".to_string(), "by_name".to_string()]) + .sort(vec!["by_recency".to_owned(), "by_name".to_owned()]) .name("growing") .build()?; @@ -966,13 +1105,13 @@ mod tests { let sync_proxy = sync_proxy_builder .clone() .cold_cache("sliding_sync") - .add_view(growing_sync) + .add_list(growing_sync) .build() .await?; - let view = sync_proxy.view("growing").context("but we just added that view!")?; // let's catch it up fully. + let list = sync_proxy.list("growing").context("but we just added that list!")?; // let's catch it up fully. let stream = sync_proxy.stream(); pin_mut!(stream); - while view.state.get_cloned() != SlidingSyncState::Live { + while list.state() != SlidingSyncState::Live { // we wait until growing sync is all done, too println!("awaiting"); let _room_summary = stream @@ -981,12 +1120,8 @@ mod tests { .context("No room summary found, loop ended unsuccessfully")??; } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple.iter().fold(0, |acc, i| if *i == RoomListEntryEasy::Filled { acc + 1 @@ -1006,8 +1141,8 @@ mod tests { let room_summary = stream.next().await.context("sync has closed unexpectedly")?; let summary = room_summary?; // we only heard about the ones we had asked for - if summary.views.iter().any(|s| s == "growing") - && view.rooms_count.get_cloned().unwrap_or_default() == 32 + if summary.lists.iter().any(|s| s == "growing") + && list.rooms_count().unwrap_or_default() == 32 { if seen { // once we saw 32, we give it another loop to catch up! @@ -1018,12 +1153,8 @@ mod tests { } } - let collection_simple = view - .rooms_list - .lock_ref() - .iter() - .map(Into::::into) - .collect::>(); + let collection_simple = list.rooms_list::(); + assert_eq!( collection_simple.iter().fold(0, |acc, i| if *i == RoomListEntryEasy::Filled { acc + 1 @@ -1035,4 +1166,232 @@ mod tests { Ok(()) } + + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn restart_room_resubscription() -> anyhow::Result<()> { + let (client, sync_proxy_builder) = random_setup_with_rooms(3).await?; + + let sync_proxy = sync_proxy_builder + .add_list( + SlidingSyncList::builder() + .sync_mode(SlidingSyncMode::Selective) + .set_range(0u32, 2u32) + .sort(vec!["by_recency".to_owned(), "by_name".to_owned()]) + .name("sliding_list") + .build()?, + ) + .build() + .await?; + + let list = sync_proxy.list("sliding_list").context("list `sliding_list` isn't found")?; + + let stream = sync_proxy.stream(); + pin_mut!(stream); + + let room_summary = + stream.next().await.context("No room summary found, loop ended unsuccessfully")??; + + // we only heard about the ones we had asked for + assert_eq!(room_summary.rooms.len(), 3); + + let collection_simple = list.rooms_list::(); + + assert_eq!( + collection_simple, + repeat(RoomListEntryEasy::Filled).take(3).collect::>() + ); + + let _signal = list.rooms_list_stream(); + + // let's move the window + + list.set_range(1, 2); + + for _n in 0..2 { + let room_summary = stream.next().await.context("sync has closed unexpectedly")??; + + // we only heard about the ones we had asked for + if room_summary.lists.iter().any(|s| s == "sliding_list") { + break; + } + } + + let collection_simple = list.rooms_list::(); + + assert_eq!( + collection_simple, + repeat(RoomListEntryEasy::Invalid) + .take(1) + .chain(repeat(RoomListEntryEasy::Filled).take(2)) + .collect::>() + ); + + // let's get that first entry + + let room_id = assert_matches!(list.rooms_list().get(0), Some(RoomListEntry::Invalidated(room_id)) => room_id.clone()); + + // send a message + + let room = client.get_joined_room(&room_id).context("No joined room {room_id}")?; + + let content = RoomMessageEventContent::text_plain("Hello world"); + + room.send(content, None).await?; // this should put our room up to the most recent + + // let's subscribe + + sync_proxy.subscribe(room_id.clone(), Default::default()); + + let mut room_updated = false; + + for _n in 0..2 { + let room_summary = stream.next().await.context("sync has closed unexpectedly")??; + + // we only heard about the ones we had asked for + if room_summary.rooms.iter().any(|s| s == &room_id) { + room_updated = true; + break; + } + } + + assert!(room_updated, "Room update has not been seen"); + + // force the pos to be invalid and thus this being reset internally + sync_proxy.set_pos("100".to_owned()); + + let mut error_seen = false; + let mut room_updated = false; + + for _n in 0..2 { + let summary = match stream.next().await { + Some(Ok(e)) => e, + Some(Err(e)) => { + match e.client_api_error_kind() { + Some(RumaError::UnknownPos) => { + // we expect this to come through. + error_seen = true; + continue; + } + _ => Err(e)?, + } + } + None => anyhow::bail!("Stream ended unexpectedly."), + }; + + // we only heard about the ones we had asked for + if summary.rooms.iter().any(|s| s == &room_id) { + room_updated = true; + break; + } + } + + assert!(error_seen, "We have not seen the UnknownPos error"); + assert!(room_updated, "Room update has not been seen"); + + // send another message + + let room = client.get_joined_room(&room_id).context("No joined room {room_id}")?; + + let content = RoomMessageEventContent::text_plain("Hello world"); + + let event_id = room.send(content, None).await?.event_id; // this should put our room up to the most recent + + // let's see for it to come down the pipe + let mut room_updated = false; + + for _n in 0..2 { + let room_summary = stream.next().await.context("sync has closed unexpectedly")??; + + // we only heard about the ones we had asked for + if room_summary.rooms.iter().any(|s| s == &room_id) { + room_updated = true; + break; + } + } + assert!(room_updated, "Room update has not been seen"); + + let sliding_sync_room = sync_proxy.get_room(&room_id).expect("Slidin Sync room not found"); + let event = sliding_sync_room.latest_event().await.expect("No even found"); + + let collection_simple = list.rooms_list::(); + + assert_eq!( + collection_simple, + repeat(RoomListEntryEasy::Invalid) + .take(1) + .chain(repeat(RoomListEntryEasy::Filled).take(2)) + .collect::>() + ); + + assert_eq!( + event.event_id().unwrap(), + event_id, + "Latest event is different than what we've sent" + ); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn receipts_extension_works() -> anyhow::Result<()> { + let (client, sync_proxy_builder) = random_setup_with_rooms(1).await?; + let list = SlidingSyncList::builder() + .sync_mode(SlidingSyncMode::Selective) + .ranges(vec![(0u32, 1u32)]) + .sort(vec!["by_recency".to_owned()]) + .name("a") + .build()?; + + let mut config = ReceiptsConfig::default(); + config.enabled = Some(true); + + let sync_proxy = sync_proxy_builder + .clone() + .add_list(list) + .with_receipt_extension(config) + .build() + .await?; + let list = sync_proxy.list("a").context("but we just added that list!")?; + + let stream = sync_proxy.stream(); + pin_mut!(stream); + + stream.next().await.context("sync has closed unexpectedly")??; + + // find the room and send an event which we will send a receipt for + let room_id = list.get_room_id(0).unwrap(); + let room = client.get_joined_room(&room_id).context("No joined room {room_id}")?; + let event_id = + room.send(RoomMessageEventContent::text_plain("Hello world"), None).await?.event_id; + + // now send a receipt + room.send_single_receipt( + CreateReceiptType::Read, + ReceiptThread::Unthreaded, + event_id.clone(), + ) + .await?; + + // we expect to see it because we have enabled the receipt extension. We don't + // know when we'll see it though + let mut found_receipt = false; + for _n in 0..3 { + stream.next().await.context("sync has closed unexpectedly")??; + + // try to find it + let room = client.get_room(&room_id).context("No joined room {room_id}")?; + let receipts = room + .event_receipts(ReceiptType::Read, ReceiptThread::Unthreaded, &event_id) + .await + .unwrap(); + + let expected_user_id = client.user_id().unwrap(); + found_receipt = receipts.iter().any(|(user_id, _)| user_id == expected_user_id); + if found_receipt { + break; + } + } + assert!(found_receipt); + Ok(()) + } } diff --git a/xtask/src/ci.rs b/xtask/src/ci.rs index d2c7636514e..44cc61bf937 100644 --- a/xtask/src/ci.rs +++ b/xtask/src/ci.rs @@ -61,7 +61,6 @@ enum CiCommand { #[derive(Subcommand, PartialEq, Eq, PartialOrd, Ord)] enum FeatureSet { - Default, NoEncryption, NoSled, NoEncryptionAndSled, @@ -192,7 +191,10 @@ fn check_docs() -> Result<()> { fn run_feature_tests(cmd: Option) -> Result<()> { let args = BTreeMap::from([ - (FeatureSet::NoEncryption, "--no-default-features --features sled,native-tls"), + ( + FeatureSet::NoEncryption, + "--no-default-features --features sled,native-tls,experimental-sliding-sync", + ), (FeatureSet::NoSled, "--no-default-features --features e2e-encryption,native-tls"), (FeatureSet::NoEncryptionAndSled, "--no-default-features --features native-tls"), ( @@ -233,6 +235,7 @@ fn run_crypto_tests() -> Result<()> { "rustup run stable cargo clippy -p matrix-sdk-crypto --features=backups_v1 -- -D warnings" ) .run()?; + cmd!("rustup run stable cargo nextest run -p matrix-sdk-crypto --no-default-features").run()?; cmd!("rustup run stable cargo nextest run -p matrix-sdk-crypto --features=backups_v1").run()?; cmd!("rustup run stable cargo test --doc -p matrix-sdk-crypto --features=backups_v1").run()?; cmd!( diff --git a/xtask/src/kotlin.rs b/xtask/src/kotlin.rs new file mode 100644 index 00000000000..4a49946e061 --- /dev/null +++ b/xtask/src/kotlin.rs @@ -0,0 +1,166 @@ +use std::{ + fs::create_dir_all, + path::{Path, PathBuf}, +}; + +use clap::{Args, Subcommand, ValueEnum}; +use xshell::{cmd, pushd}; + +use crate::{workspace, Result}; + +struct PackageValues { + name: &'static str, + udl_path: &'static str, +} + +#[derive(ValueEnum, Clone)] +enum Package { + CryptoSDK, + FullSDK, +} + +impl Package { + fn values(self) -> PackageValues { + match self { + Package::CryptoSDK => PackageValues { + name: "matrix-sdk-crypto-ffi", + udl_path: "bindings/matrix-sdk-crypto-ffi/src/olm.udl", + }, + Package::FullSDK => PackageValues { + name: "matrix-sdk-ffi", + udl_path: "bindings/matrix-sdk-ffi/src/api.udl", + }, + } + } +} + +#[derive(Args)] +pub struct KotlinArgs { + #[clap(subcommand)] + cmd: KotlinCommand, +} + +#[derive(Subcommand)] +enum KotlinCommand { + /// Builds the SDK for Android as an AAR. + BuildAndroidLibrary { + #[clap(value_enum, long)] + package: Package, + /// Build with the release profile + #[clap(long)] + release: bool, + + /// Build with a custom profile, takes precedence over `--release` + #[clap(long)] + profile: Option, + + /// Build the given target only + #[clap(long)] + only_target: Option, + + /// Move the generated files into the given src direct + #[clap(long)] + src_dir: PathBuf, + }, +} + +impl KotlinArgs { + pub fn run(self) -> Result<()> { + let _p = pushd(workspace::root_path()?)?; + + match self.cmd { + KotlinCommand::BuildAndroidLibrary { + release, + profile, + src_dir, + only_target, + package, + } => { + let profile = profile.as_deref().unwrap_or(if release { "release" } else { "dev" }); + build_android_library(profile, only_target, src_dir, package) + } + } + } +} + +fn build_android_library( + profile: &str, + only_target: Option, + src_dir: PathBuf, + package: Package, +) -> Result<()> { + let root_dir = workspace::root_path()?; + + let package_values = package.values(); + let package_name = package_values.name; + let udl_path = root_dir.join(package_values.udl_path); + + let jni_libs_dir = src_dir.join("jniLibs"); + let jni_libs_dir_str = jni_libs_dir.to_str().unwrap(); + + let kotlin_generated_dir = src_dir.join("kotlin"); + create_dir_all(kotlin_generated_dir.clone())?; + + let uniffi_lib_path = if let Some(target) = only_target { + println!("-- Building for {target} [1/1]"); + build_for_android_target(target.as_str(), profile, jni_libs_dir_str, package_name)? + } else { + println!("-- Building for x86_64-linux-android[1/4]"); + build_for_android_target("x86_64-linux-android", profile, jni_libs_dir_str, package_name)?; + println!("-- Building for aarch64-linux-android[2/4]"); + build_for_android_target("aarch64-linux-android", profile, jni_libs_dir_str, package_name)?; + println!("-- Building for armv7-linux-androideabi[3/4]"); + build_for_android_target( + "armv7-linux-androideabi", + profile, + jni_libs_dir_str, + package_name, + )?; + println!("-- Building for i686-linux-android[4/4]"); + build_for_android_target("i686-linux-android", profile, jni_libs_dir_str, package_name)? + }; + + println!("-- Generate uniffi files"); + generate_uniffi_bindings(&udl_path, &uniffi_lib_path, &kotlin_generated_dir)?; + + println!("-- All done and hunky dory. Enjoy!"); + Ok(()) +} + +fn generate_uniffi_bindings( + udl_path: &Path, + library_path: &Path, + ffi_generated_dir: &Path, +) -> Result<()> { + println!("-- library_path = {}", library_path.to_string_lossy()); + let udl_file = camino::Utf8Path::from_path(udl_path).unwrap(); + let out_dir_overwrite = camino::Utf8Path::from_path(ffi_generated_dir).unwrap(); + let library_file = camino::Utf8Path::from_path(library_path).unwrap(); + + uniffi_bindgen::generate_bindings( + udl_file, + None, + vec!["kotlin"], + Some(out_dir_overwrite), + Some(library_file), + false, + )?; + Ok(()) +} + +fn build_for_android_target( + target: &str, + profile: &str, + dest_dir: &str, + package_name: &str, +) -> Result { + cmd!("cargo ndk --target {target} -o {dest_dir} build --profile {profile} -p {package_name}") + .run()?; + + // The builtin dev profile has its files stored under target/debug, all + // other targets have matching directory names + let profile_dir_name = if profile == "dev" { "debug" } else { profile }; + let package_camel = package_name.replace('-', "_"); + let lib_name = format!("lib{package_camel}.so"); + Ok(workspace::target_path()?.join(target).join(profile_dir_name).join(lib_name)) +} diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 15cfd9443d8..4e1706e4aac 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -1,11 +1,13 @@ mod ci; mod fixup; +mod kotlin; mod swift; mod workspace; use ci::CiArgs; use clap::{Parser, Subcommand}; use fixup::FixupArgs; +use kotlin::KotlinArgs; use swift::SwiftArgs; use xshell::cmd; @@ -30,6 +32,7 @@ enum Command { open: bool, }, Swift(SwiftArgs), + Kotlin(KotlinArgs), } fn main() -> Result<()> { @@ -38,6 +41,7 @@ fn main() -> Result<()> { Command::Fixup(cfg) => cfg.run(), Command::Doc { open } => build_docs(open.then_some("--open"), DenyWarnings::No), Command::Swift(cfg) => cfg.run(), + Command::Kotlin(cfg) => cfg.run(), } }