Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: 🐛 Llama Index imports and track costs and token counts in the class #47

Merged
merged 8 commits into from
Jun 4, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ pip install `'tokencost[llama-index]'`
To use the base callback handler, you may import it:

```python
from tokencost.callbacks.llama_index import BaseCallbackHandler
from tokencost.callbacks.llama_index import TokenCostHandler
```

and pass to your framework callback handler.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dev = [
"coverage[toml]>=7.4.0",
]
llama-index = [
"llama-index>=0.9.24"
"llama-index>=0.10.23"
]

[project.urls]
Expand Down
64 changes: 32 additions & 32 deletions tests/test_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,21 +131,21 @@ def test_count_string_invalid_model():
@pytest.mark.parametrize(
"prompt,model,expected_output",
[
(MESSAGES, "gpt-3.5-turbo", Decimal('0.0000225')),
(MESSAGES, "gpt-3.5-turbo-0301", Decimal('0.0000255')),
(MESSAGES, "gpt-3.5-turbo-0613", Decimal('0.0000225')),
(MESSAGES, "gpt-3.5-turbo-16k", Decimal('0.000045')),
(MESSAGES, "gpt-3.5-turbo-16k-0613", Decimal('0.000045')),
(MESSAGES, "gpt-3.5-turbo-1106", Decimal('0.000015')),
(MESSAGES, "gpt-3.5-turbo-instruct", Decimal('0.0000225')),
(MESSAGES, "gpt-4", Decimal('0.00045')),
(MESSAGES, "gpt-4-0314", Decimal('0.00045')),
(MESSAGES, "gpt-4-32k", Decimal('0.00090')),
(MESSAGES, "gpt-4-32k-0314", Decimal('0.00090')),
(MESSAGES, "gpt-4-0613", Decimal('0.00045')),
(MESSAGES, "gpt-4-1106-preview", Decimal('0.00015')),
(MESSAGES, "gpt-4-vision-preview", Decimal('0.00015')),
(STRING, "text-embedding-ada-002", Decimal('0.0000004')),
(MESSAGES, "gpt-3.5-turbo", Decimal("0.0000225")),
(MESSAGES, "gpt-3.5-turbo-0301", Decimal("0.0000255")),
(MESSAGES, "gpt-3.5-turbo-0613", Decimal("0.0000225")),
(MESSAGES, "gpt-3.5-turbo-16k", Decimal("0.000045")),
(MESSAGES, "gpt-3.5-turbo-16k-0613", Decimal("0.000045")),
(MESSAGES, "gpt-3.5-turbo-1106", Decimal("0.000015")),
(MESSAGES, "gpt-3.5-turbo-instruct", Decimal("0.0000225")),
(MESSAGES, "gpt-4", Decimal("0.00045")),
(MESSAGES, "gpt-4-0314", Decimal("0.00045")),
(MESSAGES, "gpt-4-32k", Decimal("0.00090")),
(MESSAGES, "gpt-4-32k-0314", Decimal("0.00090")),
(MESSAGES, "gpt-4-0613", Decimal("0.00045")),
(MESSAGES, "gpt-4-1106-preview", Decimal("0.00015")),
(MESSAGES, "gpt-4-vision-preview", Decimal("0.00015")),
(STRING, "text-embedding-ada-002", Decimal("0.0000004")),
],
)
def test_calculate_prompt_cost(prompt, model, expected_output):
Expand All @@ -165,20 +165,20 @@ def test_invalid_prompt_format():
@pytest.mark.parametrize(
"prompt,model,expected_output",
[
(STRING, "gpt-3.5-turbo", Decimal('0.000008')),
(STRING, "gpt-3.5-turbo-0301", Decimal('0.000008')),
(STRING, "gpt-3.5-turbo-0613", Decimal('0.000008')),
(STRING, "gpt-3.5-turbo-16k", Decimal('0.000016')),
(STRING, "gpt-3.5-turbo-16k-0613", Decimal('0.000016')),
(STRING, "gpt-3.5-turbo-1106", Decimal('0.000008')),
(STRING, "gpt-3.5-turbo-instruct", Decimal('0.000008')),
(STRING, "gpt-4", Decimal('0.00024')),
(STRING, "gpt-4-0314", Decimal('0.00024')),
(STRING, "gpt-4-32k", Decimal('0.00048')),
(STRING, "gpt-4-32k-0314", Decimal('0.00048')),
(STRING, "gpt-4-0613", Decimal('0.00024')),
(STRING, "gpt-4-1106-preview", Decimal('0.00012')),
(STRING, "gpt-4-vision-preview", Decimal('0.00012')),
(STRING, "gpt-3.5-turbo", Decimal("0.000008")),
(STRING, "gpt-3.5-turbo-0301", Decimal("0.000008")),
(STRING, "gpt-3.5-turbo-0613", Decimal("0.000008")),
(STRING, "gpt-3.5-turbo-16k", Decimal("0.000016")),
(STRING, "gpt-3.5-turbo-16k-0613", Decimal("0.000016")),
(STRING, "gpt-3.5-turbo-1106", Decimal("0.000008")),
(STRING, "gpt-3.5-turbo-instruct", Decimal("0.000008")),
(STRING, "gpt-4", Decimal("0.00024")),
(STRING, "gpt-4-0314", Decimal("0.00024")),
(STRING, "gpt-4-32k", Decimal("0.00048")),
(STRING, "gpt-4-32k-0314", Decimal("0.00048")),
(STRING, "gpt-4-0613", Decimal("0.00024")),
(STRING, "gpt-4-1106-preview", Decimal("0.00012")),
(STRING, "gpt-4-vision-preview", Decimal("0.00012")),
(STRING, "text-embedding-ada-002", 0),
],
)
Expand Down Expand Up @@ -213,9 +213,9 @@ def test_calculate_invalid_input_types():
@pytest.mark.parametrize(
"num_tokens,model,token_type,expected_output",
[
(10, "gpt-3.5-turbo", 'input', Decimal('0.0000150')), # Example values
(5, "gpt-4", 'output', Decimal('0.00030')), # Example values
(10, "ai21.j2-mid-v1", 'input', Decimal('0.0001250')), # Example values
(10, "gpt-3.5-turbo", "input", Decimal("0.0000150")), # Example values
(5, "gpt-4", "output", Decimal("0.00030")), # Example values
(10, "ai21.j2-mid-v1", "input", Decimal("0.0001250")), # Example values
],
)
def test_calculate_cost_by_tokens(num_tokens, model, token_type, expected_output):
Expand Down
23 changes: 13 additions & 10 deletions tests/test_llama_index_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# test_llama_index.py
import pytest
from tokencost.callbacks import llama_index
from llama_index.callbacks.schema import CBEventType, EventPayload
from llama_index.core.callbacks.schema import CBEventType, EventPayload
from unittest.mock import MagicMock

