Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new events for AI typing indicator #3516

Merged
merged 12 commits into from
Dec 2, 2024
105 changes: 105 additions & 0 deletions Sources/StreamChat/WebSocketClient/Events/AITypingEvents.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
//
// Copyright © 2024 Stream.io Inc. All rights reserved.
//

import Foundation

/// An event that provides updates about the state of the AI typing indicator.
public struct AIIndicatorUpdateEvent: Event {
/// The state of the AI typing indicator.
public let state: AITypingState
/// The channel ID this event is related to.
public let cid: ChannelId?
/// The message ID this event is related to.
public let messageId: MessageId?
/// Optional server message, usually when an error occurs.
public let aiMessage: String?
}

class AIIndicatorUpdateEventDTO: EventDTO {
let payload: EventPayload

init(from response: EventPayload) throws {
payload = response
}

func toDomainEvent(session: DatabaseSession) -> Event? {
if let typingState = payload.aiState,
let aiTypingState = AITypingState(rawValue: typingState) {
return AIIndicatorUpdateEvent(
state: aiTypingState,
cid: payload.cid,
messageId: payload.messageId,
aiMessage: payload.aiMessage
)
} else {
return nil
}
}
}

/// An event that clears the AI typing indicator.
public struct AIIndicatorClearEvent: Event {
/// The channel ID this event is related to.
public let cid: ChannelId?
}

class AIIndicatorClearEventDTO: EventDTO {
let payload: EventPayload

init(from response: EventPayload) throws {
payload = response
}

func toDomainEvent(session: any DatabaseSession) -> (any Event)? {
AIIndicatorClearEvent(cid: payload.cid)
}
}

/// An event that indicates the AI has stopped generating the message.
public struct AIIndicatorStopEvent: CustomEventPayload, Event {
public static var eventType: EventType = .aiTypingIndicatorStop

/// The channel ID this event is related to.
public let cid: ChannelId?

public init(cid: ChannelId?) {
self.cid = cid
}
}

class AIIndicatorStopEventDTO: EventDTO {
let payload: EventPayload

init(from response: EventPayload) throws {
payload = response
}

func toDomainEvent(session: any DatabaseSession) -> (any Event)? {
AIIndicatorStopEvent(cid: payload.cid)
}
}

/// The state of the AI typing indicator.
public struct AITypingState: ExpressibleByStringLiteral, Hashable {
public var rawValue: String

public init?(rawValue: String) {
self.rawValue = rawValue
}

public init(stringLiteral value: String) {
rawValue = value
}
}

public extension AITypingState {
/// The AI is thinking.
static let thinking: Self = "AI_STATE_THINKING"
/// The AI is checking external sources.
static let checkingExternalSources: Self = "AI_STATE_EXTERNAL_SOURCES"
/// The AI is generating the message.
static let generating: Self = "AI_STATE_GENERATING"
/// There's an error with the message generation.
static let error: Self = "AI_STATE_ERROR"
}
18 changes: 17 additions & 1 deletion Sources/StreamChat/WebSocketClient/Events/EventPayload.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class EventPayload: Decodable {
case thread
case vote = "poll_vote"
case poll
case aiState = "ai_state"
case messageId = "message_id"
case aiMessage = "ai_message"
}

