Skip to content

Commit

Permalink
[Add] preq 생성 시 꼬리질문 세션 유지
Browse files Browse the repository at this point in the history
  • Loading branch information
Lightieey committed Sep 18, 2023
1 parent f101bcc commit 2d1c15e
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package kr.co.preq.domain.applicationChild.repository;

import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;

import kr.co.preq.domain.applicationChild.entity.ApplicationChild;

Expand All @@ -13,4 +14,21 @@ public interface ApplicationChildRepository extends JpaRepository<ApplicationChi
List<ApplicationChild> findByApplicationId(Long applicationId);

List<ApplicationChild> findApplicationChildByApplicationIdAndMemberIdAndIsDeletedOrderByCreatedAt(Long applicationId, Long memberId, Boolean isDeleted);

@Query(value = "WITH RECURSIVE find_division "
+ "as (\n"
+ " select *, 1 as DEPTH\n"
+ " from application_child\n"
+ " where 1=1\n"
+ " and id = 77\n"
+ "\n"
+ " union\n"
+ " select d.*, fd.DEPTH+1\n"
+ " from find_division fd\n"
+ " INNER JOIN application_child d on fd.parent_id = d.id\n"
+ ")\n"
+ "select * from find_division order by DEPTH desc;", nativeQuery = true)
List<ApplicationChild> findAllByApplicationChildIdOrderByParentIdAscNullsFirstCategoryIdAsc(Long applicationChildId);

ApplicationChild findApplicationChildById(Long id);
}
14 changes: 14 additions & 0 deletions src/main/java/kr/co/preq/domain/preq/dto/SessionDto.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package kr.co.preq.domain.preq.dto;

import java.util.List;

import lombok.AllArgsConstructor;
import lombok.Getter;

@Getter
@AllArgsConstructor
public class SessionDto {
private String question;
private String answer;
private List<String> preqList;
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import java.util.ArrayList;
import java.util.List;

import org.springframework.beans.factory.annotation.Value;

@Getter
@Setter
@NoArgsConstructor
Expand All @@ -29,7 +27,7 @@ public class OpenAIRequestDto implements Serializable {
private List<Message> messages;

@Builder
public OpenAIRequestDto(String model, Integer n, String command, String prompt) {
public OpenAIRequestDto(String model, Integer n, String command, List<Message> sessions, String prompt) {
this.model = model;
this.n = n;
//this.prompt = prompt;
Expand All @@ -38,7 +36,7 @@ public OpenAIRequestDto(String model, Integer n, String command, String prompt)
//this.topP = topP;
this.messages = new ArrayList<>();
this.messages.add(new Message("system", command));
this.messages.addAll(sessions);
this.messages.add(new Message("user", prompt));
}

}
22 changes: 20 additions & 2 deletions src/main/java/kr/co/preq/domain/preq/service/OpenAIService.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package kr.co.preq.domain.preq.service;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

Expand All @@ -12,6 +13,8 @@
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;

import kr.co.preq.domain.preq.dto.SessionDto;
import kr.co.preq.domain.preq.dto.Message;
import kr.co.preq.domain.preq.dto.request.OpenAIRequestDto;
import kr.co.preq.domain.preq.dto.response.OpenAIResponseDto;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -41,8 +44,18 @@ public class OpenAIService {
private String MEDIA_TYPE = "application/json; charset=UTF-8";
RestTemplate restTemplate = new RestTemplate(new HttpComponentsClientHttpRequestFactory());

public List<String> generateQuestions(String question, String answer) {
String prompt = "질문: " + question + "\n답변: " + answer;
public List<String> generateQuestions(String question, String answer, List<SessionDto> sessionDtoList) {
List<Message> sessions = new ArrayList<>();
sessionDtoList.forEach((a) -> {
sessions.add(new Message("user", makeUserPrompt(a.getQuestion(), a.getAnswer())));
sessions.add(new Message("assistant", a.getPreqList().toString()));
System.out.println("session--------");
System.out.println(makeUserPrompt(a.getQuestion(), a.getAnswer()));
System.out.println(a.getPreqList().toString());
});

String prompt = makeUserPrompt(question, answer);
System.out.println("new---------");
System.out.println(prompt);

OpenAIResponseDto response = this.getResponse(
Expand All @@ -51,6 +64,7 @@ public List<String> generateQuestions(String question, String answer) {
MODEL,
N,
COMMAND,
sessions,
prompt
)
)
Expand All @@ -76,4 +90,8 @@ private OpenAIResponseDto getResponse(HttpEntity<OpenAIRequestDto> chatGptReques

return responseEntity.getBody();
}

private String makeUserPrompt(String question, String answer) {
return "질문: " + question + "\n답변: " + answer;
}
}
16 changes: 14 additions & 2 deletions src/main/java/kr/co/preq/domain/preq/service/PreqService.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package kr.co.preq.domain.preq.service;

import kr.co.preq.domain.preq.dto.SessionDto;
import kr.co.preq.domain.preq.dto.Message;
import kr.co.preq.domain.preq.dto.request.KeywordAndSoftskillsRequestDto;
import kr.co.preq.domain.preq.dto.response.KeywordAndSoftskillsResponseDto;
import kr.co.preq.domain.auth.service.AuthService;
import kr.co.preq.domain.preq.dto.response.PreqAndKeywordResponseDto;
import kr.co.preq.domain.preq.dto.response.PreqMapper;
import kr.co.preq.domain.preq.entity.Preq;
Expand Down Expand Up @@ -46,8 +47,19 @@ public PreqAndKeywordResponseDto createPreqAndKeyword(Long applicationChildId) {
ApplicationChild applicationChild = applicationChildRepository.findById(applicationChildId)
.orElseThrow(() -> new BadRequestException(ErrorCode.NO_ID));

List<ApplicationChild> allApplicationChilds = applicationChildRepository
.findAllByApplicationChildIdOrderByParentIdAscNullsFirstCategoryIdAsc(applicationChildId);

List<SessionDto> parentApplicationChilds = allApplicationChilds
.subList(0, allApplicationChilds.size()-1) // remove first element (=self)
.stream().map((a) -> {
List<Preq> preqList = preqRepository.findPreqsByApplicationChildIdAndIsDeleted(a.getId(), false);
return new SessionDto(a.getQuestion(), a.getAnswer(), preqList.stream().map((p) -> p.getQuestion()).collect(
Collectors.toList()));
}).collect(Collectors.toList());

// generate preQuestions
List<String> questions = openAIService.generateQuestions(applicationChild.getQuestion(), applicationChild.getAnswer());
List<String> questions = openAIService.generateQuestions(applicationChild.getQuestion(), applicationChild.getAnswer(), parentApplicationChilds);

// manufacturing chatGPT response
List<String> cutQuestions = new ArrayList<>();
Expand Down

0 comments on commit 2d1c15e

Please sign in to comment.