Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add adaptors for ollama and openai #12

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class LlmLsCompletionProvider: InlineCompletionProvider {
val queryParams = settings.queryParams
val fimParams = settings.fim
val tokenizerConfig = settings.tokenizer
val params = CompletionParams(textDocument, position, request_params = queryParams, fim = fimParams, api_token = secrets.getSecretSetting(), model = settings.model, tokens_to_clear = settings.tokensToClear, tokenizer_config = tokenizerConfig, context_window = settings.contextWindow)
val params = CompletionParams(textDocument, position, request_params = queryParams, fim = fimParams, api_token = secrets.getSecretSetting(), model = settings.model, tokens_to_clear = settings.tokensToClear, tokenizer_config = tokenizerConfig, context_window = settings.contextWindow, adaptor = settings.adaptor, request_body = settings.requestBody)
lspServer.requestExecutor.sendRequestAsync(LlmLsGetCompletionsRequest(lspServer, params)) { response ->
CoroutineScope(Dispatchers.Default).launch {
if (response != null) {
Expand All @@ -57,4 +57,4 @@ class LlmLsCompletionProvider: InlineCompletionProvider {
val settings = LlmSettingsState.instance
return settings.ghostTextEnabled
}
}
}
52 changes: 52 additions & 0 deletions src/main/kotlin/co/huggingface/llmintellij/LlmSettingsComponent.kt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class LlmSettingsComponent {
private val contextWindowLabel: JBLabel
private val contextWindow: JBTextField
private val enableGhostText: JBCheckBox
private val adaptorLabel: JBLabel
private val adaptor: JComboBox<String>
private val requestBodyModelLabel: JBLabel
private val requestBodyModel: JBTextField

init {
// Used for all text fields with browse button
Expand Down Expand Up @@ -218,6 +222,16 @@ class LlmSettingsComponent {
llmLsSubsectionPanel.add(lspLogLevelLabel)
llmLsSubsectionPanel.add(lspLogLevel)

val adaptorPanel = createSectionPanel("Adaptors", rootPanel)
val adaptorOptions = arrayOf("huggingface", "tgi", "ollama", "openai")
adaptorLabel = JBLabel("Adaptor provider")
adaptor = JComboBox(adaptorOptions)
adaptorPanel.add(adaptorLabel)
adaptorPanel.add(adaptor)
requestBodyModelLabel = JBLabel("Adaptor provider API request model")
requestBodyModel = JBTextField("")
adaptorPanel.add(requestBodyModelLabel)
adaptorPanel.add(requestBodyModel)
}

val preferredFocusedComponent: JComponent
Expand Down Expand Up @@ -415,6 +429,44 @@ class LlmSettingsComponent {
contextWindow.text = value.toString()
}

fun getAdaptor(): String? {
var adaptorValue = adaptor.getItemAt(adaptor.selectedIndex)
return if ( adaptorValue == "" ) {
null
} else {
adaptorValue
}
}

fun setAdaptor(adaptorValue: String) {
adaptor.selectedItem = adaptorValue
}

fun getRequestBody(): RequestBody {
return RequestBody(model = getRequestBodyModel())
}

fun getRequestBodyModel(): String? {
var model = requestBodyModel.text
return if (model == "") {
null
} else {
model
}
}

fun setRequestBody(requestBody: RequestBody) {
var model = requestBody.model
when (model) {
null -> {
requestBodyModel.text = ""
}
else -> {
requestBodyModel.text = model
}
}
}

private fun createSectionPanel(title: String, parentPanel: JPanel): JPanel {
val panel = JPanel()
panel.setLayout(VerticalLayout(5))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class LlmSettingsConfigurable : Configurable {
modified = modified or (settingsComponent?.getLspLogLevel() != settings.lsp.logLevel)
modified = modified or (settingsComponent?.getTokenizerConfig() != settings.tokenizer)
modified = modified or (settingsComponent?.getContextWindow() != settings.contextWindow)
modified = modified or (settingsComponent?.getAdaptor() != settings.adaptor)
modified = modified or (settingsComponent?.getRequestBody() != settings.requestBody)
return modified
}

Expand All @@ -65,6 +67,8 @@ class LlmSettingsConfigurable : Configurable {
settings.lsp.logLevel = settingsComponent?.getLspLogLevel() ?: ""
settings.tokenizer = settingsComponent?.getTokenizerConfig()
settings.contextWindow = settingsComponent?.getContextWindow() ?: 0u
settings.adaptor = settingsComponent?.getAdaptor()
settings.requestBody = settingsComponent?.getRequestBody()
}

override fun reset() {
Expand All @@ -86,9 +90,11 @@ class LlmSettingsConfigurable : Configurable {
settingsComponent?.setLspLogLevel(settings.lsp.logLevel)
settingsComponent?.setTokenizerConfig(settings.tokenizer)
settingsComponent?.setContextWindow(settings.contextWindow)
settingsComponent?.setAdaptor(settings.adaptor ?: "")
settingsComponent?.setRequestBody(settings.requestBody ?: RequestBody(model = null))
}

override fun disposeUIResources() {
settingsComponent = null
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ sealed class TokenizerConfig {
data class Download(val url: String, val to: String) : TokenizerConfig()
}

class RequestBody(
var model: String? = null,
)

@State(
name = "co.huggingface.llmintellij.LlmSettingsState",
storages = [Storage("LlmSettingsPlugin.xml")]
Expand All @@ -48,6 +52,8 @@ class LlmSettingsState: PersistentStateComponent<LlmSettingsState?> {
var lsp = LspSettings()
var tokenizer: TokenizerConfig? = TokenizerConfig.HuggingFace("bigcode/starcoder")
var contextWindow = 8192u
var adaptor: String? = "huggingface"
var requestBody: RequestBody? = null

override fun getState(): LlmSettingsState {
return this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package co.huggingface.llmintellij.lsp
import co.huggingface.llmintellij.FimParams
import co.huggingface.llmintellij.QueryParams
import co.huggingface.llmintellij.TokenizerConfig
import co.huggingface.llmintellij.RequestBody
import org.eclipse.lsp4j.TextDocumentIdentifier
import org.eclipse.lsp4j.jsonrpc.services.JsonRequest
import org.eclipse.lsp4j.jsonrpc.services.JsonSegment
Expand All @@ -25,10 +26,12 @@ class CompletionParams(
val tokenizer_config: TokenizerConfig?,
val context_window: UInt,
val tls_skip_verify_insecure: Boolean = false,
val adaptor: String?,
val request_body: RequestBody?,
)

@JsonSegment("llm-ls")
public interface LlmLsLanguageServer: LanguageServer {
@JsonRequest
fun getCompletions(params: CompletionParams): CompletableFuture<CompletionResponse>;
}
}