Skip to content

Commit

Permalink
Refactor pull #15
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed Jul 28, 2024
1 parent b24fc67 commit 436fa6e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 7 deletions.
4 changes: 2 additions & 2 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ show <- function(name, verbose = FALSE, output = c("jsonlist", "resp", "raw"), e
delete <- function(name, endpoint = "/api/delete", host = NULL) {
if (!model_avail(name)) {
message("Available models listed below.")
print(list_models(output = 'text', host = host))
print(list_models(output = "text", host = host))
return(invisible())
}

Expand Down Expand Up @@ -424,7 +424,7 @@ pull <- function(name, stream = TRUE, insecure = FALSE, endpoint = "/api/pull",
req <- create_request(endpoint, host)
req <- httr2::req_method(req, "POST")

body_json <- list(name = name, stream = stream, insecure = insecure)
body_json <- list(name = name, insecure = insecure)
req <- httr2::req_body_json(req, body_json, stream = stream)

if (!stream) {
Expand Down
29 changes: 24 additions & 5 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ stream_handler <- function(x, env, endpoint) {
#' resp_process(resp, "resp") # return input response object
#' resp_process(resp, "text") # return text/character vector
resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text")) {

if (is.null(resp) || resp$status_code != 200) {
warning("Cannot process response")
return(NULL)
}

endpoints_to_skip <- c("api/pull")
endpoints_to_skip <- c("api/delete")
for (endpoint in endpoints_to_skip) {
if (grepl(endpoint, resp$url)) {
message("Returning response object because resp_process not supported for this endpoint.")
return(resp)
}
}
Expand All @@ -105,7 +105,7 @@ resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text
}

# endpoints that should never be processed with resp_process_stream
endpoints_without_stream <- c("api/tags", "api/delete", "api/show", "api/pull")
endpoints_without_stream <- c("api/tags", "api/delete", "api/show")

# process stream resp separately
stream <- FALSE
Expand Down Expand Up @@ -248,8 +248,27 @@ resp_process_stream <- function(resp, output) {
if (output[1] == "text") {
return(paste0(df_response$content, collapse = ""))
}
} else if (grepl("api/tags", resp$url)) { # process tags endpoint
return(NULL) # TODO fill in
} else if (grepl("api/pull", resp$url)) {
json_lines <- strsplit(rawToChar(resp$body), "\n")[[1]]
json_lines_output <- vector("list", length = length(json_lines))
df_response <- tibble::tibble(
status = character(length(json_lines_output)),
)

for (i in seq_along(json_lines)) {
json_lines_output[[i]] <- jsonlite::fromJSON(json_lines[[i]])
df_response$status[i] <- json_lines_output[[i]]$status
}

if (output[1] == "jsonlist") {
return(json_lines_output)
}
if (output[1] == "df") {
return(df_response)
}
if (output[1] == "text") {
return(paste0(df_response$status, collapse = ""))
}
}
}

Expand Down
36 changes: 36 additions & 0 deletions tests/testthat/test-pull.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ library(ollamar)
test_that("pull function works", {
skip_if_not(test_connection()$status_code == 200, "Ollama server not available")

# streaming is TRUE by default
# wrong model
result <- pull('WRONGMODEL')
expect_s3_class(result, "httr2_response")
Expand All @@ -15,5 +16,40 @@ test_that("pull function works", {
expect_s3_class(result, "httr2_response")
expect_equal(result$status_code, 200)
expect_vector(result$body)

expect_s3_class(result, "httr2_response")
expect_s3_class(resp_process(result), "data.frame")
expect_s3_class(resp_process(result, "df"), "data.frame")
expect_type(resp_process(result, "text"), "character")
expect_type(resp_process(result, "raw"), "character")
expect_type(resp_process(result, "jsonlist"), "list")

# streaming is FALSE
result <- pull('WRONGMODEL', stream = FALSE)
expect_s3_class(result, "httr2_response")
expect_equal(result$status_code, 200)
expect_vector(result$body)

# correct model
result <- pull('llama3', stream = FALSE)
# for this endpoint, even when stream = FALSE, the response is chunked)
expect_true(httr2::resp_headers(result)$`Transfer-Encoding` == "chunked")
expect_s3_class(result, "httr2_response")
expect_equal(result$status_code, 200)
expect_vector(result$body)

expect_s3_class(result, "httr2_response")
expect_s3_class(resp_process(result), "data.frame")
expect_s3_class(resp_process(result, "df"), "data.frame")
expect_type(resp_process(result, "text"), "character")
expect_type(resp_process(result, "raw"), "character")
expect_type(resp_process(result, "jsonlist"), "list")

# insecure parameter
expect_s3_class(pull('llama3', stream = FALSE, insecure = TRUE), "httr2_response")
expect_s3_class(pull('sdafd', stream = FALSE, insecure = FALSE), "httr2_response")
expect_s3_class(pull('sdafd', stream = TRUE, insecure = TRUE), "httr2_response")
expect_s3_class(pull('sdafd', stream = TRUE, insecure = FALSE), "httr2_response")

})

0 comments on commit 436fa6e

Please sign in to comment.