|
12 | 12 | # limitations under the License. |
13 | 13 |
|
14 | 14 | from __future__ import annotations |
15 | | -from typing import Optional, TYPE_CHECKING |
| 15 | +from typing import Optional, TYPE_CHECKING, cast |
16 | 16 | import math |
17 | 17 | from sys import version_info |
18 | 18 |
|
|
29 | 29 | # prevent circular dependenacy by skipping import at runtime |
30 | 30 | from .project_config import ProjectConfig |
31 | 31 | from .entities import Experiment, Variation, Holdout |
32 | | - from .helpers.types import TrafficAllocation, VariationDict |
| 32 | + from .helpers.types import TrafficAllocation |
33 | 33 |
|
34 | 34 |
|
35 | 35 | MAX_TRAFFIC_VALUE: Final = 10000 |
@@ -105,7 +105,7 @@ def find_bucket( |
105 | 105 | def bucket( |
106 | 106 | self, project_config: ProjectConfig, |
107 | 107 | experiment: Experiment | Holdout, user_id: str, bucketing_id: str |
108 | | - ) -> tuple[Optional[Variation | VariationDict], list[str]]: |
| 108 | + ) -> tuple[Variation | None, list[str]]: |
109 | 109 | """ For a given experiment and bucketing ID determines variation to be shown to user. |
110 | 110 |
|
111 | 111 | Args: |
@@ -137,7 +137,8 @@ def bucket( |
137 | 137 | variation_id, decide_reasons = self.bucket_to_entity_id(project_config, experiment, user_id, bucketing_id) |
138 | 138 | if variation_id: |
139 | 139 | variation = project_config.get_variation_from_id_by_experiment_id(experiment_id, variation_id) |
140 | | - return variation, decide_reasons |
| 140 | + # Cast is safe here because experiments always use Variation entities, not VariationDict |
| 141 | + return cast(Optional[Variation], variation), decide_reasons |
141 | 142 |
|
142 | 143 | # No variation found - log message for empty traffic range |
143 | 144 | message = 'Bucketed into an empty traffic range. Returning nil.' |
|
0 commit comments