Skip to content

Commit

Permalink
Merge pull request #15 from dh031200/update
Browse files Browse the repository at this point in the history
fix add_extension method
  • Loading branch information
dh031200 authored May 10, 2023
2 parents 185070a + 382bbf2 commit e2e9d4a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
11 changes: 3 additions & 8 deletions src/dis_inference/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
pre_processing,
device,
post_processing,
get_name,
read,
check_params,
write,
)

Expand All @@ -26,19 +25,15 @@ def dis_inference(source, silent=False):
inference(source, save=True, silent=silent)


def inference(source: Union[str, np.ndarray], save=False, silent=False, output='output'):
def inference(source: Union[str, np.ndarray], save=False, silent=False, output=None):
"""
:param source: Source image for inference.
:param save: Whether to save output image.
:param silent: Whether to print verbose.
:param output: The name of output image file
:return: (numpy.ndarray)dichotomous segmentation image
"""
if type(source) == str:
output, extension = get_name(source)
source = read(source)
else:
extension = '.png' if not any([output.endswith('png'), output.endswith('jpg'), output.endswith('jpeg')]) else ''
source, output, extension = check_params(source, output)
net = init_model()
image = pre_processing(source).to(device)
result = net(image)
Expand Down
21 changes: 19 additions & 2 deletions src/dis_inference/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,24 @@ def write(path, result, silent):
cv2.imwrite(path, result)


def check_params(source, output):
if type(source) == str:
if output:
extension = (
'.png' if not any([output.endswith('png'), output.endswith('jpg'), output.endswith('jpeg')]) else ''
)
else:
output, extension = get_name(source)
source = read(source)
else:
if not output:
output = 'output'
extension = (
'.png' if not any([output.endswith('.png'), output.endswith('.jpg'), output.endswith('.jpeg')]) else ''
)
return source, output, extension


def pre_processing(image):
if len(image.shape) < 3:
image = image[:, :, np.newaxis]
Expand Down Expand Up @@ -125,8 +143,7 @@ def get_user_config_dir(sub_dir='DIS-inference'):

__all__ = (
"init_model",
"get_name",
"read",
"check_params",
"write",
"pre_processing",
"post_processing",
Expand Down

0 comments on commit e2e9d4a

Please sign in to comment.