diff --git a/.gitignore b/.gitignore index 9c15483..b3c6eb3 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ /tests/data/from_list.star /tests/data/test_overwrite_flag.star /tests/data/test_write_with_float_format.star +/tests/data/test_overwrite_backup.star /build/ /dist/ /m2relion/ diff --git a/src/starfile/__init__.py b/src/starfile/__init__.py index 0501c47..679a0e7 100644 --- a/src/starfile/__init__.py +++ b/src/starfile/__init__.py @@ -1 +1 @@ -from .functions import read, write +from .functions import read, write, to_string diff --git a/src/starfile/functions.py b/src/starfile/functions.py index b563f8a..792f666 100644 --- a/src/starfile/functions.py +++ b/src/starfile/functions.py @@ -81,4 +81,39 @@ def write( separator=sep, quote_character=quote_character, quote_all_strings=quote_all_strings, + ).write() + + +def to_string( + data: Union[DataBlock, Dict[str, DataBlock], List[DataBlock]], + float_format: str = '%.6f', + sep: str = '\t', + na_rep: str = '', + quote_character: str = '"', + quote_all_strings: bool = False, + **kwargs +): + """Represent data in the STAR format. + + Parameters + ---------- + data: DataBlock | Dict[str, DataBlock] | List[DataBlock] + Data to represent. DataBlocks are dictionaries or dataframes. + If a dictionary of datablocks are passed the keys will be the data block names. + float_format: str + Float format string which will be passed to pandas. + sep: str + Separator between values, will be passed to pandas. + na_rep: str + Representation of null values, will be passed to pandas. + """ + writer = StarWriter( + data, + filename=None, + float_format=float_format, + na_rep=na_rep, + separator=sep, + quote_character=quote_character, + quote_all_strings=quote_all_strings, ) + return ''.join(line + '\n' for line in writer.lines()) diff --git a/src/starfile/writer.py b/src/starfile/writer.py index 70843f3..be8d784 100644 --- a/src/starfile/writer.py +++ b/src/starfile/writer.py @@ -2,7 +2,7 @@ from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Union, Dict, List +from typing import TYPE_CHECKING, Union, Dict, List, Generator, Optional from importlib.metadata import version import csv @@ -21,7 +21,7 @@ class StarWriter: def __init__( self, data_blocks: Union[DataBlock, Dict[str, DataBlock], List[DataBlock]], - filename: PathLike, + filename: Optional[PathLike] = None, float_format: str = '%.6f', separator: str = '\t', na_rep: str = '', @@ -31,8 +31,10 @@ def __init__( # coerce data self.data_blocks = self.coerce_data_blocks(data_blocks) - # write - self.filename = Path(filename) + if filename is not None: + self.filename = Path(filename) + else: + self.filename = None self.float_format = float_format self.sep = separator self.na_rep = na_rep @@ -40,7 +42,6 @@ def __init__( self.quote_all_strings = quote_all_strings self.buffer = TextBuffer() self.backup_if_file_exists() - self.write() def coerce_data_blocks( self, @@ -61,24 +62,31 @@ def coerce_data_blocks( got {type(data_blocks)}' ) + def lines(self) -> Generator[str, None, None]: + yield package_info() + yield '' + yield '' + for line in self.data_block_generator(): + yield line + def write(self): - write_package_info(self.filename) - write_blank_lines(self.filename, n=2) - self.write_data_blocks() + if self.filename is None: + raise ValueError('Cannot write nameless file!') + with open(self.filename, mode='w+') as f: + f.writelines(line + '\n' for line in self.lines()) - def write_data_blocks(self): + def data_block_generator(self) -> Generator[str, None, None]: for block_name, block in self.data_blocks.items(): if isinstance(block, dict): - write_simple_block( - file=self.filename, + for line in simple_block( block_name=block_name, data=block, quote_character=self.quote_character, quote_all_strings=self.quote_all_strings - ) + ): + yield line elif isinstance(block, pd.DataFrame): - write_loop_block( - file=self.filename, + for line in loop_block( block_name=block_name, df=block, float_format=self.float_format, @@ -86,10 +94,11 @@ def write_data_blocks(self): na_rep=self.na_rep, quote_character=self.quote_character, quote_all_strings=self.quote_all_strings - ) + ): + yield line def backup_if_file_exists(self): - if self.filename.exists(): + if self.filename and self.filename.exists(): new_name = self.filename.name + '~' backup_path = self.filename.resolve().parent / new_name if backup_path.exists(): @@ -118,48 +127,36 @@ def coerce_list(data_blocks: List[DataBlock]) -> Dict[str, DataBlock]: return {f'{idx}': df for idx, df in enumerate(data_blocks)} -def write_blank_lines(file: Path, n: int): - with open(file, mode='a') as f: - f.write('\n' * n) - - -def write_package_info(file: Path): +def package_info(): date = datetime.now().strftime('%d/%m/%Y') time = datetime.now().strftime('%H:%M:%S') - line = f'# Created by the starfile Python package (version {__version__}) at {time} on {date}' - with open(file, mode='w+') as f: - f.write(f'{line}\n') + return f'# Created by the starfile Python package (version {__version__}) at {time} on {date}' + +def quote(x, *, + quote_character: str = '"', + quote_all_strings: bool = False) -> str: + if isinstance(x, str) and (quote_all_strings or ' ' in x or not x): + return f'{quote_character}{x}{quote_character}' + return x -def write_simple_block( - file: Path, + +def simple_block( block_name: str, data: Dict[str, Union[str, int, float]], quote_character: str = '"', quote_all_strings: bool = False -): - quoted_data = { - k: f"{quote_character}{v}{quote_character}" - if isinstance(v, str) and (quote_all_strings or " " in v or v == "") - else v - for k, v - in data.items() - } - formatted_lines = '\n'.join( - [ - f'_{k}\t\t\t{v}' - for k, v - in quoted_data.items() - ] - ) - with open(file, mode='a') as f: - f.write(f'data_{block_name}\n\n') - f.write(formatted_lines) - f.write('\n\n\n') - - -def write_loop_block( - file: Path, +) -> Generator[str, None, None]: + + yield f'data_{block_name}' + yield '' + for k, v in data.items(): + yield f'_{k}\t\t\t{quote(v, quote_character=quote_character, quote_all_strings=quote_all_strings)}' + yield '' + yield '' + + +def loop_block( block_name: str, df: pd.DataFrame, float_format: str = '%.6f', @@ -167,26 +164,21 @@ def write_loop_block( na_rep: str = '', quote_character: str = '"', quote_all_strings: bool = False -): - # write header - header_lines = [ - f'_{column_name} #{idx}' - for idx, column_name - in enumerate(df.columns, 1) - ] - with open(file, mode='a') as f: - f.write(f'data_{block_name}\n\n') - f.write('loop_\n') - f.write('\n'.join(header_lines)) - f.write('\n') - - df = df.map(lambda x: f'{quote_character}{x}{quote_character}' - if isinstance(x, str) and (quote_all_strings or " " in x or x == "") - else x) - - # write data - df.to_csv( - path_or_buf=file, +) -> Generator[str, None, None]: + + # Header + yield f'data_{block_name}' + yield '' + yield 'loop_' + for idx, column_name in enumerate(df.columns, 1): + yield f'_{column_name} #{idx}' + + # Data + for line in df.map(lambda x: + quote(x, + quote_character=quote_character, + quote_all_strings=quote_all_strings) + ).to_csv( mode='a', sep=separator, header=False, @@ -194,5 +186,8 @@ def write_loop_block( float_format=float_format, na_rep=na_rep, quoting=csv.QUOTE_NONE - ) - write_blank_lines(file, n=2) + ).split('\n'): + yield line + + yield '' + yield '' diff --git a/tests/test_functional_interface.py b/tests/test_functional_interface.py index a609362..2f24c40 100644 --- a/tests/test_functional_interface.py +++ b/tests/test_functional_interface.py @@ -47,3 +47,11 @@ def test_read_non_existent_file(): with pytest.raises(FileNotFoundError): starfile.read(f) + + +def test_generate_string(): + star_string = starfile.to_string(test_df) + output_file = test_data_directory / "test_write.star" + starfile.write(test_df, output_file, overwrite=True) + with open(output_file, "r") as f: + assert f.read() == star_string diff --git a/tests/test_read_write_round_trip.py b/tests/test_read_write_round_trip.py index a86544b..68e8d54 100644 --- a/tests/test_read_write_round_trip.py +++ b/tests/test_read_write_round_trip.py @@ -42,8 +42,8 @@ def test_round_trip_postprocess(tmp_path): assert _actual == _expected -def test_write_read_write_read(): - filename = 'tmp.star' +def test_write_read_write_read(tmp_path): + filename = tmp_path / 'tmp.star' df_a = pd.DataFrame({'a': [0, 1], 'b': [2, 3]}) starfile.write(df_a, filename) diff --git a/tests/test_writing.py b/tests/test_writing.py index 1396520..8e0ee00 100644 --- a/tests/test_writing.py +++ b/tests/test_writing.py @@ -14,28 +14,28 @@ def test_write_simple_block(): s = StarParser(postprocess) output_file = test_data_directory / 'basic_block.star' - StarWriter(s.data_blocks, output_file) + StarWriter(s.data_blocks, output_file).write() assert output_file.exists() def test_write_loop(): s = StarParser(loop_simple) output_file = test_data_directory / 'loop_block.star' - StarWriter(s.data_blocks, output_file) + StarWriter(s.data_blocks, output_file).write() assert output_file.exists() def test_write_multiblock(): s = StarParser(postprocess) output_file = test_data_directory / 'multiblock.star' - StarWriter(s.data_blocks, output_file) + StarWriter(s.data_blocks, output_file).write() assert output_file.exists() def test_from_single_dataframe(): output_file = test_data_directory / 'from_df.star' - StarWriter(test_df, output_file) + StarWriter(test_df, output_file).write() assert output_file.exists() s = StarParser(output_file) @@ -45,7 +45,7 @@ def test_create_from_dataframes(): dfs = [test_df, test_df] output_file = test_data_directory / 'from_list.star' - StarWriter(dfs, output_file) + StarWriter(dfs, output_file).write() assert output_file.exists() s = StarParser(output_file) @@ -59,7 +59,7 @@ def test_can_write_non_zero_indexed_one_row_dataframe(): with TemporaryDirectory() as directory: filename = join_path(directory, "test.star") - StarWriter(df, filename) + StarWriter(df, filename).write() with open(filename) as output_file: output = output_file.read() @@ -83,7 +83,7 @@ def test_string_quoting_loop_datablock(quote_character, quote_all_strings, num_q columns=["a_number","string_without_space", "string_space", "just_space", "empty_string"]) filename = tmp_path / "test.star" - StarWriter(df, filename, quote_character=quote_character, quote_all_strings=quote_all_strings) + StarWriter(df, filename, quote_character=quote_character, quote_all_strings=quote_all_strings).write() # Test for the appropriate number of quotes with open(filename) as f: @@ -118,7 +118,7 @@ def test_string_quoting_simple_datablock(quote_character, quote_all_strings,num_ } filename = tmp_path / "test.star" - StarWriter(o, filename, quote_character=quote_character, quote_all_strings=quote_all_strings) + StarWriter(o, filename, quote_character=quote_character, quote_all_strings=quote_all_strings).write() # Test for the appropriate number of quotes with open(filename) as f: @@ -127,3 +127,8 @@ def test_string_quoting_simple_datablock(quote_character, quote_all_strings,num_ s = StarParser(filename) assert o == s.data_blocks[""] + + +def test_no_filename_error(): + with pytest.raises(ValueError): + StarWriter(test_df).write()