diff --git a/xinference/web/ui/package-lock.json b/xinference/web/ui/package-lock.json
index 0730d3b275..7f15648e74 100644
--- a/xinference/web/ui/package-lock.json
+++ b/xinference/web/ui/package-lock.json
@@ -29,6 +29,7 @@
"@testing-library/user-event": "^13.5.0",
"clipboard": "^2.0.11",
"formik": "^2.4.2",
+ "nunjucks": "^3.2.4",
"prop-types": "^15.8.1",
"react": "^18.2.0",
"react-cookie": "^6.1.1",
@@ -5799,6 +5800,11 @@
"resolved": "https://registry.npmjs.org/@xtuc/long/-/long-4.2.2.tgz",
"integrity": "sha512-NuHqBY1PB/D8xU6s/thBgOAiAP7HOYDQ32+BFZILJ8ivkUkAHQnWfn6WhL79Owj1qmUnoN/YPhktdIoucipkAQ=="
},
+ "node_modules/a-sync-waterfall": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmmirror.com/a-sync-waterfall/-/a-sync-waterfall-1.0.1.tgz",
+ "integrity": "sha512-RYTOHHdWipFUliRFMCS4X2Yn2X8M87V/OpSqWzKKOGhzqyUxzyVmhHDH9sAvG+ZuQf/TAOFsLCpMw09I1ufUnA=="
+ },
"node_modules/abab": {
"version": "2.0.6",
"resolved": "https://registry.npmjs.org/abab/-/abab-2.0.6.tgz",
@@ -13750,6 +13756,38 @@
"url": "https://github.com/fb55/nth-check?sponsor=1"
}
},
+ "node_modules/nunjucks": {
+ "version": "3.2.4",
+ "resolved": "https://registry.npmmirror.com/nunjucks/-/nunjucks-3.2.4.tgz",
+ "integrity": "sha512-26XRV6BhkgK0VOxfbU5cQI+ICFUtMLixv1noZn1tGU38kQH5A5nmmbk/O45xdyBhD1esk47nKrY0mvQpZIhRjQ==",
+ "dependencies": {
+ "a-sync-waterfall": "^1.0.0",
+ "asap": "^2.0.3",
+ "commander": "^5.1.0"
+ },
+ "bin": {
+ "nunjucks-precompile": "bin/precompile"
+ },
+ "engines": {
+ "node": ">= 6.9.0"
+ },
+ "peerDependencies": {
+ "chokidar": "^3.3.0"
+ },
+ "peerDependenciesMeta": {
+ "chokidar": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/nunjucks/node_modules/commander": {
+ "version": "5.1.0",
+ "resolved": "https://registry.npmmirror.com/commander/-/commander-5.1.0.tgz",
+ "integrity": "sha512-P0CysNDQ7rtVw4QIQtm+MRxV66vKFSvlsQvGYXZWR3qFU0jlMKHZZZgw8e+8DSah4UDKMqnknRDQz+xuQXQ/Zg==",
+ "engines": {
+ "node": ">= 6"
+ }
+ },
"node_modules/nwsapi": {
"version": "2.2.7",
"resolved": "https://registry.npmjs.org/nwsapi/-/nwsapi-2.2.7.tgz",
diff --git a/xinference/web/ui/package.json b/xinference/web/ui/package.json
index 0a163ec52b..1bda015ba8 100644
--- a/xinference/web/ui/package.json
+++ b/xinference/web/ui/package.json
@@ -25,6 +25,7 @@
"@testing-library/user-event": "^13.5.0",
"clipboard": "^2.0.11",
"formik": "^2.4.2",
+ "nunjucks": "^3.2.4",
"prop-types": "^15.8.1",
"react": "^18.2.0",
"react-cookie": "^6.1.1",
diff --git a/xinference/web/ui/src/scenes/register_model/components/addStop.js b/xinference/web/ui/src/scenes/register_model/components/addStop.js
new file mode 100644
index 0000000000..0acca09981
--- /dev/null
+++ b/xinference/web/ui/src/scenes/register_model/components/addStop.js
@@ -0,0 +1,107 @@
+import AddIcon from '@mui/icons-material/Add'
+import DeleteIcon from '@mui/icons-material/Delete'
+import { Alert, Button, TextField } from '@mui/material'
+import React, { useEffect, useState } from 'react'
+
+const regex = /^[1-9]\d*$/
+
+const AddStop = ({ label, onGetData, arrItemType, formData, onGetError }) => {
+ const [dataArr, setDataArr] = useState(formData?.length ? formData : [''])
+ const arr = []
+
+ useEffect(() => {
+ if (arrItemType === 'number') {
+ const newDataArr = dataArr.map((item) => {
+ if (item && regex.test(item)) {
+ arr.push('true')
+ return Number(item)
+ }
+ if (item && !regex.test(item)) arr.push('false')
+ return item
+ })
+ onGetError(arr)
+ onGetData(newDataArr)
+ } else {
+ onGetData(dataArr)
+ }
+ }, [dataArr])
+
+ const handleChange = (value, index) => {
+ const arr = [...dataArr]
+ arr[index] = value
+ setDataArr([...arr])
+ }
+
+ const handleAdd = () => {
+ if (dataArr[dataArr.length - 1]) {
+ setDataArr([...dataArr, ''])
+ }
+ }
+
+ const handleDelete = (index) => {
+ setDataArr(dataArr.filter((_, subIndex) => index !== subIndex))
+ }
+
+ const handleShowAlert = (item) => {
+ return item !== '' && !regex.test(item) && arrItemType === 'number'
+ }
+
+ return (
+ <>
+
+
+
+ }
+ className="addBtn"
+ onClick={handleAdd}
+ >
+ more
+
+
+
+ {dataArr.map((item, index) => (
+
+
+ handleChange(e.target.value, index)}
+ size="small"
+ style={{ width: '100%' }}
+ />
+ {dataArr.length > 1 && (
+ handleDelete(index)}
+ style={{ cursor: 'pointer', color: '#1976d2' }}
+ />
+ )}
+
+
+ {handleShowAlert(item) && (
+
+ Please enter an integer greater than 0.
+
+ )}
+
+ ))}
+
+
+ >
+ )
+}
+
+export default AddStop
diff --git a/xinference/web/ui/src/scenes/register_model/index.js b/xinference/web/ui/src/scenes/register_model/index.js
index eb5b0a9e77..6aa0146bc9 100644
--- a/xinference/web/ui/src/scenes/register_model/index.js
+++ b/xinference/web/ui/src/scenes/register_model/index.js
@@ -63,7 +63,6 @@ const RegisterModel = () => {
context_length: 2048,
model_lang: ['en'],
model_ability: ['generate'],
- model_family: '',
model_specs: [
{
model_uri: '/path/to/llama-1',
@@ -72,7 +71,7 @@ const RegisterModel = () => {
quantizations: ['none'],
},
],
- prompt_style: undefined,
+ model_family: '',
}}
/>
diff --git a/xinference/web/ui/src/scenes/register_model/registerModel.js b/xinference/web/ui/src/scenes/register_model/registerModel.js
index 06cc582927..717587d6d6 100644
--- a/xinference/web/ui/src/scenes/register_model/registerModel.js
+++ b/xinference/web/ui/src/scenes/register_model/registerModel.js
@@ -1,14 +1,20 @@
import './styles/registerModelStyle.css'
-import CheckIcon from '@mui/icons-material/Check'
+import Cancel from '@mui/icons-material/Cancel'
+import CheckCircleIcon from '@mui/icons-material/CheckCircle'
import KeyboardDoubleArrowRightIcon from '@mui/icons-material/KeyboardDoubleArrowRight'
import NotesIcon from '@mui/icons-material/Notes'
+import OpenInFullIcon from '@mui/icons-material/OpenInFull'
import {
Alert,
Box,
Button,
Checkbox,
Chip,
+ Dialog,
+ DialogActions,
+ DialogContent,
+ DialogTitle,
FormControl,
FormControlLabel,
InputLabel,
@@ -21,6 +27,7 @@ import {
TextField,
Tooltip,
} from '@mui/material'
+import nunjucks from 'nunjucks'
import React, { useContext, useEffect, useRef, useState } from 'react'
import { useCookies } from 'react-cookie'
import { useNavigate, useParams } from 'react-router-dom'
@@ -31,22 +38,27 @@ import fetchWrapper from '../../components/fetchWrapper'
import { isValidBearerToken } from '../../components/utils'
import AddControlnet from './components/addControlnet'
import AddModelSpecs from './components/addModelSpecs'
+import AddStop from './components/addStop'
import languages from './data/languages'
const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' }
const SUPPORTED_FEATURES = ['Generate', 'Chat', 'Vision']
+const messages = [
+ {
+ role: 'assistant',
+ content: 'This is the message content replied by the assistant previously',
+ },
+ {
+ role: 'user',
+ content: 'This is the message content sent by the user currently',
+ },
+]
// Convert dictionary of supported languages into list
const SUPPORTED_LANGUAGES = Object.keys(SUPPORTED_LANGUAGES_DICT)
const RegisterModelComponent = ({ modelType, customData }) => {
- const endPoint = useContext(ApiContext).endPoint
const { setErrorMsg } = useContext(ApiContext)
const [formData, setFormData] = useState(customData)
- const [promptStyles, setPromptStyles] = useState([])
- const [family, setFamily] = useState({
- chat: [],
- generate: [],
- })
const [languagesArr, setLanguagesArr] = useState([])
const [isContextLengthAlert, setIsContextLengthAlert] = useState(false)
const [isDimensionsAlert, setIsDimensionsAlert] = useState(false)
@@ -73,6 +85,11 @@ const RegisterModelComponent = ({ modelType, customData }) => {
)
const [contrastObj, setContrastObj] = useState({})
const [isEqual, setIsEqual] = useState(true)
+ const [testRes, setTestRes] = useState('')
+ const [isOpenMessages, setIsOpenMessages] = useState(false)
+ const [testErrorInfo, setTestErrorInfo] = useState('')
+ const [isTestSuccess, setIsTestSuccess] = useState(false)
+ const [isStopTokenIdsAlert, setIsStopTokenIdsAlert] = useState(false)
useEffect(() => {
if (model_name) {
@@ -93,7 +110,9 @@ const RegisterModelComponent = ({ modelType, customData }) => {
model_ability,
model_family,
model_specs,
- prompt_style,
+ chat_template,
+ stop_token_ids,
+ stop,
} = data
const specsDataArr = model_specs.map((item) => {
const {
@@ -120,8 +139,10 @@ const RegisterModelComponent = ({ modelType, customData }) => {
model_ability,
model_family,
model_specs: specsDataArr,
+ chat_template,
+ stop_token_ids,
+ stop,
}
- prompt_style ? (llmData.prompt_style = prompt_style) : ''
setFormData(llmData)
setContrastObj(llmData)
setSpecsArr(specsDataArr)
@@ -217,79 +238,6 @@ const RegisterModelComponent = ({ modelType, customData }) => {
navigate('/login', { replace: true })
return
}
-
- const getBuiltinFamilies = async () => {
- const response = await fetch(endPoint + '/v1/models/families', {
- method: 'GET',
- headers: {
- 'Content-Type': 'application/json',
- },
- })
- if (!response.ok) {
- const errorData = await response.json() // Assuming the server returns error details in JSON format
- setErrorMsg(
- `Server error: ${response.status} - ${
- errorData.detail || 'Unknown error'
- }`
- )
- } else {
- const data = await response.json()
- data.chat.push('other')
- data.generate.push('other')
- setFamily(data)
- }
- }
-
- const getBuiltInPromptStyles = async () => {
- const response = await fetch(endPoint + '/v1/models/prompts', {
- method: 'GET',
- headers: {
- 'Content-Type': 'application/json',
- },
- })
- if (!response.ok) {
- const errorData = await response.json() // Assuming the server returns error details in JSON format
- setErrorMsg(
- `Server error: ${response.status} - ${
- errorData.detail || 'Unknown error'
- }`
- )
- } else {
- const data = await response.json()
- let res = []
- for (const key in data) {
- let v = data[key]
- v['name'] = key
- res.push(v)
- }
- setPromptStyles(res)
- }
- }
-
- if (
- Object.prototype.hasOwnProperty.call(customData, 'model_ability') &&
- Object.prototype.hasOwnProperty.call(customData, 'model_family')
- ) {
- // avoid keep requesting backend to get prompts
- if (promptStyles.length === 0) {
- getBuiltInPromptStyles().catch((error) => {
- setErrorMsg(
- error.message ||
- 'An unexpected error occurred when getting builtin prompt styles.'
- )
- console.error('Error: ', error)
- })
- }
- if (family.chat.length === 0) {
- getBuiltinFamilies().catch((error) => {
- setErrorMsg(
- error.message ||
- 'An unexpected error occurred when getting builtin prompt styles.'
- )
- console.error('Error: ', error)
- })
- }
- }
}, [cookie.token])
useEffect(() => {
@@ -299,34 +247,7 @@ const RegisterModelComponent = ({ modelType, customData }) => {
}
}, [formData])
- const getFamilyByAbility = () => {
- if (
- formData.model_ability.includes('chat') ||
- formData.model_ability.includes('vision')
- ) {
- return family.chat
- } else {
- return family.generate
- }
- }
-
- const sortStringsByFirstLetter = (arr) => {
- return arr.sort((a, b) => {
- const firstCharA = a.charAt(0).toLowerCase()
- const firstCharB = b.charAt(0).toLowerCase()
- if (firstCharA < firstCharB) {
- return -1
- }
- if (firstCharA > firstCharB) {
- return 1
- }
- return 0
- })
- }
-
const handleClick = async () => {
- console.log('formData', modelType, formData)
-
for (let key in formData) {
const type = Object.prototype.toString.call(formData[key]).slice(8, -1)
if (
@@ -427,61 +348,26 @@ const RegisterModelComponent = ({ modelType, customData }) => {
}
const toggleAbility = (ability) => {
+ const obj = JSON.parse(JSON.stringify(formData))
if (formData.model_ability.includes(ability)) {
- const obj = JSON.parse(JSON.stringify(formData))
if (ability === 'chat') {
- delete obj.prompt_style
+ delete obj.chat_template
+ delete obj.stop_token_ids
+ delete obj.stop
}
setFormData({
...obj,
model_ability: formData.model_ability.filter((a) => a !== ability),
- model_family: '',
})
} else {
- setFormData({
- ...formData,
- model_ability: [...formData.model_ability, ability],
- model_family: '',
- })
- }
- }
-
- const toggleFamily = (value) => {
- const ps = promptStyles.find((item) => item.name === value)
- if (formData.model_ability.includes('chat') && ps) {
- const prompt_style = {
- style_name: ps.style_name,
- system_prompt: ps.system_prompt,
- roles: ps.roles,
- intra_message_sep: ps.intra_message_sep,
- inter_message_sep: ps.inter_message_sep,
- stop: ps.stop ?? null,
- stop_token_ids: ps.stop_token_ids ?? null,
+ if (ability === 'chat') {
+ obj.chat_template = ''
+ obj.stop_token_ids = []
+ obj.stop = []
}
setFormData({
- ...formData,
- model_family: value,
- prompt_style,
- })
- } else {
- const {
- version,
- model_name,
- model_description,
- context_length,
- model_lang,
- model_ability,
- model_specs,
- } = formData
- setFormData({
- version,
- model_name,
- model_description,
- context_length,
- model_lang,
- model_ability,
- model_family: value,
- model_specs,
+ ...obj,
+ model_ability: [...formData.model_ability, ability],
})
}
}
@@ -569,6 +455,58 @@ const RegisterModelComponent = ({ modelType, customData }) => {
return true
}
+ const handleTest = () => {
+ setTestRes('')
+ if (formData.chat_template) {
+ try {
+ nunjucks.configure({ autoescape: false })
+ const test_res = nunjucks.renderString(formData.chat_template, {
+ messages: messages,
+ })
+ if (test_res === '') {
+ setTestRes(test_res)
+ setTestErrorInfo('error')
+ setIsTestSuccess(false)
+ } else {
+ setTestRes(test_res)
+ setTestErrorInfo('')
+ setIsTestSuccess(true)
+ }
+ } catch (error) {
+ setTestErrorInfo(`${error}`)
+ setIsTestSuccess(false)
+ }
+ }
+ }
+
+ const getStopTokenIds = (value) => {
+ if (value.length === 1 && value[0] === '') {
+ setFormData({
+ ...formData,
+ stop_token_ids: [],
+ })
+ } else {
+ setFormData({
+ ...formData,
+ stop_token_ids: value,
+ })
+ }
+ }
+
+ const getStop = (value) => {
+ if (value.length === 1 && value[0] === '') {
+ setFormData({
+ ...formData,
+ stop: [],
+ })
+ } else {
+ setFormData({
+ ...formData,
+ stop: value,
+ })
+ }
+ }
+
return (
@@ -845,66 +783,162 @@ const RegisterModelComponent = ({ modelType, customData }) => {
{/* family */}
{(customData.model_family === '' || customData.model_family) && (
-
-
- {modelType === 'LLM' && formData.model_family && (
- }
- severity="success"
- >
- Please be careful to select the family name corresponding to
- the model you want to register. If not found, please choose
- other
- .
-
- )}
- {modelType === 'LLM' && !formData.model_family && (
-
- Please be careful to select the family name corresponding to
- the model you want to register. If not found, please choose
- other
- .
-
+ <>
+ {modelType === 'LLM' && (
+ <>
+
+ setFormData({
+ ...formData,
+ model_family: event.target.value,
+ })
+ }
+ />
+
+ >
)}
- {
- toggleFamily(e.target.value)
- }}
- >
-
- {modelType === 'LLM' &&
- sortStringsByFirstLetter(getFamilyByAbility()).map((v) => (
-
+ {(modelType === 'image' || modelType === 'audio') && (
+ <>
+
+
+
+
}
- label={v}
+ label={formData.model_family}
/>
- ))}
- {(modelType === 'image' || modelType === 'audio') && (
- }
- label={formData.model_family}
+
+
+
+ >
+ )}
+ >
+ )}
+
+ {/* chat_template */}
+ {formData.model_ability?.includes('chat') && (
+ <>
+
+
+ setFormData({
+ ...formData,
+ chat_template: event.target.value,
+ })
+ }
+ style={{ flex: 1 }}
+ />
+
+
+
+ messages example
+ setIsOpenMessages(true)}
+ style={{ fontSize: 14, color: '#666', cursor: 'pointer' }}
/>
- )}
-
-
+
+
+
+ test result
+ {testErrorInfo ? (
+
+ ) : testRes ? (
+
+ ) : (
+ ''
+ )}
+
+
+ {testErrorInfo !== ''
+ ? testErrorInfo
+ : testRes
+ ? testRes
+ : 'No test results...'}
+
+
+
+
+
+ >
+ )}
+
+ {/* stop_token_ids */}
+ {formData.model_ability?.includes('chat') && (
+ <>
+ {
+ if (value.includes('false')) {
+ setIsStopTokenIdsAlert(true)
+ } else {
+ setIsStopTokenIdsAlert(false)
+ }
+ }}
+ />
-
+ >
+ )}
+
+ {/* stop */}
+ {formData.model_ability?.includes('chat') && (
+ <>
+
+
+ >
)}
{/* specs */}
@@ -1011,6 +1045,21 @@ const RegisterModelComponent = ({ modelType, customData }) => {
color="primary"
type="submit"
onClick={handleClick}
+ disabled={
+ isContextLengthAlert ||
+ isDimensionsAlert ||
+ isMaxTokensAlert ||
+ formData.model_lang?.length === 0 ||
+ formData.language?.length === 0 ||
+ formData.model_ability?.length === 0 ||
+ (modelType === 'LLM' && !formData.model_family) ||
+ (formData.model_ability?.includes('chat') &&
+ !formData.chat_template) ||
+ (formData.model_ability?.includes('chat') &&
+ formData.chat_template &&
+ !isTestSuccess) ||
+ isStopTokenIdsAlert
+ }
>
Register Model
@@ -1018,6 +1067,32 @@ const RegisterModelComponent = ({ modelType, customData }) => {
)}
+
+
{/* JSON */}
diff --git a/xinference/web/ui/src/scenes/register_model/styles/registerModelStyle.css b/xinference/web/ui/src/scenes/register_model/styles/registerModelStyle.css
index e7d8b9fd68..7d4c167bb9 100644
--- a/xinference/web/ui/src/scenes/register_model/styles/registerModelStyle.css
+++ b/xinference/web/ui/src/scenes/register_model/styles/registerModelStyle.css
@@ -119,3 +119,26 @@
font-size: 28px !important;
color: #fff;
}
+
+.chat_template_box {
+ display: flex;
+ align-items: center;
+ gap: 10px;
+}
+
+.chat_template_test {
+ height: 137px;
+ width: 30%;
+ padding: 10px;
+ border: 1px solid #ccc;
+ border-radius: 4px;
+ overflow: scroll;
+}
+
+.test_res_box {
+ background-color: #eee;
+ min-height: 55px;
+ padding: 10px;
+ margin-top: 5px;
+ border-radius: 4px;
+}