Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/manage meta #518

Merged
merged 180 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
180 commits
Select commit Hold shift + click to select a range
63d430a
add api call
drcege Oct 23, 2024
6720da4
add call_api ops
drcege Oct 24, 2024
8daa6e1
clean
drcege Oct 29, 2024
ef11951
minor update
drcege Oct 29, 2024
5597d5c
more tests
drcege Oct 29, 2024
4b6e769
update tests
drcege Oct 29, 2024
835be22
Merge branch 'main' into dev/api_model
drcege Oct 29, 2024
325a753
update prompts
drcege Oct 29, 2024
4f04bdd
fix unittest
drcege Oct 30, 2024
0adbdcd
update tests
drcege Oct 30, 2024
0aa4069
add docs
drcege Nov 1, 2024
f007532
minor fix
drcege Nov 1, 2024
9aa7390
Merge branch 'main' into dev/api_model
drcege Nov 5, 2024
ee4f461
add API processor
drcege Nov 5, 2024
9bbfe47
Merge branch 'main' into dev/api_model
drcege Nov 5, 2024
b00b182
refine API processor
drcege Nov 5, 2024
b718de7
refine
drcege Nov 5, 2024
6d1d433
chunk and extract events
BeachWang Nov 6, 2024
4d1670f
fix bugs
drcege Nov 6, 2024
9e11aa3
fix tests
drcege Nov 6, 2024
cc40fc0
extract attribute
BeachWang Nov 7, 2024
4c262ad
Merge branch 'dev/api_model' of github.com:alibaba/data-juicer into d…
BeachWang Nov 7, 2024
347bc0f
refine tests
drcege Nov 7, 2024
c9d5051
extract nickname
BeachWang Nov 8, 2024
8a128ca
Merge branch 'dev/api_model' of github.com:alibaba/data-juicer into d…
BeachWang Nov 8, 2024
9262777
nickname test done
BeachWang Nov 8, 2024
58fc020
merge main
BeachWang Nov 8, 2024
c7dc28e
lightRAG to OP
BeachWang Nov 11, 2024
238869e
merge main
BeachWang Nov 11, 2024
0e51a43
doc done
BeachWang Nov 11, 2024
6d9d8a5
remove extra test
BeachWang Nov 11, 2024
a637a64
relavant -> relevant
BeachWang Nov 11, 2024
56e7988
fix minor error
BeachWang Nov 11, 2024
03880b7
group by op done
BeachWang Nov 12, 2024
23174fd
ValueError -> Exception
BeachWang Nov 12, 2024
e82cc06
merge main
BeachWang Nov 12, 2024
20a8dee
fix config_all error
BeachWang Nov 12, 2024
38a9511
fix prepare_api_model
BeachWang Nov 13, 2024
35f0eb3
fix rank sample None
BeachWang Nov 13, 2024
155d3dd
constant fix key
BeachWang Nov 13, 2024
f862897
aggregator op
BeachWang Nov 14, 2024
2d4da5e
merge llm_info_extract
BeachWang Nov 14, 2024
7e66057
init python_lambda_mapper
drcege Nov 20, 2024
a61859b
set default arg
drcege Nov 20, 2024
8031a31
fix init
drcege Nov 21, 2024
67711f9
add python_file_mapper
drcege Nov 21, 2024
cdeb692
support text & most relavant entities
BeachWang Nov 22, 2024
125a8f3
coverage ignore_errors
drcege Nov 25, 2024
0c68089
index sample
BeachWang Nov 25, 2024
651789d
role_playing_system_prompt_yaml
BeachWang Nov 25, 2024
c5d7b9e
merge python_file_mapper
BeachWang Nov 26, 2024
cf6a53a
Merge branch 'main' of github.com:alibaba/data-juicer into dev/group_…
BeachWang Nov 26, 2024
222790e
system_prompt begin
BeachWang Nov 27, 2024
75f2911
support batched
drcege Nov 27, 2024
11fa852
remove unforkable
BeachWang Nov 27, 2024
4af2bfb
support batched & add docs
drcege Nov 27, 2024
8867580
Merge branch 'main' into op/python_lambda
drcege Nov 28, 2024
553d5ad
add docs
drcege Nov 28, 2024
470ca19
fix docs
drcege Nov 28, 2024
399a238
update docs
drcege Nov 28, 2024
706365f
Merge branch 'main' into op/python_file
drcege Nov 28, 2024
115ab9a
pre-commit done
BeachWang Nov 28, 2024
ecb8635
fix batch bug
BeachWang Dec 2, 2024
03e3469
fix batch bug
BeachWang Dec 2, 2024
1788fa6
merge fix_batch_bug
BeachWang Dec 3, 2024
735ff4d
Merge branch 'main' of github.com:alibaba/data-juicer into debug/fix_…
BeachWang Dec 3, 2024
00ff624
fix filter batch
BeachWang Dec 3, 2024
8601519
fix filter batch
BeachWang Dec 3, 2024
eeefcab
system prompt recipe done
BeachWang Dec 3, 2024
6eaa50c
Merge branch 'main' of github.com:alibaba/data-juicer into dev/group_…
BeachWang Dec 3, 2024
1575717
not rank for filter
BeachWang Dec 5, 2024
2c5c4a1
limit pyav version
BeachWang Dec 5, 2024
5c96dd5
Merge branch 'debug/fix_batch_bug' of github.com:alibaba/data-juicer …
BeachWang Dec 5, 2024
49be467
add test for op
BeachWang Dec 5, 2024
9ab02fe
tmp
BeachWang Dec 5, 2024
ba086de
tmp
BeachWang Dec 5, 2024
f712131
doc done
BeachWang Dec 5, 2024
12b7616
Merge branch 'op/python_lambda' of github.com:alibaba/data-juicer int…
BeachWang Dec 5, 2024
e57b64a
merge python_lambda
BeachWang Dec 5, 2024
5f463cd
merge python_lambda
BeachWang Dec 5, 2024
a786070
skip api test
BeachWang Dec 6, 2024
73f4e77
merge main
BeachWang Dec 6, 2024
4b6f0b9
merge main
BeachWang Dec 6, 2024
788a212
add env dependency
BeachWang Dec 6, 2024
10242c4
install by recipe
BeachWang Dec 10, 2024
6a43eec
dialog sent intensity
BeachWang Dec 12, 2024
621a693
add query
BeachWang Dec 12, 2024
b46d105
change to dj_install
BeachWang Dec 12, 2024
a0da444
change to dj_install
BeachWang Dec 12, 2024
02f8dda
developer doc done
BeachWang Dec 12, 2024
635a8a9
merge dj_install
BeachWang Dec 12, 2024
083b665
+ add auto mode for analyzer: load all filters that produce stats to …
HYLcool Dec 12, 2024
662df5e
+ add default mem_required for those model-based OPs
HYLcool Dec 13, 2024
3b04908
query sent_int mapper
BeachWang Dec 13, 2024
6b4d525
query sentiment test done
BeachWang Dec 13, 2024
926c3da
- support wordcloud drawing for str or str list fields in stats
HYLcool Dec 13, 2024
27347c0
- take the minimum one of dataset length and auto num
HYLcool Dec 13, 2024
d19f92f
* update default export path
HYLcool Dec 13, 2024
fbd6726
* set version limit for wandb to avoid exception
HYLcool Dec 13, 2024
58288f7
change meta pass
BeachWang Dec 13, 2024
9f9f85b
+ add docs for auto mode
HYLcool Dec 13, 2024
b665c10
doc done
BeachWang Dec 13, 2024
07be552
merge main
BeachWang Dec 13, 2024
8ba4156
sentiment detection
BeachWang Dec 16, 2024
48b1761
diff label
BeachWang Dec 16, 2024
8160725
sentiment
BeachWang Dec 16, 2024
01846d1
test done
BeachWang Dec 16, 2024
566eb5b
+ support t-test for Measure
HYLcool Dec 16, 2024
7b8ee5c
* fix some bugs
HYLcool Dec 16, 2024
a76d975
dialog intent label
BeachWang Dec 17, 2024
2fb9fe4
fix typo
BeachWang Dec 17, 2024
324467f
prompt adjust
BeachWang Dec 17, 2024
4a3ad39
add more test
BeachWang Dec 17, 2024
937b3f1
query intent detection
BeachWang Dec 17, 2024
d4ca87b
for test
BeachWang Dec 17, 2024
8109c71
for test
BeachWang Dec 17, 2024
c749dcd
change model
BeachWang Dec 17, 2024
c7df0bc
fix typo
BeachWang Dec 17, 2024
c7662cb
fix typo
BeachWang Dec 17, 2024
6f44ec0
for test
BeachWang Dec 17, 2024
9b6652d
for test
BeachWang Dec 17, 2024
fa306dc
doc done
BeachWang Dec 17, 2024
601d9a2
- support analyze a dataset object
HYLcool Dec 17, 2024
34f2ab6
- support analysis on tags in meta
HYLcool Dec 17, 2024
8531a01
- support analysis with tagging OPs
HYLcool Dec 17, 2024
4d6b701
- move tags into the meta field
HYLcool Dec 18, 2024
767b2f0
dialog topic detection
BeachWang Dec 18, 2024
c088cb1
dialog topic detection
BeachWang Dec 18, 2024
12351db
dialog topic detection
BeachWang Dec 18, 2024
4b4e946
dialog topic detection
BeachWang Dec 18, 2024
4506a8e
dialog topic detection
BeachWang Dec 18, 2024
d21db85
dialog topic detection
BeachWang Dec 18, 2024
6f394ee
query topic detection
BeachWang Dec 18, 2024
abee815
query topic detection
BeachWang Dec 18, 2024
0494741
query topic detection
BeachWang Dec 18, 2024
38523a1
query topic detection
BeachWang Dec 18, 2024
b03a33a
query topic detection
BeachWang Dec 18, 2024
35aa6bd
- do not tell tags using their suffix
HYLcool Dec 18, 2024
ad226b1
doc done
BeachWang Dec 18, 2024
85e1392
- add insight mining
HYLcool Dec 18, 2024
b02745b
meta tags aggregator
BeachWang Dec 19, 2024
f2654f1
meta tags aggregator
BeachWang Dec 19, 2024
23e5d6f
meta tags aggregator
BeachWang Dec 19, 2024
1c74709
meta tags aggregator
BeachWang Dec 19, 2024
a997726
meta tags aggregator
BeachWang Dec 19, 2024
2642847
meta tags aggregator
BeachWang Dec 19, 2024
2dae3b8
meta tags aggregator
BeachWang Dec 19, 2024
8bb2509
meta tags aggregator
BeachWang Dec 19, 2024
90303ee
meta tags aggregator
BeachWang Dec 19, 2024
e4c6ff1
meta tags aggregator
BeachWang Dec 19, 2024
12f8946
meta tags aggregator
BeachWang Dec 19, 2024
09b1599
meta tags aggregator
BeachWang Dec 19, 2024
203bc64
naive reverse grouper
BeachWang Dec 19, 2024
cf01e7e
naive reverse grouper
BeachWang Dec 19, 2024
e3d7b8b
* resolve the bugs when running insight mining in multiprocessing mode
HYLcool Dec 19, 2024
3ca9994
Merge branch 'main' into feat/insight_mining
HYLcool Dec 19, 2024
16ca358
* update unittests
HYLcool Dec 20, 2024
dfb0bca
* update unittests
HYLcool Dec 20, 2024
f8b9539
* update unittests
HYLcool Dec 20, 2024
0ba6459
tags specified field
BeachWang Dec 20, 2024
45259e5
* update readme for analyzer
HYLcool Dec 20, 2024
174ee05
Merge branch 'main' into feat/insight_mining
HYLcool Dec 20, 2024
4ad8b8d
merge main
BeachWang Dec 20, 2024
9f098bd
doc done
BeachWang Dec 20, 2024
51f53dc
* use more detailed key
HYLcool Dec 20, 2024
58001ca
+ add reference
HYLcool Dec 20, 2024
892cb48
Merge branch 'feat/insight_mining' of github.com:alibaba/data-juicer …
BeachWang Dec 20, 2024
19fd15b
move mm tags
BeachWang Dec 20, 2024
8fec0f7
move meta key
BeachWang Dec 24, 2024
6fdc95b
done
BeachWang Dec 30, 2024
8e01f7e
merge main
BeachWang Dec 30, 2024
af9e14d
test done
BeachWang Dec 31, 2024
f57f454
rm nested set
BeachWang Dec 31, 2024
7a7e3df
Update constant.py
yxdyc Jan 3, 2025
250ecbd
rename agg to batch meta
BeachWang Jan 3, 2025
57fab1c
Merge branch 'dev/manage_meta' of github.com:alibaba/data-juicer into…
BeachWang Jan 3, 2025
fcfced2
export in naive reverse grouper
BeachWang Jan 3, 2025
0a02163
fix bug
BeachWang Jan 3, 2025
855c4a7
fix bug
BeachWang Jan 3, 2025
fb8b11b
fix bug
BeachWang Jan 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 46 additions & 28 deletions configs/config_all.yaml

