diff --git a/xinference/web/ui/package-lock.json b/xinference/web/ui/package-lock.json index 0730d3b275..7f15648e74 100644 --- a/xinference/web/ui/package-lock.json +++ b/xinference/web/ui/package-lock.json @@ -29,6 +29,7 @@ "@testing-library/user-event": "^13.5.0", "clipboard": "^2.0.11", "formik": "^2.4.2", + "nunjucks": "^3.2.4", "prop-types": "^15.8.1", "react": "^18.2.0", "react-cookie": "^6.1.1", @@ -5799,6 +5800,11 @@ "resolved": "https://registry.npmjs.org/@xtuc/long/-/long-4.2.2.tgz", "integrity": "sha512-NuHqBY1PB/D8xU6s/thBgOAiAP7HOYDQ32+BFZILJ8ivkUkAHQnWfn6WhL79Owj1qmUnoN/YPhktdIoucipkAQ==" }, + "node_modules/a-sync-waterfall": { + "version": "1.0.1", + "resolved": "https://registry.npmmirror.com/a-sync-waterfall/-/a-sync-waterfall-1.0.1.tgz", + "integrity": "sha512-RYTOHHdWipFUliRFMCS4X2Yn2X8M87V/OpSqWzKKOGhzqyUxzyVmhHDH9sAvG+ZuQf/TAOFsLCpMw09I1ufUnA==" + }, "node_modules/abab": { "version": "2.0.6", "resolved": "https://registry.npmjs.org/abab/-/abab-2.0.6.tgz", @@ -13750,6 +13756,38 @@ "url": "https://github.com/fb55/nth-check?sponsor=1" } }, + "node_modules/nunjucks": { + "version": "3.2.4", + "resolved": "https://registry.npmmirror.com/nunjucks/-/nunjucks-3.2.4.tgz", + "integrity": "sha512-26XRV6BhkgK0VOxfbU5cQI+ICFUtMLixv1noZn1tGU38kQH5A5nmmbk/O45xdyBhD1esk47nKrY0mvQpZIhRjQ==", + "dependencies": { + "a-sync-waterfall": "^1.0.0", + "asap": "^2.0.3", + "commander": "^5.1.0" + }, + "bin": { + "nunjucks-precompile": "bin/precompile" + }, + "engines": { + "node": ">= 6.9.0" + }, + "peerDependencies": { + "chokidar": "^3.3.0" + }, + "peerDependenciesMeta": { + "chokidar": { + "optional": true + } + } + }, + "node_modules/nunjucks/node_modules/commander": { + "version": "5.1.0", + "resolved": "https://registry.npmmirror.com/commander/-/commander-5.1.0.tgz", + "integrity": "sha512-P0CysNDQ7rtVw4QIQtm+MRxV66vKFSvlsQvGYXZWR3qFU0jlMKHZZZgw8e+8DSah4UDKMqnknRDQz+xuQXQ/Zg==", + "engines": { + "node": ">= 6" + } + }, "node_modules/nwsapi": { "version": "2.2.7", "resolved": "https://registry.npmjs.org/nwsapi/-/nwsapi-2.2.7.tgz", diff --git a/xinference/web/ui/package.json b/xinference/web/ui/package.json index 0a163ec52b..1bda015ba8 100644 --- a/xinference/web/ui/package.json +++ b/xinference/web/ui/package.json @@ -25,6 +25,7 @@ "@testing-library/user-event": "^13.5.0", "clipboard": "^2.0.11", "formik": "^2.4.2", + "nunjucks": "^3.2.4", "prop-types": "^15.8.1", "react": "^18.2.0", "react-cookie": "^6.1.1", diff --git a/xinference/web/ui/src/scenes/register_model/components/addStop.js b/xinference/web/ui/src/scenes/register_model/components/addStop.js new file mode 100644 index 0000000000..0acca09981 --- /dev/null +++ b/xinference/web/ui/src/scenes/register_model/components/addStop.js @@ -0,0 +1,107 @@ +import AddIcon from '@mui/icons-material/Add' +import DeleteIcon from '@mui/icons-material/Delete' +import { Alert, Button, TextField } from '@mui/material' +import React, { useEffect, useState } from 'react' + +const regex = /^[1-9]\d*$/ + +const AddStop = ({ label, onGetData, arrItemType, formData, onGetError }) => { + const [dataArr, setDataArr] = useState(formData?.length ? formData : ['']) + const arr = [] + + useEffect(() => { + if (arrItemType === 'number') { + const newDataArr = dataArr.map((item) => { + if (item && regex.test(item)) { + arr.push('true') + return Number(item) + } + if (item && !regex.test(item)) arr.push('false') + return item + }) + onGetError(arr) + onGetData(newDataArr) + } else { + onGetData(dataArr) + } + }, [dataArr]) + + const handleChange = (value, index) => { + const arr = [...dataArr] + arr[index] = value + setDataArr([...arr]) + } + + const handleAdd = () => { + if (dataArr[dataArr.length - 1]) { + setDataArr([...dataArr, '']) + } + } + + const handleDelete = (index) => { + setDataArr(dataArr.filter((_, subIndex) => index !== subIndex)) + } + + const handleShowAlert = (item) => { + return item !== '' && !regex.test(item) && arrItemType === 'number' + } + + return ( + <> +
+
+ + +
+
+ {dataArr.map((item, index) => ( +
+
+ handleChange(e.target.value, index)} + size="small" + style={{ width: '100%' }} + /> + {dataArr.length > 1 && ( + handleDelete(index)} + style={{ cursor: 'pointer', color: '#1976d2' }} + /> + )} +
+ + {handleShowAlert(item) && ( + + Please enter an integer greater than 0. + + )} +
+ ))} +
+
+ + ) +} + +export default AddStop diff --git a/xinference/web/ui/src/scenes/register_model/index.js b/xinference/web/ui/src/scenes/register_model/index.js index eb5b0a9e77..6aa0146bc9 100644 --- a/xinference/web/ui/src/scenes/register_model/index.js +++ b/xinference/web/ui/src/scenes/register_model/index.js @@ -63,7 +63,6 @@ const RegisterModel = () => { context_length: 2048, model_lang: ['en'], model_ability: ['generate'], - model_family: '', model_specs: [ { model_uri: '/path/to/llama-1', @@ -72,7 +71,7 @@ const RegisterModel = () => { quantizations: ['none'], }, ], - prompt_style: undefined, + model_family: '', }} /> diff --git a/xinference/web/ui/src/scenes/register_model/registerModel.js b/xinference/web/ui/src/scenes/register_model/registerModel.js index 06cc582927..717587d6d6 100644 --- a/xinference/web/ui/src/scenes/register_model/registerModel.js +++ b/xinference/web/ui/src/scenes/register_model/registerModel.js @@ -1,14 +1,20 @@ import './styles/registerModelStyle.css' -import CheckIcon from '@mui/icons-material/Check' +import Cancel from '@mui/icons-material/Cancel' +import CheckCircleIcon from '@mui/icons-material/CheckCircle' import KeyboardDoubleArrowRightIcon from '@mui/icons-material/KeyboardDoubleArrowRight' import NotesIcon from '@mui/icons-material/Notes' +import OpenInFullIcon from '@mui/icons-material/OpenInFull' import { Alert, Box, Button, Checkbox, Chip, + Dialog, + DialogActions, + DialogContent, + DialogTitle, FormControl, FormControlLabel, InputLabel, @@ -21,6 +27,7 @@ import { TextField, Tooltip, } from '@mui/material' +import nunjucks from 'nunjucks' import React, { useContext, useEffect, useRef, useState } from 'react' import { useCookies } from 'react-cookie' import { useNavigate, useParams } from 'react-router-dom' @@ -31,22 +38,27 @@ import fetchWrapper from '../../components/fetchWrapper' import { isValidBearerToken } from '../../components/utils' import AddControlnet from './components/addControlnet' import AddModelSpecs from './components/addModelSpecs' +import AddStop from './components/addStop' import languages from './data/languages' const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' } const SUPPORTED_FEATURES = ['Generate', 'Chat', 'Vision'] +const messages = [ + { + role: 'assistant', + content: 'This is the message content replied by the assistant previously', + }, + { + role: 'user', + content: 'This is the message content sent by the user currently', + }, +] // Convert dictionary of supported languages into list const SUPPORTED_LANGUAGES = Object.keys(SUPPORTED_LANGUAGES_DICT) const RegisterModelComponent = ({ modelType, customData }) => { - const endPoint = useContext(ApiContext).endPoint const { setErrorMsg } = useContext(ApiContext) const [formData, setFormData] = useState(customData) - const [promptStyles, setPromptStyles] = useState([]) - const [family, setFamily] = useState({ - chat: [], - generate: [], - }) const [languagesArr, setLanguagesArr] = useState([]) const [isContextLengthAlert, setIsContextLengthAlert] = useState(false) const [isDimensionsAlert, setIsDimensionsAlert] = useState(false) @@ -73,6 +85,11 @@ const RegisterModelComponent = ({ modelType, customData }) => { ) const [contrastObj, setContrastObj] = useState({}) const [isEqual, setIsEqual] = useState(true) + const [testRes, setTestRes] = useState('') + const [isOpenMessages, setIsOpenMessages] = useState(false) + const [testErrorInfo, setTestErrorInfo] = useState('') + const [isTestSuccess, setIsTestSuccess] = useState(false) + const [isStopTokenIdsAlert, setIsStopTokenIdsAlert] = useState(false) useEffect(() => { if (model_name) { @@ -93,7 +110,9 @@ const RegisterModelComponent = ({ modelType, customData }) => { model_ability, model_family, model_specs, - prompt_style, + chat_template, + stop_token_ids, + stop, } = data const specsDataArr = model_specs.map((item) => { const { @@ -120,8 +139,10 @@ const RegisterModelComponent = ({ modelType, customData }) => { model_ability, model_family, model_specs: specsDataArr, + chat_template, + stop_token_ids, + stop, } - prompt_style ? (llmData.prompt_style = prompt_style) : '' setFormData(llmData) setContrastObj(llmData) setSpecsArr(specsDataArr) @@ -217,79 +238,6 @@ const RegisterModelComponent = ({ modelType, customData }) => { navigate('/login', { replace: true }) return } - - const getBuiltinFamilies = async () => { - const response = await fetch(endPoint + '/v1/models/families', { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - }, - }) - if (!response.ok) { - const errorData = await response.json() // Assuming the server returns error details in JSON format - setErrorMsg( - `Server error: ${response.status} - ${ - errorData.detail || 'Unknown error' - }` - ) - } else { - const data = await response.json() - data.chat.push('other') - data.generate.push('other') - setFamily(data) - } - } - - const getBuiltInPromptStyles = async () => { - const response = await fetch(endPoint + '/v1/models/prompts', { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - }, - }) - if (!response.ok) { - const errorData = await response.json() // Assuming the server returns error details in JSON format - setErrorMsg( - `Server error: ${response.status} - ${ - errorData.detail || 'Unknown error' - }` - ) - } else { - const data = await response.json() - let res = [] - for (const key in data) { - let v = data[key] - v['name'] = key - res.push(v) - } - setPromptStyles(res) - } - } - - if ( - Object.prototype.hasOwnProperty.call(customData, 'model_ability') && - Object.prototype.hasOwnProperty.call(customData, 'model_family') - ) { - // avoid keep requesting backend to get prompts - if (promptStyles.length === 0) { - getBuiltInPromptStyles().catch((error) => { - setErrorMsg( - error.message || - 'An unexpected error occurred when getting builtin prompt styles.' - ) - console.error('Error: ', error) - }) - } - if (family.chat.length === 0) { - getBuiltinFamilies().catch((error) => { - setErrorMsg( - error.message || - 'An unexpected error occurred when getting builtin prompt styles.' - ) - console.error('Error: ', error) - }) - } - } }, [cookie.token]) useEffect(() => { @@ -299,34 +247,7 @@ const RegisterModelComponent = ({ modelType, customData }) => { } }, [formData]) - const getFamilyByAbility = () => { - if ( - formData.model_ability.includes('chat') || - formData.model_ability.includes('vision') - ) { - return family.chat - } else { - return family.generate - } - } - - const sortStringsByFirstLetter = (arr) => { - return arr.sort((a, b) => { - const firstCharA = a.charAt(0).toLowerCase() - const firstCharB = b.charAt(0).toLowerCase() - if (firstCharA < firstCharB) { - return -1 - } - if (firstCharA > firstCharB) { - return 1 - } - return 0 - }) - } - const handleClick = async () => { - console.log('formData', modelType, formData) - for (let key in formData) { const type = Object.prototype.toString.call(formData[key]).slice(8, -1) if ( @@ -427,61 +348,26 @@ const RegisterModelComponent = ({ modelType, customData }) => { } const toggleAbility = (ability) => { + const obj = JSON.parse(JSON.stringify(formData)) if (formData.model_ability.includes(ability)) { - const obj = JSON.parse(JSON.stringify(formData)) if (ability === 'chat') { - delete obj.prompt_style + delete obj.chat_template + delete obj.stop_token_ids + delete obj.stop } setFormData({ ...obj, model_ability: formData.model_ability.filter((a) => a !== ability), - model_family: '', }) } else { - setFormData({ - ...formData, - model_ability: [...formData.model_ability, ability], - model_family: '', - }) - } - } - - const toggleFamily = (value) => { - const ps = promptStyles.find((item) => item.name === value) - if (formData.model_ability.includes('chat') && ps) { - const prompt_style = { - style_name: ps.style_name, - system_prompt: ps.system_prompt, - roles: ps.roles, - intra_message_sep: ps.intra_message_sep, - inter_message_sep: ps.inter_message_sep, - stop: ps.stop ?? null, - stop_token_ids: ps.stop_token_ids ?? null, + if (ability === 'chat') { + obj.chat_template = '' + obj.stop_token_ids = [] + obj.stop = [] } setFormData({ - ...formData, - model_family: value, - prompt_style, - }) - } else { - const { - version, - model_name, - model_description, - context_length, - model_lang, - model_ability, - model_specs, - } = formData - setFormData({ - version, - model_name, - model_description, - context_length, - model_lang, - model_ability, - model_family: value, - model_specs, + ...obj, + model_ability: [...formData.model_ability, ability], }) } } @@ -569,6 +455,58 @@ const RegisterModelComponent = ({ modelType, customData }) => { return true } + const handleTest = () => { + setTestRes('') + if (formData.chat_template) { + try { + nunjucks.configure({ autoescape: false }) + const test_res = nunjucks.renderString(formData.chat_template, { + messages: messages, + }) + if (test_res === '') { + setTestRes(test_res) + setTestErrorInfo('error') + setIsTestSuccess(false) + } else { + setTestRes(test_res) + setTestErrorInfo('') + setIsTestSuccess(true) + } + } catch (error) { + setTestErrorInfo(`${error}`) + setIsTestSuccess(false) + } + } + } + + const getStopTokenIds = (value) => { + if (value.length === 1 && value[0] === '') { + setFormData({ + ...formData, + stop_token_ids: [], + }) + } else { + setFormData({ + ...formData, + stop_token_ids: value, + }) + } + } + + const getStop = (value) => { + if (value.length === 1 && value[0] === '') { + setFormData({ + ...formData, + stop: [], + }) + } else { + setFormData({ + ...formData, + stop: value, + }) + } + } + return (
@@ -845,66 +783,162 @@ const RegisterModelComponent = ({ modelType, customData }) => { {/* family */} {(customData.model_family === '' || customData.model_family) && ( - - - {modelType === 'LLM' && formData.model_family && ( - } - severity="success" - > - Please be careful to select the family name corresponding to - the model you want to register. If not found, please choose - other - . - - )} - {modelType === 'LLM' && !formData.model_family && ( - - Please be careful to select the family name corresponding to - the model you want to register. If not found, please choose - other - . - + <> + {modelType === 'LLM' && ( + <> + + setFormData({ + ...formData, + model_family: event.target.value, + }) + } + /> + + )} - { - toggleFamily(e.target.value) - }} - > - - {modelType === 'LLM' && - sortStringsByFirstLetter(getFamilyByAbility()).map((v) => ( - + {(modelType === 'image' || modelType === 'audio') && ( + <> + + + + } - label={v} + label={formData.model_family} /> - ))} - {(modelType === 'image' || modelType === 'audio') && ( - } - label={formData.model_family} + + + + + )} + + )} + + {/* chat_template */} + {formData.model_ability?.includes('chat') && ( + <> +
+ + setFormData({ + ...formData, + chat_template: event.target.value, + }) + } + style={{ flex: 1 }} + /> + +
+
+ messages example + setIsOpenMessages(true)} + style={{ fontSize: 14, color: '#666', cursor: 'pointer' }} /> - )} - - +
+
+ + test result + {testErrorInfo ? ( + + ) : testRes ? ( + + ) : ( + '' + )} + +
+ {testErrorInfo !== '' + ? testErrorInfo + : testRes + ? testRes + : 'No test results...'} +
+
+
+
+ + + )} + + {/* stop_token_ids */} + {formData.model_ability?.includes('chat') && ( + <> + { + if (value.includes('false')) { + setIsStopTokenIdsAlert(true) + } else { + setIsStopTokenIdsAlert(false) + } + }} + /> -
+ + )} + + {/* stop */} + {formData.model_ability?.includes('chat') && ( + <> + + + )} {/* specs */} @@ -1011,6 +1045,21 @@ const RegisterModelComponent = ({ modelType, customData }) => { color="primary" type="submit" onClick={handleClick} + disabled={ + isContextLengthAlert || + isDimensionsAlert || + isMaxTokensAlert || + formData.model_lang?.length === 0 || + formData.language?.length === 0 || + formData.model_ability?.length === 0 || + (modelType === 'LLM' && !formData.model_family) || + (formData.model_ability?.includes('chat') && + !formData.chat_template) || + (formData.model_ability?.includes('chat') && + formData.chat_template && + !isTestSuccess) || + isStopTokenIdsAlert + } > Register Model @@ -1018,6 +1067,32 @@ const RegisterModelComponent = ({ modelType, customData }) => { )}
+ setIsOpenMessages(false)} + aria-labelledby="alert-dialog-title" + aria-describedby="alert-dialog-description" + > + Messages Example + +