# Mock the calculate_prompt_cost and calculate_completion_cost functions
Expand All @@ -20,36 +20,39 @@ def __init__(self, text):
def __str__(self):
return self.text

monkeypatch.setattr('llama_index.llms.ChatMessage', MockChatMessage)
monkeypatch.setattr("llama_index.core.llms.ChatMessage", MockChatMessage)
return MockChatMessage


# Test the _calc_llm_event_cost method for prompt and completion


def test_calc_llm_event_cost_prompt_completion(capsys):
handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo')
payload = {
EventPayload.PROMPT: STRING,
EventPayload.COMPLETION: STRING
}
handler = llama_index.TokenCostHandler(model="gpt-3.5-turbo")
payload = {EventPayload.PROMPT: STRING, EventPayload.COMPLETION: STRING}
handler._calc_llm_event_cost(payload)
captured = capsys.readouterr()
assert "# Prompt cost: 0.0000060" in captured.out
assert "# Completion: 0.000008" in captured.out


# Test the _calc_llm_event_cost method for messages and response


def test_calc_llm_event_cost_messages_response(mock_chat_message, capsys):
handler = llama_index.TokenCostHandler(model='gpt-3.5-turbo')
handler = llama_index.TokenCostHandler(model="gpt-3.5-turbo")
payload = {
EventPayload.MESSAGES: [mock_chat_message("message 1"), mock_chat_message("message 2")],
EventPayload.RESPONSE: "test response"
EventPayload.MESSAGES: [
mock_chat_message("message 1"),
mock_chat_message("message 2"),
],
EventPayload.RESPONSE: "test response",
}
handler._calc_llm_event_cost(payload)
captured = capsys.readouterr()
assert "# Prompt cost: 0.0000105" in captured.out
assert "# Completion: 0.000004" in captured.out


# Additional tests can be written for start_trace, end_trace, on_event_start, and on_event_end
# depending on the specific logic and requirements of those methods.
1 change: 1 addition & 0 deletions tokencost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
count_string_tokens,
calculate_completion_cost,
calculate_prompt_cost,
calculate_all_costs_and_tokens,
)
from .constants import TOKEN_COSTS_STATIC, TOKEN_COSTS, update_token_costs
38 changes: 25 additions & 13 deletions tokencost/callbacks/llama_index.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, Dict, List, Optional, cast
from llama_index.callbacks.base_handler import BaseCallbackHandler
from llama_index.callbacks.schema import CBEventType, EventPayload
from tokencost import calculate_prompt_cost, calculate_completion_cost
from llama_index.core.callbacks.base_handler import BaseCallbackHandler
areibman marked this conversation as resolved.
Show resolved Hide resolved
from llama_index.core.callbacks.schema import CBEventType, EventPayload
from llama_index.core.llms import ChatMessage
from tokencost import calculate_all_costs_and_tokens