Large diffs are not rendered by default.

39 changes: 23 additions & 16 deletions data_juicer/ops/aggregator/entity_attribute_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from data_juicer.ops.base_op import OPERATORS, Aggregator
from data_juicer.utils.common_utils import (avg_split_string_list_under_limit,
is_string_list, nested_access,
nested_set)
is_string_list)
from data_juicer.utils.constant import BatchMetaKeys, Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model

from .nested_aggregator import NestedAggregator
Expand Down Expand Up @@ -53,8 +53,8 @@ def __init__(self,
api_model: str = 'gpt-4o',
entity: str = None,
attribute: str = None,
input_key: str = None,
output_key: str = None,
input_key: str = MetaKeys.event_description,
output_key: str = BatchMetaKeys.entity_attribute,
word_limit: PositiveInt = 100,
max_token_num: Optional[PositiveInt] = None,
*,
Expand All @@ -73,12 +73,10 @@ def __init__(self,
:param api_model: API model name.
:param entity: The given entity.
:param attribute: The given attribute.
:param input_key: The input field key in the samples. Support for
nested keys such as "__dj__stats__.text_len". It is text_key
in default.
:param output_key: The output field key in the samples. Support for
nested keys such as "__dj__stats__.text_len". It is same as the
input_key in default.
:param input_key: The input key in the meta field of the samples.
It is "event_description" in default.
:param output_key: The output key in the aggregation field of the
samples. It is "entity_attribute" in default.
:param word_limit: Prompt the output length.
:param max_token_num: The max token num of the total tokens of the
sub documents. Without limitation if it is None.
Expand All @@ -103,8 +101,8 @@ def __init__(self,

self.entity = entity
self.attribute = attribute
self.input_key = input_key or self.text_key
self.output_key = output_key or self.input_key
self.input_key = input_key
self.output_key = output_key
self.word_limit = word_limit
self.max_token_num = max_token_num

Expand All @@ -131,7 +129,7 @@ def __init__(self,
**model_params)

self.try_num = try_num
self.nested_sum = NestedAggregator(model=api_model,
self.nested_sum = NestedAggregator(api_model=api_model,
max_token_num=max_token_num,
api_endpoint=api_endpoint,
response_path=response_path,
Expand Down Expand Up @@ -185,12 +183,21 @@ def attribute_summary(self, sub_docs, rank=None):

def process_single(self, sample=None, rank=None):

if self.output_key in sample[Fields.batch_meta]:
return sample

if Fields.meta not in sample or self.input_key not in sample[
Fields.meta][0]:
logger.warning('The input key does not exist in the sample!')
return sample

sub_docs = [d[self.input_key] for d in sample[Fields.meta]]
# if not batched sample
sub_docs = nested_access(sample, self.input_key)
if not is_string_list(sub_docs):
logger.warning('Require string meta as input!')
return sample

sample = nested_set(sample, self.output_key,
self.attribute_summary(sub_docs, rank=rank))
sample[Fields.batch_meta][self.output_key] = self.attribute_summary(
sub_docs, rank=rank)

return sample
39 changes: 23 additions & 16 deletions data_juicer/ops/aggregator/most_relavant_entities_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from pydantic import PositiveInt

from data_juicer.ops.base_op import OPERATORS, Aggregator
from data_juicer.utils.common_utils import (is_string_list, nested_access,
nested_set)
from data_juicer.utils.common_utils import is_string_list
from data_juicer.utils.constant import BatchMetaKeys, Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model

from ..common import split_text_by_punctuation
Expand Down Expand Up @@ -44,8 +44,8 @@ def __init__(self,
api_model: str = 'gpt-4o',
entity: str = None,
query_entity_type: str = None,
input_key: str = None,
output_key: str = None,
input_key: str = MetaKeys.event_description,
output_key: str = BatchMetaKeys.most_relavant_entities,
max_token_num: Optional[PositiveInt] = None,
*,
api_endpoint: Optional[str] = None,
Expand All @@ -62,12 +62,10 @@ def __init__(self,
:param api_model: API model name.
:param entity: The given entity.
:param query_entity_type: The type of queried relavant entities.
:param input_key: The input field key in the samples. Support for
nested keys such as "__dj__stats__.text_len". It is text_key
in default.
:param output_key: The output field key in the samples. Support for
nested keys such as "__dj__stats__.text_len". It is same as the
input_key in default.
:param input_key: The input key in the meta field of the samples.
It is "event_description" in default.
:param output_key: The output key in the aggregation field of the
samples. It is "most_relavant_entities" in default.
:param max_token_num: The max token num of the total tokens of the
sub documents. Without limitation if it is None.
:param api_endpoint: URL endpoint for the API.
Expand All @@ -91,8 +89,8 @@ def __init__(self,

self.entity = entity
self.query_entity_type = query_entity_type
self.input_key = input_key or self.text_key
self.output_key = output_key or self.input_key
self.input_key = input_key
self.output_key = output_key
self.max_token_num = max_token_num

system_prompt_template = system_prompt_template or \
Expand Down Expand Up @@ -167,13 +165,22 @@ def query_most_relavant_entities(self, sub_docs, rank=None):

def process_single(self, sample=None, rank=None):

if self.output_key in sample[Fields.batch_meta]:
return sample

if Fields.meta not in sample or self.input_key not in sample[
Fields.meta][0]:
logger.warning('The input key does not exist in the sample!')
return sample

sub_docs = [d[self.input_key] for d in sample[Fields.meta]]

# if not batched sample
sub_docs = nested_access(sample, self.input_key)
if not is_string_list(sub_docs):
return sample

sample = nested_set(
sample, self.output_key,
self.query_most_relavant_entities(sub_docs, rank=rank))
sample[Fields.batch_meta][
self.output_key] = self.query_most_relavant_entities(sub_docs,
rank=rank)

return sample
29 changes: 19 additions & 10 deletions data_juicer/ops/aggregator/nested_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from data_juicer.ops.base_op import OPERATORS, Aggregator
from data_juicer.utils.common_utils import (avg_split_string_list_under_limit,
is_string_list, nested_access)
is_string_list)
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'nested_aggregator'
Expand Down Expand Up @@ -47,7 +48,7 @@ class NestedAggregator(Aggregator):

def __init__(self,
api_model: str = 'gpt-4o',
input_key: str = None,
input_key: str = MetaKeys.event_description,
output_key: str = None,
max_token_num: Optional[PositiveInt] = None,
*,
Expand All @@ -63,12 +64,10 @@ def __init__(self,
"""
Initialization method.
:param api_model: API model name.
:param input_key: The input field key in the samples. Support for
nested keys such as "__dj__stats__.text_len". It is text_key
in default.
:param output_key: The output field key in the samples. Support for
nested keys such as "__dj__stats__.text_len". It is same as the
input_key in default.
:param input_key: The input key in the meta field of the samples.
It is "event_description" in default.
:param output_key: The output key in the aggregation field in the
samples. It is same as the input_key in default.
:param max_token_num: The max token num of the total tokens of the
sub documents. Without limitation if it is None.
:param api_endpoint: URL endpoint for the API.
Expand Down Expand Up @@ -165,11 +164,21 @@ def recursive_summary(self, sub_docs, rank=None):

def process_single(self, sample=None, rank=None):

if self.output_key in sample[Fields.batch_meta]:
return sample

if Fields.meta not in sample or self.input_key not in sample[
Fields.meta][0]:
logger.warning('The input key does not exist in the sample!')
return sample

sub_docs = [d[self.input_key] for d in sample[Fields.meta]]

# if not batched sample
sub_docs = nested_access(sample, self.input_key)
if not is_string_list(sub_docs):
return sample

sample[self.output_key] = self.recursive_summary(sub_docs, rank=rank)
sample[Fields.batch_meta][self.output_key] = self.recursive_summary(
sub_docs, rank=rank)

return sample
11 changes: 11 additions & 0 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,17 @@ def process_single(self, sample):

def run(self, dataset, *, exporter=None, tracer=None):
dataset = super(Aggregator, self).run(dataset)
# add batched meta field for OPs that produce aggregations
if Fields.batch_meta not in dataset.features:
from data_juicer.core.data import add_same_content_to_new_column
dataset = dataset.map(add_same_content_to_new_column,
fn_kwargs={
'new_column_name': Fields.batch_meta,
'initial_value': {}
},
num_proc=self.runtime_np(),
batch_size=self.batch_size,
desc='Adding new column for aggregation')
new_dataset = dataset.map(
self.process,
num_proc=self.runtime_np(),
Expand Down
8 changes: 4 additions & 4 deletions data_juicer/ops/filter/video_tagging_from_frames_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from pydantic import PositiveInt

from data_juicer.utils.constant import Fields
from data_juicer.utils.constant import Fields, MetaKeys

from ..base_op import (NON_STATS_FILTERS, OPERATORS, TAGGING_OPS, UNFORKABLE,
Filter)
Expand All @@ -30,7 +30,7 @@ def __init__(self,
contain: str = 'any',
frame_sampling_method: str = 'all_keyframes',
frame_num: PositiveInt = 3,
tag_field_name: str = Fields.video_frame_tags,
tag_field_name: str = MetaKeys.video_frame_tags,
any_or_all: str = 'any',
*args,
**kwargs):
Expand All @@ -55,8 +55,8 @@ def __init__(self,
the first and the last frames will be extracted. If it's larger
than 2, in addition to the first and the last frames, other frames
will be extracted uniformly within the video duration.
:param tag_field_name: the field name to store the tags. It's
"__dj__video_frame_tags__" in default.
:param tag_field_name: the key name to store the tags in the meta
field. It's "video_frame_tags" in default.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all videos. 'any': keep this sample if any videos meet the
condition. 'all': keep this sample only if all videos meet the
Expand Down
24 changes: 23 additions & 1 deletion data_juicer/ops/grouper/naive_reverse_grouper.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,48 @@
import json
import os

from data_juicer.utils.constant import Fields
from data_juicer.utils.file_utils import create_directory_if_not_exists

from ..base_op import OPERATORS, Grouper, convert_dict_list_to_list_dict


@OPERATORS.register_module('naive_reverse_grouper')
class NaiveReverseGrouper(Grouper):
"""Split batched samples to samples. """

def __init__(self, *args, **kwargs):
def __init__(self, batch_meta_export_path=None, *args, **kwargs):
"""
Initialization method.

:param batch_meta_export_path: the path to export the batch meta.
Just drop the batch meta if it is None.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.batch_meta_export_path = batch_meta_export_path

def process(self, dataset):

if len(dataset) == 0:
return dataset

samples = []
batch_metas = []
for sample in dataset:
if Fields.batch_meta in sample:
batch_metas.append(sample[Fields.batch_meta])
sample = {
k: sample[k]
for k in sample if k != Fields.batch_meta
}
samples.extend(convert_dict_list_to_list_dict(sample))
if self.batch_meta_export_path is not None:
create_directory_if_not_exists(
os.path.dirname(self.batch_meta_export_path))
with open(self.batch_meta_export_path, 'w') as f:
for batch_meta in batch_metas:
f.write(json.dumps(batch_meta, ensure_ascii=False) + '\n')

return samples
28 changes: 19 additions & 9 deletions data_juicer/ops/mapper/dialog_intent_detection_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,21 @@
from loguru import logger
from pydantic import NonNegativeInt, PositiveInt

from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.utils.common_utils import nested_set
from data_juicer.ops.base_op import OPERATORS, TAGGING_OPS, Mapper
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'dialog_intent_detection_mapper'


# TODO: LLM-based inference.
@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class DialogIntentDetectionMapper(Mapper):
"""
Mapper to generate user's intent labels in dialog. Input from
history_key, query_key and response_key. Output lists of
labels and analysis for queries in the dialog, which is
store in 'dialog_intent_labels' and
'dialog_intent_labels_analysis' in Data-Juicer meta field.
labels and analysis for queries in the dialog.
"""

DEFAULT_SYSTEM_PROMPT = (
Expand Down Expand Up @@ -60,6 +58,8 @@ def __init__(self,
intent_candidates: Optional[List[str]] = None,
max_round: NonNegativeInt = 10,
*,
labels_key: str = MetaKeys.dialog_intent_labels,
analysis_key: str = MetaKeys.dialog_intent_labels_analysis,
api_endpoint: Optional[str] = None,
response_path: Optional[str] = None,
system_prompt: Optional[str] = None,
Expand All @@ -82,6 +82,11 @@ def __init__(self,
intent labels of the open domain if it is None.
:param max_round: The max num of round in the dialog to build the
prompt.
:param labels_key: The key name in the meta field to store the
output labels. It is 'dialog_intent_labels' in default.
:param analysis_key: The key name in the meta field to store the
corresponding analysis. It is 'dialog_intent_labels_analysis'
in default.
:param api_endpoint: URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'choices.0.message.content'.
Expand Down Expand Up @@ -111,6 +116,8 @@ def __init__(self,

self.intent_candidates = intent_candidates
self.max_round = max_round
self.labels_key = labels_key
self.analysis_key = analysis_key

self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
self.query_template = query_template or self.DEFAULT_QUERY_TEMPLATE
Expand Down Expand Up @@ -167,6 +174,11 @@ def parse_output(self, response):
return analysis, labels

def process_single(self, sample, rank=None):

meta = sample[Fields.meta]
if self.labels_key in meta and self.analysis_key in meta:
return sample

client = get_model(self.model_key, rank=rank)

analysis_list = []
Expand Down Expand Up @@ -208,9 +220,7 @@ def process_single(self, sample, rank=None):
history.append(self.labels_template.format(labels=labels))
history.append(self.response_template.format(response=qa[1]))

analysis_key = f'{Fields.meta}.{MetaKeys.dialog_intent_labels_analysis}' # noqa: E501
sample = nested_set(sample, analysis_key, analysis_list)
labels_key = f'{Fields.meta}.{MetaKeys.dialog_intent_labels}'
sample = nested_set(sample, labels_key, labels_list)
meta[self.labels_key] = labels_list
meta[self.analysis_key] = analysis_list

return sample
Loading
Loading