Skip to content

Commit

Permalink
feat: changes the way arguments are collected and fixes system prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
Npahlfer committed Dec 27, 2023
1 parent 217274d commit 13d8a15
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ooo"
version = "0.1.0"
version = "0.1.1"
edition = "2021"

[dependencies]
Expand Down
74 changes: 33 additions & 41 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,6 @@ enum PromptResponse {
Error(String),
}

impl PromptResponse {
fn as_str(&self) -> &str {
match self {
PromptResponse::Response(res) => res.as_str(),
PromptResponse::Error(res) => res.as_str(),
}
}
}

impl std::fmt::Display for PromptResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand All @@ -48,7 +39,7 @@ impl std::fmt::Display for PromptResponse {
}

#[tokio::main]
async fn main() -> Result<(), io::Error> {
async fn main() -> Result<(), String> {
let Arguments {
system,
user,
Expand All @@ -74,18 +65,26 @@ async fn main() -> Result<(), io::Error> {
system
};

let stdin_input = format!(". Input: {}", stdin_text);
let prompt = format!("{} {}.\n{} {}{}", SYSTEM_PROMPT, USER_PROMPT, system, user, stdin_input);
let res = prompt_ollama(prompt, &ollama, model.to_string()).await;
let stdin_input = if stdin_text.is_empty() {
format!(". Input: {}", stdin_text)
} else {
String::new()
};

if let PromptResponse::Error(res) = res {
println!("error prompting ollama with model {}", model);
return Err(io::Error::new(io::ErrorKind::Other, res));
}
let prompt = format!(
"{} {}.\n{} {}{}",
SYSTEM_PROMPT, system, USER_PROMPT, user, stdin_input
);

output_to_stdout(res.as_str());
let res = prompt_ollama(prompt, &ollama, model.to_string()).await;

Ok(())
match res {
PromptResponse::Error(_) => Err("unable to request Ollama, is Ollama running?".to_string()),
PromptResponse::Response(res) => {
output_to_stdout(res.as_str());
Ok(())
}
}
}

fn read_stdin_lines() -> String {
Expand All @@ -106,22 +105,6 @@ fn output_to_stdout(output: &str) {
stdout.flush().unwrap();
}

fn continue_gathering_args(
args: &mut std::iter::Skip<std::env::Args>,
) -> bool {
if let Some(arg) = args.next() {
if arg == SYSTEM_FLAG
|| arg == USER_FLAG
|| arg == MODEL_FLAG
|| arg == URL_FLAG
|| arg == PORT_FLAG
{
return true;
}
}
false
}

fn get_parsed_arguments() -> Arguments {
let mut args = std::env::args().skip(1);
let mut system = Vec::new();
Expand All @@ -130,22 +113,32 @@ fn get_parsed_arguments() -> Arguments {
let mut url = None;
let mut port = None;

let mut last_active_flag = None;

while let Some(arg) = args.next() {
match arg.as_str() {
SYSTEM_FLAG => {
while continue_gathering_args(&mut args) {
system.push(args.next().unwrap_or_default())
}
system.push(args.next().unwrap());
last_active_flag = Some(SYSTEM_FLAG);
}
USER_FLAG => {
user.push(args.next().unwrap());
last_active_flag = Some(USER_FLAG);
}
USER_FLAG => user.push(args.next().unwrap_or_default()),
MODEL_FLAG => model = Some(args.next().unwrap_or_default()),
URL_FLAG => url = Some(args.next().unwrap_or_default()),
PORT_FLAG => {
if let Some(port_str) = args.next() {
port = port_str.parse::<u16>().ok();
}
}
_ => user.push(arg),
_ => {
if last_active_flag == Some(SYSTEM_FLAG) {
system.push(arg)
} else if last_active_flag == Some(USER_FLAG) {
user.push(arg)
}
}
}
}

Expand Down Expand Up @@ -184,4 +177,3 @@ fn get_default_system_prompt() -> String {
]
.join(" ")
}

0 comments on commit 13d8a15

Please sign in to comment.