-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
9 changed files
with
394 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import json | ||
import logging | ||
import re | ||
from dataclasses import dataclass | ||
|
||
import pandas as pd | ||
|
||
from eureka_ml_insights.data_utils import DFTransformBase | ||
|
||
|
||
@dataclass | ||
class NPHARDTSPExtractAnswer(DFTransformBase): | ||
"""Class to extract and transform the TSP path from model output.""" | ||
|
||
model_output_column: str | ||
model_answer_column: str | ||
|
||
def transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
"""Extracts the tsp path from the model output and stores it in the model_answer_column.""" | ||
df[self.model_answer_column] = df[self.model_output_column].apply(parse_path_from_model_output) | ||
return df | ||
|
||
|
||
def extract_final_answer(model_output): | ||
# Find all non-overlapping occurrences between <final_answer> and </final_answer> | ||
matches = re.findall(r"<final_answer>(.*?)</final_answer>", model_output, flags=re.DOTALL) | ||
|
||
# Return the last occurrence if any are found, otherwise return None | ||
return matches[-1] if matches else None | ||
|
||
|
||
def extract_path(final_answer): | ||
"""Extracts the path string from the final answer, handling both JSON formats.""" | ||
try: | ||
# Convert single quotes to double quotes for valid JSON parsing | ||
final_answer_json = json.loads(final_answer.replace("'", '"')) | ||
return final_answer_json.get("Path", None) | ||
except json.JSONDecodeError: | ||
# Fallback regex extraction if JSON parsing fails | ||
match = re.search(r'"Path":\s*"([^"]+)"', final_answer) | ||
return match.group(1) if match else None | ||
|
||
|
||
def parse_path_from_model_output(model_output_string): | ||
"""Parses the model output to extract a tsp path.""" | ||
try: | ||
final_answer = extract_final_answer(model_output_string) | ||
tour_string = extract_path(final_answer) if final_answer else None | ||
|
||
if tour_string is None: | ||
return "0,0,0,0" | ||
|
||
parts = re.findall(r"\d+|->", tour_string) | ||
tour_string = "".join(parts) | ||
tour = list(map(int, tour_string.split("->"))) | ||
except (AttributeError, ValueError) as e: | ||
logging.info(f"There is no valid path: {e}") | ||
return "0,0,0,0" | ||
|
||
return ",".join(map(str, tour)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import logging | ||
|
||
from .metrics_base import Metric | ||
|
||
|
||
class NPHardTSPMetric(Metric): | ||
""" | ||
A metric class for evaluating solutions to the Traveling Salesman Problem (TSP). | ||
A prediction is considered correct if it is a valid TSP tour and matches one of the optimal solutions. | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def is_valid_tsp_path(self, path, cities, distance_matrix=None): | ||
""" | ||
Validates a TSP path and if valid, evaluates its length. | ||
Parameters: | ||
path (list of int): The TSP path, a list of city indices. | ||
cities (list of int): The list of all cities. | ||
distance_matrix (list of lists, optional): Matrix representing distances between cities. | ||
Returns: | ||
tuple: (bool, float): Whether the path is valid and its length (if valid). Length is None if invalid. | ||
""" | ||
# Ensure the path is not empty and has the correct number of cities | ||
if not path or len(path) != len(cities) + 1: | ||
logging.info("Invalid: Path is empty or has incorrect number of nodes.") | ||
return False, None | ||
|
||
# Ensure the path starts and ends at the same city | ||
if path[0] != path[-1]: | ||
logging.info("Invalid: Path does not start and end at the same city.") | ||
return False, None | ||
|
||
# Ensure all cities are visited exactly once (except the start/end city) | ||
unique_cities_in_path = set(path[:-1]) # Exclude the last city | ||
unique_cities = set(cities) | ||
|
||
if unique_cities_in_path != unique_cities: | ||
logging.info("Invalid: Path does not include all cities exactly once.") | ||
return False, None | ||
|
||
# If a distance matrix is provided, calculate the path length | ||
path_length = 0 | ||
if distance_matrix: | ||
try: | ||
for i in range(len(path) - 1): | ||
start = cities.index(path[i]) | ||
end = cities.index(path[i + 1]) | ||
path_length += distance_matrix[start][end] | ||
except (IndexError, ValueError): | ||
logging.info("Invalid: Path contains cities not in the provided distance matrix.") | ||
return False, None | ||
|
||
return True, path_length | ||
|
||
def is_tour_present(self, optimal_tour_curr, tour_string): | ||
|
||
optimal_tour_list = eval(optimal_tour_curr.strip(".")) | ||
optimal_tour_strings = [",".join(map(str, tour + (tour[0],))) for tour in optimal_tour_list] | ||
|
||
# Check if tour_string is in the list of optimal tour strings | ||
return tour_string in optimal_tour_strings | ||
|
||
def __evaluate__(self, x): | ||
""" | ||
Evaluates whether the model's output is a correct TSP tour. | ||
""" | ||
is_valid_curr = x["is_valid"] | ||
|
||
if not is_valid_curr: | ||
return "none" | ||
|
||
optimal_tour_curr = x["optimal_tour"] | ||
weight_matrix_curr = x["weight_matrix"] | ||
ground_truth_curr = x["ground_truth"] | ||
tour_string = x["model_output"] | ||
|
||
# Convert tour string into a list of integers representing the city sequence | ||
tour = list(map(int, tour_string.split(","))) | ||
cities = [i for i in range(len(weight_matrix_curr))] | ||
|
||
# Validate the TSP tour and compute its length | ||
is_tsp_path_valid, total_tsp_path_length = self.is_valid_tsp_path(tour, cities, weight_matrix_curr) | ||
|
||
# The prediction is incorrect if the tour is invalid or the length is incorrect | ||
if not is_tsp_path_valid or total_tsp_path_length != ground_truth_curr: | ||
return "incorrect" | ||
|
||
# Check if the predicted tour is one of the optimal solutions | ||
is_tsp_tour_present = self.is_tour_present(optimal_tour_curr, tour_string) | ||
|
||
if not is_tsp_tour_present: | ||
return "incorrect" | ||
|
||
return "correct" |
8 changes: 8 additions & 0 deletions
8
eureka_ml_insights/prompt_templates/nphard_tsp_templates/Template_tsp_cot.jinja
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
The traveling salesman problem (TSP) is a classic optimization problem that aims to find the shortest possible route that visits a set of cities, with each city being visited exactly once and the route returning to the original city. | ||
|
||
You must find the shortest path that visits all cities. The distances between each pair of cities are provided. | ||
Please list each city in the order they are visited. Provide the total distance of the trip. | ||
Reflect on what the problem is asking. Think step by step and explain your reasoning in detail. | ||
The final output of the result path and total distance wrapped by final_answer tag, like <final_answer>{'Path': '0->1->2->...->N->0', 'TotalDistance': 'INT_TOTAL_DISTANCE'}</final_answer>. | ||
|
||
{{prompt}} |
7 changes: 7 additions & 0 deletions
7
eureka_ml_insights/prompt_templates/nphard_tsp_templates/Template_tsp_o1.jinja
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
The traveling salesman problem (TSP) is a classic optimization problem that aims to find the shortest possible route that visits a set of cities, with each city being visited exactly once and the route returning to the original city. | ||
|
||
You must find the shortest path that visits all cities. The distances between each pair of cities are provided. | ||
Please list each city in the order they are visited. Provide the total distance of the trip. | ||
The final output of the result path and total distance wrapped by final_answer tag, like <final_answer>{'Path': '0->1->2->...->N->0', 'TotalDistance': 'INT_TOTAL_DISTANCE'}</final_answer> | ||
|
||
{{prompt}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import os | ||
from typing import Any | ||
|
||
from eureka_ml_insights.configs import ( | ||
AggregatorConfig, | ||
DataProcessingConfig, | ||
DataSetConfig, | ||
EvalReportingConfig, | ||
ExperimentConfig, | ||
InferenceConfig, | ||
MetricConfig, | ||
ModelConfig, | ||
PipelineConfig, | ||
PromptProcessingConfig, | ||
) | ||
from eureka_ml_insights.core import ( | ||
DataProcessing, | ||
EvalReporting, | ||
Inference, | ||
PromptProcessing, | ||
) | ||
from eureka_ml_insights.data_utils import ( | ||
AddColumn, | ||
ColumnRename, | ||
DataReader, | ||
HFDataReader, | ||
MajorityVoteTransform, | ||
MMDataLoader, | ||
MultiplyTransform, | ||
SequenceTransform, | ||
) | ||
from eureka_ml_insights.data_utils.nphard_tsp_utils import ( | ||
NPHARDTSPExtractAnswer, | ||
) | ||
from eureka_ml_insights.metrics import CountAggregator, NPHardTSPMetric | ||
|
||
"""This file contains user defined configuration classes for the Traveling Salesman Problem (TSP). | ||
""" | ||
|
||
|
||
class NPHARD_TSP_PIPELINE(ExperimentConfig): | ||
def configure_pipeline( | ||
self, model_config: ModelConfig, resume_from: str = None, **kwargs: dict[str, Any] | ||
) -> PipelineConfig: | ||
# Configure the data processing component. | ||
self.data_processing_comp = PromptProcessingConfig( | ||
component_type=PromptProcessing, | ||
data_reader_config=DataSetConfig( | ||
HFDataReader, | ||
{ | ||
"path": "GeoMeterData/nphard_tsp2", | ||
"split": "train", | ||
"transform": SequenceTransform( | ||
[ | ||
ColumnRename(name_mapping={"query_text": "prompt", "target_text": "ground_truth"}), | ||
] | ||
), | ||
}, | ||
), | ||
prompt_template_path=os.path.join( | ||
os.path.dirname(__file__), "../prompt_templates/nphard_tsp_templates/Template_tsp_o1.jinja" | ||
), | ||
output_dir=os.path.join(self.log_dir, "data_processing_output"), | ||
) | ||
|
||
# Configure the inference component | ||
self.inference_comp = InferenceConfig( | ||
component_type=Inference, | ||
model_config=model_config, | ||
data_loader_config=DataSetConfig( | ||
MMDataLoader, | ||
{"path": os.path.join(self.data_processing_comp.output_dir, "transformed_data.jsonl")}, | ||
), | ||
output_dir=os.path.join(self.log_dir, "inference_result"), | ||
resume_from=resume_from, | ||
max_concurrent=1, | ||
) | ||
|
||
# post process the response to extract the answer | ||
self.data_post_processing = DataProcessingConfig( | ||
component_type=DataProcessing, | ||
data_reader_config=DataSetConfig( | ||
DataReader, | ||
{ | ||
"path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"), | ||
"format": ".jsonl", | ||
"transform": SequenceTransform( | ||
[ | ||
ColumnRename( | ||
name_mapping={ | ||
"model_output": "raw_output", | ||
} | ||
), | ||
AddColumn("model_output"), | ||
NPHARDTSPExtractAnswer("raw_output", "model_output"), | ||
] | ||
), | ||
}, | ||
), | ||
output_dir=os.path.join(self.log_dir, "data_post_processing_output"), | ||
) | ||
|
||
# Configure the evaluation and reporting component. | ||
self.evalreporting_comp = EvalReportingConfig( | ||
component_type=EvalReporting, | ||
data_reader_config=DataSetConfig( | ||
DataReader, | ||
{ | ||
"path": os.path.join(self.data_post_processing.output_dir, "transformed_data.jsonl"), | ||
"format": ".jsonl", | ||
}, | ||
), | ||
metric_config=MetricConfig(NPHardTSPMetric), | ||
aggregator_configs=[ | ||
AggregatorConfig(CountAggregator, {"column_names": ["NPHardTSPMetric_result"], "normalize": True}), | ||
], | ||
output_dir=os.path.join(self.log_dir, "eval_report"), | ||
) | ||
|
||
# Aggregate the results by a majority vote. | ||
self.postevalprocess_comp = EvalReportingConfig( | ||
component_type=EvalReporting, | ||
data_reader_config=DataSetConfig( | ||
DataReader, | ||
{ | ||
"path": os.path.join(self.data_post_processing.output_dir, "transformed_data.jsonl"), | ||
"format": ".jsonl", | ||
"transform": SequenceTransform( | ||
[ | ||
MajorityVoteTransform(id_col="data_point_id"), | ||
ColumnRename( | ||
name_mapping={ | ||
"model_output": "model_output_onerun", | ||
"majority_vote": "model_output", | ||
} | ||
), | ||
] | ||
), | ||
}, | ||
), | ||
metric_config=MetricConfig(NPHardTSPMetric), | ||
aggregator_configs=[ | ||
AggregatorConfig(CountAggregator, {"column_names": ["NPHardTSPMetric_result"], "normalize": True}), | ||
], | ||
output_dir=os.path.join(self.log_dir, "eval_report_majorityVote"), | ||
) | ||
|
||
# Configure the pipeline | ||
return PipelineConfig( | ||
[ | ||
self.data_processing_comp, | ||
self.inference_comp, | ||
self.data_post_processing, | ||
self.evalreporting_comp, | ||
self.postevalprocess_comp, | ||
], | ||
self.log_dir, | ||
) | ||
|
||
|
||
class NPHARD_TSP_PIPELINE_MULTIPLE_RUNS(NPHARD_TSP_PIPELINE): | ||
"""This class specifies the config for running TSP benchmark n repeated times""" | ||
|
||
def configure_pipeline( | ||
self, model_config: ModelConfig, resume_from: str = None, **kwargs: dict[str, Any] | ||
) -> PipelineConfig: | ||
pipeline = super().configure_pipeline(model_config=model_config, resume_from=resume_from) | ||
# data preprocessing | ||
self.data_processing_comp.data_reader_config.init_args["transform"].transforms.append( | ||
MultiplyTransform(n_repeats=1) | ||
) | ||
return pipeline |
Oops, something went wrong.