add remove bg model selection
This commit is contained in:
@@ -20,8 +20,8 @@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "./ui/tabs"
|
||||
import { useEffect, useState } from "react"
|
||||
import { cn } from "@/lib/utils"
|
||||
import { useQuery } from "@tanstack/react-query"
|
||||
import { fetchModelInfos, switchModel } from "@/lib/api"
|
||||
import { ModelInfo } from "@/lib/types"
|
||||
import { getServerConfig, switchModel, switchPluginModel } from "@/lib/api"
|
||||
import { ModelInfo, PluginName } from "@/lib/types"
|
||||
import { useStore } from "@/lib/states"
|
||||
import { ScrollArea } from "./ui/scroll-area"
|
||||
import { useToast } from "./ui/use-toast"
|
||||
@@ -39,6 +39,14 @@ import {
|
||||
MODEL_TYPE_OTHER,
|
||||
} from "@/lib/const"
|
||||
import useHotKey from "@/hooks/useHotkey"
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectGroup,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "./ui/select"
|
||||
|
||||
const formSchema = z.object({
|
||||
enableFileManager: z.boolean(),
|
||||
@@ -48,42 +56,45 @@ const formSchema = z.object({
|
||||
enableManualInpainting: z.boolean(),
|
||||
enableUploadMask: z.boolean(),
|
||||
enableAutoExtractPrompt: z.boolean(),
|
||||
removeBGModel: z.string(),
|
||||
})
|
||||
|
||||
const TAB_GENERAL = "General"
|
||||
const TAB_MODEL = "Model"
|
||||
const TAB_PLUGINS = "Plugins"
|
||||
// const TAB_FILE_MANAGER = "File Manager"
|
||||
|
||||
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL]
|
||||
const TAB_NAMES = [TAB_MODEL, TAB_GENERAL, TAB_PLUGINS]
|
||||
|
||||
export function SettingsDialog() {
|
||||
const [open, toggleOpen] = useToggle(false)
|
||||
const [openModelSwitching, toggleOpenModelSwitching] = useToggle(false)
|
||||
const [tab, setTab] = useState(TAB_MODEL)
|
||||
const [
|
||||
updateAppState,
|
||||
settings,
|
||||
updateSettings,
|
||||
fileManagerState,
|
||||
updateFileManagerState,
|
||||
setAppModel,
|
||||
setServerConfig,
|
||||
] = useStore((state) => [
|
||||
state.updateAppState,
|
||||
state.settings,
|
||||
state.updateSettings,
|
||||
state.fileManagerState,
|
||||
state.updateFileManagerState,
|
||||
state.setModel,
|
||||
state.setServerConfig,
|
||||
])
|
||||
const { toast } = useToast()
|
||||
const [model, setModel] = useState<ModelInfo>(settings.model)
|
||||
const [modelSwitchingTexts, setModelSwitchingTexts] = useState<string[]>([])
|
||||
const openModelSwitching = modelSwitchingTexts.length > 0
|
||||
useEffect(() => {
|
||||
setModel(settings.model)
|
||||
}, [settings.model])
|
||||
|
||||
const { data: modelInfos, status } = useQuery({
|
||||
queryKey: ["modelInfos"],
|
||||
queryFn: fetchModelInfos,
|
||||
const { data: serverConfig, status } = useQuery({
|
||||
queryKey: ["serverConfig"],
|
||||
queryFn: getServerConfig,
|
||||
})
|
||||
|
||||
// 1. Define your form.
|
||||
@@ -96,9 +107,17 @@ export function SettingsDialog() {
|
||||
enableAutoExtractPrompt: settings.enableAutoExtractPrompt,
|
||||
inputDirectory: fileManagerState.inputDirectory,
|
||||
outputDirectory: fileManagerState.outputDirectory,
|
||||
removeBGModel: serverConfig?.removeBGModel,
|
||||
},
|
||||
})
|
||||
|
||||
useEffect(() => {
|
||||
if (serverConfig) {
|
||||
setServerConfig(serverConfig)
|
||||
form.setValue("removeBGModel", serverConfig.removeBGModel)
|
||||
}
|
||||
}, [form, serverConfig])
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
// Do something with the form values. ✅ This will be type-safe and validated.
|
||||
updateSettings({
|
||||
@@ -109,29 +128,67 @@ export function SettingsDialog() {
|
||||
})
|
||||
|
||||
// TODO: validate input/output Directory
|
||||
updateFileManagerState({
|
||||
inputDirectory: values.inputDirectory,
|
||||
outputDirectory: values.outputDirectory,
|
||||
})
|
||||
if (model.name !== settings.model.name) {
|
||||
toggleOpenModelSwitching()
|
||||
updateAppState({ disableShortCuts: true })
|
||||
try {
|
||||
const newModel = await switchModel(model.name)
|
||||
toast({
|
||||
title: `Switch to ${newModel.name} success`,
|
||||
})
|
||||
setAppModel(model)
|
||||
} catch (error: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
title: `Switch to ${model.name} failed: ${error}`,
|
||||
})
|
||||
setModel(settings.model)
|
||||
} finally {
|
||||
toggleOpenModelSwitching()
|
||||
updateAppState({ disableShortCuts: false })
|
||||
// updateFileManagerState({
|
||||
// inputDirectory: values.inputDirectory,
|
||||
// outputDirectory: values.outputDirectory,
|
||||
// })
|
||||
|
||||
const shouldSwitchModel = model.name !== settings.model.name
|
||||
const shouldSwitchRemoveBGModel =
|
||||
serverConfig?.removeBGModel !== values.removeBGModel
|
||||
const showModelSwitching = shouldSwitchModel || shouldSwitchRemoveBGModel
|
||||
|
||||
if (showModelSwitching) {
|
||||
const newModelSwitchingTexts: string[] = []
|
||||
if (shouldSwitchModel) {
|
||||
newModelSwitchingTexts.push(
|
||||
`Switching model from ${settings.model.name} to ${model.name}`
|
||||
)
|
||||
}
|
||||
if (shouldSwitchRemoveBGModel) {
|
||||
newModelSwitchingTexts.push(
|
||||
`Switching removebg model from ${serverConfig?.removeBGModel} to ${values.removeBGModel}`
|
||||
)
|
||||
}
|
||||
setModelSwitchingTexts(newModelSwitchingTexts)
|
||||
|
||||
updateAppState({ disableShortCuts: true })
|
||||
|
||||
if (shouldSwitchModel) {
|
||||
try {
|
||||
const newModel = await switchModel(model.name)
|
||||
toast({
|
||||
title: `Switch to ${newModel.name} success`,
|
||||
})
|
||||
setAppModel(model)
|
||||
} catch (error: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
title: `Switch to ${model.name} failed: ${error}`,
|
||||
})
|
||||
setModel(settings.model)
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldSwitchRemoveBGModel) {
|
||||
try {
|
||||
const res = await switchPluginModel(
|
||||
PluginName.RemoveBG,
|
||||
values.removeBGModel
|
||||
)
|
||||
if (res.status !== 200) {
|
||||
throw new Error(res.statusText)
|
||||
}
|
||||
} catch (error: any) {
|
||||
toast({
|
||||
variant: "destructive",
|
||||
title: `Switch removebg model to ${model.name} failed: ${error}`,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
setModelSwitchingTexts([])
|
||||
updateAppState({ disableShortCuts: false })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,7 +200,17 @@ export function SettingsDialog() {
|
||||
onSubmit(form.getValues())
|
||||
}
|
||||
},
|
||||
[open, form, model]
|
||||
[open, form, model, serverConfig]
|
||||
)
|
||||
|
||||
if (status !== "success") {
|
||||
return <></>
|
||||
}
|
||||
|
||||
const modelInfos = serverConfig.modelInfos
|
||||
const plugins = serverConfig.plugins
|
||||
const removeBGEnabled = plugins.some(
|
||||
(plugin) => plugin.name === PluginName.RemoveBG
|
||||
)
|
||||
|
||||
function onOpenChange(value: boolean) {
|
||||
@@ -186,10 +253,6 @@ export function SettingsDialog() {
|
||||
}
|
||||
|
||||
function renderModelSettings() {
|
||||
if (status !== "success") {
|
||||
return <></>
|
||||
}
|
||||
|
||||
let defaultTab = MODEL_TYPE_INPAINT
|
||||
for (let info of modelInfos) {
|
||||
if (model.name === info.name) {
|
||||
@@ -356,6 +419,44 @@ export function SettingsDialog() {
|
||||
)
|
||||
}
|
||||
|
||||
function renderPluginsSettings() {
|
||||
return (
|
||||
<div className="space-y-4 w-[510px]">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="removeBGModel"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex items-center justify-between">
|
||||
<div className="space-y-0.5">
|
||||
<FormLabel>Remove Background</FormLabel>
|
||||
<FormDescription>Remove background model</FormDescription>
|
||||
</div>
|
||||
<Select
|
||||
onValueChange={field.onChange}
|
||||
defaultValue={field.value}
|
||||
disabled={!removeBGEnabled}
|
||||
>
|
||||
<FormControl>
|
||||
<SelectTrigger className="w-[200px]">
|
||||
<SelectValue placeholder="Select removebg model" />
|
||||
</SelectTrigger>
|
||||
</FormControl>
|
||||
<SelectContent align="end">
|
||||
<SelectGroup>
|
||||
{serverConfig?.removeBGModels.map((model) => (
|
||||
<SelectItem key={model} value={model}>
|
||||
{model}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
// function renderFileManagerSettings() {
|
||||
// return (
|
||||
// <div className="flex flex-col justify-between rounded-lg gap-4 w-[400px]">
|
||||
@@ -446,7 +547,9 @@ export function SettingsDialog() {
|
||||
<span className="sr-only">Loading...</span>
|
||||
</div>
|
||||
|
||||
<div>Switching to {model.name}</div>
|
||||
{modelSwitchingTexts.map((text, index) => (
|
||||
<div key={index}>{text}</div>
|
||||
))}
|
||||
</div>
|
||||
{/* </AlertDialogDescription> */}
|
||||
</AlertDialogHeader>
|
||||
@@ -473,6 +576,7 @@ export function SettingsDialog() {
|
||||
<Button
|
||||
key={item}
|
||||
variant="ghost"
|
||||
disabled={item === TAB_PLUGINS && !removeBGEnabled}
|
||||
onClick={() => setTab(item)}
|
||||
className={cn(
|
||||
tab === item ? "bg-muted " : "hover:bg-muted",
|
||||
@@ -489,6 +593,7 @@ export function SettingsDialog() {
|
||||
<form onSubmit={form.handleSubmit(onSubmit)}>
|
||||
{tab === TAB_MODEL ? renderModelSettings() : <></>}
|
||||
{tab === TAB_GENERAL ? renderGeneralSettings() : <></>}
|
||||
{tab === TAB_PLUGINS ? renderPluginsSettings() : <></>}
|
||||
{/* {tab === TAB_FILE_MANAGER ? (
|
||||
renderFileManagerSettings()
|
||||
) : (
|
||||
|
||||
Reference in New Issue
Block a user