class TokenCostHandler(BaseCallbackHandler):
Expand All @@ -10,6 +11,10 @@ class TokenCostHandler(BaseCallbackHandler):
def __init__(self, model) -> None:
super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[])
self.model = model
self.prompt_cost = 0
self.completion_cost = 0
self.prompt_tokens = 0
self.completion_tokens = 0

def start_trace(self, trace_id: Optional[str] = None) -> None:
return
Expand All @@ -22,27 +27,34 @@ def end_trace(
return

def _calc_llm_event_cost(self, payload: dict) -> None:
from llama_index.llms import ChatMessage

prompt_cost = 0
completion_cost = 0
if EventPayload.PROMPT in payload:
prompt = str(payload.get(EventPayload.PROMPT))
completion = str(payload.get(EventPayload.COMPLETION))
prompt_cost = calculate_prompt_cost(prompt, self.model)
completion_cost = calculate_completion_cost(completion, self.model)
estimates = calculate_all_costs_and_tokens(prompt, completion, self.model)

elif EventPayload.MESSAGES in payload:
messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, []))
messages_str = "\n".join([str(x) for x in messages])
prompt_cost = calculate_prompt_cost(messages_str, self.model)
response = str(payload.get(EventPayload.RESPONSE))
completion_cost = calculate_completion_cost(response, self.model)
estimates = calculate_all_costs_and_tokens(
messages_str, response, self.model
)

self.prompt_cost += estimates["prompt_cost"]
self.completion_cost += estimates["completion_cost"]
self.prompt_tokens += estimates["prompt_tokens"]
self.completion_tokens += estimates["completion_tokens"]

print(f"# Prompt cost: {prompt_cost}")
print(f"# Completion: {completion_cost}")
print(f"# Prompt cost: {estimates['prompt_cost']}")
print(f"# Completion: {estimates['completion_cost']}")
print("\n")

def reset_counts(self) -> None:
self.prompt_cost = 0
self.completion_cost = 0
self.prompt_tokens = 0
self.completion_tokens = 0

def on_event_start(
self,
event_type: CBEventType,
Expand Down
54 changes: 49 additions & 5 deletions tokencost/costs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Costs dictionary and utility tool for counting tokens
"""

import tiktoken
from typing import Union, List, Dict
from .constants import TOKEN_COSTS
Expand Down Expand Up @@ -57,10 +58,14 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int:
tokens_per_message = 4
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
logging.warning("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
logging.warning(
"gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
)
return count_message_tokens(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
logging.warning("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
logging.warning(
"gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
)
return count_message_tokens(messages, model="gpt-4-0613")
else:
raise KeyError(
Expand Down Expand Up @@ -118,7 +123,9 @@ def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: str) -> De
Double-check your spelling, or submit an issue/PR"""
)

cost_per_token_key = 'input_cost_per_token' if token_type == 'input' else 'output_cost_per_token'
cost_per_token_key = (
"input_cost_per_token" if token_type == "input" else "output_cost_per_token"
)
cost_per_token = TOKEN_COSTS[model][cost_per_token_key]

return Decimal(str(cost_per_token)) * Decimal(num_tokens)
Expand Down Expand Up @@ -164,7 +171,7 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
else count_message_tokens(prompt, model)
)

return calculate_cost_by_tokens(prompt_tokens, model, 'input')
return calculate_cost_by_tokens(prompt_tokens, model, "input")


def calculate_completion_cost(completion: str, model: str) -> Decimal:
Expand All @@ -191,4 +198,41 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
)
completion_tokens = count_string_tokens(completion, model)

return calculate_cost_by_tokens(completion_tokens, model, 'output')
return calculate_cost_by_tokens(completion_tokens, model, "output")


def calculate_all_costs_and_tokens(
prompt: Union[List[dict], str], completion: str, model: str
) -> dict:
"""
Calculate the prompt and completion costs and tokens in USD.

Args:
prompt (Union[List[dict], str]): List of message objects or single string prompt.
completion (str): Completion string.
model (str): The model name.

Returns:
dict: The calculated cost and tokens in USD.

e.g.:
>>> prompt = "Hello world"
>>> completion = "How may I assist you today?"
>>> calculate_all_costs_and_tokens(prompt, completion, "gpt-3.5-turbo")
{'prompt_cost': Decimal('0.0000030'), 'prompt_tokens': 2, 'completion_cost': Decimal('0.000014'), 'completion_tokens': 7}
"""
prompt_cost = calculate_prompt_cost(prompt, model)
completion_cost = calculate_completion_cost(completion, model)
prompt_tokens = (
count_string_tokens(prompt, model)
if isinstance(prompt, str)
else count_message_tokens(prompt, model)
)
completion_tokens = count_string_tokens(completion, model)

return {
"prompt_cost": prompt_cost,
"prompt_tokens": prompt_tokens,
"completion_cost": completion_cost,
"completion_tokens": completion_tokens,
}