Skip to content

Commit

Permalink
Merge pull request #593 from bounswe/be-520/semantic-search
Browse files Browse the repository at this point in the history
Be 520/semantic search
  • Loading branch information
BatuhanIlhan authored Dec 17, 2023
2 parents 6b2ca66 + 431eedc commit 76664cb
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 5 deletions.
4 changes: 4 additions & 0 deletions app/backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"test:e2e": "jest --config ./test/jest-e2e.json"
},
"dependencies": {
"@langchain/google-genai": "^0.0.3",
"@nestjs-modules/mailer": "^1.9.1",
"@nestjs/common": "^10.0.0",
"@nestjs/config": "^3.1.1",
Expand All @@ -28,13 +29,16 @@
"@nestjs/platform-express": "^10.0.0",
"@nestjs/swagger": "^7.1.14",
"@nestjs/typeorm": "^10.0.0",
"@pinecone-database/pinecone": "^1.1.2",
"bcrypt": "^5.1.1",
"chai": "^4.3.10",
"class-transformer": "^0.5.1",
"class-validator": "^0.14.0",
"helmet": "^7.0.0",
"langchain": "^0.0.208",
"nodemailer": "^6.9.7",
"pg": "^8.11.3",
"pinecone-client": "^2.0.0",
"reflect-metadata": "^0.1.13",
"rxjs": "^7.8.1",
"typeorm": "^0.3.17"
Expand Down
4 changes: 4 additions & 0 deletions app/backend/src/comment/comment.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import { Poll } from '../poll/entities/poll.entity';
import { BadgeService } from '../badge/badge.service';
import { Report } from '../user/entities/report.entity';
import { TagModule } from '../tag/tag.module';
import { Pinecone } from '@pinecone-database/pinecone';
import { GoogleGenerativeAIEmbeddings } from '@langchain/google-genai';
import { RankingService } from '../ranking/ranking.service';
import { Ranking } from '../ranking/entities/ranking.entity';
import { Vote } from '../vote/entities/vote.entity';
Expand All @@ -37,6 +39,8 @@ import { Vote } from '../vote/entities/vote.entity';
PollRepository,
UserService,
BadgeService,
Pinecone,
GoogleGenerativeAIEmbeddings,
RankingService
],
})
Expand Down
4 changes: 4 additions & 0 deletions app/backend/src/like/like.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import { UserService } from '../user/user.service';
import { BadgeService } from '../badge/badge.service';
import { Report } from '../user/entities/report.entity';
import { TagModule } from '../tag/tag.module';
import { Pinecone } from '@pinecone-database/pinecone';
import { GoogleGenerativeAIEmbeddings } from '@langchain/google-genai';
import { RankingService } from '../ranking/ranking.service';
import { Ranking } from '../ranking/entities/ranking.entity';
import { Vote } from '../vote/entities/vote.entity';
Expand Down Expand Up @@ -42,6 +44,8 @@ import { Vote } from '../vote/entities/vote.entity';
PollRepository,
UserService,
BadgeService,
Pinecone,
GoogleGenerativeAIEmbeddings,
RankingService
],
})
Expand Down
17 changes: 17 additions & 0 deletions app/backend/src/poll/dto/responses/semantic-search-response.dto.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { ApiProperty } from '@nestjs/swagger';

export class SemanticSearchResponseDto {
@ApiProperty({
example: 'Who will be the champion?'
})
pageContent: string;

@ApiProperty({
example: {
id: 'f9b9a7d1-8e5a-4e3e-8f0a-7f0a8e5a4e3e',
},
})
metaData: {
id: string;
};
}
40 changes: 40 additions & 0 deletions app/backend/src/poll/poll.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {
} from './dto/settle-poll-request.dto';
import { ModeratorGuard } from '../moderator/guards/moderator.guard';
import { VerificationModeratorGuard } from '../moderator/guards/verification-moderator.guard';
import { Poll } from './entities/poll.entity';

