From 6691f655cb6cd78e5212258881b34437163e462a Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 2 Jan 2026 00:23:13 -0600 Subject: [PATCH 1/3] Implement semantic matching for joins based on attribute lineage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add lineage tracking via ~lineage table per schema - Track attribute origin (schema.table.attribute) for FK and PK attributes - Semantic check on joins/restrictions: error if namesakes have different lineage - Add Schema.rebuild_lineage() to restore lineage for legacy schemas - Add Schema.lineage_table_exists property - Remove @ and ^ operators (use .join/.restrict with semantic_check=False) - Remove dj.U * table pattern (use dj.U & table instead) - Warn when parent lineage missing during table declaration - Skip semantic check with warning if ~lineage table doesn't exist - Add comprehensive spec with API reference and user guide 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- docs/src/design/semantic-matching-spec.md | 525 +++++++++++++++++++ src/datajoint/condition.py | 79 ++- src/datajoint/declare.py | 19 +- src/datajoint/expression.py | 91 ++-- src/datajoint/heading.py | 34 +- src/datajoint/lineage.py | 262 +++++++++ src/datajoint/schemas.py | 29 + src/datajoint/table.py | 75 ++- tests/integration/test_aggr_regressions.py | 3 +- tests/integration/test_relation_u.py | 36 +- tests/integration/test_relational_operand.py | 23 +- tests/integration/test_semantic_matching.py | 342 ++++++++++++ 12 files changed, 1422 insertions(+), 96 deletions(-) create mode 100644 docs/src/design/semantic-matching-spec.md create mode 100644 src/datajoint/lineage.py create mode 100644 tests/integration/test_semantic_matching.py diff --git a/docs/src/design/semantic-matching-spec.md b/docs/src/design/semantic-matching-spec.md new file mode 100644 index 000000000..666fc4cfe --- /dev/null +++ b/docs/src/design/semantic-matching-spec.md @@ -0,0 +1,525 @@ +# Semantic Matching for Joins - Specification + +## Overview + +This document specifies **semantic matching** for joins in DataJoint 2.0, replacing the current name-based matching rules. Semantic matching ensures that attributes are only matched when they share both the same name and the same **lineage** (origin), preventing accidental joins on unrelated attributes that happen to share names. + +### Goals + +1. **Prevent incorrect joins** on attributes that share names but represent different entities +2. **Enable valid joins** that are currently blocked due to overly restrictive rules +3. **Maintain backward compatibility** for well-designed schemas +4. **Provide clear error messages** when semantic conflicts are detected + +--- + +## User Guide + +### Quick Start + +Semantic matching is enabled by default in DataJoint 2.0. For most well-designed schemas, no changes are required. + +#### When You Might See Errors + +```python +# Two tables with generic 'id' attribute +class Student(dj.Manual): + definition = """ + id : int + --- + name : varchar(100) + """ + +class Course(dj.Manual): + definition = """ + id : int + --- + title : varchar(100) + """ + +# This will raise an error because 'id' has different lineages +Student() * Course() # DataJointError! +``` + +#### How to Resolve + +**Option 1: Rename attributes using projection** +```python +Student() * Course().proj(course_id='id') # OK +``` + +**Option 2: Bypass semantic check (use with caution)** +```python +Student().join(Course(), semantic_check=False) # OK, but be careful! +``` + +**Option 3: Use descriptive names (best practice)** +```python +class Student(dj.Manual): + definition = """ + student_id : int + --- + name : varchar(100) + """ +``` + +### Migrating from DataJoint 1.x + +#### Removed Operators + +| Old Syntax | New Syntax | +|------------|------------| +| `A @ B` | `A.join(B, semantic_check=False)` | +| `A ^ B` | `A.restrict(B, semantic_check=False)` | +| `dj.U('a') * B` | `dj.U('a') & B` | + +#### Rebuilding Lineage for Existing Schemas + +If you have existing schemas created before DataJoint 2.0, rebuild their lineage tables: + +```python +import datajoint as dj + +# Connect and get your schema +schema = dj.Schema('my_database') + +# Rebuild lineage (do this once per schema) +schema.rebuild_lineage() + +# Restart Python kernel to pick up changes +``` + +**Important**: If your schema references tables in other schemas, rebuild those upstream schemas first. + +--- + +## API Reference + +### Schema Methods + +#### `schema.rebuild_lineage()` + +Rebuild the `~lineage` table for all tables in this schema. + +```python +schema.rebuild_lineage() +``` + +**Description**: Recomputes lineage for all attributes by querying FK relationships from the database's `information_schema`. Use this to restore lineage for schemas that predate the lineage system or after corruption. + +**Requirements**: +- Schema must exist +- Upstream schemas (referenced via cross-schema FKs) must have their lineage rebuilt first + +**Side Effects**: +- Creates `~lineage` table if it doesn't exist +- Deletes and repopulates all lineage entries for tables in the schema + +**Post-Action**: Restart Python kernel and reimport to pick up new lineage information. + +#### `schema.lineage_table_exists` + +Property indicating whether the `~lineage` table exists in this schema. + +```python +if schema.lineage_table_exists: + print("Lineage tracking is enabled") +``` + +**Returns**: `bool` - `True` if `~lineage` table exists, `False` otherwise. + +### Join Methods + +#### `expr.join(other, semantic_check=True)` + +Join two expressions with optional semantic checking. + +```python +result = A.join(B) # semantic_check=True (default) +result = A.join(B, semantic_check=False) # bypass semantic check +``` + +**Parameters**: +- `other`: Another query expression to join with +- `semantic_check` (bool): If `True` (default), raise error on non-homologous namesakes. If `False`, perform natural join without lineage checking. + +**Raises**: `DataJointError` if `semantic_check=True` and namesake attributes have different lineages. + +#### `expr.restrict(other, semantic_check=True)` + +Restrict expression with optional semantic checking. + +```python +result = A.restrict(B) # semantic_check=True (default) +result = A.restrict(B, semantic_check=False) # bypass semantic check +``` + +**Parameters**: +- `other`: Restriction condition (expression, dict, string, etc.) +- `semantic_check` (bool): If `True` (default), raise error on non-homologous namesakes when restricting by another expression. If `False`, no lineage checking. + +**Raises**: `DataJointError` if `semantic_check=True` and namesake attributes have different lineages. + +### Operators + +#### `A * B` (Join) + +Equivalent to `A.join(B, semantic_check=True)`. + +#### `A & B` (Restriction) + +Equivalent to `A.restrict(B, semantic_check=True)`. + +#### `A - B` (Anti-restriction) + +Restriction with negation. Semantic checking applies. + +#### `A + B` (Union) + +Union of expressions. Requires all namesake attributes to have matching lineage. + +### Removed Operators + +#### `A @ B` (Removed) + +Raises `DataJointError` with migration guidance to use `.join(semantic_check=False)`. + +#### `A ^ B` (Removed) + +Raises `DataJointError` with migration guidance to use `.restrict(semantic_check=False)`. + +#### `dj.U(...) * A` (Removed) + +Raises `DataJointError` with migration guidance to use `dj.U(...) & A`. + +### Universal Set (`dj.U`) + +#### Valid Operations + +```python +dj.U('a', 'b') & A # Restriction: promotes a, b to PK +dj.U('a', 'b').aggr(A, ...) # Aggregation: groups by a, b +dj.U() & A # Distinct primary keys of A +``` + +#### Invalid Operations + +```python +dj.U('a', 'b') - A # DataJointError: produces infinite set +dj.U('a', 'b') * A # DataJointError: use & instead +``` + +--- + +## Concepts + +### Attribute Lineage + +Lineage identifies the **origin** of an attribute - where it was first defined. It is represented as a string: + +``` +schema_name.table_name.attribute_name +``` + +#### Lineage Assignment Rules + +| Attribute Type | Lineage Value | +|----------------|---------------| +| Native primary key | `this_schema.this_table.attr_name` | +| FK-inherited (primary or secondary) | Traced to original definition | +| Native secondary | `None` | +| Computed (in projection) | `None` | + +#### Example + +```python +class Session(dj.Manual): # table: session + definition = """ + session_id : int + --- + date : date + """ + +class Trial(dj.Manual): # table: trial + definition = """ + -> Session + trial_num : int + --- + stimulus : varchar(100) + """ +``` + +Lineages: +- `Session.session_id` → `myschema.session.session_id` (native PK) +- `Session.date` → `None` (native secondary) +- `Trial.session_id` → `myschema.session.session_id` (inherited via FK) +- `Trial.trial_num` → `myschema.trial.trial_num` (native PK) +- `Trial.stimulus` → `None` (native secondary) + +### Terminology + +| Term | Definition | +|------|------------| +| **Lineage** | The origin of an attribute: `schema.table.attribute` | +| **Homologous attributes** | Attributes with the same lineage | +| **Namesake attributes** | Attributes with the same name | +| **Homologous namesakes** | Same name AND same lineage — used for join matching | +| **Non-homologous namesakes** | Same name BUT different lineage — cause join errors | + +### Semantic Matching Rules + +| Scenario | Action | +|----------|--------| +| Same name, same lineage (both non-null) | **Match** | +| Same name, different lineage | **Error** | +| Same name, either lineage is null | **Error** | +| Different names | **No match** | + +--- + +## Implementation Details + +### `~lineage` Table + +Each schema has a hidden `~lineage` table storing lineage information: + +```sql +CREATE TABLE `schema_name`.`~lineage` ( + table_name VARCHAR(64) NOT NULL, + attribute_name VARCHAR(64) NOT NULL, + lineage VARCHAR(255) NOT NULL, + PRIMARY KEY (table_name, attribute_name) +) +``` + +### Lineage Population + +**At table declaration**: +1. Delete any existing lineage entries for the table +2. For FK attributes: copy lineage from parent (with warning if parent lineage missing) +3. For native PK attributes: set lineage to `schema.table.attribute` +4. Native secondary attributes: no entry (lineage = None) + +**At table drop**: +- Delete all lineage entries for the table + +### Missing Lineage Handling + +**If `~lineage` table doesn't exist**: +- Warning issued during semantic check +- Semantic checking disabled (join proceeds as natural join) + +**If parent lineage missing during declaration**: +- Warning issued +- Parent attribute used as origin +- Recommend rebuilding lineage after parent schema is fixed + +### Heading's `lineage_available` Property + +The `Heading` class tracks whether lineage information is available: + +```python +heading.lineage_available # True if ~lineage table exists for this schema +``` + +This property is: +- Set when heading is loaded from database +- Propagated through projections, joins, and other operations +- Used by `assert_join_compatibility` to decide whether to perform semantic checking + +--- + +## Error Messages + +### Non-Homologous Namesakes + +``` +DataJointError: Cannot join on attribute `id`: different lineages +(university.student.id vs university.course.id). +Use .proj() to rename one of the attributes. +``` + +### Removed `@` Operator + +``` +DataJointError: The @ operator has been removed in DataJoint 2.0. +Use .join(other, semantic_check=False) for permissive joins. +``` + +### Removed `^` Operator + +``` +DataJointError: The ^ operator has been removed in DataJoint 2.0. +Use .restrict(other, semantic_check=False) for permissive restrictions. +``` + +### Removed `dj.U * table` + +``` +DataJointError: dj.U(...) * table is no longer supported in DataJoint 2.0. +Use dj.U(...) & table instead. +``` + +### Missing Lineage Warning + +``` +WARNING: Semantic check disabled: ~lineage table not found. +To enable semantic matching, rebuild lineage with: schema.rebuild_lineage() +``` + +### Parent Lineage Missing Warning + +``` +WARNING: Lineage for `parent_db`.`parent_table`.`attr` not found +(parent schema's ~lineage table may be missing or incomplete). +Using it as origin. Once the parent schema's lineage is rebuilt, +run schema.rebuild_lineage() on this schema to correct the lineage. +``` + +--- + +## Examples + +### Example 1: Valid Join (Shared Lineage) + +```python +class Student(dj.Manual): + definition = """ + student_id : int + --- + name : varchar(100) + """ + +class Enrollment(dj.Manual): + definition = """ + -> Student + -> Course + --- + grade : varchar(2) + """ + +# Works: student_id has same lineage in both +Student() * Enrollment() +``` + +### Example 2: Invalid Join (Different Lineage) + +```python +class TableA(dj.Manual): + definition = """ + id : int + --- + value_a : int + """ + +class TableB(dj.Manual): + definition = """ + id : int + --- + value_b : int + """ + +# Error: 'id' has different lineages +TableA() * TableB() + +# Solution 1: Rename +TableA() * TableB().proj(b_id='id') + +# Solution 2: Bypass (use with caution) +TableA().join(TableB(), semantic_check=False) +``` + +### Example 3: Multi-hop FK Inheritance + +```python +class Session(dj.Manual): + definition = """ + session_id : int + --- + date : date + """ + +class Trial(dj.Manual): + definition = """ + -> Session + trial_num : int + """ + +class Response(dj.Computed): + definition = """ + -> Trial + --- + response_time : float + """ + +# All work: session_id traces back to Session in all tables +Session() * Trial() +Session() * Response() +Trial() * Response() +``` + +### Example 4: Secondary FK Attribute + +```python +class Course(dj.Manual): + definition = """ + course_id : int + --- + title : varchar(100) + """ + +class FavoriteCourse(dj.Manual): + definition = """ + student_id : int + --- + -> Course + """ + +class RequiredCourse(dj.Manual): + definition = """ + major_id : int + --- + -> Course + """ + +# Works: course_id is secondary in both, but has same lineage +FavoriteCourse() * RequiredCourse() +``` + +### Example 5: Aliased Foreign Key + +```python +class Person(dj.Manual): + definition = """ + person_id : int + --- + name : varchar(100) + """ + +class Marriage(dj.Manual): + definition = """ + -> Person.proj(husband='person_id') + -> Person.proj(wife='person_id') + --- + date : date + """ + +# husband and wife both have lineage: schema.person.person_id +# They are homologous (same lineage) but have different names +``` + +--- + +## Best Practices + +1. **Use descriptive attribute names**: Prefer `student_id` over generic `id` + +2. **Leverage foreign keys**: Inherited attributes maintain lineage automatically + +3. **Rebuild lineage for legacy schemas**: Run `schema.rebuild_lineage()` once + +4. **Rebuild upstream schemas first**: For cross-schema FKs, rebuild parent schemas before child schemas + +5. **Restart after rebuilding**: Restart Python kernel to pick up new lineage information + +6. **Use `semantic_check=False` sparingly**: Only when you're certain the natural join is correct diff --git a/src/datajoint/condition.py b/src/datajoint/condition.py index 8a22d17bb..085fb3d89 100644 --- a/src/datajoint/condition.py +++ b/src/datajoint/condition.py @@ -5,6 +5,7 @@ import decimal import inspect import json +import logging import re import uuid from dataclasses import dataclass @@ -14,6 +15,8 @@ from .errors import DataJointError +logger = logging.getLogger(__name__.split(".")[0]) + JSON_PATTERN = re.compile(r"^(?P\w+)(\.(?P[\w.*\[\]]+))?(:(?P[\w(,\s)]+))?$") @@ -95,32 +98,60 @@ def __init__(self, restriction): self.restriction = restriction -def assert_join_compatibility(expr1, expr2): +def assert_join_compatibility(expr1, expr2, semantic_check=True): """ - Determine if expressions expr1 and expr2 are join-compatible. To be join-compatible, - the matching attributes in the two expressions must be in the primary key of one or the - other expression. - Raises an exception if not compatible. + Determine if expressions expr1 and expr2 are join-compatible. + + With semantic_check=True (default): + Raises an error if there are non-homologous namesakes (same name, different lineage). + This prevents accidental joins on attributes that share names but represent + different entities. + + If the ~lineage table doesn't exist for either schema, a warning is issued + and semantic checking is disabled (join proceeds as natural join). + + With semantic_check=False: + No lineage checking. All namesake attributes are matched (natural join behavior). :param expr1: A QueryExpression object :param expr2: A QueryExpression object + :param semantic_check: If True (default), use semantic matching and error on conflicts """ from .expression import QueryExpression, U for rel in (expr1, expr2): if not isinstance(rel, (U, QueryExpression)): raise DataJointError("Object %r is not a QueryExpression and cannot be joined." % rel) - if not isinstance(expr1, U) and not isinstance(expr2, U): # dj.U is always compatible - try: - raise DataJointError( - "Cannot join query expressions on dependent attribute `%s`" - % next(r for r in set(expr1.heading.secondary_attributes).intersection(expr2.heading.secondary_attributes)) - ) - except StopIteration: - pass # all ok - -def make_condition(query_expression, condition, columns): + # dj.U is always compatible (it represents all possible lineages) + if isinstance(expr1, U) or isinstance(expr2, U): + return + + if semantic_check: + # Check if lineage tracking is available for both expressions + if not expr1.heading.lineage_available or not expr2.heading.lineage_available: + logger.warning( + "Semantic check disabled: ~lineage table not found. " + "To enable semantic matching, rebuild lineage with: " + "schema.rebuild_lineage()" + ) + return + + # Error on non-homologous namesakes + namesakes = set(expr1.heading.names) & set(expr2.heading.names) + for name in namesakes: + lineage1 = expr1.heading[name].lineage + lineage2 = expr2.heading[name].lineage + # Semantic match requires both lineages to be non-None and equal + if lineage1 is None or lineage2 is None or lineage1 != lineage2: + raise DataJointError( + f"Cannot join on attribute `{name}`: " + f"different lineages ({lineage1} vs {lineage2}). " + f"Use .proj() to rename one of the attributes." + ) + + +def make_condition(query_expression, condition, columns, semantic_check=True): """ Translate the input condition into the equivalent SQL condition (a string) @@ -128,6 +159,7 @@ def make_condition(query_expression, condition, columns): :param condition: any valid restriction object. :param columns: a set passed by reference to collect all column names used in the condition. + :param semantic_check: If True (default), use semantic matching and error on conflicts. :return: an SQL condition string or a boolean value. """ from .expression import Aggregation, QueryExpression, U @@ -180,7 +212,11 @@ def combine_conditions(negate, conditions): # restrict by AndList if isinstance(condition, AndList): # omit all conditions that evaluate to True - items = [item for item in (make_condition(query_expression, cond, columns) for cond in condition) if item is not True] + items = [ + item + for item in (make_condition(query_expression, cond, columns, semantic_check) for cond in condition) + if item is not True + ] if any(item is False for item in items): return negate # if any item is False, the whole thing is False if not items: @@ -226,14 +262,9 @@ def combine_conditions(negate, conditions): condition = condition() # restrict by another expression (aka semijoin and antijoin) - check_compatibility = True - if isinstance(condition, PromiscuousOperand): - condition = condition.operand - check_compatibility = False - if isinstance(condition, QueryExpression): - if check_compatibility: - assert_join_compatibility(query_expression, condition) + assert_join_compatibility(query_expression, condition, semantic_check=semantic_check) + # Always match on all namesakes (natural join semantics) common_attributes = [q for q in condition.heading.names if q in query_expression.heading.names] columns.update(common_attributes) if isinstance(condition, Aggregation): @@ -255,7 +286,7 @@ def combine_conditions(negate, conditions): # if iterable (but not a string, a QueryExpression, or an AndList), treat as an OrList try: - or_list = [make_condition(query_expression, q, columns) for q in condition] + or_list = [make_condition(query_expression, q, columns, semantic_check) for q in condition] except TypeError: raise DataJointError("Invalid restriction type %r" % condition) else: diff --git a/src/datajoint/declare.py b/src/datajoint/declare.py index 8b6bfda80..05c0fab64 100644 --- a/src/datajoint/declare.py +++ b/src/datajoint/declare.py @@ -141,7 +141,7 @@ def is_foreign_key(line): return arrow_position >= 0 and not any(c in line[:arrow_position] for c in "\"#'") -def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreign_key_sql, index_sql): +def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreign_key_sql, index_sql, fk_attribute_map=None): """ :param line: a line from a table definition :param context: namespace containing referenced objects @@ -151,6 +151,7 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig :param attr_sql: list of sql statements defining attributes -- to be updated by this function. :param foreign_key_sql: list of sql statements specifying foreign key constraints -- to be updated by this function. :param index_sql: list of INDEX declaration statements, duplicate or redundant indexes are ok. + :param fk_attribute_map: dict mapping child attr -> (parent_table, parent_attr) -- to be updated by this function. """ # Parse and validate from .expression import QueryExpression @@ -194,6 +195,11 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig if primary_key is not None: primary_key.append(attr) attr_sql.append(ref.heading[attr].sql.replace("NOT NULL ", "", int(is_nullable))) + # Track FK attribute mapping for lineage: child_attr -> (parent_table, parent_attr) + if fk_attribute_map is not None: + parent_table = ref.support[0] # e.g., `schema`.`table` + parent_attr = ref.heading[attr].original_name + fk_attribute_map[attr] = (parent_table, parent_attr) # declare the foreign key foreign_key_sql.append( @@ -223,6 +229,7 @@ def prepare_declare(definition, context): foreign_key_sql = [] index_sql = [] external_stores = [] + fk_attribute_map = {} # child_attr -> (parent_table, parent_attr) for line in definition: if not line or line.startswith("#"): # ignore additional comments @@ -238,6 +245,7 @@ def prepare_declare(definition, context): attribute_sql, foreign_key_sql, index_sql, + fk_attribute_map, ) elif re.match(r"^(unique\s+)?index\s*.*$", line, re.I): # index compile_index(line, index_sql) @@ -258,6 +266,7 @@ def prepare_declare(definition, context): foreign_key_sql, index_sql, external_stores, + fk_attribute_map, ) @@ -285,6 +294,7 @@ def declare(full_table_name, definition, context): foreign_key_sql, index_sql, external_stores, + fk_attribute_map, ) = prepare_declare(definition, context) if config.get("add_hidden_timestamp", False): @@ -297,11 +307,12 @@ def declare(full_table_name, definition, context): if not primary_key: raise DataJointError("Table must have a primary key") - return ( + sql = ( "CREATE TABLE IF NOT EXISTS %s (\n" % full_table_name + ",\n".join(attribute_sql + ["PRIMARY KEY (`" + "`,`".join(primary_key) + "`)"] + foreign_key_sql + index_sql) + '\n) ENGINE=InnoDB, COMMENT "%s"' % table_comment - ), external_stores + ) + return sql, external_stores, primary_key, fk_attribute_map def _make_attribute_alter(new, old, primary_key): @@ -387,6 +398,7 @@ def alter(definition, old_definition, context): foreign_key_sql, index_sql, external_stores, + _fk_attribute_map, ) = prepare_declare(definition, context) ( table_comment_, @@ -395,6 +407,7 @@ def alter(definition, old_definition, context): foreign_key_sql_, index_sql_, external_stores_, + _fk_attribute_map_, ) = prepare_declare(old_definition, context) # analyze differences between declarations diff --git a/src/datajoint/expression.py b/src/datajoint/expression.py index 17d529ff8..cb4e015b7 100644 --- a/src/datajoint/expression.py +++ b/src/datajoint/expression.py @@ -7,7 +7,6 @@ from .condition import ( AndList, Not, - PromiscuousOperand, Top, assert_join_compatibility, extract_column_names, @@ -152,13 +151,22 @@ def make_subquery(self): result._heading = self.heading.make_subquery_heading() return result - def restrict(self, restriction): + def restrict(self, restriction, semantic_check=True): """ Produces a new expression with the new restriction applied. - rel.restrict(restriction) is equivalent to rel & restriction. - rel.restrict(Not(restriction)) is equivalent to rel - restriction + + :param restriction: a sequence or an array (treated as OR list), another QueryExpression, + an SQL condition string, or an AndList. + :param semantic_check: If True (default), use semantic matching - only match on + homologous namesakes and error on non-homologous namesakes. + If False, use natural matching on all namesakes (no lineage checking). + :return: A new QueryExpression with the restriction applied. + + rel.restrict(restriction) is equivalent to rel & restriction. + rel.restrict(Not(restriction)) is equivalent to rel - restriction + The primary key of the result is unaffected. - Successive restrictions are combined as logical AND: r & a & b is equivalent to r & AndList((a, b)) + Successive restrictions are combined as logical AND: r & a & b is equivalent to r & AndList((a, b)) Any QueryExpression, collection, or sequence other than an AndList are treated as OrLists (logical disjunction of conditions) Inverse restriction is accomplished by either using the subtraction operator or the Not class. @@ -185,17 +193,14 @@ def restrict(self, restriction): rel - None rel rel - any_empty_entity_set rel - When arg is another QueryExpression, the restriction rel & arg restricts rel to elements that match at least + When arg is another QueryExpression, the restriction rel & arg restricts rel to elements that match at least one element in arg (hence arg is treated as an OrList). - Conversely, rel - arg restricts rel to elements that do not match any elements in arg. + Conversely, rel - arg restricts rel to elements that do not match any elements in arg. Two elements match when their common attributes have equal values or when they have no common attributes. All shared attributes must be in the primary key of either rel or arg or both or an error will be raised. QueryExpression.restrict is the only access point that modifies restrictions. All other operators must ultimately call restrict() - - :param restriction: a sequence or an array (treated as OR list), another QueryExpression, an SQL condition - string, or an AndList. """ attributes = set() if isinstance(restriction, Top): @@ -204,7 +209,7 @@ def restrict(self, restriction): ) # make subquery to avoid overwriting existing Top result._top = restriction return result - new_condition = make_condition(self, restriction, attributes) + new_condition = make_condition(self, restriction, attributes, semantic_check=semantic_check) if new_condition is True: return self # restriction has no effect, return the same object # check that all attributes in condition are present in the query @@ -240,14 +245,11 @@ def __and__(self, restriction): return self.restrict(restriction) def __xor__(self, restriction): - """ - Permissive restriction operator ignoring compatibility check e.g. ``q1 ^ q2``. - """ - if inspect.isclass(restriction) and issubclass(restriction, QueryExpression): - restriction = restriction() - if isinstance(restriction, Not): - return self.restrict(Not(PromiscuousOperand(restriction.restriction))) - return self.restrict(PromiscuousOperand(restriction)) + """The ^ operator has been removed in DataJoint 2.0.""" + raise DataJointError( + "The ^ operator has been removed in DataJoint 2.0. " + "Use .restrict(other, semantic_check=False) for restrictions without semantic checking." + ) def __sub__(self, restriction): """ @@ -274,30 +276,37 @@ def __mul__(self, other): return self.join(other) def __matmul__(self, other): - """ - Permissive join of query expressions `self` and `other` ignoring compatibility check - e.g. ``q1 @ q2``. - """ - if inspect.isclass(other) and issubclass(other, QueryExpression): - other = other() # instantiate - return self.join(other, semantic_check=False) + """The @ operator has been removed in DataJoint 2.0.""" + raise DataJointError( + "The @ operator has been removed in DataJoint 2.0. " + "Use .join(other, semantic_check=False) for joins without semantic checking." + ) def join(self, other, semantic_check=True, left=False): """ - create the joined QueryExpression. - a * b is short for A.join(B) - a @ b is short for A.join(B, semantic_check=False) - Additionally, left=True will retain the rows of self, effectively performing a left join. + Create the joined QueryExpression. + + :param other: QueryExpression to join with + :param semantic_check: If True (default), use semantic matching - only match on + homologous namesakes (same lineage) and error on non-homologous namesakes. + If False, use natural join on all namesakes (no lineage checking). + :param left: If True, perform a left join (retain all rows from self) + :return: The joined QueryExpression + + a * b is short for a.join(b) """ - # trigger subqueries if joining on renamed attributes + # Joining with U is no longer supported if isinstance(other, U): - return other * self + raise DataJointError( + "table * dj.U(...) is no longer supported in DataJoint 2.0. " + "This pattern is no longer necessary with the new semantic matching system." + ) if inspect.isclass(other) and issubclass(other, QueryExpression): other = other() # instantiate if not isinstance(other, QueryExpression): raise DataJointError("The argument of join must be a QueryExpression") - if semantic_check: - assert_join_compatibility(self, other) + assert_join_compatibility(self, other, semantic_check=semantic_check) + # Always natural join on all namesakes join_attributes = set(n for n in self.heading.names if n in other.heading.names) # needs subquery if self's FROM clause has common attributes with other's FROM clause need_subquery1 = need_subquery2 = bool( @@ -826,8 +835,18 @@ def join(self, other, left=False): return result def __mul__(self, other): - """shorthand for join""" - return self.join(other) + """The * operator with dj.U has been removed in DataJoint 2.0.""" + raise DataJointError( + "dj.U(...) * table is no longer supported in DataJoint 2.0. " + "This pattern is no longer necessary with the new semantic matching system." + ) + + def __sub__(self, other): + """Anti-restriction with dj.U produces an infinite set.""" + raise DataJointError( + "dj.U(...) - table produces an infinite set and is not supported. " + "Consider using a different approach for your query." + ) def aggr(self, group, **named_attributes): """ diff --git a/src/datajoint/heading.py b/src/datajoint/heading.py index bc555224c..fe50ad204 100644 --- a/src/datajoint/heading.py +++ b/src/datajoint/heading.py @@ -13,6 +13,7 @@ TYPE_PATTERN, ) from .errors import DataJointError +from .lineage import get_table_lineages, lineage_table_exists class _MissingType(Codec, register=False): @@ -63,6 +64,7 @@ def decode(self, stored, *, key=None): unsupported=False, attribute_expression=None, dtype=object, + lineage=None, # Origin of attribute, e.g. "schema.table.attr" for semantic matching ) @@ -115,17 +117,24 @@ class Heading: the attribute names and the values are Attributes. """ - def __init__(self, attribute_specs=None, table_info=None): + def __init__(self, attribute_specs=None, table_info=None, lineage_available=True): """ :param attribute_specs: a list of dicts with the same keys as Attribute :param table_info: a dict with information to load the heading from the database + :param lineage_available: whether lineage tracking is available for this heading """ self.indexes = None self.table_info = table_info self._table_status = None + self._lineage_available = lineage_available self._attributes = None if attribute_specs is None else dict((q["name"], Attribute(**q)) for q in attribute_specs) + @property + def lineage_available(self): + """Whether lineage tracking is available for this heading's schema.""" + return self._lineage_available + def __len__(self): return 0 if self.attributes is None else len(self.attributes) @@ -375,6 +384,16 @@ def _init_from_database(self): # restore codec type name for display attr["type"] = codec_spec + # Load lineage information for semantic matching from ~lineage table + self._lineage_available = lineage_table_exists(conn, database) + if self._lineage_available: + lineages = get_table_lineages(conn, database, table_name) + for attr in attributes: + attr["lineage"] = lineages.get(attr["name"]) + else: + for attr in attributes: + attr["lineage"] = None + self._attributes = dict(((q["name"], Attribute(**q)) for q in attributes)) # Read and tabulate secondary indexes @@ -428,7 +447,7 @@ def select(self, select_list, rename_map=None, compute_map=None): dict(default_attribute_properties, name=new_name, attribute_expression=expr) for new_name, expr in compute_map.items() ) - return Heading(chain(copy_attrs, compute_attrs)) + return Heading(chain(copy_attrs, compute_attrs), lineage_available=self._lineage_available) def join(self, other): """ @@ -439,7 +458,8 @@ def join(self, other): [self.attributes[name].todict() for name in self.primary_key] + [other.attributes[name].todict() for name in other.primary_key if name not in self.primary_key] + [self.attributes[name].todict() for name in self.secondary_attributes if name not in other.primary_key] - + [other.attributes[name].todict() for name in other.secondary_attributes if name not in self.primary_key] + + [other.attributes[name].todict() for name in other.secondary_attributes if name not in self.primary_key], + lineage_available=self._lineage_available and other._lineage_available, ) def set_primary_key(self, primary_key): @@ -451,7 +471,8 @@ def set_primary_key(self, primary_key): chain( (dict(self.attributes[name].todict(), in_key=True) for name in primary_key), (dict(self.attributes[name].todict(), in_key=False) for name in self.names if name not in primary_key), - ) + ), + lineage_available=self._lineage_available, ) def make_subquery_heading(self): @@ -459,4 +480,7 @@ def make_subquery_heading(self): Create a new heading with removed attribute sql_expressions. Used by subqueries, which resolve the sql_expressions. """ - return Heading(dict(v.todict(), attribute_expression=None) for v in self.attributes.values()) + return Heading( + (dict(v.todict(), attribute_expression=None) for v in self.attributes.values()), + lineage_available=self._lineage_available, + ) diff --git a/src/datajoint/lineage.py b/src/datajoint/lineage.py new file mode 100644 index 000000000..d5fa5d3a1 --- /dev/null +++ b/src/datajoint/lineage.py @@ -0,0 +1,262 @@ +""" +Lineage management for semantic matching in DataJoint. + +Lineage identifies the origin of an attribute - where it was first defined. +It is represented as a string in the format: schema_name.table_name.attribute_name + +Semantic matching is applied to all binary operations that match attributes by name: +- Join (A * B): matches on homologous namesakes +- Restriction (A & B, A - B): matches on homologous namesakes +- Aggregation (A.aggr(B, ...)): requires homologous namesakes for grouping +- Union (A + B): requires all namesakes to have matching lineage + +If namesake attributes have different lineages (including either being None), +a DataJointError is raised. + +If the ~lineage table doesn't exist for a schema, a warning is issued and +semantic checking is disabled for operations involving that schema. + +The ~lineage table stores lineage information for each schema, populated at table +declaration time. Use schema.rebuild_lineage() to restore lineage for legacy schemas. +""" + +import logging + +from .errors import DataJointError + +logger = logging.getLogger(__name__.split(".")[0]) + + +def ensure_lineage_table(connection, database): + """ + Create the ~lineage table in the schema if it doesn't exist. + + :param connection: A DataJoint connection object + :param database: The schema/database name + """ + connection.query( + """ + CREATE TABLE IF NOT EXISTS `{database}`.`~lineage` ( + table_name VARCHAR(64) NOT NULL COMMENT 'table name within the schema', + attribute_name VARCHAR(64) NOT NULL COMMENT 'attribute name', + lineage VARCHAR(255) NOT NULL COMMENT 'origin: schema.table.attribute', + PRIMARY KEY (table_name, attribute_name) + ) ENGINE=InnoDB + """.format(database=database) + ) + + +def lineage_table_exists(connection, database): + """ + Check if the ~lineage table exists in the schema. + + :param connection: A DataJoint connection object + :param database: The schema/database name + :return: True if the table exists, False otherwise + """ + result = connection.query( + """ + SELECT COUNT(*) FROM information_schema.tables + WHERE table_schema = %s AND table_name = '~lineage' + """, + args=(database,), + ).fetchone() + return result[0] > 0 + + +def get_lineage(connection, database, table_name, attribute_name): + """ + Get the lineage for an attribute from the ~lineage table. + + :param connection: A DataJoint connection object + :param database: The schema/database name + :param table_name: The table name + :param attribute_name: The attribute name + :return: The lineage string, or None if not found + """ + if not lineage_table_exists(connection, database): + return None + + result = connection.query( + """ + SELECT lineage FROM `{database}`.`~lineage` + WHERE table_name = %s AND attribute_name = %s + """.format(database=database), + args=(table_name, attribute_name), + ).fetchone() + return result[0] if result else None + + +def get_table_lineages(connection, database, table_name): + """ + Get all lineages for a table from the ~lineage table. + + :param connection: A DataJoint connection object + :param database: The schema/database name + :param table_name: The table name + :return: A dict mapping attribute names to lineage strings + """ + if not lineage_table_exists(connection, database): + return {} + + results = connection.query( + """ + SELECT attribute_name, lineage FROM `{database}`.`~lineage` + WHERE table_name = %s + """.format(database=database), + args=(table_name,), + ).fetchall() + return {row[0]: row[1] for row in results} + + +def insert_lineages(connection, database, entries): + """ + Insert multiple lineage entries in the ~lineage table as a single transaction. + + :param connection: A DataJoint connection object + :param database: The schema/database name + :param entries: A list of (table_name, attribute_name, lineage) tuples + """ + if not entries: + return + ensure_lineage_table(connection, database) + # Build a single INSERT statement with multiple values for atomicity + placeholders = ", ".join(["(%s, %s, %s)"] * len(entries)) + # Flatten the entries into a single args tuple + args = tuple(val for entry in entries for val in entry) + connection.query( + """ + INSERT INTO `{database}`.`~lineage` (table_name, attribute_name, lineage) + VALUES {placeholders} + ON DUPLICATE KEY UPDATE lineage = VALUES(lineage) + """.format(database=database, placeholders=placeholders), + args=args, + ) + + +def delete_table_lineages(connection, database, table_name): + """ + Delete all lineage entries for a table. + + :param connection: A DataJoint connection object + :param database: The schema/database name + :param table_name: The table name + """ + if not lineage_table_exists(connection, database): + return + connection.query( + """ + DELETE FROM `{database}`.`~lineage` + WHERE table_name = %s + """.format(database=database), + args=(table_name,), + ) + + +def rebuild_schema_lineage(connection, database): + """ + Rebuild the ~lineage table for all tables in a schema. + + This utility recomputes lineage for all attributes in all tables + by querying FK relationships from the information_schema. Use this + to restore lineage after corruption or for schemas that predate + the lineage system. + + This function assumes that any upstream schemas (referenced via + cross-schema foreign keys) have already had their lineage rebuilt. + If a referenced attribute in another schema has no lineage entry, + a DataJointError is raised. + + :param connection: A DataJoint connection object + :param database: The schema/database name + """ + # Ensure the lineage table exists + ensure_lineage_table(connection, database) + + # Clear all existing lineage entries for this schema + connection.query(f"DELETE FROM `{database}`.`~lineage`") + + # Get all tables in the schema (excluding hidden tables) + tables_result = connection.query( + """ + SELECT TABLE_NAME FROM information_schema.tables + WHERE TABLE_SCHEMA = %s AND TABLE_NAME NOT LIKE '~%%' + """, + args=(database,), + ).fetchall() + all_tables = {row[0] for row in tables_result} + + if not all_tables: + return + + # Get all primary key columns for all tables + pk_result = connection.query( + """ + SELECT TABLE_NAME, COLUMN_NAME FROM information_schema.KEY_COLUMN_USAGE + WHERE TABLE_SCHEMA = %s AND CONSTRAINT_NAME = 'PRIMARY' + """, + args=(database,), + ).fetchall() + # table -> set of PK columns + pk_columns = {} + for table, col in pk_result: + pk_columns.setdefault(table, set()).add(col) + + # Get all FK relationships within and across schemas + fk_result = connection.query( + """ + SELECT TABLE_NAME, COLUMN_NAME, + REFERENCED_TABLE_SCHEMA, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME + FROM information_schema.KEY_COLUMN_USAGE + WHERE TABLE_SCHEMA = %s AND REFERENCED_TABLE_NAME IS NOT NULL + """, + args=(database,), + ).fetchall() + + # Build FK map: (table, column) -> (parent_schema, parent_table, parent_column) + fk_map = {(table, col): (ref_schema, ref_table, ref_col) for table, col, ref_schema, ref_table, ref_col in fk_result} + + # Lineage cache: (table, column) -> lineage string (for this schema) + lineage_cache = {} + + def resolve_lineage(table, col): + """Recursively resolve lineage for an attribute.""" + if (table, col) in lineage_cache: + return lineage_cache[(table, col)] + + if (table, col) in fk_map: + # FK attribute - get parent's lineage + parent_schema, parent_table, parent_col = fk_map[(table, col)] + if parent_schema == database: + # Same schema - recurse + lineage = resolve_lineage(parent_table, parent_col) + else: + # Cross-schema - query parent's lineage table + lineage = get_lineage(connection, parent_schema, parent_table, parent_col) + if not lineage: + raise DataJointError( + f"Cannot rebuild lineage for `{database}`.`{table}`: " + f"referenced attribute `{parent_schema}`.`{parent_table}`.`{parent_col}` " + f"has no lineage. Rebuild lineage for schema `{parent_schema}` first." + ) + else: + # Native PK attribute - lineage is self + lineage = f"{database}.{table}.{col}" + + lineage_cache[(table, col)] = lineage + return lineage + + # Resolve lineage for all PK and FK attributes + for table in all_tables: + table_pk = pk_columns.get(table, set()) + table_fk_cols = {col for (t, col) in fk_map if t == table} + + # Process all attributes that need lineage (PK and FK) + for col in table_pk | table_fk_cols: + if not col.startswith("_"): + resolve_lineage(table, col) + + # Insert all lineages in one batch + if lineage_cache: + entries = [(table, col, lineage) for (table, col), lineage in lineage_cache.items()] + insert_lineages(connection, database, entries) diff --git a/src/datajoint/schemas.py b/src/datajoint/schemas.py index f61282ed5..6e8cba8e0 100644 --- a/src/datajoint/schemas.py +++ b/src/datajoint/schemas.py @@ -323,6 +323,35 @@ def exists(self): ).rowcount ) + @property + def lineage_table_exists(self): + """ + :return: true if the ~lineage table exists in this schema + """ + from .lineage import lineage_table_exists + + self._assert_exists() + return lineage_table_exists(self.connection, self.database) + + def rebuild_lineage(self): + """ + Rebuild the ~lineage table for all tables in this schema. + + This recomputes lineage for all attributes by querying FK relationships + from the information_schema. Use this to restore lineage for schemas + that predate the lineage system or after corruption. + + After rebuilding, restart the Python kernel and reimport to pick up + the new lineage information. + + Note: Upstream schemas (referenced via cross-schema foreign keys) must + have their lineage rebuilt first. + """ + from .lineage import rebuild_schema_lineage + + self._assert_exists() + rebuild_schema_lineage(self.connection, self.database) + @property def jobs(self): """ diff --git a/src/datajoint/table.py b/src/datajoint/table.py index 23648e1d7..a09c66559 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -100,13 +100,78 @@ def declare(self, context=None): "Table class name `{name}` is invalid. Please use CamelCase. ".format(name=self.class_name) + "Classes defining tables should be formatted in strict CamelCase." ) - sql, _external_stores = declare(self.full_table_name, self.definition, context) + sql, _external_stores, primary_key, fk_attribute_map = declare(self.full_table_name, self.definition, context) sql = sql.format(database=self.database) try: self.connection.query(sql) except AccessError: # skip if no create privilege - pass + return + + # Populate lineage table for this table's attributes + self._populate_lineage(primary_key, fk_attribute_map) + + def _populate_lineage(self, primary_key, fk_attribute_map): + """ + Populate the ~lineage table with lineage information for this table's attributes. + + Lineage is stored for: + - All FK attributes (traced to their origin) + - Native primary key attributes (lineage = self) + + :param primary_key: list of primary key attribute names + :param fk_attribute_map: dict mapping child_attr -> (parent_table, parent_attr) + """ + from .lineage import ( + ensure_lineage_table, + get_lineage, + delete_table_lineages, + insert_lineages, + ) + + # Ensure the ~lineage table exists + ensure_lineage_table(self.connection, self.database) + + # Delete any existing lineage entries for this table (for idempotent re-declaration) + delete_table_lineages(self.connection, self.database, self.table_name) + + entries = [] + + # FK attributes: copy lineage from parent (whether in PK or not) + for attr, (parent_table, parent_attr) in fk_attribute_map.items(): + # Parse parent table name: `schema`.`table` -> (schema, table) + parent_clean = parent_table.replace("`", "") + if "." in parent_clean: + parent_db, parent_tbl = parent_clean.split(".", 1) + else: + parent_db = self.database + parent_tbl = parent_clean + + # Get parent's lineage for this attribute + parent_lineage = get_lineage(self.connection, parent_db, parent_tbl, parent_attr) + if parent_lineage: + # Copy parent's lineage + entries.append((self.table_name, attr, parent_lineage)) + else: + # Parent doesn't have lineage entry - use parent as origin + # This can happen for legacy/external schemas without lineage tracking + lineage = f"{parent_db}.{parent_tbl}.{parent_attr}" + entries.append((self.table_name, attr, lineage)) + logger.warning( + f"Lineage for `{parent_db}`.`{parent_tbl}`.`{parent_attr}` not found " + f"(parent schema's ~lineage table may be missing or incomplete). " + f"Using it as origin. Once the parent schema's lineage is rebuilt, " + f"run schema.rebuild_lineage() on this schema to correct the lineage." + ) + + # Native PK attributes (in PK but not FK): this table is the origin + for attr in primary_key: + if attr not in fk_attribute_map: + lineage = f"{self.database}.{self.table_name}.{attr}" + entries.append((self.table_name, attr, lineage)) + + if entries: + insert_lineages(self.connection, self.database, entries) def alter(self, prompt=True, context=None): """ @@ -608,6 +673,12 @@ def drop_quick(self): Drops the table without cascading to dependent tables and without user prompt. """ if self.is_declared: + # Clean up lineage entries for this table + from .lineage import delete_table_lineages, lineage_table_exists + + if lineage_table_exists(self.connection, self.database): + delete_table_lineages(self.connection, self.database, self.table_name) + query = "DROP TABLE %s" % self.full_table_name self.connection.query(query) logger.info("Dropped table %s" % self.full_table_name) diff --git a/tests/integration/test_aggr_regressions.py b/tests/integration/test_aggr_regressions.py index d87fa37e0..de2cb078a 100644 --- a/tests/integration/test_aggr_regressions.py +++ b/tests/integration/test_aggr_regressions.py @@ -59,8 +59,9 @@ def test_issue449(schema_aggr_reg): """ ---------------- ISSUE 449 ------------------ Issue 449 arises from incorrect group by attributes after joining with a dj.U() + Note: dj.U() * table pattern is no longer supported in 2.0, use dj.U() & table instead """ - result = dj.U("n") * R.aggr(S, n="max(s)") + result = dj.U("n") & R.aggr(S, n="max(s)") result.fetch() diff --git a/tests/integration/test_relation_u.py b/tests/integration/test_relation_u.py index 6af159e7a..e0c96877b 100644 --- a/tests/integration/test_relation_u.py +++ b/tests/integration/test_relation_u.py @@ -7,15 +7,13 @@ def test_restriction(lang, languages, trial): + """Test dj.U restriction semantics.""" language_set = {s[1] for s in languages} rel = dj.U("language") & lang assert list(rel.heading.names) == ["language"] assert len(rel) == len(language_set) assert set(rel.fetch("language")) == language_set - # Test for issue #342 - rel = trial * dj.U("start_time") - assert list(rel.primary_key) == trial.primary_key + ["start_time"] - assert list(rel.primary_key) == list((rel & "trial_id>3").primary_key) + # dj.U & table promotes attributes to PK assert list((dj.U("start_time") & trial).primary_key) == ["start_time"] @@ -29,17 +27,17 @@ def test_ineffective_restriction(lang): assert rel.make_sql() == lang.make_sql() -def test_join(experiment): - rel = experiment * dj.U("experiment_date") - assert experiment.primary_key == ["subject_id", "experiment_id"] - assert rel.primary_key == experiment.primary_key + ["experiment_date"] +def test_join_with_u_removed(experiment): + """Test that table * dj.U(...) raises an error (removed in 2.0).""" + with raises(dj.DataJointError): + experiment * dj.U("experiment_date") - rel = dj.U("experiment_date") * experiment - assert experiment.primary_key == ["subject_id", "experiment_id"] - assert rel.primary_key == experiment.primary_key + ["experiment_date"] + with raises(dj.DataJointError): + dj.U("experiment_date") * experiment def test_invalid_join(schema_any): + """Test that dj.U * non-QueryExpression raises an error.""" with raises(dj.DataJointError): dj.U("language") * dict(language="English") @@ -64,14 +62,20 @@ def test_aggregations(schema_any): def test_argmax(schema_any): + """Test argmax pattern using aggregation and restriction.""" rel = TTest() - # get the tuples corresponding to the maximum value - mx = (rel * dj.U().aggr(rel, mx="max(value)")) & "mx=value" + # Get the maximum value using aggregation + max_val = dj.U().aggr(rel, mx="max(value)").fetch1("mx") + # Get tuples with that value + mx = rel & f"value={max_val}" assert mx.fetch("value")[0] == max(rel.fetch("value")) def test_aggr(schema_any, schema_simp): + """Test aggregation with dj.U - the old * pattern is removed.""" rel = ArgmaxTest() - amax1 = (dj.U("val") * rel) & dj.U("secondary_key").aggr(rel, val="min(val)") - amax2 = (dj.U("val") * rel) * dj.U("secondary_key").aggr(rel, val="min(val)") - assert len(amax1) == len(amax2) == rel.n, "Aggregated argmax with join and restriction does not yield the same length." + # The old pattern using dj.U("val") * rel is no longer supported + # Use aggregation directly instead + agg = dj.U("secondary_key").aggr(rel, min_val="min(val)") + # Verify aggregation works + assert len(agg) > 0 diff --git a/tests/integration/test_relational_operand.py b/tests/integration/test_relational_operand.py index 3f15a7319..60857b0d5 100644 --- a/tests/integration/test_relational_operand.py +++ b/tests/integration/test_relational_operand.py @@ -183,7 +183,7 @@ def test_project(schema_simp_pop): def test_rename_non_dj_attribute(connection_test, schema_simp_pop, schema_any_pop, prefix): schema = prefix + "_test1" connection_test.query(f"CREATE TABLE {schema}.test_table (oldID int PRIMARY KEY)").fetchall() - mySchema = dj.VirtualModule(schema, schema) + mySchema = dj.VirtualModule(schema, schema, connection=connection_test) assert ( "oldID" not in mySchema.TestTable.proj(new_name="oldID").heading.attributes.keys() ), "Failed to rename attribute correctly" @@ -193,7 +193,8 @@ def test_rename_non_dj_attribute(connection_test, schema_simp_pop, schema_any_po def test_union(schema_simp_pop): x = set(zip(*IJ.fetch("i", "j"))) y = set(zip(*JI.fetch("i", "j"))) - assert len(x) > 0 and len(y) > 0 and len(IJ() * JI()) < len(x) # ensure the IJ and JI are non-trivial + # IJ and JI have attributes i,j from different origins, so use semantic_check=False + assert len(x) > 0 and len(y) > 0 and len(IJ().join(JI(), semantic_check=False)) < len(x) z = set(zip(*(IJ + JI).fetch("i", "j"))) # union assert x.union(y) == z assert len(IJ + JI) == len(z) @@ -296,14 +297,15 @@ def test_semijoin(schema_simp_pop): """ x = IJ() y = JI() + # IJ and JI have i,j from different origins - use semantic_check=False n = len(x & y.fetch(as_dict=True)) m = len(x - y.fetch(as_dict=True)) assert n > 0 and m > 0 assert len(x) == m + n assert len(x & y.fetch()) == n assert len(x - y.fetch()) == m - semi = x & y - anti = x - y + semi = x.restrict(y, semantic_check=False) + anti = x.restrict(dj.Not(y), semantic_check=False) assert len(semi) == n assert len(anti) == m @@ -388,7 +390,8 @@ def test_date(schema_simp_pop): def test_join_project(schema_simp_pop): """Test join of projected relations with matching non-primary key""" - q = DataA.proj() * DataB.proj() + # DataA and DataB have 'idx' from different origins, so use semantic_check=False + q = DataA.proj().join(DataB.proj(), semantic_check=False) assert len(q) == len(DataA()) == len(DataB()), "Join of projected relations does not work" @@ -459,13 +462,15 @@ def test_reserved_words2(schema_simp_pop): def test_permissive_join_basic(schema_any_pop): - """Verify join compatibility check is skipped for join""" - Child @ Parent + """Verify join compatibility check can be skipped with semantic_check=False""" + # The @ operator has been removed in 2.0, use .join(semantic_check=False) instead + Child().join(Parent(), semantic_check=False) def test_permissive_restriction_basic(schema_any_pop): - """Verify join compatibility check is skipped for restriction""" - Child ^ Parent + """Verify restriction compatibility check can be skipped with semantic_check=False""" + # The ^ operator has been removed in 2.0, use .restrict(semantic_check=False) instead + Child().restrict(Parent(), semantic_check=False) def test_complex_date_restriction(schema_simp_pop): diff --git a/tests/integration/test_semantic_matching.py b/tests/integration/test_semantic_matching.py new file mode 100644 index 000000000..d8dff27fa --- /dev/null +++ b/tests/integration/test_semantic_matching.py @@ -0,0 +1,342 @@ +""" +Tests for semantic matching in joins. + +These tests verify the lineage-based semantic matching system +that prevents incorrect joins on attributes with the same name +but different origins. +""" + +import pytest + +import datajoint as dj +from datajoint.errors import DataJointError + + +# Schema definitions for semantic matching tests +LOCALS_SEMANTIC = {} + + +class Student(dj.Manual): + definition = """ + student_id : int + --- + name : varchar(100) + """ + + +class Course(dj.Manual): + definition = """ + course_id : int + --- + title : varchar(100) + """ + + +class Enrollment(dj.Manual): + definition = """ + -> Student + -> Course + --- + grade : varchar(2) + """ + + +class Session(dj.Manual): + definition = """ + session_id : int + --- + date : date + """ + + +class Trial(dj.Manual): + definition = """ + -> Session + trial_num : int + --- + stimulus : varchar(100) + """ + + +class Response(dj.Computed): + definition = """ + -> Trial + --- + response_time : float + """ + + +# Tables with generic 'id' attribute for collision testing +class TableWithId1(dj.Manual): + definition = """ + id : int + --- + value1 : int + """ + + +class TableWithId2(dj.Manual): + definition = """ + id : int + --- + value2 : int + """ + + +# Register all classes in LOCALS_SEMANTIC +for cls in [ + Student, + Course, + Enrollment, + Session, + Trial, + Response, + TableWithId1, + TableWithId2, +]: + LOCALS_SEMANTIC[cls.__name__] = cls + + +@pytest.fixture(scope="module") +def schema_semantic(connection_test, prefix): + """Schema for semantic matching tests.""" + schema = dj.Schema( + prefix + "_semantic", + context=LOCALS_SEMANTIC, + connection=connection_test, + ) + # Declare tables + schema(Student) + schema(Course) + schema(Enrollment) + schema(Session) + schema(Trial) + # Skip Response for now - it's a computed table + schema(TableWithId1) + schema(TableWithId2) + + yield schema + schema.drop() + + +class TestLineageComputation: + """Tests for lineage computation from dependency graph.""" + + def test_native_primary_key_has_lineage(self, schema_semantic): + """Native primary key attributes should have lineage pointing to their table.""" + student = Student() + lineage = student.heading["student_id"].lineage + assert lineage is not None + assert "student_id" in lineage + # The lineage should include schema and table name + assert "student" in lineage.lower() + + def test_inherited_attribute_traces_to_origin(self, schema_semantic): + """FK-inherited attributes should trace lineage to the original table.""" + enrollment = Enrollment() + # student_id is inherited from Student + student_lineage = enrollment.heading["student_id"].lineage + assert student_lineage is not None + assert "student" in student_lineage.lower() + + # course_id is inherited from Course + course_lineage = enrollment.heading["course_id"].lineage + assert course_lineage is not None + assert "course" in course_lineage.lower() + + def test_secondary_attribute_no_lineage(self, schema_semantic): + """Native secondary attributes should have no lineage.""" + student = Student() + name_lineage = student.heading["name"].lineage + assert name_lineage is None + + def test_multi_hop_inheritance(self, schema_semantic): + """Lineage should trace through multiple FK hops.""" + trial = Trial() + # session_id in Trial is inherited from Session + session_lineage = trial.heading["session_id"].lineage + assert session_lineage is not None + assert "session" in session_lineage.lower() + + +class TestJoinCompatibility: + """Tests for join compatibility checking.""" + + def test_join_on_shared_lineage_works(self, schema_semantic): + """Joining tables with shared lineage should work.""" + student = Student() + enrollment = Enrollment() + + # This should work - student_id has same lineage in both + result = student * enrollment + assert "student_id" in result.heading.names + + def test_join_different_lineage_default_fails(self, schema_semantic): + """By default (semantic_check=True), non-homologous namesakes cause an error.""" + table1 = TableWithId1() + table2 = TableWithId2() + + # Default is semantic_check=True, this should fail + with pytest.raises(DataJointError) as exc_info: + table1 * table2 + + assert "lineage" in str(exc_info.value).lower() + assert "id" in str(exc_info.value) + + def test_join_different_lineage_semantic_check_false_works(self, schema_semantic): + """With semantic_check=False, no lineage checking - natural join proceeds.""" + table1 = TableWithId1() + table2 = TableWithId2() + + # With semantic_check=False, no error even with different lineages + result = table1.join(table2, semantic_check=False) + assert "id" in result.heading.names + + +class TestRestrictCompatibility: + """Tests for restriction compatibility checking.""" + + def test_restrict_shared_lineage_works(self, schema_semantic): + """Restricting with shared lineage should work.""" + student = Student() + enrollment = Enrollment() + + # This should work - student_id has same lineage + result = student & enrollment + assert "student_id" in result.heading.names + + def test_restrict_different_lineage_default_fails(self, schema_semantic): + """By default (semantic_check=True), non-homologous namesakes cause an error.""" + table1 = TableWithId1() + table2 = TableWithId2() + + # Default is semantic_check=True, this should fail + with pytest.raises(DataJointError) as exc_info: + table1 & table2 + + assert "lineage" in str(exc_info.value).lower() + + def test_restrict_different_lineage_semantic_check_false_works(self, schema_semantic): + """With semantic_check=False, no lineage checking - restriction proceeds.""" + table1 = TableWithId1() + table2 = TableWithId2() + + # With semantic_check=False, no error even with different lineages + result = table1.restrict(table2, semantic_check=False) + assert "id" in result.heading.names + + +class TestProjectionLineage: + """Tests for lineage preservation in projections.""" + + def test_projection_preserves_lineage(self, schema_semantic): + """Projected attributes should preserve their lineage.""" + enrollment = Enrollment() + + projected = enrollment.proj("grade") + # Primary key attributes should still have lineage + assert projected.heading["student_id"].lineage is not None + + def test_renamed_attribute_preserves_lineage(self, schema_semantic): + """Renamed attributes should preserve their original lineage.""" + student = Student() + + renamed = student.proj(sid="student_id") + # The renamed attribute should have the same lineage as original + original_lineage = student.heading["student_id"].lineage + renamed_lineage = renamed.heading["sid"].lineage + assert renamed_lineage == original_lineage + + def test_computed_attribute_no_lineage(self, schema_semantic): + """Computed attributes should have no lineage.""" + student = Student() + + computed = student.proj(doubled="student_id * 2") + # Computed attributes have no lineage + assert computed.heading["doubled"].lineage is None + + +class TestRemovedOperators: + """Tests for removed operators.""" + + def test_matmul_operator_removed(self, schema_semantic): + """The @ operator should raise an error.""" + student = Student() + course = Course() + + with pytest.raises(DataJointError) as exc_info: + student @ course + + assert "@" in str(exc_info.value) or "matmul" in str(exc_info.value).lower() + assert "removed" in str(exc_info.value).lower() + + def test_xor_operator_removed(self, schema_semantic): + """The ^ operator should raise an error.""" + student = Student() + course = Course() + + with pytest.raises(DataJointError) as exc_info: + student ^ course + + assert "^" in str(exc_info.value) or "removed" in str(exc_info.value).lower() + + +class TestUniversalSetOperators: + """Tests for dj.U operations.""" + + def test_u_mul_raises_error(self, schema_semantic): + """dj.U * table should raise an error.""" + student = Student() + + with pytest.raises(DataJointError) as exc_info: + dj.U("student_id") * student + + assert "no longer supported" in str(exc_info.value).lower() + + def test_table_mul_u_raises_error(self, schema_semantic): + """table * dj.U should raise an error.""" + student = Student() + + with pytest.raises(DataJointError) as exc_info: + student * dj.U("student_id") + + assert "no longer supported" in str(exc_info.value).lower() + + def test_u_sub_raises_error(self, schema_semantic): + """dj.U - table should raise an error (infinite set).""" + student = Student() + + with pytest.raises(DataJointError) as exc_info: + dj.U("student_id") - student + + assert "infinite" in str(exc_info.value).lower() + + def test_u_and_works(self, schema_semantic): + """dj.U & table should work for restriction.""" + student = Student() + student.insert([{"student_id": 1, "name": "Alice"}, {"student_id": 2, "name": "Bob"}]) + + result = dj.U("student_id") & student + assert len(result) == 2 + + +class TestRebuildLineageUtility: + """Tests for the lineage rebuild utility.""" + + def test_rebuild_lineage_method_exists(self): + """The rebuild_lineage method should exist on Schema.""" + assert hasattr(dj.Schema, "rebuild_lineage") + + def test_rebuild_lineage_populates_table(self, schema_semantic): + """schema.rebuild_lineage() should populate the ~lineage table.""" + from datajoint.lineage import get_table_lineages + + # Run rebuild using Schema method + schema_semantic.rebuild_lineage() + + # Check that ~lineage table was created + assert schema_semantic.lineage_table_exists + + # Check that lineages were populated for Student table + lineages = get_table_lineages(schema_semantic.connection, schema_semantic.database, "student") + assert "student_id" in lineages From a3dcabfdd876d78d84e77312fb3a326868bc5cd7 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 2 Jan 2026 00:49:02 -0600 Subject: [PATCH 2/3] Simplify delete_table_lineages call and update spec types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove redundant lineage_table_exists check in table.py (already handled inside delete_table_lineages) - Update spec examples to use core DataJoint types (uint32, uint16) instead of native types (int) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- docs/src/design/semantic-matching-spec.md | 44 +++++++++++------------ src/datajoint/table.py | 5 ++- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/docs/src/design/semantic-matching-spec.md b/docs/src/design/semantic-matching-spec.md index 666fc4cfe..f860e8487 100644 --- a/docs/src/design/semantic-matching-spec.md +++ b/docs/src/design/semantic-matching-spec.md @@ -25,14 +25,14 @@ Semantic matching is enabled by default in DataJoint 2.0. For most well-designed # Two tables with generic 'id' attribute class Student(dj.Manual): definition = """ - id : int + id : uint32 --- name : varchar(100) """ class Course(dj.Manual): definition = """ - id : int + id : uint32 --- title : varchar(100) """ @@ -57,7 +57,7 @@ Student().join(Course(), semantic_check=False) # OK, but be careful! ```python class Student(dj.Manual): definition = """ - student_id : int + student_id : uint32 --- name : varchar(100) """ @@ -235,15 +235,15 @@ schema_name.table_name.attribute_name ```python class Session(dj.Manual): # table: session definition = """ - session_id : int + session_id : uint32 --- - date : date + session_date : date """ class Trial(dj.Manual): # table: trial definition = """ -> Session - trial_num : int + trial_num : uint16 --- stimulus : varchar(100) """ @@ -251,7 +251,7 @@ class Trial(dj.Manual): # table: trial Lineages: - `Session.session_id` → `myschema.session.session_id` (native PK) -- `Session.date` → `None` (native secondary) +- `Session.session_date` → `None` (native secondary) - `Trial.session_id` → `myschema.session.session_id` (inherited via FK) - `Trial.trial_num` → `myschema.trial.trial_num` (native PK) - `Trial.stimulus` → `None` (native secondary) @@ -385,7 +385,7 @@ run schema.rebuild_lineage() on this schema to correct the lineage. ```python class Student(dj.Manual): definition = """ - student_id : int + student_id : uint32 --- name : varchar(100) """ @@ -407,16 +407,16 @@ Student() * Enrollment() ```python class TableA(dj.Manual): definition = """ - id : int + id : uint32 --- - value_a : int + value_a : int32 """ class TableB(dj.Manual): definition = """ - id : int + id : uint32 --- - value_b : int + value_b : int32 """ # Error: 'id' has different lineages @@ -434,22 +434,22 @@ TableA().join(TableB(), semantic_check=False) ```python class Session(dj.Manual): definition = """ - session_id : int + session_id : uint32 --- - date : date + session_date : date """ class Trial(dj.Manual): definition = """ -> Session - trial_num : int + trial_num : uint16 """ class Response(dj.Computed): definition = """ -> Trial --- - response_time : float + response_time : float64 """ # All work: session_id traces back to Session in all tables @@ -463,21 +463,21 @@ Trial() * Response() ```python class Course(dj.Manual): definition = """ - course_id : int + course_id : int unsigned --- title : varchar(100) """ class FavoriteCourse(dj.Manual): definition = """ - student_id : int + student_id : int unsigned --- -> Course """ class RequiredCourse(dj.Manual): definition = """ - major_id : int + major_id : int unsigned --- -> Course """ @@ -491,9 +491,9 @@ FavoriteCourse() * RequiredCourse() ```python class Person(dj.Manual): definition = """ - person_id : int + person_id : int unsigned --- - name : varchar(100) + full_name : varchar(100) """ class Marriage(dj.Manual): @@ -501,7 +501,7 @@ class Marriage(dj.Manual): -> Person.proj(husband='person_id') -> Person.proj(wife='person_id') --- - date : date + marriage_date : date """ # husband and wife both have lineage: schema.person.person_id diff --git a/src/datajoint/table.py b/src/datajoint/table.py index a09c66559..00d1e8de8 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -674,10 +674,9 @@ def drop_quick(self): """ if self.is_declared: # Clean up lineage entries for this table - from .lineage import delete_table_lineages, lineage_table_exists + from .lineage import delete_table_lineages - if lineage_table_exists(self.connection, self.database): - delete_table_lineages(self.connection, self.database, self.table_name) + delete_table_lineages(self.connection, self.database, self.table_name) query = "DROP TABLE %s" % self.full_table_name self.connection.query(query) From 19cde1cba9e4cf64c91723ef0bd57146738fb51e Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 2 Jan 2026 01:01:16 -0600 Subject: [PATCH 3/3] Add schema.lineage property to view all lineages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add get_schema_lineages() function in lineage.py - Add schema.lineage property returning flat dict mapping 'schema.table.attribute' to its lineage origin - Add note about A - B without semantic check in spec - Document schema.lineage in API reference 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- docs/src/design/semantic-matching-spec.md | 15 +++++++++++++++ src/datajoint/lineage.py | 20 ++++++++++++++++++++ src/datajoint/schemas.py | 12 ++++++++++++ 3 files changed, 47 insertions(+) diff --git a/docs/src/design/semantic-matching-spec.md b/docs/src/design/semantic-matching-spec.md index f860e8487..b3333a873 100644 --- a/docs/src/design/semantic-matching-spec.md +++ b/docs/src/design/semantic-matching-spec.md @@ -128,6 +128,19 @@ if schema.lineage_table_exists: **Returns**: `bool` - `True` if `~lineage` table exists, `False` otherwise. +#### `schema.lineage` + +Property returning all lineage entries for the schema. + +```python +schema.lineage +# {'myschema.session.session_id': 'myschema.session.session_id', +# 'myschema.trial.session_id': 'myschema.session.session_id', +# 'myschema.trial.trial_num': 'myschema.trial.trial_num'} +``` + +**Returns**: `dict` - Maps `'schema.table.attribute'` to its lineage origin + ### Join Methods #### `expr.join(other, semantic_check=True)` @@ -174,6 +187,8 @@ Equivalent to `A.restrict(B, semantic_check=True)`. Restriction with negation. Semantic checking applies. +To bypass semantic checking: `A.restrict(dj.Not(B), semantic_check=False)` + #### `A + B` (Union) Union of expressions. Requires all namesake attributes to have matching lineage. diff --git a/src/datajoint/lineage.py b/src/datajoint/lineage.py index d5fa5d3a1..63a2d675b 100644 --- a/src/datajoint/lineage.py +++ b/src/datajoint/lineage.py @@ -109,6 +109,26 @@ def get_table_lineages(connection, database, table_name): return {row[0]: row[1] for row in results} +def get_schema_lineages(connection, database): + """ + Get all lineages for a schema from the ~lineage table. + + :param connection: A DataJoint connection object + :param database: The schema/database name + :return: A dict mapping 'schema.table.attribute' to its lineage + """ + if not lineage_table_exists(connection, database): + return {} + + results = connection.query( + """ + SELECT table_name, attribute_name, lineage FROM `{database}`.`~lineage` + """.format(database=database), + ).fetchall() + + return {f"{database}.{table}.{attr}": lineage for table, attr, lineage in results} + + def insert_lineages(connection, database, entries): """ Insert multiple lineage entries in the ~lineage table as a single transaction. diff --git a/src/datajoint/schemas.py b/src/datajoint/schemas.py index 6e8cba8e0..ae03d328c 100644 --- a/src/datajoint/schemas.py +++ b/src/datajoint/schemas.py @@ -333,6 +333,18 @@ def lineage_table_exists(self): self._assert_exists() return lineage_table_exists(self.connection, self.database) + @property + def lineage(self): + """ + Get all lineages for tables in this schema. + + :return: A dict mapping 'schema.table.attribute' to its lineage + """ + from .lineage import get_schema_lineages + + self._assert_exists() + return get_schema_lineages(self.connection, self.database) + def rebuild_lineage(self): """ Rebuild the ~lineage table for all tables in this schema.