let eventType: EventType
Expand Down Expand Up @@ -68,6 +71,10 @@ class EventPayload: Decodable {
/// Thread Data, it is stored in Result, to be easier to debug decoding errors
let threadDetails: Result<ThreadDetailsPayload, Error>?
let threadPartial: Result<ThreadPartialPayload, Error>?

let aiState: String?
let messageId: String?
let aiMessage: String?

init(
eventType: EventType,
Expand Down Expand Up @@ -96,7 +103,10 @@ class EventPayload: Decodable {
threadDetails: Result<ThreadDetailsPayload, Error>? = nil,
threadPartial: Result<ThreadPartialPayload, Error>? = nil,
poll: PollPayload? = nil,
vote: PollVotePayload? = nil
vote: PollVotePayload? = nil,
aiState: String? = nil,
messageId: String? = nil,
aiMessage: String? = nil
) {
self.eventType = eventType
self.connectionId = connectionId
Expand Down Expand Up @@ -125,6 +135,9 @@ class EventPayload: Decodable {
self.threadDetails = threadDetails
self.poll = poll
self.vote = vote
self.aiState = aiState
self.messageId = messageId
self.aiMessage = aiMessage
}

required init(from decoder: Decoder) throws {
Expand Down Expand Up @@ -158,6 +171,9 @@ class EventPayload: Decodable {
threadPartial = container.decodeAsResultIfPresent(ThreadPartialPayload.self, forKey: .thread)
vote = try container.decodeIfPresent(PollVotePayload.self, forKey: .vote)
poll = try container.decodeIfPresent(PollPayload.self, forKey: .poll)
aiState = try container.decodeIfPresent(String.self, forKey: .aiState)
messageId = try container.decodeIfPresent(String.self, forKey: .messageId)
aiMessage = try container.decodeIfPresent(String.self, forKey: .aiMessage)
}

func event() throws -> Event {
Expand Down
14 changes: 14 additions & 0 deletions Sources/StreamChat/WebSocketClient/Events/EventType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,17 @@ public extension EventType {

/// When a thread has a new reply.
static let threadMessageNew: Self = "notification.thread_message_new"

// MARK: - AI

// When an AI typing indicator's state has changed.
static let aiTypingIndicatorChanged: Self = "ai_indicator.update"

// When an AI typing indicator has been cleared.
static let aiTypingIndicatorClear: Self = "ai_indicator.clear"

// When an AI typing indicator has been stopped.
static let aiTypingIndicatorStop: Self = "ai_indicator.stop"
}

extension EventType {
Expand Down Expand Up @@ -208,6 +219,9 @@ extension EventType {
case .pollVoteRemoved: return try PollVoteRemovedEventDTO(from: response)
case .threadUpdated: return try ThreadUpdatedEventDTO(from: response)
case .threadMessageNew: return try ThreadMessageNewEventDTO(from: response)
case .aiTypingIndicatorChanged: return try AIIndicatorUpdateEventDTO(from: response)
case .aiTypingIndicatorClear: return try AIIndicatorClearEventDTO(from: response)
case .aiTypingIndicatorStop: return try AIIndicatorStopEventDTO(from: response)
default:
if response.cid == nil {
throw ClientError.UnknownUserEvent(response.eventType)
Expand Down
30 changes: 30 additions & 0 deletions StreamChat.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,8 @@
847E946E269C687300E31D0C /* EventsController.swift in Sources */ = {isa = PBXBuildFile; fileRef = 847E946D269C687300E31D0C /* EventsController.swift */; };
847F3CEA2689FDEB00D240E0 /* ChatMessageCell.swift in Sources */ = {isa = PBXBuildFile; fileRef = 847F3CE92689FDEB00D240E0 /* ChatMessageCell.swift */; };
8486CAF926FA51EE00A9AD96 /* EventDTOConverterMiddleware_Tests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8486CAF826FA51EE00A9AD96 /* EventDTOConverterMiddleware_Tests.swift */; };
848849B62CEE01070010E7CA /* AITypingEvents.swift in Sources */ = {isa = PBXBuildFile; fileRef = 848849B52CEE01070010E7CA /* AITypingEvents.swift */; };
848849B72CEE01070010E7CA /* AITypingEvents.swift in Sources */ = {isa = PBXBuildFile; fileRef = 848849B52CEE01070010E7CA /* AITypingEvents.swift */; };
849980F1277246DB00ABA58B /* UIScrollView+Extensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = 849980F0277246DB00ABA58B /* UIScrollView+Extensions.swift */; };
849AE664270CB14000423A20 /* VideoAttachmentComposerPreview_Tests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 849AE663270CB14000423A20 /* VideoAttachmentComposerPreview_Tests.swift */; };
849AE666270CB55F00423A20 /* VideoAttachmentGalleryCell_Tests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 849AE665270CB55F00423A20 /* VideoAttachmentGalleryCell_Tests.swift */; };
Expand Down Expand Up @@ -812,6 +814,10 @@
84DCB851269F4D31006CDF32 /* EventsController+Combine_Tests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84DCB850269F4D31006CDF32 /* EventsController+Combine_Tests.swift */; };
84DCB853269F569A006CDF32 /* EventsController+SwiftUI.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84DCB852269F569A006CDF32 /* EventsController+SwiftUI.swift */; };
84DCB855269F56A7006CDF32 /* EventsController+SwiftUI_Tests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84DCB854269F56A7006CDF32 /* EventsController+SwiftUI_Tests.swift */; };
84E46A372CFA1B8E000CBDDE /* AIIndicatorClear.json in Resources */ = {isa = PBXBuildFile; fileRef = 84E46A342CFA1B8E000CBDDE /* AIIndicatorClear.json */; };
84E46A382CFA1B8E000CBDDE /* AIIndicatorStop.json in Resources */ = {isa = PBXBuildFile; fileRef = 84E46A352CFA1B8E000CBDDE /* AIIndicatorStop.json */; };
84E46A392CFA1B8E000CBDDE /* AIIndicatorUpdate.json in Resources */ = {isa = PBXBuildFile; fileRef = 84E46A362CFA1B8E000CBDDE /* AIIndicatorUpdate.json */; };
84E46A3B2CFA1BB9000CBDDE /* AIIndicatorEvents_Tests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84E46A3A2CFA1BB9000CBDDE /* AIIndicatorEvents_Tests.swift */; };
84EB4E76276A012900E47E73 /* ClientError_Tests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84EB4E75276A012900E47E73 /* ClientError_Tests.swift */; };
84EB4E78276A03DE00E47E73 /* ErrorPayload_Tests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84EB4E77276A03DE00E47E73 /* ErrorPayload_Tests.swift */; };
84EE53B12BBC32AD00FD2A13 /* Chat_Mock.swift in Sources */ = {isa = PBXBuildFile; fileRef = 84EE53B02BBC32AD00FD2A13 /* Chat_Mock.swift */; };
Expand Down Expand Up @@ -3660,6 +3666,7 @@
847E946D269C687300E31D0C /* EventsController.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = EventsController.swift; sourceTree = "<group>"; };
847F3CE92689FDEB00D240E0 /* ChatMessageCell.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatMessageCell.swift; sourceTree = "<group>"; };
8486CAF826FA51EE00A9AD96 /* EventDTOConverterMiddleware_Tests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = EventDTOConverterMiddleware_Tests.swift; sourceTree = "<group>"; };
848849B52CEE01070010E7CA /* AITypingEvents.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AITypingEvents.swift; sourceTree = "<group>"; };
849980F0277246DB00ABA58B /* UIScrollView+Extensions.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "UIScrollView+Extensions.swift"; sourceTree = "<group>"; };
849AE661270CB00000423A20 /* VideoLoader_Mock.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = VideoLoader_Mock.swift; sourceTree = "<group>"; };
849AE663270CB14000423A20 /* VideoAttachmentComposerPreview_Tests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = VideoAttachmentComposerPreview_Tests.swift; sourceTree = "<group>"; };
Expand Down Expand Up @@ -3711,6 +3718,10 @@
84DCB852269F569A006CDF32 /* EventsController+SwiftUI.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "EventsController+SwiftUI.swift"; sourceTree = "<group>"; };
84DCB854269F56A7006CDF32 /* EventsController+SwiftUI_Tests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "EventsController+SwiftUI_Tests.swift"; sourceTree = "<group>"; };
84E32EA4276C9AB200A27112 /* InternetConnection_Mock.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InternetConnection_Mock.swift; sourceTree = "<group>"; };
84E46A342CFA1B8E000CBDDE /* AIIndicatorClear.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = AIIndicatorClear.json; sourceTree = "<group>"; };
84E46A352CFA1B8E000CBDDE /* AIIndicatorStop.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = AIIndicatorStop.json; sourceTree = "<group>"; };
84E46A362CFA1B8E000CBDDE /* AIIndicatorUpdate.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = AIIndicatorUpdate.json; sourceTree = "<group>"; };
84E46A3A2CFA1BB9000CBDDE /* AIIndicatorEvents_Tests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AIIndicatorEvents_Tests.swift; sourceTree = "<group>"; };
84EB4E732769F76500E47E73 /* BackgroundTaskScheduler_Mock.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BackgroundTaskScheduler_Mock.swift; sourceTree = "<group>"; };
84EB4E75276A012900E47E73 /* ClientError_Tests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ClientError_Tests.swift; sourceTree = "<group>"; };
84EB4E77276A03DE00E47E73 /* ErrorPayload_Tests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ErrorPayload_Tests.swift; sourceTree = "<group>"; };
Expand Down Expand Up @@ -5515,6 +5526,7 @@
8A0C3BBB24C0947400CAFD19 /* UserEvents.swift */,
841BAA0F2BCEADAC000C73E4 /* PollsEvents.swift */,
AD7BE1692C209888000A5756 /* ThreadEvents.swift */,
848849B52CEE01070010E7CA /* AITypingEvents.swift */,
);
path = Events;
sourceTree = "<group>";
Expand Down Expand Up @@ -6247,6 +6259,16 @@
path = Cells;
sourceTree = "<group>";
};
84E46A332CFA1B73000CBDDE /* AIIndicator */ = {
isa = PBXGroup;
children = (
84E46A342CFA1B8E000CBDDE /* AIIndicatorClear.json */,
84E46A352CFA1B8E000CBDDE /* AIIndicatorStop.json */,
84E46A362CFA1B8E000CBDDE /* AIIndicatorUpdate.json */,
);
path = AIIndicator;
sourceTree = "<group>";
};
84EE53AF2BBC329300FD2A13 /* State */ = {
isa = PBXGroup;
children = (
Expand Down Expand Up @@ -6543,6 +6565,7 @@
8A62705F24BE31B20040BFD6 /* Events */ = {
isa = PBXGroup;
children = (
84E46A332CFA1B73000CBDDE /* AIIndicator */,
ADE57B802C3C5C4600DD6B88 /* Thread */,
8A0C3BCA24C1C38C00CAFD19 /* Channel */,
E7DB9F2526329C0C0090D9C7 /* HealthCheck */,
Expand Down Expand Up @@ -6961,6 +6984,7 @@
8A62705B24BE2BC00040BFD6 /* TypingEvent_Tests.swift */,
8A0C3BC824C0BBAB00CAFD19 /* UserEvents_Tests.swift */,
ADE57B872C3C60CB00DD6B88 /* ThreadEvents_Tests.swift */,
84E46A3A2CFA1BB9000CBDDE /* AIIndicatorEvents_Tests.swift */,
);
path = Events;
sourceTree = "<group>";
Expand Down Expand Up @@ -10213,6 +10237,9 @@
A311B40327E8B9AD00CFCF6D /* NotificationInviteAccepted.json in Resources */,
A311B3D327E8B98C00CFCF6D /* Message.json in Resources */,
A311B3E427E8B98C00CFCF6D /* MessagePayloadWithCustom.json in Resources */,
84E46A372CFA1B8E000CBDDE /* AIIndicatorClear.json in Resources */,
84E46A382CFA1B8E000CBDDE /* AIIndicatorStop.json in Resources */,
84E46A392CFA1B8E000CBDDE /* AIIndicatorUpdate.json in Resources */,
A3D9D68727EDE3B900725066 /* yoda.jpg in Resources */,
A311B42727E8B9CE00CFCF6D /* MessageReactionPayload+CustomExtraData.json in Resources */,
A311B3E527E8B98C00CFCF6D /* MessageWithBrokenAttachments.json in Resources */,
Expand Down Expand Up @@ -11551,6 +11578,7 @@
40789D1729F6AC500018C2BB /* AudioPlaybackContextAccessor.swift in Sources */,
79D5CDD427EA1BE300BE7D8B /* MessageTranslationsPayload.swift in Sources */,
88D85D97252F168000AE1030 /* MemberController+SwiftUI.swift in Sources */,
848849B72CEE01070010E7CA /* AITypingEvents.swift in Sources */,
43D3F0FC28410A0200B74921 /* CreateCallRequestBody.swift in Sources */,
79877A0D2498E4BC00015F8B /* CurrentUser.swift in Sources */,
4042967D29FAC9DA0089126D /* AudioAnalysisContext.swift in Sources */,
Expand Down Expand Up @@ -11703,6 +11731,7 @@
84EB4E76276A012900E47E73 /* ClientError_Tests.swift in Sources */,
DA8407232525E871005A0F62 /* UserListPayload_Tests.swift in Sources */,
437FCA1926D906B20000223C /* ChatPushNotificationContent_Tests.swift in Sources */,
84E46A3B2CFA1BB9000CBDDE /* AIIndicatorEvents_Tests.swift in Sources */,
F61D7C3124FF9D1F00188A0E /* MessageEndpoints_Tests.swift in Sources */,
8459C9EE2BFB673E00F0D235 /* PollVoteListController+Combine_Tests.swift in Sources */,
8836FFC325408210009FDF73 /* FlagUserPayload_Tests.swift in Sources */,
Expand Down Expand Up @@ -12497,6 +12526,7 @@
C121E8C2274544B100023E4C /* ChannelEventsController.swift in Sources */,
C121E8C3274544B100023E4C /* ListChange.swift in Sources */,
C15C8839286C7BF300E6A72C /* BackgroundListDatabaseObserver.swift in Sources */,
848849B62CEE01070010E7CA /* AITypingEvents.swift in Sources */,
ADB951B3291C3CE900800554 /* AnyAttachmentUpdater.swift in Sources */,
40789D3829F6AC500018C2BB /* AssetPropertyLoading.swift in Sources */,
C121E8C4274544B100023E4C /* EntityChange.swift in Sources */,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"type": "ai_indicator.clear",
"cid": "messaging:general-a4ea1bed-f233-4021-b9f8-f9519367cefd",
"channel_id": "general-a4ea1bed-f233-4021-b9f8-f9519367cefd",
"channel_type": "messaging",
"user": {
"id": "ai-b076753a-830e-40b0-816d-0929bb73d7ce",
"role": "user",
"created_at": "2024-11-27T15:57:45.157276Z",
"updated_at": "2024-11-27T15:57:45.157276Z",
"last_active": "2024-11-27T15:57:45.177077Z",
"banned": false,
"online": true
},
"created_at": "2024-11-27T15:57:45.32967Z"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"type": "ai_indicator.stop",
"cid": "messaging:general-3ac667a1-6113-4b16-b1e3-50dbff0ffb89",
"channel_id": "general-3ac667a1-6113-4b16-b1e3-50dbff0ffb89",
"channel_type": "messaging",
"user": {
"id": "ai-1e9666df-9b9e-429e-b7e4-e2f54446d5ac",
"role": "user",
"created_at": "2024-11-27T15:51:55.649597Z",
"updated_at": "2024-11-27T15:51:55.649597Z",
"last_active": "2024-11-27T15:51:55.668787Z",
"banned": false,
"online": true
},
"created_at": "2024-11-27T15:51:55.803339Z"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"type": "ai_indicator.update",
"cid": "messaging:general-3ac667a1-6113-4b16-b1e3-50dbff0ffb89",
"channel_id": "general-3ac667a1-6113-4b16-b1e3-50dbff0ffb89",
"channel_type": "messaging",
"message_id": "aba120c6-c845-4c5a-968d-31ed0429c31e",
"ai_state": "AI_STATE_ERROR",
"ai_message": "failure",
"user": {
"id": "ai-1e9666df-9b9e-429e-b7e4-e2f54446d5ac",
"role": "user",
"created_at": "2024-11-27T15:51:55.649597Z",
"updated_at": "2024-11-27T15:51:55.649597Z",
"last_active": "2024-11-27T15:51:55.668787Z",
"banned": false,
"online": true
},
"created_at": "2024-11-27T15:51:55.757904Z"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//
// Copyright © 2024 Stream.io Inc. All rights reserved.
//

@testable import StreamChat
@testable import StreamChatTestTools
import XCTest

final class AIIndicatorEvents_Tests: XCTestCase {
var eventDecoder: EventDecoder!

override func setUp() {
super.setUp()
eventDecoder = EventDecoder()
}

override func tearDown() {
super.tearDown()
eventDecoder = nil
}

func test_aiIndicatorUpdate() throws {
let json = XCTestCase.mockData(fromJSONFile: "AIIndicatorUpdate")
let event = try XCTUnwrap(try eventDecoder.decode(from: json) as? AIIndicatorUpdateEventDTO)
XCTAssertEqual(event.payload.cid?.rawValue, "messaging:general-3ac667a1-6113-4b16-b1e3-50dbff0ffb89")
XCTAssertEqual(event.payload.messageId, "aba120c6-c845-4c5a-968d-31ed0429c31e")
XCTAssertEqual(event.payload.aiState, "AI_STATE_ERROR")
XCTAssertEqual(event.payload.aiMessage, "failure")
}

func test_aiIndicatorClear() throws {
let json = XCTestCase.mockData(fromJSONFile: "AIIndicatorClear")
let event = try XCTUnwrap(try eventDecoder.decode(from: json) as? AIIndicatorClearEventDTO)
XCTAssertEqual(event.payload.cid?.rawValue, "messaging:general-a4ea1bed-f233-4021-b9f8-f9519367cefd")
}

func test_aiIndicatorStop() throws {
let json = XCTestCase.mockData(fromJSONFile: "AIIndicatorStop")
let event = try XCTUnwrap(try eventDecoder.decode(from: json) as? AIIndicatorStopEventDTO)
XCTAssertEqual(event.payload.cid?.rawValue, "messaging:general-3ac667a1-6113-4b16-b1e3-50dbff0ffb89")
}
}
Loading