Skip to content

Commit

Permalink
feat: support switch model (#5575)
Browse files Browse the repository at this point in the history
* feat: ai settings page

* chore: intergate client api

* chore: replace open ai calls

* chore: disable gen image from ai

* chore: clippy

* chore: remove learn about ai

* chore: fix wanrings

* chore: fix restart button title

* chore: remove await

* chore: remove loading indicator

---------

Co-authored-by: nathan <[email protected]>
Co-authored-by: Lucas.Xu <[email protected]>
  • Loading branch information
3 people authored Jun 24, 2024
1 parent 40312f4 commit 54c9d12
Show file tree
Hide file tree
Showing 78 changed files with 1,383 additions and 499 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ void main() {
await tester.ime.insertText(inputContent);
expect(find.text(inputContent, findRichText: true), findsOneWidget);

// TODO(nathan): remove the await
// 6 seconds for data sync
await tester.waitForSeconds(6);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import 'package:appflowy/plugins/document/presentation/editor_plugins/openai/service/openai_client.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/openai/service/ai_client.dart';
import 'package:appflowy/startup/startup.dart';
import 'package:appflowy_editor/appflowy_editor.dart';
import 'package:appflowy_editor/src/render/toolbar/toolbar_widget.dart';
Expand Down Expand Up @@ -103,8 +103,8 @@ Future<AppFlowyEditor> setUpOpenAITesting(WidgetTester tester) async {
}

Future<void> mockOpenAIRepository() async {
await getIt.unregister<OpenAIRepository>();
getIt.registerFactoryAsync<OpenAIRepository>(
await getIt.unregister<AIRepository>();
getIt.registerFactoryAsync<AIRepository>(
() => Future.value(
MockOpenAIRepository(),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class MockOpenAIRepository extends HttpOpenAIRepository {
required Future<void> Function() onStart,
required Future<void> Function(TextCompletionResponse response) onProcess,
required Future<void> Function() onEnd,
required void Function(OpenAIError error) onError,
required void Function(AIError error) onError,
String? suffix,
int maxTokens = 2048,
double temperature = 0.3,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class ImagePlaceholderState extends State<ImagePlaceholder> {
UploadImageType.local,
UploadImageType.url,
UploadImageType.unsplash,
UploadImageType.openAI,
// UploadImageType.openAI,
UploadImageType.stabilityAI,
],
onSelectedLocalImage: (path) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import 'dart:async';

import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/openai/service/ai_client.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/openai/service/error.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/openai/service/openai_client.dart';
import 'package:appflowy/startup/startup.dart';
import 'package:appflowy_result/appflowy_result.dart';
import 'package:easy_localization/easy_localization.dart';
Expand All @@ -22,7 +22,7 @@ class OpenAIImageWidget extends StatefulWidget {
}

class _OpenAIImageWidgetState extends State<OpenAIImageWidget> {
Future<FlowyResult<List<String>, OpenAIError>>? future;
Future<FlowyResult<List<String>, AIError>>? future;
String query = '';

@override
Expand Down Expand Up @@ -93,7 +93,7 @@ class _OpenAIImageWidgetState extends State<OpenAIImageWidget> {
}

void _search() async {
final openAI = await getIt.getAsync<OpenAIRepository>();
final openAI = await getIt.getAsync<AIRepository>();
setState(() {
future = openAI.generateImage(
prompt: query,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/header/cover_editor.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/image/embed_image_url_widget.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/image/open_ai_image_widget.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/image/stability_ai_image_widget.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/image/unsplash_image_widget.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/image/upload_image_file_widget.dart';
Expand All @@ -19,7 +18,7 @@ enum UploadImageType {
url,
unsplash,
stabilityAI,
openAI,
// openAI,
color;

String get description {
Expand All @@ -30,8 +29,8 @@ enum UploadImageType {
return LocaleKeys.document_imageBlock_embedLink_label.tr();
case UploadImageType.unsplash:
return LocaleKeys.document_imageBlock_unsplash_label.tr();
case UploadImageType.openAI:
return LocaleKeys.document_imageBlock_ai_label.tr();
// case UploadImageType.openAI:
// return LocaleKeys.document_imageBlock_ai_label.tr();
case UploadImageType.stabilityAI:
return LocaleKeys.document_imageBlock_stability_ai_label.tr();
case UploadImageType.color:
Expand Down Expand Up @@ -186,23 +185,23 @@ class _UploadImageMenuState extends State<UploadImageMenu> {
),
),
);
case UploadImageType.openAI:
return supportOpenAI
? Expanded(
child: Container(
padding: const EdgeInsets.all(8.0),
constraints: constraints,
child: OpenAIImageWidget(
onSelectNetworkImage: widget.onSelectedAIImage,
),
),
)
: Padding(
padding: const EdgeInsets.all(8.0),
child: FlowyText(
LocaleKeys.document_imageBlock_pleaseInputYourOpenAIKey.tr(),
),
);
// case UploadImageType.openAI:
// return supportOpenAI
// ? Expanded(
// child: Container(
// padding: const EdgeInsets.all(8.0),
// constraints: constraints,
// child: OpenAIImageWidget(
// onSelectNetworkImage: widget.onSelectedAIImage,
// ),
// ),
// )
// : Padding(
// padding: const EdgeInsets.all(8.0),
// child: FlowyText(
// LocaleKeys.document_imageBlock_pleaseInputYourOpenAIKey.tr(),
// ),
// );
case UploadImageType.stabilityAI:
return supportStabilityAI
? Expanded(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import 'package:appflowy/plugins/document/presentation/editor_plugins/openai/service/error.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/openai/service/text_completion.dart';
import 'package:appflowy_backend/protobuf/flowy-chat/entities.pb.dart';
import 'package:appflowy_result/appflowy_result.dart';

abstract class AIRepository {
Future<void> getStreamedCompletions({
required String prompt,
required Future<void> Function() onStart,
required Future<void> Function(TextCompletionResponse response) onProcess,
required Future<void> Function() onEnd,
required void Function(AIError error) onError,
String? suffix,
int maxTokens = 2048,
double temperature = 0.3,
bool useAction = false,
});

Future<void> streamCompletion({
required String text,
required CompletionTypePB completionType,
required Future<void> Function() onStart,
required Future<void> Function(String text) onProcess,
required Future<void> Function() onEnd,
required void Function(AIError error) onError,
});

Future<FlowyResult<List<String>, AIError>> generateImage({
required String prompt,
int n = 1,
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ part 'error.freezed.dart';
part 'error.g.dart';

@freezed
class OpenAIError with _$OpenAIError {
const factory OpenAIError({
class AIError with _$AIError {
const factory AIError({
String? code,
required String message,
}) = _OpenAIError;
}) = _AIError;

factory OpenAIError.fromJson(Map<String, Object?> json) =>
_$OpenAIErrorFromJson(json);
factory AIError.fromJson(Map<String, Object?> json) =>
_$AIErrorFromJson(json);
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import 'dart:async';
import 'dart:convert';

import 'package:appflowy/plugins/document/presentation/editor_plugins/openai/service/text_edit.dart';
import 'package:appflowy/plugins/document/presentation/editor_plugins/openai/service/ai_client.dart';
import 'package:appflowy_backend/protobuf/flowy-chat/entities.pbenum.dart';
import 'package:appflowy_result/appflowy_result.dart';
import 'package:http/http.dart' as http;

Expand All @@ -25,58 +26,7 @@ enum OpenAIRequestType {
}
}

abstract class OpenAIRepository {
/// Get completions from GPT-3
///
/// [prompt] is the prompt text
/// [suffix] is the suffix text
/// [maxTokens] is the maximum number of tokens to generate
/// [temperature] is the temperature of the model
///
Future<FlowyResult<TextCompletionResponse, OpenAIError>> getCompletions({
required String prompt,
String? suffix,
int maxTokens = 2048,
double temperature = .3,
});

Future<void> getStreamedCompletions({
required String prompt,
required Future<void> Function() onStart,
required Future<void> Function(TextCompletionResponse response) onProcess,
required Future<void> Function() onEnd,
required void Function(OpenAIError error) onError,
String? suffix,
int maxTokens = 2048,
double temperature = 0.3,
bool useAction = false,
});

/// Get edits from GPT-3
///
/// [input] is the input text
/// [instruction] is the instruction text
/// [temperature] is the temperature of the model
///
Future<FlowyResult<TextEditResponse, OpenAIError>> getEdits({
required String input,
required String instruction,
double temperature = 0.3,
});

/// Generate image from GPT-3
///
/// [prompt] is the prompt text
/// [n] is the number of images to generate
///
/// the result is a list of urls
Future<FlowyResult<List<String>, OpenAIError>> generateImage({
required String prompt,
int n = 1,
});
}

class HttpOpenAIRepository implements OpenAIRepository {
class HttpOpenAIRepository implements AIRepository {
const HttpOpenAIRepository({
required this.client,
required this.apiKey,
Expand All @@ -90,50 +40,13 @@ class HttpOpenAIRepository implements OpenAIRepository {
'Content-Type': 'application/json',
};

@override
Future<FlowyResult<TextCompletionResponse, OpenAIError>> getCompletions({
required String prompt,
String? suffix,
int maxTokens = 2048,
double temperature = 0.3,
}) async {
final parameters = {
'model': 'gpt-3.5-turbo-instruct',
'prompt': prompt,
'suffix': suffix,
'max_tokens': maxTokens,
'temperature': temperature,
'stream': false,
};

final response = await client.post(
OpenAIRequestType.textCompletion.uri,
headers: headers,
body: json.encode(parameters),
);

if (response.statusCode == 200) {
return FlowyResult.success(
TextCompletionResponse.fromJson(
json.decode(
utf8.decode(response.bodyBytes),
),
),
);
} else {
return FlowyResult.failure(
OpenAIError.fromJson(json.decode(response.body)['error']),
);
}
}

@override
Future<void> getStreamedCompletions({
required String prompt,
required Future<void> Function() onStart,
required Future<void> Function(TextCompletionResponse response) onProcess,
required Future<void> Function() onEnd,
required void Function(OpenAIError error) onError,
required void Function(AIError error) onError,
String? suffix,
int maxTokens = 2048,
double temperature = 0.3,
Expand Down Expand Up @@ -201,50 +114,14 @@ class HttpOpenAIRepository implements OpenAIRepository {
} else {
final body = await response.stream.bytesToString();
onError(
OpenAIError.fromJson(json.decode(body)['error']),
AIError.fromJson(json.decode(body)['error']),
);
}
return;
}

@override
Future<FlowyResult<TextEditResponse, OpenAIError>> getEdits({
required String input,
required String instruction,
double temperature = 0.3,
int n = 1,
}) async {
final parameters = {
'model': 'gpt-4',
'input': input,
'instruction': instruction,
'temperature': temperature,
'n': n,
};

final response = await client.post(
OpenAIRequestType.textEdit.uri,
headers: headers,
body: json.encode(parameters),
);

if (response.statusCode == 200) {
return FlowyResult.success(
TextEditResponse.fromJson(
json.decode(
utf8.decode(response.bodyBytes),
),
),
);
} else {
return FlowyResult.failure(
OpenAIError.fromJson(json.decode(response.body)['error']),
);
}
}

@override
Future<FlowyResult<List<String>, OpenAIError>> generateImage({
Future<FlowyResult<List<String>, AIError>> generateImage({
required String prompt,
int n = 1,
}) async {
Expand Down Expand Up @@ -273,11 +150,23 @@ class HttpOpenAIRepository implements OpenAIRepository {
return FlowyResult.success(urls);
} else {
return FlowyResult.failure(
OpenAIError.fromJson(json.decode(response.body)['error']),
AIError.fromJson(json.decode(response.body)['error']),
);
}
} catch (error) {
return FlowyResult.failure(OpenAIError(message: error.toString()));
return FlowyResult.failure(AIError(message: error.toString()));
}
}

@override
Future<void> streamCompletion({
required String text,
required CompletionTypePB completionType,
required Future<void> Function() onStart,
required Future<void> Function(String text) onProcess,
required Future<void> Function() onEnd,
required void Function(AIError error) onError,
}) {
throw UnimplementedError();
}
}
Loading

0 comments on commit 54c9d12

Please sign in to comment.