Skip to content

Commit

Permalink
remove method from process_distance function to class
Browse files Browse the repository at this point in the history
  • Loading branch information
chraibi committed Oct 2, 2024
1 parent 62ebf30 commit 7fc1b04
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions src/scripts/proximity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Tuple

from datetime import datetime
import numpy as np
import pandas as pd
import pedpy
from joblib import Parallel, delayed
from tqdm import tqdm
from enum import Enum

parent_dir = str(Path(__file__).resolve().parent.parent)
print(parent_dir)
print("parent dir", parent_dir)
if parent_dir not in sys.path:
print(parent_dir)
sys.path.append(parent_dir)
Expand All @@ -34,12 +35,23 @@
# print(f"distances {len(path_distances)}")


class Method(Enum):
"""Method to calculate distances. vect and merge are Euc."""

VECT = "vect"
ARC = "arc"
MERGE = "merge"


@dataclass
class InitData:
"""Class to hold data."""

countries: List[str]
files: Dict[str, List[str]]
result_csv: Path
fps: int
method: str


def load_file(file: str) -> pedpy.TrajectoryData:
Expand Down Expand Up @@ -74,7 +86,6 @@ def filter_frames(data: pd.DataFrame, nagents: int) -> pd.DataFrame:
With this method with basically skip all these frames and keep
only these with <nagents> agents
"""

# Group data by frame and count the unique IDs in each frame
frame_counts = data.groupby("frame")["id"].nunique()

Expand Down Expand Up @@ -314,7 +325,7 @@ def init_gender_code(filename: str) -> str:
return name


def calculate_proximity_analysis(country: str, filename: str, data: pd.DataFrame, fps: int = 25) -> List[Dict[str, Any]]:
def calculate_proximity_analysis(country: str, filename: str, data: pd.DataFrame, fps: int = 25, method: Method = Method.VECT) -> List[Dict[str, Any]]:
"""
Perform proximity analysis on given data of agents.
Expand All @@ -337,33 +348,32 @@ def calculate_proximity_analysis(country: str, filename: str, data: pd.DataFrame
'gender_of_next_neighbor', 'gender_of_prev_neighbor', 'distance_to_next_neighbor',
and 'distance_to_prev_neighbor' columns.
"""

pp = Path(filename).parts
country_short = Path(pp[-2]).name
if country_short != "pal":
# print(f"WARNINNG in calculate_proximity_analysis(): method={method} is hard-coded")
# merge and vect are basically the same, just testing which implementation is faster
# arc uses a different calculation of distances based on arc.
method = "arc"
if method == "merge":
if method == method.MERGE:
processed_data = compute_distance_merge(data)
elif method == "vect":
elif method == method.VECT:
nagents = extract_first_number(filename)
cleaned_data = filter_frames(data, nagents)
processed_data = calculate_circular_distance_and_gender_vect(cleaned_data)
else: # arc
elif method == method.ARC:
nagents = extract_first_number(filename)
cleaned_data = filter_frames(data, nagents)
processed_data = calculate_circular_distance_and_gender_arc(cleaned_data)
fps = 25
else:
print(f"Method {method} is not recognised.")
else:
processed_data = calculate_circular_distance_and_gender_pal(data)
fps = 1

proximity_analysis_res = []
if processed_data.empty:
print("========")
print(filename)
print("Processed_data empty", filename)
print(processed_data)
print("========")
frames_to_include = set(range(0, processed_data["frame"].max(), fps))
Expand Down Expand Up @@ -420,11 +430,11 @@ def unpack_and_process(args: Any) -> List[Dict[str, Any]]:
return calculate_proximity_analysis(*args)


def prepare_data(country: str, selected_file: str, fps: int) -> Tuple[str, str, pd.DataFrame, int]:
def prepare_data(country: str, selected_file: str, fps: int, method: Method) -> Tuple[str, str, pd.DataFrame, int]:
"""Load file and make pedpy.datatrajectory."""
trajectory_data = load_file(selected_file)
data = trajectory_data.data
return country, selected_file, data, fps
return country, selected_file, data, fps, method


def calculate_with_joblib(init_data: InitData) -> pd.DataFrame:
Expand All @@ -435,7 +445,7 @@ def calculate_with_joblib(init_data: InitData) -> pd.DataFrame:
print(f"prepare tasks: {country} with {key}")

for filename in init_data.files[key]:
tasks.append(prepare_data(country, filename, init_data.fps))
tasks.append(prepare_data(country, filename, init_data.fps, init_data.method))

# Define a function to be executed in parallel
def process_task(task: List[Any]) -> List[Dict[str, Any]]:
Expand All @@ -456,13 +466,15 @@ def process_task(task: List[Any]) -> List[Dict[str, Any]]:

def init() -> InitData:
"""
Initializes the application by setting up the countries, file paths, result CSV path, and FPS (frames per second).
Initialize the application by setting up the countries, file paths, result CSV path, and FPS (frames per second).
Returns:
InitData: A data class containing initialized data including countries, files dictionary, result CSV path, and FPS.
"""
print("Enter Init")
result_csv = Path(f"{parent_dir}/app_data/proximity_analysis_results_arc.csv")
method = Method.VECT
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
result_csv = Path(f"{parent_dir}/app_data/proximity_analysis_results_{method}_{timestamp}.csv")
result_csv.parent.mkdir(parents=True, exist_ok=True)
print(f"Created file: {result_csv}")
fps = 25 # For distance calculations, calculate every fps-frame
Expand All @@ -478,7 +490,7 @@ def init() -> InitData:
key = Path(country).name
files[key] = [str(path) for path in Path(country).glob("*.csv")]

return InitData(countries=countries, files=files, result_csv=result_csv, fps=fps)
return InitData(countries=countries, files=files, result_csv=result_csv, fps=fps, method=method)


if __name__ == "__main__":
Expand Down

0 comments on commit 7fc1b04

Please sign in to comment.