import datetime
import threading
import uuid
from typing import Any, Literal
from langchain.agents.middleware import AgentState, AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from .energy_estimation_model import estimate_energy_and_emissions
from .models import EnergyDataPoint, EnergyGroupSummary
[docs]
class EnergyMiddleware(AgentMiddleware):
"""
Middleware that tracks token usage, estimates energy consumption,
and calculates associated CO2 emissions for agent model calls.
This middleware maintains a thread-safe list of datapoints and
supports nested agent calls using a prompt ID stack.
"""
def __init__(self):
super().__init__()
self.datapoints: list[EnergyDataPoint] = []
self._lock = threading.Lock()
self._prompt_id_stack: list[str] = []
self._prompt_order: list[str] = []
@property
def _current_prompt_id(self) -> str | None:
"""
Get the current prompt ID from the stack.
Returns:
str | None: The current prompt ID, or None if the stack is empty.
"""
with self._lock:
return self._prompt_id_stack[-1] if self._prompt_id_stack else None
def _filter_datapoints(self, last_n_prompts: int | None, last_n_hours: int | None) -> list[EnergyDataPoint]:
"""
Filter datapoints based on recent prompts or time.
Args:
last_n_prompts (int | None): If set, only include datapoints from the last N top-level prompts.
last_n_hours (float | None): If set, only include datapoints from the last N hours.
Returns:
list[EnergyDataPoint]: A filtered list of datapoints based on the provided criteria.
"""
with self._lock:
points = self.datapoints.copy()
recent_ids = set(self._prompt_order[-last_n_prompts:]) if last_n_prompts is not None else None
if recent_ids is not None:
points = [dp for dp in points if dp.prompt_id in recent_ids]
if last_n_hours is not None:
cutoff = datetime.datetime.now() - datetime.timedelta(hours=last_n_hours)
points = [dp for dp in points if dp.timestamp >= cutoff]
return points
def _group_datapoints(self, points: list[EnergyDataPoint], key: str) -> list[EnergyGroupSummary]:
"""
Group datapoints by a specified key (model_name or agent_name) and aggregates energy/token usage.
Args:
points (list[EnergyDataPoint]): List of datapoints to group.
key (str): Attribute to group by ("model_name" or "agent_name").
Returns:
list[EnergyGroupSummary]: A list of aggregated summaries for each group.
"""
buckets: dict[str, EnergyGroupSummary] = {}
for dp in points:
name = getattr(dp, key)
if name not in buckets:
buckets[name] = EnergyGroupSummary(
name=name,
total_energy_joule=0.0,
total_co2e_kg=0.0,
total_input_tokens=0,
total_output_tokens=0,
datapoint_count=0
)
buckets[name].total_energy_joule += dp.estimated_energy_joule
buckets[name].total_co2e_kg += dp.estimated_co2e_kg
buckets[name].total_input_tokens += dp.input_token_count
buckets[name].total_output_tokens += dp.output_token_count
buckets[name].datapoint_count += 1
return list(buckets.values())
[docs]
def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""
Hook executed before an agent runs.
Ensures that a prompt ID is assigned and propagated through
nested agent calls.
Args:
state (AgentState): Current agent state.
runtime (Runtime): Runtime context.
Returns:
dict[str, Any] | None: Optional state updates (unused).
"""
existing_id = self._current_prompt_id
if existing_id:
with self._lock:
self._prompt_id_stack.append(existing_id)
return None
new_id = str(uuid.uuid4())
with self._lock:
self._prompt_id_stack.append(new_id)
self._prompt_order.append(new_id)
return None
[docs]
def after_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""
Hook executed after an agent finishes.
Removes the latest prompt ID from the stack.
Args:
state (AgentState): Current agent state.
runtime (Runtime): Runtime context.
Returns:
dict[str, Any] | None: Optional state updates (unused).
"""
with self._lock:
if self._prompt_id_stack:
self._prompt_id_stack.pop()
return None
[docs]
def after_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""
Hook executed after a model generates a response.
Extracts token usage, estimates energy and emissions, and
records a datapoint for a model call.
Args:
state (AgentState): Current agent state containing messages.
runtime (Runtime): Runtime context.
Returns:
dict[str, Any] | None: Optional state updates (unused).
"""
last_message = state["messages"][-1]
model_name = last_message.response_metadata.get("model_name", "unknown_model")
input_token_count = last_message.usage_metadata.get("input_tokens", 0)
output_token_count = last_message.usage_metadata.get("output_tokens", 0)
config = get_config()
agent_name = config["metadata"].get("lc_agent_name", "unknown_agent") if config and "metadata" in config else "unknown_agent"
if last_message.content is None or not str(last_message.content).strip():
return None
energy, co2e = estimate_energy_and_emissions(input_token_count, output_token_count, model_name)
prompt_id = self._current_prompt_id
output_datapoint = EnergyDataPoint(
input_token_count=input_token_count,
output_token_count=output_token_count,
estimated_energy_joule=energy,
estimated_co2e_kg=co2e,
model_name=model_name,
timestamp=datetime.datetime.now(),
message=str(last_message.content)[:100],
prompt_id=prompt_id,
agent_name=agent_name,
)
with self._lock:
self.datapoints.append(output_datapoint)
return None
# ── Raw ────────────────────────────────────────────────────────────
[docs]
def get_report(self) -> list[EnergyDataPoint]:
"""
Retrieve a copy of all collected datapoints.
Returns:
list[EnergyDataPoint]: A copy of the recorded datapoints.
"""
with self._lock:
return self.datapoints.copy()
[docs]
def get_prompt_count(self) -> int:
"""
Number of top-level prompts seen so far.
Returns:
int: Count of unique top-level prompts.
"""
with self._lock:
return len(self._prompt_order)
# ── Totals ────────────────────────────────────────────────────
[docs]
def get_totals(self) -> dict[str, float | int]:
"""
Returns all total accumulated over the collected datapoints.
Returns:
dict[str, float | int]: A dictionary containing total energy (J), total CO2 (kg), total input tokens, and total output tokens.
"""
total_energy = 0.0
total_co2 = 0.0
total_input_tokens = 0
total_output_tokens = 0
with self._lock:
for dp in self.datapoints:
total_energy += dp.estimated_energy_joule
total_co2 += dp.estimated_co2e_kg
total_input_tokens += dp.input_token_count
total_output_tokens += dp.output_token_count
return {
"energy": total_energy,
"co2": total_co2,
"input_tokens": total_input_tokens,
"output_tokens": total_output_tokens,
}
[docs]
def get_total_energy(self) -> float:
"""
Get the total estimated energy consumption of all collected datapoints in Joules.
Returns:
float: Total estimated energy in Joules.
"""
return self.get_totals()["energy"]
[docs]
def get_total_co2(self) -> float:
"""
Get the total estimated CO2 emissions of all collected datapoints in kilograms.
Returns:
float: Total estimated CO2 emissions in kilograms.
"""
return self.get_totals()["co2"]
[docs]
def get_total_output_tokens(self) -> int:
"""
Get the total number of output tokens across all collected datapoints.
Returns:
int: Total number of output tokens.
"""
return self.get_totals()["output_tokens"]
# ── Grouped / filtered Summary ───────────────────────────────────────────────────────
[docs]
def get_summary(
self,
group_by: Literal["model_name", "agent_name"],
last_n_prompts: int | None = None,
last_n_hours: float | None = None,
) -> list[EnergyGroupSummary]:
"""
Returns aggregated energy/token summaries grouped by model or agent.
Args:
group_by: "model_name" or "agent_name"
last_n_prompts: if set, only include data from the last N top-level prompts
last_n_hours: if set, only include data from the last N hours
"""
points = self._filter_datapoints(last_n_prompts, last_n_hours)
return self._group_datapoints(points, group_by)