const statusMap = new Map<string, boolean>();
statusMap.set('pending', null);
Expand All @@ -37,6 +38,30 @@ statusMap.set('rejected', false);
@ApiTags('poll')
export class PollController {
constructor(private readonly pollService: PollService) {}
@Get('pinecone')
public async pinecone(): Promise<any> {
return await this.pollService.pineconeTest();
}

@Post('pinecone/sync')
public async pineconeSync(): Promise<any> {
return await this.pollService.syncVectorStore();
}

@UseGuards(AuthGuard, VerificationGuard)
@ApiResponse({
status: 200,
description: 'Polls are searched successfully.',
type: [GetPollResponseDto],
})
@ApiResponse({
status: 500,
description: 'Internal server error, contact with backend team.',
})
@Post('pinecone/search')
public async pineconeSearch(@Query('searchQuery') searchQuery: string): Promise<Poll[]> {
return await this.pollService.searchSemanticPolls(searchQuery);
}

@UseGuards(AuthGuard, VerificationGuard)
@Post()
Expand Down Expand Up @@ -259,6 +284,21 @@ export class PollController {
return await this.pollService.findPollById(pollId, userId);
}

@UseGuards(AuthGuard, VerificationGuard)
@ApiResponse({
status: 200,
description: 'Polls are removed successfully.',
})
@ApiResponse({ status: 404, description: 'Poll not found.' })
@ApiResponse({
status: 500,
description: 'Internal server error, contact with backend team.',
})
@Delete()
public async removeAll() {
return await this.pollService.removeAll();
}

@ApiResponse({ status: 200, description: 'Poll deleted successfully.' })
@ApiResponse({ status: 404, description: 'Poll not found.' })
@ApiResponse({
Expand Down
4 changes: 4 additions & 0 deletions app/backend/src/poll/poll.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import { Comment } from '../comment/entities/comment.entity';
import { Report } from '../user/entities/report.entity';
import { TagModule } from '../tag/tag.module';
import { TokenDecoderMiddleware } from '../auth/middlewares/tokenDecoder.middleware';
import { Pinecone } from '@pinecone-database/pinecone';
import { GoogleGenerativeAIEmbeddings } from '@langchain/google-genai';
import { RankingService } from '../ranking/ranking.service';
import { Ranking } from '../ranking/entities/ranking.entity';
import { Vote } from '../vote/entities/vote.entity';
Expand Down Expand Up @@ -52,6 +54,8 @@ import { Vote } from '../vote/entities/vote.entity';
BadgeService,
ModeratorService,
TagService,
Pinecone,
GoogleGenerativeAIEmbeddings,
RankingService
],
exports: [PollService],
Expand Down
72 changes: 69 additions & 3 deletions app/backend/src/poll/poll.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
} from '@nestjs/common';
import { InjectRepository } from '@nestjs/typeorm';
import { Poll } from './entities/poll.entity';
import { IsNull, Repository } from 'typeorm';
import { In, IsNull, Repository } from 'typeorm';
import { Option } from '../option/entities/option.entity';
import { Tag } from '../tag/entities/tag.entity';
import { PollRepository } from './repository/poll.repository';
Expand All @@ -16,22 +16,42 @@ import { Like } from '../like/entities/like.entity';
import { Comment } from '../comment/entities/comment.entity';
import { Sort } from './enums/sort.enum';
import { TagService } from '../tag/tag.service';
import { Document } from "langchain/document";
import { Pinecone } from "@pinecone-database/pinecone";
import { PineconeStore } from "langchain/vectorstores/pinecone";
import { GoogleGenerativeAIEmbeddings } from "@langchain/google-genai";
import { TaskType } from "@google/generative-ai";
import { RankingService } from '../ranking/ranking.service';

@Injectable()
export class PollService {
private pineconeStore: PineconeStore;
private embeddings: GoogleGenerativeAIEmbeddings
constructor(
private readonly pollRepository: PollRepository,
@InjectRepository(Option)
private readonly optionRepository: Repository<Option>,
@InjectRepository(Tag) private readonly tagRepository: Repository<Tag>,
@InjectRepository(Tag)
private readonly tagRepository: Repository<Tag>,
@InjectRepository(Like)
private readonly likeRepository: Repository<Like>,
@InjectRepository(Comment)
private readonly commentRepository: Repository<Comment>,
private readonly tagService: TagService,
private readonly pinecone: Pinecone,
private readonly rankingService: RankingService
) {}
) {
this.embeddings = new GoogleGenerativeAIEmbeddings({
modelName: "embedding-001", // 768 dimensions
taskType: TaskType.RETRIEVAL_DOCUMENT,
title: "Prediction Polls",
});
this.pineconeStore = new PineconeStore(
this.embeddings,
{
pineconeIndex: pinecone.Index("prediction-polls"),
});
}

public async createPoll(createPollDto: any): Promise<Poll> {
const poll = new Poll();
Expand Down Expand Up @@ -66,6 +86,18 @@ export class PollService {
);

savedPoll.tags = tags;
try {
await this.pineconeStore.addDocuments([
new Document({
metadata: {id: savedPoll.id},
pageContent: savedPoll.question + ' ' + savedPoll.description + ' ' + savedPoll.tags.map((tag) => tag.name).join(' ')
}),
], {
ids: [savedPoll.id],
});
} catch (e) {
console.log(e);
}

return await this.pollRepository.save(savedPoll);
}
Expand Down Expand Up @@ -221,4 +253,38 @@ export class PollService {
public async removeById(id: string): Promise<void> {
await this.pollRepository.delete(id);
}

public async removeAll(): Promise<void> {
await this.pollRepository.delete({});
}

public async pineconeTest(): Promise<any> {
return this.pineconeStore.pineconeIndex;
}

public async syncVectorStore(): Promise<any> {
const polls = await this.pollRepository.find({
relations: ['options', 'tags'],
});

const documents = polls.map((poll) => {
return new Document({
metadata: {id: poll.id},
pageContent: poll.question + ' ' + poll.description + ' ' + poll.tags.map((tag) => tag.name).join(' '),
});
});

await this.pineconeStore.addDocuments(documents, {
ids: polls.map((poll) => poll.id),
});
}

public async searchSemanticPolls(query: string): Promise<Poll[]> {
let results = await this.pineconeStore.similaritySearchWithScore(query, 5);
results = results.filter((result) => result[1] > 0.7).map((result) => result[0].metadata.id);
return await this.pollRepository.find({
where: {id: In(results)},
relations: ['options', 'tags', 'creator', 'outcome', 'likes', 'comments', 'votes', 'annotations'],
})
}
}
5 changes: 3 additions & 2 deletions app/backend/src/ranking/ranking.controller.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Controller, Get, Post, Body, Patch, Param, Delete } from '@nestjs/common';
import { RankingService } from './ranking.service';
import { CreateRankingDto } from './dto/create-ranking.dto';
import { UpdateRankingDto } from './dto/update-ranking.dto';
import { ApiBearerAuth, ApiTags } from '@nestjs/swagger';

@ApiBearerAuth()
@Controller('ranking')
@ApiTags('ranking')
export class RankingController {
constructor(private readonly rankingService: RankingService) {}

Expand Down

0 comments on commit 76664cb

Please sign in to comment.