diff --git a/sqlite/GraphNet.db b/sqlite/GraphNet.db new file mode 100644 index 000000000..cefb9086a Binary files /dev/null and b/sqlite/GraphNet.db differ diff --git a/sqlite/Readme.md b/sqlite/Readme.md new file mode 100644 index 000000000..2ffe098f4 --- /dev/null +++ b/sqlite/Readme.md @@ -0,0 +1,14 @@ +work under /GraphNet/ + +mkdir -p sqlite/logs + +# migrate database +# Use default database path +python ./sqlite/init_db.py + +# Specify custom database path +python ./sqlite/init_db.py --db_path sqlite/GraphNet.db + + +# Add data to database +bash ./sqlite/graphsample_insert.sh | tee "sqlite/logs/insert_$(date +'%Y-%m-%d-%H%M%S').log" diff --git a/sqlite/graphsample_insert.py b/sqlite/graphsample_insert.py new file mode 100644 index 000000000..faecbea95 --- /dev/null +++ b/sqlite/graphsample_insert.py @@ -0,0 +1,352 @@ +import sqlite3 +import json +import argparse +from pathlib import Path +from datetime import datetime +import uuid as uuid_lib +import re +from orm_models import ( + get_session, + GraphSample, + SubgraphSource, + DimensionGeneralizationSource, + DataTypeGeneralizationSource, +) +from sqlalchemy.exc import IntegrityError + + +# graph_sample insert func +def get_graph_sample_data( + model_path_prefix: str, + relative_model_path: str, + repo_uid: str, + sample_type: str, + order_value: int, +) -> dict: + model_path = Path(model_path_prefix) / relative_model_path + data = { + "uuid": _get_uuid(), + "repo_uid": repo_uid, + "relative_model_path": relative_model_path, + "sample_type": sample_type, + "is_subgraph": _is_subgraph(sample_type), + "num_ops": _get_num_ops(model_path, sample_type), + "graph_hash": _get_graph_hash(model_path), + "order_value": order_value, + "create_at": _get_create_at(), + "deleted": False, + "delete_at": None, + } + return data + + +def insert_graph_sample(db_path: str, data: dict, model_path_prefix: str): + session = get_session(db_path) + try: + graph_sample = GraphSample(**data) + session.add(graph_sample) + session.commit() + return graph_sample + except IntegrityError as e: + session.rollback() + raise e + finally: + session.close() + + +# subgraph source insert func +def insert_subgraph_source( + subgraph_uuid: str, model_path_prefix: str, relative_model_path: str, db_path: str +): + session = get_session(db_path) + try: + parent_relative_path = get_parent_relative_path(relative_model_path) + full_graph = ( + session.query(GraphSample) + .filter( + GraphSample.relative_model_path == parent_relative_path, + GraphSample.sample_type == "full_graph", + ) + .first() + ) + + if not full_graph: + raise ValueError(f"Full graph not found for path: {parent_relative_path}") + + range_info = _get_range_info(model_path_prefix, relative_model_path) + subgraph_source = SubgraphSource( + subgraph_uuid=subgraph_uuid, + full_graph_uuid=full_graph.uuid, + range_start=range_info["start"], + range_end=range_info["end"], + create_at=datetime.now(), + deleted=False, + delete_at=None, + ) + session.add(subgraph_source) + session.commit() + + return { + "subgraph_uuid": subgraph_source.subgraph_uuid, + "full_graph_uuid": subgraph_source.full_graph_uuid, + "range_start": subgraph_source.range_start, + "range_end": subgraph_source.range_end, + } + except IntegrityError as e: + session.rollback() + raise e + finally: + session.close() + + +def _get_range_info(model_path_prefix: str, relative_model_path: str): + model_path = Path(model_path_prefix) / relative_model_path + subgraph_sources_file = model_path / "subgraph_sources.json" + if not subgraph_sources_file.exists(): + return {"start": -1, "end": -1} + + try: + with open(subgraph_sources_file) as f: + data = json.load(f) + for key, ranges in data.items(): + if isinstance(ranges, list): + r = ranges[0] + if isinstance(r, list) and len(r) == 2: + return {"start": r[0], "end": r[1]} + return {"start": -1, "end": -1} + except (json.JSONDecodeError, KeyError, TypeError, IndexError) as e: + print(f"Warning: Failed to parse {subgraph_sources_file}: {e}") + return {"start": -1, "end": -1} + + +def get_parent_relative_path(relative_path: str) -> str: + if "_decomposed" not in relative_path: + return None + + parts = relative_path.split("/") + if len(parts) < 2: + return None + + parent_parts = [] + for part in parts: + if part == "_decomposed": + break + parent_parts.append(part) + + return "/".join(parent_parts) + + +# full_graph insert func +def _get_uuid() -> str: + return uuid_lib.uuid4().hex + + +def _is_subgraph(sample_type: str) -> bool: + return sample_type not in ("full_graph") + + +def _get_num_ops(model_path: Path, sample_type: str): + if sample_type == "full_graph": + return -1 + subgraph_sources_file = model_path / "subgraph_sources.json" + if not subgraph_sources_file.exists(): + return -1 + + try: + with open(subgraph_sources_file) as f: + data = json.load(f) + for key, ranges in data.items(): + if isinstance(ranges, list): + r = ranges[0] + if isinstance(r, list) and len(r) == 2: + return r[1] - r[0] + + return -1 + except (json.JSONDecodeError, KeyError, TypeError, IndexError) as e: + print(f"Warning: Failed to parse {subgraph_sources_file}: {e}") + return -1 + + +def _get_graph_hash(model_path: Path) -> str: + hash_file = model_path / "graph_hash.txt" + if hash_file.exists(): + return hash_file.read_text().strip() + return "" + + +def _get_create_at() -> datetime: + return datetime.now() + + +# DimensionGeneralizationSource insert func +def insert_dimension_generalization_source( + generalized_graph_uuid: str, + original_graph_uuid: str, + model_path_prefix: str, + relative_model_path: str, + db_path: str, +): + session = get_session(db_path) + try: + dimension_source = DimensionGeneralizationSource( + generalized_graph_uuid=generalized_graph_uuid, + original_graph_uuid=original_graph_uuid, + total_element_size=_get_total_element_size( + model_path_prefix, relative_model_path + ), + create_at=datetime.now(), + deleted=False, + delete_at=None, + ) + session.add(dimension_source) + session.commit() + except IntegrityError as e: + session.rollback() + raise e + finally: + session.close() + + +def _get_total_element_size(model_path_prefix: str, relative_model_path: str): + model_path = Path(model_path_prefix) / relative_model_path + weight_meta_file = model_path / "weight_meta.py" + try: + with open(weight_meta_file) as f: + content = f.read() + + shape_matches = re.findall( + r"shape\s*=\s*\[([0-9,\s\.]+(?:\d+)?[^\]]+)\s*\]", content + ) + total_element_size = 0 + for match in shape_matches: + shape_str = match.strip() + shape_element_size = 1 + numbers = re.findall(r"[0-9]+(?:\.[0-9]+)?", shape_str) + for num_str in numbers: + num = float(num_str) if "." in num_str else int(num_str) + shape_element_size *= num + + total_element_size += shape_element_size + + return total_element_size + except Exception as e: + print(f"Warning: Failed to parse {weight_meta_file}: {e}") + return -1 + + +# DataTypeGeneralizationSource insert func +def insert_datatype_generalization_source( + generalized_graph_uuid: str, + original_graph_uuid: str, + model_path_prefix: str, + relative_model_path: str, + db_path: str, +): + session = get_session(db_path) + try: + data_type_source = DataTypeGeneralizationSource( + generalized_graph_uuid=generalized_graph_uuid, + original_graph_uuid=original_graph_uuid, + data_type=_get_data_type(model_path_prefix, relative_model_path), + create_at=datetime.now(), + deleted=False, + delete_at=None, + ) + session.add(data_type_source) + session.commit() + except IntegrityError as e: + session.rollback() + raise e + finally: + session.close() + + +def _get_data_type(model_path_prefix: str, relative_model_path: str): + return "todo" + + +# main func +def main(args): + data = get_graph_sample_data( + model_path_prefix=args.model_path_prefix, + relative_model_path=args.relative_model_path, + repo_uid=args.repo_uid, + sample_type=args.sample_type, + order_value=args.order_value, + ) + print(f"\ninsert into database: {args.db_path}") + try: + insert_graph_sample(args.db_path, data, args.model_path_prefix) + if data["is_subgraph"]: + subgraph_source_data = insert_subgraph_source( + data["uuid"], + args.model_path_prefix, + data["relative_model_path"], + args.db_path, + ) + if args.sample_type in ["fusible_graph"]: + insert_dimension_generalization_source( + subgraph_source_data["subgraph_uuid"], + subgraph_source_data["full_graph_uuid"], + args.model_path_prefix, + args.relative_model_path, + args.db_path, + ) + insert_datatype_generalization_source( + subgraph_source_data["subgraph_uuid"], + subgraph_source_data["full_graph_uuid"], + args.model_path_prefix, + args.relative_model_path, + args.db_path, + ) + print(f"success insert: {data['relative_model_path']}") + except sqlite3.IntegrityError as e: + print("insert failed: integrity error (possible duplicate uuid or graph_hash)") + print(f"error info: {e}") + except Exception as e: + print(f"insert failed: {e}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="insert graph sample to database") + parser.add_argument( + "--model_path_prefix", + type=str, + required=True, + default="GraphNet", + help="Prefix of model path root'", + ) + parser.add_argument( + "--relative_model_path", + type=str, + required=True, + help="Path to model folder e.g '../../samples/torch/resnet18'", + ) + parser.add_argument( + "--repo_uid", + type=str, + required=True, + help="Repository uid e.g 'github torch samples', 'github_paddle_samples'", + ) + parser.add_argument( + "--sample_type", + type=str, + required=True, + default="full_graph", + help="Sample type e.g 'full_graph', 'fusible_graph'", + ) + parser.add_argument( + "--order_value", + type=int, + required=True, + help="Order value e.g '1'", + ) + parser.add_argument( + "--db_path", + type=str, + required=False, + default="graphnet.db", + help="Database file path e.g 'graphnet.db'", + ) + args = parser.parse_args() + main(args) diff --git a/sqlite/graphsample_insert.sh b/sqlite/graphsample_insert.sh new file mode 100644 index 000000000..9af4664a8 --- /dev/null +++ b/sqlite/graphsample_insert.sh @@ -0,0 +1,89 @@ +#!/bin/bash +set -x + +GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))") +DB_PATH="${GRAPH_NET_ROOT}/sqlite/GraphNet.db" +TORCH_MODEL_LIST="graph_net/config/test.txt" +PADDLE_MODEL_LIST="graph_net/config/small10_paddle_samples_list.txt" +TYPICAL_GRAPH_SAMPLES_LIST="tututu/range_decomposed_subgraph_sample_list.txt" +FUSIBLE_GRAPH_SAMPLES_LIST="tututu/fusible_subgraph_sample_list.txt" +SOLE_OP_GRAPH_SAMPLES_LIST="sole_graph/single_operator_sample_list.txt" +ORDER_VALUE=0 + + +if [ ! -f "$DB_PATH" ]; then + echo "Fail ! No Database ! : $DB_PATH" + exit 1 +fi + +while IFS= read -r model_rel_path; do + echo "insert : $model_rel_path" + python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \ + --model_path_prefix "$GRAPH_NET_ROOT" \ + --relative_model_path "$model_rel_path" \ + --repo_uid "github_torch_samples" \ + --sample_type "full_graph" \ + --order_value "$ORDER_VALUE" \ + --db_path "$DB_PATH" + + ((ORDER_VALUE++)) + +done < "$TORCH_MODEL_LIST" + +while IFS= read -r model_rel_path; do + echo "insert : $model_rel_path" + python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \ + --model_path_prefix "$GRAPH_NET_ROOT" \ + --relative_model_path "$model_rel_path" \ + --repo_uid "github_paddle_samples" \ + --sample_type "full_graph" \ + --order_value "$ORDER_VALUE" \ + --db_path "$DB_PATH" + + ((ORDER_VALUE++)) + +done < "$PADDLE_MODEL_LIST" + +while IFS= read -r model_rel_path; do + echo "insert : $model_rel_path" + python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \ + --model_path_prefix "${GRAPH_NET_ROOT}/tututu/range_decompose" \ + --relative_model_path "$model_rel_path" \ + --repo_uid "github_torch_samples" \ + --sample_type "typical_graph" \ + --order_value "$ORDER_VALUE" \ + --db_path "$DB_PATH" + + ((ORDER_VALUE++)) + +done < "$TYPICAL_GRAPH_SAMPLES_LIST" + +while IFS= read -r model_rel_path; do + echo "insert : $model_rel_path" + python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \ + --model_path_prefix "${GRAPH_NET_ROOT}/tututu/fusible_subgraph_samples" \ + --relative_model_path "$model_rel_path" \ + --repo_uid "github_torch_samples" \ + --sample_type "fusible_graph" \ + --order_value "$ORDER_VALUE" \ + --db_path "$DB_PATH" + + ((ORDER_VALUE++)) + +done < "$FUSIBLE_GRAPH_SAMPLES_LIST" + +while IFS= read -r model_rel_path; do + echo "insert : $model_rel_path" + python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \ + --model_path_prefix "${GRAPH_NET_ROOT}/sole_graph" \ + --relative_model_path "$model_rel_path" \ + --repo_uid "github_torch_samples" \ + --sample_type "sole_op_graph" \ + --order_value "$ORDER_VALUE" \ + --db_path "$DB_PATH" + + ((ORDER_VALUE++)) + +done < "$SOLE_OP_GRAPH_SAMPLES_LIST" + +echo "all done" diff --git a/sqlite/init_db.py b/sqlite/init_db.py new file mode 100644 index 000000000..da1d0ab57 --- /dev/null +++ b/sqlite/init_db.py @@ -0,0 +1,65 @@ +import sqlite3 +import re +import argparse +from pathlib import Path + + +def parse_timestamp(filename: str) -> int: + match = re.search(r"(\d{4}-\d{2}-\d{2}-\d{6})", filename) + if match: + timestamp_str = match.group(1).replace("-", "") + return int(timestamp_str) + return 0 + + +def migrate(db_path: str = "sqlite/GraphNet.db", migrates_dir: str = "sqlite/migrates"): + db_path_obj = Path(db_path) + migrates_path = Path(migrates_dir) + + if db_path_obj.exists(): + db_path_obj.unlink() + print(f"Deleted existing database: {db_path}") + + db_path_obj.parent.mkdir(parents=True, exist_ok=True) + db_path_obj.touch() + print(f"Created new database: {db_path}") + + sql_files = list(migrates_path.glob("*.sql")) + if not sql_files: + print(f"No migration files found in {migrates_dir}") + return + + sql_files.sort(key=lambda f: parse_timestamp(f.name)) + print(f"Found {len(sql_files)} migration file(s)") + print("=" * 50) + for sql_file in sql_files: + print(f"\nExecuting: {sql_file.name}") + with open(sql_file, "r", encoding="utf-8") as f: + sql_content = f.read() + + try: + conn = sqlite3.connect(db_path) + conn.executescript(sql_content) + conn.commit() + conn.close() + print(f" ✓ Completed: {sql_file.name}") + except Exception as e: + print(f" ✗ Failed: {sql_file.name}") + print(f" Error: {e}") + if Path(db_path).exists(): + Path(db_path).unlink() + + print("\n" + "=" * 50) + print(f"Migration completed. Database: {db_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GraphNet database migration tool") + parser.add_argument( + "--db_path", + type=str, + default="sqlite/GraphNet.db", + help="Database file path (default: sqlite/GraphNet.db)", + ) + args = parser.parse_args() + migrate(args.db_path) diff --git a/sqlite/migrates/create_main_tables_2026-02-02-031353.sql b/sqlite/migrates/create_main_tables_2026-02-02-031353.sql new file mode 100644 index 000000000..67cd89393 --- /dev/null +++ b/sqlite/migrates/create_main_tables_2026-02-02-031353.sql @@ -0,0 +1,92 @@ +-- SQLite +-- create repo table +CREATE TABLE IF NOT EXISTS repo ( + repo_uid VARCHAR(255) NOT NULL PRIMARY KEY, + repo_type VARCHAR(50) NOT NULL, + repo_name VARCHAR(255) NOT NULL, + repo_url TEXT +); +INSERT OR IGNORE INTO repo (repo_uid, repo_type, repo_name, repo_url) VALUES +('github_torch_samples', 'github', 'GraphNet', 'https://github.com/PaddlePaddle/GraphNet'), +('github_paddle_samples', 'github', 'GraphNet', 'https://github.com/PaddlePaddle/GraphNet'); + + +-- create graph_sample table +CREATE TABLE IF NOT EXISTS graph_sample ( + uuid VARCHAR(255) NOT NULL PRIMARY KEY, + repo_uid VARCHAR(255) NOT NULL, + relative_model_path TEXT NOT NULL, + sample_type VARCHAR(50) NOT NULL, + is_subgraph BOOLEAN DEFAULT FALSE, + num_ops INTEGER DEFAULT -1, + graph_hash VARCHAR(255) NOT NULL, + order_value INTEGER, + create_at DATETIME DEFAULT CURRENT_TIMESTAMP, + deleted BOOLEAN DEFAULT FALSE, + delete_at DATETIME, + FOREIGN KEY (repo_uid) REFERENCES repo(repo_uid) +); +CREATE INDEX IF NOT EXISTS idx_relative_model_path ON graph_sample (relative_model_path); +CREATE INDEX IF NOT EXISTS idx_graph_hash ON graph_sample (graph_hash); +CREATE INDEX IF NOT EXISTS idx_order_value ON graph_sample (order_value); +CREATE UNIQUE INDEX IF NOT EXISTS uq_relative_model_path_repo_uid ON graph_sample (relative_model_path, repo_uid); + +-- create subgraph_source table +CREATE TABLE IF NOT EXISTS subgraph_source ( + subgraph_uuid VARCHAR(255) NOT NULL PRIMARY KEY, + full_graph_uuid VARCHAR(255) NOT NULL, + range_start INTEGER NOT NULL, + range_end INTEGER NOT NULL, + create_at DATETIME DEFAULT CURRENT_TIMESTAMP, + deleted BOOLEAN DEFAULT FALSE, + delete_at DATETIME, + FOREIGN KEY (subgraph_uuid) REFERENCES graph_sample(uuid), + FOREIGN KEY (full_graph_uuid) REFERENCES graph_sample(uuid) +); +CREATE INDEX IF NOT EXISTS idx_subgraph_uuid ON subgraph_source (subgraph_uuid); +CREATE INDEX IF NOT EXISTS idx_full_graph_uuid ON subgraph_source (full_graph_uuid); + +-- create dimension_generalization_source table +CREATE TABLE IF NOT EXISTS dimension_generalization_source ( + generalized_graph_uuid VARCHAR(255) NOT NULL PRIMARY KEY, + original_graph_uuid VARCHAR(255) NOT NULL, + total_element_size INTEGER NOT NULL, + create_at DATETIME DEFAULT CURRENT_TIMESTAMP, + deleted BOOLEAN DEFAULT FALSE, + delete_at DATETIME, + FOREIGN KEY (generalized_graph_uuid) REFERENCES graph_sample(uuid), + FOREIGN KEY (original_graph_uuid) REFERENCES graph_sample(uuid) +); +CREATE INDEX IF NOT EXISTS idx_dimension_generalized_graph_uuid ON dimension_generalization_source (generalized_graph_uuid); +CREATE INDEX IF NOT EXISTS idx_dimension_original_graph_uuid ON dimension_generalization_source (original_graph_uuid); +CREATE INDEX IF NOT EXISTS idx_total_element_size ON dimension_generalization_source (total_element_size); + +-- create datatype_generalization_source table +CREATE TABLE IF NOT EXISTS datatype_generalization_source ( + generalized_graph_uuid VARCHAR(255) NOT NULL PRIMARY KEY, + original_graph_uuid VARCHAR(255) NOT NULL, + data_type VARCHAR(50) NOT NULL, + create_at DATETIME DEFAULT CURRENT_TIMESTAMP, + deleted BOOLEAN DEFAULT FALSE, + delete_at DATETIME, + FOREIGN KEY (generalized_graph_uuid) REFERENCES graph_sample(uuid), + FOREIGN KEY (original_graph_uuid) REFERENCES graph_sample(uuid) +); +CREATE INDEX IF NOT EXISTS idx_datatype_generalized_graph_uuid ON datatype_generalization_source (generalized_graph_uuid); +CREATE INDEX IF NOT EXISTS idx_datatype_original_graph_uuid ON datatype_generalization_source (original_graph_uuid); + +-- create backward_graph_source table +CREATE TABLE IF NOT EXISTS backward_graph_source ( + forward_graph_uuid VARCHAR(255) NOT NULL PRIMARY KEY, + backward_graph_uuid VARCHAR(255) NOT NULL, + original_graph_uuid VARCHAR(255) NOT NULL, + create_at DATETIME DEFAULT CURRENT_TIMESTAMP, + deleted BOOLEAN DEFAULT FALSE, + delete_at DATETIME, + FOREIGN KEY (forward_graph_uuid) REFERENCES graph_sample(uuid), + FOREIGN KEY (backward_graph_uuid) REFERENCES graph_sample(uuid), + FOREIGN KEY (original_graph_uuid) REFERENCES graph_sample(uuid) +); +CREATE INDEX IF NOT EXISTS idx_forward_graph_uuid ON backward_graph_source (forward_graph_uuid); +CREATE INDEX IF NOT EXISTS idx_backward_graph_uuid ON backward_graph_source (backward_graph_uuid); +CREATE INDEX IF NOT EXISTS idx_backward_original_graph_uuid ON backward_graph_source (original_graph_uuid); diff --git a/sqlite/orm_models.py b/sqlite/orm_models.py new file mode 100644 index 000000000..a9fd5cb9a --- /dev/null +++ b/sqlite/orm_models.py @@ -0,0 +1,235 @@ +from sqlalchemy.orm import declarative_base, sessionmaker, relationship +from datetime import datetime +from sqlalchemy import ( + create_engine, + Column, + String, + Integer, + Boolean, + DateTime, + ForeignKey, + Index, + UniqueConstraint, +) + +Base = declarative_base() + + +class Repo(Base): + __tablename__ = "repo" + + repo_uid = Column(String(255), primary_key=True) + repo_type = Column(String(50), nullable=False) + repo_name = Column(String(255), nullable=False) + repo_url = Column(String(255), nullable=False) + graph_samples = relationship("GraphSample", back_populates="repo") + + +class GraphSample(Base): + __tablename__ = "graph_sample" + + uuid = Column(String(255), primary_key=True) + repo_uid = Column(String(255), ForeignKey("repo.repo_uid"), nullable=False) + relative_model_path = Column(String, nullable=False) + sample_type = Column(String(50), nullable=False) + is_subgraph = Column(Boolean, default=False) + num_ops = Column(Integer, default=-1) + graph_hash = Column(String(255), nullable=False) + order_value = Column(Integer) + create_at = Column(DateTime, default=datetime.now) + deleted = Column(Boolean, default=False) + delete_at = Column(DateTime) + + __table_args__ = ( + Index("idx_relative_model_path", "relative_model_path"), + Index("idx_graph_hash", "graph_hash"), + Index("idx_order_value", "order_value"), + UniqueConstraint( + "relative_model_path", "repo_uid", name="uq_relative_model_path_repo_uid" + ), + ) + + repo = relationship("Repo", back_populates="graph_samples") + subgraph_sources = relationship( + "SubgraphSource", + foreign_keys="SubgraphSource.subgraph_uuid", + back_populates="subgraph", + ) + subgraph_as_full_graph = relationship( + "SubgraphSource", + foreign_keys="SubgraphSource.full_graph_uuid", + back_populates="full_graph", + ) + dimension_sources_as_generalized = relationship( + "DimensionGeneralizationSource", + foreign_keys="DimensionGeneralizationSource.generalized_graph_uuid", + back_populates="generalized_graph", + ) + dimension_sources_as_original = relationship( + "DimensionGeneralizationSource", + foreign_keys="DimensionGeneralizationSource.original_graph_uuid", + back_populates="original_graph", + ) + data_type_sources_as_original = relationship( + "DataTypeGeneralizationSource", + foreign_keys="DataTypeGeneralizationSource.original_graph_uuid", + back_populates="original_graph", + ) + data_type_sources_as_generalized = relationship( + "DataTypeGeneralizationSource", + foreign_keys="DataTypeGeneralizationSource.generalized_graph_uuid", + back_populates="generalized_graph", + ) + backward_graph_sources_as_forward = relationship( + "BackwardGraphSource", + foreign_keys="BackwardGraphSource.forward_graph_uuid", + back_populates="forward_graph", + ) + backward_graph_as_backward = relationship( + "BackwardGraphSource", + foreign_keys="BackwardGraphSource.backward_graph_uuid", + back_populates="backward_graph", + ) + backward_graph_as_original = relationship( + "BackwardGraphSource", + foreign_keys="BackwardGraphSource.original_graph_uuid", + back_populates="original_graph", + ) + + +class SubgraphSource(Base): + __tablename__ = "subgraph_source" + + subgraph_uuid = Column( + String(255), ForeignKey("graph_sample.uuid"), nullable=False, primary_key=True + ) + full_graph_uuid = Column( + String(255), ForeignKey("graph_sample.uuid"), nullable=False + ) + range_start = Column(Integer, nullable=False) + range_end = Column(Integer, nullable=False) + create_at = Column(DateTime, default=datetime.now) + deleted = Column(Boolean, default=False) + delete_at = Column(DateTime) + + __table_args__ = ( + Index("idx_subgraph_uuid", "subgraph_uuid"), + Index("idx_full_graph_uuid", "full_graph_uuid"), + ) + + subgraph = relationship( + "GraphSample", foreign_keys=[subgraph_uuid], back_populates="subgraph_sources" + ) + full_graph = relationship( + "GraphSample", + foreign_keys=[full_graph_uuid], + back_populates="subgraph_as_full_graph", + ) + + +class DimensionGeneralizationSource(Base): + __tablename__ = "dimension_generalization_source" + + generalized_graph_uuid = Column( + String(255), ForeignKey("graph_sample.uuid"), nullable=False, primary_key=True + ) + original_graph_uuid = Column( + String(255), ForeignKey("graph_sample.uuid"), nullable=False + ) + total_element_size = Column(Integer, nullable=False) + create_at = Column(DateTime, default=datetime.now) + deleted = Column(Boolean, default=False) + delete_at = Column(DateTime) + + __table_args__ = ( + Index("idx_dimension_generalized_graph_uuid", "generalized_graph_uuid"), + Index("idx_dimension_original_graph_uuid", "original_graph_uuid"), + Index("idx_total_element_size", "total_element_size"), + ) + + generalized_graph = relationship( + "GraphSample", + foreign_keys=[generalized_graph_uuid], + back_populates="dimension_sources_as_generalized", + ) + original_graph = relationship( + "GraphSample", + foreign_keys=[original_graph_uuid], + back_populates="dimension_sources_as_original", + ) + + +class DataTypeGeneralizationSource(Base): + __tablename__ = "datatype_generalization_source" + + generalized_graph_uuid = Column( + String(255), ForeignKey("graph_sample.uuid"), nullable=False, primary_key=True + ) + original_graph_uuid = Column( + String(255), ForeignKey("graph_sample.uuid"), nullable=False + ) + data_type = Column(String(50), nullable=False) + create_at = Column(DateTime, default=datetime.now) + deleted = Column(Boolean, default=False) + delete_at = Column(DateTime) + + __table_args__ = ( + Index("idx_datatype_generalized_graph_uuid", "generalized_graph_uuid"), + Index("idx_datatype_original_graph_uuid", "original_graph_uuid"), + ) + + generalized_graph = relationship( + "GraphSample", + foreign_keys=[generalized_graph_uuid], + back_populates="data_type_sources_as_generalized", + ) + original_graph = relationship( + "GraphSample", + foreign_keys=[original_graph_uuid], + back_populates="data_type_sources_as_original", + ) + + +class BackwardGraphSource(Base): + __tablename__ = "backward_graph_source" + + forward_graph_uuid = Column( + String(255), ForeignKey("graph_sample.uuid"), nullable=False, primary_key=True + ) + backward_graph_uuid = Column( + String(255), ForeignKey("graph_sample.uuid"), nullable=False + ) + original_graph_uuid = Column( + String(255), ForeignKey("graph_sample.uuid"), nullable=False + ) + create_at = Column(DateTime, default=datetime.now) + deleted = Column(Boolean, default=False) + delete_at = Column(DateTime) + + __table_args__ = ( + Index("idx_forward_graph_uuid", "forward_graph_uuid"), + Index("idx_backward_graph_uuid", "backward_graph_uuid"), + Index("idx_backward_original_graph_uuid", "original_graph_uuid"), + ) + + forward_graph = relationship( + "GraphSample", + foreign_keys=[forward_graph_uuid], + back_populates="backward_graph_sources_as_forward", + ) + backward_graph = relationship( + "GraphSample", + foreign_keys=[backward_graph_uuid], + back_populates="backward_graph_as_backward", + ) + original_graph = relationship( + "GraphSample", + foreign_keys=[original_graph_uuid], + back_populates="backward_graph_as_original", + ) + + +def get_session(db_path: str, echo: bool = False): + engine = create_engine(f"sqlite:///{db_path}", echo=echo) + Session = sessionmaker(bind=engine) + return Session()