Skip to content

Commit

Permalink
BUG: custom model related bugs (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
UranusSeven authored Aug 17, 2023
1 parent 7bfb004 commit 5b53b52
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
4 changes: 3 additions & 1 deletion xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def _install():
user_defined_llm_dir = os.path.join(XINFERENCE_MODEL_DIR, "llm")
if os.path.isdir(user_defined_llm_dir):
for f in os.listdir(user_defined_llm_dir):
with codecs.open(f, encoding="utf-8") as fd:
with codecs.open(
os.path.join(user_defined_llm_dir, f), encoding="utf-8"
) as fd:
user_defined_llm_family = LLMFamilyV1.parse_obj(json.load(fd))
register_llm(user_defined_llm_family, persist=False)
21 changes: 20 additions & 1 deletion xinference/model/llm/llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def copy(
f"-{llm_spec.model_size_in_billions}b"
)
cache_dir = os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, cache_dir_name))
if os.path.exists(cache_dir):
return cache_dir

assert llm_spec.model_uri is not None
src_scheme, src_root = parse_uri(llm_spec.model_uri)
Expand Down Expand Up @@ -183,7 +185,6 @@ def copy(
src_path = f"{path}/{file}"
local_path = src_path.replace(src_root, cache_dir)
files_to_download.append((src_path, local_path))
print(files_to_download)

from concurrent.futures import ThreadPoolExecutor

Expand Down Expand Up @@ -350,6 +351,24 @@ def unregister_llm(model_name: str):
)
if os.path.exists(persist_path):
os.remove(persist_path)

llm_spec = llm_family.model_specs[0]
cache_dir_name = (
f"{llm_family.model_name}-{llm_spec.model_format}"
f"-{llm_spec.model_size_in_billions}b"
)
cache_dir = os.path.join(XINFERENCE_CACHE_DIR, cache_dir_name)
if os.path.exists(cache_dir):
logger.warning(
f"Remove the cache of user-defined model {llm_family.model_name}. "
f"Cache directory: {cache_dir}"
)
if os.path.islink(cache_dir):
os.remove(cache_dir)
else:
logger.warning(
f"Cache directory is not a soft link, please remove it manually."
)
else:
raise ValueError(f"Model {model_name} not found")

Expand Down

0 comments on commit 5b53b52

Please sign in to comment.